436 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			436 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
| #!/usr/bin/env python3
 | ||
| """
 | ||
| Cloudflare Memory Bank 实现测试
 | ||
| """
 | ||
| 
 | ||
| import unittest
 | ||
| import asyncio
 | ||
| import os
 | ||
| import sys
 | ||
| from unittest.mock import patch, MagicMock, AsyncMock
 | ||
| from datetime import datetime
 | ||
| 
 | ||
| # 添加项目根目录到Python路径
 | ||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
 | ||
| 
 | ||
| from src.jixia.memory.cloudflare_memory_bank import CloudflareMemoryBank, MemoryEntry
 | ||
| 
 | ||
| 
 | ||
| class TestCloudflareMemoryBank(unittest.TestCase):
 | ||
|     """测试CloudflareMemoryBank类"""
 | ||
| 
 | ||
|     def setUp(self):
 | ||
|         """测试前的设置"""
 | ||
|         # Mock掉 aiohttp.ClientSession 以避免实际网络请求
 | ||
|         self.patcher = patch('src.jixia.memory.cloudflare_memory_bank.aiohttp.ClientSession')
 | ||
|         self.mock_session_class = self.patcher.start()
 | ||
|         self.mock_session = AsyncMock()
 | ||
|         self.mock_session_class.return_value = self.mock_session
 | ||
|         
 | ||
|         # Mock掉 get_cloudflare_config 以避免实际读取配置
 | ||
|         self.config_patcher = patch('src.jixia.memory.cloudflare_memory_bank.get_cloudflare_config')
 | ||
|         self.mock_get_config = self.config_patcher.start()
 | ||
|         self.mock_get_config.return_value = {
 | ||
|             'account_id': 'test-account',
 | ||
|             'api_token': 'test-token',
 | ||
|             'vectorize_index': 'test-index',
 | ||
|             'embed_model': '@cf/baai/bge-m3',
 | ||
|             'autorag_domain': 'test.example.com'
 | ||
|         }
 | ||
|         
 | ||
|         # 创建CloudflareMemoryBank实例
 | ||
|         self.memory_bank = CloudflareMemoryBank()
 | ||
|         
 | ||
|         # 重置一些内部状态
 | ||
|         self.memory_bank.config = self.mock_get_config.return_value
 | ||
|         self.memory_bank.account_id = 'test-account'
 | ||
|         self.memory_bank.api_token = 'test-token'
 | ||
|         self.memory_bank.vectorize_index = 'test-index'
 | ||
|         self.memory_bank.embed_model = '@cf/baai/bge-m3'
 | ||
|         self.memory_bank.autorag_domain = 'test.example.com'
 | ||
| 
 | ||
|     def tearDown(self):
 | ||
|         """测试后的清理"""
 | ||
|         self.patcher.stop()
 | ||
|         self.config_patcher.stop()
 | ||
| 
 | ||
|     def test_init(self):
 | ||
|         """测试初始化"""
 | ||
|         self.assertEqual(self.memory_bank.account_id, "test-account")
 | ||
|         self.assertEqual(self.memory_bank.api_token, "test-token")
 | ||
|         self.assertEqual(self.memory_bank.vectorize_index, "test-index")
 | ||
|         self.assertEqual(self.memory_bank.embed_model, "@cf/baai/bge-m3")
 | ||
|         self.assertEqual(self.memory_bank.autorag_domain, "test.example.com")
 | ||
| 
 | ||
|     async def test_create_memory_bank(self):
 | ||
|         """测试创建记忆空间"""
 | ||
|         memory_bank_id = await self.memory_bank.create_memory_bank("tieguaili")
 | ||
|         
 | ||
|         # 验证返回的ID格式
 | ||
|         self.assertEqual(memory_bank_id, "cf_memory_tieguaili")
 | ||
| 
 | ||
|     async def test_create_memory_bank_with_display_name(self):
 | ||
|         """测试创建记忆空间时指定显示名称"""
 | ||
|         memory_bank_id = await self.memory_bank.create_memory_bank(
 | ||
|             "tieguaili", 
 | ||
|             "铁拐李的专属记忆银行"
 | ||
|         )
 | ||
|         
 | ||
|         # 验证返回的ID格式
 | ||
|         self.assertEqual(memory_bank_id, "cf_memory_tieguaili")
 | ||
| 
 | ||
|     async def test_generate_embedding(self):
 | ||
