mcp_tools.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import json
  2. import subprocess
  3. import asyncio
  4. import sys
  5. import os
  6. from mcp import ClientSession, StdioServerParameters
  7. from mcp.client.stdio import stdio_client
  8. from pydantic import BaseModel
  9. class MCPTools:
  10. def __init__(self, config_file="mcp_config.json"):
  11. # 从配置文件加载MCP服务配置
  12. self.mcp_servers = self._load_config(config_file)
  13. # 缓存每个服务的工具列表
  14. self._tools_cache = {}
  15. def _load_config(self, config_file):
  16. """
  17. 从配置文件加载MCP服务配置
  18. Args:
  19. config_file (str): 配置文件路径
  20. Returns:
  21. dict: MCP服务配置
  22. """
  23. try:
  24. # 加载配置文件
  25. with open(config_file, 'r', encoding='utf-8') as f:
  26. config = json.load(f)
  27. return config.get("mcpServers", {})
  28. except Exception as e:
  29. print(f"加载MCP配置文件时出错: {e}")
  30. # 返回空配置
  31. return {}
  32. async def call_mcp_server(self, server_name, method, params=None):
  33. """
  34. 调用指定的MCP服务器
  35. Args:
  36. server_name (str): 服务器名称
  37. method (str): 要调用的方法
  38. params (dict): 方法参数
  39. Returns:
  40. dict: MCP服务器的响应
  41. """
  42. if server_name not in self.mcp_servers:
  43. return {"error": f"MCP server '{server_name}' not found"}
  44. server_config = self.mcp_servers[server_name]
  45. try:
  46. # 检查命令是否存在
  47. try:
  48. process = subprocess.Popen([server_config["command"]] + server_config["args"],
  49. stdout=subprocess.PIPE,
  50. stderr=subprocess.PIPE,
  51. shell=(sys.platform == 'win32'))
  52. process.terminate() # 立即终止,只是检查命令是否存在
  53. except FileNotFoundError:
  54. return {"error": f"命令未找到: {server_config['command']}"}
  55. # 检查service_mcp脚本是否存在 (仅当它是配置中的服务时)
  56. if server_name == "service_mcp" and server_config.get("args"):
  57. script_path = server_config["args"][0] if server_config["args"] else None
  58. if script_path and not os.path.exists(script_path):
  59. return {"error": f"Service MCP脚本不存在: {script_path}"}
  60. async with stdio_client(
  61. StdioServerParameters(
  62. command=server_config["command"],
  63. args=server_config["args"],
  64. env=None
  65. )
  66. ) as (read, write):
  67. async with ClientSession(read, write) as session:
  68. # 初始化MCP服务器
  69. await session.initialize()
  70. # 调用指定方法
  71. if method == "prompts/list":
  72. result = await session.list_prompts()
  73. elif method == "prompts/get" and params:
  74. result = await session.get_prompt(params["name"])
  75. elif method == "resources/list":
  76. result = await session.list_resources()
  77. elif method == "resources/read" and params:
  78. result = await session.read_resource(params["uri"])
  79. elif method == "tools/list":
  80. result = await session.list_tools()
  81. elif method == "tools/call" and params:
  82. result = await session.call_tool(params["name"], params.get("arguments", {}))
  83. else:
  84. # 通用方法调用
  85. result = await session.send_request(method, params or {})
  86. # 将结果转换为可序列化的字典
  87. return self._serialize_result(result)
  88. except Exception as e:
  89. return {"error": f"调用MCP服务时出错: {str(e)}"}
  90. def _serialize_result(self, result):
  91. """
  92. 将MCP结果转换为可JSON序列化的格式
  93. """
  94. if isinstance(result, BaseModel):
  95. return result.model_dump()
  96. elif isinstance(result, dict):
  97. return {key: self._serialize_result(value) for key, value in result.items()}
  98. elif isinstance(result, list):
  99. return [self._serialize_result(item) for item in result]
  100. else:
  101. return result
  102. def get_mcp_tool_list(self):
  103. """
  104. 获取MCP工具列表,用于添加到AI工具中
  105. 根据配置动态生成工具列表
  106. """
  107. tools = []
  108. # 为每个配置的MCP服务创建对应的工具
  109. for server_name, server_config in self.mcp_servers.items():
  110. tool_name = f"call_{server_name}_mcp"
  111. description = server_config.get("description", f"调用{server_name} MCP服务")
  112. tools.append({
  113. "type": "function",
  114. "function": {
  115. "name": tool_name,
  116. "description": description,
  117. "parameters": {
  118. "type": "object",
  119. "properties": {
  120. "method": {
  121. "type": "string",
  122. "description": "要调用的MCP方法,如tools/list, tools/call等"
  123. },
  124. "params": {
  125. "type": "object",
  126. "description": "方法参数"
  127. }
  128. },
  129. "required": ["method"]
  130. }
  131. }
  132. })
  133. return tools
  134. async def get_actual_mcp_tools(self, server_name):
  135. """
  136. 获取指定MCP服务的实际工具列表
  137. Args:
  138. server_name (str): 服务器名称
  139. Returns:
  140. list: 实际的工具列表
  141. """
  142. if server_name not in self.mcp_servers:
  143. return []
  144. try:
  145. result = await self.call_mcp_server(server_name, "tools/list")
  146. if isinstance(result, dict) and "tools" in result:
  147. return result["tools"]
  148. return []
  149. except Exception as e:
  150. print(f"获取 {server_name} 工具列表时出错: {e}")
  151. return []
  152. def get_all_mcp_tools_sync(self):
  153. """
  154. 同步获取所有MCP服务的实际工具列表
  155. Returns:
  156. list: 所有MCP服务的工具列表
  157. """
  158. all_tools = []
  159. for server_name in self.mcp_servers.keys():
  160. try:
  161. tools = asyncio.run(self.get_actual_mcp_tools(server_name))
  162. # 缓存工具列表及对应的服务器名称
  163. self._tools_cache[server_name] = tools
  164. all_tools.extend(tools)
  165. except Exception as e:
  166. print(f"获取 {server_name} 工具列表时出错: {e}")
  167. return all_tools
  168. def get_server_for_tool(self, tool_name):
  169. """
  170. 根据工具名称获取对应的服务器名称
  171. Args:
  172. tool_name (str): 工具名称
  173. Returns:
  174. str or None: 服务器名称或None
  175. """
  176. # 遍历缓存的工具列表查找对应的服务器
  177. for server_name, tools in self._tools_cache.items():
  178. for tool in tools:
  179. if tool.get("name") == tool_name:
  180. return server_name
  181. return None
  182. def call_mcp_tool(self, tool_name, parameters):
  183. """
  184. 执行MCP工具调用
  185. Args:
  186. tool_name (str): 工具名称
  187. parameters (dict): 工具参数
  188. Returns:
  189. str: 工具执行结果
  190. """
  191. try:
  192. # 从工具名称中提取服务器名称
  193. if tool_name.startswith("call_") and tool_name.endswith("_mcp"):
  194. server_name = tool_name[5:-4] # 移除 "call_" 前缀和 "_mcp" 后缀
  195. else:
  196. return f"未知的MCP工具格式: {tool_name}"
  197. if server_name not in self.mcp_servers:
  198. return f"未配置的MCP服务: {server_name}"
  199. # 在异步环境中运行
  200. result = asyncio.run(
  201. self.call_mcp_server(
  202. server_name,
  203. parameters.get("method", ""),
  204. parameters.get("params")
  205. )
  206. )
  207. return json.dumps(result, ensure_ascii=False)
  208. except Exception as e:
  209. return f"MCP工具调用失败: {str(e)}"
  210. async def test_mcp_connection(self, server_name):
  211. """
  212. 测试MCP服务连接
  213. Args:
  214. server_name (str): 服务器名称
  215. Returns:
  216. dict: 测试结果
  217. """
  218. if server_name not in self.mcp_servers:
  219. return {"error": f"MCP server '{server_name}' not found"}
  220. try:
  221. result = await self.call_mcp_server(server_name, "tools/list")
  222. # 缓存工具列表
  223. if isinstance(result, dict) and "tools" in result:
  224. self._tools_cache[server_name] = result["tools"]
  225. return {
  226. "status": "success",
  227. "server": server_name,
  228. "tools_available": len(result.get("tools", [])) if isinstance(result, dict) else 0,
  229. "result": result
  230. }
  231. except Exception as e:
  232. return {
  233. "status": "error",
  234. "server": server_name,
  235. "error": str(e)
  236. }
  237. def get_available_servers(self):
  238. """
  239. 获取所有可用的MCP服务器名称
  240. Returns:
  241. list: 可用的MCP服务器名称列表
  242. """
  243. return list(self.mcp_servers.keys())