浏览代码

优化了mcp调用逻辑

关习习 1 月之前
父节点
当前提交
9359b98307
共有 4 个文件被更改,包括 182 次插入189 次删除
  1. 1 1
      main.py
  2. 146 136
      mcp_api_server.py
  3. 6 5
      requirements.txt
  4. 29 47
      tools.py

+ 1 - 1
main.py

@@ -36,7 +36,7 @@ def stream_response(messages, tools=None):
         # 提取delta内容
         delta = chunk.choices[0].delta
 
-        # 处理文本内容(最终回答)
+        # 处理思考内容,没有添加标签直接打印文本了
         if hasattr(delta, 'reasoning') and delta.reasoning:
             content = delta.reasoning
             full_response += content

+ 146 - 136
mcp_api_server.py

@@ -1,8 +1,4 @@
 import json
-import subprocess
-import asyncio
-import sys
-import os
 from mcp import ClientSession, StdioServerParameters
 from mcp.client.stdio import stdio_client
 from pydantic import BaseModel
@@ -10,6 +6,7 @@ from typing import Dict, List, Any, Optional
 import uvicorn
 from fastapi import FastAPI, HTTPException
 from fastapi.middleware.cors import CORSMiddleware
+from contextlib import asynccontextmanager
 
 class ToolModel(BaseModel):
     name: str
@@ -20,15 +17,115 @@ class CallRequest(BaseModel):
     name: str
     arguments: Optional[Dict] = {}
 
+class PersistentMCPService:
+    """
+    持久化MCP服务类,用于管理长期运行的MCP服务连接
+    """
+    def __init__(self, server_name: str, command: str, args: List[str]):
+        self.server_name = server_name
+        self.command = command
+        self.args = args
+        self.client_context = None
+        self.session = None
+
+    async def start(self):
+        """
+        启动MCP服务并建立持久化连接
+        """
+        try:
+            print(f"正在启动MCP服务: {self.server_name}")
+            
+            # 创建stdio客户端连接
+            stdio_params = StdioServerParameters(
+                command=self.command,
+                args=self.args,
+                env=None
+            )
+            
+            # 使用上下文管理器创建持久化连接
+            self.client_context = stdio_client(stdio_params)
+            read, write = await self.client_context.__aenter__()
+            
+            # 创建会话
+            self.session = ClientSession(read, write)
+            await self.session.__aenter__()
+            
+            # 初始化会话
+            await self.session.initialize()
+            
+            print(f"✓ MCP服务 {self.server_name} 启动成功")
+            return True
+        except Exception as e:
+            print(f"✗ 启动MCP服务 {self.server_name} 失败: {e}")
+            await self.stop()
+            return False
+
+    async def stop(self):
+        """
+        停止MCP服务
+        """
+        try:
+            if self.session:
+                await self.session.__aexit__(None, None, None)
+                self.session = None
+            
+            if self.client_context:
+                await self.client_context.__aexit__(None, None, None)
+                self.client_context = None
+                
+        except Exception as e:
+            print(f"停止MCP服务 {self.server_name} 时出错: {e}")
+
+    async def call_method(self, method: str, params: Optional[Dict] = None):
+        """
+        调用MCP方法
+        """
+        if not self.session:
+            return {"error": f"MCP服务 {self.server_name} 未连接"}
+        
+        try:
+            if method == "prompts/list":
+                result = await self.session.list_prompts()
+            elif method == "prompts/get" and params:
+                result = await self.session.get_prompt(params["name"])
+            elif method == "resources/list":
+                result = await self.session.list_resources()
+            elif method == "resources/read" and params:
+                result = await self.session.read_resource(params["uri"])
+            elif method == "tools/list":
+                result = await self.session.list_tools()
+            elif method == "tools/call" and params:
+                result = await self.session.call_tool(params["name"], params.get("arguments", {}))
+            else:
+                # 通用方法调用
+                result = await self.session.send_request(method, params or {})
+            
+            return self._serialize_result(result)
+        except Exception as e:
+            return {"error": f"调用MCP服务 {self.server_name} 时出错: {str(e)}"}
+
+    def _serialize_result(self, result):
+        """
+        将MCP结果转换为可JSON序列化的格式
+        """
+        from pydantic import BaseModel
+        if isinstance(result, BaseModel):
+            return result.model_dump()
+        elif isinstance(result, dict):
+            return {key: self._serialize_result(value) for key, value in result.items()}
+        elif isinstance(result, list):
+            return [self._serialize_result(item) for item in result]
+        else:
+            return result
 
 class MCPServerManager:
     def __init__(self, config_file="mcp_config.json"):
         # 从配置文件加载MCP服务配置
         self.mcp_servers = self._load_config(config_file)
