Files
tts/scripts/import_to_qdrant.py
2026-01-19 10:27:41 +08:00

187 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
导入文章到 Qdrant 向量数据库
支持 MCP 访问
"""
import os
import sys
from pathlib import Path
import qdrant_client
from qdrant_client.models import PointStruct, VectorParams, Distance
import uuid
import time
# Qdrant 配置
QDRANT_URL = "http://localhost:6333"
COLLECTION_NAME = "fengtian_articles"
class ArticleImporter:
def __init__(self):
self.client = qdrant_client.QdrantClient(url=QDRANT_URL)
self.collection_name = COLLECTION_NAME
def create_collection(self):
"""创建 collection如果不存在"""
collections = self.client.get_collections().collections
if not any(c.name == self.collection_name for c in collections):
print(f"创建 collection: {self.collection_name}")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=768, # nomic-embed-text 维度
distance=Distance.COSINE
)
)
else:
print(f"Collection {self.collection_name} 已存在")
def read_file(self, file_path):
"""读取文件内容"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
print(f"读取文件: {file_path} ({len(content)} 字符)")
return content
except Exception as e:
print(f"读取文件失败: {e}")
return None
def split_into_chunks(self, content, chunk_size=1000, overlap=100):
"""将内容分割成 chunks"""
chunks = []
start = 0
while start < len(content):
end = start + chunk_size
# 尽量在句号或换行处分割
if end < len(content):
# 查找最近的句号
last_period = content.rfind('', start, end)
last_newline = content.rfind('\n', start, end)
split_pos = max(last_period, last_newline)
if split_pos > start + chunk_size * 0.8: # 只在 chunk 的 80% 之后找分割点
end = split_pos + 1
chunk = content[start:end].strip()
if chunk:
chunks.append(chunk)
start = end - overlap
print(f"分割成 {len(chunks)} 个 chunks")
return chunks
def generate_embedding(self, text):
"""使用 Ollama 生成向量嵌入"""
try:
import ollama
response = ollama.embeddings(
model="nomic-embed-text",
prompt=text[:8192] # 限制长度
)
return response["embedding"]
except Exception as e:
print(f"生成 embedding 失败: {e}")
# 降级使用随机向量
import random
return [random.random() for _ in range(768)]
def import_file(self, file_path):
"""导入单个文件"""
content = self.read_file(file_path)
if not content:
return
chunks = self.split_into_chunks(content)
points = []
for i, chunk in enumerate(chunks):
# 生成向量(实际应使用真实 embedding
vector = self.generate_embedding(chunk)
point_id = str(uuid.uuid4())
points.append(
PointStruct(
id=point_id,
vector=vector,
payload={
"file_path": str(file_path),
"chunk_index": i,
"content": chunk[:200] + "..." if len(chunk) > 200 else chunk,
"full_content": chunk,
"timestamp": int(time.time())
}
)
)
# 批量导入
batch_size = 100
for i in range(0, len(points), batch_size):
batch = points[i:i + batch_size]
self.client.upsert(
collection_name=self.collection_name,
points=batch
)
print(f"已导入 {len(batch)} 条记录")
print(f"\n文件 {file_path} 导入完成,共 {len(points)} 条记录")
def import_directory(self, dir_path, pattern="*.md"):
"""导入目录下的所有匹配文件"""
path = Path(dir_path)
files = list(path.rglob(pattern))
print(f"发现 {len(files)} 个文件")
for file_path in files:
if file_path.is_file():
print(f"\n{'='*60}")
print(f"处理文件: {file_path}")
print(f"{'='*60}")
self.import_file(file_path)
def search(self, query_text, limit=5):
"""搜索相似内容"""
query_vector = self.generate_embedding(query_text)
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=limit
)
return results
def main():
importer = ArticleImporter()
# 创建 collection
importer.create_collection()
# 导入文件
if len(sys.argv) > 1:
# 导入指定文件或目录
path = sys.argv[1]
if os.path.isdir(path):
importer.import_directory(path)
else:
importer.import_file(path)
else:
# 默认导入 material 和 papers 目录
print("导入 material 目录...")
importer.import_directory("/root/tts/material")
print("\n导入 papers 目录...")
importer.import_directory("/root/tts/papers")
print("\n导入 docs 目录...")
importer.import_directory("/root/tts/docs")
if __name__ == "__main__":
main()