Skip to content

Commit

Permalink
feat: update token on login
Browse files Browse the repository at this point in the history
  • Loading branch information
david972 authored and aeneasr committed May 27, 2022
1 parent 37cb4ce commit fb6d6ca
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 10 deletions.
6 changes: 4 additions & 2 deletions selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ continueLogin:
}

var i *identity.Identity
var s Strategy
for _, ss := range h.d.AllLoginStrategies() {
interim, err := ss.Login(w, r, f, sess)
if errors.Is(err, flow.ErrStrategyNotResponsible) {
Expand All @@ -608,15 +609,16 @@ continueLogin:
method := ss.CompletedAuthenticationMethod(r.Context())
sess.CompletedLoginFor(method.Method, method.AAL)
i = interim
s = ss
break
}

if i == nil {
if i == nil || s == nil {
h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewNoLoginStrategyResponsible()))
return
}

if err := h.d.LoginHookExecutor().PostLoginHook(w, r, f, i, sess); err != nil {
if err := h.d.LoginHookExecutor().PostLoginHook(w, r, s.ID(), f, i, sess); err != nil {
if errors.Is(err, ErrAddressNotVerified) {
h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewAddressNotVerifiedError()))
return
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func TestFlowLifecycle(t *testing.T) {
assert.NotEqual(t, gjson.Get(a, "session_token").String(), gjson.Get(b, "session_token").String())

assert.NotEmpty(t, gjson.Get(b, "session.id").String())
assert.NotEqual(t, gjson.Get(b, "session.id").String(), gjson.Get(a, "id").String())
assert.NotEqual(t, gjson.Get(b, "session.id").String(), gjson.Get(a, "session.id").String())
})
})
})
Expand Down
12 changes: 11 additions & 1 deletion selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type (
type (
executorDependencies interface {
config.Provider
identity.ManagementProvider
session.ManagementProvider
session.PersistenceProvider
x.WriterProvider
Expand Down Expand Up @@ -74,7 +75,7 @@ func (e *HookExecutor) requiresAAL2(r *http.Request, s *session.Session, a *Flow
return aalErr, true
}

func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *Flow, i *identity.Identity, s *session.Session) error {
func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, a *Flow, i *identity.Identity, s *session.Session) error {
if err := s.Activate(i, e.d.Config(r.Context()), time.Now().UTC()); err != nil {
return err
}
Expand Down Expand Up @@ -144,6 +145,15 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
return nil
}

if ct == identity.CredentialsTypeOIDC {
options := []identity.ManagerOption{
identity.ManagerExposeValidationErrorsForInternalTypeAssertion,
identity.ManagerAllowWriteProtectedTraits,
}
if err := e.d.IdentityManager().Update(r.Context(), i, options...); err != nil {
return errors.WithStack(err)
}
}
if err := e.d.SessionManager().UpsertAndIssueCookie(r.Context(), w, r, s); err != nil {
return errors.WithStack(err)
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestLoginExecutor(t *testing.T) {
}

testhelpers.SelfServiceHookLoginErrorHandler(t, w, r,
reg.LoginHookExecutor().PostLoginHook(w, r, a, useIdentity, sess))
reg.LoginHookExecutor().PostLoginHook(w, r, identity.CredentialsType(strategy), a, useIdentity, sess))
})

ts := httptest.NewServer(router)
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/strategy_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ func newHydra(t *testing.T, subject, website *string, scope *[]string) (remoteAd
fmt.Sprintf("URLS_SELF_ISSUER=http://127.0.0.1:%d/", publicPort),
"URLS_LOGIN=" + hydraIntegrationTSURL + "/login",
"URLS_CONSENT=" + hydraIntegrationTSURL + "/consent",
"TTL_ACCESS_TOKEN=1s",
},
Cmd: []string{"serve", "all", "--dangerous-force-http"},
ExposedPorts: []string{"4444/tcp", "4445/tcp"},
Expand All @@ -222,7 +223,7 @@ func newHydra(t *testing.T, subject, website *string, scope *[]string) (remoteAd
t.Cleanup(func() {
require.NoError(t, hydra.Close())
})
require.NoError(t, hydra.Expire(uint(60*5)))
require.NoError(t, hydra.Expire(uint(60*10)))

require.NotEmpty(t, hydra.GetPort("4444/tcp"), "%+v", hydra.Container.NetworkSettings.Ports)
require.NotEmpty(t, hydra.GetPort("4445/tcp"), "%+v", hydra.Container)
Expand Down
42 changes: 38 additions & 4 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,45 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login
return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The password credentials could not be decoded properly").WithDebug(err.Error())))
}

// UPDATE TOKEN
var it string
if idToken, ok := token.Extra("id_token").(string); ok {
if it, err = s.d.Cipher().Encrypt(r.Context(), []byte(idToken)); err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}
}

cat, err := s.d.Cipher().Encrypt(r.Context(), []byte(token.AccessToken))
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}

crt, err := s.d.Cipher().Encrypt(r.Context(), []byte(token.RefreshToken))
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}

sess := session.NewInactiveSession()
sess.CompletedLoginFor(s.ID(), identity.AuthenticatorAssuranceLevel1)
for _, c := range o.Providers {
if c.Subject == claims.Subject && c.Provider == provider.Config().ID {
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, a, i, sess); err != nil {
for k := range o.Providers {
p := &o.Providers[k]
if p.Subject == claims.Subject && p.Provider == provider.Config().ID {
i, err := s.d.PrivilegedIdentityPool().GetIdentityConfidential(r.Context(), i.ID)
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}

p.InitialIDToken = it
p.InitialAccessToken = cat
p.InitialRefreshToken = crt

c.Config, err = json.Marshal(o)
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}

i.Credentials[s.ID()] = *c
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, identity.CredentialsTypeOIDC, a, i, sess); err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}
return nil, nil
Expand All @@ -123,7 +157,7 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login
return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject)))
}

func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, ss *session.Session) (i *identity.Identity, err error) {
func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ *session.Session) (i *identity.Identity, err error) {
if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil {
return nil, err
}
Expand Down
17 changes: 17 additions & 0 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,21 @@ func TestStrategy(t *testing.T) {
)
}

getAccessToken := func(t *testing.T, provider string, body []byte) string {
i, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), uuid.FromStringOrNil(gjson.GetBytes(body, "identity.id").String()))
require.NoError(t, err)
c := i.Credentials[identity.CredentialsTypeOIDC].Config
return gjson.GetBytes(c, "providers.0.initial_access_token").String()
}

var registrationAccessToken string
t.Run("case=should pass registration", func(t *testing.T) {
r := newRegistrationFlow(t, returnTS.URL, time.Minute)
action := afv(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
ai(t, res, body)
expectTokens(t, "valid", body)
registrationAccessToken = getAccessToken(t, "valid", body)
})

t.Run("case=should pass login", func(t *testing.T) {
Expand All @@ -350,6 +359,14 @@ func TestStrategy(t *testing.T) {
ai(t, res, body)
expectTokens(t, "valid", body)
})

t.Run("case=token from login should not be the same", func(t *testing.T) {
r := newLoginFlow(t, returnTS.URL, time.Minute)
action := afv(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
ai(t, res, body)
assert.NotEqual(t, getAccessToken(t, "valid", body), registrationAccessToken)
})
})

t.Run("case=login without registered account", func(t *testing.T) {
Expand Down

0 comments on commit fb6d6ca

Please sign in to comment.