147 lines
4.6 KiB
Python
147 lines
4.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试 Google ADK Memory Bank 功能
|
||
"""
|
||
|
||
import os
|
||
import asyncio
|
||
from google.adk import Agent
|
||
from google.adk.memory import MemoryBank, MemoryItem
|
||
from datetime import datetime
|
||
|
||
async def test_memory_bank():
|
||
"""测试Memory Bank基本功能"""
|
||
print("🧠 测试 Google ADK Memory Bank...")
|
||
|
||
try:
|
||
# 创建记忆银行
|
||
memory_bank = MemoryBank(
|
||
name="test_memory_bank",
|
||
description="测试用的记忆银行"
|
||
)
|
||
|
||
print("✅ Memory Bank 创建成功")
|
||
|
||
# 添加记忆项
|
||
memory_item = MemoryItem(
|
||
content="这是一个测试记忆:比特币在2021年达到历史最高点69000美元",
|
||
metadata={
|
||
"type": "market_data",
|
||
"asset": "bitcoin",
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
)
|
||
|
||
await memory_bank.add_memory(memory_item)
|
||
print("✅ 记忆添加成功")
|
||
|
||
# 搜索记忆
|
||
search_results = await memory_bank.search("比特币", limit=5)
|
||
print(f"✅ 记忆搜索成功,找到 {len(search_results)} 条相关记忆")
|
||
|
||
for i, memory in enumerate(search_results):
|
||
print(f" {i+1}. {memory.content}")
|
||
|
||
# 创建带记忆银行的智能体
|
||
agent = Agent(
|
||
name="测试智能体",
|
||
model="gemini-2.0-flash-exp",
|
||
instruction="你是一个测试智能体,请使用你的记忆银行来回答问题。",
|
||
memory_bank=memory_bank
|
||
)
|
||
|
||
print("✅ 带记忆银行的智能体创建成功")
|
||
|
||
return True
|
||
|
||
except ImportError as e:
|
||
print(f"❌ Memory Bank 模块导入失败: {e}")
|
||
print("💡 可能需要更新 Google ADK 版本或启用 Memory Bank 功能")
|
||
return False
|
||
except Exception as e:
|
||
print(f"❌ Memory Bank 测试失败: {e}")
|
||
return False
|
||
|
||
async def test_simple_memory_simulation():
|
||
"""模拟Memory Bank功能的简单实现"""
|
||
print("\n🔄 使用简单模拟实现...")
|
||
|
||
class SimpleMemoryBank:
|
||
def __init__(self, name: str, description: str):
|
||
self.name = name
|
||
self.description = description
|
||
self.memories = []
|
||
|
||
async def add_memory(self, content: str, metadata: dict = None):
|
||
memory = {
|
||
"content": content,
|
||
"metadata": metadata or {},
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
self.memories.append(memory)
|
||
|
||
async def search(self, query: str, limit: int = 5):
|
||
# 简单的关键词匹配
|
||
results = []
|
||
query_lower = query.lower()
|
||
|
||
for memory in self.memories:
|
||
if query_lower in memory["content"].lower():
|
||
results.append(memory)
|
||
if len(results) >= limit:
|
||
break
|
||
|
||
return results
|
||
|
||
# 测试简单实现
|
||
memory_bank = SimpleMemoryBank(
|
||
name="铁拐李记忆银行",
|
||
description="铁拐李的逆向投资记忆"
|
||
)
|
||
|
||
# 添加一些记忆
|
||
memories = [
|
||
"2000年互联网泡沫破裂,纳斯达克指数从5048点跌到1114点",
|
||
"2008年金融危机,雷曼兄弟破产引发全球恐慌",
|
||
"2020年3月疫情恐慌,美股熔断4次,但随后强劲反弹",
|
||
"比特币从2017年的2万美元跌到2018年的3200美元"
|
||
]
|
||
|
||
for memory in memories:
|
||
await memory_bank.add_memory(memory, {"type": "historical_event"})
|
||
|
||
print(f"✅ 已添加 {len(memories)} 条记忆")
|
||
|
||
# 搜索测试
|
||
search_queries = ["泡沫", "比特币", "金融危机"]
|
||
|
||
for query in search_queries:
|
||
results = await memory_bank.search(query)
|
||
print(f"\n🔍 搜索 '{query}' 找到 {len(results)} 条记忆:")
|
||
for i, result in enumerate(results):
|
||
print(f" {i+1}. {result['content']}")
|
||
|
||
return True
|
||
|
||
async def main():
|
||
"""主测试函数"""
|
||
print("🚀 Google ADK Memory Bank 功能测试")
|
||
|
||
# 检查API密钥
|
||
api_key = os.getenv('GOOGLE_API_KEY')
|
||
if not api_key:
|
||
print("❌ 未找到 GOOGLE_API_KEY 环境变量")
|
||
return
|
||
|
||
print(f"✅ API密钥已配置")
|
||
|
||
# 尝试真实的Memory Bank
|
||
success = await test_memory_bank()
|
||
|
||
if not success:
|
||
# 如果真实的Memory Bank不可用,使用模拟实现
|
||
await test_simple_memory_simulation()
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |