#!/usr/bin/env python3 """ Cloudflare Memory Bank 实现测试 """ import unittest import asyncio import os import sys from unittest.mock import patch, MagicMock, AsyncMock from datetime import datetime # 添加项目根目录到Python路径 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from src.jixia.memory.cloudflare_memory_bank import CloudflareMemoryBank, MemoryEntry class TestCloudflareMemoryBank(unittest.TestCase): """测试CloudflareMemoryBank类""" def setUp(self): """测试前的设置""" # Mock掉 aiohttp.ClientSession 以避免实际网络请求 self.patcher = patch('src.jixia.memory.cloudflare_memory_bank.aiohttp.ClientSession') self.mock_session_class = self.patcher.start() self.mock_session = AsyncMock() self.mock_session_class.return_value = self.mock_session # Mock掉 get_cloudflare_config 以避免实际读取配置 self.config_patcher = patch('src.jixia.memory.cloudflare_memory_bank.get_cloudflare_config') self.mock_get_config = self.config_patcher.start() self.mock_get_config.return_value = { 'account_id': 'test-account', 'api_token': 'test-token', 'vectorize_index': 'test-index', 'embed_model': '@cf/baai/bge-m3', 'autorag_domain': 'test.example.com' } # 创建CloudflareMemoryBank实例 self.memory_bank = CloudflareMemoryBank() # 重置一些内部状态 self.memory_bank.config = self.mock_get_config.return_value self.memory_bank.account_id = 'test-account' self.memory_bank.api_token = 'test-token' self.memory_bank.vectorize_index = 'test-index' self.memory_bank.embed_model = '@cf/baai/bge-m3' self.memory_bank.autorag_domain = 'test.example.com' def tearDown(self): """测试后的清理""" self.patcher.stop() self.config_patcher.stop() def test_init(self): """测试初始化""" self.assertEqual(self.memory_bank.account_id, "test-account") self.assertEqual(self.memory_bank.api_token, "test-token") self.assertEqual(self.memory_bank.vectorize_index, "test-index") self.assertEqual(self.memory_bank.embed_model, "@cf/baai/bge-m3") self.assertEqual(self.memory_bank.autorag_domain, "test.example.com") async def test_create_memory_bank(self): """测试创建记忆空间""" memory_bank_id = await self.memory_bank.create_memory_bank("tieguaili") # 验证返回的ID格式 self.assertEqual(memory_bank_id, "cf_memory_tieguaili") async def test_create_memory_bank_with_display_name(self): """测试创建记忆空间时指定显示名称""" memory_bank_id = await self.memory_bank.create_memory_bank( "tieguaili", "铁拐李的专属记忆银行" ) # 验证返回的ID格式 self.assertEqual(memory_bank_id, "cf_memory_tieguaili") async def test_generate_embedding(self): """测试生成嵌入向量""" # Mock响应 mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value={ "result": { "data": [ { "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] } ] } }) # Mock session.post self.mock_session.post.return_value.__aenter__.return_value = mock_response # 调用方法 embedding = await self.memory_bank._generate_embedding("测试文本") # 验证结果 self.assertEqual(embedding, [0.1, 0.2, 0.3, 0.4, 0.5]) # 验证调用了正确的URL和参数 expected_url = "https://api.cloudflare.com/client/v4/accounts/test-account/ai/run/@cf/baai/bge-m3" self.mock_session.post.assert_called_once() call_args = self.mock_session.post.call_args self.assertEqual(call_args[0][0], expected_url) self.assertEqual(call_args[1]['json'], {"text": ["测试文本"]}) async def test_generate_embedding_api_error(self): """测试生成嵌入向量时API错误""" # Mock响应 mock_response = AsyncMock() mock_response.status = 500 mock_response.text = AsyncMock(return_value="Internal Server Error") # Mock session.post self.mock_session.post.return_value.__aenter__.return_value = mock_response # 验证抛出异常 with self.assertRaises(Exception) as context: await self.memory_bank._generate_embedding("测试文本") self.assertIn("Failed to generate embedding", str(context.exception)) async def test_add_memory(self): """测试添加记忆""" # Mock _generate_embedding 方法 with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed: # Mock upsert 响应 mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value={"result": {"upserted": 1}}) # Mock session.post self.mock_session.post.return_value.__aenter__.return_value = mock_response # 添加记忆 memory_id = await self.memory_bank.add_memory( agent_name="tieguaili", content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。", memory_type="preference", debate_topic="NVIDIA投资分析", metadata={"source": "manual"} ) # 验证返回的ID格式 (以mem_开头) self.assertTrue(memory_id.startswith("mem_tieguaili_")) # 验证调用了生成嵌入的方法 mock_embed.assert_called_once_with("在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。") # 验证调用了upsert API self.mock_session.post.assert_called() # 验证upsert调用的参数 upsert_call = None for call in self.mock_session.post.call_args_list: if 'vectorize/indexes/test-index/upsert' in call[0][0]: upsert_call = call break self.assertIsNotNone(upsert_call) call_args, call_kwargs = upsert_call self.assertIn("vectorize/indexes/test-index/upsert", call_args[0]) self.assertIn("vectors", call_kwargs['json']) async def test_add_memory_api_error(self): """测试添加记忆时API错误""" # Mock _generate_embedding 方法 with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])): # Mock upsert 响应 mock_response = AsyncMock() mock_response.status = 500 mock_response.text = AsyncMock(return_value="Internal Server Error") # Mock session.post self.mock_session.post.return_value.__aenter__.return_value = mock_response # 验证抛出异常 with self.assertRaises(Exception) as context: await self.memory_bank.add_memory( agent_name="tieguaili", content="测试内容" ) self.assertIn("Failed to upsert memory", str(context.exception)) async def test_search_memories(self): """测试搜索记忆""" # Mock _generate_embedding 方法 with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed: # Mock query 响应 mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value={ "result": { "matches": [ { "metadata": { "content": "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。", "memory_type": "preference", "agent_name": "tieguaili" }, "score": 0.95 } ] } }) # Mock session.post self.mock_session.post.return_value.__aenter__.return_value = mock_response # 搜索记忆 results = await self.memory_bank.search_memories( agent_name="tieguaili", query="NVIDIA", memory_type="preference", limit=5 ) # 验证结果 self.assertEqual(len(results), 1) self.assertEqual(results[0]["content"], "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。") self.assertEqual(results[0]["relevance_score"], 0.95) # 验证调用了生成嵌入的方法 mock_embed.assert_called_once_with("NVIDIA") # 验证调用了query API self.mock_session.post.assert_called() # 验证query调用的参数 query_call = None for call in self.mock_session.post.call_args_list: if 'vectorize/indexes/test-index/query' in call[0][0]: query_call = call break self.assertIsNotNone(query_call) call_args, call_kwargs = query_call self.assertIn("vectorize/indexes/test-index/query", call_args[0]) self.assertIn("vector", call_kwargs['json']) self.assertIn("filter", call_kwargs['json']) self.assertEqual(call_kwargs['json']['filter'], {"agent_name": "tieguaili", "memory_type": "preference"}) async def test_search_memories_api_error(self): """测试搜索记忆时API错误""" # Mock _generate_embedding 方法 with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])): # Mock query 响应 mock_response = AsyncMock() mock_response.status = 500 mock_response.text = AsyncMock(return_value="Internal Server Error") # Mock session.post self.mock_session.post.return_value.__aenter__.return_value = mock_response # 验证返回空列表而不是抛出异常 results = await self.memory_bank.search_memories( agent_name="tieguaili", query="NVIDIA" ) self.assertEqual(results, []) async def test_get_agent_context(self): """测试获取智能体上下文""" # Mock search_memories 方法 with patch.object(self.memory_bank, 'search_memories', new=AsyncMock()) as mock_search: # 设置mock返回值 mock_search.side_effect = [ [ # conversation memories {"content": "NVIDIA的估值过高,存在泡沫风险。", "relevance_score": 0.9} ], [ # preference memories {"content": "倾向于逆向思维,关注潜在风险。", "relevance_score": 0.8} ], [ # strategy memories {"content": "使用技术分析策略。", "relevance_score": 0.7} ] ] # 获取上下文 context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析") # 验证上下文包含预期内容 self.assertIn("# 铁拐李的记忆上下文", context) self.assertIn("## 历史对话记忆", context) self.assertIn("## 偏好记忆", context) self.assertIn("## 策略记忆", context) self.assertIn("NVIDIA的估值过高,存在泡沫风险。", context) self.assertIn("倾向于逆向思维,关注潜在风险。", context) self.assertIn("使用技术分析策略。", context) async def test_get_agent_context_no_memories(self): """测试获取智能体上下文但无相关记忆""" # Mock search_memories 方法 with patch.object(self.memory_bank, 'search_memories', new=AsyncMock(return_value=[])): # 获取上下文 context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析") # 验证上下文包含暂无相关记忆的提示 self.assertIn("# 铁拐李的记忆上下文", context) self.assertIn("暂无相关记忆。", context) async def test_save_debate_session(self): """测试保存辩论会话""" # Mock add_memory 方法 with patch.object(self.memory_bank, 'add_memory', new=AsyncMock()) as mock_add: conversation_history = [ {"agent": "tieguaili", "content": "NVIDIA的估值过高,存在泡沫风险。"}, {"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"}, {"agent": "tieguaili", "content": "但我们需要考虑竞争加剧和增长放缓的可能性。"} ] outcomes = { "winner": "lvdongbin", "insights": { "tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。" } } # 保存辩论会话 await self.memory_bank.save_debate_session( debate_topic="NVIDIA投资分析", participants=["tieguaili", "lvdongbin"], conversation_history=conversation_history, outcomes=outcomes ) # 验证调用了add_memory两次(对话总结和策略洞察) self.assertEqual(mock_add.call_count, 2) # 验证第一次调用是对话总结 call_args1 = mock_add.call_args_list[0][1] self.assertEqual(call_args1['agent_name'], 'tieguaili') self.assertEqual(call_args1['memory_type'], 'conversation') self.assertIn('铁拐李在本次辩论中的主要观点', call_args1['content']) # 验证第二次调用是策略洞察 call_args2 = mock_add.call_args_list[1][1] self.assertEqual(call_args2['agent_name'], 'tieguaili') self.assertEqual(call_args2['memory_type'], 'strategy') self.assertIn('铁拐李的风险意识值得肯定', call_args2['content']) def test_summarize_conversation(self): """测试对话总结""" conversation_history = [ {"agent": "tieguaili", "content": "第一点看法:NVIDIA的估值过高,存在泡沫风险。"}, {"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"}, {"agent": "tieguaili", "content": "第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。"}, {"agent": "tieguaili", "content": "第三点看法:从技术分析角度看,股价已出现超买信号。"} ] summary = self.memory_bank._summarize_conversation(conversation_history, "tieguaili") # 验证总结包含预期内容 self.assertIn("铁拐李在本次辩论中的主要观点", summary) self.assertIn("第一点看法:NVIDIA的估值过高,存在泡沫风险。", summary) self.assertIn("第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。", summary) self.assertIn("第三点看法:从技术分析角度看,股价已出现超买信号。", summary) def test_extract_strategy_insight_winner(self): """测试提取策略洞察 - 获胜者""" outcomes = { "winner": "tieguaili", "insights": {} } insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili") self.assertIn("铁拐李在本次辩论中获胜", insight) def test_extract_strategy_insight_from_insights(self): """测试从洞察中提取策略洞察""" outcomes = { "winner": "lvdongbin", "insights": { "tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。" } } insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili") self.assertEqual(insight, "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。") if __name__ == '__main__': # 创建一个异步测试运行器 def run_async_test(test_case): """运行异步测试用例""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(test_case) finally: loop.close() # 获取所有以test_开头的异步方法并运行它们 suite = unittest.TestSuite() test_instance = TestCloudflareMemoryBank() test_instance.setUp() test_instance.addCleanup(test_instance.tearDown) # 添加同步测试 suite.addTest(TestCloudflareMemoryBank('test_init')) suite.addTest(TestCloudflareMemoryBank('test_summarize_conversation')) suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_winner')) suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_from_insights')) # 添加异步测试 async_tests = [ 'test_create_memory_bank', 'test_create_memory_bank_with_display_name', 'test_generate_embedding', 'test_generate_embedding_api_error', 'test_add_memory', 'test_add_memory_api_error', 'test_search_memories', 'test_search_memories_api_error', 'test_get_agent_context', 'test_get_agent_context_no_memories', 'test_save_debate_session' ] for test_name in async_tests: test_method = getattr(test_instance, test_name) suite.addTest(unittest.FunctionTestCase(lambda tm=test_method: run_async_test(tm()))) # 运行测试 runner = unittest.TextTestRunner(verbosity=2) runner.run(suite)