|         """测试生成嵌入向量"""
 | ||
|         # Mock响应
 | ||
|         mock_response = AsyncMock()
 | ||
|         mock_response.status = 200
 | ||
|         mock_response.json = AsyncMock(return_value={
 | ||
|             "result": {
 | ||
|                 "data": [
 | ||
|                     {
 | ||
|                         "embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
 | ||
|                     }
 | ||
|                 ]
 | ||
|             }
 | ||
|         })
 | ||
|         
 | ||
|         # Mock session.post
 | ||
|         self.mock_session.post.return_value.__aenter__.return_value = mock_response
 | ||
|         
 | ||
|         # 调用方法
 | ||
|         embedding = await self.memory_bank._generate_embedding("测试文本")
 | ||
|         
 | ||
|         # 验证结果
 | ||
|         self.assertEqual(embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
 | ||
|         
 | ||
|         # 验证调用了正确的URL和参数
 | ||
|         expected_url = "https://api.cloudflare.com/client/v4/accounts/test-account/ai/run/@cf/baai/bge-m3"
 | ||
|         self.mock_session.post.assert_called_once()
 | ||
|         call_args = self.mock_session.post.call_args
 | ||
|         self.assertEqual(call_args[0][0], expected_url)
 | ||
|         self.assertEqual(call_args[1]['json'], {"text": ["测试文本"]})
 | ||
| 
 | ||
|     async def test_generate_embedding_api_error(self):
 | ||
|         """测试生成嵌入向量时API错误"""
 | ||
|         # Mock响应
 | ||
|         mock_response = AsyncMock()
 | ||
|         mock_response.status = 500
 | ||
|         mock_response.text = AsyncMock(return_value="Internal Server Error")
 | ||
|         
 | ||
|         # Mock session.post
 | ||
|         self.mock_session.post.return_value.__aenter__.return_value = mock_response
 | ||
|         
 | ||
|         # 验证抛出异常
 | ||
|         with self.assertRaises(Exception) as context:
 | ||
|             await self.memory_bank._generate_embedding("测试文本")
 | ||
|         
 | ||
|         self.assertIn("Failed to generate embedding", str(context.exception))
 | ||
| 
 | ||
|     async def test_add_memory(self):
 | ||
|         """测试添加记忆"""
 | ||
|         # Mock _generate_embedding 方法
 | ||
|         with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed:
 | ||
|             # Mock upsert 响应
 | ||
|             mock_response = AsyncMock()
 | ||
|             mock_response.status = 200
 | ||
|             mock_response.json = AsyncMock(return_value={"result": {"upserted": 1}})
 | ||
|             
 | ||
|             # Mock session.post
 | ||
|             self.mock_session.post.return_value.__aenter__.return_value = mock_response
 | ||
|             
 | ||
|             # 添加记忆
 | ||
|             memory_id = await self.memory_bank.add_memory(
 | ||
|                 agent_name="tieguaili",
 | ||
|                 content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
 | ||
|                 memory_type="preference",
 | ||
|                 debate_topic="NVIDIA投资分析",
 | ||
|                 metadata={"source": "manual"}
 | ||
|             )
 | ||
|             
 | ||
|             # 验证返回的ID格式 (以mem_开头)
 | ||
|             self.assertTrue(memory_id.startswith("mem_tieguaili_"))
 | ||
|             
 | ||
|             # 验证调用了生成嵌入的方法
 | ||
|             mock_embed.assert_called_once_with("在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
 | ||
|             
 | ||
|             # 验证调用了upsert API
 | ||
|             self.mock_session.post.assert_called()
 | ||
|             
 | ||
|             # 验证upsert调用的参数
 | ||
|             upsert_call = None
 | ||
|             for call in self.mock_session.post.call_args_list:
 | ||
|                 if 'vectorize/indexes/test-index/upsert' in call[0][0]:
 | ||
|                     upsert_call = call
 | ||
|                     break
 | ||
|             
 | ||
|             self.assertIsNotNone(upsert_call)
 | ||
|             call_args, call_kwargs = upsert_call
 | ||
|             self.assertIn("vectorize/indexes/test-index/upsert", call_args[0])
 | ||
|             self.assertIn("vectors", call_kwargs['json'])
 | ||
| 
 | ||
|     async def test_add_memory_api_error(self):
 | ||
|         """测试添加记忆时API错误"""
 | ||
|         # Mock _generate_embedding 方法
 | ||
|         with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])):
 | ||
