Skip to content

Commit

Permalink
feat: Store migration label in DB
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelespinoza committed Jul 11, 2024
1 parent ac48d69 commit 6748981
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 122 deletions.
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Driver interface {
// UpdateSchemaMigrations records a timestamped version of a migration that
// has been successfully applied by adding a new row to the schema
// migrations table.
UpdateSchemaMigrations(forward bool, version string) error
UpdateSchemaMigrations(forward bool, version string, label string) error
}

// AppliedVersions represents an iterative list of migrations that have been run
Expand Down
15 changes: 9 additions & 6 deletions drivers/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,17 @@ func (d *driver) Execute(query string, args ...interface{}) (err error) {

func (d *driver) CreateSchemaMigrationsTable() (err error) {
err = d.connection.Query(
`CREATE TABLE IF NOT EXISTS schema_migrations (migration_id TEXT PRIMARY KEY)`,
`CREATE TABLE IF NOT EXISTS schema_migrations (
migration_id TEXT PRIMARY KEY,
label TEXT
)`,
).Exec()
return
}

func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
query := d.connection.Query(
`SELECT migration_id FROM schema_migrations`,
`SELECT migration_id, label FROM schema_migrations`,
)

av := execAllAscending(query)
Expand Down Expand Up @@ -101,13 +104,13 @@ func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
return
}

