357 lines
13 KiB
Python
Executable File
357 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Qdrant 与 Ollama 嵌入模型集成的 MCP 服务器
|
|
此脚本实现了一个 MCP 服务器,使用 Ollama 作为嵌入模型提供者与 Qdrant 向量数据库集成
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Any, Dict, List, Optional
|
|
import logging
|
|
|
|
from langchain_ollama import OllamaEmbeddings
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter
|
|
|
|
# 设置日志
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class QdrantOllamaMCPServer:
|
|
def __init__(self):
|
|
# 在初始化之前打印环境变量
|
|
print(f"环境变量:")
|
|
print(f"QDRANT_URL: {os.getenv('QDRANT_URL', '未设置')}")
|
|
print(f"QDRANT_API_KEY: {os.getenv('QDRANT_API_KEY', '未设置')}")
|
|
print(f"OLLAMA_URL: {os.getenv('OLLAMA_URL', '未设置')}")
|
|
print(f"OLLAMA_MODEL: {os.getenv('OLLAMA_MODEL', '未设置')}")
|
|
print(f"COLLECTION_NAME: {os.getenv('COLLECTION_NAME', '未设置')}")
|
|
|
|
# 从环境变量获取配置
|
|
self.qdrant_url = os.getenv("QDRANT_URL", "http://dev1:6333") # dev1服务器上的Qdrant地址
|
|
self.qdrant_api_key = os.getenv("QDRANT_API_KEY", "313131")
|
|
self.collection_name = os.getenv("COLLECTION_NAME", "ollama_mcp")
|
|
self.ollama_model = os.getenv("OLLAMA_MODEL", "nomic-embed-text")
|
|
self.ollama_url = os.getenv("OLLAMA_URL", "http://dev1:11434") # dev1服务器上的Ollama地址
|
|
|
|
# 初始化客户端
|
|
self.embeddings = OllamaEmbeddings(
|
|
model=self.ollama_model,
|
|
base_url=self.ollama_url
|
|
)
|
|
|
|
self.client = QdrantClient(
|
|
url=self.qdrant_url,
|
|
api_key=self.qdrant_api_key
|
|
)
|
|
|
|
# 确保集合存在
|
|
self._ensure_collection_exists()
|
|
|
|
logger.info(f"初始化完成,使用集合: {self.collection_name}")
|
|
|
|
def _ensure_collection_exists(self):
|
|
"""确保集合存在,如果不存在则创建"""
|
|
collections = self.client.get_collections().collections
|
|
collection_exists = any(collection.name == self.collection_name for collection in collections)
|
|
|
|
if not collection_exists:
|
|
# 获取嵌入模型的维度
|
|
sample_embedding = self.embeddings.embed_query("sample text")
|
|
vector_size = len(sample_embedding)
|
|
|
|
self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=VectorParams(
|
|
size=vector_size,
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
logger.info(f"已创建新集合,向量维度: {vector_size}")
|
|
else:
|
|
logger.info("集合已存在")
|
|
|
|
async def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理 MCP 请求"""
|
|
method = request.get("method")
|
|
params = request.get("params", {})
|
|
request_id = request.get("id")
|
|
|
|
logger.info(f"处理请求: {method}")
|
|
|
|
try:
|
|
if method == "initialize":
|
|
result = {
|
|
"protocolVersion": "2024-11-05",
|
|
"capabilities": {
|
|
"tools": {
|
|
"listChanged": True
|
|
},
|
|
"resources": {
|
|
"subscribe": True,
|
|
"listChanged": True
|
|
}
|
|
},
|
|
"serverInfo": {
|
|
"name": "qdrant-ollama-mcp-server",
|
|
"version": "1.0.0"
|
|
}
|
|
}
|
|
elif method == "tools/list":
|
|
result = {
|
|
"tools": [
|
|
{
|
|
"name": "add_document",
|
|
"description": "添加文档到向量数据库",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"text": {
|
|
"type": "string",
|
|
"description": "文档文本内容"
|
|
},
|
|
"metadata": {
|
|
"type": "object",
|
|
"description": "文档的元数据"
|
|
}
|
|
},
|
|
"required": ["text"]
|
|
}
|
|
},
|
|
{
|
|
"name": "search_documents",
|
|
"description": "在向量数据库中搜索相似文档",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "搜索查询文本"
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "返回结果数量限制",
|
|
"default": 5
|
|
},
|
|
"filter": {
|
|
"type": "object",
|
|
"description": "搜索过滤器"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
},
|
|
{
|
|
"name": "list_collections",
|
|
"description": "列出所有集合",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
},
|
|
{
|
|
"name": "get_collection_info",
|
|
"description": "获取集合信息",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"collection_name": {
|
|
"type": "string",
|
|
"description": "集合名称"
|
|
}
|
|
},
|
|
"required": ["collection_name"]
|
|
}
|
|
}
|
|
]
|
|
}
|
|
elif method == "tools/call":
|
|
tool_name = params.get("name")
|
|
tool_params = params.get("arguments", {})
|
|
|
|
if tool_name == "add_document":
|
|
result = await self._add_document(tool_params)
|
|
elif tool_name == "search_documents":
|
|
result = await self._search_documents(tool_params)
|
|
elif tool_name == "list_collections":
|
|
result = await self._list_collections(tool_params)
|
|
elif tool_name == "get_collection_info":
|
|
result = await self._get_collection_info(tool_params)
|
|
else:
|
|
raise ValueError(f"未知工具: {tool_name}")
|
|
else:
|
|
raise ValueError(f"未知方法: {method}")
|
|
|
|
response = {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": result
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"处理请求时出错: {e}")
|
|
response = {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"error": {
|
|
"code": -1,
|
|
"message": str(e)
|
|
}
|
|
}
|
|
|
|
return response
|
|
|
|
async def _add_document(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""添加文档到向量数据库"""
|
|
text = params.get("text")
|
|
metadata = params.get("metadata", {})
|
|
|
|
if not text:
|
|
raise ValueError("文档文本不能为空")
|
|
|
|
# 生成嵌入
|
|
embedding = self.embeddings.embed_query(text)
|
|
|
|
# 创建点
|
|
point = PointStruct(
|
|
id=hash(text) % (2 ** 31), # 使用文本哈希作为ID
|
|
vector=embedding,
|
|
payload={
|
|
"text": text,
|
|
"metadata": metadata
|
|
}
|
|
)
|
|
|
|
# 上传到 Qdrant
|
|
self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
points=[point]
|
|
)
|
|
|
|
return {"success": True, "message": "文档已添加"}
|
|
|
|
async def _search_documents(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""在向量数据库中搜索相似文档"""
|
|
query = params.get("query")
|
|
limit = params.get("limit", 5)
|
|
filter_dict = params.get("filter")
|
|
|
|
if not query:
|
|
raise ValueError("搜索查询不能为空")
|
|
|
|
# 生成查询嵌入
|
|
query_embedding = self.embeddings.embed_query(query)
|
|
|
|
# 构建过滤器
|
|
search_filter = None
|
|
if filter_dict:
|
|
search_filter = Filter(**filter_dict)
|
|
|
|
# 执行搜索
|
|
search_result = self.client.search(
|
|
collection_name=self.collection_name,
|
|
query_vector=query_embedding,
|
|
limit=limit,
|
|
query_filter=search_filter
|
|
)
|
|
|
|
# 格式化结果
|
|
results = []
|
|
for hit in search_result:
|
|
results.append({
|
|
"text": hit.payload.get("text", ""),
|
|
"metadata": hit.payload.get("metadata", {}),
|
|
"score": hit.score
|
|
})
|
|
|
|
return {"results": results}
|
|
|
|
async def _list_collections(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""列出所有集合"""
|
|
collections = self.client.get_collections().collections
|
|
return {
|
|
"collections": [
|
|
{"name": collection.name} for collection in collections
|
|
]
|
|
}
|
|
|
|
async def _get_collection_info(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""获取集合信息"""
|
|
collection_name = params.get("collection_name")
|
|
|
|
if not collection_name:
|
|
raise ValueError("集合名称不能为空")
|
|
|
|
try:
|
|
collection_info = self.client.get_collection(collection_name)
|
|
return {
|
|
"name": collection_name,
|
|
"vectors_count": collection_info.points_count,
|
|
"vectors_config": collection_info.config.params.vectors.dict()
|
|
}
|
|
except Exception as e:
|
|
raise ValueError(f"获取集合信息失败: {str(e)}")
|
|
|
|
async def run(self):
|
|
"""运行 MCP 服务器"""
|
|
logger.info("启动 Qdrant-Ollama MCP 服务器")
|
|
logger.info(f"Qdrant URL: {self.qdrant_url}")
|
|
logger.info(f"Ollama URL: {self.ollama_url}")
|
|
logger.info(f"Collection: {self.collection_name}")
|
|
|
|
# 从标准输入读取请求
|
|
while True:
|
|
try:
|
|
line = await asyncio.get_event_loop().run_in_executor(
|
|
None, sys.stdin.readline
|
|
)
|
|
if not line:
|
|
break
|
|
|
|
logger.info(f"收到请求: {line.strip()}")
|
|
|
|
# 解析 JSON 请求
|
|
request = json.loads(line.strip())
|
|
|
|
# 处理请求
|
|
response = await self.handle_request(request)
|
|
|
|
# 发送响应
|
|
response_json = json.dumps(response)
|
|
print(response_json, flush=True)
|
|
logger.info(f"发送响应: {response_json}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"JSON 解析错误: {e}")
|
|
except Exception as e:
|
|
logger.error(f"处理请求时出错: {e}")
|
|
except KeyboardInterrupt:
|
|
logger.info("服务器被中断")
|
|
break
|
|
|
|
async def main():
|
|
"""主函数"""
|
|
# 设置日志级别
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
# 打印环境变量
|
|
print(f"环境变量:")
|
|
print(f"QDRANT_URL: {os.getenv('QDRANT_URL', '未设置')}")
|
|
print(f"QDRANT_API_KEY: {os.getenv('QDRANT_API_KEY', '未设置')}")
|
|
print(f"OLLAMA_URL: {os.getenv('OLLAMA_URL', '未设置')}")
|
|
print(f"OLLAMA_MODEL: {os.getenv('OLLAMA_MODEL', '未设置')}")
|
|
print(f"COLLECTION_NAME: {os.getenv('COLLECTION_NAME', '未设置')}")
|
|
|
|
# 创建服务器实例
|
|
server = QdrantOllamaMCPServer()
|
|
|
|
# 运行服务器
|
|
await server.run()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |