384 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			384 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
| #!/usr/bin/env python3
 | ||
| """
 | ||
| Vertex Memory Bank 实现测试
 | ||
| """
 | ||
| 
 | ||
| import unittest
 | ||
| import asyncio
 | ||
| import os
 | ||
| import sys
 | ||
| from unittest.mock import patch, MagicMock, AsyncMock
 | ||
| from datetime import datetime
 | ||
| 
 | ||
| # 添加项目根目录到Python路径
 | ||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
 | ||
| 
 | ||
| from src.jixia.memory.vertex_memory_bank import VertexMemoryBank, MemoryEntry
 | ||
| 
 | ||
| 
 | ||
| class TestVertexMemoryBank(unittest.TestCase):
 | ||
|     """测试VertexMemoryBank类"""
 | ||
| 
 | ||
|     def setUp(self):
 | ||
|         """测试前的设置"""
 | ||
|         # Mock掉aiplatform.init以避免实际初始化
 | ||
|         patcher = patch('src.jixia.memory.vertex_memory_bank.aiplatform.init')
 | ||
|         self.mock_init = patcher.start()
 | ||
|         self.addCleanup(patcher.stop)
 | ||
|         
 | ||
|         # 创建VertexMemoryBank实例
 | ||
|         self.memory_bank = VertexMemoryBank(
 | ||
|             project_id="test-project",
 | ||
|             location="us-central1"
 | ||
|         )
 | ||
|         
 | ||
|         # 重置本地存储
 | ||
|         self.memory_bank.local_memories = {}
 | ||
|         self.memory_bank.memory_banks = {}
 | ||
| 
 | ||
|     def test_init(self):
 | ||
|         """测试初始化"""
 | ||
|         self.assertEqual(self.memory_bank.project_id, "test-project")
 | ||
|         self.assertEqual(self.memory_bank.location, "us-central1")
 | ||
|         self.assertEqual(self.memory_bank.local_memories, {})
 | ||
|         self.assertEqual(self.memory_bank.memory_banks, {})
 | ||
|         
 | ||
|         # 验证调用了aiplatform.init
 | ||
|         self.mock_init.assert_called_once_with(project="test-project", location="us-central1")
 | ||
| 
 | ||
|     def test_from_config(self):
 | ||
|         """测试从配置创建实例"""
 | ||
|         with patch('src.jixia.memory.vertex_memory_bank.get_google_genai_config') as mock_config:
 | ||
|             mock_config.return_value = {
 | ||
|                 'project_id': 'config-project',
 | ||
|                 'location': 'europe-west1'
 | ||
|             }
 | ||
|             
 | ||
|             memory_bank = VertexMemoryBank.from_config()
 | ||
|             
 | ||
|             self.assertEqual(memory_bank.project_id, "config-project")
 | ||
|             self.assertEqual(memory_bank.location, "europe-west1")
 | ||
| 
 | ||
|     def test_from_config_missing_project_id(self):
 | ||
|         """测试从配置创建实例时缺少project_id"""
 | ||
|         with patch('src.jixia.memory.vertex_memory_bank.get_google_genai_config') as mock_config:
 | ||
|             mock_config.return_value = {
 | ||
|                 'project_id': None,
 | ||
|                 'location': 'europe-west1'
 | ||
|             }
 | ||
|             
 | ||
|             with self.assertRaises(ValueError) as context:
 | ||
|                 VertexMemoryBank.from_config()
 | ||
|             
 | ||
|             self.assertIn("Google Cloud Project ID 未配置", str(context.exception))
 | ||
| 
 | ||
|     async def test_create_memory_bank(self):
 | ||
|         """测试创建记忆银行"""
 | ||
|         memory_bank_id = await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         
 | ||
|         # 验证返回的ID格式
 | ||
|         self.assertEqual(memory_bank_id, "memory_bank_tieguaili_test-project")
 | ||
|         
 | ||
|         # 验证内部状态
 | ||
|         self.assertIn("tieguaili", self.memory_bank.memory_banks)
 | ||
|         self.assertEqual(self.memory_bank.memory_banks["tieguaili"], memory_bank_id)
 | ||
|         self.assertIn("tieguaili", self.memory_bank.local_memories)
 | ||
|         self.assertEqual(self.memory_bank.local_memories["tieguaili"], [])
 | ||
| 
 | ||
|     async def test_create_memory_bank_with_display_name(self):
 | ||
|         """测试创建记忆银行时指定显示名称"""
 | ||
