diff --git a/cmd/daemon/serve.go b/cmd/daemon/serve.go index 48ba9174a8cd..7fce0fc26cc7 100644 --- a/cmd/daemon/serve.go +++ b/cmd/daemon/serve.go @@ -58,6 +58,8 @@ func servePublic(d driver.Driver, wg *sync.WaitGroup, cmd *cobra.Command, args [ c.SelfPublicURL().Hostname(), !flagx.MustGetBool(cmd, "dev"), ) + + n.UseFunc(x.CleanPath) // Prevent double slashes from breaking CSRF. r.WithCSRFHandler(csrf) n.UseHandler(r.CSRFHandler()) diff --git a/internal/testhelpers/httptest.go b/internal/testhelpers/httptest.go new file mode 100644 index 000000000000..9408ab91fb5c --- /dev/null +++ b/internal/testhelpers/httptest.go @@ -0,0 +1,13 @@ +package testhelpers + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func NewHTTPTestServer(t *testing.T, h http.Handler) *httptest.Server { + ts := httptest.NewServer(h) + t.Cleanup(ts.Close) + return ts +} diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index db3d4bc10066..c37434e7a91c 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -33,6 +33,7 @@ type ( session.ManagementProvider x.WriterProvider x.CSRFTokenGeneratorProvider + x.CSRFProvider } HandlerProvider interface { LoginHandler() *Handler @@ -48,6 +49,7 @@ func NewHandler(d handlerDependencies, c configuration.Provider) *Handler { } func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) { + h.d.CSRFHandler().ExemptPath(RouteInitAPIFlow) public.GET(RouteInitBrowserFlow, h.initBrowserFlow) public.GET(RouteInitAPIFlow, h.initAPIFlow) public.GET(RouteGetFlow, h.fetchFlow) diff --git a/selfservice/flow/recovery/handler.go b/selfservice/flow/recovery/handler.go index fcbcdc8c2c3e..df254eeb2903 100644 --- a/selfservice/flow/recovery/handler.go +++ b/selfservice/flow/recovery/handler.go @@ -36,6 +36,7 @@ type ( FlowPersistenceProvider x.CSRFTokenGeneratorProvider x.WriterProvider + x.CSRFProvider } Handler struct { d handlerDependencies @@ -48,6 +49,8 @@ func NewHandler(d handlerDependencies, c configuration.Provider) *Handler { } func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) { + h.d.CSRFHandler().ExemptPath(RouteInitAPIFlow) + redirect := session.RedirectOnAuthenticated(h.c) public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsNotAuthenticated(h.initBrowserFlow, redirect)) public.GET(RouteInitAPIFlow, h.d.SessionHandler().IsNotAuthenticated(h.initAPIFlow, diff --git a/selfservice/flow/registration/handler.go b/selfservice/flow/registration/handler.go index 29c934cca11c..259a11eaa1ae 100644 --- a/selfservice/flow/registration/handler.go +++ b/selfservice/flow/registration/handler.go @@ -33,6 +33,7 @@ type ( x.CSRFTokenGeneratorProvider HookExecutorProvider FlowPersistenceProvider + x.CSRFProvider } HandlerProvider interface { RegistrationHandler() *Handler @@ -48,6 +49,8 @@ func NewHandler(d handlerDependencies, c configuration.Provider) *Handler { } func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) { + h.d.CSRFHandler().ExemptPath(RouteInitAPIFlow) + public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsNotAuthenticated(h.initBrowserFlow, session.RedirectOnAuthenticated(h.c))) public.GET(RouteInitAPIFlow, h.d.SessionHandler().IsNotAuthenticated(h.initApiFlow, session.RespondWithJSONErrorOnAuthenticated(h.d.Writer(), errors.WithStack(ErrAlreadyLoggedIn)))) diff --git a/selfservice/flow/settings/handler.go b/selfservice/flow/settings/handler.go index f87a684c8580..820af5cdfb7d 100644 --- a/selfservice/flow/settings/handler.go +++ b/selfservice/flow/settings/handler.go @@ -57,6 +57,7 @@ type ( StrategyProvider IdentityTraitsSchemas() schema.Schemas + x.CSRFProvider } HandlerProvider interface { SettingsHandler() *Handler @@ -73,6 +74,8 @@ func NewHandler(d handlerDependencies, c configuration.Provider) *Handler { } func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) { + h.d.CSRFHandler().ExemptPath(RouteInitAPIFlow) + redirect := session.RedirectOnUnauthenticated(h.c.SelfServiceFlowLoginUI().String()) public.GET(RouteInitBrowserFlow, h.d.SessionHandler().IsAuthenticated(h.initBrowserFlow, redirect)) public.GET(RouteInitAPIFlow, h.d.SessionHandler().IsAuthenticated(h.initApiFlow, nil)) diff --git a/selfservice/flow/verification/handler.go b/selfservice/flow/verification/handler.go index 8a689cc8c760..d171156b89cf 100644 --- a/selfservice/flow/verification/handler.go +++ b/selfservice/flow/verification/handler.go @@ -36,6 +36,7 @@ type ( FlowPersistenceProvider ErrorHandlerProvider StrategyProvider + x.CSRFProvider } Handler struct { d handlerDependencies @@ -48,6 +49,8 @@ func NewHandler(d handlerDependencies, c configuration.Provider) *Handler { } func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) { + h.d.CSRFHandler().ExemptPath(RouteInitAPIFlow) + public.GET(RouteInitBrowserFlow, h.initBrowserFlow) public.GET(RouteInitAPIFlow, h.initAPIFlow) public.GET(RouteGetFlow, h.fetch) diff --git a/x/clean_url.go b/x/clean_url.go new file mode 100644 index 000000000000..f367c18d67e7 --- /dev/null +++ b/x/clean_url.go @@ -0,0 +1,13 @@ +package x + +import ( + "net/http" + + "github.com/julienschmidt/httprouter" + "github.com/urfave/negroni" +) + +var CleanPath negroni.HandlerFunc = func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + r.URL.Path = httprouter.CleanPath(r.URL.Path) + next(rw, r) +} diff --git a/x/clean_url_test.go b/x/clean_url_test.go new file mode 100644 index 000000000000..5e5956c8228a --- /dev/null +++ b/x/clean_url_test.go @@ -0,0 +1,35 @@ +package x + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/bmizerany/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/negroni" +) + +func TestCleanPath(t *testing.T) { + n := negroni.New(CleanPath) + n.UseHandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(r.URL.String())) + }) + ts := httptest.NewServer(n) + defer ts.Close() + + for k, tc := range [][]string{ + {"//foo", "/foo"}, + {"//foo//bar", "/foo/bar"}, + } { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + res, err := ts.Client().Get(ts.URL + tc[0]) + require.NoError(t, err) + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + assert.Equal(t, string(body), tc[1]) + }) + } +}