Skip to content

Commit

Permalink
chore: code review
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Mar 17, 2023
1 parent 8336c30 commit a18f91d
Show file tree
Hide file tree
Showing 18 changed files with 248 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export PATH := .bin:${PATH}
export PWD := $(shell pwd)
export BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
export VCS_REF := $(shell git rev-parse HEAD)
export QUICKSTART_OPTIONS ?=
export QUICKSTART_OPTIONS ?= ""

GO_DEPENDENCIES = github.com/ory/go-acc \
github.com/golang/mock/mockgen \
Expand Down
8 changes: 8 additions & 0 deletions identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
package identity

import (
"bytes"
"context"
"encoding/json"
"reflect"
"time"

"github.com/pkg/errors"

"github.com/gobuffalo/pop/v6"

"github.com/ory/kratos/ui/node"
Expand Down Expand Up @@ -117,6 +121,10 @@ func (c *Credentials) AfterEagerFind(tx *pop.Connection) error {
return c.setCredentials()
}

func (c *Credentials) UnmarshalConfig(target interface{}) error {
return errors.WithStack(json.NewDecoder(bytes.NewBuffer(c.Config)).Decode(&target))
}

func (c *Credentials) setCredentials() error {
c.Type = c.IdentityCredentialType.Name
c.Identifiers = make([]string, 0, len(c.CredentialIdentifiers))
Expand Down
15 changes: 15 additions & 0 deletions identity/credentials_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type CredentialsOIDCProvider struct {
InitialIDToken string `json:"initial_id_token"`
InitialAccessToken string `json:"initial_access_token"`
InitialRefreshToken string `json:"initial_refresh_token"`
CurrentIDToken string `json:"current_id_token"`
CurrentAccessToken string `json:"current_access_token"`
CurrentRefreshToken string `json:"current_refresh_token"`
}

// NewCredentialsOIDC creates a new OIDC credential.
Expand All @@ -50,6 +53,9 @@ func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject st
InitialIDToken: idToken,
InitialAccessToken: accessToken,
InitialRefreshToken: refreshToken,
CurrentIDToken: idToken,
CurrentAccessToken: accessToken,
CurrentRefreshToken: refreshToken,
}},
}); err != nil {
return nil, errors.WithStack(x.PseudoPanic.
Expand All @@ -66,3 +72,12 @@ func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject st
func OIDCUniqueID(provider, subject string) string {
return fmt.Sprintf("%s:%s", provider, subject)
}

func (c *CredentialsOIDC) GetProvider(provider, subject string) (k int, found bool) {
for k, p := range c.Providers {
if p.Subject == subject && p.Provider == provider {
return k, true
}
}
return -1, false
}
23 changes: 23 additions & 0 deletions identity/credentials_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,26 @@ func TestNewCredentialsOIDC(t *testing.T) {
_, err = NewCredentialsOIDC("", "", "", "not-empty", "not-empty")
require.NoError(t, err)
}

func TestGetProvider(t *testing.T) {
c := CredentialsOIDC{
Providers: []CredentialsOIDCProvider{
{
Subject: "user-a",
Provider: "google",
},
{
Subject: "user-a",
Provider: "github",
},
},
}

k, found := c.GetProvider("github", "user-a")
require.True(t, found)
require.Equal(t, 1, k)

k, found = c.GetProvider("not-found", "user-a")
require.False(t, found)
require.Equal(t, -1, k)
}
6 changes: 6 additions & 0 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,19 @@ func TestHandler(t *testing.T) {
InitialAccessToken: transform(accessToken + "0"),
InitialRefreshToken: transform(refreshToken + "0"),
InitialIDToken: transform(idToken + "0"),
CurrentAccessToken: transform(accessToken + "current-0"),
CurrentRefreshToken: transform(refreshToken + "current-0"),
CurrentIDToken: transform(idToken + "current-0"),
},
{
Subject: "baz",
Provider: "zab",
InitialAccessToken: transform(accessToken + "1"),
InitialRefreshToken: transform(refreshToken + "1"),
InitialIDToken: transform(idToken + "1"),
CurrentAccessToken: transform(accessToken + "current-1"),
CurrentRefreshToken: transform(refreshToken + "current-1"),
CurrentIDToken: transform(idToken + "current-1"),
},
}}),
},
Expand Down
2 changes: 1 addition & 1 deletion identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func (i *Identity) WithDeclassifiedCredentialsOIDC(ctx context.Context, c cipher
toPublish := original
toPublish.Config = []byte{}

for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token"} {
for _, token := range []string{"initial_id_token", "initial_access_token", "initial_refresh_token", "current_id_token", "current_access_token", "current_refresh_token"} {
var i int
var err error
gjson.GetBytes(original.Config, "providers").ForEach(func(_, v gjson.Result) bool {
Expand Down
27 changes: 27 additions & 0 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,33 @@ func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...Manager
return m.r.PrivilegedIdentityPool().UpdateIdentity(ctx, updated)
}

func (m *Manager) UpdateCredentials(ctx context.Context, id uuid.UUID, ct CredentialsType, cb func(*Credentials) error, opts ...ManagerOption) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.Update")
defer otelx.End(span, &err)

updated, err := m.r.PrivilegedIdentityPool().GetIdentityConfidential(ctx, id)
if err != nil {
return err
}

c, ok := updated.GetCredentials(ct)
if !ok {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected credentials of type %s to exist but they did not.", ct))
}

if err := cb(c); err != nil {
return err
}

updated.SetCredentials(ct, *c)
o := newManagerOptions(opts)
if err := m.ValidateIdentity(ctx, updated, o); err != nil {
return err
}

return m.r.PrivilegedIdentityPool().UpdateIdentity(ctx, updated)
}

func (m *Manager) UpdateSchemaID(ctx context.Context, id uuid.UUID, schemaID string, opts ...ManagerOption) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.UpdateSchemaID")
defer otelx.End(span, &err)
Expand Down
53 changes: 53 additions & 0 deletions identity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
"testing"
"time"

"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/ory/x/assertx"

"github.com/ory/x/sqlxx"

"github.com/ory/kratos/internal/testhelpers"
Expand Down Expand Up @@ -203,6 +208,54 @@ func TestManager(t *testing.T) {
assert.Equal(t, 1, count)
})

t.Run("method=UpdateCredentials", func(t *testing.T) {
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits("[email protected]", "")
original.Credentials[identity.CredentialsTypePassword] = identity.Credentials{
Type: identity.CredentialsTypePassword,
Identifiers: []string{"[email protected]"},
Config: []byte(`{"hashed_password":"$argon2id$v=19$m=32,t=2,p=4$cm94YnRVOW5jZzFzcVE4bQ$MNzk5BtR2vUhrp6qQEjRNw"}`),
}
original.Credentials[identity.CredentialsTypeWebAuthn] = identity.Credentials{
Type: identity.CredentialsTypeWebAuthn,
Identifiers: []string{"foo"},
Config: []byte(`{"credentials":[{"is_passwordless":false}]}`),
Version: 1,
}
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))

t.Run("case=can not update non-existing credentials", func(t *testing.T) {
require.Error(t, reg.IdentityManager().UpdateCredentials(context.Background(), original.ID, identity.CredentialsTypeOIDC, func(c *identity.Credentials) error {
return nil
}))
})

t.Run("case=propagates error", func(t *testing.T) {
err := errors.New("foo")
require.ErrorIs(t, reg.IdentityManager().UpdateCredentials(context.Background(), original.ID, identity.CredentialsTypePassword, func(c *identity.Credentials) error {
return err
}), err)
})

t.Run("case=updates credentials", func(t *testing.T) {
require.NoError(t, reg.IdentityManager().UpdateCredentials(context.Background(), original.ID, identity.CredentialsTypePassword, func(c *identity.Credentials) (err error) {
c.Config, err = sjson.SetBytes(c.Config, "new_key", "new_value")
return nil
}))

fromStore, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), original.ID)
require.NoError(t, err)

actual, ok := fromStore.GetCredentials(identity.CredentialsTypePassword)
require.True(t, ok)
assert.Equal(t, "new_value", gjson.GetBytes(actual.Config, "new_key").String())

actual, ok = fromStore.GetCredentials(identity.CredentialsTypeWebAuthn)
require.True(t, ok)
assertx.EqualAsJSONExcept(t, actual, original.Credentials[identity.CredentialsTypeWebAuthn], []string{"updated_at"}, "other credentials should not be changed")
})
})

