litemigrate/migration.go

126 lines
2.4 KiB
Go
Raw Normal View History

package litemigrate
import (
"context"
"database/sql"
"github.com/Southclaws/fault"
"github.com/lmika/gopkgs/fp/maps"
"io/fs"
"log/slog"
"regexp"
"sort"
"strconv"
)
type Migrator struct {
fs fs.FS
db *sql.DB
preStepHooks []func(filename string)
}
func New(fs fs.FS, db *sql.DB, opts ...Option) *Migrator {
migrator := &Migrator{
db: db,
fs: fs,
}
for _, opt := range opts {
opt(migrator)
}
return migrator
}
func (m *Migrator) Version(ctx context.Context) (userVersion int, err error) {
if err := m.inTX(ctx, func(tx *sql.Tx) error {
userVersion, err = m.getVersion(ctx, tx)
return err
}); err != nil {
return 0, fault.Wrap(err)
}
return int(userVersion), nil
}
func (m *Migrator) MigrateUp(ctx context.Context) error {
sf, err := m.readFiles()
if err != nil {
return fault.Wrap(err)
} else if len(sf) == 0 {
return nil
}
return m.inTX(ctx, func(tx *sql.Tx) error {
latestVersion := sf[len(sf)-1].ver
currentVersion, err := m.getVersion(ctx, tx)
if err != nil {
return fault.Wrap(err)
}
if currentVersion == latestVersion {
slog.Debug("no DB migration necessary", "current_version", currentVersion, "latest_version", latestVersion)
return nil
}
slog.Debug("starting migration", "current_version", currentVersion, "latest_version", latestVersion)
for _, mf := range sf {
if mf.ver <= currentVersion {
continue
}
for _, h := range m.preStepHooks {
h(mf.upFile)
}
if err := m.runMigrationFile(ctx, tx, mf.upFile); err != nil {
return err
}
}
return m.setVersion(ctx, tx, latestVersion)
})
}
func (m *Migrator) readFiles() ([]schemaFiles, error) {
de, err := fs.ReadDir(m.fs, ".")
if err != nil {
return nil, fault.Wrap(err)
}
verFiles := make(map[int]schemaFiles)
for _, f := range de {
parts := migrateFileName.FindStringSubmatch(f.Name())
if len(parts) != 3 {
continue
}
versionID, err := strconv.Atoi(parts[1])
if err != nil {
continue
}
vf := verFiles[versionID]
vf.ver = versionID
if parts[2] == "up" {
vf.upFile = f.Name()
} else {
vf.downFile = f.Name()
}
verFiles[versionID] = vf
}
files := maps.Values(verFiles)
sort.Slice(files, func(i, j int) bool { return files[i].ver < files[j].ver })
return files, nil
}
var migrateFileName = regexp.MustCompile(`0*([0-9]+)[-_].*\.(up|down)\.sql`)
type schemaFiles struct {
ver int
upFile string
downFile string
}