+        # 存储持久化的服务实例
+        self.persistent_services: Dict[str, PersistentMCPService] = {}
         # 缓存每个服务的工具列表
         self._tools_cache = {}
-        # 存储活动的会话
-        self._active_sessions = {}
 
     def _load_config(self, config_file):
         """
@@ -50,54 +147,32 @@ class MCPServerManager:
             # 返回空配置
             return {}
 
-    async def _test_server_command(self, server_name):
+    async def start_all_services(self):
         """
-        测试服务器命令是否可用
-        
-        Args:
-            server_name (str): 服务器名称
-            
-        Returns:
-            bool: 命令是否可用
+        启动所有MCP服务
         """
-        if server_name not in self.mcp_servers:
-            return False
-
-        server_config = self.mcp_servers[server_name]
-        
-        try:
-            # 检查命令是否存在
-            process = subprocess.Popen([server_config["command"]] + server_config["args"], 
-                                     stdout=subprocess.PIPE, 
-                                     stderr=subprocess.PIPE,
-                                     shell=(sys.platform == 'win32'))
-            # 给进程一点时间启动
-            await asyncio.sleep(0.5)
+        print("正在启动所有MCP服务...")
+        for server_name, server_config in self.mcp_servers.items():
+            service = PersistentMCPService(
+                server_name,
+                server_config["command"],
+                server_config["args"]
+            )
             
-            # 检查进程是否仍在运行
-            if process.poll() is None:
-                # 进程仍在运行,终止它
-                process.terminate()
-                try:
-                    process.wait(timeout=1)
-                except subprocess.TimeoutExpired:
-                    process.kill()
-                return True
+            success = await service.start()
+            if success:
+                self.persistent_services[server_name] = service
             else:
-                # 进程已退出,检查返回码
-                _, stderr = process.communicate()
-                if process.returncode == 0 or process.returncode == 1:
-                    # 返回码为0(成功)或1(一般错误)表示命令存在
-                    return True
-                else:
-                    print(f"命令执行失败 {server_name}: {stderr.decode()}")
-                    return False
-        except FileNotFoundError:
-            print(f"命令未找到: {server_config['command']}")
-            return False
-        except Exception as e:
-            print(f"测试命令时出错 {server_name}: {e}")
-            return False
+                print(f"✗ MCP服务 {server_name} 启动失败")
+
+    async def stop_all_services(self):
+        """
+        停止所有MCP服务
+        """
+        print("正在停止所有MCP服务...")
+        for server_name, service in self.persistent_services.items():
+            await service.stop()
+        self.persistent_services.clear()
 
     async def call_mcp_server(self, server_name, method, params=None):
         """
