380 lines
13 KiB
Python
Executable File
380 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
Qdrant MCP 服务器
|
||
此脚本实现了一个 MCP 服务器,与 Qdrant 向量数据库集成
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import sys
|
||
from typing import Any, Dict, List, Optional
|
||
import logging
|
||
|
||
from qdrant_client import QdrantClient
|
||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class QdrantMCPServer:
|
||
def __init__(self):
|
||
# 从环境变量获取配置
|
||
self.qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||
self.qdrant_api_key = os.getenv("QDRANT_API_KEY", "")
|
||
self.collection_name = os.getenv("COLLECTION_NAME", "mcp")
|
||
self.embedding_model = os.getenv("EMBEDDING_MODEL", "bge-m3")
|
||
|
||
# 初始化 Qdrant 客户端
|
||
self.client = QdrantClient(
|
||
url=self.qdrant_url,
|
||
api_key=self.qdrant_api_key if self.qdrant_api_key else None
|
||
)
|
||
|
||
# 确保集合存在
|
||
self._ensure_collection_exists()
|
||
|
||
logger.info(f"Qdrant MCP 服务器已初始化")
|
||
logger.info(f"Qdrant URL: {self.qdrant_url}")
|
||
logger.info(f"集合名称: {self.collection_name}")
|
||
logger.info(f"嵌入模型: {self.embedding_model}")
|
||
|
||
def _ensure_collection_exists(self):
|
||
"""确保集合存在,如果不存在则创建"""
|
||
try:
|
||
collections = self.client.get_collections().collections
|
||
collection_names = [collection.name for collection in collections]
|
||
|
||
if self.collection_name not in collection_names:
|
||
# 创建新集合
|
||
self.client.create_collection(
|
||
collection_name=self.collection_name,
|
||
vectors_config=VectorParams(size=1024, distance=Distance.COSINE)
|
||
)
|
||
logger.info(f"已创建新集合: {self.collection_name}")
|
||
else:
|
||
logger.info(f"集合已存在: {self.collection_name}")
|
||
except Exception as e:
|
||
logger.error(f"确保集合存在时出错: {e}")
|
||
raise
|
||
|
||
async def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""处理 MCP 请求"""
|
||
method = request.get("method")
|
||
params = request.get("params", {})
|
||
request_id = request.get("id")
|
||
|
||
logger.info(f"收到请求: {method}")
|
||
|
||
try:
|
||
if method == "initialize":
|
||
result = await self.initialize(params)
|
||
elif method == "tools/list":
|
||
result = await self.list_tools(params)
|
||
elif method == "tools/call":
|
||
result = await self.call_tool(params)
|
||
elif method == "resources/list":
|
||
result = await self.list_resources(params)
|
||
elif method == "resources/read":
|
||
result = await self.read_resource(params)
|
||
else:
|
||
result = {
|
||
"error": {
|
||
"code": -32601,
|
||
"message": f"未知方法: {method}"
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"处理请求时出错: {e}")
|
||
result = {
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"内部错误: {str(e)}"
|
||
}
|
||
}
|
||
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
**result
|
||
}
|
||
|
||
return response
|
||
|
||
async def initialize(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""初始化 MCP 服务器"""
|
||
logger.info("初始化 Qdrant MCP 服务器")
|
||
|
||
return {
|
||
"result": {
|
||
"protocolVersion": "2024-11-05",
|
||
"capabilities": {
|
||
"tools": {
|
||
"listChanged": False
|
||
},
|
||
"resources": {
|
||
"subscribe": False,
|
||
"listChanged": False
|
||
}
|
||
},
|
||
"serverInfo": {
|
||
"name": "qdrant-mcp-server",
|
||
"version": "1.0.0"
|
||
}
|
||
}
|
||
}
|
||
|
||
async def list_tools(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""列出可用工具"""
|
||
return {
|
||
"result": {
|
||
"tools": [
|
||
{
|
||
"name": "qdrant_search",
|
||
"description": "在 Qdrant 中搜索相似向量",
|
||
"inputSchema": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "搜索查询文本"
|
||
},
|
||
"limit": {
|
||
"type": "integer",
|
||
"default": 5,
|
||
"description": "返回结果数量限制"
|
||
}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
},
|
||
{
|
||
"name": "qdrant_add",
|
||
"description": "向 Qdrant 添加向量",
|
||
"inputSchema": {
|
||
"type": "object",
|
||
"properties": {
|
||
"text": {
|
||
"type": "string",
|
||
"description": "要添加的文本内容"
|
||
},
|
||
"metadata": {
|
||
"type": "object",
|
||
"description": "与文本关联的元数据"
|
||
}
|
||
},
|
||
"required": ["text"]
|
||
}
|
||
},
|
||
{
|
||
"name": "qdrant_delete",
|
||
"description": "从 Qdrant 删除向量",
|
||
"inputSchema": {
|
||
"type": "object",
|
||
"properties": {
|
||
"id": {
|
||
"type": "string",
|
||
"description": "要删除的向量ID"
|
||
}
|
||
},
|
||
"required": ["id"]
|
||
}
|
||
}
|
||
]
|
||
}
|
||
}
|
||
|
||
async def call_tool(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""调用工具"""
|
||
name = params.get("name")
|
||
arguments = params.get("arguments", {})
|
||
|
||
if name == "qdrant_search":
|
||
return await self._search_vectors(arguments)
|
||
elif name == "qdrant_add":
|
||
return await self._add_vector(arguments)
|
||
elif name == "qdrant_delete":
|
||
return await self._delete_vector(arguments)
|
||
else:
|
||
return {
|
||
"error": {
|
||
"code": -32601,
|
||
"message": f"未知工具: {name}"
|
||
}
|
||
}
|
||
|
||
async def _search_vectors(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""搜索相似向量"""
|
||
query = params.get("query", "")
|
||
limit = params.get("limit", 5)
|
||
|
||
# 这里应该使用嵌入模型将查询转换为向量
|
||
# 由于我们没有实际的嵌入模型,这里使用一个简单的模拟
|
||
query_vector = [0.1] * 1024 # 模拟向量
|
||
|
||
try:
|
||
search_result = self.client.search(
|
||
collection_name=self.collection_name,
|
||
query_vector=query_vector,
|
||
limit=limit
|
||
)
|
||
|
||
results = []
|
||
for hit in search_result:
|
||
results.append({
|
||
"id": hit.id,
|
||
"score": hit.score,
|
||
"payload": hit.payload
|
||
})
|
||
|
||
return {
|
||
"result": {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"搜索结果: {json.dumps(results, ensure_ascii=False)}"
|
||
}
|
||
]
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"搜索向量时出错: {e}")
|
||
return {
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"搜索向量时出错: {str(e)}"
|
||
}
|
||
}
|
||
|
||
async def _add_vector(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""添加向量"""
|
||
text = params.get("text", "")
|
||
metadata = params.get("metadata", {})
|
||
|
||
# 生成一个简单的ID
|
||
import hashlib
|
||
vector_id = hashlib.md5(text.encode()).hexdigest()
|
||
|
||
# 这里应该使用嵌入模型将文本转换为向量
|
||
# 由于我们没有实际的嵌入模型,这里使用一个简单的模拟
|
||
vector = [0.1] * 1024 # 模拟向量
|
||
|
||
try:
|
||
self.client.upsert(
|
||
collection_name=self.collection_name,
|
||
points=[
|
||
PointStruct(
|
||
id=vector_id,
|
||
vector=vector,
|
||
payload={
|
||
"text": text,
|
||
**metadata
|
||
}
|
||
)
|
||
]
|
||
)
|
||
|
||
return {
|
||
"result": {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"已添加向量,ID: {vector_id}"
|
||
}
|
||
]
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"添加向量时出错: {e}")
|
||
return {
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"添加向量时出错: {str(e)}"
|
||
}
|
||
}
|
||
|
||
async def _delete_vector(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""删除向量"""
|
||
vector_id = params.get("id", "")
|
||
|
||
try:
|
||
self.client.delete(
|
||
collection_name=self.collection_name,
|
||
points_selector=[vector_id]
|
||
)
|
||
|
||
return {
|
||
"result": {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"已删除向量,ID: {vector_id}"
|
||
}
|
||
]
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"删除向量时出错: {e}")
|
||
return {
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"删除向量时出错: {str(e)}"
|
||
}
|
||
}
|
||
|
||
async def list_resources(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""列出资源"""
|
||
return {
|
||
"result": {
|
||
"resources": []
|
||
}
|
||
}
|
||
|
||
async def read_resource(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""读取资源"""
|
||
return {
|
||
"error": {
|
||
"code": -32601,
|
||
"message": "不支持读取资源"
|
||
}
|
||
}
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
server = QdrantMCPServer()
|
||
|
||
# 从标准输入读取请求
|
||
for line in sys.stdin:
|
||
try:
|
||
request = json.loads(line)
|
||
response = await server.handle_request(request)
|
||
print(json.dumps(response, ensure_ascii=False))
|
||
sys.stdout.flush()
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"解析 JSON 时出错: {e}")
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"id": None,
|
||
"error": {
|
||
"code": -32700,
|
||
"message": f"解析 JSON 时出错: {str(e)}"
|
||
}
|
||
}
|
||
print(json.dumps(error_response, ensure_ascii=False))
|
||
sys.stdout.flush()
|
||
except Exception as e:
|
||
logger.error(f"处理请求时出错: {e}")
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"id": None,
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"内部错误: {str(e)}"
|
||
}
|
||
}
|
||
print(json.dumps(error_response, ensure_ascii=False))
|
||
sys.stdout.flush()
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |