mgmt/qdrant_mcp_server.py

380 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Qdrant MCP 服务器
此脚本实现了一个 MCP 服务器,与 Qdrant 向量数据库集成
"""
import asyncio
import json
import os
import sys
from typing import Any, Dict, List, Optional
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QdrantMCPServer:
def __init__(self):
# 从环境变量获取配置
self.qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")
self.qdrant_api_key = os.getenv("QDRANT_API_KEY", "")
self.collection_name = os.getenv("COLLECTION_NAME", "mcp")
self.embedding_model = os.getenv("EMBEDDING_MODEL", "bge-m3")
# 初始化 Qdrant 客户端
self.client = QdrantClient(
url=self.qdrant_url,
api_key=self.qdrant_api_key if self.qdrant_api_key else None
)
# 确保集合存在
self._ensure_collection_exists()
logger.info(f"Qdrant MCP 服务器已初始化")
logger.info(f"Qdrant URL: {self.qdrant_url}")
logger.info(f"集合名称: {self.collection_name}")
logger.info(f"嵌入模型: {self.embedding_model}")
def _ensure_collection_exists(self):
"""确保集合存在,如果不存在则创建"""
try:
collections = self.client.get_collections().collections
collection_names = [collection.name for collection in collections]
if self.collection_name not in collection_names:
# 创建新集合
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=1024, distance=Distance.COSINE)
)
logger.info(f"已创建新集合: {self.collection_name}")
else:
logger.info(f"集合已存在: {self.collection_name}")
except Exception as e:
logger.error(f"确保集合存在时出错: {e}")
raise
async def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理 MCP 请求"""
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
logger.info(f"收到请求: {method}")
try:
if method == "initialize":
result = await self.initialize(params)
elif method == "tools/list":
result = await self.list_tools(params)
elif method == "tools/call":
result = await self.call_tool(params)
elif method == "resources/list":
result = await self.list_resources(params)
elif method == "resources/read":
result = await self.read_resource(params)
else:
result = {
"error": {
"code": -32601,
"message": f"未知方法: {method}"
}
}
except Exception as e:
logger.error(f"处理请求时出错: {e}")
result = {
"error": {
"code": -32603,
"message": f"内部错误: {str(e)}"
}
}
response = {
"jsonrpc": "2.0",
"id": request_id,
**result
}
return response
async def initialize(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""初始化 MCP 服务器"""
logger.info("初始化 Qdrant MCP 服务器")
return {
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {
"listChanged": False
},
"resources": {
"subscribe": False,
"listChanged": False
}
},
"serverInfo": {
"name": "qdrant-mcp-server",
"version": "1.0.0"
}
}
}
async def list_tools(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""列出可用工具"""
return {
"result": {
"tools": [
{
"name": "qdrant_search",
"description": "在 Qdrant 中搜索相似向量",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询文本"
},
"limit": {
"type": "integer",
"default": 5,
"description": "返回结果数量限制"
}
},
"required": ["query"]
}
},
{
"name": "qdrant_add",
"description": "向 Qdrant 添加向量",
"inputSchema": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "要添加的文本内容"
},
"metadata": {
"type": "object",
"description": "与文本关联的元数据"
}
},
"required": ["text"]
}
},
{
"name": "qdrant_delete",
"description": "从 Qdrant 删除向量",
"inputSchema": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "要删除的向量ID"
}
},
"required": ["id"]
}
}
]
}
}
async def call_tool(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""调用工具"""
name = params.get("name")
arguments = params.get("arguments", {})
if name == "qdrant_search":
return await self._search_vectors(arguments)
elif name == "qdrant_add":
return await self._add_vector(arguments)
elif name == "qdrant_delete":
return await self._delete_vector(arguments)
else:
return {
"error": {
"code": -32601,
"message": f"未知工具: {name}"
}
}
async def _search_vectors(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""搜索相似向量"""
query = params.get("query", "")
limit = params.get("limit", 5)
# 这里应该使用嵌入模型将查询转换为向量
# 由于我们没有实际的嵌入模型,这里使用一个简单的模拟
query_vector = [0.1] * 1024 # 模拟向量
try:
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=limit
)
results = []
for hit in search_result:
results.append({
"id": hit.id,
"score": hit.score,
"payload": hit.payload
})
return {
"result": {
"content": [
{
"type": "text",
"text": f"搜索结果: {json.dumps(results, ensure_ascii=False)}"
}
]
}
}
except Exception as e:
logger.error(f"搜索向量时出错: {e}")
return {
"error": {
"code": -32603,
"message": f"搜索向量时出错: {str(e)}"
}
}
async def _add_vector(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""添加向量"""
text = params.get("text", "")
metadata = params.get("metadata", {})
# 生成一个简单的ID
import hashlib
vector_id = hashlib.md5(text.encode()).hexdigest()
# 这里应该使用嵌入模型将文本转换为向量
# 由于我们没有实际的嵌入模型,这里使用一个简单的模拟
vector = [0.1] * 1024 # 模拟向量
try:
self.client.upsert(
collection_name=self.collection_name,
points=[
PointStruct(
id=vector_id,
vector=vector,
payload={
"text": text,
**metadata
}
)
]
)
return {
"result": {
"content": [
{
"type": "text",
"text": f"已添加向量ID: {vector_id}"
}
]
}
}
except Exception as e:
logger.error(f"添加向量时出错: {e}")
return {
"error": {
"code": -32603,
"message": f"添加向量时出错: {str(e)}"
}
}
async def _delete_vector(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""删除向量"""
vector_id = params.get("id", "")
try:
self.client.delete(
collection_name=self.collection_name,
points_selector=[vector_id]
)
return {
"result": {
"content": [
{
"type": "text",
"text": f"已删除向量ID: {vector_id}"
}
]
}
}
except Exception as e:
logger.error(f"删除向量时出错: {e}")
return {
"error": {
"code": -32603,
"message": f"删除向量时出错: {str(e)}"
}
}
async def list_resources(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""列出资源"""
return {
"result": {
"resources": []
}
}
async def read_resource(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""读取资源"""
return {
"error": {
"code": -32601,
"message": "不支持读取资源"
}
}
async def main():
"""主函数"""
server = QdrantMCPServer()
# 从标准输入读取请求
for line in sys.stdin:
try:
request = json.loads(line)
response = await server.handle_request(request)
print(json.dumps(response, ensure_ascii=False))
sys.stdout.flush()
except json.JSONDecodeError as e:
logger.error(f"解析 JSON 时出错: {e}")
error_response = {
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32700,
"message": f"解析 JSON 时出错: {str(e)}"
}
}
print(json.dumps(error_response, ensure_ascii=False))
sys.stdout.flush()
except Exception as e:
logger.error(f"处理请求时出错: {e}")
error_response = {
"jsonrpc": "2.0",
"id": None,
"error": {
"code": -32603,
"message": f"内部错误: {str(e)}"
}
}
print(json.dumps(error_response, ensure_ascii=False))
sys.stdout.flush()
if __name__ == "__main__":
asyncio.run(main())