diff --git a/build/testing/integration/api/api.go b/build/testing/integration/api/api.go index 0637058e86..af9ad63eee 100644 --- a/build/testing/integration/api/api.go +++ b/build/testing/integration/api/api.go @@ -4,7 +4,11 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "strings" "testing" + "time" "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" @@ -13,11 +17,18 @@ import ( sdk "go.flipt.io/flipt/sdk/go" ) -func API(t *testing.T, ctx context.Context, client sdk.SDK, namespace string, authenticated bool) { +var ( + httpClient = &http.Client{ + Timeout: 5 * time.Second, + } +) + +func API(t *testing.T, ctx context.Context, client sdk.SDK, fliptAddr string, namespace string, authenticated bool) { t.Run("Namespaces", func(t *testing.T) { if !namespaceIsDefault(namespace) { t.Log(`Create namespace.`) + fmt.Println("The namespace is: ", namespace) created, err := client.Flipt().CreateNamespace(ctx, &flipt.CreateNamespaceRequest{ Key: namespace, Name: namespace, @@ -55,6 +66,20 @@ func API(t *testing.T, ctx context.Context, client sdk.SDK, namespace string, au require.NoError(t, err) assert.Equal(t, "Some kind of description", updated.Description) + + t.Log(`Namespace request with trailing slash should succeed.`) + + if isHTTPProtocol(fliptAddr) { + reader := makeRequestWithTrailingSlash(t, ctx, http.MethodGet, fmt.Sprintf("%s/api/v1/namespaces", fliptAddr)) + + var fliptNamespaces *flipt.NamespaceList + + err = json.NewDecoder(reader).Decode(&fliptNamespaces) + assert.NoError(t, err) + + assert.Equal(t, "default", fliptNamespaces.Namespaces[0].Key) + assert.Equal(t, namespace, fliptNamespaces.Namespaces[1].Key) + } } else { t.Log(`Ensure default cannot be created.`) @@ -641,3 +666,21 @@ func API(t *testing.T, ctx context.Context, client sdk.SDK, namespace string, au func namespaceIsDefault(ns string) bool { return ns == "" || ns == "default" } + +func isHTTPProtocol(fliptAddr string) bool { + protocol, _, _ := strings.Cut(fliptAddr, "://") + + return protocol == "http" || protocol == "https" +} + +func makeRequestWithTrailingSlash(t *testing.T, ctx context.Context, method string, fliptAddrWithPath string) io.Reader { + t.Helper() + + req, err := http.NewRequestWithContext(ctx, method, fliptAddrWithPath+"/", nil) + assert.NoError(t, err) + + res, err := httpClient.Do(req) + assert.NoError(t, err) + + return res.Body +} diff --git a/build/testing/integration/api/api_test.go b/build/testing/integration/api/api_test.go index 7cd785da01..20902ed1f6 100644 --- a/build/testing/integration/api/api_test.go +++ b/build/testing/integration/api/api_test.go @@ -10,10 +10,10 @@ import ( ) func TestAPI(t *testing.T) { - integration.Harness(t, func(t *testing.T, sdk sdk.SDK, namespace string, authentication bool) { + integration.Harness(t, func(t *testing.T, sdk sdk.SDK, fliptAddr string, namespace string, authentication bool) { ctx := context.Background() - api.API(t, ctx, sdk, namespace, authentication) + api.API(t, ctx, sdk, fliptAddr, namespace, authentication) // run extra tests in authenticated context if authentication { diff --git a/build/testing/integration/integration.go b/build/testing/integration/integration.go index 66d7389829..8424441226 100644 --- a/build/testing/integration/integration.go +++ b/build/testing/integration/integration.go @@ -20,7 +20,7 @@ var ( fliptNamespace = flag.String("flipt-namespace", "", "Namespace used to scope API calls.") ) -func Harness(t *testing.T, fn func(t *testing.T, sdk sdk.SDK, ns string, authenticated bool)) { +func Harness(t *testing.T, fn func(t *testing.T, sdk sdk.SDK, fliptAddr string, ns string, authenticated bool)) { var transport sdk.Transport protocol, host, _ := strings.Cut(*fliptAddr, "://") @@ -54,6 +54,6 @@ func Harness(t *testing.T, fn func(t *testing.T, sdk sdk.SDK, ns string, authent name := fmt.Sprintf("[Protocol %q; Namespace %q; Authentication %v]", protocol, namespace, authentication) t.Run(name, func(t *testing.T) { - fn(t, sdk.New(transport, opts...), namespace, authentication) + fn(t, sdk.New(transport, opts...), *fliptAddr, namespace, authentication) }) } diff --git a/build/testing/integration/readonly/readonly_test.go b/build/testing/integration/readonly/readonly_test.go index e547ad418f..d3561633bd 100644 --- a/build/testing/integration/readonly/readonly_test.go +++ b/build/testing/integration/readonly/readonly_test.go @@ -15,7 +15,7 @@ import ( // folder has been loaded into the target instance being tested. // It then exercises a bunch of read operations via the provided SDK in the target namespace. func TestReadOnly(t *testing.T) { - integration.Harness(t, func(t *testing.T, sdk sdk.SDK, namespace string, authenticated bool) { + integration.Harness(t, func(t *testing.T, sdk sdk.SDK, _ string, namespace string, authenticated bool) { ctx := context.Background() ns, err := sdk.Flipt().GetNamespace(ctx, &flipt.GetNamespaceRequest{ Key: namespace, diff --git a/internal/cmd/http.go b/internal/cmd/http.go index c7a82e0c64..f3a5276444 100644 --- a/internal/cmd/http.go +++ b/internal/cmd/http.go @@ -7,6 +7,8 @@ import ( "errors" "fmt" "net/http" + "net/url" + "strings" "time" "github.com/fatih/color" @@ -99,29 +101,25 @@ func NewHTTPServer( h.ServeHTTP(w, r) }) }) + r.Use(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/") { + // Panic if URL can not be parsed if a trailing slash is trimmed. + nurl, err := url.Parse(strings.TrimSuffix(r.URL.String(), "/")) + if err != nil { + panic(err) + } + + r.URL = nurl + } + h.ServeHTTP(w, r) + }) + }) r.Use(middleware.Compress(gzip.DefaultCompression)) r.Use(middleware.Recoverer) r.Mount("/debug", middleware.Profiler()) r.Mount("/metrics", promhttp.Handler()) - // Middleware to trim the trailing slash off of the request URL if it - // exists. - // r.Use(func(h http.Handler) http.Handler { - // return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // if strings.HasSuffix(r.URL.Path, "/") { - // fmt.Println("GETTING IN HERE...") - // nurl, err := url.Parse(strings.TrimSuffix(r.URL.String(), "/")) - // if err != nil { - // panic(err) - // } - - // fmt.Println("NEW URL: ", nurl) - // r.URL = nurl - // } - // h.ServeHTTP(w, r) - // }) - // }) - r.Group(func(r chi.Router) { if key := cfg.Authentication.Session.CSRF.Key; key != "" { logger.Debug("enabling CSRF prevention")