63 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			63 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
#!/usr/bin/env python3
 | 
						||
"""
 | 
						||
Memory Bank 模块测试
 | 
						||
"""
 | 
						||
 | 
						||
import unittest
 | 
						||
import os
 | 
						||
import sys
 | 
						||
from unittest.mock import patch, MagicMock
 | 
						||
 | 
						||
# 添加项目根目录到Python路径
 | 
						||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
 | 
						||
 | 
						||
from src.jixia.memory.factory import get_memory_backend
 | 
						||
from src.jixia.memory.base_memory_bank import MemoryBankProtocol
 | 
						||
 | 
						||
 | 
						||
class TestMemoryBankFactory(unittest.TestCase):
 | 
						||
    """测试记忆银行工厂函数"""
 | 
						||
 | 
						||
    @patch('src.jixia.memory.factory.VertexMemoryBank')
 | 
						||
    def test_get_memory_backend_always_returns_vertex(self, mock_vertex):
 | 
						||
        """测试 get_memory_backend 总是返回 Vertex AI 后端"""
 | 
						||
        mock_instance = MagicMock()
 | 
						||
        mock_vertex.from_config.return_value = mock_instance
 | 
						||
 | 
						||
        # 不设置任何环境变量
 | 
						||
        memory_bank = get_memory_backend()
 | 
						||
        self.assertEqual(memory_bank, mock_instance)
 | 
						||
        mock_vertex.from_config.assert_called_once()
 | 
						||
 | 
						||
    @patch('src.jixia.memory.factory.VertexMemoryBank')
 | 
						||
    def test_get_memory_backend_ignores_prefer_parameter(self, mock_vertex):
 | 
						||
        """测试 get_memory_backend 忽略 prefer 参数"""
 | 
						||
        mock_instance = MagicMock()
 | 
						||
        mock_vertex.from_config.return_value = mock_instance
 | 
						||
 | 
						||
        # prefer 参数设置为 cloudflare,但应被忽略
 | 
						||
        memory_bank = get_memory_backend(prefer="cloudflare")
 | 
						||
        self.assertEqual(memory_bank, mock_instance)
 | 
						||
        mock_vertex.from_config.assert_called_once()
 | 
						||
 | 
						||
 | 
						||
class TestMemoryBankProtocol(unittest.TestCase):
 | 
						||
    """测试MemoryBankProtocol协议"""
 | 
						||
 | 
						||
    def test_protocol_methods(self):
 | 
						||
        """测试协议定义的方法"""
 | 
						||
        # 创建一个实现MemoryBankProtocol的简单类用于测试
 | 
						||
        class TestMemoryBank:
 | 
						||
            async def create_memory_bank(self, agent_name: str, display_name = None): pass
 | 
						||
            async def add_memory(self, agent_name: str, content: str, memory_type = "conversation", debate_topic = "", metadata = None): pass
 | 
						||
            async def search_memories(self, agent_name: str, query: str, memory_type = None, limit = 10): pass
 | 
						||
            async def get_agent_context(self, agent_name: str, debate_topic: str): pass
 | 
						||
            async def save_debate_session(self, debate_topic: str, participants, conversation_history, outcomes = None): pass
 | 
						||
        
 | 
						||
        # 验证TestMemoryBank是否符合MemoryBankProtocol协议
 | 
						||
        self.assertIsInstance(TestMemoryBank(), MemoryBankProtocol)
 | 
						||
 | 
						||
 | 
						||
if __name__ == '__main__':
 | 
						||
    # 运行测试
 | 
						||
    unittest.main() |