@@ -111,63 +186,11 @@ class MCPServerManager:
         Returns:
             dict: MCP服务器的响应
         """
-        if server_name not in self.mcp_servers:
-            return {"error": f"MCP server '{server_name}' not found"}
+        if server_name not in self.persistent_services:
+            return {"error": f"MCP服务 '{server_name}' 未启动或启动失败"}
 
-        # 检查命令是否可用
-        if not await self._test_server_command(server_name):
-            return {"error": f"MCP server '{server_name}' 命令不可用"}
-
-        try:
-            # 创建临时会话(每次调用都创建新会话)
-            server_config = self.mcp_servers[server_name]
-            
-            async with stdio_client(
-                StdioServerParameters(
-                    command=server_config["command"],
-                    args=server_config["args"],
-                    env=None
-                )
-            ) as (read, write):
-                async with ClientSession(read, write) as session:
-                    # 初始化MCP服务器
-                    await session.initialize()
-                    
-                    # 调用指定方法
-                    if method == "prompts/list":
-                        result = await session.list_prompts()
-                    elif method == "prompts/get" and params:
-                        result = await session.get_prompt(params["name"])
-                    elif method == "resources/list":
-                        result = await session.list_resources()
-                    elif method == "resources/read" and params:
-                        result = await session.read_resource(params["uri"])
-                    elif method == "tools/list":
-                        result = await session.list_tools()
-                    elif method == "tools/call" and params:
-                        result = await session.call_tool(params["name"], params.get("arguments", {}))
-                    else:
-                        # 通用方法调用
-                        result = await session.send_request(method, params or {})
-                    
-                    # 将结果转换为可序列化的字典
-                    return self._serialize_result(result)
-        except Exception as e:
-            return {"error": f"调用MCP服务时出错: {str(e)}"}
-
-    def _serialize_result(self, result):
-        """
-        将MCP结果转换为可JSON序列化的格式
-        """
-        from pydantic import BaseModel
-        if isinstance(result, BaseModel):
-            return result.model_dump()
-        elif isinstance(result, dict):
-            return {key: self._serialize_result(value) for key, value in result.items()}
-        elif isinstance(result, list):
-            return [self._serialize_result(item) for item in result]
-        else:
-            return result
+        service = self.persistent_services[server_name]
+        return await service.call_method(method, params)
 
     async def get_actual_mcp_tools(self, server_name):
         """
@@ -269,21 +292,26 @@ class MCPServerManager:
                     return server_name
         return None
 
-    async def initialize_all_servers(self):
-        """
-        初始化所有MCP服务器并获取工具列表
-        """
-        print("正在测试MCP服务器连接...")
-        for server_name in self.mcp_servers.keys():
-            if await self._test_server_command(server_name):
-                print(f"✓ {server_name} 命令可用")
-            else:
-                print(f"✗ {server_name} 命令不可用")
-
+# 创建全局MCP服务器管理器实例
+mcp_manager = MCPServerManager()
 
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    """
+    应用生命周期管理器:启动时启动所有MCP服务,关闭时停止所有服务
+    """
+    # 启动事件
+    await mcp_manager.start_all_services()
+    
+    yield
+    
+    # 关闭事件
+    await mcp_manager.stop_all_services()
 
 # 创建FastAPI应用
