diff --git a/gzhttp/compress.go b/gzhttp/compress.go index 265e71c062..289ae3e2ee 100644 --- a/gzhttp/compress.go +++ b/gzhttp/compress.go @@ -335,7 +335,16 @@ func (w *GzipResponseWriter) Close() error { ce = w.Header().Get(contentEncoding) cr = w.Header().Get(contentRange) ) - // fmt.Println(len(w.buf) == 0, len(w.buf) < w.minSize, len(w.Header()[HeaderNoCompression]) != 0, ce != "", cr != "", !w.contentTypeFilter(ct)) + if ct == "" { + ct = http.DetectContentType(w.buf) + + // Handles the intended case of setting a nil Content-Type (as for http/server or http/fs) + // Set the header only if the key does not exist + if _, ok := w.Header()[contentType]; w.setContentType && !ok { + w.Header().Set(contentType, ct) + } + } + if len(w.buf) == 0 || len(w.buf) < w.minSize || len(w.Header()[HeaderNoCompression]) != 0 || ce != "" || cr != "" || !w.contentTypeFilter(ct) { // GZIP not triggered, write out regular response. return w.startPlain() diff --git a/gzhttp/compress_test.go b/gzhttp/compress_test.go index fc19723892..dde980b5a7 100644 --- a/gzhttp/compress_test.go +++ b/gzhttp/compress_test.go @@ -14,6 +14,7 @@ import ( "net/url" "os" "strconv" + "strings" "testing" "github.com/klauspost/compress/gzip" @@ -1883,3 +1884,68 @@ func Test1xxResponses(t *testing.T) { body, _ := io.ReadAll(res.Body) assertEqual(t, gzipStrLevel(testBody, gzip.DefaultCompression), body) } + +func TestContentTypeDetectWithJitter(t *testing.T) { + t.Parallel() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + content := `` + strings.Repeat("foo", 400) + w.Write([]byte(content)) + }) + + for _, tc := range []struct { + name string + wrapper func(http.Handler) (http.Handler, error) + }{ + { + name: "no wrapping", + wrapper: func(h http.Handler) (http.Handler, error) { + return h, nil + }, + }, + { + name: "default", + wrapper: func(h http.Handler) (http.Handler, error) { + wrapper, err := NewWrapper() + if err != nil { + return nil, err + } + return wrapper(h), nil + }, + }, + { + name: "jitter, default buffer", + wrapper: func(h http.Handler) (http.Handler, error) { + wrapper, err := NewWrapper(RandomJitter(32, 0, false)) + if err != nil { + return nil, err + } + return wrapper(h), nil + }, + }, + { + name: "jitter, small buffer", + wrapper: func(h http.Handler) (http.Handler, error) { + wrapper, err := NewWrapper(RandomJitter(32, DefaultMinSize, false)) + if err != nil { + return nil, err + } + return wrapper(h), nil + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handler, err := tc.wrapper(handler) + assertNil(t, err) + + req, resp := httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder() + req.Header.Add("Accept-Encoding", "gzip") + + handler.ServeHTTP(resp, req) + + assertEqual(t, "text/html; charset=utf-8", resp.Header().Get("Content-Type")) + }) + } +}