|             # Mock upsert 响应
 | ||
|             mock_response = AsyncMock()
 | ||
|             mock_response.status = 500
 | ||
|             mock_response.text = AsyncMock(return_value="Internal Server Error")
 | ||
|             
 | ||
|             # Mock session.post
 | ||
|             self.mock_session.post.return_value.__aenter__.return_value = mock_response
 | ||
|             
 | ||
|             # 验证抛出异常
 | ||
|             with self.assertRaises(Exception) as context:
 | ||
|                 await self.memory_bank.add_memory(
 | ||
|                     agent_name="tieguaili",
 | ||
|                     content="测试内容"
 | ||
|                 )
 | ||
|             
 | ||
|             self.assertIn("Failed to upsert memory", str(context.exception))
 | ||
| 
 | ||
|     async def test_search_memories(self):
 | ||
|         """测试搜索记忆"""
 | ||
|         # Mock _generate_embedding 方法
 | ||
|         with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed:
 | ||
|             # Mock query 响应
 | ||
|             mock_response = AsyncMock()
 | ||
|             mock_response.status = 200
 | ||
|             mock_response.json = AsyncMock(return_value={
 | ||
|                 "result": {
 | ||
|                     "matches": [
 | ||
|                         {
 | ||
|                             "metadata": {
 | ||
|                                 "content": "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
 | ||
|                                 "memory_type": "preference",
 | ||
|                                 "agent_name": "tieguaili"
 | ||
|                             },
 | ||
|                             "score": 0.95
 | ||
|                         }
 | ||
|                     ]
 | ||
|                 }
 | ||
|             })
 | ||
|             
 | ||
|             # Mock session.post
 | ||
|             self.mock_session.post.return_value.__aenter__.return_value = mock_response
 | ||
|             
 | ||
|             # 搜索记忆
 | ||
|             results = await self.memory_bank.search_memories(
 | ||
|                 agent_name="tieguaili",
 | ||
|                 query="NVIDIA",
 | ||
|                 memory_type="preference",
 | ||
|                 limit=5
 | ||
|             )
 | ||
|             
 | ||
|             # 验证结果
 | ||
|             self.assertEqual(len(results), 1)
 | ||
|             self.assertEqual(results[0]["content"], "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
 | ||
|             self.assertEqual(results[0]["relevance_score"], 0.95)
 | ||
|             
 | ||
|             # 验证调用了生成嵌入的方法
 | ||
|             mock_embed.assert_called_once_with("NVIDIA")
 | ||
|             
 | ||
|             # 验证调用了query API
 | ||
|             self.mock_session.post.assert_called()
 | ||
|             
 | ||
|             # 验证query调用的参数
 | ||
|             query_call = None
 | ||
|             for call in self.mock_session.post.call_args_list:
 | ||
|                 if 'vectorize/indexes/test-index/query' in call[0][0]:
 | ||
|                     query_call = call
 | ||
|                     break
 | ||
|             
 | ||
|             self.assertIsNotNone(query_call)
 | ||
|             call_args, call_kwargs = query_call
 | ||
|             self.assertIn("vectorize/indexes/test-index/query", call_args[0])
 | ||
|             self.assertIn("vector", call_kwargs['json'])
 | ||
|             self.assertIn("filter", call_kwargs['json'])
 | ||
|             self.assertEqual(call_kwargs['json']['filter'], {"agent_name": "tieguaili", "memory_type": "preference"})
 | ||
| 
 | ||
|     async def test_search_memories_api_error(self):
 | ||
|         """测试搜索记忆时API错误"""
 | ||
|         # Mock _generate_embedding 方法
 | ||
|         with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])):
 | ||
|             # Mock query 响应
 | ||
|             mock_response = AsyncMock()
 | ||
|             mock_response.status = 500
 | ||
|             mock_response.text = AsyncMock(return_value="Internal Server Error")
 | ||
|             
 | ||
|             # Mock session.post
 | ||
|             self.mock_session.post.return_value.__aenter__.return_value = mock_response
 | ||
|             
 | ||
|             # 验证返回空列表而不是抛出异常
 | ||
|             results = await self.memory_bank.search_memories(
 | ||
|                 agent_name="tieguaili",
 | ||
|                 query="NVIDIA"
 | ||
|             )
 | ||
|             
 | ||