t.Run("method=UpdateTraits", func(t *testing.T) {
t.Run("case=should update protected traits with option", func(t *testing.T) {
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
Expand Down
6 changes: 2 additions & 4 deletions selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,6 @@ continueLogin:

var i *identity.Identity
var group node.UiNodeGroup
var s Strategy
for _, ss := range h.d.AllLoginStrategies() {
interim, err := ss.Login(w, r, f, sess.IdentityID)
group = ss.NodeGroup()
Expand All @@ -738,16 +737,15 @@ continueLogin:
method := ss.CompletedAuthenticationMethod(r.Context())
sess.CompletedLoginFor(method.Method, method.AAL)
i = interim
s = ss
break
}

if i == nil || s == nil {
if i == nil {
h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewNoLoginStrategyResponsible()))
return
}

if err := h.d.LoginHookExecutor().PostLoginHook(w, r, s.ID(), group, f, i, sess); err != nil {
if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess); err != nil {
if errors.Is(err, ErrAddressNotVerified) {
h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewAddressNotVerifiedError()))
return
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func TestFlowLifecycle(t *testing.T) {
assert.NotEqual(t, gjson.Get(a, "session_token").String(), gjson.Get(b, "session_token").String())

assert.NotEmpty(t, gjson.Get(b, "session.id").String())
assert.NotEqual(t, gjson.Get(b, "session.id").String(), gjson.Get(a, "session.id").String())
assert.NotEqual(t, gjson.Get(b, "session.id").String(), gjson.Get(a, "id").String())
})
})

