#!/usr/bin/env python3 """ Memory Bank 模块测试 """ import unittest import os import sys from unittest.mock import patch, MagicMock # 添加项目根目录到Python路径 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from src.jixia.memory.factory import get_memory_backend from src.jixia.memory.base_memory_bank import MemoryBankProtocol class TestMemoryBankFactory(unittest.TestCase): """测试记忆银行工厂函数""" @patch('src.jixia.memory.factory.VertexMemoryBank') def test_get_memory_backend_always_returns_vertex(self, mock_vertex): """测试 get_memory_backend 总是返回 Vertex AI 后端""" mock_instance = MagicMock() mock_vertex.from_config.return_value = mock_instance # 不设置任何环境变量 memory_bank = get_memory_backend() self.assertEqual(memory_bank, mock_instance) mock_vertex.from_config.assert_called_once() @patch('src.jixia.memory.factory.VertexMemoryBank') def test_get_memory_backend_ignores_prefer_parameter(self, mock_vertex): """测试 get_memory_backend 忽略 prefer 参数""" mock_instance = MagicMock() mock_vertex.from_config.return_value = mock_instance # prefer 参数设置为 cloudflare,但应被忽略 memory_bank = get_memory_backend(prefer="cloudflare") self.assertEqual(memory_bank, mock_instance) mock_vertex.from_config.assert_called_once() class TestMemoryBankProtocol(unittest.TestCase): """测试MemoryBankProtocol协议""" def test_protocol_methods(self): """测试协议定义的方法""" # 创建一个实现MemoryBankProtocol的简单类用于测试 class TestMemoryBank: async def create_memory_bank(self, agent_name: str, display_name = None): pass async def add_memory(self, agent_name: str, content: str, memory_type = "conversation", debate_topic = "", metadata = None): pass async def search_memories(self, agent_name: str, query: str, memory_type = None, limit = 10): pass async def get_agent_context(self, agent_name: str, debate_topic: str): pass async def save_debate_session(self, debate_topic: str, participants, conversation_history, outcomes = None): pass # 验证TestMemoryBank是否符合MemoryBankProtocol协议 self.assertIsInstance(TestMemoryBank(), MemoryBankProtocol) if __name__ == '__main__': # 运行测试 unittest.main()