1. 连接服务器 #
1.1. progress.py #
progress.py
# 从client模块导入MCPClient类
from client import MCPClient
# 定义进度跟踪功能演示函数
def progress():
"""进度跟踪功能演示"""
# 创建MCP客户端实例,连接到本地8000端口的服务器
client = MCPClient("http://localhost:8000")
try:
# 打印正在连接服务器的信息
print(" 正在连接到MCP服务器...")
# 初始化客户端连接,如果失败则提示并返回
if not client.initialize():
print(" 连接初始化失败")
return
# 捕获并处理异常
except Exception as e:
# 打印错误信息
print(f" 演示过程中出现错误: {e}")
# 导入traceback模块用于打印详细的异常信息
import traceback
# 打印异常的详细堆栈信息
traceback.print_exc()
# 定义主函数
def main():
"""主函数"""
# 打印演示标题
print("🔧 MCP进度跟踪功能演示")
# 提示用户确保服务器已启动
print("请确保MCP服务器正在运行 (python server.py)")
# 调用进度跟踪演示函数
progress()
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main()
2. 无进度跟踪的工具调用 #
2.1. progress.py #
progress.py
# 从client模块导入MCPClient类
from client import MCPClient
+import time
# 定义进度跟踪功能演示函数
def progress():
"""进度跟踪功能演示"""
# 创建MCP客户端实例,连接到本地8000端口的服务器
client = MCPClient("http://localhost:8000")
try:
# 打印正在连接服务器的信息
print(" 正在连接到MCP服务器...")
# 初始化客户端连接,如果失败则提示并返回
if not client.initialize():
print(" 连接初始化失败")
return
# 打印连接成功的信息
+ print(" 连接成功!")
# 打印当前会话信息
+ print(f" 会话信息: {client.get_session_info()}")
# 获取可用工具列表
+ tools = client.list_tools()
# 打印所有可用工具的名称
+ print(f"🛠️ 可用工具: {[tool['name'] for tool in tools]}")
# 分隔线,提示即将开始演示1
+ print("\n" + "=" * 50)
# 打印演示1的标题
+ print("📊 演示1: 无进度跟踪的工具调用")
# 再次打印分隔线
+ print("=" * 50)
# 记录工具调用开始时间
+ start_time = time.time()
# 调用名为"calculate"的工具,传入表达式参数
+ result = client.call_tool("calculate", {"expression": "10 + 20 * 3"})
# 记录工具调用结束时间
+ end_time = time.time()
# 如果调用有返回结果
+ if result:
# 获取结果中的content字段,默认为空列表
+ content = result.get("content", [])
# 如果content不为空
+ if content:
# 打印计算结果
+ print(f" 计算结果: {content[0].get('text', '')}")
# 打印工具调用的执行时间,保留两位小数
+ print(f"⏱️ 执行时间: {end_time - start_time:.2f}秒")
# 捕获并处理异常
except Exception as e:
# 打印错误信息
print(f" 演示过程中出现错误: {e}")
# 导入traceback模块用于打印详细的异常信息
import traceback
# 打印异常的详细堆栈信息
traceback.print_exc()
# 定义主函数
def main():
"""主函数"""
# 打印演示标题
print("🔧 MCP进度跟踪功能演示")
# 提示用户确保服务器已启动
print("请确保MCP服务器正在运行 (python server.py)")
# 调用进度跟踪演示函数
progress()
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main()
3. 有进度跟踪的工具调用 #
3.1. client.py #
client.py
# 导入json模块,用于处理JSON数据
import json
# 导入logging模块,用于日志记录
import logging
# 导入uuid模块,用于生成唯一标识
+import uuid
# 导入requests模块,用于发送HTTP请求
import requests
# 从urllib.parse模块导入urljoin,用于拼接URL
from urllib.parse import urljoin
# 配置日志级别为INFO
logging.basicConfig(level=logging.INFO)
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义MCPClient类,表示MCP HTTP客户端
class MCPClient:
"""MCP HTTP客户端"""
# 初始化方法,设置服务器URL和相关属性
def __init__(self, server_url: str):
# 协议版本号
self.protocol_version = "2025-06-18"
# 客户端信息
self.client_info = {"name": "MCP客户端", "version": "1.0.0"}
# 服务器信息,初始化为None
self.server_info = None
# 服务器能力,初始化为None
self.capabilities = None
# 会话ID,初始化为None
self.session_id = None
# 消息ID,初始值为1
self.message_id = 1
# 去除服务器URL末尾的斜杠
self.server_url = server_url.rstrip("/")
# 拼接MCP接口的完整URL
self.mcp_endpoint = urljoin(self.server_url, "/mcp")
# 创建requests的Session对象
self.session = requests.Session()
# 设置Session的默认请求头
self.session.headers.update(
{
"Accept": "application/json, text/event-stream",
"MCP-Protocol-Version": self.protocol_version,
}
)
# 处理进度通知
+ def _handle_progress_notification(self, notification):
+ """处理进度通知"""
# 获取进度令牌
+ progress_token = notification.get("progressToken")
# 获取进度
+ progress = notification.get("progress")
# 获取总进度
+ total = notification.get("total")
# 获取进度消息
+ message = notification.get("message", "")
# 记录进度通知
+ logger.info(
+ f"收到进度通知 - 令牌: {progress_token}, 进度: {progress}/{total}, 消息: {message}"
+ )
# 重置会话方法
def reset_session(self):
"""重置会话"""
# 保存旧的会话ID
old_session_id = self.session_id
# 清空当前会话ID
self.session_id = None
# 记录会话重置信息
logger.info(f"会话已重置,原会话ID: {old_session_id}")
# 打开SSE流连接方法
def open_sse_stream(self):
"""打开SSE流连接"""
try:
# 构造请求头
headers = {"Accept": "text/event-stream"}
# 如果有会话ID则加入请求头
if self.session_id:
headers["Mcp-Session-Id"] = self.session_id
# 发送GET请求,开启SSE流
response = self.session.get(
+ self.mcp_endpoint, headers=headers, stream=True, timeout=None
)
# 如果响应状态码为200,表示连接成功
if response.status_code == 200:
logger.info("SSE流已打开")
return response
else:
# 连接失败,记录错误
logger.error(f"SSE流打开失败: {response.status_code}")
return None
# 捕获请求异常
except requests.exceptions.RequestException as e:
logger.error(f"SSE流连接错误: {e}")
return None
# 获取当前会话信息方法
def get_session_info(self):
"""获取当前会话信息"""
# 返回会话相关信息字典
return {
"session_id": self.session_id,
"server_url": self.server_url,
"protocol_version": self.protocol_version,
}
# 获取可用工具列表方法
def list_tools(self):
"""获取可用工具列表"""
# 构造请求消息
request = {
"jsonrpc": "2.0",
"id": self._generate_id(),
"method": "tools/list",
"params": {},
}
# 发送请求并获取响应
response = self._send_message(request)
# 如果响应存在且包含result字段
if response and "result" in response:
# 获取工具列表
tools = response["result"].get("tools", [])
logger.info(f"可用工具数量: {len(tools)}")
return tools
else:
# 获取失败,记录错误
logger.error("获取工具列表失败")
return []
# 通过HTTP发送消息的私有方法
def _send_message(self, message):
"""通过HTTP发送消息"""
try:
# 初始化请求头字典
headers = {}
# 如果有会话ID,则添加到请求头
if self.session_id:
headers["Mcp-Session-Id"] = self.session_id
+ logger.info(f"使用会话ID: {self.session_id}")
# 发送POST请求
response = self.session.post(
self.mcp_endpoint, json=message, headers=headers, timeout=30
)
# 检查响应状态码
if response.status_code == 200:
# 检查响应头中是否有新的会话ID
new_session_id = response.headers.get("Mcp-Session-Id")
# 如果有新的会话ID且与当前不同,则更新
if new_session_id and new_session_id != self.session_id:
self.session_id = new_session_id
logger.info(f"更新会话ID: {self.session_id}")
# 返回响应的JSON内容
return response.json()
# 如果状态码为202,表示已接受但无内容
elif response.status_code == 202:
# Accepted - 通知或响应
return None
# 如果状态码为400,表示请求错误
elif response.status_code == 400:
logger.error(f"请求错误: {response.text}")
return None
# 如果状态码为401,表示会话ID无效
elif response.status_code == 401:
logger.error("会话ID无效,需要重新初始化")
self.session_id = None
return None
# 其他HTTP错误
else:
logger.error(f"HTTP错误: {response.status_code} - {response.text}")
return None
# 捕获请求异常
except requests.exceptions.RequestException as e:
logger.error(f"HTTP请求错误: {e}")
return None
# 生成消息ID的方法
def _generate_id(self) -> str:
"""生成消息ID"""
# 消息ID自增
self.message_id += 1
# 返回字符串类型的消息ID
return str(self.message_id)
# 发送消息到服务器的方法(重复定义,实际只会用后面这个)
def _send_message(self, message):
"""通过HTTP发送消息"""
try:
# 初始化请求头字典
headers = {}
# 如果有会话ID,则添加到请求头
if self.session_id:
headers["Mcp-Session-Id"] = self.session_id
+ logger.info(f"使用会话ID: {self.session_id}")
# 发送POST请求到MCP端点
response = self.session.post(
self.mcp_endpoint, json=message, headers=headers, timeout=30
)
# 检查响应状态码
if response.status_code == 200:
# 从响应头获取新的会话ID
new_session_id = response.headers.get("Mcp-Session-Id")
# 如果有新的会话ID且与当前不同,则更新
if new_session_id and new_session_id != self.session_id:
self.session_id = new_session_id
logger.info(f"更新会话ID: {self.session_id}")
# 返回响应的JSON内容
return response.json()
# 如果状态码为202,表示已接受但无内容
elif response.status_code == 202:
# Accepted - 通知或响应
return None
# 如果状态码为400,表示请求错误
elif response.status_code == 400:
logger.error(f"请求错误: {response.text}")
return None
# 如果状态码为401,表示会话ID无效
elif response.status_code == 401:
logger.error("会话ID无效,需要重新初始化")
self.session_id = None
return None
# 其他HTTP错误
else:
logger.error(f"HTTP错误: {response.status_code} - {response.text}")
return None
# 捕获请求异常
except requests.exceptions.RequestException as e:
logger.error(f"HTTP请求错误: {e}")
return None
# 生成进度令牌
+ def _generate_progress_token(self) -> str:
+ """生成唯一的进度令牌"""
# 生成并返回UUID字符串
+ return str(uuid.uuid4())
# 调用工具方法
+ def call_tool(self, tool_name, arguments, enable_progress=False):
"""调用工具"""
# 构造调用工具的请求消息
request = {
"jsonrpc": "2.0",
"id": self._generate_id(),
"method": "tools/call",
"params": {"calls": [{"name": tool_name, "arguments": arguments}]},
}
# 如果启用进度跟踪,添加进度令牌
+ if enable_progress:
# 生成进度令牌
+ progress_token = self._generate_progress_token()
# 添加进度令牌到请求参数
+ request["params"]["_meta"] = {"progressToken": progress_token}
# 记录启用进度跟踪的信息
+ logger.info(f"启用进度跟踪,令牌: {progress_token}")
# 发送请求并获取响应
response = self._send_message(request)
# 如果响应存在且包含result字段
if response and "result" in response:
# 获取调用结果
calls = response["result"].get("calls", [])
if calls:
return calls[0]
else:
# 调用失败,记录错误
logger.error(f"调用工具 {tool_name} 失败")
return None
# 初始化与服务器连接的方法
def initialize(self) -> bool:
"""初始化与服务器的连接"""
# 构造初始化请求消息
request = {
"jsonrpc": "2.0",
"id": self._generate_id(),
"method": "initialize",
"params": {
"protocolVersion": self.protocol_version,
"capabilities": {},
"clientInfo": self.client_info,
},
}
# 发送初始化请求
response = self._send_message(request)
# 如果响应存在且包含result字段
if response and "result" in response:
# 获取result字段
result = response["result"]
# 获取服务器信息
self.server_info = result.get("serverInfo")
# 获取服务器能力
self.capabilities = result.get("capabilities")
# 提取并存储会话ID
if "sessionId" in result:
self.session_id = result["sessionId"]
logger.info(f"获取到会话ID: {self.session_id}")
logger.info(f"服务器初始化成功: {self.server_info}")
return True
else:
# 初始化失败,记录错误
logger.error("服务器初始化失败")
return False
# HTTP传输的函数
def http_transport():
"""HTTP传输"""
# 打印分隔线
print("=== HTTP传输 ===")
# 创建HTTP客户端实例
client = MCPClient("http://localhost:8000")
try:
# 初始化客户端与服务器的连接
if not client.initialize():
print("初始化失败")
return
# 显示会话信息
session_info = client.get_session_info()
print(f" 会话信息: {session_info}")
# 获取工具列表
tools = client.list_tools()
print(f"可用工具: {[tool['name'] for tool in tools]}")
# 调用calculate工具
result = client.call_tool("calculate", {"expression": "10 / 2 + 5"})
if result:
# 获取内容字段
content = result.get("content", [])
if content:
print(f"计算结果: {content[0].get('text', '')}")
# 打开SSE流并持续监听
sse_response = client.open_sse_stream()
if sse_response:
print(" SSE流连接成功")
try:
# 持续监听SSE事件
for line in sse_response.iter_lines(decode_unicode=True):
# 处理event事件类型行
if line.startswith("event: "):
# 去掉 'event: ' 前缀
+ event_type = line[7:]
# 处理data数据行
elif line.startswith("data: "):
# 去掉 'data: ' 前缀
+ data_str = line[6:]
try:
# 解析JSON数据
data = json.loads(data_str)
# 处理connected事件
if event_type == "connected":
status = data.get("status")
session_id = data.get("sessionId")
print(f"🔗 连接状态: {status}, 会话ID: {session_id}")
# 处理进度通知事件
+ elif event_type == "notifications/progress":
+ client._handle_progress_notification(data)
# 处理未知事件类型
else:
print(f" 未知事件类型: {event_type}, 数据: {data}")
# 捕获JSON解析异常
except json.JSONDecodeError:
print(f" 数据解析错误: {data_str}")
# 空行表示事件结束,跳过
elif line == "":
continue
# 捕获用户中断
except KeyboardInterrupt:
print("\n⏹️ 用户停止监听")
finally:
# 关闭SSE流
sse_response.close()
print("🔌 SSE流已关闭")
# 最终会话状态
final_session_info = client.get_session_info()
print(f"🏁 最终会话状态: {final_session_info}")
# 会话重置
print("\n=== 会话重置 ===")
client.reset_session()
reset_session_info = client.get_session_info()
print(f"🔄 重置后会话状态: {reset_session_info}")
# 捕获并打印异常
except Exception as e:
print(f"HTTP错误: {e}")
# 主函数
def main():
"""主函数"""
# 调用HTTP传输函数
# http_transport()
# 判断是否为主程序入口
if __name__ == "__main__":
main()
3.2. progress.py #
progress.py
# 导入threading模块,用于多线程操作
+import threading
# 从client模块导入MCPClient类
from client import MCPClient
# 导入time模块,用于计时
import time
# 导入json模块,用于处理JSON数据
+import json
# 定义SSE监听器函数,负责持续监听服务器推送的事件
+def start_sse_listerner(client):
# 打开SSE流并持续监听
+ sse_response = client.open_sse_stream()
# 如果SSE流连接成功
+ if sse_response:
# 打印SSE流连接成功提示
+ print(" SSE流连接成功")
+ try:
# 持续监听SSE事件,每次读取一行
+ for line in sse_response.iter_lines(chunk_size=1, decode_unicode=True):
# 打印接收到的原始行内容
+ print(line)
# 如果是事件类型行(以"event: "开头)
+ if line.startswith("event: "):
# 提取事件类型(去掉前缀)
+ event_type = line[7:]
# 如果是数据行(以"data: "开头)
+ elif line.startswith("data: "):
# 提取数据内容(去掉前缀)
+ data_str = line[6:]
+ try:
# 尝试将数据内容解析为JSON对象
+ data = json.loads(data_str)
# 如果事件类型为"connected"
+ if event_type == "connected":
# 获取连接状态
+ status = data.get("status")
# 获取会话ID
+ session_id = data.get("sessionId")
# 打印连接状态和会话ID
+ print(f"🔗 连接状态: {status}, 会话ID: {session_id}")
# 如果事件类型为进度通知
+ elif event_type == "notifications/progress":
# 调用客户端的进度通知处理方法
+ client._handle_progress_notification(data)
# 其他未知事件类型
+ else:
# 打印未知事件类型及其数据
+ print(f" 未知事件类型: {event_type}, 数据: {data}")
# 捕获JSON解析异常
+ except json.JSONDecodeError:
# 打印数据解析错误提示
+ print(f" 数据解析错误: {data_str}")
# 如果是空行,表示一个事件结束,跳过
+ elif line == "":
+ continue
# 捕获用户中断(如Ctrl+C)
+ except KeyboardInterrupt:
# 打印用户停止监听提示
+ print("\n⏹️ 用户停止监听")
# 定义进度跟踪功能演示函数
def progress():
"""进度跟踪功能演示"""
# 创建MCP客户端实例,连接到本地8000端口的服务器
client = MCPClient("http://localhost:8000")
try:
# 打印正在连接服务器的信息
print(" 正在连接到MCP服务器...")
# 初始化客户端连接,如果失败则提示并返回
if not client.initialize():
print(" 连接初始化失败")
return
# 启动SSE监听线程,daemon=True表示主线程退出时自动结束
+ sse_listerner_thread = threading.Thread(
+ target=start_sse_listerner, args=(client,), daemon=True
+ )
# 启动SSE监听线程
+ sse_listerner_thread.start()
# 打印连接成功的信息
print(" 连接成功!")
# 打印当前会话信息
print(f" 会话信息: {client.get_session_info()}")
# 获取可用工具列表
tools = client.list_tools()
# 打印所有可用工具的名称
print(f"🛠️ 可用工具: {[tool['name'] for tool in tools]}")
# 打印分隔线,提示即将开始演示1
print("\n" + "=" * 50)
# 打印演示1的标题
print("📊 演示1: 无进度跟踪的工具调用")
# 再次打印分隔线
print("=" * 50)
# 记录工具调用开始时间
start_time = time.time()
# 调用名为"calculate"的工具,传入表达式参数
result = client.call_tool("calculate", {"expression": "10 + 20 * 3"})
# 记录工具调用结束时间
end_time = time.time()
# 如果调用有返回结果
if result:
# 获取结果中的content字段,默认为空列表
content = result.get("content", [])
# 如果content不为空
if content:
# 打印计算结果
print(f" 计算结果: {content[0].get('text', '')}")
# 打印工具调用的执行时间,保留两位小数
print(f"⏱️ 执行时间: {end_time - start_time:.2f}秒")
# 打印演示2的标题
+ print("📊 演示2: 带进度跟踪的工具调用")
# 打印开始执行带进度跟踪的计算提示
+ print("🔄 开始执行带进度跟踪的计算...")
# 记录带进度跟踪工具调用的开始时间
+ start_time = time.time()
# 调用带进度跟踪的工具
+ result = client.call_tool(
+ "calculate", {"expression": "100 / 4 + 25 * 2"}, enable_progress=True
+ )
# 记录带进度跟踪工具调用的结束时间
+ end_time = time.time()
# 如果调用有返回结果
+ if result:
# 获取结果中的content字段
+ content = result.get("content", [])
# 如果content不为空
+ if content:
# 打印计算结果
+ print(f" 计算结果: {content[0].get('text', '')}")
# 打印工具调用的执行时间,保留两位小数
+ print(f"⏱️ 执行时间: {end_time - start_time:.2f}秒")
# 等待SSE监听线程结束(实际上由于daemon=True,主线程结束时会自动退出)
+ sse_listerner_thread.join()
# 捕获并处理异常
except Exception as e:
# 打印错误信息
print(f" 演示过程中出现错误: {e}")
# 导入traceback模块用于打印详细的异常信息
import traceback
# 打印异常的详细堆栈信息
traceback.print_exc()
# 定义主函数
def main():
"""主函数"""
# 打印演示标题
print("🔧 MCP进度跟踪功能演示")
# 提示用户确保服务器已启动
print("请确保MCP服务器正在运行 (python server.py)")
# 调用进度跟踪演示函数
progress()
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main()
3.3. server.py #
server.py
# 导入logging模块,用于日志记录
import logging
# 导入json模块,用于处理JSON数据
import json
# 导入uuid模块,用于生成唯一会话ID
import uuid
# 导入time模块,用于时间戳
import time
# 从tools模块导入tools对象
from tools import tools
# 从http.server模块导入BaseHTTPRequestHandler和ThreadingHTTPServer类
+from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
# 配置日志记录级别为INFO
logging.basicConfig(level=logging.INFO)
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义MCPServer类,作为MCP服务器的核心类
class MCPServer:
"""MCP服务器核心类"""
# 初始化方法
def __init__(self):
# 协议版本号
self.protocol_version = "2025-06-18"
# 服务器信息,包括名称和版本
self.server_info = {"name": "MCP服务器", "version": "1.0.0"}
# 服务器能力,包括工具、资源和提示
self.capabilities = {"tools": {}, "resources": {}, "prompts": {}}
# 会话管理字典
self.sessions = {}
# 获取所有活跃会话信息
def get_all_sessions(self):
"""获取所有活跃会话信息"""
# 返回会话字典的副本
return self.sessions.copy()
# 处理工具列表请求
def handle_tools_list(self, params=None, session_id: str = None):
"""处理工具列表请求"""
# 如果有会话ID,则获取或创建会话
if session_id:
self._get_or_create_session(session_id)
# 返回工具列表
return {"tools": tools}
# 处理初始化请求(旧方法,已被覆盖)
def handle_initialize(self, params, session_id: str = None):
"""处理初始化请求"""
# 返回服务器能力
return {"capabilities": self.capabilities}
# 生成唯一的会话ID
def _generate_session_id(self) -> str:
"""生成唯一的会话ID"""
# 使用uuid4生成唯一ID
return str(uuid.uuid4())
# 获取会话信息
def get_session_info(self, session_id: str):
"""获取会话信息"""
# 如果会话ID存在于sessions字典中,返回会话信息
if session_id in self.sessions:
return self.sessions[session_id]
# 否则返回None
return None
# 清理过期会话(默认1小时)
def _cleanup_expired_sessions(self, max_age: int = 3600):
"""清理过期会话(默认1小时)"""
# 获取当前时间戳
current_time = time.time()
# 记录过期会话ID的列表
expired_sessions = []
# 遍历所有会话,查找过期的
for session_id, session_data in self.sessions.items():
if current_time - session_data["last_activity"] > max_age:
expired_sessions.append(session_id)
# 删除所有过期会话
for session_id in expired_sessions:
del self.sessions[session_id]
logger.info(f"清理过期会话: {session_id}")
# 返回清理的会话数量
return len(expired_sessions)
# 获取或创建会话
def _get_or_create_session(self, session_id: str = None):
"""获取或创建会话"""
# 如果没有传入会话ID,则生成一个新的
if not session_id:
session_id = self._generate_session_id()
# 如果会话ID不存在于sessions字典中,则新建一个会话
if session_id not in self.sessions:
self.sessions[session_id] = {
"created_at": time.time(),
"last_activity": time.time(),
"message_count": 0,
"client_info": None,
}
# 记录新会话的创建
logger.info(f"创建新会话: {session_id}")
# 更新会话的最后活动时间
self.sessions[session_id]["last_activity"] = time.time()
# 增加消息计数
self.sessions[session_id]["message_count"] += 1
# 返回会话ID
return session_id
# 发送进度通知
+ def _send_progress_notification(
+ self,
+ progress_token: str,
+ progress: int,
+ total: int,
+ message: str,
+ session_id: str,
+ ):
+ """发送进度通知"""
# 构造进度通知的消息体,符合JSON-RPC 2.0规范
+ notification = {
+ "jsonrpc": "2.0",
+ "method": "notifications/progress",
+ "params": {
+ "progressToken": progress_token, # 进度令牌
+ "progress": progress, # 当前进度
+ "total": total, # 总进度
+ "message": message, # 进度消息
+ },
+ }
# 记录发送进度通知的日志,包含进度和消息内容
+ logger.info(f"发送进度通知: {progress}/{total} - {message}")
# 仅写入该会话的待发送队列
+ if session_id in self.sessions:
+ session_info = self.sessions[session_id]
+ if "pending_notifications" not in session_info:
+ session_info["pending_notifications"] = []
+ session_info["pending_notifications"].append(notification)
# 处理工具调用请求
def handle_tools_call(self, params, session_id: str = None):
"""处理工具调用请求"""
# 如果有会话ID,则获取或创建会话
if session_id:
self._get_or_create_session(session_id)
# 获取调用列表
calls = params.get("calls", [])
# 初始化结果列表
results = []
# 初始化进度令牌
+ progress_token = None
# 检查是否有进度令牌
+ if "_meta" in params and "progressToken" in params["_meta"]:
# 获取进度令牌
+ progress_token = params["_meta"]["progressToken"]
# 记录收到进度令牌的信息
+ logger.info(f"收到进度令牌: {progress_token}")
# 遍历每个调用
for call in calls:
# 获取工具名称
tool_name = call.get("name")
# 获取参数
arguments = call.get("arguments", {})
# 如果工具名称为"calculate"
if tool_name == "calculate":
# 获取表达式
expression = arguments.get("expression", "")
try:
# 如果启用了进度跟踪,发送进度为50的通知,提示“开始计算...”
+ if progress_token:
+ time.sleep(0.1)
+ self._send_progress_notification(
+ progress_token, 50, 100, "开始计算...", session_id
+ )
+ time.sleep(0.2)
+ self._send_progress_notification(
+ progress_token, 100, 100, "计算完成...", session_id
+ )
# 使用eval函数计算表达式的结果
result = eval(expression)
# 添加成功结果到结果列表
results.append(
{
"name": tool_name,
"content": [
{
"type": "text",
"text": f"计算结果: {expression} = {result}",
}
],
}
)
except Exception as e:
# 计算出错,添加错误信息
results.append(
{
"name": tool_name,
"isError": True,
"error": {"message": f"计算错误: {str(e)}"},
}
)
else:
# 未知工具,添加错误信息
results.append(
{
"name": tool_name,
"isError": True,
"error": {"message": f"未知工具: {tool_name}"},
}
)
# 返回所有调用的结果
return {"calls": results}
# 处理初始化请求(实际生效的方法,覆盖上面的方法)
def handle_initialize(self, params, session_id: str = None):
"""处理初始化请求"""
# 获取客户端信息
client_info = params.get("clientInfo", {})
# 获取或创建会话ID
session_id = self._get_or_create_session(session_id)
# 存储客户端信息到会话
self.sessions[session_id]["client_info"] = client_info
# 记录客户端初始化信息
logger.info(f"客户端初始化: {client_info}, 会话ID: {session_id}")
# 返回初始化结果,包括协议版本、能力、服务器信息和会话ID
return {
"protocolVersion": self.protocol_version,
"capabilities": self.capabilities,
"serverInfo": self.server_info,
"sessionId": session_id, # 返回会话ID给客户端
}
# 定义MCPHTTPHandler类,继承自BaseHTTPRequestHandler
class MCPHTTPHandler(BaseHTTPRequestHandler):
"""HTTP处理器,支持Streamable HTTP传输"""
# 初始化方法,接收mcp_server参数
def __init__(self, *args, mcp_server=None, **kwargs):
# 保存MCP服务器实例
self.mcp_server = mcp_server
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 处理GET请求(建立SSE流)
def do_GET(self):
"""处理GET请求(建立SSE流)"""
try:
# 如果请求路径不是/mcp,返回404
if self.path != "/mcp":
self.send_error(404, "MCP endpoint not found")
return
# 检查Accept头
accept_header = self.headers.get("Accept", "")
# 如果Accept头不包含text/event-stream,返回405
if "text/event-stream" not in accept_header:
self.send_error(405, "Method not allowed")
return
# 提取会话ID
session_id = self.headers.get("Mcp-Session-Id")
if session_id:
# 验证会话是否存在
session_info = self.mcp_server.get_session_info(session_id)
if not session_info:
self.send_error(401, "Invalid session ID")
return
# 记录使用现有会话
logger.info(f"SSE连接使用现有会话: {session_id}")
else:
# 为SSE连接创建新会话
session_id = self.mcp_server._get_or_create_session(None)
# 记录新会话
logger.info(f"SSE连接创建新会话: {session_id}")
# 设置SSE头
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
self.send_header("MCP-Protocol-Version", self.mcp_server.protocol_version)
self.send_header("Mcp-Session-Id", session_id) # 返回会话ID
self.end_headers()
# 发送初始事件
self._send_sse_event(
"connected", {"status": "connected", "sessionId": session_id}
)
# 定时发送服务器时间和进度通知
try:
while True:
# 每0.2秒发送一次
+ time.sleep(0.2)
# 检查并发送待处理的进度通知
+ session_info = self.mcp_server.get_session_info(session_id)
+ if session_info and "pending_notifications" in session_info:
+ pending_notifications = session_info["pending_notifications"]
+ if pending_notifications:
# 发送所有待处理的进度通知
+ for notification in pending_notifications:
+ self._send_sse_event(
+ "notifications/progress", notification["params"]
+ )
+ logger.info(
+ f"通过SSE发送进度通知: {notification['params']}"
+ )
# 清空已发送的通知
+ session_info["pending_notifications"] = []
+ else:
+ pass
# 可以在此处发送ping保持连接
# self.wfile.write(b": ping\n\n")
# self.wfile.flush()
except (BrokenPipeError, ConnectionResetError):
# 连接断开,记录日志
logger.info(f"SSE连接已断开,会话ID: {session_id}")
except Exception as e:
# 发送事件出错,记录日志
logger.error(f"SSE事件发送错误: {e}, 会话ID: {session_id}")
except Exception as e:
# 处理GET请求出错,记录日志
logger.error(f"GET处理错误: {e}")
# 返回500错误
self.send_error(500, "Internal Server Error")
# 发送SSE事件
def _send_sse_event(self, event_type: str, data):
"""发送SSE事件"""
try:
# 构造事件数据
event_data = f"event: {event_type}\n"
event_data += f"data: {json.dumps(data)}\n\n"
# 写入事件到输出流
self.wfile.write(event_data.encode("utf-8"))
self.wfile.flush()
except Exception as e:
# 发送事件出错,记录日志
logger.error(f"发送SSE事件错误: {e}")
# 处理MCP消息
def _handle_mcp_message(self, message, session_id: str = None):
"""处理MCP消息"""
# 获取方法名
method = message.get("method")
# 获取参数
params = message.get("params", {})
# 获取消息ID
msg_id = message.get("id")
# 判断方法类型
if method == "initialize":
# 调用服务器的handle_initialize方法
result = self.mcp_server.handle_initialize(params, session_id)
# 如果有消息ID,返回带ID的结果
return {"id": msg_id, "result": result} if msg_id else None
elif method == "tools/list":
# 处理工具列表请求
result = self.mcp_server.handle_tools_list(params, session_id)
return {"id": msg_id, "result": result} if msg_id else None
elif method == "tools/call":
# 处理工具调用请求
result = self.mcp_server.handle_tools_call(params, session_id)
return {"id": msg_id, "result": result} if msg_id else None
else:
# 如果方法不存在,返回错误
if msg_id:
return {
"id": msg_id,
"error": {"code": -32601, "message": f"Method not found: {method}"},
}
# 没有ID则返回None
return None
# 处理POST请求(发送消息到服务器)
def do_POST(self):
"""处理POST请求(发送消息到服务器)"""
try:
# 检查请求路径是否为/mcp
if self.path != "/mcp":
self.send_error(404, "MCP endpoint not found")
return
# 检查Accept头是否包含application/json或text/event-stream
accept_header = self.headers.get("Accept", "")
if (
"application/json" not in accept_header
and "text/event-stream" not in accept_header
):
self.send_error(400, "Missing required Accept header")
return
# 检查MCP-Protocol-Version头是否匹配
protocol_version = self.headers.get("MCP-Protocol-Version")
if (
protocol_version
and protocol_version != self.mcp_server.protocol_version
):
self.send_error(
400, f"Unsupported protocol version: {protocol_version}"
)
return
# 获取会话ID
session_id = self.headers.get("Mcp-Session-Id")
# 读取请求体长度
content_length = int(self.headers.get("Content-Length", 0))
# 读取请求体内容
body = self.rfile.read(content_length)
try:
# 尝试解析JSON消息
message = json.loads(body.decode("utf-8"))
except json.JSONDecodeError:
# JSON解析失败
self.send_error(400, "Invalid JSON")
return
# 处理MCP消息
response = self._handle_mcp_message(message, session_id)
# 发送响应
if response:
# 发送200响应码
self.send_response(200)
# 设置响应头Content-Type为application/json
self.send_header("Content-Type", "application/json")
# 设置协议版本头
self.send_header(
"MCP-Protocol-Version", self.mcp_server.protocol_version
)
# 如果响应中包含会话ID,则添加到响应头
if "result" in response and "sessionId" in response["result"]:
self.send_header("Mcp-Session-Id", response["result"]["sessionId"])
# 结束响应头
self.end_headers()
# 写入JSON响应体
self.wfile.write(json.dumps(response).encode("utf-8"))
else:
# 没有响应内容,返回202 Accepted
self.send_response(202) # Accepted
self.end_headers()
except Exception as e:
# 记录错误日志
logger.error(f"POST处理错误: {e}")
# 返回500内部服务器错误
self.send_error(500, "Internal Server Error")
# 定义运行HTTP服务器的函数
def run_http_server(mcp_server: MCPServer, host: str = "localhost", port: int = 8000):
"""运行HTTP服务器"""
# 定义处理器工厂函数,用于传递mcp_server实例
def handler_factory(*args, **kwargs):
return MCPHTTPHandler(*args, mcp_server=mcp_server, **kwargs)
# 创建HTTPServer实例,绑定主机和端口
+ server = ThreadingHTTPServer((host, port), handler_factory)
+ server.daemon_threads = True
# 记录服务器启动信息
logger.info(f"HTTP服务器运行在 http://{host}:{port}/mcp")
# 记录会话管理功能启用信息
logger.info("会话管理功能已启用,支持会话ID跟踪和自动清理")
# 启动会话清理线程
import threading
# 定义会话清理函数
def cleanup_sessions():
while True:
try:
# 每5分钟清理一次
time.sleep(300)
# 调用会话清理方法
cleaned_count = mcp_server._cleanup_expired_sessions()
# 如果有清理的会话,记录日志
if cleaned_count > 0:
logger.info(f"清理了 {cleaned_count} 个过期会话")
# 记录当前活跃会话数
active_sessions = len(mcp_server.get_all_sessions())
+ logger.info(f"当前活跃会话数: {active_sessions}")
except Exception as e:
# 清理出错,记录日志
logger.error(f"会话清理错误: {e}")
# 启动清理线程
cleanup_thread = threading.Thread(target=cleanup_sessions, daemon=True)
cleanup_thread.start()
logger.info("会话清理线程已启动")
try:
# 启动服务器,进入循环监听请求
server.serve_forever()
except KeyboardInterrupt:
# 捕获Ctrl+C,记录服务器停止
logger.info("HTTP服务器已停止")
logger.info("正在清理所有会话...")
# 清理所有会话
all_sessions = mcp_server.get_all_sessions()
for session_id in list(all_sessions.keys()):
del mcp_server.sessions[session_id]
logger.info(f"已清理 {len(all_sessions)} 个会话")
finally:
# 关闭服务器
server.server_close()
# 定义主函数
def main():
"""主函数"""
# 导入argparse模块,用于解析命令行参数
import argparse
# 创建ArgumentParser对象,设置描述信息
parser = argparse.ArgumentParser(description="MCP HTTP服务器")
# 添加--host参数,指定服务器主机,默认localhost
parser.add_argument(
"--host", default="localhost", help="HTTP服务器主机 (默认: localhost)"
)
# 添加--port参数,指定服务器端口,默认8000
parser.add_argument(
"--port", type=int, default=8000, help="HTTP服务器端口 (默认: 8000)"
)
# 解析命令行参数
args = parser.parse_args()
# 创建MCP服务器实例
mcp_server = MCPServer()
# 运行HTTP服务器
run_http_server(mcp_server, args.host, args.port)
# 判断是否为主程序入口
if __name__ == "__main__":
# 调用主函数
main()