coze-studio/backend/internal/mock/infra/contract/orm/sqlitedb.go

155 lines
3.0 KiB
Go

package orm
import (
"fmt"
"reflect"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// MockDB represents a test database setup helper
type MockDB struct {
db *gorm.DB
tables map[any]*Table
}
// NewMockDB creates a new MockDB instance
func NewMockDB() *MockDB {
return &MockDB{
tables: make(map[any]*Table),
}
}
// AddTable adds a table to the test database
func (s *MockDB) AddTable(tableModel any) *Table {
if tb, ok := s.tables[tableModel]; ok {
return tb
}
tb := &Table{
rows: make([]any, 0, 10),
}
s.tables[tableModel] = tb
return tb
}
type Table struct {
rows []any
}
func (t *Table) AddRows(rows ...any) *Table {
t.rows = append(t.rows, rows...)
return t
}
// DB returns the underlying gorm.DB instance
func (s *MockDB) DB() (*gorm.DB, error) {
db, err := newSQLiteDB(":memory:")
if err != nil {
return nil, err
}
s.db = db
if err := s.setup(); err != nil {
return nil, err
}
return s.db, nil
}
func (s *MockDB) SharedDB(name string) (*gorm.DB, error) {
db, err := newSQLiteDB(fmt.Sprintf("file:%s?mode=memory&cache=shared", name))
if err != nil {
return nil, err
}
s.db = db
if err := s.setup(); err != nil {
return nil, err
}
return s.db, nil
}
// Close cleans up the test database
func (s *MockDB) Close() error {
tables := make([]any, 0, len(s.tables))
for tb := range s.tables {
tables = append(tables, tb)
}
if err := s.tearDown(tables...); err != nil {
return fmt.Errorf("failed to tear down database: %w", err)
}
return nil
}
func (s *MockDB) setup() error {
for tableModel, tb := range s.tables {
// Create tables
if err := s.createTableFromStruct(tableModel); err != nil {
return fmt.Errorf("failed to create table: %w", err)
}
// Insert test data
if err := s.tearUp(tableModel, tb.rows); err != nil {
return fmt.Errorf("failed to insert test data: %w", err)
}
}
return nil
}
// newSQLiteDB creates a new in-memory SQLite database for testing
func newSQLiteDB(dsn string) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, fmt.Errorf("failed to connect database: %w", err)
}
return db, nil
}
// createTableFromStruct creates a table in the database based on the provided struct
func (s *MockDB) createTableFromStruct(model any) error {
if reflect.TypeOf(model).Kind() != reflect.Ptr {
return fmt.Errorf("model must be a pointer to struct")
}
return s.db.AutoMigrate(model)
}
// tearUp inserts test data into the database
func (s *MockDB) tearUp(tableModel any, rows []any) error {
if len(rows) == 0 {
return nil
}
for _, row := range rows {
err := s.db.Model(tableModel).Create(row).Error
if err != nil {
return err
}
}
return nil
}
// tearDown cleans up the test data
func (s *MockDB) tearDown(models ...any) error {
for _, model := range models {
if err := s.db.Where("1 = 1").Delete(model).Error; err != nil {
return err
}
if err := s.db.Migrator().DropTable(model); err != nil {
return err
}
}
return nil
}