-app = FastAPI(title="MCP API Server", description="MCP服务API接口")
+app = FastAPI(
+    lifespan=lifespan
+)
 
 # 添加CORS中间件
 app.add_middleware(
@@ -294,15 +322,6 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-# 创建全局MCP服务器管理器实例
-mcp_manager = MCPServerManager()
-
-
-@app.get("/")
-async def root():
-    return {"message": "MCP API Server is running"}
-
-
 @app.get("/tools", response_model=List[ToolModel])
 async def get_tools():
     """
@@ -314,7 +333,6 @@ async def get_tools():
     except Exception as e:
         raise HTTPException(status_code=500, detail=str(e))
 
-
 @app.post("/call")
 async def call_tool(request: CallRequest):
     """
@@ -331,14 +349,6 @@ async def call_tool(request: CallRequest):
     except Exception as e:
         raise HTTPException(status_code=500, detail=str(e))
 
-
 if __name__ == "__main__":
-    # 初始化所有服务器
-    print("正在初始化MCP服务器...")
-    loop = asyncio.get_event_loop()
-    loop.run_until_complete(mcp_manager.initialize_all_servers())
-    print("MCP服务器初始化完成")
-    
     # 启动FastAPI服务器
-    print("启动MCP API服务器: http://localhost:8000")
     uvicorn.run("mcp_api_server:app", host="localhost", port=8000, reload=False)

+ 6 - 5
requirements.txt

@@ -1,8 +1,9 @@
 openai
 json
-subprocess
-asyncio
-sys
-os
 mcp
-pydantic
+pydantic
+uvicorn
+typing
+contextlib
+fastapi
+requests

+ 29 - 47
tools.py

@@ -1,48 +1,39 @@
 import json
-import urllib.request
-import urllib.parse
-import urllib.error
-from html.parser import HTMLParser
+import requests
 
-class MLStripper(HTMLParser):
-    def __init__(self):
-        super().__init__()
-        self.reset()
-        self.fed = []
-    
-    def handle_data(self, d):
-        self.fed.append(d)
-    
-    def get_data(self):
-        return ''.join(self.fed)
-
-def strip_tags(html):
-    s = MLStripper()
-    s.feed(html)
-    return s.get_data()
+def old_tool():
+    """
+    模拟器旧工具,固定返回100
+    """
+    return 100
 
 class Tools:
     def __init__(self, mcp_api_url="http://localhost:8000"):
         self.mcp_api_url = mcp_api_url
+        self.mcp_tools_cache = []
 
     def get_tool_list(self):
-        # 基础工具
+
+        # 首先从缓存中读取tools列表
+        if self.mcp_tools_cache:
+            return self.mcp_tools_cache
+
+        # 基础工具,旧版tools
         tools = [
             {
                 "type": "function",
                 "function": {
-                    "name": "read_webpage",
-                    "description": "读取指定URL网页的文本内容",
-                    "parameters": {"type": "object", "properties": {"url": {"type": "string", "description": "要读取的网页URL"}}}
+                    "name": "luck_num",
+                    "description": "获取幸运数字",
+                    "parameters": {"type": "object", "properties": {"choice": {"type": "string", "description": "默认传入yes"}}}
                 }
             }
         ]
 
         # 获取所有MCP服务的实际工具
         try:
-            req = urllib.request.Request(f"{self.mcp_api_url}/tools")
-            with urllib.request.urlopen(req) as response:
-                mcp_tools = json.loads(response.read().decode('utf-8'))
+            response = requests.get(f"{self.mcp_api_url}/tools")
+            mcp_tools = response.json()
             
             # 为每个MCP工具创建独立的工具定义
             for tool in mcp_tools:
@@ -60,6 +51,9 @@ class Tools:
                 })
         except Exception as e:
             print(f"获取MCP工具列表时出错: {e}")
+
+        # 存入缓存,便于下次快速调用
+        self.mcp_tools_cache =  tools
         
         return tools
 
@@ -67,19 +61,8 @@ class Tools:
         print(f"🔧 正在执行工具: {tool_name}({parameters})")
 
         # 处理原有的工具
-        if tool_name == "read_webpage":
-            # 实现读取网页内容功能
-            try:
-                url = parameters["url"]
-                req = urllib.request.Request(url, headers={
-                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
-                })
-                with urllib.request.urlopen(req) as response:
-                    html = response.read().decode('utf-8')
-                    text = strip_tags(html)
-                    return text
-            except Exception as e:
-                return f"读取网页错误: {str(e)}"
+        if tool_name == "luck_num":
+            return str(old_tool())
         
         # 处理MCP工具
         elif tool_name.endswith("_mcp"):
@@ -92,14 +75,13 @@ class Tools:
                 }
                 
                 # 发送POST请求到MCP API服务器
-                req = urllib.request.Request(
+                response = requests.post(
                     f"{self.mcp_api_url}/call",
-                    data=json.dumps(data).encode('utf-8'),
-                    headers={'Content-Type': 'application/json'}
+                    json=data  # 使用json参数自动设置Content-Type为application/json
                 )
+                response.raise_for_status()
                 
-                with urllib.request.urlopen(req) as response:
-                    result = json.loads(response.read().decode('utf-8'))
+                result = response.json()
                 
                 # 检查是否有错误
                 if "error" in result:
@@ -108,8 +90,8 @@ class Tools:
                 
                 # 返回结果
                 return json.dumps(result, ensure_ascii=False)
-            except urllib.error.HTTPError as e:
-                return f"MCP调用HTTP错误: {e.code} - {e.reason}"
+            except requests.exceptions.RequestException as e:
+                return f"MCP调用HTTP错误: {str(e)}"
             except Exception as e:
                 return f"MCP调用失败: {str(e)}"
         else: