litemigrate/db.go
2024-11-06 22:16:17 +00:00

52 lines
1.1 KiB
Go

package litemigrate
import (
"context"
"database/sql"
"fmt"
"github.com/Southclaws/fault"
"io"
)
func (m *Migrator) runMigrationFile(ctx context.Context, tx *sql.Tx, filename string) error {
f, err := m.fs.Open(filename)
if err != nil {
return fault.Wrap(err)
}
defer f.Close()
bts, err := io.ReadAll(f)
if err != nil {
return fault.Wrap(err)
}
if _, err := tx.ExecContext(ctx, string(bts)); err != nil {
return fault.Wrap(err)
}
return nil
}
func (m *Migrator) getVersion(ctx context.Context, tx *sql.Tx) (userVersion int, err error) {
err = tx.QueryRowContext(ctx, `SELECT * FROM pragma_user_version;`).Scan(&userVersion)
return userVersion, err
}
func (m *Migrator) setVersion(ctx context.Context, tx *sql.Tx, userVersion int) (err error) {
_, err = tx.ExecContext(ctx, fmt.Sprintf("PRAGMA user_version = %v", userVersion))
return err
}
func (m *Migrator) inTX(ctx context.Context, txFn func(tx *sql.Tx) error) error {
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return fault.Wrap(err)
}
if err := txFn(tx); err != nil {
_ = tx.Rollback()
return fault.Wrap(err)
}
return fault.Wrap(tx.Commit())
}