diff --git a/Makefile b/Makefile index 46b444fbc..8a3a05fb1 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,9 @@ tools: test-packages: go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) +test-packages-short: + go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) + test-e2e: test-e2e-postgres test-e2e-mysql test-e2e-clickhouse test-e2e-vertica test-e2e-postgres: diff --git a/create_test.go b/create_test.go index fddf48d85..34791cc65 100644 --- a/create_test.go +++ b/create_test.go @@ -11,6 +11,9 @@ import ( func TestSequential(t *testing.T) { t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } dir := t.TempDir() defer os.Remove("./bin/create-goose") // clean up diff --git a/fix_test.go b/fix_test.go index 5c982dbe8..6a5e0842b 100644 --- a/fix_test.go +++ b/fix_test.go @@ -11,6 +11,9 @@ import ( func TestFix(t *testing.T) { t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } dir := t.TempDir() defer os.Remove("./bin/fix-goose") // clean up diff --git a/internal/migrationstats/migrationstats_test.go b/internal/migrationstats/migrationstats_test.go index 67a65a3cf..26c49fd38 100644 --- a/internal/migrationstats/migrationstats_test.go +++ b/internal/migrationstats/migrationstats_test.go @@ -8,6 +8,7 @@ import ( ) func TestParsingGoMigrations(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -38,6 +39,7 @@ func TestParsingGoMigrations(t *testing.T) { } func TestParsingGoMigrationsError(t *testing.T) { + t.Parallel() _, err := parseGoFile(strings.NewReader(emptyInit)) check.HasError(t, err) check.Contains(t, err.Error(), "no registered goose functions") diff --git a/internal/provider/collect.go b/internal/provider/collect.go new file mode 100644 index 000000000..cf12961fb --- /dev/null +++ b/internal/provider/collect.go @@ -0,0 +1,176 @@ +package provider + +import ( + "errors" + "fmt" + "io/fs" + "path/filepath" + "sort" + "strings" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/migrate" +) + +// fileSources represents a collection of migration files on the filesystem. +type fileSources struct { + sqlSources []Source + goSources []Source +} + +// collectFileSources scans the file system for migration files that have a numeric prefix (greater +// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil, +// in which case an empty fileSources is returned. +// +// If strict is true, then any error parsing the numeric component of the filename will result in an +// error. The file is skipped otherwise. +// +// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects +// migration sources from the filesystem. +func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) { + if fsys == nil { + return new(fileSources), nil + } + sources := new(fileSources) + versionToBaseLookup := make(map[int64]string) // map[version]filepath.Base(fullpath) + for _, pattern := range []string{ + "*.sql", + "*.go", + } { + files, err := fs.Glob(fsys, pattern) + if err != nil { + return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err) + } + for _, fullpath := range files { + base := filepath.Base(fullpath) + // Skip explicit excludes or Go test files. + if excludes[base] || strings.HasSuffix(base, "_test.go") { + continue + } + // If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use + // that as the version. Otherwise, ignore it. This allows users to have arbitrary + // filenames, but still have versioned migrations within the same directory. For + // example, a user could have a helpers.go file which contains unexported helper + // functions for migrations. + version, err := goose.NumericComponent(base) + if err != nil { + if strict { + return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) + } + continue + } + // Ensure there are no duplicate versions. + if existing, ok := versionToBaseLookup[version]; ok { + return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", + version, + existing, + base, + ) + } + switch filepath.Ext(base) { + case ".sql": + sources.sqlSources = append(sources.sqlSources, Source{ + Fullpath: fullpath, + Version: version, + }) + case ".go": + sources.goSources = append(sources.goSources, Source{ + Fullpath: fullpath, + Version: version, + }) + default: + // Should never happen since we already filtered out all other file types. + return nil, fmt.Errorf("unknown migration type: %s", base) + } + // Add the version to the lookup map. + versionToBaseLookup[version] = base + } + } + return sources, nil +} + +func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) { + var migrations []*migrate.Migration + migrationLookup := make(map[int64]*migrate.Migration) + // Add all SQL migrations to the list of migrations. + for _, s := range sources.sqlSources { + m := &migrate.Migration{ + Type: migrate.TypeSQL, + Fullpath: s.Fullpath, + Version: s.Version, + SQLParsed: false, + } + migrations = append(migrations, m) + migrationLookup[s.Version] = m + } + // If there are no Go files in the filesystem and no registered Go migrations, return early. + if len(sources.goSources) == 0 && len(registerd) == 0 { + return migrations, nil + } + // Return an error if the given sources contain a versioned Go migration that has not been + // registered. This is a sanity check to ensure users didn't accidentally create a valid looking + // Go migration file on disk and forget to register it. + // + // This is almost always a user error. + var unregistered []string + for _, s := range sources.goSources { + if _, ok := registerd[s.Version]; !ok { + unregistered = append(unregistered, s.Fullpath) + } + } + if len(unregistered) > 0 { + return nil, unregisteredError(unregistered) + } + // Add all registered Go migrations to the list of migrations, checking for duplicate versions. + // + // Important, users can register Go migrations manually via goose.Add_ functions. These + // migrations may not have a corresponding file on disk. Which is fine! We include them + // wholesale as part of migrations. This allows users to build a custom binary that only embeds + // the SQL migration files. + for _, r := range registerd { + // Ensure there are no duplicate versions. + if existing, ok := migrationLookup[r.Version]; ok { + return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", + r.Version, + existing, + filepath.Base(r.Source), + ) + } + m := &migrate.Migration{ + Fullpath: r.Source, // May be empty if the migration was registered manually. + Version: r.Version, + Type: migrate.TypeGo, + Go: &migrate.Go{ + UseTx: r.UseTx, + UpFn: r.UpFnContext, + UpFnNoTx: r.UpFnNoTxContext, + DownFn: r.DownFnContext, + DownFnNoTx: r.DownFnNoTxContext, + }, + } + migrations = append(migrations, m) + migrationLookup[r.Version] = m + } + // Sort migrations by version in ascending order. + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].Version < migrations[j].Version + }) + return migrations, nil +} + +func unregisteredError(unregistered []string) error { + f := "file" + if len(unregistered) > 1 { + f += "s" + } + var b strings.Builder + + b.WriteString(fmt.Sprintf("error: detected %d unregistered Go %s:\n", len(unregistered), f)) + for _, name := range unregistered { + b.WriteString("\t" + name + "\n") + } + b.WriteString("\n") + b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations") + + return errors.New(b.String()) +} diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go new file mode 100644 index 000000000..a5ee2d352 --- /dev/null +++ b/internal/provider/collect_test.go @@ -0,0 +1,185 @@ +package provider + +import ( + "io/fs" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" +) + +func TestCollectFileSources(t *testing.T) { + t.Parallel() + t.Run("nil", func(t *testing.T) { + sources, err := collectFileSources(nil, false, nil) + check.NoError(t, err) + check.Bool(t, sources != nil, true) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + }) + t.Run("empty", func(t *testing.T) { + sources, err := collectFileSources(fstest.MapFS{}, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + check.Bool(t, sources != nil, true) + }) + t.Run("incorrect_fsys", func(t *testing.T) { + mapFS := fstest.MapFS{ + "00000_foo.sql": sqlMapFile, + } + // strict disable - should not error + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + // strict enabled - should error + _, err = collectFileSources(mapFS, true, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "migration version must be greater than zero") + }) + t.Run("collect", func(t *testing.T) { + fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 4) + check.Number(t, len(sources.goSources), 0) + expected := fileSources{ + sqlSources: []Source{ + {Fullpath: "00001_foo.sql", Version: 1}, + {Fullpath: "00002_bar.sql", Version: 2}, + {Fullpath: "00003_baz.sql", Version: 3}, + {Fullpath: "00110_qux.sql", Version: 110}, + }, + } + for i := 0; i < len(sources.sqlSources); i++ { + check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + } + }) + t.Run("excludes", func(t *testing.T) { + fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") + check.NoError(t, err) + sources, err := collectFileSources( + fsys, + false, + // exclude 2 files explicitly + map[string]bool{ + "00002_bar.sql": true, + "00110_qux.sql": true, + }, + ) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 2) + check.Number(t, len(sources.goSources), 0) + expected := fileSources{ + sqlSources: []Source{ + {Fullpath: "00001_foo.sql", Version: 1}, + {Fullpath: "00003_baz.sql", Version: 3}, + }, + } + for i := 0; i < len(sources.sqlSources); i++ { + check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + } + }) + t.Run("strict", func(t *testing.T) { + mapFS := newSQLOnlyFS() + // Add a file with no version number + mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")} + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + _, err = collectFileSources(fsys, true, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) + }) + t.Run("skip_go_test_files", func(t *testing.T) { + mapFS := fstest.MapFS{ + "1_foo.sql": sqlMapFile, + "2_bar.sql": sqlMapFile, + "3_baz.sql": sqlMapFile, + "4_qux.sql": sqlMapFile, + "5_foo_test.go": {Data: []byte(`package goose_test`)}, + } + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 4) + check.Number(t, len(sources.goSources), 0) + }) + t.Run("skip_random_files", func(t *testing.T) { + mapFS := fstest.MapFS{ + "1_foo.sql": sqlMapFile, + "4_something.go": {Data: []byte(`package goose`)}, + "5_qux.sql": sqlMapFile, + "README.md": {Data: []byte(`# README`)}, + "LICENSE": {Data: []byte(`MIT`)}, + "no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)}, + "some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)}, + } + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 2) + check.Number(t, len(sources.goSources), 1) + // 1 + check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql") + check.Equal(t, sources.sqlSources[0].Version, int64(1)) + // 2 + check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql") + check.Equal(t, sources.sqlSources[1].Version, int64(5)) + // 3 + check.Equal(t, sources.goSources[0].Fullpath, "4_something.go") + check.Equal(t, sources.goSources[0].Version, int64(4)) + }) + t.Run("duplicate_versions", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": sqlMapFile, + "01_bar.sql": sqlMapFile, + } + _, err := collectFileSources(mapFS, false, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "found duplicate migration version 1") + }) + t.Run("dirpath", func(t *testing.T) { + mapFS := fstest.MapFS{ + "dir1/101_a.sql": sqlMapFile, + "dir1/102_b.sql": sqlMapFile, + "dir1/103_c.sql": sqlMapFile, + "dir2/201_a.sql": sqlMapFile, + "876_a.sql": sqlMapFile, + } + assertDirpath := func(dirpath string, sqlSources []Source) { + t.Helper() + f, err := fs.Sub(mapFS, dirpath) + check.NoError(t, err) + got, err := collectFileSources(f, false, nil) + check.NoError(t, err) + check.Number(t, len(got.sqlSources), len(sqlSources)) + check.Number(t, len(got.goSources), 0) + for i := 0; i < len(got.sqlSources); i++ { + check.Equal(t, got.sqlSources[i], sqlSources[i]) + } + } + assertDirpath(".", []Source{ + {Fullpath: "876_a.sql", Version: 876}, + }) + assertDirpath("dir1", []Source{ + {Fullpath: "101_a.sql", Version: 101}, + {Fullpath: "102_b.sql", Version: 102}, + {Fullpath: "103_c.sql", Version: 103}, + }) + assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}}) + assertDirpath("dir3", nil) + }) +} + +func newSQLOnlyFS() fstest.MapFS { + return fstest.MapFS{ + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.sql": sqlMapFile, + "migrations/00003_baz.sql": sqlMapFile, + "migrations/00110_qux.sql": sqlMapFile, + } +} + +var ( + sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)} +) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index c8d899511..6702f0731 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -5,22 +5,29 @@ import ( "database/sql" "errors" "io/fs" + "os" "time" + "github.com/pressly/goose/v3/internal/migrate" "github.com/pressly/goose/v3/internal/sqladapter" ) +var ( + // ErrNoMigrations is returned by [NewProvider] when no migrations are found. + ErrNoMigrations = errors.New("no migrations found") +) + // NewProvider returns a new goose Provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For // example, if the database dialect is "postgres", the database/sql driver could be // github.com/lib/pq or github.com/jackc/pgx. // -// fsys is the filesystem used to read the migration files. Most users will want to use -// os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is -// possible to use a different filesystem, such as embed.FS. +// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to +// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is +// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub. // -// Functional options are used to configure the Provider. See [ProviderOption] for more information. +// See [ProviderOption] for more information on configuring the provider. // // Unless otherwise specified, all methods on Provider are safe for concurrent use. // @@ -33,7 +40,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) return nil, errors.New("dialect must not be empty") } if fsys == nil { - return nil, errors.New("fsys must not be nil") + fsys = noopFS{} } var cfg config for _, opt := range opts { @@ -41,7 +48,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) return nil, err } } - // Set defaults + // Set defaults after applying user-supplied options so option funcs can check for empty values. if cfg.tableName == "" { cfg.tableName = defaultTablename } @@ -49,41 +56,76 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if err != nil { return nil, err } - // TODO(mf): implement the rest of this function - collect sources - merge sources into - // migrations + // Collect migrations from the filesystem and merge with registered migrations. + // + // Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed + // lazily. + // + // TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to + // return an error if there are any SQL parsing errors. This adds a bit overhead to startup + // though, so we should make it optional. + sources, err := collectFileSources(fsys, false, cfg.excludes) + if err != nil { + return nil, err + } + migrations, err := merge(sources, nil) + if err != nil { + return nil, err + } + if len(migrations) == 0 { + return nil, ErrNoMigrations + } return &Provider{ - db: db, - fsys: fsys, - cfg: cfg, - store: store, + db: db, + fsys: fsys, + cfg: cfg, + store: store, + migrations: migrations, }, nil } +type noopFS struct{} + +var _ fs.FS = noopFS{} + +func (f noopFS) Open(name string) (fs.File, error) { + return nil, os.ErrNotExist +} + // Provider is a goose migration provider. -// Experimental: This API is experimental and may change in the future. type Provider struct { - db *sql.DB - fsys fs.FS - cfg config - store sqladapter.Store + db *sql.DB + fsys fs.FS + cfg config + store sqladapter.Store + migrations []*migrate.Migration } +// State represents the state of a migration. +type State string + +const ( + // StateUntracked represents a migration that is in the database, but not on the filesystem. + StateUntracked State = "untracked" + // StatePending represents a migration that is on the filesystem, but not in the database. + StatePending State = "pending" + // StateApplied represents a migration that is in BOTH the database and on the filesystem. + StateApplied State = "applied" +) + // MigrationStatus represents the status of a single migration. type MigrationStatus struct { - // State represents the state of the migration. One of "untracked", "pending", "applied". - // - untracked: in the database, but not on the filesystem. - // - pending: on the filesystem, but not in the database. - // - applied: in both the database and on the filesystem. - State string - // AppliedAt is the time the migration was applied. Only set if state is applied or untracked. + // State is the state of the migration. + State State + // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or + // [StateUntracked]. AppliedAt time.Time - // Source is the migration source. Only set if the state is pending or applied. - Source Source + // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. + Source *Source } // Status returns the status of all migrations, merging the list of migrations from the database and // filesystem. The returned items are ordered by version, in ascending order. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { return nil, errors.New("not implemented") } @@ -91,7 +133,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { // GetDBVersion returns the max version from the database, regardless of the applied order. For // example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been // applied, it returns 0. -// Experimental: This API is experimental and may change in the future. func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { return 0, errors.New("not implemented") } @@ -111,7 +152,6 @@ const ( // For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if // the migration has a corresponding file on disk. It will be empty if the migration was registered // manually. -// Experimental: This API is experimental and may change in the future. type Source struct { // Type is the type of migration. Type SourceType @@ -123,22 +163,34 @@ type Source struct { Version int64 } -// ListSources returns a list of all available migration sources the provider is aware of. -// Experimental: This API is experimental and may change in the future. +// ListSources returns a list of all available migration sources the provider is aware of, sorted in +// ascending order by version. func (p *Provider) ListSources() []*Source { - return nil + sources := make([]*Source, 0, len(p.migrations)) + for _, m := range p.migrations { + s := &Source{ + Fullpath: m.Fullpath, + Version: m.Version, + } + switch m.Type { + case migrate.TypeSQL: + s.Type = SourceTypeSQL + case migrate.TypeGo: + s.Type = SourceTypeGo + } + sources = append(sources, s) + } + return sources } // Ping attempts to ping the database to verify a connection is available. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Ping(ctx context.Context) error { - return errors.New("not implemented") + return p.db.PingContext(ctx) } // Close closes the database connection. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Close() error { - return errors.New("not implemented") + return p.db.Close() } // MigrationResult represents the result of a single migration. @@ -150,21 +202,18 @@ type MigrationResult struct{} // // When direction is true, the up migration is executed, and when direction is false, the down // migration is executed. -// Experimental: This API is experimental and may change in the future. func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { return nil, errors.New("not implemented") } // Up applies all pending migrations. If there are no new migrations to apply, this method returns // empty list and nil error. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } // UpByOne applies the next available migration. If there are no migrations to apply, this method // returns [ErrNoNextVersion]. -// Experimental: This API is experimental and may change in the future. func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { return nil, errors.New("not implemented") } @@ -174,14 +223,12 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { // // For instance, if there are three new migrations (9,10,11) and the current database version is 8 // with a requested version of 10, only versions 9 and 10 will be applied. -// Experimental: This API is experimental and may change in the future. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } // Down rolls back the most recently applied migration. If there are no migrations to apply, this // method returns [ErrNoNextVersion]. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { return nil, errors.New("not implemented") } @@ -190,7 +237,6 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { // // For instance, if the current database version is 11, and the requested version is 9, only // migrations 11 and 10 will be rolled back. -// Experimental: This API is experimental and may change in the future. func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index bf7b9f9b2..d8060c458 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -60,10 +60,24 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption { }) } +// WithExcludes excludes the given file names from the list of migrations. +// +// If WithExcludes is called multiple times, the list of excludes is merged. +func WithExcludes(excludes []string) ProviderOption { + return configFunc(func(c *config) error { + for _, name := range excludes { + c.excludes[name] = true + } + return nil + }) +} + type config struct { tableName string verbose bool + excludes map[string]bool + // Locking options lockEnabled bool sessionLocker lock.SessionLocker } diff --git a/internal/provider/provider_options_test.go b/internal/provider/provider_options_test.go index 82362bad1..89a1cda16 100644 --- a/internal/provider/provider_options_test.go +++ b/internal/provider/provider_options_test.go @@ -1,13 +1,13 @@ -package provider +package provider_test import ( "database/sql" - "io/fs" "path/filepath" "testing" "testing/fstest" "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" _ "modernc.org/sqlite" ) @@ -15,86 +15,52 @@ func TestNewProvider(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) check.NoError(t, err) - fsys := newFsys() + fsys := fstest.MapFS{ + "1_foo.sql": {Data: []byte(migration1)}, + "2_bar.sql": {Data: []byte(migration2)}, + "3_baz.sql": {Data: []byte(migration3)}, + "4_qux.sql": {Data: []byte(migration4)}, + } t.Run("invalid", func(t *testing.T) { // Empty dialect not allowed - _, err = NewProvider("", db, fsys) + _, err = provider.NewProvider("", db, fsys) check.HasError(t, err) // Invalid dialect not allowed - _, err = NewProvider("unknown-dialect", db, fsys) + _, err = provider.NewProvider("unknown-dialect", db, fsys) check.HasError(t, err) // Nil db not allowed - _, err = NewProvider("sqlite3", nil, fsys) + _, err = provider.NewProvider("sqlite3", nil, fsys) check.HasError(t, err) // Nil fsys not allowed - _, err = NewProvider("sqlite3", db, nil) + _, err = provider.NewProvider("sqlite3", db, nil) check.HasError(t, err) // Duplicate table name not allowed - _, err = NewProvider("sqlite3", db, fsys, WithTableName("foo"), WithTableName("bar")) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName("foo"), + provider.WithTableName("bar"), + ) check.HasError(t, err) check.Equal(t, `table already set to "foo"`, err.Error()) // Empty table name not allowed - _, err = NewProvider("sqlite3", db, fsys, WithTableName("")) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName(""), + ) check.HasError(t, err) check.Equal(t, "table must not be empty", err.Error()) }) t.Run("valid", func(t *testing.T) { // Valid dialect, db, and fsys allowed - _, err = NewProvider("sqlite3", db, fsys) + _, err = provider.NewProvider("sqlite3", db, fsys) check.NoError(t, err) // Valid dialect, db, fsys, and table name allowed - _, err = NewProvider("sqlite3", db, fsys, WithTableName("foo")) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName("foo"), + ) check.NoError(t, err) // Valid dialect, db, fsys, and verbose allowed - _, err = NewProvider("sqlite3", db, fsys, WithVerbose()) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithVerbose(), + ) check.NoError(t, err) }) } - -func newFsys() fs.FS { - return fstest.MapFS{ - "1_foo.sql": {Data: []byte(migration1)}, - "2_bar.sql": {Data: []byte(migration2)}, - "3_baz.sql": {Data: []byte(migration3)}, - "4_qux.sql": {Data: []byte(migration4)}, - } -} - -var ( - migration1 = ` --- +goose Up -CREATE TABLE foo (id INTEGER PRIMARY KEY); --- +goose Down -DROP TABLE foo; -` - migration2 = ` --- +goose Up -ALTER TABLE foo ADD COLUMN name TEXT; --- +goose Down -ALTER TABLE foo DROP COLUMN name; -` - migration3 = ` --- +goose Up -CREATE TABLE bar ( - id INTEGER PRIMARY KEY, - description TEXT -); --- +goose Down -DROP TABLE bar; -` - migration4 = ` --- +goose Up --- Rename the 'foo' table to 'my_foo' -ALTER TABLE foo RENAME TO my_foo; - --- Add a new column 'timestamp' to 'my_foo' -ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; - --- +goose Down --- Remove the 'timestamp' column from 'my_foo' -ALTER TABLE my_foo DROP COLUMN timestamp; - --- Rename the 'my_foo' table back to 'foo' -ALTER TABLE my_foo RENAME TO foo; -` -) diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go new file mode 100644 index 000000000..10aed48e0 --- /dev/null +++ b/internal/provider/provider_test.go @@ -0,0 +1,83 @@ +package provider_test + +import ( + "database/sql" + "errors" + "io/fs" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" + _ "modernc.org/sqlite" +) + +func TestProvider(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + t.Run("empty", func(t *testing.T) { + _, err := provider.NewProvider("sqlite3", db, fstest.MapFS{}) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true) + }) + + mapFS := fstest.MapFS{ + "migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)}, + "migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + p, err := provider.NewProvider("sqlite3", db, fsys) + check.NoError(t, err) + sources := p.ListSources() + check.Equal(t, len(sources), 2) + // 1 + check.Equal(t, sources[0].Version, int64(1)) + check.Equal(t, sources[0].Fullpath, "001_foo.sql") + check.Equal(t, sources[0].Type, provider.SourceTypeSQL) + // 2 + check.Equal(t, sources[1].Version, int64(2)) + check.Equal(t, sources[1].Fullpath, "002_bar.sql") + check.Equal(t, sources[1].Type, provider.SourceTypeSQL) +} + +var ( + migration1 = ` +-- +goose Up +CREATE TABLE foo (id INTEGER PRIMARY KEY); +-- +goose Down +DROP TABLE foo; +` + migration2 = ` +-- +goose Up +ALTER TABLE foo ADD COLUMN name TEXT; +-- +goose Down +ALTER TABLE foo DROP COLUMN name; +` + migration3 = ` +-- +goose Up +CREATE TABLE bar ( + id INTEGER PRIMARY KEY, + description TEXT +); +-- +goose Down +DROP TABLE bar; +` + migration4 = ` +-- +goose Up +-- Rename the 'foo' table to 'my_foo' +ALTER TABLE foo RENAME TO my_foo; + +-- Add a new column 'timestamp' to 'my_foo' +ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; + +-- +goose Down +-- Remove the 'timestamp' column from 'my_foo' +ALTER TABLE my_foo DROP COLUMN timestamp; + +-- Rename the 'my_foo' table back to 'foo' +ALTER TABLE my_foo RENAME TO foo; +` +) diff --git a/lock/postgres_test.go b/lock/postgres_test.go index bfb1a0d99..2622d5cb6 100644 --- a/lock/postgres_test.go +++ b/lock/postgres_test.go @@ -14,19 +14,20 @@ import ( ) func TestPostgresSessionLocker(t *testing.T) { + t.Parallel() if testing.Short() { t.Skip("skip long running test") } db, cleanup, err := testdb.NewPostgres() check.NoError(t, err) t.Cleanup(cleanup) - const ( - lockID int64 = 123456789 - ) // Do not run tests in parallel, because they are using the same database. t.Run("lock_and_unlock", func(t *testing.T) { + const ( + lockID int64 = 123456789 + ) locker, err := lock.NewPostgresSessionLocker( lock.WithLockID(lockID), lock.WithLockTimeout(4*time.Second), diff --git a/migration.go b/migration.go index dcf0c6118..619e934d0 100644 --- a/migration.go +++ b/migration.go @@ -218,27 +218,27 @@ func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, d return store.DeleteVersionNoTx(ctx, db, TableName(), version) } -// NumericComponent looks for migration scripts with names in the form: -// XXX_descriptivename.ext where XXX specifies the version number -// and ext specifies the type of migration -func NumericComponent(name string) (int64, error) { - base := filepath.Base(name) - +// NumericComponent parses the version from the migration file name. +// +// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of +// migration, either .sql or .go. +func NumericComponent(filename string) (int64, error) { + base := filepath.Base(filename) if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { - return 0, errors.New("not a recognized migration file type") + return 0, errors.New("migration file does not have .sql or .go file extension") } - idx := strings.Index(base, "_") if idx < 0 { return 0, errors.New("no filename separator '_' found") } - - n, e := strconv.ParseInt(base[:idx], 10, 64) - if e == nil && n <= 0 { - return 0, errors.New("migration IDs must be greater than zero") + n, err := strconv.ParseInt(base[:idx], 10, 64) + if err != nil { + return 0, err } - - return n, e + if n < 1 { + return 0, errors.New("migration version must be greater than zero") + } + return n, nil } func truncateDuration(d time.Duration) time.Duration {