import json from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from pydantic import BaseModel from typing import Dict, List, Any, Optional import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager class ToolModel(BaseModel): name: str description: Optional[str] = None inputSchema: Optional[Dict] = None class CallRequest(BaseModel): name: str arguments: Optional[Dict] = {} class PersistentMCPService: """ 持久化MCP服务类,用于管理长期运行的MCP服务连接 """ def __init__(self, server_name: str, command: str, args: List[str]): self.server_name = server_name self.command = command self.args = args self.client_context = None self.session = None async def start(self): """ 启动MCP服务并建立持久化连接 """ try: print(f"正在启动MCP服务: {self.server_name}") # 创建stdio客户端连接 stdio_params = StdioServerParameters( command=self.command, args=self.args, env=None ) # 使用上下文管理器创建持久化连接 self.client_context = stdio_client(stdio_params) read, write = await self.client_context.__aenter__() # 创建会话 self.session = ClientSession(read, write) await self.session.__aenter__() # 初始化会话 await self.session.initialize() print(f"✓ MCP服务 {self.server_name} 启动成功") return True except Exception as e: print(f"✗ 启动MCP服务 {self.server_name} 失败: {e}") await self.stop() return False async def stop(self): """ 停止MCP服务 """ try: if self.session: await self.session.__aexit__(None, None, None) self.session = None if self.client_context: await self.client_context.__aexit__(None, None, None) self.client_context = None except Exception as e: print(f"停止MCP服务 {self.server_name} 时出错: {e}") async def call_method(self, method: str, params: Optional[Dict] = None): """ 调用MCP方法 """ if not self.session: return {"error": f"MCP服务 {self.server_name} 未连接"} try: if method == "prompts/list": result = await self.session.list_prompts() elif method == "prompts/get" and params: result = await self.session.get_prompt(params["name"]) elif method == "resources/list": result = await self.session.list_resources() elif method == "resources/read" and params: result = await self.session.read_resource(params["uri"]) elif method == "tools/list": result = await self.session.list_tools() elif method == "tools/call" and params: result = await self.session.call_tool(params["name"], params.get("arguments", {})) else: # 通用方法调用 result = await self.session.send_request(method, params or {}) return self._serialize_result(result) except Exception as e: return {"error": f"调用MCP服务 {self.server_name} 时出错: {str(e)}"} def _serialize_result(self, result): """ 将MCP结果转换为可JSON序列化的格式 """ from pydantic import BaseModel if isinstance(result, BaseModel): return result.model_dump() elif isinstance(result, dict): return {key: self._serialize_result(value) for key, value in result.items()} elif isinstance(result, list): return [self._serialize_result(item) for item in result] else: return result class MCPServerManager: def __init__(self, config_file="mcp_config.json"): # 从配置文件加载MCP服务配置 self.mcp_servers = self._load_config(config_file) # 存储持久化的服务实例 self.persistent_services: Dict[str, PersistentMCPService] = {} # 缓存每个服务的工具列表 self._tools_cache = {} def _load_config(self, config_file): """ 从配置文件加载MCP服务配置 Args: config_file (str): 配置文件路径 Returns: dict: MCP服务配置 """ try: # 加载配置文件 with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) return config.get("mcpServers", {}) except Exception as e: print(f"加载MCP配置文件时出错: {e}") # 返回空配置 return {} async def start_all_services(self): """ 启动所有MCP服务 """ print("正在启动所有MCP服务...") for server_name, server_config in self.mcp_servers.items(): service = PersistentMCPService( server_name, server_config["command"], server_config["args"] ) success = await service.start() if success: self.persistent_services[server_name] = service else: print(f"✗ MCP服务 {server_name} 启动失败") async def stop_all_services(self): """ 停止所有MCP服务 """ print("正在停止所有MCP服务...") for server_name, service in self.persistent_services.items(): await service.stop() self.persistent_services.clear() async def call_mcp_server(self, server_name, method, params=None): """ 调用指定的MCP服务器 Args: server_name (str): 服务器名称 method (str): 要调用的方法 params (dict): 方法参数 Returns: dict: MCP服务器的响应 """ if server_name not in self.persistent_services: return {"error": f"MCP服务 '{server_name}' 未启动或启动失败"} service = self.persistent_services[server_name] return await service.call_method(method, params) async def get_actual_mcp_tools(self, server_name): """ 获取指定MCP服务的实际工具列表 Args: server_name (str): 服务器名称 Returns: list: 实际的工具列表 """ if server_name not in self.mcp_servers: return [] try: result = await self.call_mcp_server(server_name, "tools/list") if isinstance(result, dict) and "tools" in result: # 缓存工具列表 self._tools_cache[server_name] = result["tools"] return result["tools"] return [] except Exception as e: print(f"获取 {server_name} 工具列表时出错: {e}") return [] async def get_all_mcp_tools(self): """ 获取所有MCP服务的实际工具列表 Returns: list: 所有MCP服务的工具列表,每个工具名称后缀添加_mcp """ all_tools = [] for server_name in self.mcp_servers.keys(): try: tools = await self.get_actual_mcp_tools(server_name) # 为每个工具名称添加_mcp后缀以区分 for tool in tools: # 创建工具副本并修改名称 tool_copy = tool.copy() tool_copy['name'] = f"{tool['name']}_mcp" all_tools.append(tool_copy) except Exception as e: print(f"获取 {server_name} 工具列表时出错: {e}") return all_tools async def call_mcp_tool(self, tool_name, arguments): """ 执行MCP工具调用 Args: tool_name (str): 工具名称(需要以_mcp结尾) arguments (dict): 工具参数 Returns: dict: 工具执行结果 """ try: # 验证工具名称格式 if not tool_name.endswith("_mcp"): return {"error": f"无效的工具名称格式: {tool_name}"} # 移除_mcp后缀获取原始工具名称 original_tool_name = tool_name[:-4] # 查找工具所属的服务器 server_name = self._find_server_for_tool(original_tool_name) if not server_name: return {"error": f"未找到工具 {tool_name} 对应的MCP服务器"} # 调用工具 result = await self.call_mcp_server( server_name, "tools/call", { "name": original_tool_name, "arguments": arguments } ) return result except Exception as e: return {"error": f"MCP工具调用失败: {str(e)}"} def _find_server_for_tool(self, tool_name): """ 根据工具名称查找对应的服务器名称 Args: tool_name (str): 工具名称 Returns: str or None: 服务器名称或None """ # 遍历缓存的工具列表查找对应的服务器 for server_name, tools in self._tools_cache.items(): for tool in tools: if tool.get("name") == tool_name: return server_name return None # 创建全局MCP服务器管理器实例 mcp_manager = MCPServerManager() @asynccontextmanager async def lifespan(app: FastAPI): """ 应用生命周期管理器:启动时启动所有MCP服务,关闭时停止所有服务 """ # 启动事件 await mcp_manager.start_all_services() yield # 关闭事件 await mcp_manager.stop_all_services() # 创建FastAPI应用 app = FastAPI( lifespan=lifespan ) # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/tools", response_model=List[ToolModel]) async def get_tools(): """ 获取所有MCP工具列表 """ try: tools = await mcp_manager.get_all_mcp_tools() return tools except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/call") async def call_tool(request: CallRequest): """ 调用指定的MCP工具 """ try: result = await mcp_manager.call_mcp_tool(request.name, request.arguments) # 检查是否有错误 if isinstance(result, dict) and "error" in result: raise HTTPException(status_code=500, detail=result["error"]) return result except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": # 启动FastAPI服务器 uvicorn.run("mcp_api_server:app", host="localhost", port=8000, reload=False)