package migration import ( "context" "database/sql" "github.com/Southclaws/fault" "github.com/lmika/gopkgs/fp/maps" "io/fs" "log/slog" "regexp" "sort" "strconv" ) type Migrator struct { fs fs.FS db *sql.DB preStepHooks []func(filename string) } func New(fs fs.FS, db *sql.DB, opts ...Option) *Migrator { migrator := &Migrator{ db: db, fs: fs, } for _, opt := range opts { opt(migrator) } return migrator } func (m *Migrator) Version(ctx context.Context) (userVersion int, err error) { if err := m.inTX(ctx, func(tx *sql.Tx) error { userVersion, err = m.getVersion(ctx, tx) return err }); err != nil { return 0, fault.Wrap(err) } return int(userVersion), nil } func (m *Migrator) MigrateUp(ctx context.Context) error { sf, err := m.readFiles() if err != nil { return fault.Wrap(err) } else if len(sf) == 0 { return nil } return m.inTX(ctx, func(tx *sql.Tx) error { latestVersion := sf[len(sf)-1].ver currentVersion, err := m.getVersion(ctx, tx) if err != nil { return fault.Wrap(err) } if currentVersion == latestVersion { slog.Debug("no DB migration necessary", "current_version", currentVersion, "latest_version", latestVersion) return nil } slog.Debug("starting migration", "current_version", currentVersion, "latest_version", latestVersion) for _, mf := range sf { if mf.ver <= currentVersion { continue } for _, h := range m.preStepHooks { h(mf.upFile) } if err := m.runMigrationFile(ctx, tx, mf.upFile); err != nil { return err } } return m.setVersion(ctx, tx, latestVersion) }) } func (m *Migrator) readFiles() ([]schemaFiles, error) { de, err := fs.ReadDir(m.fs, ".") if err != nil { return nil, fault.Wrap(err) } verFiles := make(map[int]schemaFiles) for _, f := range de { parts := migrateFileName.FindStringSubmatch(f.Name()) if len(parts) != 3 { continue } versionID, err := strconv.Atoi(parts[1]) if err != nil { continue } vf := verFiles[versionID] vf.ver = versionID if parts[2] == "up" { vf.upFile = f.Name() } else { vf.downFile = f.Name() } verFiles[versionID] = vf } files := maps.Values(verFiles) sort.Slice(files, func(i, j int) bool { return files[i].ver < files[j].ver }) return files, nil } var migrateFileName = regexp.MustCompile(`0*([0-9]+)[-_].*\.(up|down)\.sql`) type schemaFiles struct { ver int upFile string downFile string }