liurenchaxin/tests/test_cloudflare_memory_bank.py

436 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Cloudflare Memory Bank 实现测试
"""
import unittest
import asyncio
import os
import sys
from unittest.mock import patch, MagicMock, AsyncMock
from datetime import datetime
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.jixia.memory.cloudflare_memory_bank import CloudflareMemoryBank, MemoryEntry
class TestCloudflareMemoryBank(unittest.TestCase):
"""测试CloudflareMemoryBank类"""
def setUp(self):
"""测试前的设置"""
# Mock掉 aiohttp.ClientSession 以避免实际网络请求
self.patcher = patch('src.jixia.memory.cloudflare_memory_bank.aiohttp.ClientSession')
self.mock_session_class = self.patcher.start()
self.mock_session = AsyncMock()
self.mock_session_class.return_value = self.mock_session
# Mock掉 get_cloudflare_config 以避免实际读取配置
self.config_patcher = patch('src.jixia.memory.cloudflare_memory_bank.get_cloudflare_config')
self.mock_get_config = self.config_patcher.start()
self.mock_get_config.return_value = {
'account_id': 'test-account',
'api_token': 'test-token',
'vectorize_index': 'test-index',
'embed_model': '@cf/baai/bge-m3',
'autorag_domain': 'test.example.com'
}
# 创建CloudflareMemoryBank实例
self.memory_bank = CloudflareMemoryBank()
# 重置一些内部状态
self.memory_bank.config = self.mock_get_config.return_value
self.memory_bank.account_id = 'test-account'
self.memory_bank.api_token = 'test-token'
self.memory_bank.vectorize_index = 'test-index'
self.memory_bank.embed_model = '@cf/baai/bge-m3'
self.memory_bank.autorag_domain = 'test.example.com'
def tearDown(self):
"""测试后的清理"""
self.patcher.stop()
self.config_patcher.stop()
def test_init(self):
"""测试初始化"""
self.assertEqual(self.memory_bank.account_id, "test-account")
self.assertEqual(self.memory_bank.api_token, "test-token")
self.assertEqual(self.memory_bank.vectorize_index, "test-index")
self.assertEqual(self.memory_bank.embed_model, "@cf/baai/bge-m3")
self.assertEqual(self.memory_bank.autorag_domain, "test.example.com")
async def test_create_memory_bank(self):
"""测试创建记忆空间"""
memory_bank_id = await self.memory_bank.create_memory_bank("tieguaili")
# 验证返回的ID格式
self.assertEqual(memory_bank_id, "cf_memory_tieguaili")
async def test_create_memory_bank_with_display_name(self):
"""测试创建记忆空间时指定显示名称"""
memory_bank_id = await self.memory_bank.create_memory_bank(
"tieguaili",
"铁拐李的专属记忆银行"
)
# 验证返回的ID格式
self.assertEqual(memory_bank_id, "cf_memory_tieguaili")
async def test_generate_embedding(self):
"""测试生成嵌入向量"""
# Mock响应
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
"result": {
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
}
]
}
})
# Mock session.post
self.mock_session.post.return_value.__aenter__.return_value = mock_response
# 调用方法
embedding = await self.memory_bank._generate_embedding("测试文本")
# 验证结果
self.assertEqual(embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
# 验证调用了正确的URL和参数
expected_url = "https://api.cloudflare.com/client/v4/accounts/test-account/ai/run/@cf/baai/bge-m3"
self.mock_session.post.assert_called_once()
call_args = self.mock_session.post.call_args
self.assertEqual(call_args[0][0], expected_url)
self.assertEqual(call_args[1]['json'], {"text": ["测试文本"]})
async def test_generate_embedding_api_error(self):
"""测试生成嵌入向量时API错误"""
# Mock响应
mock_response = AsyncMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
# Mock session.post
self.mock_session.post.return_value.__aenter__.return_value = mock_response
# 验证抛出异常
with self.assertRaises(Exception) as context:
await self.memory_bank._generate_embedding("测试文本")
self.assertIn("Failed to generate embedding", str(context.exception))
async def test_add_memory(self):
"""测试添加记忆"""
# Mock _generate_embedding 方法
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed:
# Mock upsert 响应
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={"result": {"upserted": 1}})
# Mock session.post
self.mock_session.post.return_value.__aenter__.return_value = mock_response
# 添加记忆
memory_id = await self.memory_bank.add_memory(
agent_name="tieguaili",
content="在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。",
memory_type="preference",
debate_topic="NVIDIA投资分析",
metadata={"source": "manual"}
)
# 验证返回的ID格式 (以mem_开头)
self.assertTrue(memory_id.startswith("mem_tieguaili_"))
# 验证调用了生成嵌入的方法
mock_embed.assert_called_once_with("在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。")
# 验证调用了upsert API
self.mock_session.post.assert_called()
# 验证upsert调用的参数
upsert_call = None
for call in self.mock_session.post.call_args_list:
if 'vectorize/indexes/test-index/upsert' in call[0][0]:
upsert_call = call
break
self.assertIsNotNone(upsert_call)
call_args, call_kwargs = upsert_call
self.assertIn("vectorize/indexes/test-index/upsert", call_args[0])
self.assertIn("vectors", call_kwargs['json'])
async def test_add_memory_api_error(self):
"""测试添加记忆时API错误"""
# Mock _generate_embedding 方法
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])):
# Mock upsert 响应
mock_response = AsyncMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
# Mock session.post
self.mock_session.post.return_value.__aenter__.return_value = mock_response
# 验证抛出异常
with self.assertRaises(Exception) as context:
await self.memory_bank.add_memory(
agent_name="tieguaili",
content="测试内容"
)
self.assertIn("Failed to upsert memory", str(context.exception))
async def test_search_memories(self):
"""测试搜索记忆"""
# Mock _generate_embedding 方法
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed:
# Mock query 响应
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
"result": {
"matches": [
{
"metadata": {
"content": "在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。",
"memory_type": "preference",
"agent_name": "tieguaili"
},
"score": 0.95
}
]
}
})
# Mock session.post
self.mock_session.post.return_value.__aenter__.return_value = mock_response
# 搜索记忆
results = await self.memory_bank.search_memories(
agent_name="tieguaili",
query="NVIDIA",
memory_type="preference",
limit=5
)
# 验证结果
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["content"], "在讨论NVIDIA股票时我倾向于逆向思维关注潜在风险。")
self.assertEqual(results[0]["relevance_score"], 0.95)
# 验证调用了生成嵌入的方法
mock_embed.assert_called_once_with("NVIDIA")
# 验证调用了query API
self.mock_session.post.assert_called()
# 验证query调用的参数
query_call = None
for call in self.mock_session.post.call_args_list:
if 'vectorize/indexes/test-index/query' in call[0][0]:
query_call = call
break
self.assertIsNotNone(query_call)
call_args, call_kwargs = query_call
self.assertIn("vectorize/indexes/test-index/query", call_args[0])
self.assertIn("vector", call_kwargs['json'])
self.assertIn("filter", call_kwargs['json'])
self.assertEqual(call_kwargs['json']['filter'], {"agent_name": "tieguaili", "memory_type": "preference"})
async def test_search_memories_api_error(self):
"""测试搜索记忆时API错误"""
# Mock _generate_embedding 方法
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])):
# Mock query 响应
mock_response = AsyncMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
# Mock session.post
self.mock_session.post.return_value.__aenter__.return_value = mock_response
# 验证返回空列表而不是抛出异常
results = await self.memory_bank.search_memories(
agent_name="tieguaili",
query="NVIDIA"
)
self.assertEqual(results, [])
async def test_get_agent_context(self):
"""测试获取智能体上下文"""
# Mock search_memories 方法
with patch.object(self.memory_bank, 'search_memories', new=AsyncMock()) as mock_search:
# 设置mock返回值
mock_search.side_effect = [
[ # conversation memories
{"content": "NVIDIA的估值过高存在泡沫风险。", "relevance_score": 0.9}
],
[ # preference memories
{"content": "倾向于逆向思维,关注潜在风险。", "relevance_score": 0.8}
],
[ # strategy memories
{"content": "使用技术分析策略。", "relevance_score": 0.7}
]
]
# 获取上下文
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
# 验证上下文包含预期内容
self.assertIn("# 铁拐李的记忆上下文", context)
self.assertIn("## 历史对话记忆", context)
self.assertIn("## 偏好记忆", context)
self.assertIn("## 策略记忆", context)
self.assertIn("NVIDIA的估值过高存在泡沫风险。", context)
self.assertIn("倾向于逆向思维,关注潜在风险。", context)
self.assertIn("使用技术分析策略。", context)
async def test_get_agent_context_no_memories(self):
"""测试获取智能体上下文但无相关记忆"""
# Mock search_memories 方法
with patch.object(self.memory_bank, 'search_memories', new=AsyncMock(return_value=[])):
# 获取上下文
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
# 验证上下文包含暂无相关记忆的提示
self.assertIn("# 铁拐李的记忆上下文", context)
self.assertIn("暂无相关记忆。", context)
async def test_save_debate_session(self):
"""测试保存辩论会话"""
# Mock add_memory 方法
with patch.object(self.memory_bank, 'add_memory', new=AsyncMock()) as mock_add:
conversation_history = [
{"agent": "tieguaili", "content": "NVIDIA的估值过高存在泡沫风险。"},
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
{"agent": "tieguaili", "content": "但我们需要考虑竞争加剧和增长放缓的可能性。"}
]
outcomes = {
"winner": "lvdongbin",
"insights": {
"tieguaili": "铁拐李的风险意识值得肯定但在AI趋势的判断上略显保守。"
}
}
# 保存辩论会话
await self.memory_bank.save_debate_session(
debate_topic="NVIDIA投资分析",
participants=["tieguaili", "lvdongbin"],
conversation_history=conversation_history,
outcomes=outcomes
)
# 验证调用了add_memory两次对话总结和策略洞察
self.assertEqual(mock_add.call_count, 2)
# 验证第一次调用是对话总结
call_args1 = mock_add.call_args_list[0][1]
self.assertEqual(call_args1['agent_name'], 'tieguaili')
self.assertEqual(call_args1['memory_type'], 'conversation')
self.assertIn('铁拐李在本次辩论中的主要观点', call_args1['content'])
# 验证第二次调用是策略洞察
call_args2 = mock_add.call_args_list[1][1]
self.assertEqual(call_args2['agent_name'], 'tieguaili')
self.assertEqual(call_args2['memory_type'], 'strategy')
self.assertIn('铁拐李的风险意识值得肯定', call_args2['content'])
def test_summarize_conversation(self):
"""测试对话总结"""
conversation_history = [
{"agent": "tieguaili", "content": "第一点看法NVIDIA的估值过高存在泡沫风险。"},
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
{"agent": "tieguaili", "content": "第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。"},
{"agent": "tieguaili", "content": "第三点看法:从技术分析角度看,股价已出现超买信号。"}
]
summary = self.memory_bank._summarize_conversation(conversation_history, "tieguaili")
# 验证总结包含预期内容
self.assertIn("铁拐李在本次辩论中的主要观点", summary)
self.assertIn("第一点看法NVIDIA的估值过高存在泡沫风险。", summary)
self.assertIn("第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。", summary)
self.assertIn("第三点看法:从技术分析角度看,股价已出现超买信号。", summary)
def test_extract_strategy_insight_winner(self):
"""测试提取策略洞察 - 获胜者"""
outcomes = {
"winner": "tieguaili",
"insights": {}
}
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
self.assertIn("铁拐李在本次辩论中获胜", insight)
def test_extract_strategy_insight_from_insights(self):
"""测试从洞察中提取策略洞察"""
outcomes = {
"winner": "lvdongbin",
"insights": {
"tieguaili": "铁拐李的风险意识值得肯定但在AI趋势的判断上略显保守。"
}
}
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
self.assertEqual(insight, "铁拐李的风险意识值得肯定但在AI趋势的判断上略显保守。")
if __name__ == '__main__':
# 创建一个异步测试运行器
def run_async_test(test_case):
"""运行异步测试用例"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(test_case)
finally:
loop.close()
# 获取所有以test_开头的异步方法并运行它们
suite = unittest.TestSuite()
test_instance = TestCloudflareMemoryBank()
test_instance.setUp()
test_instance.addCleanup(test_instance.tearDown)
# 添加同步测试
suite.addTest(TestCloudflareMemoryBank('test_init'))
suite.addTest(TestCloudflareMemoryBank('test_summarize_conversation'))
suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_winner'))
suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_from_insights'))
# 添加异步测试
async_tests = [
'test_create_memory_bank',
'test_create_memory_bank_with_display_name',
'test_generate_embedding',
'test_generate_embedding_api_error',
'test_add_memory',
'test_add_memory_api_error',
'test_search_memories',
'test_search_memories_api_error',
'test_get_agent_context',
'test_get_agent_context_no_memories',
'test_save_debate_session'
]
for test_name in async_tests:
test_method = getattr(test_instance, test_name)
suite.addTest(unittest.FunctionTestCase(lambda tm=test_method: run_async_test(tm())))
# 运行测试
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)