mgmt/scripts/mcp/servers/qdrant-ollama-mcp-server.py

357 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Qdrant 与 Ollama 嵌入模型集成的 MCP 服务器
此脚本实现了一个 MCP 服务器,使用 Ollama 作为嵌入模型提供者与 Qdrant 向量数据库集成
"""
import asyncio
import json
import os
import sys
from typing import Any, Dict, List, Optional
import logging
from langchain_ollama import OllamaEmbeddings
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class QdrantOllamaMCPServer:
def __init__(self):
# 在初始化之前打印环境变量
print(f"环境变量:")
print(f"QDRANT_URL: {os.getenv('QDRANT_URL', '未设置')}")
print(f"QDRANT_API_KEY: {os.getenv('QDRANT_API_KEY', '未设置')}")
print(f"OLLAMA_URL: {os.getenv('OLLAMA_URL', '未设置')}")
print(f"OLLAMA_MODEL: {os.getenv('OLLAMA_MODEL', '未设置')}")
print(f"COLLECTION_NAME: {os.getenv('COLLECTION_NAME', '未设置')}")
# 从环境变量获取配置
self.qdrant_url = os.getenv("QDRANT_URL", "http://dev1:6333") # dev1服务器上的Qdrant地址
self.qdrant_api_key = os.getenv("QDRANT_API_KEY", "313131")
self.collection_name = os.getenv("COLLECTION_NAME", "ollama_mcp")
self.ollama_model = os.getenv("OLLAMA_MODEL", "nomic-embed-text")
self.ollama_url = os.getenv("OLLAMA_URL", "http://dev1:11434") # dev1服务器上的Ollama地址
# 初始化客户端
self.embeddings = OllamaEmbeddings(
model=self.ollama_model,
base_url=self.ollama_url
)
self.client = QdrantClient(
url=self.qdrant_url,
api_key=self.qdrant_api_key
)
# 确保集合存在
self._ensure_collection_exists()
logger.info(f"初始化完成,使用集合: {self.collection_name}")
def _ensure_collection_exists(self):
"""确保集合存在,如果不存在则创建"""
collections = self.client.get_collections().collections
collection_exists = any(collection.name == self.collection_name for collection in collections)
if not collection_exists:
# 获取嵌入模型的维度
sample_embedding = self.embeddings.embed_query("sample text")
vector_size = len(sample_embedding)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE
)
)
logger.info(f"已创建新集合,向量维度: {vector_size}")
else:
logger.info("集合已存在")
async def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理 MCP 请求"""
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
logger.info(f"处理请求: {method}")
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {
"listChanged": True
},
"resources": {
"subscribe": True,
"listChanged": True
}
},
"serverInfo": {
"name": "qdrant-ollama-mcp-server",
"version": "1.0.0"
}
}
elif method == "tools/list":
result = {
"tools": [
{
"name": "add_document",
"description": "添加文档到向量数据库",
"inputSchema": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "文档文本内容"
},
"metadata": {
"type": "object",
"description": "文档的元数据"
}
},
"required": ["text"]
}
},
{
"name": "search_documents",
"description": "在向量数据库中搜索相似文档",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询文本"
},
"limit": {
"type": "integer",
"description": "返回结果数量限制",
"default": 5
},
"filter": {
"type": "object",
"description": "搜索过滤器"
}
},
"required": ["query"]
}
},
{
"name": "list_collections",
"description": "列出所有集合",
"inputSchema": {
"type": "object",
"properties": {}
}
},
{
"name": "get_collection_info",
"description": "获取集合信息",
"inputSchema": {
"type": "object",
"properties": {
"collection_name": {
"type": "string",
"description": "集合名称"
}
},
"required": ["collection_name"]
}
}
]
}
elif method == "tools/call":
tool_name = params.get("name")
tool_params = params.get("arguments", {})
if tool_name == "add_document":
result = await self._add_document(tool_params)
elif tool_name == "search_documents":
result = await self._search_documents(tool_params)
elif tool_name == "list_collections":
result = await self._list_collections(tool_params)
elif tool_name == "get_collection_info":
result = await self._get_collection_info(tool_params)
else:
raise ValueError(f"未知工具: {tool_name}")
else:
raise ValueError(f"未知方法: {method}")
response = {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
except Exception as e:
logger.error(f"处理请求时出错: {e}")
response = {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -1,
"message": str(e)
}
}
return response
async def _add_document(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""添加文档到向量数据库"""
text = params.get("text")
metadata = params.get("metadata", {})
if not text:
raise ValueError("文档文本不能为空")
# 生成嵌入
embedding = self.embeddings.embed_query(text)
# 创建点
point = PointStruct(
id=hash(text) % (2 ** 31), # 使用文本哈希作为ID
vector=embedding,
payload={
"text": text,
"metadata": metadata
}
)
# 上传到 Qdrant
self.client.upsert(
collection_name=self.collection_name,
points=[point]
)
return {"success": True, "message": "文档已添加"}
async def _search_documents(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""在向量数据库中搜索相似文档"""
query = params.get("query")
limit = params.get("limit", 5)
filter_dict = params.get("filter")
if not query:
raise ValueError("搜索查询不能为空")
# 生成查询嵌入
query_embedding = self.embeddings.embed_query(query)
# 构建过滤器
search_filter = None
if filter_dict:
search_filter = Filter(**filter_dict)
# 执行搜索
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
limit=limit,
query_filter=search_filter
)
# 格式化结果
results = []
for hit in search_result:
results.append({
"text": hit.payload.get("text", ""),
"metadata": hit.payload.get("metadata", {}),
"score": hit.score
})
return {"results": results}
async def _list_collections(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""列出所有集合"""
collections = self.client.get_collections().collections
return {
"collections": [
{"name": collection.name} for collection in collections
]
}
async def _get_collection_info(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""获取集合信息"""
collection_name = params.get("collection_name")
if not collection_name:
raise ValueError("集合名称不能为空")
try:
collection_info = self.client.get_collection(collection_name)
return {
"name": collection_name,
"vectors_count": collection_info.points_count,
"vectors_config": collection_info.config.params.vectors.dict()
}
except Exception as e:
raise ValueError(f"获取集合信息失败: {str(e)}")
async def run(self):
"""运行 MCP 服务器"""
logger.info("启动 Qdrant-Ollama MCP 服务器")
logger.info(f"Qdrant URL: {self.qdrant_url}")
logger.info(f"Ollama URL: {self.ollama_url}")
logger.info(f"Collection: {self.collection_name}")
# 从标准输入读取请求
while True:
try:
line = await asyncio.get_event_loop().run_in_executor(
None, sys.stdin.readline
)
if not line:
break
logger.info(f"收到请求: {line.strip()}")
# 解析 JSON 请求
request = json.loads(line.strip())
# 处理请求
response = await self.handle_request(request)
# 发送响应
response_json = json.dumps(response)
print(response_json, flush=True)
logger.info(f"发送响应: {response_json}")
except json.JSONDecodeError as e:
logger.error(f"JSON 解析错误: {e}")
except Exception as e:
logger.error(f"处理请求时出错: {e}")
except KeyboardInterrupt:
logger.info("服务器被中断")
break
async def main():
"""主函数"""
# 设置日志级别
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 打印环境变量
print(f"环境变量:")
print(f"QDRANT_URL: {os.getenv('QDRANT_URL', '未设置')}")
print(f"QDRANT_API_KEY: {os.getenv('QDRANT_API_KEY', '未设置')}")
print(f"OLLAMA_URL: {os.getenv('OLLAMA_URL', '未设置')}")
print(f"OLLAMA_MODEL: {os.getenv('OLLAMA_MODEL', '未设置')}")
print(f"COLLECTION_NAME: {os.getenv('COLLECTION_NAME', '未设置')}")
# 创建服务器实例
server = QdrantOllamaMCPServer()
# 运行服务器
await server.run()
if __name__ == "__main__":
asyncio.run(main())