384 lines
16 KiB
Python
384 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Vertex Memory Bank 实现测试
|
||
"""
|
||
|
||
import unittest
|
||
import asyncio
|
||
import os
|
||
import sys
|
||
from unittest.mock import patch, MagicMock, AsyncMock
|
||
from datetime import datetime
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
|
||
from src.jixia.memory.vertex_memory_bank import VertexMemoryBank, MemoryEntry
|
||
|
||
|
||
class TestVertexMemoryBank(unittest.TestCase):
|
||
"""测试VertexMemoryBank类"""
|
||
|
||
def setUp(self):
|
||
"""测试前的设置"""
|
||
# Mock掉aiplatform.init以避免实际初始化
|
||
patcher = patch('src.jixia.memory.vertex_memory_bank.aiplatform.init')
|
||
self.mock_init = patcher.start()
|
||
self.addCleanup(patcher.stop)
|
||
|
||
# 创建VertexMemoryBank实例
|
||
self.memory_bank = VertexMemoryBank(
|
||
project_id="test-project",
|
||
location="us-central1"
|
||
)
|
||
|
||
# 重置本地存储
|
||
self.memory_bank.local_memories = {}
|
||
self.memory_bank.memory_banks = {}
|
||
|
||
def test_init(self):
|
||
"""测试初始化"""
|
||
self.assertEqual(self.memory_bank.project_id, "test-project")
|
||
self.assertEqual(self.memory_bank.location, "us-central1")
|
||
self.assertEqual(self.memory_bank.local_memories, {})
|
||
self.assertEqual(self.memory_bank.memory_banks, {})
|
||
|
||
# 验证调用了aiplatform.init
|
||
self.mock_init.assert_called_once_with(project="test-project", location="us-central1")
|
||
|
||
def test_from_config(self):
|
||
"""测试从配置创建实例"""
|
||
with patch('src.jixia.memory.vertex_memory_bank.get_google_genai_config') as mock_config:
|
||
mock_config.return_value = {
|
||
'project_id': 'config-project',
|
||
'location': 'europe-west1'
|
||
}
|
||
|
||
memory_bank = VertexMemoryBank.from_config()
|
||
|
||
self.assertEqual(memory_bank.project_id, "config-project")
|
||
self.assertEqual(memory_bank.location, "europe-west1")
|
||
|
||
def test_from_config_missing_project_id(self):
|
||
"""测试从配置创建实例时缺少project_id"""
|
||
with patch('src.jixia.memory.vertex_memory_bank.get_google_genai_config') as mock_config:
|
||
mock_config.return_value = {
|
||
'project_id': None,
|
||
'location': 'europe-west1'
|
||
}
|
||
|
||
with self.assertRaises(ValueError) as context:
|
||
VertexMemoryBank.from_config()
|
||
|
||
self.assertIn("Google Cloud Project ID 未配置", str(context.exception))
|
||
|
||
async def test_create_memory_bank(self):
|
||
"""测试创建记忆银行"""
|
||
memory_bank_id = await self.memory_bank.create_memory_bank("tieguaili")
|
||
|
||
# 验证返回的ID格式
|
||
self.assertEqual(memory_bank_id, "memory_bank_tieguaili_test-project")
|
||
|
||
# 验证内部状态
|
||
self.assertIn("tieguaili", self.memory_bank.memory_banks)
|
||
self.assertEqual(self.memory_bank.memory_banks["tieguaili"], memory_bank_id)
|
||
self.assertIn("tieguaili", self.memory_bank.local_memories)
|
||
self.assertEqual(self.memory_bank.local_memories["tieguaili"], [])
|
||
|
||
async def test_create_memory_bank_with_display_name(self):
|
||
"""测试创建记忆银行时指定显示名称"""
|
||
memory_bank_id = await self.memory_bank.create_memory_bank(
|
||
"tieguaili",
|
||
"铁拐李的专属记忆银行"
|
||
)
|
||
|
||
# 验证返回的ID格式
|
||
self.assertEqual(memory_bank_id, "memory_bank_tieguaili_test-project")
|
||
|
||
async def test_add_memory(self):
|
||
"""测试添加记忆"""
|
||
# 先创建记忆银行
|
||
await self.memory_bank.create_memory_bank("tieguaili")
|
||
|
||
# 添加记忆
|
||
memory_id = await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
|
||
memory_type="preference",
|
||
debate_topic="NVIDIA投资分析",
|
||
metadata={"source": "manual"}
|
||
)
|
||
|
||
# 验证返回的ID格式
|
||
self.assertEqual(memory_id, "memory_tieguaili_0")
|
||
|
||
# 验证记忆已存储
|
||
self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 1)
|
||
|
||
stored_memory = self.memory_bank.local_memories["tieguaili"][0]
|
||
self.assertEqual(stored_memory["id"], memory_id)
|
||
self.assertEqual(stored_memory["content"], "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
|
||
self.assertEqual(stored_memory["memory_type"], "preference")
|
||
self.assertEqual(stored_memory["debate_topic"], "NVIDIA投资分析")
|
||
self.assertIn("source", stored_memory["metadata"])
|
||
self.assertEqual(stored_memory["metadata"]["source"], "manual")
|
||
self.assertIn("agent_name", stored_memory["metadata"])
|
||
self.assertEqual(stored_memory["metadata"]["agent_name"], "tieguaili")
|
||
|
||
async def test_add_memory_creates_bank_if_not_exists(self):
|
||
"""测试添加记忆时自动创建记忆银行"""
|
||
# 不先创建记忆银行,直接添加记忆
|
||
memory_id = await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="测试内容"
|
||
)
|
||
|
||
# 验证记忆银行已被自动创建
|
||
self.assertIn("tieguaili", self.memory_bank.memory_banks)
|
||
self.assertIn("tieguaili", self.memory_bank.local_memories)
|
||
|
||
# 验证记忆已存储
|
||
self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 1)
|
||
|
||
async def test_search_memories(self):
|
||
"""测试搜索记忆"""
|
||
# 先创建记忆银行并添加一些记忆
|
||
await self.memory_bank.create_memory_bank("tieguaili")
|
||
|
||
await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
|
||
memory_type="preference",
|
||
debate_topic="NVIDIA投资分析"
|
||
)
|
||
|
||
await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="我喜欢关注苹果公司的创新产品发布会。",
|
||
memory_type="preference",
|
||
debate_topic="AAPL投资分析"
|
||
)
|
||
|
||
# 搜索NVIDIA相关记忆
|
||
results = await self.memory_bank.search_memories(
|
||
agent_name="tieguaili",
|
||
query="NVIDIA"
|
||
)
|
||
|
||
self.assertEqual(len(results), 1)
|
||
self.assertEqual(results[0]["content"], "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
|
||
self.assertIn("relevance_score", results[0])
|
||
|
||
async def test_search_memories_with_type_filter(self):
|
||
"""测试带类型过滤的搜索记忆"""
|
||
# 先创建记忆银行并添加不同类型的记忆
|
||
await self.memory_bank.create_memory_bank("tieguaili")
|
||
|
||
await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
|
||
memory_type="preference",
|
||
debate_topic="NVIDIA投资分析"
|
||
)
|
||
|
||
await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="在NVIDIA的辩论中,我使用了技术分析策略。",
|
||
memory_type="strategy",
|
||
debate_topic="NVIDIA投资分析"
|
||
)
|
||
|
||
# 搜索NVIDIA相关记忆,只返回preference类型
|
||
results = await self.memory_bank.search_memories(
|
||
agent_name="tieguaili",
|
||
query="NVIDIA",
|
||
memory_type="preference"
|
||
)
|
||
|
||
self.assertEqual(len(results), 1)
|
||
self.assertEqual(results[0]["metadata"]["memory_type"], "preference")
|
||
|
||
async def test_search_memories_no_results(self):
|
||
"""测试搜索无结果的情况"""
|
||
# 搜索不存在的记忆银行
|
||
results = await self.memory_bank.search_memories(
|
||
agent_name="nonexistent",
|
||
query="test"
|
||
)
|
||
|
||
self.assertEqual(results, [])
|
||
|
||
# 搜索空的记忆银行
|
||
await self.memory_bank.create_memory_bank("tieguaili")
|
||
results = await self.memory_bank.search_memories(
|
||
agent_name="tieguaili",
|
||
query="test"
|
||
)
|
||
|
||
self.assertEqual(results, [])
|
||
|
||
async def test_get_agent_context(self):
|
||
"""测试获取智能体上下文"""
|
||
# 先创建记忆银行并添加一些记忆
|
||
await self.memory_bank.create_memory_bank("tieguaili")
|
||
|
||
await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
|
||
memory_type="preference",
|
||
debate_topic="NVIDIA投资分析"
|
||
)
|
||
|
||
await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="在NVIDIA的辩论中,我使用了技术分析策略。",
|
||
memory_type="strategy",
|
||
debate_topic="NVIDIA投资分析"
|
||
)
|
||
|
||
# 获取上下文
|
||
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
|
||
|
||
# 验证上下文包含预期内容
|
||
self.assertIn("# 铁拐李的记忆上下文", context)
|
||
self.assertIn("## 偏好记忆", context)
|
||
self.assertIn("## 策略记忆", context)
|
||
self.assertIn("在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。", context)
|
||
self.assertIn("在NVIDIA的辩论中,我使用了技术分析策略。", context)
|
||
|
||
async def test_get_agent_context_no_memories(self):
|
||
"""测试获取智能体上下文但无相关记忆"""
|
||
# 先创建记忆银行
|
||
await self.memory_bank.create_memory_bank("tieguaili")
|
||
|
||
# 获取上下文
|
||
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
|
||
|
||
# 验证上下文包含暂无相关记忆的提示
|
||
self.assertIn("# 铁拐李的记忆上下文", context)
|
||
self.assertIn("暂无相关记忆。", context)
|
||
|
||
async def test_save_debate_session(self):
|
||
"""测试保存辩论会话"""
|
||
conversation_history = [
|
||
{"agent": "tieguaili", "content": "NVIDIA的估值过高,存在泡沫风险。"},
|
||
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
|
||
{"agent": "tieguaili", "content": "但我们需要考虑竞争加剧和增长放缓的可能性。"}
|
||
]
|
||
|
||
outcomes = {
|
||
"winner": "lvdongbin",
|
||
"insights": {
|
||
"tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
|
||
}
|
||
}
|
||
|
||
# 保存辩论会话
|
||
await self.memory_bank.save_debate_session(
|
||
debate_topic="NVIDIA投资分析",
|
||
participants=["tieguaili", "lvdongbin"],
|
||
conversation_history=conversation_history,
|
||
outcomes=outcomes
|
||
)
|
||
|
||
# 验证铁拐李的记忆已保存
|
||
self.assertIn("tieguaili", self.memory_bank.local_memories)
|
||
self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 2)
|
||
|
||
# 验证第一条记忆是对话总结
|
||
summary_memory = self.memory_bank.local_memories["tieguaili"][0]
|
||
self.assertEqual(summary_memory["memory_type"], "conversation")
|
||
self.assertIn("铁拐李在本次辩论中的主要观点", summary_memory["content"])
|
||
|
||
# 验证第二条记忆是策略洞察
|
||
strategy_memory = self.memory_bank.local_memories["tieguaili"][1]
|
||
self.assertEqual(strategy_memory["memory_type"], "strategy")
|
||
self.assertIn("铁拐李的风险意识值得肯定", strategy_memory["content"])
|
||
|
||
def test_summarize_conversation(self):
|
||
"""测试对话总结"""
|
||
conversation_history = [
|
||
{"agent": "tieguaili", "content": "第一点看法:NVIDIA的估值过高,存在泡沫风险。"},
|
||
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
|
||
{"agent": "tieguaili", "content": "第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。"},
|
||
{"agent": "tieguaili", "content": "第三点看法:从技术分析角度看,股价已出现超买信号。"}
|
||
]
|
||
|
||
summary = self.memory_bank._summarize_conversation(conversation_history, "tieguaili")
|
||
|
||
# 验证总结包含预期内容
|
||
self.assertIn("铁拐李在本次辩论中的主要观点", summary)
|
||
self.assertIn("第一点看法:NVIDIA的估值过高,存在泡沫风险。", summary)
|
||
self.assertIn("第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。", summary)
|
||
self.assertIn("第三点看法:从技术分析角度看,股价已出现超买信号。", summary)
|
||
|
||
def test_extract_strategy_insight_winner(self):
|
||
"""测试提取策略洞察 - 获胜者"""
|
||
outcomes = {
|
||
"winner": "tieguaili",
|
||
"insights": {}
|
||
}
|
||
|
||
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
|
||
|
||
self.assertIn("铁拐李在本次辩论中获胜", insight)
|
||
|
||
def test_extract_strategy_insight_from_insights(self):
|
||
"""测试从洞察中提取策略洞察"""
|
||
outcomes = {
|
||
"winner": "lvdongbin",
|
||
"insights": {
|
||
"tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
|
||
}
|
||
}
|
||
|
||
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
|
||
|
||
self.assertEqual(insight, "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 创建一个异步测试运行器
|
||
def run_async_test(test_case):
|
||
"""运行异步测试用例"""
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(test_case)
|
||
finally:
|
||
loop.close()
|
||
|
||
# 获取所有以test_开头的异步方法并运行它们
|
||
suite = unittest.TestSuite()
|
||
test_instance = TestVertexMemoryBank()
|
||
test_instance.setUp()
|
||
|
||
# 添加同步测试
|
||
suite.addTest(TestVertexMemoryBank('test_init'))
|
||
suite.addTest(TestVertexMemoryBank('test_from_config'))
|
||
suite.addTest(TestVertexMemoryBank('test_from_config_missing_project_id'))
|
||
suite.addTest(TestVertexMemoryBank('test_summarize_conversation'))
|
||
suite.addTest(TestVertexMemoryBank('test_extract_strategy_insight_winner'))
|
||
suite.addTest(TestVertexMemoryBank('test_extract_strategy_insight_from_insights'))
|
||
|
||
# 添加异步测试
|
||
async_tests = [
|
||
'test_create_memory_bank',
|
||
'test_create_memory_bank_with_display_name',
|
||
'test_add_memory',
|
||
'test_add_memory_creates_bank_if_not_exists',
|
||
'test_search_memories',
|
||
'test_search_memories_with_type_filter',
|
||
'test_search_memories_no_results',
|
||
'test_get_agent_context',
|
||
'test_get_agent_context_no_memories',
|
||
'test_save_debate_session'
|
||
]
|
||
|
||
for test_name in async_tests:
|
||
test_method = getattr(test_instance, test_name)
|
||
suite.addTest(unittest.FunctionTestCase(lambda tm=test_method: run_async_test(tm())))
|
||
|
||
# 运行测试
|
||
runner = unittest.TextTestRunner(verbosity=2)
|
||
runner.run(suite) |