2024-11-06 22:16:17 +00:00
|
|
|
package litemigrate
|
2024-09-27 23:04:07 +00:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
|
|
|
"github.com/Southclaws/fault"
|
|
|
|
"github.com/lmika/gopkgs/fp/maps"
|
|
|
|
"io/fs"
|
|
|
|
"log/slog"
|
|
|
|
"regexp"
|
|
|
|
"sort"
|
|
|
|
"strconv"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Migrator struct {
|
|
|
|
fs fs.FS
|
|
|
|
db *sql.DB
|
|
|
|
|
|
|
|
preStepHooks []func(filename string)
|
|
|
|
}
|
|
|
|
|
|
|
|
func New(fs fs.FS, db *sql.DB, opts ...Option) *Migrator {
|
|
|
|
migrator := &Migrator{
|
|
|
|
db: db,
|
|
|
|
fs: fs,
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
opt(migrator)
|
|
|
|
}
|
|
|
|
return migrator
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migrator) Version(ctx context.Context) (userVersion int, err error) {
|
|
|
|
if err := m.inTX(ctx, func(tx *sql.Tx) error {
|
|
|
|
userVersion, err = m.getVersion(ctx, tx)
|
|
|
|
return err
|
|
|
|
}); err != nil {
|
|
|
|
return 0, fault.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return int(userVersion), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migrator) MigrateUp(ctx context.Context) error {
|
|
|
|
sf, err := m.readFiles()
|
|
|
|
if err != nil {
|
|
|
|
return fault.Wrap(err)
|
|
|
|
} else if len(sf) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return m.inTX(ctx, func(tx *sql.Tx) error {
|
|
|
|
latestVersion := sf[len(sf)-1].ver
|
|
|
|
|
|
|
|
currentVersion, err := m.getVersion(ctx, tx)
|
|
|
|
if err != nil {
|
|
|
|
return fault.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if currentVersion == latestVersion {
|
|
|
|
slog.Debug("no DB migration necessary", "current_version", currentVersion, "latest_version", latestVersion)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
slog.Debug("starting migration", "current_version", currentVersion, "latest_version", latestVersion)
|
|
|
|
for _, mf := range sf {
|
|
|
|
if mf.ver <= currentVersion {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, h := range m.preStepHooks {
|
|
|
|
h(mf.upFile)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := m.runMigrationFile(ctx, tx, mf.upFile); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return m.setVersion(ctx, tx, latestVersion)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migrator) readFiles() ([]schemaFiles, error) {
|
|
|
|
de, err := fs.ReadDir(m.fs, ".")
|
|
|
|
if err != nil {
|
|
|
|
return nil, fault.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
verFiles := make(map[int]schemaFiles)
|
|
|
|
for _, f := range de {
|
|
|
|
parts := migrateFileName.FindStringSubmatch(f.Name())
|
|
|
|
if len(parts) != 3 {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
versionID, err := strconv.Atoi(parts[1])
|
|
|
|
if err != nil {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
vf := verFiles[versionID]
|
|
|
|
vf.ver = versionID
|
|
|
|
if parts[2] == "up" {
|
|
|
|
vf.upFile = f.Name()
|
|
|
|
} else {
|
|
|
|
vf.downFile = f.Name()
|
|
|
|
}
|
|
|
|
|
|
|
|
verFiles[versionID] = vf
|
|
|
|
}
|
|
|
|
|
|
|
|
files := maps.Values(verFiles)
|
|
|
|
sort.Slice(files, func(i, j int) bool { return files[i].ver < files[j].ver })
|
|
|
|
return files, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var migrateFileName = regexp.MustCompile(`0*([0-9]+)[-_].*\.(up|down)\.sql`)
|
|
|
|
|
|
|
|
type schemaFiles struct {
|
|
|
|
ver int
|
|
|
|
upFile string
|
|
|
|
downFile string
|
|
|
|
}
|