From dd9cab0e02400c88e89877f755f03c6179013123 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Thu, 22 Oct 2020 11:32:31 +0200 Subject: [PATCH] fix: return correct error in login csrf Closes #785 --- selfservice/flow/request.go | 2 +- selfservice/strategy/password/login.go | 2 +- selfservice/strategy/password/login_test.go | 37 +++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/selfservice/flow/request.go b/selfservice/flow/request.go index f434f7c5ace6..bec6415b0567 100644 --- a/selfservice/flow/request.go +++ b/selfservice/flow/request.go @@ -39,7 +39,7 @@ func VerifyRequest( return nil default: if !nosurf.VerifyToken(generator(r), actual) { - return x.ErrInvalidCSRFToken + return errors.WithStack(x.ErrInvalidCSRFToken) } } diff --git a/selfservice/strategy/password/login.go b/selfservice/strategy/password/login.go index 4f704b85492a..e5dd3e0c0e25 100644 --- a/selfservice/strategy/password/login.go +++ b/selfservice/strategy/password/login.go @@ -113,7 +113,7 @@ func (s *Strategy) handleLogin(w http.ResponseWriter, r *http.Request, _ httprou } if err := flow.VerifyRequest(r, ar.Type, s.d.GenerateCSRFToken, p.CSRFToken); err != nil { - s.handleLoginError(w, r, ar, &p, x.ErrInvalidCSRFToken) + s.handleLoginError(w, r, ar, &p, err) return } diff --git a/selfservice/strategy/password/login_test.go b/selfservice/strategy/password/login_test.go index 699c3ac25e4c..ea0e0e8a2e86 100644 --- a/selfservice/strategy/password/login_test.go +++ b/selfservice/strategy/password/login_test.go @@ -1,6 +1,7 @@ package password_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -179,6 +180,42 @@ func TestCompleteLogin(t *testing.T) { assert.EqualValues(t, http.StatusBadRequest, res.StatusCode) assert.Contains(t, actual, "provided credentials are invalid") }) + + t.Run("case=should fail with correct CSRF error cause/type=api", func(t *testing.T) { + for k, tc := range []struct { + mod func(http.Header) + exp string + }{ + { + mod: func(h http.Header) { + h.Add("Cookie", "name=bar") + }, + exp: "The HTTP Request Header included the \\\"Cookie\\\" key", + }, + { + mod: func(h http.Header) { + h.Add("Origin", "www.bar.com") + }, + exp: "The HTTP Request Header included the \\\"Origin\\\" key", + }, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + f := testhelpers.InitializeLoginFlowViaAPI(t, apiClient, publicTS, false) + c := testhelpers.GetLoginFlowMethodConfig(t, f.Payload, identity.CredentialsTypePassword.String()) + + req := testhelpers.NewRequest(t, true, "POST", pointerx.StringR(c.Action), bytes.NewBufferString(testhelpers.EncodeFormAsJSON(t, true, values))) + tc.mod(req.Header) + + res, err := apiClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + actual := string(x.MustReadAll(res.Body)) + assert.EqualValues(t, http.StatusBadRequest, res.StatusCode) + assert.Contains(t, actual, tc.exp) + }) + } + }) }) var expectValidationError = func(t *testing.T, isAPI, forced bool, values func(url.Values)) string {