|         memory_bank_id = await self.memory_bank.create_memory_bank(
 | ||
|             "tieguaili", 
 | ||
|             "铁拐李的专属记忆银行"
 | ||
|         )
 | ||
|         
 | ||
|         # 验证返回的ID格式
 | ||
|         self.assertEqual(memory_bank_id, "memory_bank_tieguaili_test-project")
 | ||
| 
 | ||
|     async def test_add_memory(self):
 | ||
|         """测试添加记忆"""
 | ||
|         # 先创建记忆银行
 | ||
|         await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         
 | ||
|         # 添加记忆
 | ||
|         memory_id = await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
 | ||
|             memory_type="preference",
 | ||
|             debate_topic="NVIDIA投资分析",
 | ||
|             metadata={"source": "manual"}
 | ||
|         )
 | ||
|         
 | ||
|         # 验证返回的ID格式
 | ||
|         self.assertEqual(memory_id, "memory_tieguaili_0")
 | ||
|         
 | ||
|         # 验证记忆已存储
 | ||
|         self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 1)
 | ||
|         
 | ||
|         stored_memory = self.memory_bank.local_memories["tieguaili"][0]
 | ||
|         self.assertEqual(stored_memory["id"], memory_id)
 | ||
|         self.assertEqual(stored_memory["content"], "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
 | ||
|         self.assertEqual(stored_memory["memory_type"], "preference")
 | ||
|         self.assertEqual(stored_memory["debate_topic"], "NVIDIA投资分析")
 | ||
|         self.assertIn("source", stored_memory["metadata"])
 | ||
|         self.assertEqual(stored_memory["metadata"]["source"], "manual")
 | ||
|         self.assertIn("agent_name", stored_memory["metadata"])
 | ||
|         self.assertEqual(stored_memory["metadata"]["agent_name"], "tieguaili")
 | ||
| 
 | ||
|     async def test_add_memory_creates_bank_if_not_exists(self):
 | ||
|         """测试添加记忆时自动创建记忆银行"""
 | ||
|         # 不先创建记忆银行,直接添加记忆
 | ||
|         memory_id = await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="测试内容"
 | ||
|         )
 | ||
|         
 | ||
|         # 验证记忆银行已被自动创建
 | ||
|         self.assertIn("tieguaili", self.memory_bank.memory_banks)
 | ||
|         self.assertIn("tieguaili", self.memory_bank.local_memories)
 | ||
|         
 | ||
|         # 验证记忆已存储
 | ||
|         self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 1)
 | ||
| 
 | ||
|     async def test_search_memories(self):
 | ||
|         """测试搜索记忆"""
 | ||
|         # 先创建记忆银行并添加一些记忆
 | ||
|         await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         
 | ||
|         await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
 | ||
|             memory_type="preference",
 | ||
|             debate_topic="NVIDIA投资分析"
 | ||
|         )
 | ||
|         
 | ||
|         await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="我喜欢关注苹果公司的创新产品发布会。",
 | ||
|             memory_type="preference",
 | ||
|             debate_topic="AAPL投资分析"
 | ||
|         )
 | ||
|         
 | ||
|         # 搜索NVIDIA相关记忆
 | ||
|         results = await self.memory_bank.search_memories(
 | ||
|             agent_name="tieguaili",
 | ||
|             query="NVIDIA"
 | ||
|         )
 | ||
|         
 | ||
|         self.assertEqual(len(results), 1)
 | ||
|         self.assertEqual(results[0]["content"], "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
 | ||
|         self.assertIn("relevance_score", results[0])
 | ||
| 
 | ||
|     async def test_search_memories_with_type_filter(self):
 | ||
|         """测试带类型过滤的搜索记忆"""
 | ||
|         # 先创建记忆银行并添加不同类型的记忆
 | ||
|         await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         
 | ||
|         await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
 | ||
|             memory_type="preference",
 | ||
|             debate_topic="NVIDIA投资分析"
 | ||
|         )
 | ||
|         
 | ||
|         await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="在NVIDIA的辩论中,我使用了技术分析策略。",
 | ||
|             memory_type="strategy",
 | ||
|             debate_topic="NVIDIA投资分析"
 | ||
|         )
 | ||
|         
 | ||
|         # 搜索NVIDIA相关记忆,只返回preference类型
 | ||
|         results = await self.memory_bank.search_memories(
 | ||
|             agent_name="tieguaili",
 | ||
|             query="NVIDIA",
 | ||
|             memory_type="preference"
 | ||
|         )
 | ||
|         
 | ||
|         self.assertEqual(len(results), 1)
 | ||
