155 lines
3.0 KiB
Go
155 lines
3.0 KiB
Go
package orm
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// MockDB represents a test database setup helper
|
|
type MockDB struct {
|
|
db *gorm.DB
|
|
tables map[any]*Table
|
|
}
|
|
|
|
// NewMockDB creates a new MockDB instance
|
|
func NewMockDB() *MockDB {
|
|
return &MockDB{
|
|
tables: make(map[any]*Table),
|
|
}
|
|
}
|
|
|
|
// AddTable adds a table to the test database
|
|
func (s *MockDB) AddTable(tableModel any) *Table {
|
|
if tb, ok := s.tables[tableModel]; ok {
|
|
return tb
|
|
}
|
|
|
|
tb := &Table{
|
|
rows: make([]any, 0, 10),
|
|
}
|
|
|
|
s.tables[tableModel] = tb
|
|
return tb
|
|
}
|
|
|
|
type Table struct {
|
|
rows []any
|
|
}
|
|
|
|
func (t *Table) AddRows(rows ...any) *Table {
|
|
t.rows = append(t.rows, rows...)
|
|
return t
|
|
}
|
|
|
|
// DB returns the underlying gorm.DB instance
|
|
func (s *MockDB) DB() (*gorm.DB, error) {
|
|
db, err := newSQLiteDB(":memory:")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s.db = db
|
|
|
|
if err := s.setup(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s.db, nil
|
|
}
|
|
|
|
func (s *MockDB) SharedDB(name string) (*gorm.DB, error) {
|
|
db, err := newSQLiteDB(fmt.Sprintf("file:%s?mode=memory&cache=shared", name))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s.db = db
|
|
|
|
if err := s.setup(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s.db, nil
|
|
}
|
|
|
|
// Close cleans up the test database
|
|
func (s *MockDB) Close() error {
|
|
tables := make([]any, 0, len(s.tables))
|
|
for tb := range s.tables {
|
|
tables = append(tables, tb)
|
|
}
|
|
if err := s.tearDown(tables...); err != nil {
|
|
return fmt.Errorf("failed to tear down database: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *MockDB) setup() error {
|
|
|
|
for tableModel, tb := range s.tables {
|
|
// Create tables
|
|
if err := s.createTableFromStruct(tableModel); err != nil {
|
|
return fmt.Errorf("failed to create table: %w", err)
|
|
}
|
|
|
|
// Insert test data
|
|
if err := s.tearUp(tableModel, tb.rows); err != nil {
|
|
return fmt.Errorf("failed to insert test data: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// newSQLiteDB creates a new in-memory SQLite database for testing
|
|
func newSQLiteDB(dsn string) (*gorm.DB, error) {
|
|
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect database: %w", err)
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
// createTableFromStruct creates a table in the database based on the provided struct
|
|
func (s *MockDB) createTableFromStruct(model any) error {
|
|
if reflect.TypeOf(model).Kind() != reflect.Ptr {
|
|
return fmt.Errorf("model must be a pointer to struct")
|
|
}
|
|
|
|
return s.db.AutoMigrate(model)
|
|
}
|
|
|
|
// tearUp inserts test data into the database
|
|
func (s *MockDB) tearUp(tableModel any, rows []any) error {
|
|
if len(rows) == 0 {
|
|
return nil
|
|
}
|
|
|
|
for _, row := range rows {
|
|
err := s.db.Model(tableModel).Create(row).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// tearDown cleans up the test data
|
|
func (s *MockDB) tearDown(models ...any) error {
|
|
for _, model := range models {
|
|
if err := s.db.Where("1 = 1").Delete(model).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.db.Migrator().DropTable(model); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|