Skip to content

Commit

Permalink
chore: craft unit test with httptest to test the trailing slash middl…
Browse files Browse the repository at this point in the history
…eware
  • Loading branch information
yquansah committed Jun 15, 2023
1 parent 36ea696 commit ab7d29b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 14 deletions.
27 changes: 13 additions & 14 deletions internal/cmd/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,7 @@ func NewHTTPServer(
h.ServeHTTP(w, r)
})
})
r.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/") {
// Panic if URL can not be parsed if a trailing slash is trimmed.
nurl, err := url.Parse(strings.TrimSuffix(r.URL.String(), "/"))
if err != nil {
panic(err)
}

r.URL = nurl
}
h.ServeHTTP(w, r)
})
})
r.Use(removeTrailingSlash)
r.Use(middleware.Compress(gzip.DefaultCompression))
r.Use(middleware.Recoverer)
r.Mount("/debug", middleware.Profiler())
Expand Down Expand Up @@ -251,3 +238,15 @@ func (h *HTTPServer) Shutdown(ctx context.Context) error {

return h.Server.Shutdown(ctx)
}

func removeTrailingSlash(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, err := url.Parse(strings.TrimSuffix(r.URL.Path, "/"))
// Panic if URL can not be parsed if a trailing slash is trimmed.
if err != nil {
panic(err)
}
r.URL = u
h.ServeHTTP(w, r)
})
}
61 changes: 61 additions & 0 deletions internal/cmd/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package cmd

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
)

const (
tsoHeader = "trailing-slash-on"
)

func TestTrailingSlashMiddleware(t *testing.T) {
r := chi.NewRouter()

r.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tso := r.Header.Get(tsoHeader)
if tso != "" {
tsh := removeTrailingSlash(h)

tsh.ServeHTTP(w, r)
return
}

h.ServeHTTP(w, r)
})
})
r.Get("/hello", func(w http.ResponseWriter, r *http.Request) {
})

s := httptest.NewServer(r)

defer s.Close()

// Request with the middleware on.
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, fmt.Sprintf("%s/hello/", s.URL), nil)
assert.NoError(t, err)
req.Header.Set(tsoHeader, "on")

res, err := http.DefaultClient.Do(req)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, res.StatusCode)
res.Body.Close()

// Request with the middleware off.
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, fmt.Sprintf("%s/hello/", s.URL), nil)
assert.NoError(t, err)

res, err = http.DefaultClient.Do(req)
assert.NoError(t, err)

assert.Equal(t, http.StatusNotFound, res.StatusCode)
res.Body.Close()
}

0 comments on commit ab7d29b

Please sign in to comment.