2024-11-06 22:16:17 +00:00
|
|
|
package litemigrate
|
2024-09-27 23:04:07 +00:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
"github.com/Southclaws/fault"
|
|
|
|
"io"
|
|
|
|
)
|
|
|
|
|
|
|
|
func (m *Migrator) runMigrationFile(ctx context.Context, tx *sql.Tx, filename string) error {
|
|
|
|
f, err := m.fs.Open(filename)
|
|
|
|
if err != nil {
|
|
|
|
return fault.Wrap(err)
|
|
|
|
}
|
|
|
|
defer f.Close()
|
|
|
|
|
|
|
|
bts, err := io.ReadAll(f)
|
|
|
|
if err != nil {
|
|
|
|
return fault.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, err := tx.ExecContext(ctx, string(bts)); err != nil {
|
|
|
|
return fault.Wrap(err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migrator) getVersion(ctx context.Context, tx *sql.Tx) (userVersion int, err error) {
|
|
|
|
err = tx.QueryRowContext(ctx, `SELECT * FROM pragma_user_version;`).Scan(&userVersion)
|
|
|
|
return userVersion, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migrator) setVersion(ctx context.Context, tx *sql.Tx, userVersion int) (err error) {
|
|
|
|
_, err = tx.ExecContext(ctx, fmt.Sprintf("PRAGMA user_version = %v", userVersion))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migrator) inTX(ctx context.Context, txFn func(tx *sql.Tx) error) error {
|
|
|
|
tx, err := m.db.BeginTx(ctx, nil)
|
|
|
|
if err != nil {
|
|
|
|
return fault.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := txFn(tx); err != nil {
|
|
|
|
_ = tx.Rollback()
|
|
|
|
return fault.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return fault.Wrap(tx.Commit())
|
|
|
|
}
|