132 lines
4.0 KiB
Python
132 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Qdrant MCP Server - 让 AI 可以访问向量数据库中的文章
|
|
"""
|
|
|
|
import sys
|
|
import json
|
|
import qdrant_client
|
|
from qdrant_client.models import VectorParams, Distance
|
|
import uuid
|
|
import time
|
|
import os
|
|
|
|
# Qdrant 配置
|
|
QDRANT_URL = "http://localhost:6333"
|
|
COLLECTION_NAME = "fengtian_articles"
|
|
|
|
class QdrantMCP:
|
|
def __init__(self):
|
|
self.client = qdrant_client.QdrantClient(url=QDRANT_URL)
|
|
self.collection_name = COLLECTION_NAME
|
|
|
|
def search(self, query_text, limit=5):
|
|
"""搜索相关文章片段"""
|
|
# 使用 Ollama 生成向量
|
|
try:
|
|
import ollama
|
|
response = ollama.embeddings(
|
|
model="nomic-embed-text",
|
|
prompt=query_text[:8192]
|
|
)
|
|
query_vector = response["embedding"]
|
|
except Exception as e:
|
|
# 降级使用随机向量
|
|
import random
|
|
query_vector = [random.random() for _ in range(768)]
|
|
|
|
results = self.client.query_points(
|
|
collection_name=self.collection_name,
|
|
query=query_vector,
|
|
limit=limit
|
|
).points
|
|
|
|
# 格式化结果
|
|
formatted_results = []
|
|
for result in results:
|
|
formatted_results.append({
|
|
"id": result.id,
|
|
"score": result.score,
|
|
"file_path": result.payload.get("file_path", ""),
|
|
"chunk_index": result.payload.get("chunk_index", 0),
|
|
"content": result.payload.get("full_content", "")
|
|
})
|
|
|
|
return formatted_results
|
|
|
|
def get_collection_info(self):
|
|
"""获取 collection 信息"""
|
|
try:
|
|
collections = self.client.get_collections().collections
|
|
collection_names = [c.name for c in collections]
|
|
|
|
if self.collection_name in collection_names:
|
|
collection_info = self.client.get_collection(self.collection_name)
|
|
return {
|
|
"exists": True,
|
|
"name": self.collection_name,
|
|
"points_count": collection_info.points_count
|
|
}
|
|
else:
|
|
return {
|
|
"exists": False,
|
|
"name": self.collection_name,
|
|
"message": "Collection not found. Please run import_to_qdrant.py first."
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"error": str(e),
|
|
"message": "Failed to connect to Qdrant. Make sure it's running."
|
|
}
|
|
|
|
|
|
def main():
|
|
qdrant = QdrantMCP()
|
|
|
|
# MCP 协议 - 简化的 stdio 通信
|
|
print("Qdrant MCP Server started", file=sys.stderr)
|
|
|
|
while True:
|
|
try:
|
|
# 读取输入
|
|
line = sys.stdin.readline()
|
|
if not line:
|
|
break
|
|
|
|
# 解析请求
|
|
request = json.loads(line.strip())
|
|
method = request.get("method")
|
|
params = request.get("params", {})
|
|
|
|
# 处理请求
|
|
if method == "search":
|
|
results = qdrant.search(
|
|
query_text=params.get("query", ""),
|
|
limit=params.get("limit", 5)
|
|
)
|
|
response = {
|
|
"result": results,
|
|
"status": "success"
|
|
}
|
|
elif method == "info":
|
|
response = qdrant.get_collection_info()
|
|
else:
|
|
response = {
|
|
"error": f"Unknown method: {method}",
|
|
"status": "error"
|
|
}
|
|
|
|
# 发送响应
|
|
print(json.dumps(response), flush=True)
|
|
|
|
except Exception as e:
|
|
error_response = {
|
|
"error": str(e),
|
|
"status": "error"
|
|
}
|
|
print(json.dumps(error_response), flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|