diff --git a/internal/android/safetynet.go b/internal/android/safetynet.go index ee963c4ef..0467d63c4 100644 --- a/internal/android/safetynet.go +++ b/internal/android/safetynet.go @@ -24,8 +24,8 @@ import ( "runtime/trace" "time" + "github.com/google/exposure-notifications-server/internal/authorizedapp/model" "github.com/google/exposure-notifications-server/internal/base64util" - "github.com/google/exposure-notifications-server/internal/database" "github.com/google/exposure-notifications-server/internal/logging" "github.com/dgrijalva/jwt-go" @@ -139,7 +139,7 @@ func ValidateAttestation(ctx context.Context, attestation string, opts *VerifyOp // VerifyOptsFor returns the Android SafetyNet verification options to be used // based on the AuthorizedApp configuration, request time, and nonce. -func VerifyOptsFor(c *database.AuthorizedApp, from time.Time, nonce string) *VerifyOpts { +func VerifyOptsFor(c *model.AuthorizedApp, from time.Time, nonce string) *VerifyOpts { digests := make([]string, len(c.SafetyNetApkDigestSHA256)) copy(digests, c.SafetyNetApkDigestSHA256) rtn := &VerifyOpts{ diff --git a/internal/android/safetynet_test.go b/internal/android/safetynet_test.go index f68419a6d..6441744c9 100644 --- a/internal/android/safetynet_test.go +++ b/internal/android/safetynet_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/exposure-notifications-server/internal/authorizedapp/model" "github.com/google/exposure-notifications-server/internal/base64util" "github.com/google/exposure-notifications-server/internal/database" "github.com/google/go-cmp/cmp" @@ -192,11 +193,11 @@ func TestVerifyOptsFor(t *testing.T) { testTime := time.Date(2020, 1, 13, 5, 6, 4, 6, time.Local) cases := []struct { - cfg *database.AuthorizedApp + cfg *model.AuthorizedApp opts *VerifyOpts }{ { - cfg: &database.AuthorizedApp{ + cfg: &model.AuthorizedApp{ AppPackageName: "foo", SafetyNetBasicIntegrity: true, SafetyNetCTSProfileMatch: true, @@ -213,7 +214,7 @@ func TestVerifyOptsFor(t *testing.T) { }, }, { - cfg: &database.AuthorizedApp{ + cfg: &model.AuthorizedApp{ AppPackageName: "foo", SafetyNetBasicIntegrity: true, SafetyNetCTSProfileMatch: false, @@ -230,7 +231,7 @@ func TestVerifyOptsFor(t *testing.T) { }, }, { - cfg: &database.AuthorizedApp{ + cfg: &model.AuthorizedApp{ AppPackageName: "foo", SafetyNetApkDigestSHA256: []string{"bar"}, SafetyNetBasicIntegrity: true, diff --git a/internal/database/authorized_app.go b/internal/authorizedapp/database/authorized_app.go similarity index 86% rename from internal/database/authorized_app.go rename to internal/authorizedapp/database/authorized_app.go index fbcdc6c25..ef76b61fc 100644 --- a/internal/database/authorized_app.go +++ b/internal/authorizedapp/database/authorized_app.go @@ -20,15 +20,27 @@ import ( "fmt" "time" + "github.com/google/exposure-notifications-server/internal/authorizedapp/model" + "github.com/google/exposure-notifications-server/internal/database" "github.com/google/exposure-notifications-server/internal/ios" "github.com/google/exposure-notifications-server/internal/secrets" pgx "github.com/jackc/pgx/v4" ) +type AuthorizedAppDB struct { + db *database.DB +} + +func NewAuthorizedAppDB(db *database.DB) *AuthorizedAppDB { + return &AuthorizedAppDB{ + db: db, + } +} + // GetAuthorizedApp loads a single AuthorizedApp for the given name. If no row // exists, this returns nil. -func (db *DB) GetAuthorizedApp(ctx context.Context, sm secrets.SecretManager, name string) (*AuthorizedApp, error) { - conn, err := db.pool.Acquire(ctx) +func (db *AuthorizedAppDB) GetAuthorizedApp(ctx context.Context, sm secrets.SecretManager, name string) (*model.AuthorizedApp, error) { + conn, err := db.db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %v", err) } @@ -45,7 +57,7 @@ func (db *DB) GetAuthorizedApp(ctx context.Context, sm secrets.SecretManager, na row := conn.QueryRow(ctx, query, name) - config := NewAuthorizedApp() + config := model.NewAuthorizedApp() var allowedRegions []string var safetyNetPastSeconds, safetyNetFutureSeconds *int var deviceCheckTeamID, deviceCheckKeyID, deviceCheckPrivateKeySecret sql.NullString diff --git a/internal/database/authorized_app_test.go b/internal/authorizedapp/database/authorized_app_test.go similarity index 86% rename from internal/database/authorized_app_test.go rename to internal/authorizedapp/database/authorized_app_test.go index 14e6d4f44..c9202001b 100644 --- a/internal/database/authorized_app_test.go +++ b/internal/authorizedapp/database/authorized_app_test.go @@ -22,9 +22,13 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "log" + "os" "testing" "time" + "github.com/google/exposure-notifications-server/internal/authorizedapp/model" + coredb "github.com/google/exposure-notifications-server/internal/database" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) @@ -41,11 +45,25 @@ func (s *testSecretManager) GetSecretValue(ctx context.Context, name string) (st return v, nil } +var testDB *coredb.DB + +func TestMain(m *testing.M) { + ctx := context.Background() + + if os.Getenv("DB_USER") != "" { + var err error + testDB, err = coredb.CreateTestDB(ctx) + if err != nil { + log.Fatalf("creating test DB: %v", err) + } + } + os.Exit(m.Run()) +} func TestGetAuthorizedApp(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer coredb.ResetTestDB(t, testDB) ctx := context.Background() // Create private key for parsing later @@ -72,7 +90,7 @@ func TestGetAuthorizedApp(t *testing.T) { name string sql string args []interface{} - exp *AuthorizedApp + exp *model.AuthorizedApp err bool }{ { @@ -83,7 +101,7 @@ func TestGetAuthorizedApp(t *testing.T) { `, args: []interface{}{"myapp", "android", []string{"US"}}, - exp: &AuthorizedApp{ + exp: &model.AuthorizedApp{ AppPackageName: "myapp", Platform: "android", AllowedRegions: map[string]struct{}{"US": {}}, @@ -98,7 +116,7 @@ func TestGetAuthorizedApp(t *testing.T) { VALUES ($1, $2, $3) `, args: []interface{}{"myapp", "android", []string{}}, - exp: &AuthorizedApp{ + exp: &model.AuthorizedApp{ AppPackageName: "myapp", Platform: "android", AllowedRegions: map[string]struct{}{}, @@ -122,7 +140,7 @@ func TestGetAuthorizedApp(t *testing.T) { "myapp", "android", []string{}, []string{"092fcfb", "252f10c"}, false, false, }, - exp: &AuthorizedApp{ + exp: &model.AuthorizedApp{ AppPackageName: "myapp", Platform: "android", AllowedRegions: map[string]struct{}{}, @@ -141,7 +159,7 @@ func TestGetAuthorizedApp(t *testing.T) { ) VALUES ($1, $2, $3, $4) `, args: []interface{}{"myapp", "android", []string{"US"}, 1800}, - exp: &AuthorizedApp{ + exp: &model.AuthorizedApp{ AppPackageName: "myapp", Platform: "android", AllowedRegions: map[string]struct{}{"US": {}}, @@ -159,7 +177,7 @@ func TestGetAuthorizedApp(t *testing.T) { ) VALUES ($1, $2, $3, $4) `, args: []interface{}{"myapp", "android", []string{"US"}, 1800}, - exp: &AuthorizedApp{ + exp: &model.AuthorizedApp{ AppPackageName: "myapp", Platform: "android", AllowedRegions: map[string]struct{}{"US": {}}, @@ -177,7 +195,7 @@ func TestGetAuthorizedApp(t *testing.T) { ) VALUES ($1, $2, $3, $4, $5, $6) `, args: []interface{}{"myapp", "ios", []string{"US"}, "ABCD1234", "DEFG5678", "private_key"}, - exp: &AuthorizedApp{ + exp: &model.AuthorizedApp{ AppPackageName: "myapp", Platform: "ios", AllowedRegions: map[string]struct{}{"US": {}}, @@ -201,19 +219,19 @@ func TestGetAuthorizedApp(t *testing.T) { t.Run(c.name, func(t *testing.T) { // Acquire a connection - conn, err := testDB.pool.Acquire(ctx) + conn, err := testDB.Pool.Acquire(ctx) if err != nil { t.Fatal(err) } defer conn.Release() - defer resetTestDB(t) + defer coredb.ResetTestDB(t, testDB) // Insert the data if _, err := conn.Exec(ctx, c.sql, c.args...); err != nil { t.Fatal(err) } - config, err := testDB.GetAuthorizedApp(ctx, sm, "myapp") + config, err := NewAuthorizedAppDB(testDB).GetAuthorizedApp(ctx, sm, "myapp") if (err != nil) != c.err { t.Fatal(err) } diff --git a/internal/authorizedapp/database_provider.go b/internal/authorizedapp/database_provider.go index 10ab3ff1d..b67b48e6e 100644 --- a/internal/authorizedapp/database_provider.go +++ b/internal/authorizedapp/database_provider.go @@ -20,7 +20,10 @@ import ( "sync" "time" + authorizedappdb "github.com/google/exposure-notifications-server/internal/authorizedapp/database" + "github.com/google/exposure-notifications-server/internal/authorizedapp/model" "github.com/google/exposure-notifications-server/internal/database" + "github.com/google/exposure-notifications-server/internal/logging" "github.com/google/exposure-notifications-server/internal/secrets" ) @@ -40,7 +43,7 @@ type DatabaseProvider struct { } type cacheItem struct { - value *database.AuthorizedApp + value *model.AuthorizedApp cachedAt time.Time } @@ -74,7 +77,7 @@ func NewDatabaseProvider(ctx context.Context, db *database.DB, config *Config, o // checkCache checks the local cache within a read lock. // The bool on return is true if there was a hit (And an error is a valid hit) // or false if there was a miss (or expiry) and the data source should be queried again. -func (p *DatabaseProvider) checkCache(name string) (*database.AuthorizedApp, bool, error) { +func (p *DatabaseProvider) checkCache(name string) (*model.AuthorizedApp, bool, error) { // Acquire a read lock first, which allows concurrent readers, to check if // there's an item in the cache. p.cacheLock.RLock() @@ -91,7 +94,7 @@ func (p *DatabaseProvider) checkCache(name string) (*database.AuthorizedApp, boo } // AppConfig returns the config for the given app package name. -func (p *DatabaseProvider) AppConfig(ctx context.Context, name string) (*database.AuthorizedApp, error) { +func (p *DatabaseProvider) AppConfig(ctx context.Context, name string) (*model.AuthorizedApp, error) { logger := logging.FromContext(ctx) data, cacheHit, error := p.checkCache(name) @@ -136,11 +139,11 @@ func (p *DatabaseProvider) AppConfig(ctx context.Context, name string) (*databas // loadAuthorizedAppFromDatabase is a lower-level private API that actually loads and parses // a single AuthorizedApp from the database. -func (p *DatabaseProvider) loadAuthorizedAppFromDatabase(ctx context.Context, name string) (*database.AuthorizedApp, error) { +func (p *DatabaseProvider) loadAuthorizedAppFromDatabase(ctx context.Context, name string) (*model.AuthorizedApp, error) { logger := logging.FromContext(ctx) logger.Infof("authorizedapp: loading %v from database", name) - config, err := p.database.GetAuthorizedApp(ctx, p.secretManager, name) + config, err := authorizedappdb.NewAuthorizedAppDB(p.database).GetAuthorizedApp(ctx, p.secretManager, name) if err != nil { return nil, fmt.Errorf("failed to read %v from database: %w", name, err) } diff --git a/internal/authorizedapp/memory_provider.go b/internal/authorizedapp/memory_provider.go index 71944d4e4..ba6d4e5e6 100644 --- a/internal/authorizedapp/memory_provider.go +++ b/internal/authorizedapp/memory_provider.go @@ -17,7 +17,7 @@ package authorizedapp import ( "context" - "github.com/google/exposure-notifications-server/internal/database" + "github.com/google/exposure-notifications-server/internal/authorizedapp/model" ) // Compile-time check to assert implementation. @@ -26,20 +26,20 @@ var _ Provider = (*MemoryProvider)(nil) // MemoryProvider is an Provider that stores values in-memory. It is primarily // used for testing. type MemoryProvider struct { - Data map[string]*database.AuthorizedApp + Data map[string]*model.AuthorizedApp } // NewMemoryProvider creates a new Provider that reads from a database. func NewMemoryProvider(ctx context.Context, _ *Config) (Provider, error) { provider := &MemoryProvider{ - Data: make(map[string]*database.AuthorizedApp), + Data: make(map[string]*model.AuthorizedApp), } return provider, nil } // AppConfig returns the config for the given app package name. -func (p *MemoryProvider) AppConfig(ctx context.Context, name string) (*database.AuthorizedApp, error) { +func (p *MemoryProvider) AppConfig(ctx context.Context, name string) (*model.AuthorizedApp, error) { val, ok := p.Data[name] if !ok { return nil, AppNotFound diff --git a/internal/database/authorized_app_model.go b/internal/authorizedapp/model/authorized_app_model.go similarity index 99% rename from internal/database/authorized_app_model.go rename to internal/authorizedapp/model/authorized_app_model.go index 5bbb65481..b38e3e7e7 100644 --- a/internal/database/authorized_app_model.go +++ b/internal/authorizedapp/model/authorized_app_model.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package database +package model import ( "crypto/ecdsa" diff --git a/internal/database/authorized_app_model_test.go b/internal/authorizedapp/model/authorized_app_model_test.go similarity index 97% rename from internal/database/authorized_app_model_test.go rename to internal/authorizedapp/model/authorized_app_model_test.go index a193f2917..23aa5e066 100644 --- a/internal/database/authorized_app_model_test.go +++ b/internal/authorizedapp/model/authorized_app_model_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package database +package model import ( "testing" diff --git a/internal/authorizedapp/provider.go b/internal/authorizedapp/provider.go index 84247be18..3768f791a 100644 --- a/internal/authorizedapp/provider.go +++ b/internal/authorizedapp/provider.go @@ -18,7 +18,7 @@ import ( "context" "errors" - "github.com/google/exposure-notifications-server/internal/database" + "github.com/google/exposure-notifications-server/internal/authorizedapp/model" ) // AppNotFound is the sentinel error returned when AppConfig fails to find an @@ -30,5 +30,5 @@ type Provider interface { // AppConfig returns the application-specific configuration for the given // name. An error is returned if the configuration fails to load. An error is // returned if no app with the given name is registered in the system. - AppConfig(ctx context.Context, name string) (*database.AuthorizedApp, error) + AppConfig(ctx context.Context, name string) (*model.AuthorizedApp, error) } diff --git a/internal/database/connection.go b/internal/database/connection.go index d804087eb..b83d5b43d 100644 --- a/internal/database/connection.go +++ b/internal/database/connection.go @@ -46,7 +46,7 @@ type config struct { } type DB struct { - pool *pgxpool.Pool + Pool *pgxpool.Pool } // NewFromEnv sets up the database connections using the configuration in the @@ -68,14 +68,14 @@ func NewFromEnv(ctx context.Context, config *Config) (*DB, error) { return nil, fmt.Errorf("creating connection pool: %v", err) } - return &DB{pool: pool}, nil + return &DB{Pool: pool}, nil } // Close releases database connections. func (db *DB) Close(ctx context.Context) { logger := logging.FromContext(ctx) logger.Infof("Closing connection pool.") - db.pool.Close() + db.Pool.Close() } // dbConnectionString builds a connection string suitable for the pgx Postgres driver, using the @@ -91,7 +91,7 @@ func dbConnectionString(ctx context.Context, config *Config) (string, error) { // dbURI builds a Postgres URI suitable for the lib/pq driver, which is used by // github.com/golang-migrate/migrate. -func dbURI(config *Config) string { +func DbURI(config *Config) string { return fmt.Sprintf("postgres://%s/%s?sslmode=disable&user=%s&password=%s&port=%s", config.Host, config.Name, config.User, url.QueryEscape(config.Password), url.QueryEscape(config.Port)) diff --git a/internal/database/database.go b/internal/database/database.go index 4bdde5334..7572db40e 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -43,7 +43,7 @@ func toNullString(s string) sql.NullString { // inTx runs the given function f within a transaction with isolation level isoLevel. func (db *DB) inTx(ctx context.Context, isoLevel pgx.TxIsoLevel, f func(tx pgx.Tx) error) error { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return fmt.Errorf("acquiring connection: %v", err) } diff --git a/internal/database/database_test.go b/internal/database/database_test_util.go similarity index 88% rename from internal/database/database_test.go rename to internal/database/database_test_util.go index 5ff3b5f2d..14fab994c 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test_util.go @@ -39,7 +39,7 @@ func TestMain(m *testing.M) { if os.Getenv("DB_USER") != "" { var err error - testDB, err = createTestDB(ctx) + testDB, err = CreateTestDB(ctx) if err != nil { log.Fatalf("creating test DB: %v", err) } @@ -50,7 +50,7 @@ func TestMain(m *testing.M) { // openTestDB connects to the Postgres server specified by the DB_XXX environment // variables, creates an empty test database on it, and returns a *DB connected // to that database. -func createTestDB(ctx context.Context) (*DB, error) { +func CreateTestDB(ctx context.Context) (*DB, error) { const testDBName = "exposure-server-test" // Connect to the default database to create the test database. @@ -65,7 +65,7 @@ func createTestDB(ctx context.Context) (*DB, error) { if err != nil { return nil, err } - if err := db.createDatabase(ctx, testDBName); err != nil { + if err := createDatabase(ctx, db, testDBName); err != nil { return nil, err } db.Close(ctx) @@ -76,8 +76,8 @@ func createTestDB(ctx context.Context) (*DB, error) { if err != nil { return nil, err } - const source = "file://../../migrations" - uri := dbURI(&config) + const source = "file://../../../migrations" + uri := DbURI(&config) m, err := migrate.New(source, uri) if err != nil { return nil, err @@ -96,8 +96,8 @@ func createTestDB(ctx context.Context) (*DB, error) { return db, nil } -func (db *DB) createDatabase(ctx context.Context, name string) error { - conn, err := db.pool.Acquire(ctx) +func createDatabase(ctx context.Context, db *DB, name string) error { + conn, err := db.Pool.Acquire(ctx) if err != nil { return err } @@ -118,10 +118,10 @@ func mustExec(t *testing.T, conn *pgxpool.Conn, stmt string, args ...interface{} } } -func resetTestDB(t *testing.T) { +func ResetTestDB(t *testing.T, testDB *DB) { t.Helper() ctx := context.Background() - conn, err := testDB.pool.Acquire(ctx) + conn, err := testDB.Pool.Acquire(ctx) if err != nil { t.Fatal(err) } diff --git a/internal/database/export.go b/internal/database/export.go index aa73739dd..3ed8137a8 100644 --- a/internal/database/export.go +++ b/internal/database/export.go @@ -76,7 +76,7 @@ func (db *DB) IterateExportConfigs(ctx context.Context, t time.Time, f func(*Exp } }() - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return fmt.Errorf("acquiring connection: %w", err) } @@ -143,7 +143,7 @@ func (db *DB) AddSignatureInfo(ctx context.Context, si *SignatureInfo) error { } func (db *DB) LookupSignatureInfos(ctx context.Context, ids []int64, validUntil time.Time) ([]*SignatureInfo, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %w", err) } @@ -185,7 +185,7 @@ func (db *DB) LookupSignatureInfos(ctx context.Context, ids []int64, validUntil // exists. // TODO(squee1945): This needs a func (db *DB) LatestExportBatchEnd(ctx context.Context, ec *ExportConfig) (time.Time, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return time.Time{}, fmt.Errorf("acquiring connection: %w", err) } @@ -243,7 +243,7 @@ func (db *DB) LeaseBatch(ctx context.Context, ttl time.Duration, now time.Time) // Lookup a set of candidate batch IDs. var openBatchIDs []int64 err := func() error { // Use a func to allow defer conn.Release() to work. - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return fmt.Errorf("acquiring connection: %w", err) } @@ -348,7 +348,7 @@ func (db *DB) LeaseBatch(ctx context.Context, ttl time.Duration, now time.Time) // LookupExportBatch returns an ExportBatch for the given batchID. func (db *DB) LookupExportBatch(ctx context.Context, batchID int64) (*ExportBatch, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %w", err) } @@ -415,7 +415,7 @@ func (db *DB) FinalizeBatch(ctx context.Context, eb *ExportBatch, files []string // LookupExportFiles returns a list of export files for the given ExportConfig exportConfigID. func (db *DB) LookupExportFiles(ctx context.Context, exportConfigID int64) ([]string, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %w", err) } @@ -464,7 +464,7 @@ type joinedExportBatchFile struct { } func (db *DB) LookupExportFile(ctx context.Context, filename string) (*ExportFile, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %w", err) } @@ -497,7 +497,7 @@ func (db *DB) DeleteFilesBefore(ctx context.Context, before time.Time, blobstore // Fetch filenames for batches where at least one file is not deleted yet. var files []joinedExportBatchFile err := func() error { // Use a func to allow defer conn.Release() to work. - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return fmt.Errorf("acquiring connection: %w", err) } diff --git a/internal/database/export_test.go b/internal/database/export_test.go index cf45d1ac1..9a5662e2f 100644 --- a/internal/database/export_test.go +++ b/internal/database/export_test.go @@ -29,7 +29,7 @@ func TestAddSignatureInfo(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() thruTime := time.Now().UTC().Add(6 * time.Hour).Truncate(time.Microsecond) @@ -42,7 +42,7 @@ func TestAddSignatureInfo(t *testing.T) { if err := testDB.AddSignatureInfo(ctx, want); err != nil { t.Fatal(err) } - conn, err := testDB.pool.Acquire(ctx) + conn, err := testDB.Pool.Acquire(ctx) if err != nil { t.Fatal(err) } @@ -69,7 +69,7 @@ func TestLookupSignatureInfos(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() testTime := time.Now().UTC() @@ -114,7 +114,7 @@ func TestAddExportConfig(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() fromTime := time.Now() @@ -131,7 +131,7 @@ func TestAddExportConfig(t *testing.T) { if err := testDB.AddExportConfig(ctx, want); err != nil { t.Fatal(err) } - conn, err := testDB.pool.Acquire(ctx) + conn, err := testDB.Pool.Acquire(ctx) if err != nil { t.Fatal(err) } @@ -164,7 +164,7 @@ func TestIterateExportConfigs(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() now := time.Now().Truncate(time.Microsecond) @@ -220,7 +220,7 @@ func TestBatches(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() now := time.Now().Truncate(time.Microsecond) @@ -330,7 +330,7 @@ func TestFinalizeBatch(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() now := time.Now().Truncate(time.Microsecond) @@ -424,7 +424,7 @@ func TestKeysInBatch(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() now := time.Now() @@ -521,7 +521,7 @@ func TestAddExportFileSkipsDuplicates(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() // Add foreign key records. diff --git a/internal/database/exposure.go b/internal/database/exposure.go index ec7c1327d..aba2247b4 100644 --- a/internal/database/exposure.go +++ b/internal/database/exposure.go @@ -55,7 +55,7 @@ type IterateExposuresCriteria struct { // the iteration at the failed row. If IterateExposures returns a nil error, // the first return value will be the empty string. func (db *DB) IterateExposures(ctx context.Context, criteria IterateExposuresCriteria, f func(*Exposure) error) (cur string, err error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return "", fmt.Errorf("acquiring connection: %v", err) } diff --git a/internal/database/exposure_test.go b/internal/database/exposure_test.go index 79192acbd..d73728837 100644 --- a/internal/database/exposure_test.go +++ b/internal/database/exposure_test.go @@ -27,7 +27,7 @@ func TestExposures(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() // Insert some Exposures. @@ -156,7 +156,7 @@ func TestIterateExposuresCursor(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx, cancel := context.WithCancel(context.Background()) // Insert some Exposures. exposures := []*Exposure{ diff --git a/internal/database/federationin.go b/internal/database/federationin.go index f437a69b8..21b7870bd 100644 --- a/internal/database/federationin.go +++ b/internal/database/federationin.go @@ -29,7 +29,7 @@ type queryRowFn func(ctx context.Context, query string, args ...interface{}) pgx // GetFederationInQuery returns a query for given queryID. If not found, ErrNotFound will be returned. func (db *DB) GetFederationInQuery(ctx context.Context, queryID string) (*FederationInQuery, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %w", err) } @@ -83,7 +83,7 @@ func (db *DB) AddFederationInQuery(ctx context.Context, q *FederationInQuery) er // GetFederationInSync returns a federation sync record for given syncID. If not found, ErrNotFound will be returned. func (db *DB) GetFederationInSync(ctx context.Context, syncID int64) (*FederationInSync, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %w", err) } @@ -127,7 +127,7 @@ func getFederationInSync(ctx context.Context, syncID int64, queryRowContext quer // StartFederationInSync stores a historical record of a query sync starting. It returns a FederationInSync key, and a FinalizeSyncFn that must be invoked to finalize the historical record. func (db *DB) StartFederationInSync(ctx context.Context, q *FederationInQuery, started time.Time) (int64, FinalizeSyncFn, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return 0, nil, fmt.Errorf("acquiring connection: %w", err) } diff --git a/internal/database/federationin_test.go b/internal/database/federationin_test.go index 781278ab5..d62de9f39 100644 --- a/internal/database/federationin_test.go +++ b/internal/database/federationin_test.go @@ -29,7 +29,7 @@ func TestFederationIn(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() ts := time.Date(2020, 5, 6, 0, 0, 0, 0, time.UTC) diff --git a/internal/database/federationout.go b/internal/database/federationout.go index 96f291341..b7cff7b5d 100644 --- a/internal/database/federationout.go +++ b/internal/database/federationout.go @@ -45,7 +45,7 @@ func (db *DB) AddFederationOutAuthorization(ctx context.Context, auth *Federatio // GetFederationOutAuthorization returns a FederationOutAuthorization record, or ErrNotFound if not found. func (db *DB) GetFederationOutAuthorization(ctx context.Context, issuer, subject string) (*FederationOutAuthorization, error) { - conn, err := db.pool.Acquire(ctx) + conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, fmt.Errorf("acquiring connection: %w", err) } diff --git a/internal/database/federationout_test.go b/internal/database/federationout_test.go index 58fb2bc29..3f0676bcb 100644 --- a/internal/database/federationout_test.go +++ b/internal/database/federationout_test.go @@ -27,7 +27,7 @@ func TestFederationOutAuthorization(t *testing.T) { if testDB == nil { t.Skip("no test DB") } - defer resetTestDB(t) + defer ResetTestDB(t, testDB) ctx := context.Background() want := &FederationOutAuthorization{ diff --git a/internal/database/lock_test.go b/internal/database/lock_test.go index 0a6051ebe..6348288f0 100644 --- a/internal/database/lock_test.go +++ b/internal/database/lock_test.go @@ -70,7 +70,7 @@ func TestLock(t *testing.T) { } // Lock table should be empty. - conn, err := testDB.pool.Acquire(ctx) + conn, err := testDB.Pool.Acquire(ctx) if err != nil { t.Fatal(err) } diff --git a/internal/pb/federation.pb.go b/internal/pb/federation.pb.go index 3923714a1..11eec2b80 100644 --- a/internal/pb/federation.pb.go +++ b/internal/pb/federation.pb.go @@ -255,7 +255,6 @@ type ContactTracingInfo struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // Transmission risk is an integer with valid values from 1-8. TransmissionRisk int32 `protobuf:"varint,1,opt,name=transmissionRisk,proto3" json:"transmissionRisk,omitempty"` // required ExposureKeys []*ExposureKey `protobuf:"bytes,2,rep,name=exposureKeys,proto3" json:"exposureKeys,omitempty"` } diff --git a/internal/verification/verify.go b/internal/verification/verify.go index a8c44343c..d429b4aed 100644 --- a/internal/verification/verify.go +++ b/internal/verification/verify.go @@ -20,7 +20,9 @@ import ( "time" "github.com/google/exposure-notifications-server/internal/android" + authorizedapp "github.com/google/exposure-notifications-server/internal/authorizedapp/model" "github.com/google/exposure-notifications-server/internal/database" + "github.com/google/exposure-notifications-server/internal/ios" ) @@ -31,7 +33,7 @@ var ( // VerifyRegions checks the request regions against the regions allowed by // the configuration for the application. -func VerifyRegions(cfg *database.AuthorizedApp, data *database.Publish) error { +func VerifyRegions(cfg *authorizedapp.AuthorizedApp, data *database.Publish) error { if cfg == nil { return fmt.Errorf("app configuration is empty") } @@ -48,7 +50,7 @@ func VerifyRegions(cfg *database.AuthorizedApp, data *database.Publish) error { // VerifySafetyNet verifies the Android SafetyNet device attestation against the // allowed configuration for the application. -func VerifySafetyNet(ctx context.Context, requestTime time.Time, cfg *database.AuthorizedApp, publish *database.Publish) error { +func VerifySafetyNet(ctx context.Context, requestTime time.Time, cfg *authorizedapp.AuthorizedApp, publish *database.Publish) error { if cfg == nil { return fmt.Errorf("cannot enforce SafetyNet, missing config") } @@ -62,7 +64,7 @@ func VerifySafetyNet(ctx context.Context, requestTime time.Time, cfg *database.A } // VerifyDeviceCheck verifies an iOS DeviceCheck token against the Apple API. -func VerifyDeviceCheck(ctx context.Context, cfg *database.AuthorizedApp, data *database.Publish) error { +func VerifyDeviceCheck(ctx context.Context, cfg *authorizedapp.AuthorizedApp, data *database.Publish) error { if cfg == nil { return fmt.Errorf("cannot enforce DeviceCheck, missing config") } diff --git a/internal/verification/verify_test.go b/internal/verification/verify_test.go index 51ea56063..2ce3c5e4d 100644 --- a/internal/verification/verify_test.go +++ b/internal/verification/verify_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/google/exposure-notifications-server/internal/android" + authorizedapp "github.com/google/exposure-notifications-server/internal/authorizedapp/model" "github.com/google/exposure-notifications-server/internal/database" ) @@ -33,7 +34,7 @@ func TestVerifyRegions(t *testing.T) { cases := []struct { name string data *database.Publish - cfg *database.AuthorizedApp + cfg *authorizedapp.AuthorizedApp err bool }{ { @@ -45,14 +46,14 @@ func TestVerifyRegions(t *testing.T) { { name: "nil_regions_allows_all", data: &database.Publish{Regions: []string{"US"}}, - cfg: &database.AuthorizedApp{ + cfg: &authorizedapp.AuthorizedApp{ AppPackageName: appPkgName, }, }, { name: "empty_regions_allows_all", data: &database.Publish{Regions: []string{"US"}}, - cfg: &database.AuthorizedApp{ + cfg: &authorizedapp.AuthorizedApp{ AppPackageName: appPkgName, AllowedRegions: map[string]struct{}{}, }, @@ -60,7 +61,7 @@ func TestVerifyRegions(t *testing.T) { { name: "region_matches_one", data: &database.Publish{Regions: []string{"US"}}, - cfg: &database.AuthorizedApp{ + cfg: &authorizedapp.AuthorizedApp{ AppPackageName: appPkgName, AllowedRegions: map[string]struct{}{ "US": {}, @@ -71,7 +72,7 @@ func TestVerifyRegions(t *testing.T) { { name: "region_matches_all", data: &database.Publish{Regions: []string{"US", "CA"}}, - cfg: &database.AuthorizedApp{ + cfg: &authorizedapp.AuthorizedApp{ AppPackageName: appPkgName, AllowedRegions: map[string]struct{}{ "US": {}, @@ -82,7 +83,7 @@ func TestVerifyRegions(t *testing.T) { { name: "region_matches_some", data: &database.Publish{Regions: []string{"US", "MX"}}, - cfg: &database.AuthorizedApp{ + cfg: &authorizedapp.AuthorizedApp{ AppPackageName: appPkgName, AllowedRegions: map[string]struct{}{ "US": {}, @@ -94,7 +95,7 @@ func TestVerifyRegions(t *testing.T) { { name: "region_matches_none", data: &database.Publish{Regions: []string{"MX"}}, - cfg: &database.AuthorizedApp{ + cfg: &authorizedapp.AuthorizedApp{ AppPackageName: appPkgName, AllowedRegions: map[string]struct{}{ "US": {}, @@ -117,14 +118,14 @@ func TestVerifyRegions(t *testing.T) { } func TestVerifySafetyNet(t *testing.T) { - allRegions := &database.AuthorizedApp{ + allRegions := &authorizedapp.AuthorizedApp{ AppPackageName: appPkgName, } cases := []struct { Data *database.Publish Msg string - Cfg *database.AuthorizedApp + Cfg *authorizedapp.AuthorizedApp AttestationResult error }{ {