Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): replace dialect with a state.Storage interface #608

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 0 additions & 55 deletions dialect.go

This file was deleted.

69 changes: 69 additions & 0 deletions global.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package goose

import (
"fmt"

"github.com/pressly/goose/v3/state"
"github.com/pressly/goose/v3/state/storage"
)

var global = struct {
storageFactory func(string) state.Storage
tableName string
}{
storageFactory: storage.PostgreSQL,
tableName: "goose_db_version",
}

func globalStorage() state.Storage {
return global.storageFactory(global.tableName)
}

// TableName returns goose db version table name
func TableName() string {
return global.tableName
}

// SetTableName set goose db version table name
func SetTableName(n string) {
global.tableName = n
}

// Dialect is the type of database dialect.
type Dialect string

const (
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectTiDB Dialect = "tidb"
DialectVertica Dialect = "vertica"
)

// SetDialect sets the dialect to use for the goose package.
func SetDialect(s string) error {
switch s {
case "postgres", "pgx":
global.storageFactory = storage.PostgreSQL
// case "mysql":
// d = dialect.Mysql
case "sqlite3", "sqlite":
global.storageFactory = storage.Sqlite3
// case "mssql", "azuresql", "sqlserver":
// d = dialect.Sqlserver
// case "redshift":
// d = dialect.Redshift
// case "tidb":
// d = dialect.Tidb
// case "clickhouse":
// d = dialect.Clickhouse
// case "vertica":
// d = dialect.Vertica
default:
return fmt.Errorf("%q: unknown dialect", s)
}
return nil
}
24 changes: 8 additions & 16 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
"math"
"sync"

"github.com/pressly/goose/v3/internal/sqladapter"
"github.com/pressly/goose/v3/state"
)

// NewProvider returns a new goose Provider.
//
// The caller is responsible for matching the database dialect with the database/sql driver. For
// example, if the database dialect is "postgres", the database/sql driver could be
// github.com/lib/pq or github.com/jackc/pgx.
// storage implementations are available in the [state/storage] package (e.g. storage.Sqlite3).
//
// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to
// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is
Expand All @@ -27,12 +25,12 @@ import (
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
func NewProvider(storage state.Storage, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
if storage == nil {
return nil, errors.New("storage must not be nil")
}
if fsys == nil {
fsys = noopFS{}
Expand All @@ -46,13 +44,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
}
}
// Set defaults after applying user-supplied options so option funcs can check for empty values.
if cfg.tableName == "" {
cfg.tableName = DefaultTablename
}
store, err := sqladapter.NewStore(string(dialect), cfg.tableName)
if err != nil {
return nil, err
}

// Collect migrations from the filesystem and merge with registered migrations.
//
// Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed
Expand Down Expand Up @@ -122,7 +114,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
db: db,
fsys: fsys,
cfg: cfg,
store: store,
store: storage,
migrations: migrations,
}, nil
}
Expand All @@ -136,7 +128,7 @@ type Provider struct {
db *sql.DB
fsys fs.FS
cfg config
store sqladapter.Store
store state.Storage
migrations []*migration
}

Expand Down
21 changes: 2 additions & 19 deletions internal/provider/provider_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,6 @@ type ProviderOption interface {
apply(*config) error
}

// WithTableName sets the name of the database table used to track history of applied migrations.
//
// If WithTableName is not called, the default value is "goose_db_version".
func WithTableName(name string) ProviderOption {
return configFunc(func(c *config) error {
if c.tableName != "" {
return fmt.Errorf("table already set to %q", c.tableName)
}
if name == "" {
return errors.New("table must not be empty")
}
c.tableName = name
return nil
})
}