|         self.assertEqual(results[0]["metadata"]["memory_type"], "preference")
 | ||
| 
 | ||
|     async def test_search_memories_no_results(self):
 | ||
|         """测试搜索无结果的情况"""
 | ||
|         # 搜索不存在的记忆银行
 | ||
|         results = await self.memory_bank.search_memories(
 | ||
|             agent_name="nonexistent",
 | ||
|             query="test"
 | ||
|         )
 | ||
|         
 | ||
|         self.assertEqual(results, [])
 | ||
|         
 | ||
|         # 搜索空的记忆银行
 | ||
|         await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         results = await self.memory_bank.search_memories(
 | ||
|             agent_name="tieguaili",
 | ||
|             query="test"
 | ||
|         )
 | ||
|         
 | ||
|         self.assertEqual(results, [])
 | ||
| 
 | ||
|     async def test_get_agent_context(self):
 | ||
|         """测试获取智能体上下文"""
 | ||
|         # 先创建记忆银行并添加一些记忆
 | ||
|         await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         
 | ||
|         await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
 | ||
|             memory_type="preference",
 | ||
|             debate_topic="NVIDIA投资分析"
 | ||
|         )
 | ||
|         
 | ||
|         await self.memory_bank.add_memory(
 | ||
|             agent_name="tieguaili",
 | ||
|             content="在NVIDIA的辩论中,我使用了技术分析策略。",
 | ||
|             memory_type="strategy",
 | ||
|             debate_topic="NVIDIA投资分析"
 | ||
|         )
 | ||
|         
 | ||
|         # 获取上下文
 | ||
|         context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
 | ||
|         
 | ||
|         # 验证上下文包含预期内容
 | ||
|         self.assertIn("# 铁拐李的记忆上下文", context)
 | ||
|         self.assertIn("## 偏好记忆", context)
 | ||
|         self.assertIn("## 策略记忆", context)
 | ||
|         self.assertIn("在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。", context)
 | ||
|         self.assertIn("在NVIDIA的辩论中,我使用了技术分析策略。", context)
 | ||
| 
 | ||
|     async def test_get_agent_context_no_memories(self):
 | ||
|         """测试获取智能体上下文但无相关记忆"""
 | ||
|         # 先创建记忆银行
 | ||
|         await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         
 | ||
|         # 获取上下文
 | ||
|         context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
 | ||
|         
 | ||
|         # 验证上下文包含暂无相关记忆的提示
 | ||
|         self.assertIn("# 铁拐李的记忆上下文", context)
 | ||
|         self.assertIn("暂无相关记忆。", context)
 | ||
| 
 | ||
|     async def test_save_debate_session(self):
 | ||
|         """测试保存辩论会话"""
 | ||
