liurenchaxin/tests/test_memory_bank_factory.py

63 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Memory Bank 模块测试
"""
import unittest
import os
import sys
from unittest.mock import patch, MagicMock
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.jixia.memory.factory import get_memory_backend
from src.jixia.memory.base_memory_bank import MemoryBankProtocol
class TestMemoryBankFactory(unittest.TestCase):
"""测试记忆银行工厂函数"""
@patch('src.jixia.memory.factory.VertexMemoryBank')
def test_get_memory_backend_always_returns_vertex(self, mock_vertex):
"""测试 get_memory_backend 总是返回 Vertex AI 后端"""
mock_instance = MagicMock()
mock_vertex.from_config.return_value = mock_instance
# 不设置任何环境变量
memory_bank = get_memory_backend()
self.assertEqual(memory_bank, mock_instance)
mock_vertex.from_config.assert_called_once()
@patch('src.jixia.memory.factory.VertexMemoryBank')
def test_get_memory_backend_ignores_prefer_parameter(self, mock_vertex):
"""测试 get_memory_backend 忽略 prefer 参数"""
mock_instance = MagicMock()
mock_vertex.from_config.return_value = mock_instance
# prefer 参数设置为 cloudflare但应被忽略
memory_bank = get_memory_backend(prefer="cloudflare")
self.assertEqual(memory_bank, mock_instance)
mock_vertex.from_config.assert_called_once()
class TestMemoryBankProtocol(unittest.TestCase):
"""测试MemoryBankProtocol协议"""
def test_protocol_methods(self):
"""测试协议定义的方法"""
# 创建一个实现MemoryBankProtocol的简单类用于测试
class TestMemoryBank:
async def create_memory_bank(self, agent_name: str, display_name = None): pass
async def add_memory(self, agent_name: str, content: str, memory_type = "conversation", debate_topic = "", metadata = None): pass
async def search_memories(self, agent_name: str, query: str, memory_type = None, limit = 10): pass
async def get_agent_context(self, agent_name: str, debate_topic: str): pass
async def save_debate_session(self, debate_topic: str, participants, conversation_history, outcomes = None): pass
# 验证TestMemoryBank是否符合MemoryBankProtocol协议
self.assertIsInstance(TestMemoryBank(), MemoryBankProtocol)
if __name__ == '__main__':
# 运行测试
unittest.main()