liurenchaxin/tests/test_memory_bank.py

147 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试 Google ADK Memory Bank 功能
"""
import os
import asyncio
from google.adk import Agent
from google.adk.memory import MemoryBank, MemoryItem
from datetime import datetime
async def test_memory_bank():
"""测试Memory Bank基本功能"""
print("🧠 测试 Google ADK Memory Bank...")
try:
# 创建记忆银行
memory_bank = MemoryBank(
name="test_memory_bank",
description="测试用的记忆银行"
)
print("✅ Memory Bank 创建成功")
# 添加记忆项
memory_item = MemoryItem(
content="这是一个测试记忆比特币在2021年达到历史最高点69000美元",
metadata={
"type": "market_data",
"asset": "bitcoin",
"timestamp": datetime.now().isoformat()
}
)
await memory_bank.add_memory(memory_item)
print("✅ 记忆添加成功")
# 搜索记忆
search_results = await memory_bank.search("比特币", limit=5)
print(f"✅ 记忆搜索成功,找到 {len(search_results)} 条相关记忆")
for i, memory in enumerate(search_results):
print(f" {i+1}. {memory.content}")
# 创建带记忆银行的智能体
agent = Agent(
name="测试智能体",
model="gemini-2.0-flash-exp",
instruction="你是一个测试智能体,请使用你的记忆银行来回答问题。",
memory_bank=memory_bank
)
print("✅ 带记忆银行的智能体创建成功")
return True
except ImportError as e:
print(f"❌ Memory Bank 模块导入失败: {e}")
print("💡 可能需要更新 Google ADK 版本或启用 Memory Bank 功能")
return False
except Exception as e:
print(f"❌ Memory Bank 测试失败: {e}")
return False
async def test_simple_memory_simulation():
"""模拟Memory Bank功能的简单实现"""
print("\n🔄 使用简单模拟实现...")
class SimpleMemoryBank:
def __init__(self, name: str, description: str):
self.name = name
self.description = description
self.memories = []
async def add_memory(self, content: str, metadata: dict = None):
memory = {
"content": content,
"metadata": metadata or {},
"timestamp": datetime.now().isoformat()
}
self.memories.append(memory)
async def search(self, query: str, limit: int = 5):
# 简单的关键词匹配
results = []
query_lower = query.lower()
for memory in self.memories:
if query_lower in memory["content"].lower():
results.append(memory)
if len(results) >= limit:
break
return results
# 测试简单实现
memory_bank = SimpleMemoryBank(
name="铁拐李记忆银行",
description="铁拐李的逆向投资记忆"
)
# 添加一些记忆
memories = [
"2000年互联网泡沫破裂纳斯达克指数从5048点跌到1114点",
"2008年金融危机雷曼兄弟破产引发全球恐慌",
"2020年3月疫情恐慌美股熔断4次但随后强劲反弹",
"比特币从2017年的2万美元跌到2018年的3200美元"
]
for memory in memories:
await memory_bank.add_memory(memory, {"type": "historical_event"})
print(f"✅ 已添加 {len(memories)} 条记忆")
# 搜索测试
search_queries = ["泡沫", "比特币", "金融危机"]
for query in search_queries:
results = await memory_bank.search(query)
print(f"\n🔍 搜索 '{query}' 找到 {len(results)} 条记忆:")
for i, result in enumerate(results):
print(f" {i+1}. {result['content']}")
return True
async def main():
"""主测试函数"""
print("🚀 Google ADK Memory Bank 功能测试")
# 检查API密钥
api_key = os.getenv('GOOGLE_API_KEY')
if not api_key:
print("❌ 未找到 GOOGLE_API_KEY 环境变量")
return
print(f"✅ API密钥已配置")
# 尝试真实的Memory Bank
success = await test_memory_bank()
if not success:
# 如果真实的Memory Bank不可用使用模拟实现
await test_simple_memory_simulation()
if __name__ == "__main__":
asyncio.run(main())