mcp_tools.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import json
  2. import subprocess
  3. import asyncio
  4. import sys
  5. import os
  6. from mcp import ClientSession, StdioServerParameters
  7. from mcp.client.stdio import stdio_client
  8. from pydantic import BaseModel
  9. class MCPTools:
  10. def __init__(self, config_file="mcp_config.json"):
  11. # 从配置文件加载MCP服务配置
  12. self.mcp_servers = self._load_config(config_file)
  13. # 缓存每个服务的工具列表
  14. self._tools_cache = {}
  15. def _load_config(self, config_file):
  16. """
  17. 从配置文件加载MCP服务配置
  18. Args:
  19. config_file (str): 配置文件路径
  20. Returns:
  21. dict: MCP服务配置
  22. """
  23. try:
  24. # 加载配置文件
  25. with open(config_file, 'r', encoding='utf-8') as f:
  26. config = json.load(f)
  27. return config.get("mcpServers", {})
  28. except Exception as e:
  29. print(f"加载MCP配置文件时出错: {e}")
  30. # 返回空配置
  31. return {}
  32. async def call_mcp_server(self, server_name, method, params=None):
  33. """
  34. 调用指定的MCP服务器
  35. Args:
  36. server_name (str): 服务器名称
  37. method (str): 要调用的方法
  38. params (dict): 方法参数
  39. Returns:
  40. dict: MCP服务器的响应
  41. """
  42. if server_name not in self.mcp_servers:
  43. return {"error": f"MCP server '{server_name}' not found"}
  44. server_config = self.mcp_servers[server_name]
  45. try:
  46. # 检查命令是否存在
  47. try:
  48. process = subprocess.Popen([server_config["command"]] + server_config["args"],
  49. stdout=subprocess.PIPE,
  50. stderr=subprocess.PIPE,
  51. shell=(sys.platform == 'win32'))
  52. process.terminate() # 立即终止,只是检查命令是否存在
  53. except FileNotFoundError:
  54. return {"error": f"命令未找到: {server_config['command']}"}
  55. # 检查service_mcp脚本是否存在 (仅当它是配置中的服务时)
  56. if server_name == "service_mcp" and server_config.get("args"):
  57. script_path = server_config["args"][0] if server_config["args"] else None
  58. if script_path and not os.path.exists(script_path):
  59. return {"error": f"Service MCP脚本不存在: {script_path}"}
  60. async with stdio_client(
  61. StdioServerParameters(
  62. command=server_config["command"],
  63. args=server_config["args"],
  64. env=None
  65. )
  66. ) as (read, write):
  67. async with ClientSession(read, write) as session:
  68. # 初始化MCP服务器
  69. await session.initialize()
  70. # 调用指定方法
  71. if method == "prompts/list":
  72. result = await session.list_prompts()
  73. elif method == "prompts/get" and params:
  74. result = await session.get_prompt(params["name"])
  75. elif method == "resources/list":
  76. result = await session.list_resources()
  77. elif method == "resources/read" and params:
  78. result = await session.read_resource(params["uri"])
  79. elif method == "tools/list":
  80. result = await session.list_tools()
  81. elif method == "tools/call" and params:
  82. result = await session.call_tool(params["name"], params.get("arguments", {}))
  83. else:
  84. # 通用方法调用
  85. result = await session.send_request(method, params or {})
  86. # 将结果转换为可序列化的字典
  87. return self._serialize_result(result)
  88. except Exception as e:
  89. return {"error": f"调用MCP服务时出错: {str(e)}"}
  90. def _serialize_result(self, result):
  91. """
  92. 将MCP结果转换为可JSON序列化的格式
  93. """
  94. if isinstance(result, BaseModel):
  95. return result.model_dump()
  96. elif isinstance(result, dict):
  97. return {key: self._serialize_result(value) for key, value in result.items()}
  98. elif isinstance(result, list):
  99. return [self._serialize_result(item) for item in result]
  100. else:
  101. return result
  102. async def get_actual_mcp_tools(self, server_name):
  103. """
  104. 获取指定MCP服务的实际工具列表
  105. Args:
  106. server_name (str): 服务器名称
  107. Returns:
  108. list: 实际的工具列表
  109. """
  110. if server_name not in self.mcp_servers:
  111. return []
  112. try:
  113. result = await self.call_mcp_server(server_name, "tools/list")
  114. if isinstance(result, dict) and "tools" in result:
  115. return result["tools"]
  116. return []
  117. except Exception as e:
  118. print(f"获取 {server_name} 工具列表时出错: {e}")
  119. return []
  120. def get_all_mcp_tools_sync(self):
  121. """
  122. 同步获取所有MCP服务的实际工具列表
  123. Returns:
  124. list: 所有MCP服务的工具列表
  125. """
  126. all_tools = []
  127. for server_name in self.mcp_servers.keys():
  128. try:
  129. tools = asyncio.run(self.get_actual_mcp_tools(server_name))
  130. # 缓存工具列表及对应的服务器名称
  131. self._tools_cache[server_name] = tools
  132. all_tools.extend(tools)
  133. except Exception as e:
  134. print(f"获取 {server_name} 工具列表时出错: {e}")
  135. return all_tools
  136. def get_server_for_tool(self, tool_name):
  137. """
  138. 根据工具名称获取对应的服务器名称
  139. Args:
  140. tool_name (str): 工具名称
  141. Returns:
  142. str or None: 服务器名称或None
  143. """
  144. # 遍历缓存的工具列表查找对应的服务器
  145. for server_name, tools in self._tools_cache.items():
  146. for tool in tools:
  147. if tool.get("name") == tool_name:
  148. return server_name
  149. return None
  150. def call_mcp_tool(self, tool_name, parameters):
  151. """
  152. 执行MCP工具调用
  153. Args:
  154. tool_name (str): 工具名称
  155. parameters (dict): 工具参数
  156. Returns:
  157. str: 工具执行结果
  158. """
  159. try:
  160. # 从工具名称中提取服务器名称
  161. if tool_name.startswith("call_") and tool_name.endswith("_mcp"):
  162. server_name = tool_name[5:-4] # 移除 "call_" 前缀和 "_mcp" 后缀
  163. else:
  164. return f"未知的MCP工具格式: {tool_name}"
  165. if server_name not in self.mcp_servers:
  166. return f"未配置的MCP服务: {server_name}"
  167. # 在异步环境中运行
  168. result = asyncio.run(
  169. self.call_mcp_server(
  170. server_name,
  171. parameters.get("method", ""),
  172. parameters.get("params")
  173. )
  174. )
  175. return json.dumps(result, ensure_ascii=False)
  176. except Exception as e:
  177. return f"MCP工具调用失败: {str(e)}"
  178. async def test_mcp_connection(self, server_name):
  179. """
  180. 测试MCP服务连接
  181. Args:
  182. server_name (str): 服务器名称
  183. Returns:
  184. dict: 测试结果
  185. """
  186. if server_name not in self.mcp_servers:
  187. return {"error": f"MCP server '{server_name}' not found"}
  188. try:
  189. result = await self.call_mcp_server(server_name, "tools/list")
  190. # 缓存工具列表
  191. if isinstance(result, dict) and "tools" in result:
  192. self._tools_cache[server_name] = result["tools"]
  193. return {
  194. "status": "success",
  195. "server": server_name,
  196. "tools_available": len(result.get("tools", [])) if isinstance(result, dict) else 0,
  197. "result": result
  198. }
  199. except Exception as e:
  200. return {
  201. "status": "error",
  202. "server": server_name,
  203. "error": str(e)
  204. }
  205. def get_available_servers(self):
  206. """
  207. 获取所有可用的MCP服务器名称
  208. Returns:
  209. list: 可用的MCP服务器名称列表
  210. """
  211. return list(self.mcp_servers.keys())