// WithVerbose enables verbose logging.
func WithVerbose(b bool) ProviderOption {
return configFunc(func(c *config) error {
Expand Down Expand Up @@ -143,9 +127,8 @@ func WithNoVersioning(b bool) ProviderOption {
}

type config struct {
tableName string
verbose bool
excludes map[string]bool
verbose bool
excludes map[string]bool

// Go migrations registered by the user. These will be merged/resolved with migrations from the
// filesystem and init() functions.
Expand Down
31 changes: 7 additions & 24 deletions internal/provider/provider_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
"github.com/pressly/goose/v3/state/storage"
_ "modernc.org/sqlite"
)

Expand All @@ -23,42 +24,24 @@ func TestNewProvider(t *testing.T) {
}
t.Run("invalid", func(t *testing.T) {
// Empty dialect not allowed
_, err = provider.NewProvider("", db, fsys)
check.HasError(t, err)
// Invalid dialect not allowed
_, err = provider.NewProvider("unknown-dialect", db, fsys)
_, err = provider.NewProvider(nil, db, fsys)
check.HasError(t, err)
// Nil db not allowed
_, err = provider.NewProvider("sqlite3", nil, fsys)
_, err = provider.NewProvider(storage.Sqlite3(""), nil, fsys)
check.HasError(t, err)
// Nil fsys not allowed
_, err = provider.NewProvider("sqlite3", db, nil)
check.HasError(t, err)
// Duplicate table name not allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName("foo"),
provider.WithTableName("bar"),
)
_, err = provider.NewProvider(storage.Sqlite3(""), db, nil)
check.HasError(t, err)
check.Equal(t, `table already set to "foo"`, err.Error())
// Empty table name not allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName(""),
)
check.HasError(t, err)
check.Equal(t, "table must not be empty", err.Error())
})
t.Run("valid", func(t *testing.T) {
// Valid dialect, db, and fsys allowed
_, err = provider.NewProvider("sqlite3", db, fsys)
_, err = provider.NewProvider(storage.Sqlite3(""), db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and table name allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
provider.WithTableName("foo"),
)
_, err = provider.NewProvider(storage.Sqlite3("foo"), db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and verbose allowed
_, err = provider.NewProvider("sqlite3", db, fsys,
_, err = provider.NewProvider(storage.Sqlite3(""), db, fsys,
provider.WithVerbose(testing.Verbose()),
)
check.NoError(t, err)
Expand Down
13 changes: 7 additions & 6 deletions internal/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/provider"
"github.com/pressly/goose/v3/state/storage"
_ "modernc.org/sqlite"
)

Expand All @@ -19,7 +20,7 @@ func TestProvider(t *testing.T) {
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
t.Run("empty", func(t *testing.T) {
_, err := provider.NewProvider("sqlite3", db, fstest.MapFS{})
_, err := provider.NewProvider(storage.Sqlite3(""), db, fstest.MapFS{})
check.HasError(t, err)
check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true)
})
Expand All @@ -30,7 +31,7 @@ func TestProvider(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
p, err := provider.NewProvider("sqlite3", db, fsys)
p, err := provider.NewProvider(storage.Sqlite3(""), db, fsys)
check.NoError(t, err)
sources := p.ListSources()
check.Equal(t, len(sources), 2)
Expand All @@ -51,7 +52,7 @@ func TestProvider(t *testing.T) {
t.Cleanup(provider.ResetGlobalGoMigrations)

db := newDB(t)
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil,
_, err = provider.NewProvider(storage.Sqlite3(""), db, nil,
provider.WithGoMigration(1, nil, nil),
)
check.HasError(t, err)
Expand All @@ -60,7 +61,7 @@ func TestProvider(t *testing.T) {
t.Run("empty_go", func(t *testing.T) {
db := newDB(t)
// explicit
_, err := provider.NewProvider(provider.DialectSQLite3, db, nil,
_, err := provider.NewProvider(storage.Sqlite3(""), db, nil,
provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}),
)
check.HasError(t, err)
Expand All @@ -77,7 +78,7 @@ func TestProvider(t *testing.T) {
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations)
db := newDB(t)
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
_, err = provider.NewProvider(storage.Sqlite3(""), db, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext")
})
Expand All @@ -92,7 +93,7 @@ func TestProvider(t *testing.T) {
check.NoError(t, err)
t.Cleanup(provider.ResetGlobalGoMigrations)
db := newDB(t)
_, err = provider.NewProvider(provider.DialectSQLite3, db, nil)
_, err = provider.NewProvider(storage.Sqlite3(""), db, nil)
check.HasError(t, err)
check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext")
})
Expand Down
Loading