Skip to content

Commit

Permalink
rename a few things and add postgres tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Oct 9, 2023
1 parent 6a4b75b commit d924454
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 19 deletions.
15 changes: 8 additions & 7 deletions lock/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"time"

"github.com/sethvargo/go-retry"
Expand All @@ -20,9 +21,9 @@ import (
// See [SessionLockerOption] for options that can be used to configure the SessionLocker.
func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) {
cfg := sessionLockerConfig{
lockID: DefaultLockID,
lockDuration: DefaultLockDuration,
unlockDuration: DefaultUnlockDuration,
lockID: DefaultLockID,
lockTimeout: DefaultLockTimeout,
unlockTimeout: DefaultUnlockTimeout,
}
for _, opt := range opts {
if err := opt.apply(&cfg); err != nil {
Expand All @@ -32,11 +33,11 @@ func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error
return &postgresSessionLocker{
lockID: cfg.lockID,
retryLock: retry.WithMaxDuration(
cfg.lockDuration,
cfg.lockTimeout,
retry.NewConstant(2*time.Second),
),
retryUnlock: retry.WithMaxDuration(
cfg.unlockDuration,
cfg.unlockTimeout,
retry.NewConstant(2*time.Second),
),
}, nil
Expand All @@ -55,7 +56,7 @@ func (l *postgresSessionLocker) SessionLock(ctx context.Context, conn *sql.Conn)
row := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", l.lockID)
var locked bool
if err := row.Scan(&locked); err != nil {
return err
return fmt.Errorf("failed to execute pg_try_advisory_lock: %w", err)
}
if locked {
// A session-level advisory lock was acquired.
Expand All @@ -73,7 +74,7 @@ func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Con
var unlocked bool
row := conn.QueryRowContext(ctx, "SELECT pg_advisory_unlock($1)", l.lockID)
if err := row.Scan(&unlocked); err != nil {
return err
return fmt.Errorf("failed to execute pg_advisory_unlock: %w", err)
}
if unlocked {
// A session-level advisory lock was released.
Expand Down
187 changes: 187 additions & 0 deletions lock/postgres_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package lock_test

import (
"context"
"database/sql"
"errors"
"sync"
"testing"
"time"

"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
"github.com/pressly/goose/v3/lock"
)

func TestPostgresSessionLocker(t *testing.T) {
if testing.Short() {
t.Skip("skip long running test")
}
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
const (
lockID int64 = 123456789
)

// Do not run tests in parallel, because they are using the same database.

t.Run("lock_and_unlock", func(t *testing.T) {
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockID(lockID),
)
check.NoError(t, err)
ctx := context.Background()
conn, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn.Close())
})
err = locker.SessionLock(ctx, conn)
check.NoError(t, err)
pgLocks, err := queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
// Check that the lock was acquired.
check.Bool(t, pgLocks[0].granted, true)
// Check that the custom lock ID is the same as the one used by the locker.
check.Equal(t, pgLocks[0].gooseLockID, lockID)
check.NumberNotZero(t, pgLocks[0].pid)

// Check that the lock is released.
err = locker.SessionUnlock(ctx, conn)
check.NoError(t, err)
pgLocks, err = queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 0)
})
t.Run("lock_close_conn_unlock", func(t *testing.T) {
locker, err := lock.NewPostgresSessionLocker()
check.NoError(t, err)
ctx := context.Background()
conn, err := db.Conn(ctx)
check.NoError(t, err)

err = locker.SessionLock(ctx, conn)
check.NoError(t, err)
pgLocks, err := queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
check.Bool(t, pgLocks[0].granted, true)
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
// Simulate a connection close.
err = conn.Close()
check.NoError(t, err)
// Check an error is returned when unlocking, because the connection is already closed.
err = locker.SessionUnlock(ctx, conn)
check.HasError(t, err)
check.Bool(t, errors.Is(err, sql.ErrConnDone), true)
})
t.Run("multiple_connections", func(t *testing.T) {
const (
workers = 5
)
ch := make(chan error)
var wg sync.WaitGroup
for i := 0; i < workers; i++ {
wg.Add(1)

go func() {
defer wg.Done()
ctx := context.Background()
conn, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn.Close())
})
// Exactly one connection should acquire the lock. While the other connections
// should fail to acquire the lock and timeout.
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockTimeout(4 * time.Second),
)
check.NoError(t, err)
ch <- locker.SessionLock(ctx, conn)
}()
}
go func() {
wg.Wait()
close(ch)
}()
var errors []error
for err := range ch {
if err != nil {
errors = append(errors, err)
}
}
check.Equal(t, len(errors), workers-1) // One worker succeeds, the rest fail.
for _, err := range errors {
check.HasError(t, err)
check.Equal(t, err.Error(), "failed to acquire lock")
}
pgLocks, err := queryPgLocks(context.Background(), db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
check.Bool(t, pgLocks[0].granted, true)
check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID)
})
t.Run("unlock_with_different_connection", func(t *testing.T) {
ctx := context.Background()
const (
lockID int64 = 999
)
locker, err := lock.NewPostgresSessionLocker(
lock.WithLockID(lockID),
lock.WithUnlockTimeout(4*time.Second),
)
check.NoError(t, err)

conn1, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn1.Close())
})
err = locker.SessionLock(ctx, conn1)
check.NoError(t, err)
pgLocks, err := queryPgLocks(ctx, db)
check.NoError(t, err)
check.Number(t, len(pgLocks), 1)
check.Bool(t, pgLocks[0].granted, true)
check.Equal(t, pgLocks[0].gooseLockID, lockID)
// Unlock with a different connection.
conn2, err := db.Conn(ctx)
check.NoError(t, err)
t.Cleanup(func() {
check.NoError(t, conn2.Close())
})
// Check an error is returned when unlocking with a different connection.
err = locker.SessionUnlock(ctx, conn2)
check.HasError(t, err)
})
}

