package migrations import ( "database/sql" "fmt" "os" "path/filepath" "sort" "strings" _ "github.com/jackc/pgx/v5/stdlib" ) // Migration represents a database migration type Migration struct { Version string Up string Down string } // Migrator handles database migrations type Migrator struct { db *sql.DB } // NewMigrator creates a new migrator func NewMigrator(db *sql.DB) *Migrator { return &Migrator{db: db} } // RunMigrations runs all pending migrations func (m *Migrator) RunMigrations(migrationsDir string) error { // Create migrations table if it doesn't exist if err := m.createMigrationsTable(); err != nil { return fmt.Errorf("failed to create migrations table: %w", err) } // Load migration files migrations, err := m.loadMigrations(migrationsDir) if err != nil { return fmt.Errorf("failed to load migrations: %w", err) } // Get applied migrations applied, err := m.getAppliedMigrations() if err != nil { return fmt.Errorf("failed to get applied migrations: %w", err) } // Run pending migrations for _, migration := range migrations { if applied[migration.Version] { continue } if err := m.runMigration(migration); err != nil { return fmt.Errorf("failed to run migration %s: %w", migration.Version, err) } } return nil } func (m *Migrator) createMigrationsTable() error { query := ` CREATE TABLE IF NOT EXISTS schema_migrations ( version VARCHAR(255) PRIMARY KEY, applied_at TIMESTAMP DEFAULT NOW() ) ` _, err := m.db.Exec(query) return err } func (m *Migrator) loadMigrations(dir string) ([]Migration, error) { files, err := os.ReadDir(dir) if err != nil { return nil, err } migrations := make(map[string]*Migration) for _, file := range files { if file.IsDir() { continue } filename := file.Name() if !strings.HasSuffix(filename, ".up.sql") && !strings.HasSuffix(filename, ".down.sql") { continue } version := strings.TrimSuffix(filename, ".up.sql") version = strings.TrimSuffix(version, ".down.sql") if migrations[version] == nil { migrations[version] = &Migration{Version: version} } content, err := os.ReadFile(filepath.Join(dir, filename)) if err != nil { return nil, err } if strings.HasSuffix(filename, ".up.sql") { migrations[version].Up = string(content) } else if strings.HasSuffix(filename, ".down.sql") { migrations[version].Down = string(content) } } // Convert to slice and sort result := make([]Migration, 0, len(migrations)) for _, m := range migrations { result = append(result, *m) } sort.Slice(result, func(i, j int) bool { return result[i].Version < result[j].Version }) return result, nil } func (m *Migrator) getAppliedMigrations() (map[string]bool, error) { rows, err := m.db.Query("SELECT version FROM schema_migrations") if err != nil { return nil, err } defer rows.Close() applied := make(map[string]bool) for rows.Next() { var version string if err := rows.Scan(&version); err != nil { return nil, err } applied[version] = true } return applied, rows.Err() } func (m *Migrator) runMigration(migration Migration) error { tx, err := m.db.Begin() if err != nil { return err } defer tx.Rollback() // Execute migration if _, err := tx.Exec(migration.Up); err != nil { return fmt.Errorf("failed to execute migration: %w", err) } // Record migration if _, err := tx.Exec( "INSERT INTO schema_migrations (version) VALUES ($1)", migration.Version, ); err != nil { return fmt.Errorf("failed to record migration: %w", err) } return tx.Commit() }