import json import subprocess import asyncio import sys import os from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from pydantic import BaseModel from typing import Dict, List, Any, Optional import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware class ToolModel(BaseModel): name: str description: Optional[str] = None inputSchema: Optional[Dict] = None class CallRequest(BaseModel): name: str arguments: Optional[Dict] = {} class MCPServerManager: def __init__(self, config_file="mcp_config.json"): # 从配置文件加载MCP服务配置 self.mcp_servers = self._load_config(config_file) # 缓存每个服务的工具列表 self._tools_cache = {} # 存储活动的会话 self._active_sessions = {} def _load_config(self, config_file): """ 从配置文件加载MCP服务配置 Args: config_file (str): 配置文件路径 Returns: dict: MCP服务配置 """ try: # 加载配置文件 with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) return config.get("mcpServers", {}) except Exception as e: print(f"加载MCP配置文件时出错: {e}") # 返回空配置 return {} async def _test_server_command(self, server_name): """ 测试服务器命令是否可用 Args: server_name (str): 服务器名称 Returns: bool: 命令是否可用 """ if server_name not in self.mcp_servers: return False server_config = self.mcp_servers[server_name] try: # 检查命令是否存在 process = subprocess.Popen([server_config["command"]] + server_config["args"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=(sys.platform == 'win32')) # 给进程一点时间启动 await asyncio.sleep(0.5) # 检查进程是否仍在运行 if process.poll() is None: # 进程仍在运行,终止它 process.terminate() try: process.wait(timeout=1) except subprocess.TimeoutExpired: process.kill() return True else: # 进程已退出,检查返回码 _, stderr = process.communicate() if process.returncode == 0 or process.returncode == 1: # 返回码为0(成功)或1(一般错误)表示命令存在 return True else: print(f"命令执行失败 {server_name}: {stderr.decode()}") return False except FileNotFoundError: print(f"命令未找到: {server_config['command']}") return False except Exception as e: print(f"测试命令时出错 {server_name}: {e}") return False async def call_mcp_server(self, server_name, method, params=None): """ 调用指定的MCP服务器 Args: server_name (str): 服务器名称 method (str): 要调用的方法 params (dict): 方法参数 Returns: dict: MCP服务器的响应 """ if server_name not in self.mcp_servers: return {"error": f"MCP server '{server_name}' not found"} # 检查命令是否可用 if not await self._test_server_command(server_name): return {"error": f"MCP server '{server_name}' 命令不可用"} try: # 创建临时会话(每次调用都创建新会话) server_config = self.mcp_servers[server_name] async with stdio_client( StdioServerParameters( command=server_config["command"], args=server_config["args"], env=None ) ) as (read, write): async with ClientSession(read, write) as session: # 初始化MCP服务器 await session.initialize() # 调用指定方法 if method == "prompts/list": result = await session.list_prompts() elif method == "prompts/get" and params: result = await session.get_prompt(params["name"]) elif method == "resources/list": result = await session.list_resources() elif method == "resources/read" and params: result = await session.read_resource(params["uri"]) elif method == "tools/list": result = await session.list_tools() elif method == "tools/call" and params: result = await session.call_tool(params["name"], params.get("arguments", {})) else: # 通用方法调用 result = await session.send_request(method, params or {}) # 将结果转换为可序列化的字典 return self._serialize_result(result) except Exception as e: return {"error": f"调用MCP服务时出错: {str(e)}"} def _serialize_result(self, result): """ 将MCP结果转换为可JSON序列化的格式 """ from pydantic import BaseModel if isinstance(result, BaseModel): return result.model_dump() elif isinstance(result, dict): return {key: self._serialize_result(value) for key, value in result.items()} elif isinstance(result, list): return [self._serialize_result(item) for item in result] else: return result async def get_actual_mcp_tools(self, server_name): """ 获取指定MCP服务的实际工具列表 Args: server_name (str): 服务器名称 Returns: list: 实际的工具列表 """ if server_name not in self.mcp_servers: return [] try: result = await self.call_mcp_server(server_name, "tools/list") if isinstance(result, dict) and "tools" in result: # 缓存工具列表 self._tools_cache[server_name] = result["tools"] return result["tools"] return [] except Exception as e: print(f"获取 {server_name} 工具列表时出错: {e}") return [] async def get_all_mcp_tools(self): """ 获取所有MCP服务的实际工具列表 Returns: list: 所有MCP服务的工具列表,每个工具名称后缀添加_mcp """ all_tools = [] for server_name in self.mcp_servers.keys(): try: tools = await self.get_actual_mcp_tools(server_name) # 为每个工具名称添加_mcp后缀以区分 for tool in tools: # 创建工具副本并修改名称 tool_copy = tool.copy() tool_copy['name'] = f"{tool['name']}_mcp" all_tools.append(tool_copy) except Exception as e: print(f"获取 {server_name} 工具列表时出错: {e}") return all_tools async def call_mcp_tool(self, tool_name, arguments): """ 执行MCP工具调用 Args: tool_name (str): 工具名称(需要以_mcp结尾) arguments (dict): 工具参数 Returns: dict: 工具执行结果 """ try: # 验证工具名称格式 if not tool_name.endswith("_mcp"): return {"error": f"无效的工具名称格式: {tool_name}"} # 移除_mcp后缀获取原始工具名称 original_tool_name = tool_name[:-4] # 查找工具所属的服务器 server_name = self._find_server_for_tool(original_tool_name) if not server_name: return {"error": f"未找到工具 {tool_name} 对应的MCP服务器"} # 调用工具 result = await self.call_mcp_server( server_name, "tools/call", { "name": original_tool_name, "arguments": arguments } ) return result except Exception as e: return {"error": f"MCP工具调用失败: {str(e)}"} def _find_server_for_tool(self, tool_name): """ 根据工具名称查找对应的服务器名称 Args: tool_name (str): 工具名称 Returns: str or None: 服务器名称或None """ # 遍历缓存的工具列表查找对应的服务器 for server_name, tools in self._tools_cache.items(): for tool in tools: if tool.get("name") == tool_name: return server_name return None async def initialize_all_servers(self): """ 初始化所有MCP服务器并获取工具列表 """ print("正在测试MCP服务器连接...") for server_name in self.mcp_servers.keys(): if await self._test_server_command(server_name): print(f"✓ {server_name} 命令可用") else: print(f"✗ {server_name} 命令不可用") # 创建FastAPI应用 app = FastAPI(title="MCP API Server", description="MCP服务API接口") # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 创建全局MCP服务器管理器实例 mcp_manager = MCPServerManager() @app.get("/") async def root(): return {"message": "MCP API Server is running"} @app.get("/tools", response_model=List[ToolModel]) async def get_tools(): """ 获取所有MCP工具列表 """ try: tools = await mcp_manager.get_all_mcp_tools() return tools except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/call") async def call_tool(request: CallRequest): """ 调用指定的MCP工具 """ try: result = await mcp_manager.call_mcp_tool(request.name, request.arguments) # 检查是否有错误 if isinstance(result, dict) and "error" in result: raise HTTPException(status_code=500, detail=result["error"]) return result except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": # 初始化所有服务器 print("正在初始化MCP服务器...") loop = asyncio.get_event_loop() loop.run_until_complete(mcp_manager.initialize_all_servers()) print("MCP服务器初始化完成") # 启动FastAPI服务器 print("启动MCP API服务器: http://localhost:8000") uvicorn.run("mcp_api_server:app", host="localhost", port=8000, reload=False)