From c590380f39cdea66e5000b3836f146df836a5c74 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Mon, 9 Oct 2023 15:08:51 -0400 Subject: [PATCH] feat(experimental): add internal migrate package and SessionLocker interface (#606) --- go.mod | 1 + go.sum | 2 + internal/migrate/doc.go | 9 ++ internal/migrate/migration.go | 166 ++++++++++++++++++++++++ internal/migrate/parse.go | 75 +++++++++++ internal/migrate/run.go | 53 ++++++++ internal/sqlextended/sqlextended.go | 2 +- lock/postgres.go | 110 ++++++++++++++++ lock/postgres_test.go | 193 ++++++++++++++++++++++++++++ lock/session_locker.go | 23 ++++ lock/session_locker_options.go | 63 +++++++++ provider_options.go | 29 ++++- 12 files changed, 723 insertions(+), 3 deletions(-) create mode 100644 internal/migrate/doc.go create mode 100644 internal/migrate/migration.go create mode 100644 internal/migrate/parse.go create mode 100644 internal/migrate/run.go create mode 100644 lock/postgres.go create mode 100644 lock/postgres_test.go create mode 100644 lock/session_locker.go create mode 100644 lock/session_locker_options.go diff --git a/go.mod b/go.mod index 9beb962cf..0230e2c97 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jackc/pgx/v5 v5.4.3 github.com/microsoft/go-mssqldb v1.6.0 github.com/ory/dockertest/v3 v3.10.0 + github.com/sethvargo/go-retry v0.2.4 github.com/vertica/vertica-sql-go v1.3.3 github.com/ziutek/mymysql v1.5.4 go.uber.org/multierr v1.11.0 diff --git a/go.sum b/go.sum index 59b8881ef..e2c6ea03a 100644 --- a/go.sum +++ b/go.sum @@ -127,6 +127,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qq github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/sethvargo/go-retry v0.2.4 h1:T+jHEQy/zKJf5s95UkguisicE0zuF9y7+/vgz08Ocec= +github.com/sethvargo/go-retry v0.2.4/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/internal/migrate/doc.go b/internal/migrate/doc.go new file mode 100644 index 000000000..5fbee1582 --- /dev/null +++ b/internal/migrate/doc.go @@ -0,0 +1,9 @@ +// Package migrate defines a Migration struct and implements the migration logic for executing Go +// and SQL migrations. +// +// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported. +// - For SQL migrations, all three are supported. +// +// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first +// time it is executed. +package migrate diff --git a/internal/migrate/migration.go b/internal/migrate/migration.go new file mode 100644 index 000000000..23a0514cf --- /dev/null +++ b/internal/migrate/migration.go @@ -0,0 +1,166 @@ +package migrate + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +type Migration struct { + // Fullpath is the full path to the migration file. + // + // Example: /path/to/migrations/123_create_users_table.go + Fullpath string + // Version is the version of the migration. + Version int64 + // Type is the type of migration. + Type MigrationType + // A migration is either a Go migration or a SQL migration, but never both. + // + // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is + // an optimization to avoid parsing the SQL migration if it is never required. Also, the + // majority of the time migrations are incremental, so it is likely that the user will only want + // to run the last few migrations, and there is no need to parse ALL prior migrations. + // + // Exactly one of these fields will be set: + Go *Go + // -- or -- + SQLParsed bool + SQL *SQL +} + +type MigrationType int + +const ( + TypeGo MigrationType = iota + 1 + TypeSQL +) + +func (t MigrationType) String() string { + switch t { + case TypeGo: + return "go" + case TypeSQL: + return "sql" + default: + // This should never happen. + return "unknown" + } +} + +func (m *Migration) UseTx() bool { + switch m.Type { + case TypeGo: + return m.Go.UseTx + case TypeSQL: + return m.SQL.UseTx + default: + // This should never happen. + panic("unknown migration type: use tx") + } +} + +func (m *Migration) IsEmpty(direction bool) bool { + switch m.Type { + case TypeGo: + return m.Go.IsEmpty(direction) + case TypeSQL: + return m.SQL.IsEmpty(direction) + default: + // This should never happen. + panic("unknown migration type: is empty") + } +} + +func (m *Migration) GetSQLStatements(direction bool) ([]string, error) { + if m.Type != TypeSQL { + return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type) + } + if m.SQL == nil { + return nil, errors.New("sql migration has not been initialized") + } + if !m.SQLParsed { + return nil, errors.New("sql migration has not been parsed") + } + if direction { + return m.SQL.UpStatements, nil + } + return m.SQL.DownStatements, nil +} + +type Go struct { + // We used an explicit bool instead of relying on a pointer because registered funcs may be nil. + // These are still valid Go and versioned migrations, but they are just empty. + // + // For example: goose.AddMigration(nil, nil) + UseTx bool + + // Only one of these func pairs will be set: + UpFn, DownFn func(context.Context, *sql.Tx) error + // -- or -- + UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error +} + +func (g *Go) IsEmpty(direction bool) bool { + if direction { + return g.UpFn == nil && g.UpFnNoTx == nil + } + return g.DownFn == nil && g.DownFnNoTx == nil +} + +func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error { + var fn func(context.Context, *sql.Tx) error + if direction { + fn = g.UpFn + } else { + fn = g.DownFn + } + if fn != nil { + return fn(ctx, tx) + } + return nil +} + +func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + var fn func(context.Context, *sql.DB) error + if direction { + fn = g.UpFnNoTx + } else { + fn = g.DownFnNoTx + } + if fn != nil { + return fn(ctx, db) + } + return nil +} + +type SQL struct { + UseTx bool + UpStatements []string + DownStatements []string +} + +func (s *SQL) IsEmpty(direction bool) bool { + if direction { + return len(s.UpStatements) == 0 + } + return len(s.DownStatements) == 0 +} + +func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { + var statements []string + if direction { + statements = s.UpStatements + } else { + statements = s.DownStatements + } + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} diff --git a/internal/migrate/parse.go b/internal/migrate/parse.go new file mode 100644 index 000000000..18a66b499 --- /dev/null +++ b/internal/migrate/parse.go @@ -0,0 +1,75 @@ +package migrate + +import ( + "bytes" + "io" + "io/fs" + + "github.com/pressly/goose/v3/internal/sqlparser" +) + +// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it +// will not be parsed again. +// +// Important: This function will mutate SQL migrations. +func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error { + for _, m := range migrations { + if m.Type == TypeSQL && !m.SQLParsed { + parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug) + if err != nil { + return err + } + m.SQLParsed = true + m.SQL = parsedSQLMigration + } + } + return nil +} + +// parse is used to determine which direction to parse the SQL migration. +type parse int + +const ( + // parseAll parses all SQL statements in BOTH directions. + parseAll parse = iota + 1 + // parseUp parses all SQL statements in the UP direction. + parseUp + // parseDown parses all SQL statements in the DOWN direction. + parseDown +) + +func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) { + r, err := fsys.Open(filename) + if err != nil { + return nil, err + } + by, err := io.ReadAll(r) + if err != nil { + return nil, err + } + if err := r.Close(); err != nil { + return nil, err + } + s := new(SQL) + if p == parseAll || p == parseUp { + s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration( + bytes.NewReader(by), + sqlparser.DirectionUp, + debug, + ) + if err != nil { + return nil, err + } + } + if p == parseAll || p == parseDown { + s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration( + bytes.NewReader(by), + sqlparser.DirectionDown, + debug, + ) + if err != nil { + return nil, err + } + } + return s, nil +} diff --git a/internal/migrate/run.go b/internal/migrate/run.go new file mode 100644 index 000000000..7b7a883d8 --- /dev/null +++ b/internal/migrate/run.go @@ -0,0 +1,53 @@ +package migrate + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" +) + +// Run runs the migration inside of a transaction. +func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { + switch m.Type { + case TypeSQL: + if m.SQL == nil || !m.SQLParsed { + return fmt.Errorf("tx: sql migration has not been parsed") + } + return m.SQL.run(ctx, tx, direction) + case TypeGo: + return m.Go.run(ctx, tx, direction) + } + // This should never happen. + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) +} + +// RunNoTx runs the migration without a transaction. +func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { + switch m.Type { + case TypeSQL: + if m.SQL == nil || !m.SQLParsed { + return fmt.Errorf("db: sql migration has not been parsed") + } + return m.SQL.run(ctx, db, direction) + case TypeGo: + return m.Go.runNoTx(ctx, db, direction) + } + // This should never happen. + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) +} + +// RunConn runs the migration without a transaction using the provided connection. +func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { + switch m.Type { + case TypeSQL: + if m.SQL == nil || !m.SQLParsed { + return fmt.Errorf("conn: sql migration has not been parsed") + } + return m.SQL.run(ctx, conn, direction) + case TypeGo: + return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") + } + // This should never happen. + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) +} diff --git a/internal/sqlextended/sqlextended.go b/internal/sqlextended/sqlextended.go index e3e763abf..83ca7ae8b 100644 --- a/internal/sqlextended/sqlextended.go +++ b/internal/sqlextended/sqlextended.go @@ -11,7 +11,7 @@ import ( // There is a long outstanding issue to formalize a std lib interface, but alas... See: // https://github.com/golang/go/issues/14468 type DBTxConn interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } diff --git a/lock/postgres.go b/lock/postgres.go new file mode 100644 index 000000000..3583162e2 --- /dev/null +++ b/lock/postgres.go @@ -0,0 +1,110 @@ +package lock + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/sethvargo/go-retry" +) + +// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive +// session-level advisory lock mechanism. +// +// This function creates a SessionLocker that can be used to acquire and release locks for +// synchronization purposes. The lock acquisition is retried until it is successfully acquired or +// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the +// default unlock duration is set to 1 minute. +// +// See [SessionLockerOption] for options that can be used to configure the SessionLocker. +func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) { + cfg := sessionLockerConfig{ + lockID: DefaultLockID, + lockTimeout: DefaultLockTimeout, + unlockTimeout: DefaultUnlockTimeout, + } + for _, opt := range opts { + if err := opt.apply(&cfg); err != nil { + return nil, err + } + } + return &postgresSessionLocker{ + lockID: cfg.lockID, + retryLock: retry.WithMaxDuration( + cfg.lockTimeout, + retry.NewConstant(2*time.Second), + ), + retryUnlock: retry.WithMaxDuration( + cfg.unlockTimeout, + retry.NewConstant(2*time.Second), + ), + }, nil +} + +type postgresSessionLocker struct { + lockID int64 + retryLock retry.Backoff + retryUnlock retry.Backoff +} + +var _ SessionLocker = (*postgresSessionLocker)(nil) + +func (l *postgresSessionLocker) SessionLock(ctx context.Context, conn *sql.Conn) error { + return retry.Do(ctx, l.retryLock, func(ctx context.Context) error { + row := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", l.lockID) + var locked bool + if err := row.Scan(&locked); err != nil { + return fmt.Errorf("failed to execute pg_try_advisory_lock: %w", err) + } + if locked { + // A session-level advisory lock was acquired. + return nil + } + // A session-level advisory lock could not be acquired. This is likely because another + // process has already acquired the lock. We will continue retrying until the lock is + // acquired or the maximum number of retries is reached. + return retry.RetryableError(errors.New("failed to acquire lock")) + }) +} + +func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Conn) error { + return retry.Do(ctx, l.retryUnlock, func(ctx context.Context) error { + var unlocked bool + row := conn.QueryRowContext(ctx, "SELECT pg_advisory_unlock($1)", l.lockID) + if err := row.Scan(&unlocked); err != nil { + return fmt.Errorf("failed to execute pg_advisory_unlock: %w", err) + } + if unlocked { + // A session-level advisory lock was released. + return nil + } + /* + TODO(mf): provide users with some documentation on how they can unlock the session + manually. + + This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will + release all session level advisory locks held by the current session. This function is + implicitly invoked at session end, even if the client disconnects ungracefully. + + Here is output from a session that has a lock held: + + SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks + WHERE locktype='advisory'; + + | pid | granted | goose_lock_id | + |-----|---------|---------------------| + | 191 | t | 5887940537704921958 | + + A forceful way to unlock the session is to terminate the backend with SIGTERM: + + SELECT pg_terminate_backend(191); + + Subsequent commands on the same connection will fail with: + + Query 1 ERROR: FATAL: terminating connection due to administrator command + */ + return retry.RetryableError(errors.New("failed to unlock session")) + }) +} diff --git a/lock/postgres_test.go b/lock/postgres_test.go new file mode 100644 index 000000000..bfb1a0d99 --- /dev/null +++ b/lock/postgres_test.go @@ -0,0 +1,193 @@ +package lock_test + +import ( + "context" + "database/sql" + "errors" + "sync" + "testing" + "time" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/testdb" + "github.com/pressly/goose/v3/lock" +) + +func TestPostgresSessionLocker(t *testing.T) { + if testing.Short() { + t.Skip("skip long running test") + } + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + const ( + lockID int64 = 123456789 + ) + + // Do not run tests in parallel, because they are using the same database. + + t.Run("lock_and_unlock", func(t *testing.T) { + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockID(lockID), + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn.Close()) + }) + err = locker.SessionLock(ctx, conn) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + // Check that the lock was acquired. + check.Bool(t, pgLocks[0].granted, true) + // Check that the custom lock ID is the same as the one used by the locker. + check.Equal(t, pgLocks[0].gooseLockID, lockID) + check.NumberNotZero(t, pgLocks[0].pid) + + // Check that the lock is released. + err = locker.SessionUnlock(ctx, conn) + check.NoError(t, err) + pgLocks, err = queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 0) + }) + t.Run("lock_close_conn_unlock", func(t *testing.T) { + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + + err = locker.SessionLock(ctx, conn) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID) + // Simulate a connection close. + err = conn.Close() + check.NoError(t, err) + // Check an error is returned when unlocking, because the connection is already closed. + err = locker.SessionUnlock(ctx, conn) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrConnDone), true) + }) + t.Run("multiple_connections", func(t *testing.T) { + const ( + workers = 5 + ) + ch := make(chan error) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn.Close()) + }) + // Exactly one connection should acquire the lock. While the other connections + // should fail to acquire the lock and timeout. + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ch <- locker.SessionLock(ctx, conn) + }() + } + go func() { + wg.Wait() + close(ch) + }() + var errors []error + for err := range ch { + if err != nil { + errors = append(errors, err) + } + } + check.Equal(t, len(errors), workers-1) // One worker succeeds, the rest fail. + for _, err := range errors { + check.HasError(t, err) + check.Equal(t, err.Error(), "failed to acquire lock") + } + pgLocks, err := queryPgLocks(context.Background(), db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID) + }) + t.Run("unlock_with_different_connection", func(t *testing.T) { + ctx := context.Background() + const ( + lockID int64 = 999 + ) + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockID(lockID), + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + + conn1, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn1.Close()) + }) + err = locker.SessionLock(ctx, conn1) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lockID) + // Unlock with a different connection. + conn2, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn2.Close()) + }) + // Check an error is returned when unlocking with a different connection. + err = locker.SessionUnlock(ctx, conn2) + check.HasError(t, err) + }) +} + +type pgLock struct { + pid int + granted bool + gooseLockID int64 +} + +func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) { + q := `SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks WHERE locktype='advisory'` + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, err + } + var pgLocks []pgLock + for rows.Next() { + var p pgLock + if err = rows.Scan(&p.pid, &p.granted, &p.gooseLockID); err != nil { + return nil, err + } + pgLocks = append(pgLocks, p) + } + if err := rows.Err(); err != nil { + return nil, err + } + return pgLocks, nil +} diff --git a/lock/session_locker.go b/lock/session_locker.go new file mode 100644 index 000000000..b74187829 --- /dev/null +++ b/lock/session_locker.go @@ -0,0 +1,23 @@ +// Package lock defines the Locker interface and implements the locking logic. +package lock + +import ( + "context" + "database/sql" + "errors" +) + +var ( + // ErrLockNotImplemented is returned when the database does not support locking. + ErrLockNotImplemented = errors.New("lock not implemented") + // ErrUnlockNotImplemented is returned when the database does not support unlocking. + ErrUnlockNotImplemented = errors.New("unlock not implemented") +) + +// SessionLocker is the interface to lock and unlock the database for the duration of a session. The +// session is defined as the duration of a single connection and both methods must be called on the +// same connection. +type SessionLocker interface { + SessionLock(ctx context.Context, conn *sql.Conn) error + SessionUnlock(ctx context.Context, conn *sql.Conn) error +} diff --git a/lock/session_locker_options.go b/lock/session_locker_options.go new file mode 100644 index 000000000..c3e42151c --- /dev/null +++ b/lock/session_locker_options.go @@ -0,0 +1,63 @@ +package lock + +import ( + "time" +) + +const ( + // DefaultLockID is the id used to lock the database for migrations. It is a crc64 hash of the + // string "goose". This is used to ensure that the lock is unique to goose. + // + // crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA)) + DefaultLockID int64 = 5887940537704921958 + + // Default values for the lock (time to wait for the lock to be acquired) and unlock (time to + // wait for the lock to be released) wait durations. + DefaultLockTimeout time.Duration = 60 * time.Minute + DefaultUnlockTimeout time.Duration = 1 * time.Minute +) + +// SessionLockerOption is used to configure a SessionLocker. +type SessionLockerOption interface { + apply(*sessionLockerConfig) error +} + +// WithLockID sets the lock ID to use when locking the database. +// +// If WithLockID is not called, the DefaultLockID is used. +func WithLockID(lockID int64) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.lockID = lockID + return nil + }) +} + +// WithLockTimeout sets the max duration to wait for the lock to be acquired. +func WithLockTimeout(duration time.Duration) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.lockTimeout = duration + return nil + }) +} + +// WithUnlockTimeout sets the max duration to wait for the lock to be released. +func WithUnlockTimeout(duration time.Duration) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.unlockTimeout = duration + return nil + }) +} + +type sessionLockerConfig struct { + lockID int64 + lockTimeout time.Duration + unlockTimeout time.Duration +} + +var _ SessionLockerOption = (sessionLockerConfigFunc)(nil) + +type sessionLockerConfigFunc func(*sessionLockerConfig) error + +func (f sessionLockerConfigFunc) apply(cfg *sessionLockerConfig) error { + return f(cfg) +} diff --git a/provider_options.go b/provider_options.go index 904b3ed34..2370486f9 100644 --- a/provider_options.go +++ b/provider_options.go @@ -3,6 +3,8 @@ package goose import ( "errors" "fmt" + + "github.com/pressly/goose/v3/lock" ) const ( @@ -38,13 +40,36 @@ func WithVerbose() ProviderOption { }) } +// WithSessionLocker enables locking using the provided SessionLocker. +// +// If WithSessionLocker is not called, locking is disabled. +func WithSessionLocker(locker lock.SessionLocker) ProviderOption { + return configFunc(func(c *config) error { + if c.lockEnabled { + return errors.New("lock already enabled") + } + if c.sessionLocker != nil { + return errors.New("session locker already set") + } + if locker == nil { + return errors.New("session locker must not be nil") + } + c.lockEnabled = true + c.sessionLocker = locker + return nil + }) +} + type config struct { tableName string verbose bool + + lockEnabled bool + sessionLocker lock.SessionLocker } type configFunc func(*config) error -func (o configFunc) apply(cfg *config) error { - return o(cfg) +func (f configFunc) apply(cfg *config) error { + return f(cfg) }