mcp_api_server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. import json
  2. import subprocess
  3. import asyncio
  4. import sys
  5. import os
  6. from mcp import ClientSession, StdioServerParameters
  7. from mcp.client.stdio import stdio_client
  8. from pydantic import BaseModel
  9. from typing import Dict, List, Any, Optional
  10. import uvicorn
  11. from fastapi import FastAPI, HTTPException
  12. from fastapi.middleware.cors import CORSMiddleware
  13. class ToolModel(BaseModel):
  14. name: str
  15. description: Optional[str] = None
  16. inputSchema: Optional[Dict] = None
  17. class CallRequest(BaseModel):
  18. name: str
  19. arguments: Optional[Dict] = {}
  20. class MCPServerManager:
  21. def __init__(self, config_file="mcp_config.json"):
  22. # 从配置文件加载MCP服务配置
  23. self.mcp_servers = self._load_config(config_file)
  24. # 缓存每个服务的工具列表
  25. self._tools_cache = {}
  26. # 存储活动的会话
  27. self._active_sessions = {}
  28. def _load_config(self, config_file):
  29. """
  30. 从配置文件加载MCP服务配置
  31. Args:
  32. config_file (str): 配置文件路径
  33. Returns:
  34. dict: MCP服务配置
  35. """
  36. try:
  37. # 加载配置文件
  38. with open(config_file, 'r', encoding='utf-8') as f:
  39. config = json.load(f)
  40. return config.get("mcpServers", {})
  41. except Exception as e:
  42. print(f"加载MCP配置文件时出错: {e}")
  43. # 返回空配置
  44. return {}
  45. async def _test_server_command(self, server_name):
  46. """
  47. 测试服务器命令是否可用
  48. Args:
  49. server_name (str): 服务器名称
  50. Returns:
  51. bool: 命令是否可用
  52. """
  53. if server_name not in self.mcp_servers:
  54. return False
  55. server_config = self.mcp_servers[server_name]
  56. try:
  57. # 检查命令是否存在
  58. process = subprocess.Popen([server_config["command"]] + server_config["args"],
  59. stdout=subprocess.PIPE,
  60. stderr=subprocess.PIPE,
  61. shell=(sys.platform == 'win32'))
  62. # 给进程一点时间启动
  63. await asyncio.sleep(0.5)
  64. # 检查进程是否仍在运行
  65. if process.poll() is None:
  66. # 进程仍在运行,终止它
  67. process.terminate()
  68. try:
  69. process.wait(timeout=1)
  70. except subprocess.TimeoutExpired:
  71. process.kill()
  72. return True
  73. else:
  74. # 进程已退出,检查返回码
  75. _, stderr = process.communicate()
  76. if process.returncode == 0 or process.returncode == 1:
  77. # 返回码为0(成功)或1(一般错误)表示命令存在
  78. return True
  79. else:
  80. print(f"命令执行失败 {server_name}: {stderr.decode()}")
  81. return False
  82. except FileNotFoundError:
  83. print(f"命令未找到: {server_config['command']}")
  84. return False
  85. except Exception as e:
  86. print(f"测试命令时出错 {server_name}: {e}")
  87. return False
  88. async def call_mcp_server(self, server_name, method, params=None):
  89. """
  90. 调用指定的MCP服务器
  91. Args:
  92. server_name (str): 服务器名称
  93. method (str): 要调用的方法
  94. params (dict): 方法参数
  95. Returns:
  96. dict: MCP服务器的响应
  97. """
  98. if server_name not in self.mcp_servers:
  99. return {"error": f"MCP server '{server_name}' not found"}
  100. # 检查命令是否可用
  101. if not await self._test_server_command(server_name):
  102. return {"error": f"MCP server '{server_name}' 命令不可用"}
  103. try:
  104. # 创建临时会话(每次调用都创建新会话)
  105. server_config = self.mcp_servers[server_name]
  106. async with stdio_client(
  107. StdioServerParameters(
  108. command=server_config["command"],
  109. args=server_config["args"],
  110. env=None
  111. )
  112. ) as (read, write):
  113. async with ClientSession(read, write) as session:
  114. # 初始化MCP服务器
  115. await session.initialize()
  116. # 调用指定方法
  117. if method == "prompts/list":
  118. result = await session.list_prompts()
  119. elif method == "prompts/get" and params:
  120. result = await session.get_prompt(params["name"])
  121. elif method == "resources/list":
  122. result = await session.list_resources()
  123. elif method == "resources/read" and params:
  124. result = await session.read_resource(params["uri"])
  125. elif method == "tools/list":
  126. result = await session.list_tools()
  127. elif method == "tools/call" and params:
  128. result = await session.call_tool(params["name"], params.get("arguments", {}))
  129. else:
  130. # 通用方法调用
  131. result = await session.send_request(method, params or {})
  132. # 将结果转换为可序列化的字典
  133. return self._serialize_result(result)
  134. except Exception as e:
  135. return {"error": f"调用MCP服务时出错: {str(e)}"}
  136. def _serialize_result(self, result):
  137. """
  138. 将MCP结果转换为可JSON序列化的格式
  139. """
  140. from pydantic import BaseModel
  141. if isinstance(result, BaseModel):
  142. return result.model_dump()
  143. elif isinstance(result, dict):
  144. return {key: self._serialize_result(value) for key, value in result.items()}
  145. elif isinstance(result, list):
  146. return [self._serialize_result(item) for item in result]
  147. else:
  148. return result
  149. async def get_actual_mcp_tools(self, server_name):
  150. """
  151. 获取指定MCP服务的实际工具列表
  152. Args:
  153. server_name (str): 服务器名称
  154. Returns:
  155. list: 实际的工具列表
  156. """
  157. if server_name not in self.mcp_servers:
  158. return []
  159. try:
  160. result = await self.call_mcp_server(server_name, "tools/list")
  161. if isinstance(result, dict) and "tools" in result:
  162. # 缓存工具列表
  163. self._tools_cache[server_name] = result["tools"]
  164. return result["tools"]
  165. return []
  166. except Exception as e:
  167. print(f"获取 {server_name} 工具列表时出错: {e}")
  168. return []
  169. async def get_all_mcp_tools(self):
  170. """
  171. 获取所有MCP服务的实际工具列表
  172. Returns:
  173. list: 所有MCP服务的工具列表,每个工具名称后缀添加_mcp
  174. """
  175. all_tools = []
  176. for server_name in self.mcp_servers.keys():
  177. try:
  178. tools = await self.get_actual_mcp_tools(server_name)
  179. # 为每个工具名称添加_mcp后缀以区分
  180. for tool in tools:
  181. # 创建工具副本并修改名称
  182. tool_copy = tool.copy()
  183. tool_copy['name'] = f"{tool['name']}_mcp"
  184. all_tools.append(tool_copy)
  185. except Exception as e:
  186. print(f"获取 {server_name} 工具列表时出错: {e}")
  187. return all_tools
  188. async def call_mcp_tool(self, tool_name, arguments):
  189. """
  190. 执行MCP工具调用
  191. Args:
  192. tool_name (str): 工具名称(需要以_mcp结尾)
  193. arguments (dict): 工具参数
  194. Returns:
  195. dict: 工具执行结果
  196. """
  197. try:
  198. # 验证工具名称格式
  199. if not tool_name.endswith("_mcp"):
  200. return {"error": f"无效的工具名称格式: {tool_name}"}
  201. # 移除_mcp后缀获取原始工具名称
  202. original_tool_name = tool_name[:-4]
  203. # 查找工具所属的服务器
  204. server_name = self._find_server_for_tool(original_tool_name)
  205. if not server_name:
  206. return {"error": f"未找到工具 {tool_name} 对应的MCP服务器"}
  207. # 调用工具
  208. result = await self.call_mcp_server(
  209. server_name,
  210. "tools/call",
  211. {
  212. "name": original_tool_name,
  213. "arguments": arguments
  214. }
  215. )
  216. return result
  217. except Exception as e:
  218. return {"error": f"MCP工具调用失败: {str(e)}"}
  219. def _find_server_for_tool(self, tool_name):
  220. """
  221. 根据工具名称查找对应的服务器名称
  222. Args:
  223. tool_name (str): 工具名称
  224. Returns:
  225. str or None: 服务器名称或None
  226. """
  227. # 遍历缓存的工具列表查找对应的服务器
  228. for server_name, tools in self._tools_cache.items():
  229. for tool in tools:
  230. if tool.get("name") == tool_name:
  231. return server_name
  232. return None
  233. async def initialize_all_servers(self):
  234. """
  235. 初始化所有MCP服务器并获取工具列表
  236. """
  237. print("正在测试MCP服务器连接...")
  238. for server_name in self.mcp_servers.keys():
  239. if await self._test_server_command(server_name):
  240. print(f"✓ {server_name} 命令可用")
  241. else:
  242. print(f"✗ {server_name} 命令不可用")
  243. # 创建FastAPI应用
  244. app = FastAPI(title="MCP API Server", description="MCP服务API接口")
  245. # 添加CORS中间件
  246. app.add_middleware(
  247. CORSMiddleware,
  248. allow_origins=["*"],
  249. allow_credentials=True,
  250. allow_methods=["*"],
  251. allow_headers=["*"],
  252. )
  253. # 创建全局MCP服务器管理器实例
  254. mcp_manager = MCPServerManager()
  255. @app.get("/")
  256. async def root():
  257. return {"message": "MCP API Server is running"}
  258. @app.get("/tools", response_model=List[ToolModel])
  259. async def get_tools():
  260. """
  261. 获取所有MCP工具列表
  262. """
  263. try:
  264. tools = await mcp_manager.get_all_mcp_tools()
  265. return tools
  266. except Exception as e:
  267. raise HTTPException(status_code=500, detail=str(e))
  268. @app.post("/call")
  269. async def call_tool(request: CallRequest):
  270. """
  271. 调用指定的MCP工具
  272. """
  273. try:
  274. result = await mcp_manager.call_mcp_tool(request.name, request.arguments)
  275. # 检查是否有错误
  276. if isinstance(result, dict) and "error" in result:
  277. raise HTTPException(status_code=500, detail=result["error"])
  278. return result
  279. except HTTPException:
  280. raise
  281. except Exception as e:
  282. raise HTTPException(status_code=500, detail=str(e))
  283. if __name__ == "__main__":
  284. # 初始化所有服务器
  285. print("正在初始化MCP服务器...")
  286. loop = asyncio.get_event_loop()
  287. loop.run_until_complete(mcp_manager.initialize_all_servers())
  288. print("MCP服务器初始化完成")
  289. # 启动FastAPI服务器
  290. print("启动MCP API服务器: http://localhost:8000")
  291. uvicorn.run("mcp_api_server:app", host="localhost", port=8000, reload=False)