liurenchaxin/scripts/generate_embeddings.py

75 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
为MongoDB中的文章生成向量embeddings
用于swarm辩论系统的语义搜索和内容聚类
"""
import os
import openai
from pymongo import MongoClient
from typing import List, Dict
import time
def get_mongodb_client():
"""从Doppler获取MongoDB连接"""
mongodb_uri = os.getenv('MONGODB_URI')
if not mongodb_uri:
raise ValueError("MONGODB_URI not found in environment variables")
return MongoClient(mongodb_uri)
def generate_embedding(text: str) -> List[float]:
"""使用OpenAI API生成文本embedding"""
try:
response = openai.Embedding.create(
model="text-embedding-ada-002",
input=text
)
return response['data'][0]['embedding']
except Exception as e:
print(f"生成embedding失败: {e}")
return None
def update_articles_with_embeddings():
"""为所有文章添加embedding字段"""
client = get_mongodb_client()
db = client.taigong
collection = db.articles
# 获取所有没有embedding的文章
articles = collection.find({"embedding": {"$exists": False}})
count = 0
for article in articles:
title = article.get('title', '')
if not title:
continue
print(f"处理文章: {title[:50]}...")
# 生成embedding
embedding = generate_embedding(title)
if embedding:
# 更新文档
collection.update_one(
{"_id": article["_id"]},
{"$set": {"embedding": embedding}}
)
count += 1
print(f"✓ 已更新 {count} 篇文章")
# 避免API rate limit
time.sleep(0.1)
else:
print(f"× 跳过文章: {title[:50]}")
print(f"\n完成!共处理 {count} 篇文章")
client.close()
if __name__ == "__main__":
# 设置OpenAI API密钥 (应该从Doppler获取)
openai.api_key = os.getenv('OPENAI_API_KEY')
if not openai.api_key:
print("警告: OPENAI_API_KEY 未设置请先在Doppler中配置")
exit(1)
update_articles_with_embeddings()