436 lines
18 KiB
Python
436 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Cloudflare Memory Bank 实现测试
|
||
"""
|
||
|
||
import unittest
|
||
import asyncio
|
||
import os
|
||
import sys
|
||
from unittest.mock import patch, MagicMock, AsyncMock
|
||
from datetime import datetime
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
|
||
from src.jixia.memory.cloudflare_memory_bank import CloudflareMemoryBank, MemoryEntry
|
||
|
||
|
||
class TestCloudflareMemoryBank(unittest.TestCase):
|
||
"""测试CloudflareMemoryBank类"""
|
||
|
||
def setUp(self):
|
||
"""测试前的设置"""
|
||
# Mock掉 aiohttp.ClientSession 以避免实际网络请求
|
||
self.patcher = patch('src.jixia.memory.cloudflare_memory_bank.aiohttp.ClientSession')
|
||
self.mock_session_class = self.patcher.start()
|
||
self.mock_session = AsyncMock()
|
||
self.mock_session_class.return_value = self.mock_session
|
||
|
||
# Mock掉 get_cloudflare_config 以避免实际读取配置
|
||
self.config_patcher = patch('src.jixia.memory.cloudflare_memory_bank.get_cloudflare_config')
|
||
self.mock_get_config = self.config_patcher.start()
|
||
self.mock_get_config.return_value = {
|
||
'account_id': 'test-account',
|
||
'api_token': 'test-token',
|
||
'vectorize_index': 'test-index',
|
||
'embed_model': '@cf/baai/bge-m3',
|
||
'autorag_domain': 'test.example.com'
|
||
}
|
||
|
||
# 创建CloudflareMemoryBank实例
|
||
self.memory_bank = CloudflareMemoryBank()
|
||
|
||
# 重置一些内部状态
|
||
self.memory_bank.config = self.mock_get_config.return_value
|
||
self.memory_bank.account_id = 'test-account'
|
||
self.memory_bank.api_token = 'test-token'
|
||
self.memory_bank.vectorize_index = 'test-index'
|
||
self.memory_bank.embed_model = '@cf/baai/bge-m3'
|
||
self.memory_bank.autorag_domain = 'test.example.com'
|
||
|
||
def tearDown(self):
|
||
"""测试后的清理"""
|
||
self.patcher.stop()
|
||
self.config_patcher.stop()
|
||
|
||
def test_init(self):
|
||
"""测试初始化"""
|
||
self.assertEqual(self.memory_bank.account_id, "test-account")
|
||
self.assertEqual(self.memory_bank.api_token, "test-token")
|
||
self.assertEqual(self.memory_bank.vectorize_index, "test-index")
|
||
self.assertEqual(self.memory_bank.embed_model, "@cf/baai/bge-m3")
|
||
self.assertEqual(self.memory_bank.autorag_domain, "test.example.com")
|
||
|
||
async def test_create_memory_bank(self):
|
||
"""测试创建记忆空间"""
|
||
memory_bank_id = await self.memory_bank.create_memory_bank("tieguaili")
|
||
|
||
# 验证返回的ID格式
|
||
self.assertEqual(memory_bank_id, "cf_memory_tieguaili")
|
||
|
||
async def test_create_memory_bank_with_display_name(self):
|
||
"""测试创建记忆空间时指定显示名称"""
|
||
memory_bank_id = await self.memory_bank.create_memory_bank(
|
||
"tieguaili",
|
||
"铁拐李的专属记忆银行"
|
||
)
|
||
|
||
# 验证返回的ID格式
|
||
self.assertEqual(memory_bank_id, "cf_memory_tieguaili")
|
||
|
||
async def test_generate_embedding(self):
|
||
"""测试生成嵌入向量"""
|
||
# Mock响应
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={
|
||
"result": {
|
||
"data": [
|
||
{
|
||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
|
||
}
|
||
]
|
||
}
|
||
})
|
||
|
||
# Mock session.post
|
||
self.mock_session.post.return_value.__aenter__.return_value = mock_response
|
||
|
||
# 调用方法
|
||
embedding = await self.memory_bank._generate_embedding("测试文本")
|
||
|
||
# 验证结果
|
||
self.assertEqual(embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||
|
||
# 验证调用了正确的URL和参数
|
||
expected_url = "https://api.cloudflare.com/client/v4/accounts/test-account/ai/run/@cf/baai/bge-m3"
|
||
self.mock_session.post.assert_called_once()
|
||
call_args = self.mock_session.post.call_args
|
||
self.assertEqual(call_args[0][0], expected_url)
|
||
self.assertEqual(call_args[1]['json'], {"text": ["测试文本"]})
|
||
|
||
async def test_generate_embedding_api_error(self):
|
||
"""测试生成嵌入向量时API错误"""
|
||
# Mock响应
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 500
|
||
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
||
|
||
# Mock session.post
|
||
self.mock_session.post.return_value.__aenter__.return_value = mock_response
|
||
|
||
# 验证抛出异常
|
||
with self.assertRaises(Exception) as context:
|
||
await self.memory_bank._generate_embedding("测试文本")
|
||
|
||
self.assertIn("Failed to generate embedding", str(context.exception))
|
||
|
||
async def test_add_memory(self):
|
||
"""测试添加记忆"""
|
||
# Mock _generate_embedding 方法
|
||
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed:
|
||
# Mock upsert 响应
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={"result": {"upserted": 1}})
|
||
|
||
# Mock session.post
|
||
self.mock_session.post.return_value.__aenter__.return_value = mock_response
|
||
|
||
# 添加记忆
|
||
memory_id = await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
|
||
memory_type="preference",
|
||
debate_topic="NVIDIA投资分析",
|
||
metadata={"source": "manual"}
|
||
)
|
||
|
||
# 验证返回的ID格式 (以mem_开头)
|
||
self.assertTrue(memory_id.startswith("mem_tieguaili_"))
|
||
|
||
# 验证调用了生成嵌入的方法
|
||
mock_embed.assert_called_once_with("在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
|
||
|
||
# 验证调用了upsert API
|
||
self.mock_session.post.assert_called()
|
||
|
||
# 验证upsert调用的参数
|
||
upsert_call = None
|
||
for call in self.mock_session.post.call_args_list:
|
||
if 'vectorize/indexes/test-index/upsert' in call[0][0]:
|
||
upsert_call = call
|
||
break
|
||
|
||
self.assertIsNotNone(upsert_call)
|
||
call_args, call_kwargs = upsert_call
|
||
self.assertIn("vectorize/indexes/test-index/upsert", call_args[0])
|
||
self.assertIn("vectors", call_kwargs['json'])
|
||
|
||
async def test_add_memory_api_error(self):
|
||
"""测试添加记忆时API错误"""
|
||
# Mock _generate_embedding 方法
|
||
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])):
|
||
# Mock upsert 响应
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 500
|
||
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
||
|
||
# Mock session.post
|
||
self.mock_session.post.return_value.__aenter__.return_value = mock_response
|
||
|
||
# 验证抛出异常
|
||
with self.assertRaises(Exception) as context:
|
||
await self.memory_bank.add_memory(
|
||
agent_name="tieguaili",
|
||
content="测试内容"
|
||
)
|
||
|
||
self.assertIn("Failed to upsert memory", str(context.exception))
|
||
|
||
async def test_search_memories(self):
|
||
"""测试搜索记忆"""
|
||
# Mock _generate_embedding 方法
|
||
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])) as mock_embed:
|
||
# Mock query 响应
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 200
|
||
mock_response.json = AsyncMock(return_value={
|
||
"result": {
|
||
"matches": [
|
||
{
|
||
"metadata": {
|
||
"content": "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。",
|
||
"memory_type": "preference",
|
||
"agent_name": "tieguaili"
|
||
},
|
||
"score": 0.95
|
||
}
|
||
]
|
||
}
|
||
})
|
||
|
||
# Mock session.post
|
||
self.mock_session.post.return_value.__aenter__.return_value = mock_response
|
||
|
||
# 搜索记忆
|
||
results = await self.memory_bank.search_memories(
|
||
agent_name="tieguaili",
|
||
query="NVIDIA",
|
||
memory_type="preference",
|
||
limit=5
|
||
)
|
||
|
||
# 验证结果
|
||
self.assertEqual(len(results), 1)
|
||
self.assertEqual(results[0]["content"], "在讨论NVIDIA股票时,我倾向于逆向思维,关注潜在风险。")
|
||
self.assertEqual(results[0]["relevance_score"], 0.95)
|
||
|
||
# 验证调用了生成嵌入的方法
|
||
mock_embed.assert_called_once_with("NVIDIA")
|
||
|
||
# 验证调用了query API
|
||
self.mock_session.post.assert_called()
|
||
|
||
# 验证query调用的参数
|
||
query_call = None
|
||
for call in self.mock_session.post.call_args_list:
|
||
if 'vectorize/indexes/test-index/query' in call[0][0]:
|
||
query_call = call
|
||
break
|
||
|
||
self.assertIsNotNone(query_call)
|
||
call_args, call_kwargs = query_call
|
||
self.assertIn("vectorize/indexes/test-index/query", call_args[0])
|
||
self.assertIn("vector", call_kwargs['json'])
|
||
self.assertIn("filter", call_kwargs['json'])
|
||
self.assertEqual(call_kwargs['json']['filter'], {"agent_name": "tieguaili", "memory_type": "preference"})
|
||
|
||
async def test_search_memories_api_error(self):
|
||
"""测试搜索记忆时API错误"""
|
||
# Mock _generate_embedding 方法
|
||
with patch.object(self.memory_bank, '_generate_embedding', new=AsyncMock(return_value=[0.1, 0.2, 0.3])):
|
||
# Mock query 响应
|
||
mock_response = AsyncMock()
|
||
mock_response.status = 500
|
||
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
||
|
||
# Mock session.post
|
||
self.mock_session.post.return_value.__aenter__.return_value = mock_response
|
||
|
||
# 验证返回空列表而不是抛出异常
|
||
results = await self.memory_bank.search_memories(
|
||
agent_name="tieguaili",
|
||
query="NVIDIA"
|
||
)
|
||
|
||
self.assertEqual(results, [])
|
||
|
||
async def test_get_agent_context(self):
|
||
"""测试获取智能体上下文"""
|
||
# Mock search_memories 方法
|
||
with patch.object(self.memory_bank, 'search_memories', new=AsyncMock()) as mock_search:
|
||
# 设置mock返回值
|
||
mock_search.side_effect = [
|
||
[ # conversation memories
|
||
{"content": "NVIDIA的估值过高,存在泡沫风险。", "relevance_score": 0.9}
|
||
],
|
||
[ # preference memories
|
||
{"content": "倾向于逆向思维,关注潜在风险。", "relevance_score": 0.8}
|
||
],
|
||
[ # strategy memories
|
||
{"content": "使用技术分析策略。", "relevance_score": 0.7}
|
||
]
|
||
]
|
||
|
||
# 获取上下文
|
||
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
|
||
|
||
# 验证上下文包含预期内容
|
||
self.assertIn("# 铁拐李的记忆上下文", context)
|
||
self.assertIn("## 历史对话记忆", context)
|
||
self.assertIn("## 偏好记忆", context)
|
||
self.assertIn("## 策略记忆", context)
|
||
self.assertIn("NVIDIA的估值过高,存在泡沫风险。", context)
|
||
self.assertIn("倾向于逆向思维,关注潜在风险。", context)
|
||
self.assertIn("使用技术分析策略。", context)
|
||
|
||
async def test_get_agent_context_no_memories(self):
|
||
"""测试获取智能体上下文但无相关记忆"""
|
||
# Mock search_memories 方法
|
||
with patch.object(self.memory_bank, 'search_memories', new=AsyncMock(return_value=[])):
|
||
# 获取上下文
|
||
context = await self.memory_bank.get_agent_context("tieguaili", "NVIDIA投资分析")
|
||
|
||
# 验证上下文包含暂无相关记忆的提示
|
||
self.assertIn("# 铁拐李的记忆上下文", context)
|
||
self.assertIn("暂无相关记忆。", context)
|
||
|
||
async def test_save_debate_session(self):
|
||
"""测试保存辩论会话"""
|
||
# Mock add_memory 方法
|
||
with patch.object(self.memory_bank, 'add_memory', new=AsyncMock()) as mock_add:
|
||
conversation_history = [
|
||
{"agent": "tieguaili", "content": "NVIDIA的估值过高,存在泡沫风险。"},
|
||
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
|
||
{"agent": "tieguaili", "content": "但我们需要考虑竞争加剧和增长放缓的可能性。"}
|
||
]
|
||
|
||
outcomes = {
|
||
"winner": "lvdongbin",
|
||
"insights": {
|
||
"tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
|
||
}
|
||
}
|
||
|
||
# 保存辩论会话
|
||
await self.memory_bank.save_debate_session(
|
||
debate_topic="NVIDIA投资分析",
|
||
participants=["tieguaili", "lvdongbin"],
|
||
conversation_history=conversation_history,
|
||
outcomes=outcomes
|
||
)
|
||
|
||
# 验证调用了add_memory两次(对话总结和策略洞察)
|
||
self.assertEqual(mock_add.call_count, 2)
|
||
|
||
# 验证第一次调用是对话总结
|
||
call_args1 = mock_add.call_args_list[0][1]
|
||
self.assertEqual(call_args1['agent_name'], 'tieguaili')
|
||
self.assertEqual(call_args1['memory_type'], 'conversation')
|
||
self.assertIn('铁拐李在本次辩论中的主要观点', call_args1['content'])
|
||
|
||
# 验证第二次调用是策略洞察
|
||
call_args2 = mock_add.call_args_list[1][1]
|
||
self.assertEqual(call_args2['agent_name'], 'tieguaili')
|
||
self.assertEqual(call_args2['memory_type'], 'strategy')
|
||
self.assertIn('铁拐李的风险意识值得肯定', call_args2['content'])
|
||
|
||
def test_summarize_conversation(self):
|
||
"""测试对话总结"""
|
||
conversation_history = [
|
||
{"agent": "tieguaili", "content": "第一点看法:NVIDIA的估值过高,存在泡沫风险。"},
|
||
{"agent": "lvdongbin", "content": "NVIDIA在AI领域的领先地位不可忽视。"},
|
||
{"agent": "tieguaili", "content": "第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。"},
|
||
{"agent": "tieguaili", "content": "第三点看法:从技术分析角度看,股价已出现超买信号。"}
|
||
]
|
||
|
||
summary = self.memory_bank._summarize_conversation(conversation_history, "tieguaili")
|
||
|
||
# 验证总结包含预期内容
|
||
self.assertIn("铁拐李在本次辩论中的主要观点", summary)
|
||
self.assertIn("第一点看法:NVIDIA的估值过高,存在泡沫风险。", summary)
|
||
self.assertIn("第二点看法:我们需要考虑竞争加剧和增长放缓的可能性。", summary)
|
||
self.assertIn("第三点看法:从技术分析角度看,股价已出现超买信号。", summary)
|
||
|
||
def test_extract_strategy_insight_winner(self):
|
||
"""测试提取策略洞察 - 获胜者"""
|
||
outcomes = {
|
||
"winner": "tieguaili",
|
||
"insights": {}
|
||
}
|
||
|
||
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
|
||
|
||
self.assertIn("铁拐李在本次辩论中获胜", insight)
|
||
|
||
def test_extract_strategy_insight_from_insights(self):
|
||
"""测试从洞察中提取策略洞察"""
|
||
outcomes = {
|
||
"winner": "lvdongbin",
|
||
"insights": {
|
||
"tieguaili": "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。"
|
||
}
|
||
}
|
||
|
||
insight = self.memory_bank._extract_strategy_insight(outcomes, "tieguaili")
|
||
|
||
self.assertEqual(insight, "铁拐李的风险意识值得肯定,但在AI趋势的判断上略显保守。")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 创建一个异步测试运行器
|
||
def run_async_test(test_case):
|
||
"""运行异步测试用例"""
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(test_case)
|
||
finally:
|
||
loop.close()
|
||
|
||
# 获取所有以test_开头的异步方法并运行它们
|
||
suite = unittest.TestSuite()
|
||
test_instance = TestCloudflareMemoryBank()
|
||
test_instance.setUp()
|
||
test_instance.addCleanup(test_instance.tearDown)
|
||
|
||
# 添加同步测试
|
||
suite.addTest(TestCloudflareMemoryBank('test_init'))
|
||
suite.addTest(TestCloudflareMemoryBank('test_summarize_conversation'))
|
||
suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_winner'))
|
||
suite.addTest(TestCloudflareMemoryBank('test_extract_strategy_insight_from_insights'))
|
||
|
||
# 添加异步测试
|
||
async_tests = [
|
||
'test_create_memory_bank',
|
||
'test_create_memory_bank_with_display_name',
|
||
'test_generate_embedding',
|
||
'test_generate_embedding_api_error',
|
||
'test_add_memory',
|
||
'test_add_memory_api_error',
|
||
'test_search_memories',
|
||
'test_search_memories_api_error',
|
||
'test_get_agent_context',
|
||
'test_get_agent_context_no_memories',
|
||
'test_save_debate_session'
|
||
]
|
||
|
||
for test_name in async_tests:
|
||
test_method = getattr(test_instance, test_name)
|
||
suite.addTest(unittest.FunctionTestCase(lambda tm=test_method: run_async_test(tm())))
|
||
|
||
# 运行测试
|
||
runner = unittest.TextTestRunner(verbosity=2)
|
||
runner.run(suite) |