187 lines
5.9 KiB
Python
187 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
导入文章到 Qdrant 向量数据库
|
||
支持 MCP 访问
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
import qdrant_client
|
||
from qdrant_client.models import PointStruct, VectorParams, Distance
|
||
import uuid
|
||
import time
|
||
|
||
# Qdrant 配置
|
||
QDRANT_URL = "http://localhost:6333"
|
||
COLLECTION_NAME = "fengtian_articles"
|
||
|
||
class ArticleImporter:
|
||
def __init__(self):
|
||
self.client = qdrant_client.QdrantClient(url=QDRANT_URL)
|
||
self.collection_name = COLLECTION_NAME
|
||
|
||
def create_collection(self):
|
||
"""创建 collection(如果不存在)"""
|
||
collections = self.client.get_collections().collections
|
||
if not any(c.name == self.collection_name for c in collections):
|
||
print(f"创建 collection: {self.collection_name}")
|
||
self.client.create_collection(
|
||
collection_name=self.collection_name,
|
||
vectors_config=VectorParams(
|
||
size=768, # nomic-embed-text 维度
|
||
distance=Distance.COSINE
|
||
)
|
||
)
|
||
else:
|
||
print(f"Collection {self.collection_name} 已存在")
|
||
|
||
def read_file(self, file_path):
|
||
"""读取文件内容"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
print(f"读取文件: {file_path} ({len(content)} 字符)")
|
||
return content
|
||
except Exception as e:
|
||
print(f"读取文件失败: {e}")
|
||
return None
|
||
|
||
def split_into_chunks(self, content, chunk_size=1000, overlap=100):
|
||
"""将内容分割成 chunks"""
|
||
chunks = []
|
||
start = 0
|
||
|
||
while start < len(content):
|
||
end = start + chunk_size
|
||
|
||
# 尽量在句号或换行处分割
|
||
if end < len(content):
|
||
# 查找最近的句号
|
||
last_period = content.rfind('。', start, end)
|
||
last_newline = content.rfind('\n', start, end)
|
||
|
||
split_pos = max(last_period, last_newline)
|
||
if split_pos > start + chunk_size * 0.8: # 只在 chunk 的 80% 之后找分割点
|
||
end = split_pos + 1
|
||
|
||
chunk = content[start:end].strip()
|
||
if chunk:
|
||
chunks.append(chunk)
|
||
|
||
start = end - overlap
|
||
|
||
print(f"分割成 {len(chunks)} 个 chunks")
|
||
return chunks
|
||
|
||
def generate_embedding(self, text):
|
||
"""使用 Ollama 生成向量嵌入"""
|
||
try:
|
||
import ollama
|
||
response = ollama.embeddings(
|
||
model="nomic-embed-text",
|
||
prompt=text[:8192] # 限制长度
|
||
)
|
||
return response["embedding"]
|
||
except Exception as e:
|
||
print(f"生成 embedding 失败: {e}")
|
||
# 降级使用随机向量
|
||
import random
|
||
return [random.random() for _ in range(768)]
|
||
|
||
def import_file(self, file_path):
|
||
"""导入单个文件"""
|
||
content = self.read_file(file_path)
|
||
if not content:
|
||
return
|
||
|
||
chunks = self.split_into_chunks(content)
|
||
points = []
|
||
|
||
for i, chunk in enumerate(chunks):
|
||
# 生成向量(实际应使用真实 embedding)
|
||
vector = self.generate_embedding(chunk)
|
||
|
||
point_id = str(uuid.uuid4())
|
||
points.append(
|
||
PointStruct(
|
||
id=point_id,
|
||
vector=vector,
|
||
payload={
|
||
"file_path": str(file_path),
|
||
"chunk_index": i,
|
||
"content": chunk[:200] + "..." if len(chunk) > 200 else chunk,
|
||
"full_content": chunk,
|
||
"timestamp": int(time.time())
|
||
}
|
||
)
|
||
)
|
||
|
||
# 批量导入
|
||
batch_size = 100
|
||
for i in range(0, len(points), batch_size):
|
||
batch = points[i:i + batch_size]
|
||
self.client.upsert(
|
||
collection_name=self.collection_name,
|
||
points=batch
|
||
)
|
||
print(f"已导入 {len(batch)} 条记录")
|
||
|
||
print(f"\n文件 {file_path} 导入完成,共 {len(points)} 条记录")
|
||
|
||
def import_directory(self, dir_path, pattern="*.md"):
|
||
"""导入目录下的所有匹配文件"""
|
||
path = Path(dir_path)
|
||
files = list(path.rglob(pattern))
|
||
|
||
print(f"发现 {len(files)} 个文件")
|
||
|
||
for file_path in files:
|
||
if file_path.is_file():
|
||
print(f"\n{'='*60}")
|
||
print(f"处理文件: {file_path}")
|
||
print(f"{'='*60}")
|
||
self.import_file(file_path)
|
||
|
||
def search(self, query_text, limit=5):
|
||
"""搜索相似内容"""
|
||
query_vector = self.generate_embedding(query_text)
|
||
|
||
results = self.client.search(
|
||
collection_name=self.collection_name,
|
||
query_vector=query_vector,
|
||
limit=limit
|
||
)
|
||
|
||
return results
|
||
|
||
|
||
def main():
|
||
importer = ArticleImporter()
|
||
|
||
# 创建 collection
|
||
importer.create_collection()
|
||
|
||
# 导入文件
|
||
if len(sys.argv) > 1:
|
||
# 导入指定文件或目录
|
||
path = sys.argv[1]
|
||
if os.path.isdir(path):
|
||
importer.import_directory(path)
|
||
else:
|
||
importer.import_file(path)
|
||
else:
|
||
# 默认导入 material 和 papers 目录
|
||
print("导入 material 目录...")
|
||
importer.import_directory("/root/tts/material")
|
||
|
||
print("\n导入 papers 目录...")
|
||
importer.import_directory("/root/tts/papers")
|
||
|
||
print("\n导入 docs 目录...")
|
||
importer.import_directory("/root/tts/docs")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|