func (d *driver) UpdateSchemaMigrations(forward bool, version string) (err error) {
func (d *driver) UpdateSchemaMigrations(forward bool, version, label string) (err error) {
conn := d.connection
if forward {
err = conn.Query(`
INSERT INTO schema_migrations (migration_id)
VALUES (?)`,
version,
INSERT INTO schema_migrations (migration_id, label)
VALUES (?, ?)`,
version, label,
).Exec()
} else {
err = conn.Query(`
Expand Down
49 changes: 37 additions & 12 deletions drivers/cassandra/versions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cassandra

import (
"database/sql"
"fmt"
"sort"

Expand All @@ -12,10 +13,12 @@ import (
// check if an error was encountered.
func execAllAscending(query *gocql.Query) *appliedVersions {
scanner := query.Iter().Scanner()
av := appliedVersions{versions: make([]string, 0)}
av := appliedVersions{versions: make([]migration, 0)}

defer func() {
sort.Strings(av.versions)
sort.Slice(av.versions, func(i, j int) bool {
return av.versions[i].id < av.versions[j].id
})

// The Err method also releases resources. The scanner should not be
// used after this point.
Expand All @@ -31,43 +34,65 @@ func execAllAscending(query *gocql.Query) *appliedVersions {
// Read it all up front so DB resources can be closed while also avoid nil
// access errors.
for scanner.Next() {
var version string
if err := scanner.Scan(&version); err != nil {
var version, label string
if err := scanner.Scan(&version, &label); err != nil {
av.err = err
return &av
}
av.versions = append(av.versions, version)
av.versions = append(av.versions, migration{version, label})
}

return &av
}

type appliedVersions struct {
counter int
versions []string
versions []migration
err error
}

func (a *appliedVersions) Close() error { return a.err }

func (a *appliedVersions) Next() bool {
if a.err != nil {
return false
}
return a.counter < len(a.versions)
}

// Scan is called by the godfish library. Unlike sql.Driver-based
// implementations, the data has already been read from the DB by the time this
// function is called. See details in the execAllAscending function.
func (a *appliedVersions) Scan(dest ...interface{}) error {
if a.err != nil {
return a.err
}

out, ok := dest[0].(*string)
if !ok {
return fmt.Errorf("dest argument should be a %T", out)
}
if !a.Next() {
return nil
}
*out = a.versions[a.counter]
curr := a.versions[a.counter]

switch val := dest[0].(type) {
case *string:
*val = curr.id
default:
return fmt.Errorf("unexpected type (%T) for %q field", val, "migration_id")
}

switch val := dest[1].(type) {
case *sql.NullString:
if err := val.Scan(curr.label); err != nil {
return fmt.Errorf("failed to Scan %q field: %w", "label", err)
}
default:
return fmt.Errorf("unexpected type (%T) for %q field", val, "label")
}

a.counter++
return nil
}

type migration struct {
id string
label string
}
17 changes: 9 additions & 8 deletions drivers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ func (d *driver) Execute(query string, args ...interface{}) (err error) {
func (d *driver) CreateSchemaMigrationsTable() (err error) {
_, err = d.connection.Exec(
`CREATE TABLE IF NOT EXISTS schema_migrations (
migration_id VARCHAR(128) PRIMARY KEY NOT NULL
migration_id VARCHAR(128) PRIMARY KEY NOT NULL,
label VARCHAR(255) DEFAULT ''
)`)
return
}

func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
rows, err := d.connection.Query(
`SELECT migration_id FROM schema_migrations ORDER BY migration_id ASC`,
`SELECT migration_id, label FROM schema_migrations ORDER BY migration_id ASC`,
)
if ierr, ok := err.(*my.MySQLError); ok {
// https://dev.mysql.com/doc/refman/8.0/en/server-error-reference.html#error_er_no_such_table
Expand All @@ -92,18 +93,18 @@ func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
return
}

func (d *driver) UpdateSchemaMigrations(forward bool, version string) (err error) {
func (d *driver) UpdateSchemaMigrations(forward bool, version, label string) (err error) {
conn := d.connection
if forward {
_, err = conn.Exec(`
INSERT INTO schema_migrations (migration_id)
VALUES (?)`,
INSERT INTO schema_migrations (migration_id, label)
VALUES (?, ?)`,
version,
label,
)
} else {
_, err = conn.Exec(`
DELETE FROM schema_migrations
WHERE migration_id = ?`,
_, err = conn.Exec(
`DELETE FROM schema_migrations WHERE migration_id = ?`,
version,
)
}
Expand Down
12 changes: 7 additions & 5 deletions drivers/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ func (d *driver) Execute(query string, args ...interface{}) (err error) {
func (d *driver) CreateSchemaMigrationsTable() (err error) {
_, err = d.connection.Exec(
`CREATE TABLE IF NOT EXISTS schema_migrations (
migration_id VARCHAR(128) PRIMARY KEY NOT NULL
migration_id VARCHAR(128) PRIMARY KEY NOT NULL,
label VARCHAR(255) DEFAULT ''
)`)
return
}

func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
rows, err := d.connection.Query(
`SELECT migration_id FROM schema_migrations ORDER BY migration_id ASC`,
`SELECT migration_id, label FROM schema_migrations ORDER BY migration_id ASC`,
)
if ierr, ok := err.(*pq.Error); ok {
// https://www.postgresql.org/docs/current/errcodes-appendix.html
Expand All @@ -65,14 +66,15 @@ func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
return
}

func (d *driver) UpdateSchemaMigrations(forward bool, version string) (err error) {
func (d *driver) UpdateSchemaMigrations(forward bool, version, label string) (err error) {
conn := d.connection
if forward {
_, err = conn.Exec(`
INSERT INTO schema_migrations (migration_id)
VALUES ($1)
INSERT INTO schema_migrations (migration_id, label)
VALUES ($1, $2)
RETURNING migration_id`,
version,
label,
)
} else {
_, err = conn.Exec(`
Expand Down
13 changes: 7 additions & 6 deletions drivers/sqlite3/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ func (d *driver) Execute(query string, args ...interface{}) (err error) {
func (d *driver) CreateSchemaMigrationsTable() (err error) {
_, err = d.connection.Exec(
`CREATE TABLE IF NOT EXISTS schema_migrations (
migration_id VARCHAR(128) PRIMARY KEY NOT NULL
migration_id VARCHAR(128) PRIMARY KEY NOT NULL,
label VARCHAR(255) DEFAULT ''
)`)
return
}

func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
rows, err := d.connection.Query(
`SELECT migration_id FROM schema_migrations ORDER BY migration_id ASC`,
`SELECT migration_id, label FROM schema_migrations ORDER BY migration_id ASC`,
)

var ierr *sqlib.Error
Expand All @@ -67,13 +68,13 @@ func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
return
}

func (d *driver) UpdateSchemaMigrations(forward bool, version string) (err error) {
func (d *driver) UpdateSchemaMigrations(forward bool, version string, label string) (err error) {
conn := d.connection
if forward {
_, err = conn.Exec(`
INSERT INTO schema_migrations (migration_id)
VALUES ($1)`,
version,
INSERT INTO schema_migrations (migration_id, label)
VALUES ($1, $2)`,
version, label,
)
} else {
_, err = conn.Exec(`
Expand Down
15 changes: 9 additions & 6 deletions drivers/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@ func (d *driver) CreateSchemaMigrationsTable() (err error) {
IF NOT EXISTS (
SELECT 1 FROM information_schema.tables WHERE table_schema = (SELECT schema_name()) AND table_name = 'schema_migrations'
)
CREATE TABLE schema_migrations (migration_id VARCHAR(128) PRIMARY KEY NOT NULL)
CREATE TABLE schema_migrations (
migration_id VARCHAR(128) PRIMARY KEY NOT NULL,
label VARCHAR(255) DEFAULT ''
)
`)
return
}

func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
rows, err := d.connection.Query(`SELECT migration_id FROM schema_migrations ORDER BY migration_id ASC`)
rows, err := d.connection.Query(`SELECT migration_id, label FROM schema_migrations ORDER BY migration_id ASC`)

var ierr mssql.Error
// https://docs.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors
Expand All @@ -72,13 +75,13 @@ func (d *driver) AppliedVersions() (out godfish.AppliedVersions, err error) {
return
}

func (d *driver) UpdateSchemaMigrations(forward bool, version string) (err error) {
func (d *driver) UpdateSchemaMigrations(forward bool, version, label string) (err error) {
conn := d.connection
if forward {
_, err = conn.Exec(`
INSERT INTO schema_migrations (migration_id)
VALUES (@p1)`,
version,
INSERT INTO schema_migrations (migration_id, label)
VALUES (@p1, @p2)`,
version, label,
)
} else {
_, err = conn.Exec(`
Expand Down
5 changes: 4 additions & 1 deletion godfish.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package godfish

import (
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -185,6 +186,7 @@ func runMigration(driver Driver, pathToFile string, mig internal.Migration) (err
err = driver.UpdateSchemaMigrations(
mig.Indirection().Value == internal.DirForward,
mig.Version().String(),
mig.Label(),
)
if err == nil {
fmt.Fprintln(os.Stderr, "ok")
Expand Down Expand Up @@ -373,8 +375,9 @@ func scanAppliedVersions(driver Driver, directoryPath string) (out []internal.Mi
defer rows.Close()
for rows.Next() {
var version, basename string
var label sql.NullString
var mig internal.Migration
if err = rows.Scan(&version); err != nil {
if err = rows.Scan(&version, &label); err != nil {
return
}
basename, err = figureOutBasename(directoryPath, internal.DirForward, version)
Expand Down
44 changes: 32 additions & 12 deletions internal/stub/appliedversions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package stub

import (
"database/sql"
"errors"
"fmt"

"github.com/rafaelespinoza/godfish"
Expand All @@ -9,17 +11,17 @@ import (

type appliedVersions struct {
counter int
versions []string
versions []internal.Migration
}

// NewAppliedVersions constructs an in-memory AppliedVersions implementation for
// testing purposes.
func NewAppliedVersions(migrations ...internal.Migration) godfish.AppliedVersions {
out := appliedVersions{
versions: make([]string, len(migrations)),
versions: make([]internal.Migration, len(migrations)),
}
for i, mig := range migrations {
out.versions[i] = mig.Version().String()
out.versions[i] = mig
}
return &out
}
Expand All @@ -31,16 +33,34 @@ func (r *appliedVersions) Close() error {

func (r *appliedVersions) Next() bool { return r.counter < len(r.versions) }

func (r *appliedVersions) Scan(dest ...interface{}) error {
var out *string
if s, ok := dest[0].(*string); !ok {
return fmt.Errorf("pass in *string; got %T", s)
} else if !r.Next() {
return fmt.Errorf("no more results")
} else {
out = s
func (r *appliedVersions) Scan(dest ...interface{}) (err error) {
if len(dest) != 2 {
err = fmt.Errorf("expected 2 args, got %d", len(dest))
return
}
*out = r.versions[r.counter]
if !r.Next() {
err = errors.New("no more results")
return
}

curr := r.versions[r.counter]
r.counter++

switch val := dest[0].(type) {
case *string:
*val = curr.Version().String()
default:
return fmt.Errorf("unexpected type (%T) for %q field", val, "version")
}

switch val := dest[1].(type) {
case *sql.NullString:
if err = val.Scan(curr.Label()); err != nil {
return fmt.Errorf("failed to Scan %q field: %w", "label", err)
}
default:
return fmt.Errorf("unexpected type (got %T) for %q field", val, "label")
}

return nil
}
Loading

0 comments on commit 6748981

Please sign in to comment.