liurenchaxin/tests/test_vertex_memory_bank.py

384 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Vertex Memory Bank 实现测试
"""
import unittest
import asyncio
import os
import sys
from unittest.mock import patch, MagicMock, AsyncMock
from datetime import datetime
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.jixia.memory.vertex_memory_bank import VertexMemoryBank, MemoryEntry
class TestVertexMemoryBank(unittest.TestCase):
"""测试VertexMemoryBank类"""
def setUp(self):
"""测试前的设置"""
# Mock掉aiplatform.init以避免实际初始化
patcher = patch('src.jixia.memory.vertex_memory_bank.aiplatform.init')
self.mock_init = patcher.start()
self.addCleanup(patcher.stop)
# 创建VertexMemoryBank实例
self.memory_bank = VertexMemoryBank(
project_id="test-project",
location="us-central1"
)
# 重置本地存储
self.memory_bank.local_memories = {}
self.memory_bank.memory_banks = {}
def test_init(self):
"""测试初始化"""
self.assertEqual(self.memory_bank.project_id, "test-project")
self.assertEqual(self.memory_bank.location, "us-central1")
self.assertEqual(self.memory_bank.local_memories, {})
self.assertEqual(self.memory_bank.memory_banks, {})
# 验证调用了aiplatform.init
self.mock_init.assert_called_once_with(project="test-project", location="us-central1")
def test_from_config(self):
"""测试从配置创建实例"""
with patch('src.jixia.memory.vertex_memory_bank.get_google_genai_config') as mock_config:
mock_config.return_value = {
'project_id': 'config-project',
'location': 'europe-west1'
}
memory_bank = VertexMemoryBank.from_config()
self.assertEqual(memory_bank.project_id, "config-project")
self.assertEqual(memory_bank.location, "europe-west1")
def test_from_config_missing_project_id(self):
"""测试从配置创建实例时缺少project_id"""
with patch('src.jixia.memory.vertex_memory_bank.get_google_genai_config') as mock_config:
mock_config.return_value = {
'project_id': None,
'location': 'europe-west1'
}
with self.assertRaises(ValueError) as context:
VertexMemoryBank.from_config()
self.assertIn("Google Cloud Project ID 未配置", str(context.exception))
async def test_create_memory_bank(self):
"""测试创建记忆银行"""
memory_bank_id = await self.memory_bank.create_memory_bank("tieguaili")
# 验证返回的ID格式
self.assertEqual(memory_bank_id, "memory_bank_tieguaili_test-project")
# 验证内部状态
self.assertIn("tieguaili", self.memory_bank.memory_banks)
self.assertEqual(self.memory_bank.memory_banks["tieguaili"], memory_bank_id)
self.assertIn("tieguaili", self.memory_bank.local_memories)
self.assertEqual(self.memory_bank.local_memories["tieguaili"], [])
async def test_create_memory_bank_with_display_name(self):
"""测试创建记忆银行时指定显示名称"""
memory_bank_id = await self.memory_bank.create_memory_bank(
"tieguaili",
"铁拐李的专属记忆银行"
)
# 验证返回的ID格式
self.assertEqual(memory_bank_id, "memory_bank_tieguaili_test-project")
async def test_add_memory(self):
"""测试添加记忆"""
# 先创建记忆银行
await self.memory_bank.create_memory_bank("tieguaili")
# 添加记忆
memory_id = await self.memory_bank.add_memory(
agent_name="tieguaili",
content="在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。",
memory_type="preference",
debate_topic="NVIDIA投资分析",
metadata={"source": "manual"}
)
# 验证返回的ID格式
self.assertEqual(memory_id, "memory_tieguaili_0")
# 验证记忆已存储
self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 1)
stored_memory = self.memory_bank.local_memories["tieguaili"][0]
self.assertEqual(stored_memory["id"], memory_id)
self.assertEqual(stored_memory["content"], "在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。")
self.assertEqual(stored_memory["memory_type"], "preference")
self.assertEqual(stored_memory["debate_topic"], "NVIDIA投资分析")
self.assertIn("source", stored_memory["metadata"])
self.assertEqual(stored_memory["metadata"]["source"], "manual")
self.assertIn("agent_name", stored_memory["metadata"])
self.assertEqual(stored_memory["metadata"]["agent_name"], "tieguaili")
async def test_add_memory_creates_bank_if_not_exists(self):
"""测试添加记忆时自动创建记忆银行"""
# 不先创建记忆银行,直接添加记忆
memory_id = await self.memory_bank.add_memory(
agent_name="tieguaili",
content="测试内容"
)
# 验证记忆银行已被自动创建
self.assertIn("tieguaili", self.memory_bank.memory_banks)
self.assertIn("tieguaili", self.memory_bank.local_memories)
# 验证记忆已存储
self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 1)
async def test_search_memories(self):
"""测试搜索记忆"""
# 先创建记忆银行并添加一些记忆
await self.memory_bank.create_memory_bank("tieguaili")
await self.memory_bank.add_memory(
agent_name="tieguaili",
content="在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。",
memory_type="preference",
debate_topic="NVIDIA投资分析"
)
await self.memory_bank.add_memory(
agent_name="tieguaili",
content="我喜欢关注苹果公司的创新产品发布会。",
memory_type="preference",
debate_topic="AAPL投资分析"
)
# 搜索NVIDIA相关记忆
results = await self.memory_bank.search_memories(
agent_name="tieguaili",
query="NVIDIA"
)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["content"], "在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。")
self.assertIn("relevance_score", results[0])
async def test_search_memories_with_type_filter(self):
"""测试带类型过滤的搜索记忆"""
# 先创建记忆银行并添加不同类型的记忆
await self.memory_bank.create_memory_bank("tieguaili")
await self.memory_bank.add_memory(
agent_name="tieguaili",
content="在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。",
memory_type="preference",
debate_topic="NVIDIA投资分析"
)
await self.memory_bank.add_memory(
agent_name="tieguaili",
content="在NVIDIA的辩论中我使用了技术分析策略。",
memory_type="strategy",
debate_topic="NVIDIA投资分析"
)
# 搜索NVIDIA相关记忆只返回preference类型
results = await self.memory_bank.search_memories(
agent_name="tieguaili",
query="NVIDIA",
memory_type="preference"
)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["metadata"]["memory_type"], "preference")
async def test_search_memories_no_results(self):
"""测试搜索无结果的情况"""
# 搜索不存在的记忆银行
results = await self.memory_bank.search_memories(
agent_name="nonexistent",
query="test"
)
self.assertEqual(results, [])
# 搜索空的记忆银行
await self.memory_bank.create_memory_bank("tieguaili")
results = await self.memory_bank.search_memories(
agent_name="tieguaili",
query="test"
)
self.assertEqual(results, [])
async def test_get_agent_context(self):
"""测试获取智能体上下文"""
# 先创建记忆银行并添加一些记忆
await self.memory_bank.create_memory_bank("tieguaili")
await self.memory_bank.add_memory(
agent_name="tieguaili",
content="在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。",
memory_type="preference",
debate_topic="NVIDIA投资分析"
)
await self.memory_bank.add_memory(
agent_name="tieguaili",
content="在NVIDIA的辩论中我使用了技术分析策略。",
memory_type="strategy",
debate_topic="NVIDIA投资分析"
)
# 获取上下文
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
# 验证上下文包含预期内容
self.assertIn("# 铁拐李的记忆上下文", context)
self.assertIn("## 偏好记忆", context)
self.assertIn("## 策略记忆", context)
self.assertIn("在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。", context)
self.assertIn("在NVIDIA的辩论中我使用了技术分析策略。", context)
async def test_get_agent_context_no_memories(self):
"""测试获取智能体上下文但无相关记忆"""
# 先创建记忆银行
await self.memory_bank.create_memory_bank("tieguaili")
# 获取上下文
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
# 验证上下文包含暂无相关记忆的提示
self.assertIn("# 铁拐李的记忆上下文", context)
self.assertIn("暂无相关记忆。", context)
async def test_save_debate_session(self):
"""测试保存辩论会话"""
conversation_history = [
{"agent": "tieguaili", "content": "NVIDIA的估值过高存在泡沫风险。"},
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
{"agent": "tieguaili", "content": "但我们需要考虑竞争加剧和增长放缓的可能性。"}
]
outcomes = {
"winner": "lvdongbin",
"insights": {
"tieguaili": "铁拐李的风险意识值得肯定但在AI趋势的判断上略显保守。"
}
}
# 保存辩论会话
await self.memory_bank.save_debate_session(
debate_topic="NVIDIA投资分析",
participants=["tieguaili", "lvdongbin"],
conversation_history=conversation_history,
outcomes=outcomes
)
# 验证铁拐李的记忆已保存
self.assertIn("tieguaili", self.memory_bank.local_memories)
self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 2)
# 验证第一条记忆是对话总结
summary_memory = self.memory_bank.local_memories["tieguaili"][0]
self.assertEqual(summary_memory["memory_type"], "conversation")
self.assertIn("铁拐李在本次辩论中的主要观点", summary_memory["content"])
# 验证第二条记忆是策略洞察
strategy_memory = self.memory_bank.local_memories["tieguaili"][1]
self.assertEqual(strategy_memory["memory_type"], "strategy")
self.assertIn("铁拐李的风险意识值得肯定", strategy_memory["content"])
def test_summarize_conversation(self):
"""测试对话总结"""
conversation_history = [
{"agent": "tieguaili", "content": "第一点看法NVIDIA的估值过高存在泡沫风险。"},
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
{"agent": "tieguaili", "content": "第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。"},
{"agent": "tieguaili", "content": "第三点看法:从技术分析角度看,股价已出现超买信号。"}
]
summary = self.memory_bank._summarize_conversation(conversation_history, "tieguaili")
# 验证总结包含预期内容
self.assertIn("铁拐李在本次辩论中的主要观点", summary)
self.assertIn("第一点看法NVIDIA的估值过高存在泡沫风险。", summary)
self.assertIn("第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。", summary)
self.assertIn("第三点看法:从技术分析角度看,股价已出现超买信号。", summary)
def test_extract_strategy_insight_winner(self):
"""测试提取策略洞察 - 获胜者"""
outcomes = {
"winner": "tieguaili",
"insights": {}
}
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
self.assertIn("铁拐李在本次辩论中获胜", insight)
def test_extract_strategy_insight_from_insights(self):
"""测试从洞察中提取策略洞察"""
outcomes = {
"winner": "lvdongbin",
"insights": {
"tieguaili": "铁拐李的风险意识值得肯定但在AI趋势的判断上略显保守。"
}
}
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
self.assertEqual(insight, "铁拐李的风险意识值得肯定但在AI趋势的判断上略显保守。")
if __name__ == '__main__':
# 创建一个异步测试运行器
def run_async_test(test_case):
"""运行异步测试用例"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(test_case)
finally:
loop.close()
# 获取所有以test_开头的异步方法并运行它们
suite = unittest.TestSuite()
test_instance = TestVertexMemoryBank()
test_instance.setUp()
# 添加同步测试
suite.addTest(TestVertexMemoryBank('test_init'))
suite.addTest(TestVertexMemoryBank('test_from_config'))
suite.addTest(TestVertexMemoryBank('test_from_config_missing_project_id'))
suite.addTest(TestVertexMemoryBank('test_summarize_conversation'))
suite.addTest(TestVertexMemoryBank('test_extract_strategy_insight_winner'))
suite.addTest(TestVertexMemoryBank('test_extract_strategy_insight_from_insights'))
# 添加异步测试
async_tests = [
'test_create_memory_bank',
'test_create_memory_bank_with_display_name',
'test_add_memory',
'test_add_memory_creates_bank_if_not_exists',
'test_search_memories',
'test_search_memories_with_type_filter',
'test_search_memories_no_results',
'test_get_agent_context',
'test_get_agent_context_no_memories',
'test_save_debate_session'
]
for test_name in async_tests:
test_method = getattr(test_instance, test_name)
suite.addTest(unittest.FunctionTestCase(lambda tm=test_method: run_async_test(tm())))
# 运行测试
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)