Skip to content

Commit

Permalink
refactor: Remove database/sql from Driver interface
Browse files Browse the repository at this point in the history
Just realized that the core godfish library code doesn't even need to
use the database/sql.DB connection. Removing this dependency frees up
the library to work with other databases that may not implement the
database/sql/driver.Driver interface at all.
  • Loading branch information
rafaelespinoza committed Apr 30, 2021
1 parent b01af18 commit a658a52
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 31 deletions.
18 changes: 4 additions & 14 deletions driver.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
package godfish

import "database/sql"

// A Driver describes what a database driver (anything at
// https:/golang/go/wiki/SQLDrivers) should be able to do.
// Driver adapts a database implementation to use godfish.
type Driver interface {
// Name should return the name of the driver: ie: postgres, mysql, etc
Name() string

// Connect should open a connection (a *sql.DB) to the database and save an
// internal reference to that connection for later use. This library might
// call this method multiple times, so use the internal reference if it's
// present instead of reconnecting to the database.
Connect(dsn string) (*sql.DB, error)
// Close should check if there's an internal reference to a database
// connection (a *sql.DB) and if it's present, close it. Then reset the
// internal reference to that connection to nil.
// Connect should open a connection to the database.
Connect(dsn string) error
// Close should close the database connection.
Close() error

// AppliedVersions queries the schema migrations table for migration
Expand Down Expand Up @@ -47,5 +39,3 @@ type AppliedVersions interface {
Next() bool
Scan(dest ...interface{}) error
}

var _ AppliedVersions = (*sql.Rows)(nil)
6 changes: 3 additions & 3 deletions drivers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ type driver struct {
}

func (d *driver) Name() string { return "mysql" }
func (d *driver) Connect(dsn string) (conn *sql.DB, err error) {
func (d *driver) Connect(dsn string) (err error) {
if d.connection != nil {
conn = d.connection
return
}
if conn, err = sql.Open(d.Name(), dsn); err != nil {
conn, err := sql.Open(d.Name(), dsn)
if err != nil {
return
}
d.connection = conn
Expand Down
6 changes: 3 additions & 3 deletions drivers/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ type driver struct {
}

func (d *driver) Name() string { return "postgres" }
func (d *driver) Connect(dsn string) (conn *sql.DB, err error) {
func (d *driver) Connect(dsn string) (err error) {
if d.connection != nil {
conn = d.connection
return
}
if conn, err = sql.Open(d.Name(), dsn); err != nil {
conn, err := sql.Open(d.Name(), dsn)
if err != nil {
return
}
d.connection = conn
Expand Down
8 changes: 4 additions & 4 deletions godfish.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func Migrate(driver Driver, directoryPath string, direction Direction, finishAtV
if dsn, err = getDSN(); err != nil {
return
}
if _, err = driver.Connect(dsn); err != nil {
if err = driver.Connect(dsn); err != nil {
return
}
defer driver.Close()
Expand Down Expand Up @@ -72,7 +72,7 @@ func ApplyMigration(driver Driver, directoryPath string, direction Direction, ve
if dsn, err = getDSN(); err != nil {
return
}
if _, err = driver.Connect(dsn); err != nil {
if err = driver.Connect(dsn); err != nil {
return
}
defer driver.Close()
Expand Down Expand Up @@ -201,7 +201,7 @@ func CreateSchemaMigrationsTable(driver Driver) (err error) {
if dsn, err = getDSN(); err != nil {
return
}
if _, err = driver.Connect(dsn); err != nil {
if err = driver.Connect(dsn); err != nil {
return err
}
defer driver.Close()
Expand All @@ -214,7 +214,7 @@ func Info(driver Driver, directoryPath string, direction Direction, finishAtVers
if dsn, err = getDSN(); err != nil {
return
}
if _, err = driver.Connect(dsn); err != nil {
if err = driver.Connect(dsn); err != nil {
return err
}
defer driver.Close()
Expand Down
10 changes: 10 additions & 0 deletions godfish_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package godfish_test

import (
"database/sql"
"encoding/json"
"os"
"testing"
Expand Down Expand Up @@ -67,3 +68,12 @@ func TestInit(t *testing.T) {
)
}
}

func TestAppliedVersions(t *testing.T) {
// Regression test on the API. It's supposed to wrap this type from the
// standard library for the most common cases.
var thing interface{} = new(sql.Rows)
if _, ok := thing.(godfish.AppliedVersions); !ok {
t.Fatalf("expected %T to implement godfish.AppliedVersions", thing)
}
}
8 changes: 3 additions & 5 deletions internal/stub/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,23 @@
package stub

import (
"database/sql"
"fmt"
"strings"

"github.com/rafaelespinoza/godfish"
)

type Driver struct {
connection *sql.DB
appliedVersions godfish.AppliedVersions
err error
errorOnExecute error
}

var _ godfish.Driver = (*Driver)(nil)

func (d *Driver) Name() string { return "stub" }
func (d *Driver) Connect(dsn string) (*sql.DB, error) { return d.connection, d.err }
func (d *Driver) Close() error { return d.err }
func (d *Driver) Name() string { return "stub" }
func (d *Driver) Connect(dsn string) error { return d.err }
func (d *Driver) Close() error { return d.err }
func (d *Driver) CreateSchemaMigrationsTable() error {
if d.appliedVersions == nil {
d.appliedVersions = MakeAppliedVersions()
Expand Down
4 changes: 2 additions & 2 deletions internal/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ func setup(driver godfish.Driver, testName string, stubs []testDriverStub, migra
// teardown clears state after running a test.
func teardown(driver godfish.Driver, path string, tablesToDrop ...string) {
var err error
if _, err = driver.Connect(mustDSN()); err != nil {
if err = driver.Connect(mustDSN()); err != nil {
panic(err)
}

Expand Down Expand Up @@ -811,7 +811,7 @@ func generateMigrationFiles(pathToTestDir string, stubs []testDriverStub) error
func collectAppliedVersions(driver godfish.Driver) (out []string, err error) {
// Collect output of AppliedVersions.
// Reconnect here because ApplyMigration closes the connection.
if _, err = driver.Connect(mustDSN()); err != nil {
if err = driver.Connect(mustDSN()); err != nil {
return
}
defer driver.Close()
Expand Down

0 comments on commit a658a52

Please sign in to comment.