From 2d51fbdf2a366ab7630b94ce5b114d06cfa1a7ae Mon Sep 17 00:00:00 2001 From: splaunov Date: Sun, 25 Feb 2024 23:40:01 +0300 Subject: [PATCH] feat: add sso provider id to list of groups (PS-236) --- selfservice/strategy/oidc/strategy_login.go | 1 + selfservice/strategy/oidc/strategy_registration.go | 1 + selfservice/strategy/oidc/strategy_test.go | 4 ++-- selfservice/strategy/oidc/stub/oidc.hydra.login.jsonnet | 9 +++++++-- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 4240e1da407e..5ab5ea33fe0e 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -327,6 +327,7 @@ func (s *Strategy) updateIdentityFromClaimsAndPersist(w http.ResponseWriter, r * } vm.ExtCode("claims", jsonClaims.String()) + vm.ExtVar("provider", provider.Config().ID) jsonIdentity, err := json.Marshal(i) if err != nil { return nil, err diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index dcd1d9e94ca5..eb070b728186 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -379,6 +379,7 @@ func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *reg } vm.ExtCode("claims", jsonClaims.String()) + vm.ExtVar("provider", provider.Config().ID) evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, string(jsonnetSnippet)) if err != nil { return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err) diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 1865cd40f10b..c68f5c0d3370 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -469,7 +469,7 @@ func TestStrategy(t *testing.T) { assertIdentity(t, res, body) expectTokens(t, "valid", body) assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) - assert.Equal(t, "", gjson.GetBytes(body, "identity.metadata_public.groups").String(), "%s", prettyJSON(t, body)) + assert.Equal(t, "", gjson.GetBytes(body, "identity.metadata_public.sso_groups.valid").String(), "%s", prettyJSON(t, body)) }) t.Run("case=should pass login", func(t *testing.T) { @@ -479,7 +479,7 @@ func TestStrategy(t *testing.T) { assertIdentity(t, res, body) expectTokens(t, "valid", body) assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) - assert.Equal(t, `["group1","group2"]`, gjson.GetBytes(body, "identity.metadata_public.groups").String(), "%s", prettyJSON(t, body)) + assert.Equal(t, `["group1","group2"]`, gjson.GetBytes(body, "identity.metadata_public.sso_groups.valid").String(), "%s", prettyJSON(t, body)) }) }) diff --git a/selfservice/strategy/oidc/stub/oidc.hydra.login.jsonnet b/selfservice/strategy/oidc/stub/oidc.hydra.login.jsonnet index 428fdce353cd..6ca43235a475 100644 --- a/selfservice/strategy/oidc/stub/oidc.hydra.login.jsonnet +++ b/selfservice/strategy/oidc/stub/oidc.hydra.login.jsonnet @@ -1,12 +1,17 @@ local claims = std.extVar('claims'); +local provider = std.extVar('provider'); local identity = std.extVar('identity'); -local metadata_public = if 'metadata_public' in identity then identity.metadata_public else {}; +local mp = if 'metadata_public' in identity then identity.metadata_public else {}; if std.length(claims.sub) == 0 then error 'claim sub not set' else { identity: { - metadata_public: metadata_public + { [if "groups" in claims.raw_claims then "groups" else null]: claims.raw_claims.groups }, + metadata_public: mp { + sso_groups+: { + [if 'groups' in claims.raw_claims then provider]: claims.raw_claims.groups, + }, + }, }, }