mcp_api_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import json
  2. from mcp import ClientSession, StdioServerParameters
  3. from mcp.client.stdio import stdio_client
  4. from pydantic import BaseModel
  5. from typing import Dict, List, Any, Optional
  6. import uvicorn
  7. from fastapi import FastAPI, HTTPException
  8. from fastapi.middleware.cors import CORSMiddleware
  9. from contextlib import asynccontextmanager
  10. class ToolModel(BaseModel):
  11. name: str
  12. description: Optional[str] = None
  13. inputSchema: Optional[Dict] = None
  14. class CallRequest(BaseModel):
  15. name: str
  16. arguments: Optional[Dict] = {}
  17. class PersistentMCPService:
  18. """
  19. 持久化MCP服务类,用于管理长期运行的MCP服务连接
  20. """
  21. def __init__(self, server_name: str, command: str, args: List[str]):
  22. self.server_name = server_name
  23. self.command = command
  24. self.args = args
  25. self.client_context = None
  26. self.session = None
  27. async def start(self):
  28. """
  29. 启动MCP服务并建立持久化连接
  30. """
  31. try:
  32. print(f"正在启动MCP服务: {self.server_name}")
  33. # 创建stdio客户端连接
  34. stdio_params = StdioServerParameters(
  35. command=self.command,
  36. args=self.args,
  37. env=None
  38. )
  39. # 使用上下文管理器创建持久化连接
  40. self.client_context = stdio_client(stdio_params)
  41. read, write = await self.client_context.__aenter__()
  42. # 创建会话
  43. self.session = ClientSession(read, write)
  44. await self.session.__aenter__()
  45. # 初始化会话
  46. await self.session.initialize()
  47. print(f"✓ MCP服务 {self.server_name} 启动成功")
  48. return True
  49. except Exception as e:
  50. print(f"✗ 启动MCP服务 {self.server_name} 失败: {e}")
  51. await self.stop()
  52. return False
  53. async def stop(self):
  54. """
  55. 停止MCP服务
  56. """
  57. try:
  58. if self.session:
  59. await self.session.__aexit__(None, None, None)
  60. self.session = None
  61. if self.client_context:
  62. await self.client_context.__aexit__(None, None, None)
  63. self.client_context = None
  64. except Exception as e:
  65. print(f"停止MCP服务 {self.server_name} 时出错: {e}")
  66. async def call_method(self, method: str, params: Optional[Dict] = None):
  67. """
  68. 调用MCP方法
  69. """
  70. if not self.session:
  71. return {"error": f"MCP服务 {self.server_name} 未连接"}
  72. try:
  73. if method == "prompts/list":
  74. result = await self.session.list_prompts()
  75. elif method == "prompts/get" and params:
  76. result = await self.session.get_prompt(params["name"])
  77. elif method == "resources/list":
  78. result = await self.session.list_resources()
  79. elif method == "resources/read" and params:
  80. result = await self.session.read_resource(params["uri"])
  81. elif method == "tools/list":
  82. result = await self.session.list_tools()
  83. elif method == "tools/call" and params:
  84. result = await self.session.call_tool(params["name"], params.get("arguments", {}))
  85. else:
  86. # 通用方法调用
  87. result = await self.session.send_request(method, params or {})
  88. return self._serialize_result(result)
  89. except Exception as e:
  90. return {"error": f"调用MCP服务 {self.server_name} 时出错: {str(e)}"}
  91. def _serialize_result(self, result):
  92. """
  93. 将MCP结果转换为可JSON序列化的格式
  94. """
  95. from pydantic import BaseModel
  96. if isinstance(result, BaseModel):
  97. return result.model_dump()
  98. elif isinstance(result, dict):
  99. return {key: self._serialize_result(value) for key, value in result.items()}
  100. elif isinstance(result, list):
  101. return [self._serialize_result(item) for item in result]
  102. else:
  103. return result
  104. class MCPServerManager:
  105. def __init__(self, config_file="mcp_config.json"):
  106. # 从配置文件加载MCP服务配置
  107. self.mcp_servers = self._load_config(config_file)
  108. # 存储持久化的服务实例
  109. self.persistent_services: Dict[str, PersistentMCPService] = {}
  110. # 缓存每个服务的工具列表
  111. self._tools_cache = {}
  112. def _load_config(self, config_file):
  113. """
  114. 从配置文件加载MCP服务配置
  115. Args:
  116. config_file (str): 配置文件路径
  117. Returns:
  118. dict: MCP服务配置
  119. """
  120. try:
  121. # 加载配置文件
  122. with open(config_file, 'r', encoding='utf-8') as f:
  123. config = json.load(f)
  124. return config.get("mcpServers", {})
  125. except Exception as e:
  126. print(f"加载MCP配置文件时出错: {e}")
  127. # 返回空配置
  128. return {}
  129. async def start_all_services(self):
  130. """
  131. 启动所有MCP服务
  132. """
  133. print("正在启动所有MCP服务...")
  134. for server_name, server_config in self.mcp_servers.items():
  135. service = PersistentMCPService(
  136. server_name,
  137. server_config["command"],
  138. server_config["args"]
  139. )
  140. success = await service.start()
  141. if success:
  142. self.persistent_services[server_name] = service
  143. else:
  144. print(f"✗ MCP服务 {server_name} 启动失败")
  145. async def stop_all_services(self):
  146. """
  147. 停止所有MCP服务
  148. """
  149. print("正在停止所有MCP服务...")
  150. for server_name, service in self.persistent_services.items():
  151. await service.stop()
  152. self.persistent_services.clear()
  153. async def call_mcp_server(self, server_name, method, params=None):
  154. """
  155. 调用指定的MCP服务器
  156. Args:
  157. server_name (str): 服务器名称
  158. method (str): 要调用的方法
  159. params (dict): 方法参数
  160. Returns:
  161. dict: MCP服务器的响应
  162. """
  163. if server_name not in self.persistent_services:
  164. return {"error": f"MCP服务 '{server_name}' 未启动或启动失败"}
  165. service = self.persistent_services[server_name]
  166. return await service.call_method(method, params)
  167. async def get_actual_mcp_tools(self, server_name):
  168. """
  169. 获取指定MCP服务的实际工具列表
  170. Args:
  171. server_name (str): 服务器名称
  172. Returns:
  173. list: 实际的工具列表
  174. """
  175. if server_name not in self.mcp_servers:
  176. return []
  177. try:
  178. result = await self.call_mcp_server(server_name, "tools/list")
  179. if isinstance(result, dict) and "tools" in result:
  180. # 缓存工具列表
  181. self._tools_cache[server_name] = result["tools"]
  182. return result["tools"]
  183. return []
  184. except Exception as e:
  185. print(f"获取 {server_name} 工具列表时出错: {e}")
  186. return []
  187. async def get_all_mcp_tools(self):
  188. """
  189. 获取所有MCP服务的实际工具列表
  190. Returns:
  191. list: 所有MCP服务的工具列表,每个工具名称后缀添加_mcp
  192. """
  193. all_tools = []
  194. for server_name in self.mcp_servers.keys():
  195. try:
  196. tools = await self.get_actual_mcp_tools(server_name)
  197. # 为每个工具名称添加_mcp后缀以区分
  198. for tool in tools:
  199. # 创建工具副本并修改名称
  200. tool_copy = tool.copy()
  201. tool_copy['name'] = f"{tool['name']}_mcp"
  202. all_tools.append(tool_copy)
  203. except Exception as e:
  204. print(f"获取 {server_name} 工具列表时出错: {e}")
  205. return all_tools
  206. async def call_mcp_tool(self, tool_name, arguments):
  207. """
  208. 执行MCP工具调用
  209. Args:
  210. tool_name (str): 工具名称(需要以_mcp结尾)
  211. arguments (dict): 工具参数
  212. Returns:
  213. dict: 工具执行结果
  214. """
  215. try:
  216. # 验证工具名称格式
  217. if not tool_name.endswith("_mcp"):
  218. return {"error": f"无效的工具名称格式: {tool_name}"}
  219. # 移除_mcp后缀获取原始工具名称
  220. original_tool_name = tool_name[:-4]
  221. # 查找工具所属的服务器
  222. server_name = self._find_server_for_tool(original_tool_name)
  223. if not server_name:
  224. return {"error": f"未找到工具 {tool_name} 对应的MCP服务器"}
  225. # 调用工具
  226. result = await self.call_mcp_server(
  227. server_name,
  228. "tools/call",
  229. {
  230. "name": original_tool_name,
  231. "arguments": arguments
  232. }
  233. )
  234. return result
  235. except Exception as e:
  236. return {"error": f"MCP工具调用失败: {str(e)}"}
  237. def _find_server_for_tool(self, tool_name):
  238. """
  239. 根据工具名称查找对应的服务器名称
  240. Args:
  241. tool_name (str): 工具名称
  242. Returns:
  243. str or None: 服务器名称或None
  244. """
  245. # 遍历缓存的工具列表查找对应的服务器
  246. for server_name, tools in self._tools_cache.items():
  247. for tool in tools:
  248. if tool.get("name") == tool_name:
  249. return server_name
  250. return None
  251. # 创建全局MCP服务器管理器实例
  252. mcp_manager = MCPServerManager()
  253. @asynccontextmanager
  254. async def lifespan(app: FastAPI):
  255. """
  256. 应用生命周期管理器:启动时启动所有MCP服务,关闭时停止所有服务
  257. """
  258. # 启动事件
  259. await mcp_manager.start_all_services()
  260. yield
  261. # 关闭事件
  262. await mcp_manager.stop_all_services()
  263. # 创建FastAPI应用
  264. app = FastAPI(
  265. lifespan=lifespan
  266. )
  267. # 添加CORS中间件
  268. app.add_middleware(
  269. CORSMiddleware,
  270. allow_origins=["*"],
  271. allow_credentials=True,
  272. allow_methods=["*"],
  273. allow_headers=["*"],
  274. )
  275. @app.get("/tools", response_model=List[ToolModel])
  276. async def get_tools():
  277. """
  278. 获取所有MCP工具列表
  279. """
  280. try:
  281. tools = await mcp_manager.get_all_mcp_tools()
  282. return tools
  283. except Exception as e:
  284. raise HTTPException(status_code=500, detail=str(e))
  285. @app.post("/call")
  286. async def call_tool(request: CallRequest):
  287. """
  288. 调用指定的MCP工具
  289. """
  290. try:
  291. result = await mcp_manager.call_mcp_tool(request.name, request.arguments)
  292. # 检查是否有错误
  293. if isinstance(result, dict) and "error" in result:
  294. raise HTTPException(status_code=500, detail=result["error"])
  295. return result
  296. except HTTPException:
  297. raise
  298. except Exception as e:
  299. raise HTTPException(status_code=500, detail=str(e))
  300. if __name__ == "__main__":
  301. # 启动FastAPI服务器
  302. uvicorn.run("mcp_api_server:app", host="localhost", port=8000, reload=False)