#!/usr/bin/env python3 """ 导入文章到 Qdrant 向量数据库 支持 MCP 访问 """ import os import sys from pathlib import Path import qdrant_client from qdrant_client.models import PointStruct, VectorParams, Distance import uuid import time # Qdrant 配置 QDRANT_URL = "http://localhost:6333" COLLECTION_NAME = "fengtian_articles" class ArticleImporter: def __init__(self): self.client = qdrant_client.QdrantClient(url=QDRANT_URL) self.collection_name = COLLECTION_NAME def create_collection(self): """创建 collection(如果不存在)""" collections = self.client.get_collections().collections if not any(c.name == self.collection_name for c in collections): print(f"创建 collection: {self.collection_name}") self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( size=768, # nomic-embed-text 维度 distance=Distance.COSINE ) ) else: print(f"Collection {self.collection_name} 已存在") def read_file(self, file_path): """读取文件内容""" try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() print(f"读取文件: {file_path} ({len(content)} 字符)") return content except Exception as e: print(f"读取文件失败: {e}") return None def split_into_chunks(self, content, chunk_size=1000, overlap=100): """将内容分割成 chunks""" chunks = [] start = 0 while start < len(content): end = start + chunk_size # 尽量在句号或换行处分割 if end < len(content): # 查找最近的句号 last_period = content.rfind('。', start, end) last_newline = content.rfind('\n', start, end) split_pos = max(last_period, last_newline) if split_pos > start + chunk_size * 0.8: # 只在 chunk 的 80% 之后找分割点 end = split_pos + 1 chunk = content[start:end].strip() if chunk: chunks.append(chunk) start = end - overlap print(f"分割成 {len(chunks)} 个 chunks") return chunks def generate_embedding(self, text): """使用 Ollama 生成向量嵌入""" try: import ollama response = ollama.embeddings( model="nomic-embed-text", prompt=text[:8192] # 限制长度 ) return response["embedding"] except Exception as e: print(f"生成 embedding 失败: {e}") # 降级使用随机向量 import random return [random.random() for _ in range(768)] def import_file(self, file_path): """导入单个文件""" content = self.read_file(file_path) if not content: return chunks = self.split_into_chunks(content) points = [] for i, chunk in enumerate(chunks): # 生成向量(实际应使用真实 embedding) vector = self.generate_embedding(chunk) point_id = str(uuid.uuid4()) points.append( PointStruct( id=point_id, vector=vector, payload={ "file_path": str(file_path), "chunk_index": i, "content": chunk[:200] + "..." if len(chunk) > 200 else chunk, "full_content": chunk, "timestamp": int(time.time()) } ) ) # 批量导入 batch_size = 100 for i in range(0, len(points), batch_size): batch = points[i:i + batch_size] self.client.upsert( collection_name=self.collection_name, points=batch ) print(f"已导入 {len(batch)} 条记录") print(f"\n文件 {file_path} 导入完成,共 {len(points)} 条记录") def import_directory(self, dir_path, pattern="*.md"): """导入目录下的所有匹配文件""" path = Path(dir_path) files = list(path.rglob(pattern)) print(f"发现 {len(files)} 个文件") for file_path in files: if file_path.is_file(): print(f"\n{'='*60}") print(f"处理文件: {file_path}") print(f"{'='*60}") self.import_file(file_path) def search(self, query_text, limit=5): """搜索相似内容""" query_vector = self.generate_embedding(query_text) results = self.client.search( collection_name=self.collection_name, query_vector=query_vector, limit=limit ) return results def main(): importer = ArticleImporter() # 创建 collection importer.create_collection() # 导入文件 if len(sys.argv) > 1: # 导入指定文件或目录 path = sys.argv[1] if os.path.isdir(path): importer.import_directory(path) else: importer.import_file(path) else: # 默认导入 material 和 papers 目录 print("导入 material 目录...") importer.import_directory("/root/tts/material") print("\n导入 papers 目录...") importer.import_directory("/root/tts/papers") print("\n导入 docs 目录...") importer.import_directory("/root/tts/docs") if __name__ == "__main__": main()