|             self.assertEqual(results, [])
 | ||
| 
 | ||
|     async def test_get_agent_context(self):
 | ||
|         """测试获取智能体上下文"""
 | ||
|         # Mock search_memories 方法
 | ||
|         with patch.object(self.memory_bank, 'search_memories', new=AsyncMock()) as mock_search:
 | ||
|             # 设置mock返回值
 | ||
|             mock_search.side_effect = [
 | ||
|                 [  # conversation memories
 | ||
|                     {"content": "NVIDIA的估值过高,存在泡沫风险。", "relevance_score": 0.9}
 | ||
|                 ],
 | ||
|                 [  # preference memories
 | ||
|                     {"content": "倾向于逆向思维,关注潜在风险。", "relevance_score": 0.8}
 | ||
|                 ],
 | ||
|                 [  # strategy memories
 | ||
|                     {"content": "使用技术分析策略。", "relevance_score": 0.7}
 | ||
|                 ]
 | ||
|             ]
 | ||
|             
 | ||
|             # 获取上下文
 | ||
|             context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
 | ||
|             
 | ||
|             # 验证上下文包含预期内容
 | ||
|             self.assertIn("# 铁拐李的记忆上下文", context)
 | ||
|             self.assertIn("## 历史对话记忆", context)
 | ||
|             self.assertIn("## 偏好记忆", context)
 | ||
|             self.assertIn("## 策略记忆", context)
 | ||
|             self.assertIn("NVIDIA的估值过高,存在泡沫风险。", context)
 | ||
|             self.assertIn("倾向于逆向思维,关注潜在风险。", context)
 | ||
|             self.assertIn("使用技术分析策略。", context)
 | ||
| 
 | ||
|     async def test_get_agent_context_no_memories(self):
 | ||
|         """测试获取智能体上下文但无相关记忆"""
 | ||
|         # Mock search_memories 方法
 | ||
|         with patch.object(self.memory_bank, 'search_memories', new=AsyncMock(return_value=[])):
 | ||
|             # 获取上下文
 | ||
|             context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
 | ||
|             
 | ||
|             # 验证上下文包含暂无相关记忆的提示
 | ||
|             self.assertIn("# 铁拐李的记忆上下文", context)
 | ||
|             self.assertIn("暂无相关记忆。", context)
 | ||
| 
 | ||
|     async def test_save_debate_session(self):
 | ||
|         """测试保存辩论会话"""
 | ||
|         # Mock add_memory 方法
 | ||
|         with patch.object(self.memory_bank, 'add_memory', new=AsyncMock()) as mock_add:
 | ||
