75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
为MongoDB中的文章生成向量embeddings
|
||
用于swarm辩论系统的语义搜索和内容聚类
|
||
"""
|
||
|
||
import os
|
||
import openai
|
||
from pymongo import MongoClient
|
||
from typing import List, Dict
|
||
import time
|
||
|
||
def get_mongodb_client():
|
||
"""从Doppler获取MongoDB连接"""
|
||
mongodb_uri = os.getenv('MONGODB_URI')
|
||
if not mongodb_uri:
|
||
raise ValueError("MONGODB_URI not found in environment variables")
|
||
return MongoClient(mongodb_uri)
|
||
|
||
def generate_embedding(text: str) -> List[float]:
|
||
"""使用OpenAI API生成文本embedding"""
|
||
try:
|
||
response = openai.Embedding.create(
|
||
model="text-embedding-ada-002",
|
||
input=text
|
||
)
|
||
return response['data'][0]['embedding']
|
||
except Exception as e:
|
||
print(f"生成embedding失败: {e}")
|
||
return None
|
||
|
||
def update_articles_with_embeddings():
|
||
"""为所有文章添加embedding字段"""
|
||
client = get_mongodb_client()
|
||
db = client.taigong
|
||
collection = db.articles
|
||
|
||
# 获取所有没有embedding的文章
|
||
articles = collection.find({"embedding": {"$exists": False}})
|
||
|
||
count = 0
|
||
for article in articles:
|
||
title = article.get('title', '')
|
||
if not title:
|
||
continue
|
||
|
||
print(f"处理文章: {title[:50]}...")
|
||
|
||
# 生成embedding
|
||
embedding = generate_embedding(title)
|
||
if embedding:
|
||
# 更新文档
|
||
collection.update_one(
|
||
{"_id": article["_id"]},
|
||
{"$set": {"embedding": embedding}}
|
||
)
|
||
count += 1
|
||
print(f"✓ 已更新 {count} 篇文章")
|
||
|
||
# 避免API rate limit
|
||
time.sleep(0.1)
|
||
else:
|
||
print(f"× 跳过文章: {title[:50]}")
|
||
|
||
print(f"\n完成!共处理 {count} 篇文章")
|
||
client.close()
|
||
|
||
if __name__ == "__main__":
|
||
# 设置OpenAI API密钥 (应该从Doppler获取)
|
||
openai.api_key = os.getenv('OPENAI_API_KEY')
|
||
if not openai.api_key:
|
||
print("警告: OPENAI_API_KEY 未设置,请先在Doppler中配置")
|
||
exit(1)
|
||
|
||
update_articles_with_embeddings() |