|
@@ -0,0 +1,417 @@
|
|
|
+import json
|
|
|
+import subprocess
|
|
|
+import asyncio
|
|
|
+import sys
|
|
|
+from mcp import ClientSession, StdioServerParameters
|
|
|
+from mcp.client.stdio import stdio_client
|
|
|
+from pydantic import BaseModel
|
|
|
+from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
|
+from urllib.parse import urlparse, parse_qs
|
|
|
+
|
|
|
+class MCPServerManager:
|
|
|
+ def __init__(self, config_file="mcp_config.json"):
|
|
|
+ # 从配置文件加载MCP服务配置
|
|
|
+ self.mcp_servers = self._load_config(config_file)
|
|
|
+ # 缓存每个服务的工具列表
|
|
|
+ self._tools_cache = {}
|
|
|
+ # 存储活动的会话
|
|
|
+ self._active_sessions = {}
|
|
|
+
|
|
|
+ def _load_config(self, config_file):
|
|
|
+ """
|
|
|
+ 从配置文件加载MCP服务配置
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_file (str): 配置文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: MCP服务配置
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 加载配置文件
|
|
|
+ with open(config_file, 'r', encoding='utf-8') as f:
|
|
|
+ config = json.load(f)
|
|
|
+ return config.get("mcpServers", {})
|
|
|
+ except Exception as e:
|
|
|
+ print(f"加载MCP配置文件时出错: {e}")
|
|
|
+ # 返回空配置
|
|
|
+ return {}
|
|
|
+
|
|
|
+ async def _test_server_command(self, server_name):
|
|
|
+ """
|
|
|
+ 测试服务器命令是否可用
|
|
|
+
|
|
|
+ Args:
|
|
|
+ server_name (str): 服务器名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 命令是否可用
|
|
|
+ """
|
|
|
+ if server_name not in self.mcp_servers:
|
|
|
+ return False
|
|
|
+
|
|
|
+ server_config = self.mcp_servers[server_name]
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 检查命令是否存在
|
|
|
+ process = subprocess.Popen([server_config["command"]] + server_config["args"],
|
|
|
+ stdout=subprocess.PIPE,
|
|
|
+ stderr=subprocess.PIPE,
|
|
|
+ shell=(sys.platform == 'win32'))
|
|
|
+ # 给进程一点时间启动
|
|
|
+ await asyncio.sleep(0.5)
|
|
|
+
|
|
|
+ # 检查进程是否仍在运行
|
|
|
+ if process.poll() is None:
|
|
|
+ # 进程仍在运行,终止它
|
|
|
+ process.terminate()
|
|
|
+ try:
|
|
|
+ process.wait(timeout=1)
|
|
|
+ except subprocess.TimeoutExpired:
|
|
|
+ process.kill()
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ # 进程已退出,检查返回码
|
|
|
+ _, stderr = process.communicate()
|
|
|
+ if process.returncode == 0 or process.returncode == 1:
|
|
|
+ # 返回码为0(成功)或1(一般错误)表示命令存在
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ print(f"命令执行失败 {server_name}: {stderr.decode()}")
|
|
|
+ return False
|
|
|
+ except FileNotFoundError:
|
|
|
+ print(f"命令未找到: {server_config['command']}")
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ print(f"测试命令时出错 {server_name}: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ async def call_mcp_server(self, server_name, method, params=None):
|
|
|
+ """
|
|
|
+ 调用指定的MCP服务器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ server_name (str): 服务器名称
|
|
|
+ method (str): 要调用的方法
|
|
|
+ params (dict): 方法参数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: MCP服务器的响应
|
|
|
+ """
|
|
|
+ if server_name not in self.mcp_servers:
|
|
|
+ return {"error": f"MCP server '{server_name}' not found"}
|
|
|
+
|
|
|
+ # 检查命令是否可用
|
|
|
+ if not await self._test_server_command(server_name):
|
|
|
+ return {"error": f"MCP server '{server_name}' 命令不可用"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 创建临时会话(每次调用都创建新会话)
|
|
|
+ server_config = self.mcp_servers[server_name]
|
|
|
+
|
|
|
+ async with stdio_client(
|
|
|
+ StdioServerParameters(
|
|
|
+ command=server_config["command"],
|
|
|
+ args=server_config["args"],
|
|
|
+ env=None
|
|
|
+ )
|
|
|
+ ) as (read, write):
|
|
|
+ async with ClientSession(read, write) as session:
|
|
|
+ # 初始化MCP服务器
|
|
|
+ await session.initialize()
|
|
|
+
|
|
|
+ # 调用指定方法
|
|
|
+ if method == "prompts/list":
|
|
|
+ result = await session.list_prompts()
|
|
|
+ elif method == "prompts/get" and params:
|
|
|
+ result = await session.get_prompt(params["name"])
|
|
|
+ elif method == "resources/list":
|
|
|
+ result = await session.list_resources()
|
|
|
+ elif method == "resources/read" and params:
|
|
|
+ result = await session.read_resource(params["uri"])
|
|
|
+ elif method == "tools/list":
|
|
|
+ result = await session.list_tools()
|
|
|
+ elif method == "tools/call" and params:
|
|
|
+ result = await session.call_tool(params["name"], params.get("arguments", {}))
|
|
|
+ else:
|
|
|
+ # 通用方法调用
|
|
|
+ result = await session.send_request(method, params or {})
|
|
|
+
|
|
|
+ # 将结果转换为可序列化的字典
|
|
|
+ return self._serialize_result(result)
|
|
|
+ except Exception as e:
|
|
|
+ return {"error": f"调用MCP服务时出错: {str(e)}"}
|
|
|
+
|
|
|
+ def _serialize_result(self, result):
|
|
|
+ """
|
|
|
+ 将MCP结果转换为可JSON序列化的格式
|
|
|
+ """
|
|
|
+ if isinstance(result, BaseModel):
|
|
|
+ return result.model_dump()
|
|
|
+ elif isinstance(result, dict):
|
|
|
+ return {key: self._serialize_result(value) for key, value in result.items()}
|
|
|
+ elif isinstance(result, list):
|
|
|
+ return [self._serialize_result(item) for item in result]
|
|
|
+ else:
|
|
|
+ return result
|
|
|
+
|
|
|
+ async def get_actual_mcp_tools(self, server_name):
|
|
|
+ """
|
|
|
+ 获取指定MCP服务的实际工具列表
|
|
|
+
|
|
|
+ Args:
|
|
|
+ server_name (str): 服务器名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ list: 实际的工具列表
|
|
|
+ """
|
|
|
+ if server_name not in self.mcp_servers:
|
|
|
+ return []
|
|
|
+
|
|
|
+ try:
|
|
|
+ result = await self.call_mcp_server(server_name, "tools/list")
|
|
|
+ if isinstance(result, dict) and "tools" in result:
|
|
|
+ # 缓存工具列表
|
|
|
+ self._tools_cache[server_name] = result["tools"]
|
|
|
+ return result["tools"]
|
|
|
+ return []
|
|
|
+ except Exception as e:
|
|
|
+ print(f"获取 {server_name} 工具列表时出错: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ async def get_all_mcp_tools(self):
|
|
|
+ """
|
|
|
+ 获取所有MCP服务的实际工具列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ list: 所有MCP服务的工具列表,每个工具名称后缀添加_mcp
|
|
|
+ """
|
|
|
+ all_tools = []
|
|
|
+ for server_name in self.mcp_servers.keys():
|
|
|
+ try:
|
|
|
+ tools = await self.get_actual_mcp_tools(server_name)
|
|
|
+ # 为每个工具名称添加_mcp后缀以区分
|
|
|
+ for tool in tools:
|
|
|
+ # 创建工具副本并修改名称
|
|
|
+ tool_copy = tool.copy()
|
|
|
+ tool_copy['name'] = f"{tool['name']}_mcp"
|
|
|
+ all_tools.append(tool_copy)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"获取 {server_name} 工具列表时出错: {e}")
|
|
|
+ return all_tools
|
|
|
+
|
|
|
+ async def call_mcp_tool(self, tool_name, arguments):
|
|
|
+ """
|
|
|
+ 执行MCP工具调用
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tool_name (str): 工具名称(需要以_mcp结尾)
|
|
|
+ arguments (dict): 工具参数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: 工具执行结果
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 验证工具名称格式
|
|
|
+ if not tool_name.endswith("_mcp"):
|
|
|
+ return {"error": f"无效的工具名称格式: {tool_name}"}
|
|
|
+
|
|
|
+ # 移除_mcp后缀获取原始工具名称
|
|
|
+ original_tool_name = tool_name[:-4]
|
|
|
+
|
|
|
+ # 查找工具所属的服务器
|
|
|
+ server_name = self._find_server_for_tool(original_tool_name)
|
|
|
+ if not server_name:
|
|
|
+ return {"error": f"未找到工具 {tool_name} 对应的MCP服务器"}
|
|
|
+
|
|
|
+ # 调用工具
|
|
|
+ result = await self.call_mcp_server(
|
|
|
+ server_name,
|
|
|
+ "tools/call",
|
|
|
+ {
|
|
|
+ "name": original_tool_name,
|
|
|
+ "arguments": arguments
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ return {"error": f"MCP工具调用失败: {str(e)}"}
|
|
|
+
|
|
|
+ def _find_server_for_tool(self, tool_name):
|
|
|
+ """
|
|
|
+ 根据工具名称查找对应的服务器名称
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tool_name (str): 工具名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str or None: 服务器名称或None
|
|
|
+ """
|
|
|
+ # 遍历缓存的工具列表查找对应的服务器
|
|
|
+ for server_name, tools in self._tools_cache.items():
|
|
|
+ for tool in tools:
|
|
|
+ if tool.get("name") == tool_name:
|
|
|
+ return server_name
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def initialize_all_servers(self):
|
|
|
+ """
|
|
|
+ 初始化所有MCP服务器并获取工具列表
|
|
|
+ """
|
|
|
+ print("正在测试MCP服务器连接...")
|
|
|
+ for server_name in self.mcp_servers.keys():
|
|
|
+ if await self._test_server_command(server_name):
|
|
|
+ print(f"✓ {server_name} 命令可用")
|
|
|
+ else:
|
|
|
+ print(f"✗ {server_name} 命令不可用")
|
|
|
+
|
|
|
+class MCPAPIHandler(BaseHTTPRequestHandler):
|
|
|
+ """
|
|
|
+ MCP API HTTP请求处理器
|
|
|
+ """
|
|
|
+
|
|
|
+ # 类变量,用于存储MCPServerManager实例
|
|
|
+ server_manager = None
|
|
|
+
|
|
|
+ def _set_headers(self, status_code=200, content_type='application/json'):
|
|
|
+ """
|
|
|
+ 设置HTTP响应头
|
|
|
+ """
|
|
|
+ self.send_response(status_code)
|
|
|
+ self.send_header('Content-type', content_type)
|
|
|
+ self.send_header('Access-Control-Allow-Origin', '*')
|
|
|
+ self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
|
|
|
+ self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
|
|
+ self.end_headers()
|
|
|
+
|
|
|
+ def do_OPTIONS(self):
|
|
|
+ """
|
|
|
+ 处理CORS预检请求
|
|
|
+ """
|
|
|
+ self._set_headers()
|
|
|
+
|
|
|
+ def do_GET(self):
|
|
|
+ """
|
|
|
+ 处理GET请求
|
|
|
+ """
|
|
|
+ parsed_path = urlparse(self.path)
|
|
|
+
|
|
|
+ # 处理 /tools 接口
|
|
|
+ if parsed_path.path == '/tools':
|
|
|
+ self._handle_tools_request()
|
|
|
+ else:
|
|
|
+ self._set_headers(404)
|
|
|
+ self.wfile.write(json.dumps({"error": "未找到接口"}).encode('utf-8'))
|
|
|
+
|
|
|
+ def do_POST(self):
|
|
|
+ """
|
|
|
+ 处理POST请求
|
|
|
+ """
|
|
|
+ parsed_path = urlparse(self.path)
|
|
|
+
|
|
|
+ # 处理 /call 接口
|
|
|
+ if parsed_path.path == '/call':
|
|
|
+ self._handle_call_request()
|
|
|
+ else:
|
|
|
+ self._set_headers(404)
|
|
|
+ self.wfile.write(json.dumps({"error": "未找到接口"}).encode('utf-8'))
|
|
|
+
|
|
|
+ def _handle_tools_request(self):
|
|
|
+ """
|
|
|
+ 处理 /tools 请求,返回所有MCP工具列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 在新的事件循环中运行异步代码
|
|
|
+ loop = asyncio.new_event_loop()
|
|
|
+ asyncio.set_event_loop(loop)
|
|
|
+
|
|
|
+ tools = loop.run_until_complete(
|
|
|
+ self.server_manager.get_all_mcp_tools()
|
|
|
+ )
|
|
|
+ loop.close()
|
|
|
+
|
|
|
+ self._set_headers()
|
|
|
+ self.wfile.write(json.dumps(tools, ensure_ascii=False).encode('utf-8'))
|
|
|
+ except Exception as e:
|
|
|
+ self._set_headers(500)
|
|
|
+ self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
|
|
|
+
|
|
|
+ def _handle_call_request(self):
|
|
|
+ """
|
|
|
+ 处理 /call 请求,调用指定的MCP工具
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 读取请求体
|
|
|
+ content_length = int(self.headers['Content-Length'])
|
|
|
+ post_data = self.rfile.read(content_length)
|
|
|
+
|
|
|
+ # 解析JSON数据
|
|
|
+ data = json.loads(post_data.decode('utf-8'))
|
|
|
+
|
|
|
+ # 获取必要参数
|
|
|
+ tool_name = data.get('name')
|
|
|
+ arguments = data.get('arguments', {})
|
|
|
+
|
|
|
+ if not tool_name:
|
|
|
+ self._set_headers(400)
|
|
|
+ self.wfile.write(json.dumps({"error": "缺少工具名称"}).encode('utf-8'))
|
|
|
+ return
|
|
|
+
|
|
|
+ # 在新的事件循环中运行异步代码
|
|
|
+ loop = asyncio.new_event_loop()
|
|
|
+ asyncio.set_event_loop(loop)
|
|
|
+
|
|
|
+ result = loop.run_until_complete(
|
|
|
+ self.server_manager.call_mcp_tool(tool_name, arguments)
|
|
|
+ )
|
|
|
+ loop.close()
|
|
|
+
|
|
|
+ # 检查是否有错误
|
|
|
+ if isinstance(result, dict) and "error" in result:
|
|
|
+ self._set_headers(500)
|
|
|
+ else:
|
|
|
+ self._set_headers()
|
|
|
+
|
|
|
+ self.wfile.write(json.dumps(result, ensure_ascii=False).encode('utf-8'))
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ self._set_headers(400)
|
|
|
+ self.wfile.write(json.dumps({"error": "无效的JSON格式"}).encode('utf-8'))
|
|
|
+ except Exception as e:
|
|
|
+ self._set_headers(500)
|
|
|
+ self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
|
|
|
+
|
|
|
+class MCPAPIServer:
|
|
|
+ """
|
|
|
+ MCP API服务器主类
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, host='localhost', port=8000, config_file='mcp_config.json'):
|
|
|
+ self.host = host
|
|
|
+ self.port = port
|
|
|
+ self.server_manager = MCPServerManager(config_file)
|
|
|
+ # 将server_manager设置为请求处理器的类变量
|
|
|
+ MCPAPIHandler.server_manager = self.server_manager
|
|
|
+
|
|
|
+ def start(self):
|
|
|
+ """
|
|
|
+ 启动MCP API服务器
|
|
|
+ """
|
|
|
+ # 初始化所有服务器
|
|
|
+ print("正在初始化MCP服务器...")
|
|
|
+ loop = asyncio.new_event_loop()
|
|
|
+ asyncio.set_event_loop(loop)
|
|
|
+ loop.run_until_complete(self.server_manager.initialize_all_servers())
|
|
|
+ loop.close()
|
|
|
+ print("MCP服务器初始化完成")
|
|
|
+
|
|
|
+ # 启动HTTP服务器
|
|
|
+ server_address = (self.host, self.port)
|
|
|
+ httpd = HTTPServer(server_address, MCPAPIHandler)
|
|
|
+ print(f"启动MCP API服务器: http://{self.host}:{self.port}")
|
|
|
+ httpd.serve_forever()
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 创建并启动MCP API服务器
|
|
|
+ api_server = MCPAPIServer()
|
|
|
+ api_server.start()
|