From 90747f534553fa98554e056cf4433db19f122418 Mon Sep 17 00:00:00 2001 From: Tom Fenech Date: Tue, 14 May 2024 16:42:21 +0200 Subject: [PATCH] Pass OIDC claims into post-login flow to include in web hook context The login flow doesn't trigger a refresh of the identity when the OIDC claims have changed. By passing the claims through to the web hook context, this means that an external handler can be configured to update the identity as appropriate, when there are changes. --- selfservice/flow/login/handler.go | 2 +- selfservice/flow/login/hook.go | 16 ++-- selfservice/flow/login/hook_test.go | 2 +- selfservice/hook/address_verifier.go | 3 +- selfservice/hook/address_verifier_test.go | 2 +- selfservice/hook/error.go | 3 +- selfservice/hook/session_destroyer.go | 11 ++- selfservice/hook/session_destroyer_test.go | 1 + selfservice/hook/show_verification_ui.go | 3 +- selfservice/hook/show_verification_ui_test.go | 8 +- selfservice/hook/stub/test_body.jsonnet | 10 ++- selfservice/hook/verification.go | 3 +- selfservice/hook/verification_test.go | 4 +- selfservice/hook/web_hook.go | 5 +- selfservice/hook/web_hook_integration_test.go | 32 +++++--- selfservice/strategy/oidc/claims/claims.go | 49 ++++++++++++ .../strategy/oidc/claims/claims_test.go | 18 +++++ selfservice/strategy/oidc/claims/locale.go | 29 +++++++ selfservice/strategy/oidc/provider.go | 79 ++----------------- selfservice/strategy/oidc/provider_apple.go | 10 ++- .../strategy/oidc/provider_apple_test.go | 11 +-- selfservice/strategy/oidc/provider_auth0.go | 5 +- .../strategy/oidc/provider_dingtalk.go | 7 +- selfservice/strategy/oidc/provider_discord.go | 7 +- .../strategy/oidc/provider_facebook.go | 5 +- .../strategy/oidc/provider_generic_oidc.go | 13 +-- selfservice/strategy/oidc/provider_github.go | 5 +- .../strategy/oidc/provider_github_app.go | 5 +- selfservice/strategy/oidc/provider_gitlab.go | 5 +- selfservice/strategy/oidc/provider_google.go | 5 +- selfservice/strategy/oidc/provider_lark.go | 7 +- .../strategy/oidc/provider_linkedin.go | 5 +- .../strategy/oidc/provider_linkedin_test.go | 5 +- .../strategy/oidc/provider_microsoft.go | 5 +- selfservice/strategy/oidc/provider_netid.go | 11 +-- selfservice/strategy/oidc/provider_patreon.go | 5 +- selfservice/strategy/oidc/provider_slack.go | 5 +- selfservice/strategy/oidc/provider_spotify.go | 5 +- selfservice/strategy/oidc/provider_test.go | 17 ++-- .../strategy/oidc/provider_userinfo_test.go | 17 ++-- selfservice/strategy/oidc/provider_vk.go | 5 +- selfservice/strategy/oidc/provider_x.go | 20 +++-- selfservice/strategy/oidc/provider_yandex.go | 5 +- selfservice/strategy/oidc/strategy.go | 5 +- .../strategy/oidc/strategy_helper_test.go | 4 +- selfservice/strategy/oidc/strategy_login.go | 5 +- .../strategy/oidc/strategy_registration.go | 7 +- .../strategy/oidc/strategy_settings.go | 3 +- selfservice/strategy/oidc/token_verifier.go | 6 +- 49 files changed, 290 insertions(+), 210 deletions(-) create mode 100644 selfservice/strategy/oidc/claims/claims.go create mode 100644 selfservice/strategy/oidc/claims/claims_test.go create mode 100644 selfservice/strategy/oidc/claims/locale.go diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index 88b3712602a0..ff9c1517dc1f 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -822,7 +822,7 @@ continueLogin: return } - if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, ""); err != nil { + if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, nil, ""); err != nil { if errors.Is(err, ErrAddressNotVerified) { h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewAddressNotVerifiedError())) return diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index f0e06ccfc934..3b98097e66d5 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/schema" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/sessiontokenexchange" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/container" "github.com/ory/kratos/ui/node" @@ -34,7 +35,7 @@ type ( } PostHookExecutor interface { - ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, s *session.Session) error + ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, s *session.Session, c *claims.Claims) error } HooksProvider interface { @@ -125,6 +126,7 @@ func (e *HookExecutor) PostLoginHook( f *Flow, i *identity.Identity, s *session.Session, + c *claims.Claims, provider string, ) (err error) { ctx := r.Context() @@ -140,15 +142,15 @@ func (e *HookExecutor) PostLoginHook( return err } - c := e.d.Config() + cfg := e.d.Config() // Verify the redirect URL before we do any other processing. returnTo, err := x.SecureRedirectTo(r, - c.SelfServiceBrowserDefaultReturnTo(r.Context()), + cfg.SelfServiceBrowserDefaultReturnTo(r.Context()), x.SecureRedirectReturnTo(f.ReturnTo), x.SecureRedirectUseSourceURL(f.RequestURL), - x.SecureRedirectAllowURLs(c.SelfServiceBrowserAllowedReturnToDomains(r.Context())), - x.SecureRedirectAllowSelfServiceURLs(c.SelfPublicURL(r.Context())), - x.SecureRedirectOverrideDefaultReturnTo(c.SelfServiceFlowLoginReturnTo(r.Context(), f.Active.String())), + x.SecureRedirectAllowURLs(cfg.SelfServiceBrowserAllowedReturnToDomains(r.Context())), + x.SecureRedirectAllowSelfServiceURLs(cfg.SelfPublicURL(r.Context())), + x.SecureRedirectOverrideDefaultReturnTo(cfg.SelfServiceFlowLoginReturnTo(r.Context(), f.Active.String())), ) if err != nil { return err @@ -168,7 +170,7 @@ func (e *HookExecutor) PostLoginHook( WithField("flow_method", f.Active). Debug("Running ExecuteLoginPostHook.") for k, executor := range e.d.PostLoginHooks(r.Context(), f.Active) { - if err := executor.ExecuteLoginPostHook(w, r, g, f, s); err != nil { + if err := executor.ExecuteLoginPostHook(w, r, g, f, s, c); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). diff --git a/selfservice/flow/login/hook_test.go b/selfservice/flow/login/hook_test.go index fe73f22d7eef..bcaa8bb58b2f 100644 --- a/selfservice/flow/login/hook_test.go +++ b/selfservice/flow/login/hook_test.go @@ -72,7 +72,7 @@ func TestLoginExecutor(t *testing.T) { } testhelpers.SelfServiceHookLoginErrorHandler(t, w, r, - reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, "")) + reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, nil, "")) }) ts := httptest.NewServer(router) diff --git a/selfservice/hook/address_verifier.go b/selfservice/hook/address_verifier.go index b28ca6d9b7e4..7b713380a6dc 100644 --- a/selfservice/hook/address_verifier.go +++ b/selfservice/hook/address_verifier.go @@ -14,6 +14,7 @@ import ( "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/login" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" ) @@ -25,7 +26,7 @@ func NewAddressVerifier() *AddressVerifier { return &AddressVerifier{} } -func (e *AddressVerifier) ExecuteLoginPostHook(_ http.ResponseWriter, _ *http.Request, _ node.UiNodeGroup, f *login.Flow, s *session.Session) error { +func (e *AddressVerifier) ExecuteLoginPostHook(_ http.ResponseWriter, _ *http.Request, _ node.UiNodeGroup, f *login.Flow, s *session.Session, _ *claims.Claims) error { // if the login happens using the password method, there must be at least one verified address if f.Active != identity.CredentialsTypePassword { return nil diff --git a/selfservice/hook/address_verifier_test.go b/selfservice/hook/address_verifier_test.go index fa80c3632644..7f6532507f64 100644 --- a/selfservice/hook/address_verifier_test.go +++ b/selfservice/hook/address_verifier_test.go @@ -82,7 +82,7 @@ func TestAddressVerifier(t *testing.T) { Identity: &identity.Identity{ID: x.NewUUID(), VerifiableAddresses: uc.verifiableAddresses}, } - err := verifier.ExecuteLoginPostHook(nil, nil, node.DefaultGroup, tc.flow, sessions) + err := verifier.ExecuteLoginPostHook(nil, nil, node.DefaultGroup, tc.flow, sessions, nil) if tc.neverError || uc.expectedError == nil { assert.NoError(t, err) diff --git a/selfservice/hook/error.go b/selfservice/hook/error.go index bb396578e5f8..a82d9e3511fe 100644 --- a/selfservice/hook/error.go +++ b/selfservice/hook/error.go @@ -12,6 +12,7 @@ import ( "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/identity" @@ -64,7 +65,7 @@ func (e Error) ExecuteSettingsPostPersistHook(w http.ResponseWriter, r *http.Req return e.err("ExecuteSettingsPostPersistHook", settings.ErrHookAbortFlow) } -func (e Error) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *login.Flow, s *session.Session) error { +func (e Error) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *login.Flow, s *session.Session, c *claims.Claims) error { return e.err("ExecuteLoginPostHook", login.ErrHookAbortFlow) } diff --git a/selfservice/hook/session_destroyer.go b/selfservice/hook/session_destroyer.go index 16f7aa11b435..0e19b2d85d9a 100644 --- a/selfservice/hook/session_destroyer.go +++ b/selfservice/hook/session_destroyer.go @@ -11,14 +11,17 @@ import ( "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/settings" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/x/otelx" ) -var _ login.PostHookExecutor = new(SessionDestroyer) -var _ recovery.PostHookExecutor = new(SessionDestroyer) -var _ settings.PostHookPostPersistExecutor = new(SessionDestroyer) +var ( + _ login.PostHookExecutor = new(SessionDestroyer) + _ recovery.PostHookExecutor = new(SessionDestroyer) + _ settings.PostHookPostPersistExecutor = new(SessionDestroyer) +) type ( sessionDestroyerDependencies interface { @@ -34,7 +37,7 @@ func NewSessionDestroyer(r sessionDestroyerDependencies) *SessionDestroyer { return &SessionDestroyer{r: r} } -func (e *SessionDestroyer) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, _ *login.Flow, s *session.Session) error { +func (e *SessionDestroyer) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, _ *login.Flow, s *session.Session, _ *claims.Claims) error { return otelx.WithSpan(r.Context(), "selfservice.hook.SessionDestroyer.ExecuteLoginPostHook", func(ctx context.Context) error { if _, err := e.r.SessionPersister().RevokeSessionsIdentityExcept(ctx, s.Identity.ID, s.ID); err != nil { return err diff --git a/selfservice/hook/session_destroyer_test.go b/selfservice/hook/session_destroyer_test.go index e2d0cc21c2e2..833dd2968a61 100644 --- a/selfservice/hook/session_destroyer_test.go +++ b/selfservice/hook/session_destroyer_test.go @@ -52,6 +52,7 @@ func TestSessionDestroyer(t *testing.T) { node.DefaultGroup, nil, &session.Session{Identity: i}, + nil, ) }, }, diff --git a/selfservice/hook/show_verification_ui.go b/selfservice/hook/show_verification_ui.go index 65a5935ec7a6..c1893485fe7b 100644 --- a/selfservice/hook/show_verification_ui.go +++ b/selfservice/hook/show_verification_ui.go @@ -11,6 +11,7 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/x" @@ -52,7 +53,7 @@ func (e *ShowVerificationUIHook) ExecutePostRegistrationPostPersistHook(_ http.R // ExecuteLoginPostHook adds redirect headers and status code if the request is a browser request. // If the request is not a browser request, this hook does nothing. -func (e *ShowVerificationUIHook) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, _ *session.Session) error { +func (e *ShowVerificationUIHook) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, _ *session.Session, _ *claims.Claims) error { return otelx.WithSpan(r.Context(), "selfservice.hook.ShowVerificationUIHook.ExecutePostRegistrationPostPersistHook", func(ctx context.Context) error { return e.execute(r.WithContext(ctx), f) }) diff --git a/selfservice/hook/show_verification_ui_test.go b/selfservice/hook/show_verification_ui_test.go index 22171f0c345b..824eefe0bbf3 100644 --- a/selfservice/hook/show_verification_ui_test.go +++ b/selfservice/hook/show_verification_ui_test.go @@ -84,7 +84,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { browserRequest := httptest.NewRequest("GET", "/", nil) f := &login.Flow{} rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil, nil)) require.Equal(t, 200, rec.Code) }) @@ -95,7 +95,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { browserRequest.Header.Add("Accept", "application/json") f := &login.Flow{} rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil, nil)) require.Equal(t, 200, rec.Code) }) @@ -112,7 +112,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { flow.NewContinueWithVerificationUI(vf, "some@ory.sh", ""), } rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil, nil)) assert.Equal(t, 200, rec.Code) assert.Equal(t, "/verification?flow="+vf.ID.String(), rf.ReturnToVerification) }) @@ -127,7 +127,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { flow.NewContinueWithSetToken("token"), } rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil, nil)) assert.Equal(t, 200, rec.Code) }) }) diff --git a/selfservice/hook/stub/test_body.jsonnet b/selfservice/hook/stub/test_body.jsonnet index 117ecc587707..f406ad51076e 100644 --- a/selfservice/hook/stub/test_body.jsonnet +++ b/selfservice/hook/stub/test_body.jsonnet @@ -1,10 +1,14 @@ function(ctx) std.prune({ flow_id: ctx.flow.id, - identity_id: if std.objectHas(ctx, "identity") then ctx.identity.id, - session_id: if std.objectHas(ctx, "session") then ctx.session.id, + identity_id: if std.objectHas(ctx, 'identity') then ctx.identity.id, + session_id: if std.objectHas(ctx, 'session') then ctx.session.id, headers: ctx.request_headers, url: ctx.request_url, method: ctx.request_method, cookies: ctx.request_cookies, - transient_payload: if std.objectHas(ctx.flow, "transient_payload") then ctx.flow.transient_payload, + transient_payload: if std.objectHas(ctx.flow, 'transient_payload') then ctx.flow.transient_payload, + nickname: if std.objectHas(ctx, 'claims') then ctx.claims.nickname, + groups: if std.objectHas(ctx, 'claims') && + std.objectHas(ctx.claims, 'raw_claims') && + std.objectHas(ctx.claims.raw_claims, 'groups') then ctx.claims.raw_claims.groups, }) diff --git a/selfservice/hook/verification.go b/selfservice/hook/verification.go index 6fdd039c146f..bd722f5edfb0 100644 --- a/selfservice/hook/verification.go +++ b/selfservice/hook/verification.go @@ -16,6 +16,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/x" @@ -65,7 +66,7 @@ func (e *Verifier) ExecuteSettingsPostPersistHook(w http.ResponseWriter, r *http }) } -func (e *Verifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, f *login.Flow, s *session.Session) (err error) { +func (e *Verifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, f *login.Flow, s *session.Session, c *claims.Claims) (err error) { ctx, span := e.r.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.hook.Verifier.ExecuteLoginPostHook") r = r.WithContext(ctx) defer otelx.End(span, &err) diff --git a/selfservice/hook/verification_test.go b/selfservice/hook/verification_test.go index 5e815fa4d4a4..9facac1653f5 100644 --- a/selfservice/hook/verification_test.go +++ b/selfservice/hook/verification_test.go @@ -45,7 +45,7 @@ func TestVerifier(t *testing.T) { name: "login", execHook: func(h *hook.Verifier, i *identity.Identity, f flow.Flow) error { return h.ExecuteLoginPostHook( - httptest.NewRecorder(), u, node.CodeGroup, f.(*login.Flow), &session.Session{ID: x.NewUUID(), Identity: i}) + httptest.NewRecorder(), u, node.CodeGroup, f.(*login.Flow), &session.Session{ID: x.NewUUID(), Identity: i}, nil) }, originalFlow: func() flow.FlowWithContinueWith { return &login.Flow{RequestURL: "http://foo.com/login", RequestedAAL: "aal1"} @@ -126,7 +126,7 @@ func TestVerifier(t *testing.T) { h := hook.NewVerifier(reg) i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) f := &login.Flow{RequestedAAL: "aal2"} - require.NoError(t, h.ExecuteLoginPostHook(httptest.NewRecorder(), u, node.CodeGroup, f, &session.Session{ID: x.NewUUID(), Identity: i})) + require.NoError(t, h.ExecuteLoginPostHook(httptest.NewRecorder(), u, node.CodeGroup, f, &session.Session{ID: x.NewUUID(), Identity: i}, nil)) messages, err := reg.CourierPersister().NextMessages(context.Background(), 12) require.EqualError(t, err, "queue is empty") diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index d4c8131e50d5..17141f27336a 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -32,6 +32,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -82,6 +83,7 @@ type ( RequestCookies map[string]string `json:"request_cookies"` Identity *identity.Identity `json:"identity,omitempty"` Session *session.Session `json:"session,omitempty"` + Claims *claims.Claims `json:"claims,omitempty"` } WebHook struct { @@ -132,7 +134,7 @@ func (e *WebHook) ExecuteLoginPreHook(_ http.ResponseWriter, req *http.Request, }) } -func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, _ node.UiNodeGroup, flow *login.Flow, session *session.Session) error { +func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, _ node.UiNodeGroup, flow *login.Flow, session *session.Session, claims *claims.Claims) error { return otelx.WithSpan(req.Context(), "selfservice.hook.WebHook.ExecuteLoginPostHook", func(ctx context.Context) error { return e.execute(ctx, &templateContext{ Flow: flow, @@ -142,6 +144,7 @@ func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, RequestCookies: cookies(req), Identity: session.Identity, Session: session, + Claims: claims, }) }) } diff --git a/selfservice/hook/web_hook_integration_test.go b/selfservice/hook/web_hook_integration_test.go index 59159f49b6cf..562195777087 100644 --- a/selfservice/hook/web_hook_integration_test.go +++ b/selfservice/hook/web_hook_integration_test.go @@ -38,6 +38,7 @@ import ( "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" "github.com/ory/kratos/selfservice/hook" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -55,6 +56,13 @@ var transientPayload = json.RawMessage(`{ } }`) +var oidcClaims = claims.Claims{ + Nickname: "nicky", + RawClaims: map[string]interface{}{ + "groups": []string{"first", "second"}, + }, +} + func TestWebHooks(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) logger := logrusx.New("kratos", "test") @@ -148,7 +156,7 @@ func TestWebHooks(t *testing.T) { }`, f.GetID(), s.Identity.ID, string(h), req.Method, "http://www.ory.sh/some_end_point", string(tp)) } - bodyWithFlowAndIdentityAndSessionAndTransientPayload := func(req *http.Request, f flow.Flow, s *session.Session, tp json.RawMessage) string { + bodyWithFlowAndIdentityAndSessionAndClaimsAndTransientPayload := func(req *http.Request, f flow.Flow, s *session.Session, c *claims.Claims, tp json.RawMessage) string { h, _ := json.Marshal(req.Header) return fmt.Sprintf(`{ "flow_id": "%s", @@ -162,8 +170,10 @@ func TestWebHooks(t *testing.T) { "Some-Cookie-2": "Some-other-Cookie-Value", "Some-Cookie-3": "Third-Cookie-Value" }, - "transient_payload": %s - }`, f.GetID(), s.Identity.ID, s.ID, string(h), req.Method, "http://www.ory.sh/some_end_point", string(tp)) + "transient_payload": %s, + "nickname": "%s", + "groups": ["%s", "%s"] + }`, f.GetID(), s.Identity.ID, s.ID, string(h), req.Method, "http://www.ory.sh/some_end_point", string(tp), oidcClaims.Nickname, oidcClaims.RawClaims["groups"].([]string)[0], oidcClaims.RawClaims["groups"].([]string)[1]) } for _, tc := range []struct { @@ -186,10 +196,10 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID(), TransientPayload: transientPayload} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, &oidcClaims) }, expectedBody: func(req *http.Request, f flow.Flow, s *session.Session) string { - return bodyWithFlowAndIdentityAndSessionAndTransientPayload(req, f, s, transientPayload) + return bodyWithFlowAndIdentityAndSessionAndClaimsAndTransientPayload(req, f, s, &oidcClaims, transientPayload) }, }, { @@ -428,7 +438,7 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook - no block", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID()} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, nil) }, webHookResponse: func() (int, []byte) { return http.StatusOK, []byte{} @@ -439,7 +449,7 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook - block", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID()} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, nil) }, webHookResponse: func() (int, []byte) { return http.StatusBadRequest, webHookResponse @@ -1004,7 +1014,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { "method": "GET", "body": "file://stub/test_body.jsonnet" }`)) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err) require.Contains(t, err.Error(), "is not a permitted destination") }) @@ -1016,7 +1026,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { "method": "GET", "body": "file://stub/test_body.jsonnet" }`)) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err, "the target does not exist and we still receive an error") require.NotContains(t, err.Error(), "is not a permitted destination", "but the error is not related to the IP range.") }) @@ -1037,7 +1047,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { "method": "GET", "body": "http://192.168.178.0/test_body.jsonnet" }`)) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err) require.Contains(t, err.Error(), "is not a permitted destination") }) @@ -1094,7 +1104,7 @@ func TestAsyncWebhook(t *testing.T) { "ignore": true } }`, webhookReceiver.URL))) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.NoError(t, err) // execution returns immediately for async webhook select { case <-time.After(1 * time.Second): diff --git a/selfservice/strategy/oidc/claims/claims.go b/selfservice/strategy/oidc/claims/claims.go new file mode 100644 index 000000000000..73f492652359 --- /dev/null +++ b/selfservice/strategy/oidc/claims/claims.go @@ -0,0 +1,49 @@ +package claims + +import ( + "github.com/pkg/errors" + + "github.com/ory/herodot" + "github.com/ory/kratos/x" +) + +// ConvertibleBoolean is used as Apple casually sends the email_verified field as a string. +type Claims struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + LastName string `json:"last_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Email string `json:"email,omitempty"` + EmailVerified x.ConvertibleBoolean `json:"email_verified,omitempty"` + Gender string `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale Locale `json:"locale,omitempty"` + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + HD string `json:"hd,omitempty"` + Team string `json:"team,omitempty"` + Nonce string `json:"nonce,omitempty"` + NonceSupported bool `json:"nonce_supported,omitempty"` + RawClaims map[string]interface{} `json:"raw_claims,omitempty"` +} + +// Validate checks if the claims are valid. +func (c *Claims) Validate() error { + if c.Subject == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("provider did not return a subject")) + } + if c.Issuer == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("issuer not set in claims")) + } + return nil +} diff --git a/selfservice/strategy/oidc/claims/claims_test.go b/selfservice/strategy/oidc/claims/claims_test.go new file mode 100644 index 000000000000..47ada4a4695c --- /dev/null +++ b/selfservice/strategy/oidc/claims/claims_test.go @@ -0,0 +1,18 @@ +package claims_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/selfservice/strategy/oidc/claims" +) + +func TestClaimsValidate(t *testing.T) { + require.Error(t, new(claims.Claims).Validate()) + require.Error(t, (&claims.Claims{Issuer: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Issuer: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Subject: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Subject: "not-empty"}).Validate()) + require.NoError(t, (&claims.Claims{Issuer: "not-empty", Subject: "not-empty"}).Validate()) +} diff --git a/selfservice/strategy/oidc/claims/locale.go b/selfservice/strategy/oidc/claims/locale.go new file mode 100644 index 000000000000..07fb8576becc --- /dev/null +++ b/selfservice/strategy/oidc/claims/locale.go @@ -0,0 +1,29 @@ +package claims + +import ( + "encoding/json" + "strings" +) + +type Locale string + +func (l *Locale) UnmarshalJSON(data []byte) error { + var linkedInLocale struct { + Language string `json:"language"` + Country string `json:"country"` + } + if err := json.Unmarshal(data, &linkedInLocale); err == nil { + switch { + case linkedInLocale.Language == "": + *l = Locale(linkedInLocale.Country) + case linkedInLocale.Country == "": + *l = Locale(linkedInLocale.Language) + default: + *l = Locale(strings.Join([]string{linkedInLocale.Language, linkedInLocale.Country}, "-")) + } + + return nil + } + + return json.Unmarshal(data, (*string)(l)) +} diff --git a/selfservice/strategy/oidc/provider.go b/selfservice/strategy/oidc/provider.go index 30ea305a22ed..c241d12f257c 100644 --- a/selfservice/strategy/oidc/provider.go +++ b/selfservice/strategy/oidc/provider.go @@ -5,19 +5,14 @@ package oidc import ( "context" - "encoding/json" "net/http" "net/url" - "strings" "github.com/dghubble/oauth1" - "github.com/pkg/errors" - "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "golang.org/x/oauth2" - - "github.com/ory/kratos/x" ) type Provider interface { @@ -28,14 +23,14 @@ type OAuth2Provider interface { Provider AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption OAuth2(ctx context.Context) (*oauth2.Config, error) - Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) + Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) } type OAuth1Provider interface { Provider OAuth1(ctx context.Context) *oauth1.Config AuthURL(ctx context.Context, state string) (string, error) - Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) + Claims(ctx context.Context, token *oauth1.Token) (*claims.Claims, error) ExchangeToken(ctx context.Context, req *http.Request) (*oauth1.Token, error) } @@ -44,75 +39,11 @@ type OAuth2TokenExchanger interface { } type IDTokenVerifier interface { - Verify(ctx context.Context, rawIDToken string) (*Claims, error) + Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) } type NonceValidationSkipper interface { - CanSkipNonce(*Claims) bool -} - -// ConvertibleBoolean is used as Apple casually sends the email_verified field as a string. -type Claims struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - LastName string `json:"last_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Email string `json:"email,omitempty"` - EmailVerified x.ConvertibleBoolean `json:"email_verified,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale Locale `json:"locale,omitempty"` - PhoneNumber string `json:"phone_number,omitempty"` - PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` - UpdatedAt int64 `json:"updated_at,omitempty"` - HD string `json:"hd,omitempty"` - Team string `json:"team,omitempty"` - Nonce string `json:"nonce,omitempty"` - NonceSupported bool `json:"nonce_supported,omitempty"` - RawClaims map[string]interface{} `json:"raw_claims,omitempty"` -} - -type Locale string - -func (l *Locale) UnmarshalJSON(data []byte) error { - var linkedInLocale struct { - Language string `json:"language"` - Country string `json:"country"` - } - if err := json.Unmarshal(data, &linkedInLocale); err == nil { - switch { - case linkedInLocale.Language == "": - *l = Locale(linkedInLocale.Country) - case linkedInLocale.Country == "": - *l = Locale(linkedInLocale.Language) - default: - *l = Locale(strings.Join([]string{linkedInLocale.Language, linkedInLocale.Country}, "-")) - } - - return nil - } - - return json.Unmarshal(data, (*string)(l)) -} - -// Validate checks if the claims are valid. -func (c *Claims) Validate() error { - if c.Subject == "" { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("provider did not return a subject")) - } - if c.Issuer == "" { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("issuer not set in claims")) - } - return nil + CanSkipNonce(*claims.Claims) bool } // UpstreamParameters returns a list of oauth2.AuthCodeOption based on the upstream parameters. diff --git a/selfservice/strategy/oidc/provider_apple.go b/selfservice/strategy/oidc/provider_apple.go index 706a7150c5e4..0a45411e7b27 100644 --- a/selfservice/strategy/oidc/provider_apple.go +++ b/selfservice/strategy/oidc/provider_apple.go @@ -15,6 +15,8 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt/v4" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" + "github.com/pkg/errors" "golang.org/x/oauth2" @@ -112,7 +114,7 @@ func (a *ProviderApple) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return options } -func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { claims, err := a.ProviderGenericOIDC.Claims(ctx, exchange, query) if err != nil { return claims, err @@ -126,7 +128,7 @@ func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, quer // The info is sent as an extra query parameter to the redirect URL. // See https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/configuring_your_webpage_for_sign_in_with_apple#3331292 // Note that there's no way to make sure the info hasn't been tampered with. -func (a *ProviderApple) DecodeQuery(query url.Values, claims *Claims) { +func (a *ProviderApple) DecodeQuery(query url.Values, claims *claims.Claims) { var user struct { Name *struct { FirstName *string `json:"firstName"` @@ -156,7 +158,7 @@ var _ IDTokenVerifier = new(ProviderApple) const issuerUrlApple = "https://appleid.apple.com" -func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*Claims, error) { +func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) { keySet := oidc.NewRemoteKeySet(ctx, a.JWKSUrl) ctx = oidc.ClientContext(ctx, a.reg.HTTPClient(ctx).HTTPClient) @@ -165,6 +167,6 @@ func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*Claims, var _ NonceValidationSkipper = new(ProviderApple) -func (a *ProviderApple) CanSkipNonce(c *Claims) bool { +func (a *ProviderApple) CanSkipNonce(c *claims.Claims) bool { return c.NonceSupported } diff --git a/selfservice/strategy/oidc/provider_apple_test.go b/selfservice/strategy/oidc/provider_apple_test.go index 422ae643708a..39c6c95462f4 100644 --- a/selfservice/strategy/oidc/provider_apple_test.go +++ b/selfservice/strategy/oidc/provider_apple_test.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) func TestDecodeQuery(t *testing.T) { @@ -28,15 +29,15 @@ func TestDecodeQuery(t *testing.T) { } for k, tc := range []struct { - claims *oidc.Claims + claims *claims.Claims familyName string givenName string lastName string }{ - {claims: &oidc.Claims{}, familyName: "last", givenName: "first", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam"}, familyName: "fam", givenName: "first", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam", GivenName: "giv"}, familyName: "fam", givenName: "giv", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam", GivenName: "giv", LastName: "las"}, familyName: "fam", givenName: "giv", lastName: "las"}, + {claims: &claims.Claims{}, familyName: "last", givenName: "first", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam"}, familyName: "fam", givenName: "first", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam", GivenName: "giv"}, familyName: "fam", givenName: "giv", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam", GivenName: "giv", LastName: "las"}, familyName: "fam", givenName: "giv", lastName: "las"}, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { a := oidc.NewProviderApple(&oidc.Configuration{}, nil).(*oidc.ProviderApple) diff --git a/selfservice/strategy/oidc/provider_auth0.go b/selfservice/strategy/oidc/provider_auth0.go index a4c9ee46e1ab..479879980371 100644 --- a/selfservice/strategy/oidc/provider_auth0.go +++ b/selfservice/strategy/oidc/provider_auth0.go @@ -11,6 +11,7 @@ import ( "path" "time" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringsx" @@ -71,7 +72,7 @@ func (g *ProviderAuth0) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx) } -func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -113,7 +114,7 @@ func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, quer } // Once we get here, we know that if there is an updated_at field in the json, it is the correct type. - var claims Claims + var claims claims.Claims if err := json.Unmarshal(b, &claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_dingtalk.go b/selfservice/strategy/oidc/provider_dingtalk.go index 12abffe85942..436dbe931aee 100644 --- a/selfservice/strategy/oidc/provider_dingtalk.go +++ b/selfservice/strategy/oidc/provider_dingtalk.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/hashicorp/go-retryablehttp" @@ -40,7 +41,7 @@ func (g *ProviderDingTalk) Config() *Configuration { } func (g *ProviderDingTalk) oauth2(ctx context.Context) *oauth2.Config { - var endpoint = oauth2.Endpoint{ + endpoint := oauth2.Endpoint{ AuthURL: "https://login.dingtalk.com/oauth2/auth", TokenURL: "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", } @@ -122,7 +123,7 @@ func (g *ProviderDingTalk) ExchangeOAuth2Token(ctx context.Context, code string, return token, nil } -func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { userInfoURL := "https://api.dingtalk.com/v1.0/contact/users/me" accessToken := exchange.AccessToken @@ -160,7 +161,7 @@ func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("userResp.ErrCode = %s, userResp.ErrMsg = %s", user.ErrCode, user.ErrMsg)) } - return &Claims{ + return &claims.Claims{ Issuer: userInfoURL, Subject: user.OpenId, Nickname: user.Nick, diff --git a/selfservice/strategy/oidc/provider_discord.go b/selfservice/strategy/oidc/provider_discord.go index 99bea24d5770..2c542d0aa8f8 100644 --- a/selfservice/strategy/oidc/provider_discord.go +++ b/selfservice/strategy/oidc/provider_discord.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/bwmarrin/discordgo" @@ -66,7 +67,7 @@ func (d *ProviderDiscord) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { } } -func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ") for _, check := range d.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -84,7 +85,7 @@ func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, qu return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: discordgo.EndpointOauth2, Subject: user.ID, Name: fmt.Sprintf("%s#%s", user.Username, user.Discriminator), @@ -93,7 +94,7 @@ func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, qu Picture: user.AvatarURL(""), Email: user.Email, EmailVerified: x.ConvertibleBoolean(user.Verified), - Locale: Locale(user.Locale), + Locale: claims.Locale(user.Locale), } return claims, nil diff --git a/selfservice/strategy/oidc/provider_facebook.go b/selfservice/strategy/oidc/provider_facebook.go index abf9806cce05..be67f20098b1 100644 --- a/selfservice/strategy/oidc/provider_facebook.go +++ b/selfservice/strategy/oidc/provider_facebook.go @@ -16,6 +16,7 @@ import ( "github.com/ory/x/httpx" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/pkg/errors" @@ -62,7 +63,7 @@ func (g *ProviderFacebook) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2ConfigFromEndpoint(ctx, endpoint), nil } -func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -114,7 +115,7 @@ func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, q user.EmailVerified = true } - return &Claims{ + return &claims.Claims{ Issuer: u.String(), Subject: user.Id, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_generic_oidc.go b/selfservice/strategy/oidc/provider_generic_oidc.go index 146505165807..4e28ebcc4318 100644 --- a/selfservice/strategy/oidc/provider_generic_oidc.go +++ b/selfservice/strategy/oidc/provider_generic_oidc.go @@ -13,6 +13,7 @@ import ( gooidc "github.com/coreos/go-oidc/v3/oidc" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringslice" ) @@ -96,13 +97,13 @@ func (g *ProviderGenericOIDC) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption return options } -func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*Claims, error) { +func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*claims.Claims, error) { token, err := provider.VerifierContext(g.withHTTPClientContext(ctx), &gooidc.Config{ClientID: g.config.ClientID}).Verify(ctx, raw) if err != nil { return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("%s", err)) } - var claims Claims + var claims claims.Claims if err := token.Claims(&claims); err != nil { return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("%s", err)) } @@ -116,7 +117,7 @@ func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Cont return &claims, nil } -func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { switch g.config.ClaimsSource { case ClaimsSourceIDToken, "": return g.claimsFromIDToken(ctx, exchange) @@ -128,7 +129,7 @@ func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token WithReasonf("Unknown claims source: %q", g.config.ClaimsSource)) } -func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange *oauth2.Token) (*Claims, error) { +func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange *oauth2.Token) (*claims.Claims, error) { p, err := g.provider(ctx) if err != nil { return nil, err @@ -139,7 +140,7 @@ func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange * return nil, err } - var claims Claims + var claims claims.Claims if err = userInfo.Claims(&claims); err != nil { return nil, err } @@ -178,7 +179,7 @@ func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange * return &claims, nil } -func (g *ProviderGenericOIDC) claimsFromIDToken(ctx context.Context, exchange *oauth2.Token) (*Claims, error) { +func (g *ProviderGenericOIDC) claimsFromIDToken(ctx context.Context, exchange *oauth2.Token) (*claims.Claims, error) { p, raw, err := g.idTokenAndProvider(ctx, exchange) if err != nil { return nil, err diff --git a/selfservice/strategy/oidc/provider_github.go b/selfservice/strategy/oidc/provider_github.go index fe1d2bc371d1..08b00d19afa5 100644 --- a/selfservice/strategy/oidc/provider_github.go +++ b/selfservice/strategy/oidc/provider_github.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/pkg/errors" @@ -60,7 +61,7 @@ func (g *ProviderGitHub) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",") for _, check := range g.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -76,7 +77,7 @@ func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, que return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: fmt.Sprintf("%d", user.GetID()), Issuer: github.Endpoint.TokenURL, Name: user.GetName(), diff --git a/selfservice/strategy/oidc/provider_github_app.go b/selfservice/strategy/oidc/provider_github_app.go index 95801ce59e47..6228af7e104f 100644 --- a/selfservice/strategy/oidc/provider_github_app.go +++ b/selfservice/strategy/oidc/provider_github_app.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/ory/x/httpx" @@ -57,7 +58,7 @@ func (g *ProviderGitHubApp) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { ctx, client := httpx.SetOAuth2(ctx, g.reg.HTTPClient(ctx), g.oauth2(ctx), exchange) gh := ghapi.NewClient(client.HTTPClient) @@ -66,7 +67,7 @@ func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: fmt.Sprintf("%d", user.GetID()), Issuer: github.Endpoint.TokenURL, Name: user.GetName(), diff --git a/selfservice/strategy/oidc/provider_gitlab.go b/selfservice/strategy/oidc/provider_gitlab.go index 9ef55b4beef7..417e6acf60c2 100644 --- a/selfservice/strategy/oidc/provider_gitlab.go +++ b/selfservice/strategy/oidc/provider_gitlab.go @@ -9,6 +9,7 @@ import ( "net/url" "path" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringsx" "github.com/hashicorp/go-retryablehttp" @@ -69,7 +70,7 @@ func (g *ProviderGitLab) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx) } -func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -98,7 +99,7 @@ func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, que return nil, err } - var claims Claims + var claims claims.Claims if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_google.go b/selfservice/strategy/oidc/provider_google.go index e27832692faa..867380dcd720 100644 --- a/selfservice/strategy/oidc/provider_google.go +++ b/selfservice/strategy/oidc/provider_google.go @@ -9,6 +9,7 @@ import ( gooidc "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringslice" ) @@ -73,7 +74,7 @@ var _ IDTokenVerifier = new(ProviderGoogle) const issuerUrlGoogle = "https://accounts.google.com" -func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*Claims, error) { +func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) { keySet := gooidc.NewRemoteKeySet(ctx, p.JWKSUrl) ctx = gooidc.ClientContext(ctx, p.reg.HTTPClient(ctx).HTTPClient) return verifyToken(ctx, keySet, p.config, rawIDToken, issuerUrlGoogle) @@ -81,7 +82,7 @@ func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*Claims var _ NonceValidationSkipper = new(ProviderGoogle) -func (a *ProviderGoogle) CanSkipNonce(c *Claims) bool { +func (a *ProviderGoogle) CanSkipNonce(c *claims.Claims) bool { // Not all SDKs support nonce validation, so we skip it if no nonce is present in the claims of the ID Token. return c.Nonce == "" } diff --git a/selfservice/strategy/oidc/provider_lark.go b/selfservice/strategy/oidc/provider_lark.go index 52902dc20e8c..66006d08754c 100644 --- a/selfservice/strategy/oidc/provider_lark.go +++ b/selfservice/strategy/oidc/provider_lark.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" ) @@ -46,7 +47,6 @@ func (g *ProviderLark) Config() *Configuration { } func (g *ProviderLark) OAuth2(ctx context.Context) (*oauth2.Config, error) { - return &oauth2.Config{ ClientID: g.config.ClientID, ClientSecret: g.config.ClientSecret, @@ -55,10 +55,9 @@ func (g *ProviderLark) OAuth2(ctx context.Context) (*oauth2.Config, error) { Scopes: g.config.Scope, RedirectURL: g.config.Redir(g.reg.Config().OIDCRedirectURIBase(ctx)), }, nil - } -func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { // larkClaim is defined in the https://open.feishu.cn/document/common-capabilities/sso/api/get-user-info type larkClaim struct { Sub string `json:"sub"` @@ -101,7 +100,7 @@ func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - return &Claims{ + return &claims.Claims{ Issuer: larkUserEndpoint, Subject: user.OpenID, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_linkedin.go b/selfservice/strategy/oidc/provider_linkedin.go index 03a3db3e490d..363a33fe8afd 100644 --- a/selfservice/strategy/oidc/provider_linkedin.go +++ b/selfservice/strategy/oidc/provider_linkedin.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/otelx" "github.com/hashicorp/go-retryablehttp" @@ -165,7 +166,7 @@ func (l *ProviderLinkedIn) ProfilePicture(profile *LinkedInProfile) string { return identifiers[0].Identifier } -func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (_ *Claims, err error) { +func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (_ *claims.Claims, err error) { ctx, span := l.reg.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.ProviderLinkedIn.Claims") defer otelx.End(span, &err) @@ -185,7 +186,7 @@ func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, q return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: profile.ID, Issuer: "https://login.linkedin.com/", Email: email.Elements[0].Handle.EmailAddress, diff --git a/selfservice/strategy/oidc/provider_linkedin_test.go b/selfservice/strategy/oidc/provider_linkedin_test.go index d5b9df86d25a..cff4f0edb9f7 100644 --- a/selfservice/strategy/oidc/provider_linkedin_test.go +++ b/selfservice/strategy/oidc/provider_linkedin_test.go @@ -19,6 +19,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) func TestProviderLinkedin_Claims(t *testing.T) { @@ -122,7 +123,7 @@ func TestProviderLinkedin_Claims(t *testing.T) { ) require.NoError(t, err) - assert.Equal(t, &oidc.Claims{ + assert.Equal(t, &claims.Claims{ Issuer: "https://login.linkedin.com/", Subject: "5foOWOiYXD", GivenName: "John", @@ -198,7 +199,7 @@ func TestProviderLinkedin_No_Picture(t *testing.T) { ) require.NoError(t, err) - assert.Equal(t, &oidc.Claims{ + assert.Equal(t, &claims.Claims{ Issuer: "https://login.linkedin.com/", Subject: "5foOWOiYXD", GivenName: "John", diff --git a/selfservice/strategy/oidc/provider_microsoft.go b/selfservice/strategy/oidc/provider_microsoft.go index d69206ec4d87..c05c31407c27 100644 --- a/selfservice/strategy/oidc/provider_microsoft.go +++ b/selfservice/strategy/oidc/provider_microsoft.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/gofrs/uuid" @@ -53,7 +54,7 @@ func (m *ProviderMicrosoft) OAuth2(ctx context.Context) (*oauth2.Config, error) return m.oauth2ConfigFromEndpoint(ctx, endpoint), nil } -func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { raw, ok := exchange.Extra("id_token").(string) if !ok || len(raw) == 0 { return nil, errors.WithStack(ErrIDTokenMissing) @@ -84,7 +85,7 @@ func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, return m.updateSubject(ctx, claims, exchange) } -func (m *ProviderMicrosoft) updateSubject(ctx context.Context, claims *Claims, exchange *oauth2.Token) (*Claims, error) { +func (m *ProviderMicrosoft) updateSubject(ctx context.Context, claims *claims.Claims, exchange *oauth2.Token) (*claims.Claims, error) { if m.config.SubjectSource == "me" { o, err := m.OAuth2(ctx) if err != nil { diff --git a/selfservice/strategy/oidc/provider_netid.go b/selfservice/strategy/oidc/provider_netid.go index dfe83c958433..c9b823ce4cd1 100644 --- a/selfservice/strategy/oidc/provider_netid.go +++ b/selfservice/strategy/oidc/provider_netid.go @@ -11,6 +11,7 @@ import ( gooidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringslice" "github.com/hashicorp/go-retryablehttp" @@ -71,7 +72,7 @@ func (n *ProviderNetID) oAuth2(ctx context.Context) (*oauth2.Config, error) { }, nil } -func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { o, err := n.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -103,17 +104,17 @@ func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ ur return nil, errors.WithStack(ErrIDTokenMissing) } - claims, err := n.verifyAndDecodeClaimsWithProvider(ctx, p, raw) + dec, err := n.verifyAndDecodeClaimsWithProvider(ctx, p, raw) if err != nil { return nil, err } - var userinfo Claims + var userinfo claims.Claims if err := json.NewDecoder(resp.Body).Decode(&userinfo); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - userinfo.Issuer = claims.Issuer - userinfo.Subject = claims.Subject + userinfo.Issuer = dec.Issuer + userinfo.Subject = dec.Subject return &userinfo, nil } diff --git a/selfservice/strategy/oidc/provider_patreon.go b/selfservice/strategy/oidc/provider_patreon.go index 745dc8fcc199..cf4a09df2773 100644 --- a/selfservice/strategy/oidc/provider_patreon.go +++ b/selfservice/strategy/oidc/provider_patreon.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -79,7 +80,7 @@ func (d *ProviderPatreon) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { } } -func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { identityUrl := "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=first_name,last_name,url,full_name,email,image_url" o := d.oauth2(ctx) @@ -107,7 +108,7 @@ func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, qu return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", jsonErr)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: "https://www.patreon.com/", Subject: data.Data.Id, Name: data.Data.Attributes.FullName, diff --git a/selfservice/strategy/oidc/provider_slack.go b/selfservice/strategy/oidc/provider_slack.go index 7c7e26c99da4..d1c4a9eb4519 100644 --- a/selfservice/strategy/oidc/provider_slack.go +++ b/selfservice/strategy/oidc/provider_slack.go @@ -9,6 +9,7 @@ import ( "net/url" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/pkg/errors" "golang.org/x/oauth2" @@ -61,7 +62,7 @@ func (d *ProviderSlack) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",") for _, check := range d.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -75,7 +76,7 @@ func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, quer return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: "https://slack.com/oauth/", Subject: identity.User.ID, Name: identity.User.Name, diff --git a/selfservice/strategy/oidc/provider_spotify.go b/selfservice/strategy/oidc/provider_spotify.go index 366105c94d0e..b1dc3791f423 100644 --- a/selfservice/strategy/oidc/provider_spotify.go +++ b/selfservice/strategy/oidc/provider_spotify.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringslice" "github.com/ory/x/stringsx" @@ -60,7 +61,7 @@ func (g *ProviderSpotify) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ") for _, check := range g.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -85,7 +86,7 @@ func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, qu userPicture = user.Images[0].URL } - claims := &Claims{ + claims := &claims.Claims{ Subject: user.ID, Issuer: spotify.Endpoint.TokenURL, Name: user.DisplayName, diff --git a/selfservice/strategy/oidc/provider_test.go b/selfservice/strategy/oidc/provider_test.go index a5733d2e95f8..1041ea2c5c56 100644 --- a/selfservice/strategy/oidc/provider_test.go +++ b/selfservice/strategy/oidc/provider_test.go @@ -11,16 +11,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" -) -func TestClaimsValidate(t *testing.T) { - require.Error(t, new(Claims).Validate()) - require.Error(t, (&Claims{Issuer: "not-empty"}).Validate()) - require.Error(t, (&Claims{Issuer: "not-empty"}).Validate()) - require.Error(t, (&Claims{Subject: "not-empty"}).Validate()) - require.Error(t, (&Claims{Subject: "not-empty"}).Validate()) - require.NoError(t, (&Claims{Issuer: "not-empty", Subject: "not-empty"}).Validate()) -} + "github.com/ory/kratos/selfservice/strategy/oidc/claims" +) type TestProvider struct { *ProviderGenericOIDC @@ -43,11 +36,11 @@ func RegisterTestProvider(id string) func() { var _ IDTokenVerifier = new(TestProvider) -func (t *TestProvider) Verify(_ context.Context, token string) (*Claims, error) { +func (t *TestProvider) Verify(_ context.Context, token string) (*claims.Claims, error) { if token == "error" { return nil, fmt.Errorf("stub error") } - c := Claims{} + c := claims.Claims{} if err := json.Unmarshal([]byte(token), &c); err != nil { return nil, err } @@ -95,7 +88,7 @@ func TestLocale(t *testing.T) { expected: "", }} { t.Run(tc.name, func(t *testing.T) { - var c Claims + var c claims.Claims err := json.Unmarshal([]byte(tc.json), &c) if tc.assertErr != nil { tc.assertErr(t, err) diff --git a/selfservice/strategy/oidc/provider_userinfo_test.go b/selfservice/strategy/oidc/provider_userinfo_test.go index 97456dfc404d..a89e8d5eb07c 100644 --- a/selfservice/strategy/oidc/provider_userinfo_test.go +++ b/selfservice/strategy/oidc/provider_userinfo_test.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/otelx" @@ -45,7 +46,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { ctx := context.Background() token := &oauth2.Token{AccessToken: "foo", Expiry: time.Now().Add(time.Hour)} - expectedClaims := &oidc.Claims{ + expectedClaims := &claims.Claims{ Issuer: "ignore-me", Subject: "123456789012345", Name: "John Doe", @@ -75,7 +76,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { config *oidc.Configuration provider oidc.Provider userInfoHandler func(req *http.Request) (*http.Response, error) - expectedClaims *oidc.Claims + expectedClaims *claims.Claims useToken *oauth2.Token hook func(t *testing.T) }{ @@ -125,7 +126,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }, ) }, - expectedClaims: &oidc.Claims{Issuer: "https://broker.netid.de/", Subject: "1234567890", Name: "John Doe", GivenName: "John", FamilyName: "Doe", LastName: "", MiddleName: "", Nickname: "John Doe", PreferredUsername: "John Doe", Profile: "", Picture: "", Website: "", Email: "john.doe@example.com", EmailVerified: true, Gender: "", Birthdate: "01/01/1990", Zoneinfo: "", Locale: "", PhoneNumber: "", PhoneNumberVerified: false, UpdatedAt: 0, HD: "", Team: ""}, + expectedClaims: &claims.Claims{Issuer: "https://broker.netid.de/", Subject: "1234567890", Name: "John Doe", GivenName: "John", FamilyName: "Doe", LastName: "", MiddleName: "", Nickname: "John Doe", PreferredUsername: "John Doe", Profile: "", Picture: "", Website: "", Email: "john.doe@example.com", EmailVerified: true, Gender: "", Birthdate: "01/01/1990", Zoneinfo: "", Locale: "", PhoneNumber: "", PhoneNumberVerified: false, UpdatedAt: 0, HD: "", Team: ""}, }, { name: "vk", @@ -148,7 +149,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://api.vk.com/method/users.get", Subject: "123456789012345", Email: "john.doe@example.com", @@ -176,7 +177,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://login.yandex.ru/info", Subject: "123456789012345", Email: "john.doe@example.com", @@ -221,7 +222,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }) return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://graph.facebook.com/me?fields=id,name,first_name,last_name,middle_name,email,picture,birthday,gender&appsecret_proof=0c0d98f7e3d9d45e72e8877bc1b104327efb9c07b18f2ffeced76d81307f1fff", Subject: "123456789012345", Name: "John Doe", @@ -292,7 +293,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }, ) }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://login.microsoftonline.com/a9b86385-f32c-4803-afc8-4b2312fbdf24/v2.0", Subject: "new-id", Name: "John Doe", Email: "john.doe@example.com", RawClaims: map[string]interface{}{"aud": []interface{}{"foo"}, "exp": 4.071728504e+09, "iat": 1.516239022e+09, "iss": "https://login.microsoftonline.com/a9b86385-f32c-4803-afc8-4b2312fbdf24/v2.0", "email": "john.doe@example.com", "name": "John Doe", "sub": "1234567890", "tid": "a9b86385-f32c-4803-afc8-4b2312fbdf24"}, }, @@ -317,7 +318,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { ID: "dingtalk", Provider: "dingtalk", }, reg), - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://api.dingtalk.com/v1.0/contact/users/me", Subject: "123456789012345", Email: "john.doe@example.com", diff --git a/selfservice/strategy/oidc/provider_vk.go b/selfservice/strategy/oidc/provider_vk.go index 2a3513b6e050..97d77cb8af65 100644 --- a/selfservice/strategy/oidc/provider_vk.go +++ b/selfservice/strategy/oidc/provider_vk.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -59,7 +60,7 @@ func (g *ProviderVK) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx), nil } -func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -118,7 +119,7 @@ func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query u gender = "male" } - return &Claims{ + return &claims.Claims{ Issuer: "https://api.vk.com/method/users.get", Subject: strconv.Itoa(user.Id), GivenName: user.FirstName, diff --git a/selfservice/strategy/oidc/provider_x.go b/selfservice/strategy/oidc/provider_x.go index f58dbd48182f..3c4933d403b5 100644 --- a/selfservice/strategy/oidc/provider_x.go +++ b/selfservice/strategy/oidc/provider_x.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/otelx" "github.com/dghubble/oauth1" @@ -18,11 +19,15 @@ import ( "github.com/ory/herodot" ) -var _ Provider = (*ProviderX)(nil) -var _ OAuth1Provider = (*ProviderX)(nil) +var ( + _ Provider = (*ProviderX)(nil) + _ OAuth1Provider = (*ProviderX)(nil) +) -const xUserInfoBase = "https://api.twitter.com/1.1/account/verify_credentials.json" -const xUserInfoWithEmail = xUserInfoBase + "?include_email=true" +const ( + xUserInfoBase = "https://api.twitter.com/1.1/account/verify_credentials.json" + xUserInfoWithEmail = xUserInfoBase + "?include_email=true" +) type ProviderX struct { config *Configuration @@ -35,7 +40,8 @@ func (p *ProviderX) Config() *Configuration { func NewProviderX( config *Configuration, - reg Dependencies) Provider { + reg Dependencies, +) Provider { return &ProviderX{ config: config, reg: reg, @@ -107,7 +113,7 @@ func (p *ProviderX) userInfoEndpoint() string { return xUserInfoBase } -func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) { +func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*claims.Claims, error) { ctx = context.WithValue(ctx, oauth1.HTTPClient, p.reg.HTTPClient(ctx).HTTPClient) c := p.OAuth1(ctx) @@ -134,7 +140,7 @@ func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*Claims, e website = *user.URL } - return &Claims{ + return &claims.Claims{ Issuer: endpoint, Subject: user.IDStr, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_yandex.go b/selfservice/strategy/oidc/provider_yandex.go index 07b30caee52b..6f8582cb5ee3 100644 --- a/selfservice/strategy/oidc/provider_yandex.go +++ b/selfservice/strategy/oidc/provider_yandex.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/herodot" @@ -57,7 +58,7 @@ func (g *ProviderYandex) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx), nil } -func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -100,7 +101,7 @@ func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, que user.Picture = "" } - return &Claims{ + return &claims.Claims{ Issuer: "https://login.yandex.ru/info", Subject: user.Id, GivenName: user.FirstName, diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 6515d06367ee..8fc406224fcc 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -26,6 +26,7 @@ import ( "github.com/ory/kratos/cipher" "github.com/ory/kratos/selfservice/flowhelpers" "github.com/ory/kratos/selfservice/sessiontokenexchange" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/jsonnetsecure" "github.com/ory/x/otelx" @@ -421,7 +422,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt return } - var claims *Claims + var claims *claims.Claims var et *identity.CredentialsOIDCEncryptedTokens switch p := provider.(type) { case OAuth2Provider: @@ -726,7 +727,7 @@ func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session. } } -func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) { +func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*claims.Claims, error) { verifier, ok := provider.(IDTokenVerifier) if !ok { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The provider %s does not support id_token verification", provider.Config().Provider)) diff --git a/selfservice/strategy/oidc/strategy_helper_test.go b/selfservice/strategy/oidc/strategy_helper_test.go index 7b39729cc6af..0285f484c286 100644 --- a/selfservice/strategy/oidc/strategy_helper_test.go +++ b/selfservice/strategy/oidc/strategy_helper_test.go @@ -368,7 +368,7 @@ var publicJWKS []byte //go:embed stub/jwks_public2.json var publicJWKS2 []byte -type claims struct { +type jwtClaims struct { *jwt.RegisteredClaims Email string `json:"email"` } @@ -376,7 +376,7 @@ type claims struct { func createIdToken(t *testing.T, cl jwt.RegisteredClaims) string { key := &jwk.KeySpec{} require.NoError(t, json.Unmarshal(rawKey, key)) - token := jwt.NewWithClaims(jwt.SigningMethodRS256, &claims{ + token := jwt.NewWithClaims(jwt.SigningMethodRS256, &jwtClaims{ RegisteredClaims: &cl, Email: "acme@ory.sh", }) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 42b948ec7c11..5c32de727721 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -19,6 +19,7 @@ import ( "github.com/ory/x/sqlcon" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/text" @@ -106,7 +107,7 @@ type UpdateLoginFlowWithOidcMethod struct { TransientPayload json.RawMessage `json:"transient_payload,omitempty" form:"transient_payload"` } -func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer) (*registration.Flow, error) { +func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (*registration.Flow, error) { i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)) if err != nil { if errors.Is(err, sqlcon.ErrNoRows) { @@ -175,7 +176,7 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo httprouter.ParamsFromContext(r.Context()).ByName("organization")) for _, c := range oidcCredentials.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { - if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil { + if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, claims, provider.Config().ID); err != nil { return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) } return nil, nil diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 124e8539f6ea..de2aa1843844 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -23,6 +23,7 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/text" "github.com/ory/kratos/x" "github.com/ory/x/decoderx" @@ -276,7 +277,7 @@ func (s *Strategy) registrationToLogin(w http.ResponseWriter, r *http.Request, r return lf, nil } -func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer, idToken string) (*login.Flow, error) { +func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer, idToken string) (*login.Flow, error) { if _, _, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)); err == nil { // If the identity already exists, we should perform the login flow instead. @@ -344,7 +345,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r return nil, nil } -func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) { +func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a flow.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) { var jsonClaims bytes.Buffer if err := json.NewEncoder(&jsonClaims).Encode(claims); err != nil { return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err) @@ -393,7 +394,7 @@ func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *reg return i, va, nil } -func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, evaluated string, i *identity.Identity) error { +func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, a flow.Flow, claims *claims.Claims, provider Provider, container *AuthCodeContainer, evaluated string, i *identity.Identity) error { jsonTraits := gjson.Get(evaluated, "identity.traits") if !jsonTraits.IsObject() { return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("OpenID Connect Jsonnet mapper did not return an object for key identity.traits. Please check your Jsonnet code!")) diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index 4fde3a457548..dac7632b0d90 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -17,6 +17,7 @@ import ( "github.com/ory/kratos/continuity" "github.com/ory/kratos/selfservice/strategy" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/decoderx" "github.com/ory/kratos/session" @@ -400,7 +401,7 @@ func (s *Strategy) initLinkProvider(w http.ResponseWriter, r *http.Request, ctxU return errors.WithStack(flow.ErrCompletedByStrategy) } -func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider) error { +func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider) error { p := &updateSettingsFlowWithOidcMethod{ Link: provider.Config().ID, FlowID: ctxUpdate.Flow.ID.String(), } diff --git a/selfservice/strategy/oidc/token_verifier.go b/selfservice/strategy/oidc/token_verifier.go index ce9cb8b3d3ee..864d9faa153e 100644 --- a/selfservice/strategy/oidc/token_verifier.go +++ b/selfservice/strategy/oidc/token_verifier.go @@ -9,9 +9,11 @@ import ( "strings" "github.com/coreos/go-oidc/v3/oidc" + + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) -func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, rawIDToken, issuerURL string) (*Claims, error) { +func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, rawIDToken, issuerURL string) (*claims.Claims, error) { tokenAudiences := append([]string{config.ClientID}, config.AdditionalIDTokenAudiences...) var token *oidc.IDToken err := fmt.Errorf("no audience matched the token's audience") @@ -34,7 +36,7 @@ func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, // None of the allowed audiences matched the audience in the token return nil, fmt.Errorf("token audience didn't match allowed audiences: %+v %w", tokenAudiences, err) } - claims := &Claims{} + claims := &claims.Claims{} if err := token.Claims(claims); err != nil { return nil, err }