diff --git a/driver.go b/driver.go index a9b7891..d7c3cb7 100644 --- a/driver.go +++ b/driver.go @@ -12,13 +12,11 @@ type Driver interface { // internal reference to that connection for later use. This library might // call this method multiple times, so use the internal reference if it's // present instead of reconnecting to the database. - Connect() (*sql.DB, error) + Connect(dsn string) (*sql.DB, error) // Close should check if there's an internal reference to a database // connection (a *sql.DB) and if it's present, close it. Then reset the // internal reference to that connection to nil. Close() error - // DSN returns data source name info, ie: how do I connect? - DSN() DSN // AppliedVersions queries the schema migrations table for migration // versions that have been executed against the database. If the schema @@ -39,43 +37,6 @@ type Driver interface { UpdateSchemaMigrations(dir Direction, version string) error } -// NewDriver initializes a Driver implementation by name and connection -// parameters. -func NewDriver(dsn DSN, migConf *MigrationsConf) (driver Driver, err error) { - return dsn.NewDriver(migConf) -} - -// DSN generates a data source name or connection URL for DB connections. The -// output will be passed to the standard library's sql.Open method. -type DSN interface { - // Boot takes inputs from the host environment so it can create a Driver. - // - // Deprecated: Set the DB_DSN environment variable instead of using this. - Boot(ConnectionParams) error - // NewDriver calls the constructor of the corresponding Driver. - NewDriver(*MigrationsConf) (Driver, error) - // String uses connection parameters to form the data source name. - String() string -} - -// ConnectionParams is what to use when initializing a DSN. -// -// Deprecated: Set the DB_DSN environment variable instead of using this. -type ConnectionParams struct { - Encoding string // Encoding is the client encoding for the connection. - Host string // Host is the name of the host to connect to. - Name string // Name is the database name. - Pass string // Pass is the password to use for the connection. - Port string // Port is the connection port. - User string // User is the name of the user to connect as. -} - -// MigrationsConf is intended to lend customizations such as specifying the path -// to the migration files. -type MigrationsConf struct { - PathToFiles string `json:"path_to_files"` -} - // AppliedVersions represents an iterative list of migrations that have been run // against the database and have been recorded in the schema migrations table. // It's enough to convert a *sql.Rows struct when implementing the Driver diff --git a/drivers/mysql/godfish/main.go b/drivers/mysql/godfish/main.go index 22dcb40..a2ced4d 100644 --- a/drivers/mysql/godfish/main.go +++ b/drivers/mysql/godfish/main.go @@ -9,9 +9,7 @@ import ( ) func main() { - var dsn mysql.DSN - err := commands.Run(&dsn) - if err != nil { + if err := commands.Run(mysql.NewDriver()); err != nil { log.Println(err) os.Exit(1) } diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index c25851d..c513f8e 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -3,7 +3,6 @@ package mysql import ( "database/sql" "fmt" - "os" "regexp" "strings" @@ -11,56 +10,21 @@ import ( "github.com/rafaelespinoza/godfish" ) -// DSN implements the godfish.DSN interface and defines keys, values needed to -// connect to a mysql database. -type DSN struct { - godfish.ConnectionParams -} - -var _ godfish.DSN = (*DSN)(nil) - -// Boot initializes the DSN from environment inputs. -func (p *DSN) Boot(params godfish.ConnectionParams) error { - p.ConnectionParams = params - return nil -} - // NewDriver creates a new mysql driver. -func (p *DSN) NewDriver(migConf *godfish.MigrationsConf) (godfish.Driver, error) { - return newMySQL(*p) -} - -// String generates a data source name (or connection URL) based on the fields. -func (p *DSN) String() string { - return os.Getenv("DB_DSN") -} +func NewDriver() godfish.Driver { return &driver{} } // driver implements the godfish.Driver interface for mysql databases. type driver struct { connection *sql.DB - dsn DSN -} - -var _ godfish.Driver = (*driver)(nil) - -func newMySQL(dsn DSN) (*driver, error) { - if dsn.Host == "" { - dsn.Host = "localhost" - } - if dsn.Port == "" { - dsn.Port = "3306" - } - return &driver{dsn: dsn}, nil } -func (d *driver) Name() string { return "mysql" } -func (d *driver) DSN() godfish.DSN { return &d.dsn } -func (d *driver) Connect() (conn *sql.DB, err error) { +func (d *driver) Name() string { return "mysql" } +func (d *driver) Connect(dsn string) (conn *sql.DB, err error) { if d.connection != nil { conn = d.connection return } - if conn, err = sql.Open(d.Name(), d.DSN().String()); err != nil { + if conn, err = sql.Open(d.Name(), dsn); err != nil { return } d.connection = conn diff --git a/drivers/mysql/mysql_test.go b/drivers/mysql/mysql_test.go index e0854f9..e8b8ace 100644 --- a/drivers/mysql/mysql_test.go +++ b/drivers/mysql/mysql_test.go @@ -8,5 +8,5 @@ import ( ) func Test(t *testing.T) { - internal.RunDriverTests(t, &mysql.DSN{}) + internal.RunDriverTests(t, mysql.NewDriver()) } diff --git a/drivers/postgres/godfish/main.go b/drivers/postgres/godfish/main.go index ad60579..9ed9fa4 100644 --- a/drivers/postgres/godfish/main.go +++ b/drivers/postgres/godfish/main.go @@ -9,9 +9,7 @@ import ( ) func main() { - var dsn postgres.DSN - err := commands.Run(&dsn) - if err != nil { + if err := commands.Run(postgres.NewDriver()); err != nil { log.Println(err) os.Exit(1) } diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index 57c566b..d594b46 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -2,62 +2,26 @@ package postgres import ( "database/sql" - "os" "github.com/lib/pq" "github.com/rafaelespinoza/godfish" ) -// DSN implements the godfish.DSN interface and defines keys, values needed to -// connect to a postgres database. -type DSN struct { - godfish.ConnectionParams -} - -var _ godfish.DSN = (*DSN)(nil) - -// Boot initializes the DSN from environment inputs. -func (p *DSN) Boot(params godfish.ConnectionParams) error { - p.ConnectionParams = params - return nil -} - // NewDriver creates a new postgres driver. -func (p *DSN) NewDriver(migConf *godfish.MigrationsConf) (godfish.Driver, error) { - return newPostgres(*p) -} - -// String generates a data source name (or connection URL) based on the fields. -func (p *DSN) String() string { - return os.Getenv("DB_DSN") -} +func NewDriver() godfish.Driver { return &driver{} } // driver implements the Driver interface for postgres databases. type driver struct { connection *sql.DB - dsn DSN -} - -var _ godfish.Driver = (*driver)(nil) - -func newPostgres(dsn DSN) (*driver, error) { - if dsn.Host == "" { - dsn.Host = "localhost" - } - if dsn.Port == "" { - dsn.Port = "5432" - } - return &driver{dsn: dsn}, nil } -func (d *driver) Name() string { return "postgres" } -func (d *driver) DSN() godfish.DSN { return &d.dsn } -func (d *driver) Connect() (conn *sql.DB, err error) { +func (d *driver) Name() string { return "postgres" } +func (d *driver) Connect(dsn string) (conn *sql.DB, err error) { if d.connection != nil { conn = d.connection return } - if conn, err = sql.Open(d.Name(), d.DSN().String()); err != nil { + if conn, err = sql.Open(d.Name(), dsn); err != nil { return } d.connection = conn diff --git a/drivers/postgres/postgres_test.go b/drivers/postgres/postgres_test.go index bfce968..e41d8df 100644 --- a/drivers/postgres/postgres_test.go +++ b/drivers/postgres/postgres_test.go @@ -8,5 +8,5 @@ import ( ) func Test(t *testing.T) { - internal.RunDriverTests(t, &postgres.DSN{}) + internal.RunDriverTests(t, postgres.NewDriver()) } diff --git a/godfish.go b/godfish.go index a4c869c..e653443 100644 --- a/godfish.go +++ b/godfish.go @@ -14,7 +14,16 @@ import ( // Migrate executes all migrations at directoryPath in the specified direction. func Migrate(driver Driver, directoryPath string, direction Direction, finishAtVersion string) (err error) { - var migrations []Migration + var ( + dsn string + migrations []Migration + ) + if dsn, err = getDSN(); err != nil { + return + } + if _, err = driver.Connect(dsn); err != nil { + return + } defer driver.Close() if finishAtVersion == "" && direction == DirForward { @@ -23,9 +32,6 @@ func Migrate(driver Driver, directoryPath string, direction Direction, finishAtV finishAtVersion = minVersion } - if _, err = driver.Connect(); err != nil { - return - } finder := migrationFinder{ direction: direction, directoryPath: directoryPath, @@ -57,17 +63,25 @@ var ( // ApplyMigration runs a migration at directoryPath with the specified version // and direction. func ApplyMigration(driver Driver, directoryPath string, direction Direction, version string) (err error) { - var mig Migration - var pathToFile string + var ( + dsn string + pathToFile string + mig Migration + ) + + if dsn, err = getDSN(); err != nil { + return + } + if _, err = driver.Connect(dsn); err != nil { + return + } defer driver.Close() if direction == DirUnknown { err = fmt.Errorf("unknown Direction %q", direction) return } - if _, err = driver.Connect(); err != nil { - return - } + if version == "" { // attempt to find the next version to apply in the direction var limit string @@ -183,7 +197,11 @@ func runMigration(driver Driver, pathToFile string, mig Migration) (err error) { // the database. Running any migration will create the table, so you don't // usually need to call this function. func CreateSchemaMigrationsTable(driver Driver) (err error) { - if _, err = driver.Connect(); err != nil { + var dsn string + if dsn, err = getDSN(); err != nil { + return + } + if _, err = driver.Connect(dsn); err != nil { return err } defer driver.Close() @@ -192,7 +210,11 @@ func CreateSchemaMigrationsTable(driver Driver) (err error) { // Info displays the outputs of various helper functions. func Info(driver Driver, directoryPath string, direction Direction, finishAtVersion string) (err error) { - if _, err = driver.Connect(); err != nil { + var dsn string + if dsn, err = getDSN(); err != nil { + return + } + if _, err = driver.Connect(dsn); err != nil { return err } defer driver.Close() @@ -206,6 +228,11 @@ func Info(driver Driver, directoryPath string, direction Direction, finishAtVers return } +// Config is for various runtime settings. +type Config struct { + PathToFiles string `json:"path_to_files"` +} + // Init creates a configuration file at pathToFile unless it already exists. func Init(pathToFile string) (err error) { _, err = os.Stat(pathToFile) @@ -218,7 +245,7 @@ func Init(pathToFile string) (err error) { } var data []byte - if data, err = json.MarshalIndent(MigrationsConf{}, "", "\t"); err != nil { + if data, err = json.MarshalIndent(Config{}, "", "\t"); err != nil { return err } return os.WriteFile( @@ -454,3 +481,13 @@ func printMigrations(migrations []Migration) { fmt.Printf("\t%-20s | %-s\n", mig.Version().String(), makeMigrationFilename(mig)) } } + +const dsnKey = "DB_DSN" + +func getDSN() (dsn string, err error) { + dsn = os.Getenv(dsnKey) + if dsn == "" { + err = fmt.Errorf("missing environment variable: %s", dsnKey) + } + return +} diff --git a/godfish_test.go b/godfish_test.go index 9465440..964d622 100644 --- a/godfish_test.go +++ b/godfish_test.go @@ -33,7 +33,7 @@ func TestInit(t *testing.T) { if err = godfish.Init(pathToFile); err != nil { t.Fatalf("something else is wrong with setup; %v", err) } - var conf godfish.MigrationsConf + var conf godfish.Config if data, err := os.ReadFile(pathToFile); err != nil { t.Fatal(err) } else if err = json.Unmarshal(data, &conf); err != nil { @@ -54,7 +54,7 @@ func TestInit(t *testing.T) { if err := godfish.Init(pathToFile); err != nil { t.Fatal(err) } - var conf2 godfish.MigrationsConf + var conf2 godfish.Config if data, err := os.ReadFile(pathToFile); err != nil { t.Fatal(err) } else if err = json.Unmarshal(data, &conf2); err != nil { diff --git a/internal/commands/commands.go b/internal/commands/commands.go index a9097d6..c2d3f66 100644 --- a/internal/commands/commands.go +++ b/internal/commands/commands.go @@ -17,7 +17,7 @@ type arguments struct { Conf string Debug bool Direction string - DSN godfish.DSN + Driver godfish.Driver Files string Name string Reversible bool @@ -32,9 +32,9 @@ var ( ) // Run does all the CLI things. -func Run(dsn godfish.DSN) (err error) { +func Run(driver godfish.Driver) (err error) { flag.Parse() - args.DSN = dsn + args.Driver = driver var cmd *subcommand @@ -123,7 +123,7 @@ func initSubcommand(positionalArgs []string, a *arguments) (subcmd *subcommand, } // Read configuration file, if present. Negotiate with Args. - var conf godfish.MigrationsConf + var conf godfish.Config if data, ierr := os.ReadFile(a.Conf); ierr != nil { // probably no config file present, rely on Args instead. } else if ierr = json.Unmarshal(data, &conf); ierr != nil { @@ -225,12 +225,8 @@ var subcommands = map[string]*subcommand{ return flags }, run: func(a arguments) error { - driver, err := bootDriver(a.DSN) - if err != nil { - return err - } direction := whichDirection(a) - return godfish.Info(driver, a.Files, direction, a.Version) + return godfish.Info(a.Driver, a.Files, direction, a.Version) }, }, "init": &subcommand{ @@ -281,13 +277,8 @@ var subcommands = map[string]*subcommand{ return flags }, run: func(a arguments) error { - driver, err := bootDriver(a.DSN) - if err != nil { - return err - } - - err = godfish.Migrate( - driver, + err := godfish.Migrate( + a.Driver, a.Files, godfish.DirForward, a.Version, @@ -311,24 +302,11 @@ var subcommands = map[string]*subcommand{ return flags }, run: func(a arguments) error { - driver, err := bootDriver(a.DSN) + err := godfish.ApplyMigration(a.Driver, a.Files, godfish.DirReverse, "") if err != nil { return err } - if err = godfish.ApplyMigration( - driver, - a.Files, - godfish.DirReverse, - "", - ); err != nil { - return err - } - return godfish.ApplyMigration( - driver, - a.Files, - godfish.DirForward, - "", - ) + return godfish.ApplyMigration(a.Driver, a.Files, godfish.DirForward, "") }, }, "rollback": &subcommand{ @@ -356,20 +334,17 @@ var subcommands = map[string]*subcommand{ return flags }, run: func(a arguments) error { - driver, err := bootDriver(a.DSN) - if err != nil { - return err - } + var err error if a.Version == "" { err = godfish.ApplyMigration( - driver, + a.Driver, a.Files, godfish.DirReverse, a.Version, ) } else { err = godfish.Migrate( - driver, + a.Driver, a.Files, godfish.DirReverse, a.Version, @@ -381,21 +356,6 @@ var subcommands = map[string]*subcommand{ "version": _Version, } -func bootDriver(dsn godfish.DSN) (driver godfish.Driver, err error) { - connParams := godfish.ConnectionParams{ - Host: os.Getenv("DB_HOST"), - Name: os.Getenv("DB_NAME"), - Pass: os.Getenv("DB_PASSWORD"), - Port: os.Getenv("DB_PORT"), - User: os.Getenv("DB_USER"), - } - if err = dsn.Boot(connParams); err != nil { - return - } - driver, err = dsn.NewDriver(nil) - return -} - func whichDirection(a arguments) (direction godfish.Direction) { direction = godfish.DirForward d := strings.ToLower(a.Direction) diff --git a/internal/stub/driver.go b/internal/stub/driver.go index 71b984d..2b06e72 100644 --- a/internal/stub/driver.go +++ b/internal/stub/driver.go @@ -10,7 +10,6 @@ import ( ) type Driver struct { - dsn DSN connection *sql.DB appliedVersions godfish.AppliedVersions err error @@ -19,10 +18,9 @@ type Driver struct { var _ godfish.Driver = (*Driver)(nil) -func (d *Driver) Name() string { return "stub" } -func (d *Driver) Connect() (*sql.DB, error) { return d.connection, d.err } -func (d *Driver) Close() error { return d.err } -func (d *Driver) DSN() godfish.DSN { return &d.dsn } +func (d *Driver) Name() string { return "stub" } +func (d *Driver) Connect(dsn string) (*sql.DB, error) { return d.connection, d.err } +func (d *Driver) Close() error { return d.err } func (d *Driver) CreateSchemaMigrationsTable() error { if d.appliedVersions == nil { d.appliedVersions = MakeAppliedVersions() @@ -77,19 +75,6 @@ func (d *Driver) Teardown() { d.appliedVersions = MakeAppliedVersions() } -type DSN struct{ godfish.ConnectionParams } - -func (d DSN) Boot(params godfish.ConnectionParams) error { - d.ConnectionParams = params - return nil -} -func (d DSN) NewDriver(migConf *godfish.MigrationsConf) (godfish.Driver, error) { - return &Driver{dsn: d}, nil -} -func (d DSN) String() string { return "this://is.a/test" } - -var _ godfish.DSN = (*DSN)(nil) - type AppliedVersions struct { counter int versions []string diff --git a/internal/stub/driver_test.go b/internal/stub/driver_test.go index 1df8285..660910e 100644 --- a/internal/stub/driver_test.go +++ b/internal/stub/driver_test.go @@ -8,5 +8,6 @@ import ( ) func Test(t *testing.T) { - internal.RunDriverTests(t, &stub.DSN{}) + var driver stub.Driver + internal.RunDriverTests(t, &driver) } diff --git a/internal/test.go b/internal/test.go index 4496019..072ad99 100644 --- a/internal/test.go +++ b/internal/test.go @@ -13,30 +13,7 @@ import ( ) // RunDriverTests tests an implementation of the godfish.Driver interface. -func RunDriverTests(t *testing.T, dsn godfish.DSN) { - t.Helper() - connParams := godfish.ConnectionParams{ - Encoding: "UTF8", - Host: os.Getenv("DB_HOST"), - Name: "godfish_test", - Pass: os.Getenv("DB_PASSWORD"), - Port: os.Getenv("DB_PORT"), - User: os.Getenv("DB_USER"), - } - if connParams.Host == "" { - connParams.Host = "localhost" - } - if connParams.User == "" { - connParams.User = "godfish" - } - if err := dsn.Boot(connParams); err != nil { - t.Fatal(err) - } - driver, err := godfish.NewDriver(dsn, nil) - if err != nil { - t.Fatal(err) - } - +func RunDriverTests(t *testing.T, driver godfish.Driver) { // Tests for creating the schema migrations table are deliberately not // included. It should be called as needed by other library functions. @@ -605,6 +582,16 @@ func RunDriverTests(t *testing.T, dsn godfish.DSN) { }) } +const dsnKey = "DB_DSN" + +func mustDSN() string { + dsn := os.Getenv(dsnKey) + if dsn == "" { + panic("empty environment variable " + dsnKey) + } + return dsn +} + // Magic option values for test setup and teardown. const ( _SkipMigration = "-" @@ -628,7 +615,7 @@ func setup(driver godfish.Driver, testName string, stubs []testDriverStub, migra // teardown clears state after running a test. func teardown(driver godfish.Driver, path string, tablesToDrop ...string) { var err error - if _, err = driver.Connect(); err != nil { + if _, err = driver.Connect(mustDSN()); err != nil { panic(err) } @@ -824,7 +811,7 @@ func generateMigrationFiles(pathToTestDir string, stubs []testDriverStub) error func collectAppliedVersions(driver godfish.Driver) (out []string, err error) { // Collect output of AppliedVersions. // Reconnect here because ApplyMigration closes the connection. - if _, err = driver.Connect(); err != nil { + if _, err = driver.Connect(mustDSN()); err != nil { return } defer driver.Close()