|
@@ -1,8 +1,4 @@
|
|
|
import json
|
|
|
-import subprocess
|
|
|
-import asyncio
|
|
|
-import sys
|
|
|
-import os
|
|
|
from mcp import ClientSession, StdioServerParameters
|
|
|
from mcp.client.stdio import stdio_client
|
|
|
from pydantic import BaseModel
|
|
@@ -10,6 +6,7 @@ from typing import Dict, List, Any, Optional
|
|
|
import uvicorn
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from contextlib import asynccontextmanager
|
|
|
|
|
|
class ToolModel(BaseModel):
|
|
|
name: str
|
|
@@ -20,15 +17,115 @@ class CallRequest(BaseModel):
|
|
|
name: str
|
|
|
arguments: Optional[Dict] = {}
|
|
|
|
|
|
+class PersistentMCPService:
|
|
|
+ """
|
|
|
+ 持久化MCP服务类,用于管理长期运行的MCP服务连接
|
|
|
+ """
|
|
|
+ def __init__(self, server_name: str, command: str, args: List[str]):
|
|
|
+ self.server_name = server_name
|
|
|
+ self.command = command
|
|
|
+ self.args = args
|
|
|
+ self.client_context = None
|
|
|
+ self.session = None
|
|
|
+
|
|
|
+ async def start(self):
|
|
|
+ """
|
|
|
+ 启动MCP服务并建立持久化连接
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ print(f"正在启动MCP服务: {self.server_name}")
|
|
|
+
|
|
|
+ # 创建stdio客户端连接
|
|
|
+ stdio_params = StdioServerParameters(
|
|
|
+ command=self.command,
|
|
|
+ args=self.args,
|
|
|
+ env=None
|
|
|
+ )
|
|
|
+
|
|
|
+ # 使用上下文管理器创建持久化连接
|
|
|
+ self.client_context = stdio_client(stdio_params)
|
|
|
+ read, write = await self.client_context.__aenter__()
|
|
|
+
|
|
|
+ # 创建会话
|
|
|
+ self.session = ClientSession(read, write)
|
|
|
+ await self.session.__aenter__()
|
|
|
+
|
|
|
+ # 初始化会话
|
|
|
+ await self.session.initialize()
|
|
|
+
|
|
|
+ print(f"✓ MCP服务 {self.server_name} 启动成功")
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ print(f"✗ 启动MCP服务 {self.server_name} 失败: {e}")
|
|
|
+ await self.stop()
|
|
|
+ return False
|
|
|
+
|
|
|
+ async def stop(self):
|
|
|
+ """
|
|
|
+ 停止MCP服务
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ if self.session:
|
|
|
+ await self.session.__aexit__(None, None, None)
|
|
|
+ self.session = None
|
|
|
+
|
|
|
+ if self.client_context:
|
|
|
+ await self.client_context.__aexit__(None, None, None)
|
|
|
+ self.client_context = None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"停止MCP服务 {self.server_name} 时出错: {e}")
|
|
|
+
|
|
|
+ async def call_method(self, method: str, params: Optional[Dict] = None):
|
|
|
+ """
|
|
|
+ 调用MCP方法
|
|
|
+ """
|
|
|
+ if not self.session:
|
|
|
+ return {"error": f"MCP服务 {self.server_name} 未连接"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ if method == "prompts/list":
|
|
|
+ result = await self.session.list_prompts()
|
|
|
+ elif method == "prompts/get" and params:
|
|
|
+ result = await self.session.get_prompt(params["name"])
|
|
|
+ elif method == "resources/list":
|
|
|
+ result = await self.session.list_resources()
|
|
|
+ elif method == "resources/read" and params:
|
|
|
+ result = await self.session.read_resource(params["uri"])
|
|
|
+ elif method == "tools/list":
|
|
|
+ result = await self.session.list_tools()
|
|
|
+ elif method == "tools/call" and params:
|
|
|
+ result = await self.session.call_tool(params["name"], params.get("arguments", {}))
|
|
|
+ else:
|
|
|
+ # 通用方法调用
|
|
|
+ result = await self.session.send_request(method, params or {})
|
|
|
+
|
|
|
+ return self._serialize_result(result)
|
|
|
+ except Exception as e:
|
|
|
+ return {"error": f"调用MCP服务 {self.server_name} 时出错: {str(e)}"}
|
|
|
+
|
|
|
+ def _serialize_result(self, result):
|
|
|
+ """
|
|
|
+ 将MCP结果转换为可JSON序列化的格式
|
|
|
+ """
|
|
|
+ from pydantic import BaseModel
|
|
|
+ if isinstance(result, BaseModel):
|
|
|
+ return result.model_dump()
|
|
|
+ elif isinstance(result, dict):
|
|
|
+ return {key: self._serialize_result(value) for key, value in result.items()}
|
|
|
+ elif isinstance(result, list):
|
|
|
+ return [self._serialize_result(item) for item in result]
|
|
|
+ else:
|
|
|
+ return result
|
|
|
|
|
|
class MCPServerManager:
|
|
|
def __init__(self, config_file="mcp_config.json"):
|
|
|
# 从配置文件加载MCP服务配置
|
|
|
self.mcp_servers = self._load_config(config_file)
|
|
|
+ # 存储持久化的服务实例
|
|
|
+ self.persistent_services: Dict[str, PersistentMCPService] = {}
|
|
|
# 缓存每个服务的工具列表
|
|
|
self._tools_cache = {}
|
|
|
- # 存储活动的会话
|
|
|
- self._active_sessions = {}
|
|
|
|
|
|
def _load_config(self, config_file):
|
|
|
"""
|
|
@@ -50,54 +147,32 @@ class MCPServerManager:
|
|
|
# 返回空配置
|
|
|
return {}
|
|
|
|
|
|
- async def _test_server_command(self, server_name):
|
|
|
+ async def start_all_services(self):
|
|
|
"""
|
|
|
- 测试服务器命令是否可用
|
|
|
-
|
|
|
- Args:
|
|
|
- server_name (str): 服务器名称
|
|
|
-
|
|
|
- Returns:
|
|
|
- bool: 命令是否可用
|
|
|
+ 启动所有MCP服务
|
|
|
"""
|
|
|
- if server_name not in self.mcp_servers:
|
|
|
- return False
|
|
|
-
|
|
|
- server_config = self.mcp_servers[server_name]
|
|
|
-
|
|
|
- try:
|
|
|
- # 检查命令是否存在
|
|
|
- process = subprocess.Popen([server_config["command"]] + server_config["args"],
|
|
|
- stdout=subprocess.PIPE,
|
|
|
- stderr=subprocess.PIPE,
|
|
|
- shell=(sys.platform == 'win32'))
|
|
|
- # 给进程一点时间启动
|
|
|
- await asyncio.sleep(0.5)
|
|
|
+ print("正在启动所有MCP服务...")
|
|
|
+ for server_name, server_config in self.mcp_servers.items():
|
|
|
+ service = PersistentMCPService(
|
|
|
+ server_name,
|
|
|
+ server_config["command"],
|
|
|
+ server_config["args"]
|
|
|
+ )
|
|
|
|
|
|
- # 检查进程是否仍在运行
|
|
|
- if process.poll() is None:
|
|
|
- # 进程仍在运行,终止它
|
|
|
- process.terminate()
|
|
|
- try:
|
|
|
- process.wait(timeout=1)
|
|
|
- except subprocess.TimeoutExpired:
|
|
|
- process.kill()
|
|
|
- return True
|
|
|
+ success = await service.start()
|
|
|
+ if success:
|
|
|
+ self.persistent_services[server_name] = service
|
|
|
else:
|
|
|
- # 进程已退出,检查返回码
|
|
|
- _, stderr = process.communicate()
|
|
|
- if process.returncode == 0 or process.returncode == 1:
|
|
|
- # 返回码为0(成功)或1(一般错误)表示命令存在
|
|
|
- return True
|
|
|
- else:
|
|
|
- print(f"命令执行失败 {server_name}: {stderr.decode()}")
|
|
|
- return False
|
|
|
- except FileNotFoundError:
|
|
|
- print(f"命令未找到: {server_config['command']}")
|
|
|
- return False
|
|
|
- except Exception as e:
|
|
|
- print(f"测试命令时出错 {server_name}: {e}")
|
|
|
- return False
|
|
|
+ print(f"✗ MCP服务 {server_name} 启动失败")
|
|
|
+
|
|
|
+ async def stop_all_services(self):
|
|
|
+ """
|
|
|
+ 停止所有MCP服务
|
|
|
+ """
|
|
|
+ print("正在停止所有MCP服务...")
|
|
|
+ for server_name, service in self.persistent_services.items():
|
|
|
+ await service.stop()
|
|
|
+ self.persistent_services.clear()
|
|
|
|
|
|
async def call_mcp_server(self, server_name, method, params=None):
|
|
|
"""
|
|
@@ -111,63 +186,11 @@ class MCPServerManager:
|
|
|
Returns:
|
|
|
dict: MCP服务器的响应
|
|
|
"""
|
|
|
- if server_name not in self.mcp_servers:
|
|
|
- return {"error": f"MCP server '{server_name}' not found"}
|
|
|
+ if server_name not in self.persistent_services:
|
|
|
+ return {"error": f"MCP服务 '{server_name}' 未启动或启动失败"}
|
|
|
|
|
|
- # 检查命令是否可用
|
|
|
- if not await self._test_server_command(server_name):
|
|
|
- return {"error": f"MCP server '{server_name}' 命令不可用"}
|
|
|
-
|
|
|
- try:
|
|
|
- # 创建临时会话(每次调用都创建新会话)
|
|
|
- server_config = self.mcp_servers[server_name]
|
|
|
-
|
|
|
- async with stdio_client(
|
|
|
- StdioServerParameters(
|
|
|
- command=server_config["command"],
|
|
|
- args=server_config["args"],
|
|
|
- env=None
|
|
|
- )
|
|
|
- ) as (read, write):
|
|
|
- async with ClientSession(read, write) as session:
|
|
|
- # 初始化MCP服务器
|
|
|
- await session.initialize()
|
|
|
-
|
|
|
- # 调用指定方法
|
|
|
- if method == "prompts/list":
|
|
|
- result = await session.list_prompts()
|
|
|
- elif method == "prompts/get" and params:
|
|
|
- result = await session.get_prompt(params["name"])
|
|
|
- elif method == "resources/list":
|
|
|
- result = await session.list_resources()
|
|
|
- elif method == "resources/read" and params:
|
|
|
- result = await session.read_resource(params["uri"])
|
|
|
- elif method == "tools/list":
|
|
|
- result = await session.list_tools()
|
|
|
- elif method == "tools/call" and params:
|
|
|
- result = await session.call_tool(params["name"], params.get("arguments", {}))
|
|
|
- else:
|
|
|
- # 通用方法调用
|
|
|
- result = await session.send_request(method, params or {})
|
|
|
-
|
|
|
- # 将结果转换为可序列化的字典
|
|
|
- return self._serialize_result(result)
|
|
|
- except Exception as e:
|
|
|
- return {"error": f"调用MCP服务时出错: {str(e)}"}
|
|
|
-
|
|
|
- def _serialize_result(self, result):
|
|
|
- """
|
|
|
- 将MCP结果转换为可JSON序列化的格式
|
|
|
- """
|
|
|
- from pydantic import BaseModel
|
|
|
- if isinstance(result, BaseModel):
|
|
|
- return result.model_dump()
|
|
|
- elif isinstance(result, dict):
|
|
|
- return {key: self._serialize_result(value) for key, value in result.items()}
|
|
|
- elif isinstance(result, list):
|
|
|
- return [self._serialize_result(item) for item in result]
|
|
|
- else:
|
|
|
- return result
|
|
|
+ service = self.persistent_services[server_name]
|
|
|
+ return await service.call_method(method, params)
|
|
|
|
|
|
async def get_actual_mcp_tools(self, server_name):
|
|
|
"""
|
|
@@ -269,21 +292,26 @@ class MCPServerManager:
|
|
|
return server_name
|
|
|
return None
|
|
|
|
|
|
- async def initialize_all_servers(self):
|
|
|
- """
|
|
|
- 初始化所有MCP服务器并获取工具列表
|
|
|
- """
|
|
|
- print("正在测试MCP服务器连接...")
|
|
|
- for server_name in self.mcp_servers.keys():
|
|
|
- if await self._test_server_command(server_name):
|
|
|
- print(f"✓ {server_name} 命令可用")
|
|
|
- else:
|
|
|
- print(f"✗ {server_name} 命令不可用")
|
|
|
-
|
|
|
+# 创建全局MCP服务器管理器实例
|
|
|
+mcp_manager = MCPServerManager()
|
|
|
|
|
|
+@asynccontextmanager
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
+ """
|
|
|
+ 应用生命周期管理器:启动时启动所有MCP服务,关闭时停止所有服务
|
|
|
+ """
|
|
|
+ # 启动事件
|
|
|
+ await mcp_manager.start_all_services()
|
|
|
+
|
|
|
+ yield
|
|
|
+
|
|
|
+ # 关闭事件
|
|
|
+ await mcp_manager.stop_all_services()
|
|
|
|
|
|
# 创建FastAPI应用
|
|
|
-app = FastAPI(title="MCP API Server", description="MCP服务API接口")
|
|
|
+app = FastAPI(
|
|
|
+ lifespan=lifespan
|
|
|
+)
|
|
|
|
|
|
# 添加CORS中间件
|
|
|
app.add_middleware(
|
|
@@ -294,15 +322,6 @@ app.add_middleware(
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
-# 创建全局MCP服务器管理器实例
|
|
|
-mcp_manager = MCPServerManager()
|
|
|
-
|
|
|
-
|
|
|
-@app.get("/")
|
|
|
-async def root():
|
|
|
- return {"message": "MCP API Server is running"}
|
|
|
-
|
|
|
-
|
|
|
@app.get("/tools", response_model=List[ToolModel])
|
|
|
async def get_tools():
|
|
|
"""
|
|
@@ -314,7 +333,6 @@ async def get_tools():
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
-
|
|
|
@app.post("/call")
|
|
|
async def call_tool(request: CallRequest):
|
|
|
"""
|
|
@@ -331,14 +349,6 @@ async def call_tool(request: CallRequest):
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
-
|
|
|
if __name__ == "__main__":
|
|
|
- # 初始化所有服务器
|
|
|
- print("正在初始化MCP服务器...")
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- loop.run_until_complete(mcp_manager.initialize_all_servers())
|
|
|
- print("MCP服务器初始化完成")
|
|
|
-
|
|
|
# 启动FastAPI服务器
|
|
|
- print("启动MCP API服务器: http://localhost:8000")
|
|
|
uvicorn.run("mcp_api_server:app", host="localhost", port=8000, reload=False)
|