type pgLock struct {
pid int
granted bool
gooseLockID int64
}

func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) {
q := `SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks
WHERE locktype='advisory'`
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
var pgLocks []pgLock
for rows.Next() {
var p pgLock
if err = rows.Scan(&p.pid, &p.granted, &p.gooseLockID); err != nil {
return nil, err
}
pgLocks = append(pgLocks, p)
}
if err := rows.Err(); err != nil {
return nil, err
}
return pgLocks, nil
}
24 changes: 12 additions & 12 deletions lock/session_locker_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ const (
DefaultLockID int64 = 5887940537704921958

// Default values for the lock (time to wait for the lock to be acquired) and unlock (time to
// wait for the lock to be released) durations.
DefaultLockDuration time.Duration = 60 * time.Minute
DefaultUnlockDuration time.Duration = 1 * time.Minute
// wait for the lock to be released) wait durations.
DefaultLockTimeout time.Duration = 60 * time.Minute
DefaultUnlockTimeout time.Duration = 1 * time.Minute
)

// SessionLockerOption is used to configure a SessionLocker.
Expand All @@ -32,26 +32,26 @@ func WithLockID(lockID int64) SessionLockerOption {
})
}

// WithLockDuration sets the max duration to wait for the lock to be acquired.
func WithLockDuration(duration time.Duration) SessionLockerOption {
// WithLockTimeout sets the max duration to wait for the lock to be acquired.
func WithLockTimeout(duration time.Duration) SessionLockerOption {
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
c.lockDuration = duration
c.lockTimeout = duration
return nil
})
}

// WithUnlockDuration sets the max duration to wait for the lock to be released.
func WithUnlockDuration(duration time.Duration) SessionLockerOption {
// WithUnlockTimeout sets the max duration to wait for the lock to be released.
func WithUnlockTimeout(duration time.Duration) SessionLockerOption {
return sessionLockerConfigFunc(func(c *sessionLockerConfig) error {
c.unlockDuration = duration
c.unlockTimeout = duration
return nil
})
}

type sessionLockerConfig struct {
lockID int64
lockDuration time.Duration
unlockDuration time.Duration
lockID int64
lockTimeout time.Duration
unlockTimeout time.Duration
}

var _ SessionLockerOption = (sessionLockerConfigFunc)(nil)
Expand Down

0 comments on commit d924454

Please sign in to comment.