|             conversation_history = [
 | ||
|                 {"agent": "tieguaili", "content": "NVIDIA的估值过高,存在泡沫风险。"},
 | ||
|                 {"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
 | ||
|                 {"agent": "tieguaili", "content": "但我们需要考虑竞争加剧和增长放缓的可能性。"}
 | ||
|             ]
 | ||
|             
 | ||
|             outcomes = {
 | ||
|                 "winner": "lvdongbin",
 | ||
|                 "insights": {
 | ||
|                     "tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
 | ||
|                 }
 | ||
|             }
 | ||
|             
 | ||
|             # 保存辩论会话
 | ||
|             await self.memory_bank.save_debate_session(
 | ||
|                 debate_topic="NVIDIA投资分析",
 | ||
|                 participants=["tieguaili", "lvdongbin"],
 | ||
|                 conversation_history=conversation_history,
 | ||
|                 outcomes=outcomes
 | ||
|             )
 | ||
|             
 | ||
|             # 验证调用了add_memory两次(对话总结和策略洞察)
 | ||
|             self.assertEqual(mock_add.call_count, 2)
 | ||
|             
 | ||
|             # 验证第一次调用是对话总结
 | ||
|             call_args1 = mock_add.call_args_list[0][1]
 | ||
|             self.assertEqual(call_args1['agent_name'], 'tieguaili')
 | ||
|             self.assertEqual(call_args1['memory_type'], 'conversation')
 | ||
|             self.assertIn('铁拐李在本次辩论中的主要观点', call_args1['content'])
 | ||
|             
 | ||
|             # 验证第二次调用是策略洞察
 | ||
|             call_args2 = mock_add.call_args_list[1][1]
 | ||
|             self.assertEqual(call_args2['agent_name'], 'tieguaili')
 | ||
|             self.assertEqual(call_args2['memory_type'], 'strategy')
 | ||
|             self.assertIn('铁拐李的风险意识值得肯定', call_args2['content'])
 | ||
| 
 | ||
|     def test_summarize_conversation(self):
 | ||
|         """测试对话总结"""
 | ||
|         conversation_history = [
 | ||
|             {"agent": "tieguaili", "content": "第一点看法:NVIDIA的估值过高,存在泡沫风险。"},
 | ||
|             {"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
 | ||
|             {"agent": "tieguaili", "content": "第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。"},
 | ||
|             {"agent": "tieguaili", "content": "第三点看法:从技术分析角度看,股价已出现超买信号。"}
 | ||
|         ]
 | ||
|         
 | ||
|         summary = self.memory_bank._summarize_conversation(conversation_history, "tieguaili")
 | ||
|         
 | ||
|         # 验证总结包含预期内容
 | ||
|         self.assertIn("铁拐李在本次辩论中的主要观点", summary)
 | ||
|         self.assertIn("第一点看法:NVIDIA的估值过高,存在泡沫风险。", summary)
 | ||
|         self.assertIn("第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。", summary)
 | ||
|         self.assertIn("第三点看法:从技术分析角度看,股价已出现超买信号。", summary)
 | ||
| 
 | ||
|     def test_extract_strategy_insight_winner(self):
 | ||
|         """测试提取策略洞察 - 获胜者"""
 | ||
|         outcomes = {
 | ||
|             "winner": "tieguaili",
 | ||
|             "insights": {}
 | ||
|         }
 | ||
|         
 | ||
|         insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
 | ||
|         
 | ||
|         self.assertIn("铁拐李在本次辩论中获胜", insight)
 | ||
| 
 | ||
|     def test_extract_strategy_insight_from_insights(self):
 | ||
|         """测试从洞察中提取策略洞察"""
 | ||
|         outcomes = {
 | ||
|             "winner": "lvdongbin",
 | ||
|             "insights": {
 | ||
|                 "tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
 | ||
|             }
 | ||
|         }
 | ||
|         
 | ||
|         insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
 | ||
|         
 | ||
|         self.assertEqual(insight, "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。")
 | ||
| 
 | ||
| 
 | ||
| if __name__ == '__main__':
 | ||
|     # 创建一个异步测试运行器
 | ||
|     def run_async_test(test_case):
 | ||
|         """运行异步测试用例"""
 | ||
|         loop = asyncio.new_event_loop()
 | ||
|         asyncio.set_event_loop(loop)
 | ||
|         try:
 | ||
|             return loop.run_until_complete(test_case)
 | ||
|         finally:
 | ||
|             loop.close()
 | ||
|     
 | ||
|     # 获取所有以test_开头的异步方法并运行它们
 | ||
|     suite = unittest.TestSuite()
 | ||
|     test_instance = TestCloudflareMemoryBank()
 | ||
|     test_instance.setUp()
 | ||
|     test_instance.addCleanup(test_instance.tearDown)
 | ||
|     
 | ||
|     # 添加同步测试
 | ||
|     suite.addTest(TestCloudflareMemoryBank('test_init'))
 | ||
|     suite.addTest(TestCloudflareMemoryBank('test_summarize_conversation'))
 | ||
|     suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_winner'))
 | ||
|     suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_from_insights'))
 | ||
|     
 | ||
|     # 添加异步测试
 | ||
|     async_tests = [
 | ||
|         'test_create_memory_bank',
 | ||
|         'test_create_memory_bank_with_display_name',
 | ||
|         'test_generate_embedding',
 | ||
|         'test_generate_embedding_api_error',
 | ||
|         'test_add_memory',
 | ||
|         'test_add_memory_api_error',
 | ||
|         'test_search_memories',
 | ||
|         'test_search_memories_api_error',
 | ||
|         'test_get_agent_context',
 | ||
|         'test_get_agent_context_no_memories',
 | ||
|         'test_save_debate_session'
 | ||
|     ]
 | ||
|     
 | ||
|     for test_name in async_tests:
 | ||
|         test_method = getattr(test_instance, test_name)
 | ||
|         suite.addTest(unittest.FunctionTestCase(lambda tm=test_method: run_async_test(tm())))
 | ||
|     
 | ||
|     # 运行测试
 | ||
|     runner = unittest.TextTestRunner(verbosity=2)
 | ||
|     runner.run(suite) |