Expand Down
14 changes: 2 additions & 12 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ type (
type (
executorDependencies interface {
config.Provider
identity.ManagementProvider
hydra.HydraProvider
identity.ManagementProvider
session.ManagementProvider
session.PersistenceProvider
x.CSRFTokenGeneratorProvider
Expand Down Expand Up @@ -114,11 +112,12 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request,
return flowError
}

func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) (err error) {
func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) (err error) {
ctx := r.Context()
ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook")
r = r.WithContext(ctx)
defer otelx.End(span, &err)

if err := s.Activate(r, i, e.d.Config(), time.Now().UTC()); err != nil {
return err
}
Expand Down Expand Up @@ -199,15 +198,6 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, ct
return nil
}

if ct == identity.CredentialsTypeOIDC {
options := []identity.ManagerOption{
identity.ManagerExposeValidationErrorsForInternalTypeAssertion,
identity.ManagerAllowWriteProtectedTraits,
}
if err := e.d.IdentityManager().Update(r.Context(), i, options...); err != nil {
return errors.WithStack(err)
}
}
if err := e.d.SessionManager().UpsertAndIssueCookie(r.Context(), w, r, s); err != nil {
return errors.WithStack(err)
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestLoginExecutor(t *testing.T) {
}

testhelpers.SelfServiceHookLoginErrorHandler(t, w, r,
reg.LoginHookExecutor().PostLoginHook(w, r, identity.CredentialsType(strategy), strategy.ToUiNodeGroup(), a, useIdentity, sess))
reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), a, useIdentity, sess))
})

ts := httptest.NewServer(router)
Expand Down
Loading

0 comments on commit a18f91d

Please sign in to comment.