|         conversation_history = [
 | ||
|             {"agent": "tieguaili", "content": "NVIDIA的估值过高,存在泡沫风险。"},
 | ||
|             {"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
 | ||
|             {"agent": "tieguaili", "content": "但我们需要考虑竞争加剧和增长放缓的可能性。"}
 | ||
|         ]
 | ||
|         
 | ||
|         outcomes = {
 | ||
|             "winner": "lvdongbin",
 | ||
|             "insights": {
 | ||
|                 "tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
 | ||
|             }
 | ||
|         }
 | ||
|         
 | ||
|         # 保存辩论会话
 | ||
|         await self.memory_bank.save_debate_session(
 | ||
|             debate_topic="NVIDIA投资分析",
 | ||
|             participants=["tieguaili", "lvdongbin"],
 | ||
|             conversation_history=conversation_history,
 | ||
|             outcomes=outcomes
 | ||
|         )
 | ||
|         
 | ||
|         # 验证铁拐李的记忆已保存
 | ||
|         self.assertIn("tieguaili", self.memory_bank.local_memories)
 | ||
|         self.assertEqual(len(self.memory_bank.local_memories["tieguaili"]), 2)
 | ||
|         
 | ||
|         # 验证第一条记忆是对话总结
 | ||
|         summary_memory = self.memory_bank.local_memories["tieguaili"][0]
 | ||
|         self.assertEqual(summary_memory["memory_type"], "conversation")
 | ||
|         self.assertIn("铁拐李在本次辩论中的主要观点", summary_memory["content"])
 | ||
|         
 | ||
|         # 验证第二条记忆是策略洞察
 | ||
|         strategy_memory = self.memory_bank.local_memories["tieguaili"][1]
 | ||
|         self.assertEqual(strategy_memory["memory_type"], "strategy")
 | ||
|         self.assertIn("铁拐李的风险意识值得肯定", strategy_memory["content"])
 | ||
| 
 | ||
|     def test_summarize_conversation(self):
 | ||
|         """测试对话总结"""
 | ||
|         conversation_history = [
 | ||
|             {"agent": "tieguaili", "content": "第一点看法:NVIDIA的估值过高,存在泡沫风险。"},
 | ||
|             {"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
 | ||
|             {"agent": "tieguaili", "content": "第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。"},
 | ||
|             {"agent": "tieguaili", "content": "第三点看法:从技术分析角度看,股价已出现超买信号。"}
 | ||
|         ]
 | ||
|         
 | ||
|         summary = self.memory_bank._summarize_conversation(conversation_history, "tieguaili")
 | ||
|         
 | ||
|         # 验证总结包含预期内容
 | ||
|         self.assertIn("铁拐李在本次辩论中的主要观点", summary)
 | ||
|         self.assertIn("第一点看法:NVIDIA的估值过高,存在泡沫风险。", summary)
 | ||
|         self.assertIn("第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。", summary)
 | ||
|         self.assertIn("第三点看法:从技术分析角度看,股价已出现超买信号。", summary)
 | ||
| 
 | ||
|     def test_extract_strategy_insight_winner(self):
 | ||
|         """测试提取策略洞察 - 获胜者"""
 | ||
|         outcomes = {
 | ||
|             "winner": "tieguaili",
 | ||
|             "insights": {}
 | ||
|         }
 | ||
|         
 | ||
|         insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
 | ||
|         
 | ||
|         self.assertIn("铁拐李在本次辩论中获胜", insight)
 | ||
| 
 | ||
|     def test_extract_strategy_insight_from_insights(self):
 | ||
|         """测试从洞察中提取策略洞察"""
 | ||
|         outcomes = {
 | ||
|             "winner": "lvdongbin",
 | ||
|             "insights": {
 | ||
|                 "tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
 | ||
|             }
 | ||
|         }
 | ||
|         
 | ||
|         insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
 | ||
|         
 | ||
|         self.assertEqual(insight, "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。")
 | ||
| 
 | ||
| 
 | ||
| if __name__ == '__main__':
 | ||
|     # 创建一个异步测试运行器
 | ||
|     def run_async_test(test_case):
 | ||
|         """运行异步测试用例"""
 | ||
|         loop = asyncio.new_event_loop()
 | ||
|         asyncio.set_event_loop(loop)
 | ||
|         try:
 | ||
|             return loop.run_until_complete(test_case)
 | ||
|         finally:
 | ||
|             loop.close()
 | ||
|     
 | ||
|     # 获取所有以test_开头的异步方法并运行它们
 | ||
|     suite = unittest.TestSuite()
 | ||
|     test_instance = TestVertexMemoryBank()
 | ||
|     test_instance.setUp()
 | ||
|     
 | ||
|     # 添加同步测试
 | ||
|     suite.addTest(TestVertexMemoryBank('test_init'))
 | ||
|     suite.addTest(TestVertexMemoryBank('test_from_config'))
 | ||
|     suite.addTest(TestVertexMemoryBank('test_from_config_missing_project_id'))
 | ||
|     suite.addTest(TestVertexMemoryBank('test_summarize_conversation'))
 | ||
|     suite.addTest(TestVertexMemoryBank('test_extract_strategy_insight_winner'))
 | ||
|     suite.addTest(TestVertexMemoryBank('test_extract_strategy_insight_from_insights'))
 | ||
|     
 | ||
|     # 添加异步测试
 | ||
|     async_tests = [
 | ||
|         'test_create_memory_bank',
 | ||
|         'test_create_memory_bank_with_display_name',
 | ||
|         'test_add_memory',
 | ||
|         'test_add_memory_creates_bank_if_not_exists',
 | ||
|         'test_search_memories',
 | ||
|         'test_search_memories_with_type_filter',
 | ||
|         'test_search_memories_no_results',
 | ||
|         'test_get_agent_context',
 | ||
|         'test_get_agent_context_no_memories',
 | ||
|         'test_save_debate_session'
 | ||
|     ]
 | ||
|     
 | ||
|     for test_name in async_tests:
 | ||
|         test_method = getattr(test_instance, test_name)
 | ||
|         suite.addTest(unittest.FunctionTestCase(lambda tm=test_method: run_async_test(tm())))
 | ||
|     
 | ||
|     # 运行测试
 | ||
|     runner = unittest.TextTestRunner(verbosity=2)
 | ||
|     runner.run(suite) |