Skip to content

Commit

Permalink
refactor: Support context-based SQL transactions (#254)
Browse files Browse the repository at this point in the history
Co-authored-by: hackerman <[email protected]>
  • Loading branch information
zepatrik and aeneasr authored Feb 17, 2020
1 parent 3c1c67b commit 6ace1ee
Show file tree
Hide file tree
Showing 14 changed files with 170 additions and 64 deletions.
1 change: 1 addition & 0 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/fsnotify/fsnotify"
"github.com/gobuffalo/packr/v2"

"github.com/ory/viper"
"github.com/ory/x/flagx"
"github.com/ory/x/viperx"
Expand Down
4 changes: 2 additions & 2 deletions courier/courier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func TestSMTP(t *testing.T) {
}

smtp, api := runTestSMTP(t)
t.Logf("SMTP URL: %s",smtp)
t.Logf("API URL: %s",api)
t.Logf("SMTP URL: %s", smtp)
t.Logf("API URL: %s", api)

conf, reg := internal.NewRegistryDefault(t)
viper.Set(configuration.ViperKeyCourierSMTPURL, smtp)
Expand Down
4 changes: 4 additions & 0 deletions persistence/reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"io"

"github.com/gobuffalo/pop/v5"

"github.com/ory/kratos/courier"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/errorx"
Expand Down Expand Up @@ -33,4 +35,6 @@ type Persister interface {
MigrationStatus(c context.Context, b io.Writer) error
MigrateDown(c context.Context, steps int) error
MigrateUp(c context.Context) error
GetConnection(ctx context.Context) *pop.Connection
Transaction(ctx context.Context, callback func(connection *pop.Connection) error) error
}
14 changes: 7 additions & 7 deletions persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,26 @@ func NewPersister(r persisterDependencies, conf configuration.Provider, c *pop.C
return &Persister{c: c, mb: m, cf: conf, r: r}, nil
}

func (p *Persister) MigrationStatus(c context.Context, w io.Writer) error {
func (p *Persister) MigrationStatus(ctx context.Context, w io.Writer) error {
return errors.WithStack(p.mb.Status(w))
}

func (p *Persister) MigrateDown(c context.Context, steps int) error {
func (p *Persister) MigrateDown(ctx context.Context, steps int) error {
return errors.WithStack(p.mb.Down(steps))
}

func (p *Persister) MigrateUp(c context.Context) error {
func (p *Persister) MigrateUp(ctx context.Context) error {
return errors.WithStack(p.mb.Up())
}

func (p *Persister) Close(c context.Context) error {
return errors.WithStack(p.c.Close())
func (p *Persister) Close(ctx context.Context) error {
return errors.WithStack(p.GetConnection(ctx).Close())
}

func (p *Persister) Ping(c context.Context) error {
func (p *Persister) Ping(ctx context.Context) error {
type pinger interface {
Ping() error
}

return errors.WithStack(p.c.Store.(pinger).Ping())
return errors.WithStack(p.GetConnection(ctx).Store.(pinger).Ping())
}
8 changes: 4 additions & 4 deletions persistence/sql/persister_courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ var _ courier.Persister = new(Persister)

func (p *Persister) AddMessage(ctx context.Context, m *courier.Message) error {
m.Status = courier.MessageStatusQueued
return sqlcon.HandleError(p.c.Create(m)) // do not create eager to avoid identity injection.
return sqlcon.HandleError(p.GetConnection(ctx).Create(m)) // do not create eager to avoid identity injection.
}

func (p *Persister) NextMessages(ctx context.Context, limit uint8) ([]courier.Message, error) {
var m []courier.Message
if err := p.c.
if err := p.GetConnection(ctx).
Eager().
Where("status != ?", courier.MessageStatusSent).
Order("created_at ASC").Limit(int(limit)).All(&m); err != nil {
Expand All @@ -40,7 +40,7 @@ func (p *Persister) NextMessages(ctx context.Context, limit uint8) ([]courier.Me

func (p *Persister) LatestQueuedMessage(ctx context.Context) (*courier.Message, error) {
var m courier.Message
if err := p.c.
if err := p.GetConnection(ctx).
Eager().
Where("status != ?", courier.MessageStatusSent).
Order("created_at DESC").First(&m); err != nil {
Expand All @@ -54,7 +54,7 @@ func (p *Persister) LatestQueuedMessage(ctx context.Context) (*courier.Message,
}

func (p *Persister) SetMessageStatus(ctx context.Context, id uuid.UUID, ms courier.MessageStatus) error {
count, err := p.c.RawQuery("UPDATE courier_messages SET status = ? WHERE id = ?", ms, id).ExecWithCount()
count, err := p.GetConnection(ctx).RawQuery("UPDATE courier_messages SET status = ? WHERE id = ?", ms, id).ExecWithCount()
if err != nil {
return sqlcon.HandleError(err)
}
Expand Down
10 changes: 5 additions & 5 deletions persistence/sql/persister_errorx.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (p *Persister) Add(ctx context.Context, csrfToken string, errs ...error) (u
WasSeen: false,
}

if err := p.c.Create(c); err != nil {
if err := p.GetConnection(ctx).Create(c); err != nil {
return uuid.Nil, sqlcon.HandleError(err)
}

Expand All @@ -42,11 +42,11 @@ func (p *Persister) Add(ctx context.Context, csrfToken string, errs ...error) (u

func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContainer, error) {
var ec errorx.ErrorContainer
if err := p.c.Find(&ec, id); err != nil {
if err := p.GetConnection(ctx).Find(&ec, id); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := p.c.RawQuery("UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ?", time.Now().UTC(), id).Exec(); err != nil {
if err := p.GetConnection(ctx).RawQuery("UPDATE selfservice_errors SET was_seen = true, seen_at = ? WHERE id = ?", time.Now().UTC(), id).Exec(); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -55,9 +55,9 @@ func (p *Persister) Read(ctx context.Context, id uuid.UUID) (*errorx.ErrorContai

func (p *Persister) Clear(ctx context.Context, olderThan time.Duration, force bool) (err error) {
if force {
err = p.c.RawQuery("DELETE FROM selfservice_errors WHERE seen_at < ?", olderThan).Exec()
err = p.GetConnection(ctx).RawQuery("DELETE FROM selfservice_errors WHERE seen_at < ?", olderThan).Exec()
} else {
err = p.c.RawQuery("DELETE FROM selfservice_errors WHERE was_seen=true AND seen_at < ?", time.Now().UTC().Add(-olderThan)).Exec()
err = p.GetConnection(ctx).RawQuery("DELETE FROM selfservice_errors WHERE was_seen=true AND seen_at < ?", time.Now().UTC().Add(-olderThan)).Exec()
}

return sqlcon.HandleError(err)
Expand Down
28 changes: 14 additions & 14 deletions persistence/sql/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ var _ identity.PrivilegedPool = new(Persister)

func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, ct identity.CredentialsType, match string) (*identity.Identity, *identity.Credentials, error) {
var cts []identity.CredentialsTypeTable
if err := p.c.All(&cts); err != nil {
if err := p.GetConnection(ctx).All(&cts); err != nil {
return nil, nil, sqlcon.HandleError(err)
}

var find struct {
IdentityID uuid.UUID `db:"identity_id"`
}

if err := p.c.RawQuery(`SELECT
if err := p.GetConnection(ctx).RawQuery(`SELECT
ic.identity_id
FROM identity_credentials ic
INNER JOIN identity_credential_types ict on ic.identity_credential_type_id = ict.id
Expand Down Expand Up @@ -148,7 +148,7 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
return err
}

return sqlcon.HandleError(p.c.Transaction(func(tx *pop.Connection) error {
return sqlcon.HandleError(p.GetConnection(ctx).Transaction(func(tx *pop.Connection) error {
if err := tx.Create(i); err != nil {
return err
}
Expand All @@ -165,7 +165,7 @@ func (p *Persister) ListIdentities(ctx context.Context, limit, offset int) ([]id
is := make([]identity.Identity, 0)

/* #nosec G201 TableName is static */
if err := sqlcon.HandleError(p.c.RawQuery(fmt.Sprintf("SELECT * FROM %s LIMIT ? OFFSET ?", new(identity.Identity).TableName()), limit, offset).All(&is)); err != nil {
if err := sqlcon.HandleError(p.GetConnection(ctx).RawQuery(fmt.Sprintf("SELECT * FROM %s LIMIT ? OFFSET ?", new(identity.Identity).TableName()), limit, offset).All(&is)); err != nil {
return nil, err
}

Expand All @@ -183,7 +183,7 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er
return err
}

return sqlcon.HandleError(p.c.Transaction(func(tx *pop.Connection) error {
return sqlcon.HandleError(p.GetConnection(ctx).Transaction(func(tx *pop.Connection) error {
if count, err := tx.Where("id = ?", i.ID).Count(i); err != nil {
return err
} else if count == 0 {
Expand Down Expand Up @@ -214,7 +214,7 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er

func (p *Persister) DeleteIdentity(ctx context.Context, id uuid.UUID) error {
/* #nosec G201 TableName is static */
count, err := p.c.RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ?", new(identity.Identity).TableName()), id).ExecWithCount()
count, err := p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE id = ?", new(identity.Identity).TableName()), id).ExecWithCount()
if err != nil {
return sqlcon.HandleError(err)
}
Expand All @@ -226,7 +226,7 @@ func (p *Persister) DeleteIdentity(ctx context.Context, id uuid.UUID) error {

func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID) (*identity.Identity, error) {
var i identity.Identity
if err := p.c.Find(&i, id); err != nil {
if err := p.GetConnection(ctx).Find(&i, id); err != nil {
return nil, sqlcon.HandleError(err)
}
i.Credentials = nil
Expand All @@ -239,19 +239,19 @@ func (p *Persister) GetIdentity(ctx context.Context, id uuid.UUID) (*identity.Id

func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (*identity.Identity, error) {
var i identity.Identity
if err := p.c.Eager().Find(&i, id); err != nil {
if err := p.GetConnection(ctx).Eager().Find(&i, id); err != nil {
return nil, sqlcon.HandleError(err)
}

var cts []identity.CredentialsTypeTable
if err := p.c.All(&cts); err != nil {
if err := p.GetConnection(ctx).All(&cts); err != nil {
return nil, sqlcon.HandleError(err)
}

i.Credentials = map[identity.CredentialsType]identity.Credentials{}
for _, creds := range i.CredentialsCollection {
var cs identity.CredentialIdentifierCollection
if err := p.c.Where("identity_credential_id = ?", creds.ID).All(&cs); err != nil {
if err := p.GetConnection(ctx).Where("identity_credential_id = ?", creds.ID).All(&cs); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -277,7 +277,7 @@ func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (

func (p *Persister) FindAddressByCode(ctx context.Context, code string) (*identity.VerifiableAddress, error) {
var address identity.VerifiableAddress
if err := p.c.Where("code = ?", code).First(&address); err != nil {
if err := p.GetConnection(ctx).Where("code = ?", code).First(&address); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -286,7 +286,7 @@ func (p *Persister) FindAddressByCode(ctx context.Context, code string) (*identi

func (p *Persister) FindAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (*identity.VerifiableAddress, error) {
var address identity.VerifiableAddress
if err := p.c.Where("via = ? AND value = ?", via, value).First(&address); err != nil {
if err := p.GetConnection(ctx).Where("via = ? AND value = ?", via, value).First(&address); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -299,7 +299,7 @@ func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
return err
}

return sqlcon.HandleError(p.c.RawQuery(
return sqlcon.HandleError(p.GetConnection(ctx).RawQuery(
/* #nosec G201 TableName is static */
fmt.Sprintf(
"UPDATE %s SET status = ?, verified = true, verified_at = ?, code = ? WHERE code = ?",
Expand All @@ -313,7 +313,7 @@ func (p *Persister) VerifyAddress(ctx context.Context, code string) error {
}

func (p *Persister) UpdateVerifiableAddress(ctx context.Context, address *identity.VerifiableAddress) error {
return sqlcon.HandleError(p.c.Update(address))
return sqlcon.HandleError(p.GetConnection(ctx).Update(address))
}

func (p *Persister) validateIdentity(i *identity.Identity) error {
Expand Down
40 changes: 23 additions & 17 deletions persistence/sql/persister_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sql
import (
"context"

"github.com/gobuffalo/pop/v5"

"github.com/gofrs/uuid"

"github.com/ory/x/sqlcon"
Expand All @@ -14,35 +16,39 @@ import (
var _ login.RequestPersister = new(Persister)

func (p *Persister) CreateLoginRequest(ctx context.Context, r *login.Request) error {
return p.c.Eager().Create(r)
return p.GetConnection(ctx).Eager().Create(r)
}

func (p *Persister) GetLoginRequest(ctx context.Context, id uuid.UUID) (*login.Request, error) {
conn := p.GetConnection(ctx)
var r login.Request
if err := p.c.Eager().Find(&r, id); err != nil {
if err := conn.Eager().Find(&r, id); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := (&r).AfterFind(p.c); err != nil {
if err := (&r).AfterFind(conn); err != nil {
return nil, err
}

return &r, nil
}

func (p *Persister) UpdateLoginRequest(ctx context.Context, id uuid.UUID, ct identity.CredentialsType, rm *login.RequestMethod) error {
rr, err := p.GetLoginRequest(ctx, id)
if err != nil {
return err
}

method, ok := rr.Methods[ct]
if !ok {
rm.RequestID = rr.ID
rm.Method = ct
return p.c.Save(rm)
}

method.Config = rm.Config
return p.c.Save(method)
return p.Transaction(ctx, func(tx *pop.Connection) error {
ctx := WithTransaction(ctx, tx)
rr, err := p.GetLoginRequest(ctx, id)
if err != nil {
return err
}

method, ok := rr.Methods[ct]
if !ok {
rm.RequestID = rr.ID
rm.Method = ct
return tx.Save(rm)
}

method.Config = rm.Config
return tx.Save(method)
})
}
6 changes: 3 additions & 3 deletions persistence/sql/persister_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ var _ profile.RequestPersister = new(Persister)

func (p *Persister) CreateProfileRequest(ctx context.Context, r *profile.Request) error {
r.IdentityID = r.Identity.ID
return sqlcon.HandleError(p.c.Create(r)) // This must not be eager or identities will be created / updated
return sqlcon.HandleError(p.GetConnection(ctx).Create(r)) // This must not be eager or identities will be created / updated
}

func (p *Persister) GetProfileRequest(ctx context.Context, id uuid.UUID) (*profile.Request, error) {
var r profile.Request
if err := p.c.Eager().Find(&r, id); err != nil {
if err := p.GetConnection(ctx).Eager().Find(&r, id); err != nil {
return nil, sqlcon.HandleError(err)
}
return &r, nil
}

func (p *Persister) UpdateProfileRequest(ctx context.Context, r *profile.Request) error {
return sqlcon.HandleError(p.c.Update(r)) // This must not be eager or identities will be created / updated
return sqlcon.HandleError(p.GetConnection(ctx).Update(r)) // This must not be eager or identities will be created / updated
}
10 changes: 5 additions & 5 deletions persistence/sql/persister_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
)

func (p *Persister) CreateRegistrationRequest(ctx context.Context, r *registration.Request) error {
return p.c.Eager().Create(r)
return p.GetConnection(ctx).Eager().Create(r)
}

func (p *Persister) GetRegistrationRequest(ctx context.Context, id uuid.UUID) (*registration.Request, error) {
var r registration.Request
if err := p.c.Eager().Find(&r, id); err != nil {
if err := p.GetConnection(ctx).Eager().Find(&r, id); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := (&r).AfterFind(p.c); err != nil {
if err := (&r).AfterFind(p.GetConnection(ctx)); err != nil {
return nil, err
}

Expand All @@ -38,9 +38,9 @@ func (p *Persister) UpdateRegistrationRequest(ctx context.Context, id uuid.UUID,
if !ok {
rm.RequestID = rr.ID
rm.Method = ct
return p.c.Save(rm)
return p.GetConnection(ctx).Save(rm)
}

method.Config = rm.Config
return p.c.Save(method)
return p.GetConnection(ctx).Save(method)
}
Loading

0 comments on commit 6ace1ee

Please sign in to comment.