From 82000118c09142c23f8d96e229f5bfefea663473 Mon Sep 17 00:00:00 2001 From: Tetsuro Aoki Date: Wed, 14 Sep 2022 23:19:42 +0900 Subject: [PATCH] feat: add host param to url function for wait.ForSQL --- e2e/container_test.go | 4 ++-- wait/sql.go | 11 ++++++++--- wait/sql_test.go | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/e2e/container_test.go b/e2e/container_test.go index 3557f7c5f1..f9dc377e04 100644 --- a/e2e/container_test.go +++ b/e2e/container_test.go @@ -24,8 +24,8 @@ func TestContainerWithWaitForSQL(t *testing.T) { "POSTGRES_DB": dbname, } var port = "5432/tcp" - dbURL := func(port nat.Port) string { - return fmt.Sprintf("postgres://postgres:password@localhost:%s/%s?sslmode=disable", port.Port(), dbname) + dbURL := func(host string, port nat.Port) string { + return fmt.Sprintf("postgres://postgres:password@%s:%s/%s?sslmode=disable", host, port.Port(), dbname) } t.Run("default query", func(t *testing.T) { diff --git a/wait/sql.go b/wait/sql.go index 551464c7c5..09f5fd7590 100644 --- a/wait/sql.go +++ b/wait/sql.go @@ -12,7 +12,7 @@ import ( const defaultForSqlQuery = "SELECT 1" //ForSQL constructs a new waitForSql strategy for the given driver -func ForSQL(port nat.Port, driver string, url func(nat.Port) string) *waitForSql { +func ForSQL(port nat.Port, driver string, url func(string, nat.Port) string) *waitForSql { return &waitForSql{ Port: port, URL: url, @@ -24,7 +24,7 @@ func ForSQL(port nat.Port, driver string, url func(nat.Port) string) *waitForSql } type waitForSql struct { - URL func(port nat.Port) string + URL func(host string, port nat.Port) string Driver string Port nat.Port startupTimeout time.Duration @@ -57,6 +57,11 @@ func (w *waitForSql) WaitUntilReady(ctx context.Context, target StrategyTarget) ctx, cancel := context.WithTimeout(ctx, w.startupTimeout) defer cancel() + host, err := target.Host(ctx) + if err != nil { + return + } + ticker := time.NewTicker(w.PollInterval) defer ticker.Stop() @@ -72,7 +77,7 @@ func (w *waitForSql) WaitUntilReady(ctx context.Context, target StrategyTarget) } } - db, err := sql.Open(w.Driver, w.URL(port)) + db, err := sql.Open(w.Driver, w.URL(host, port)) if err != nil { return fmt.Errorf("sql.Open: %v", err) } diff --git a/wait/sql_test.go b/wait/sql_test.go index f9ec88525f..63048e6acc 100644 --- a/wait/sql_test.go +++ b/wait/sql_test.go @@ -8,7 +8,7 @@ import ( func Test_waitForSql_WithQuery(t *testing.T) { t.Run("default query", func(t *testing.T) { - w := ForSQL("5432/tcp", "postgres", func(port nat.Port) string { + w := ForSQL("5432/tcp", "postgres", func(host string, port nat.Port) string { return "fake-url" }) @@ -19,7 +19,7 @@ func Test_waitForSql_WithQuery(t *testing.T) { t.Run("custom query", func(t *testing.T) { const q = "SELECT 100;" - w := ForSQL("5432/tcp", "postgres", func(port nat.Port) string { + w := ForSQL("5432/tcp", "postgres", func(host string, port nat.Port) string { return "fake-url" }).WithQuery(q)