From 10bf84feb224a161d5c9a24e5e72b2f6e055ace5 Mon Sep 17 00:00:00 2001 From: Trim21 Date: Fri, 30 Aug 2024 01:48:17 +0800 Subject: [PATCH] feat: add max reponse body limit (#830) --- client.go | 61 +++++++++++++++++++++++++++++++++++++++++++++----- client_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++ request.go | 15 +++++++++++++ resty_test.go | 2 ++ 4 files changed, 127 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index aeb55077..d33d6297 100644 --- a/client.go +++ b/client.go @@ -128,6 +128,7 @@ type Client struct { // HeaderAuthorizationKey is used to set/access Request Authorization header // value when `SetAuthToken` option is used. HeaderAuthorizationKey string + ResponseBodyLimit int jsonEscapeHTML bool setContentLength bool @@ -442,11 +443,12 @@ func (c *Client) R() *Request { RawPathParams: map[string]string{}, Debug: c.Debug, - client: c, - multipartFiles: []*File{}, - multipartFields: []*MultipartField{}, - jsonEscapeHTML: c.jsonEscapeHTML, - log: c.log, + client: c, + multipartFiles: []*File{}, + multipartFields: []*MultipartField{}, + jsonEscapeHTML: c.jsonEscapeHTML, + log: c.log, + responseBodyLimit: c.ResponseBodyLimit, } return r } @@ -1089,6 +1091,20 @@ func (c *Client) SetJSONEscapeHTML(b bool) *Client { return c } +// SetResponseBodyLimit set a max body size limit on response, avoid reading too many data to memory. +// +// Client will return [resty.ErrResponseBodyTooLarge] if uncompressed response body size if larger than limit. +// Body size limit will not be enforced in following case: +// - ResponseBodyLimit <= 0, which is the default behavior. +// - [Request.SetOutput] is called to save a response data to file. +// - "DoNotParseResponse" is set for client or request. +// +// this can be overridden at client level with [Request.SetResponseBodyLimit] +func (c *Client) SetResponseBodyLimit(v int) *Client { + c.ResponseBodyLimit = v + return c +} + // EnableTrace method enables the Resty client trace for the requests fired from // the client using `httptrace.ClientTrace` and provides insights. // @@ -1238,7 +1254,7 @@ func (c *Client) execute(req *Request) (*Response, error) { } } - if response.body, err = io.ReadAll(body); err != nil { + if response.body, err = readAllWithLimit(body, req.responseBodyLimit); err != nil { response.setReceivedAt() return response, err } @@ -1258,6 +1274,39 @@ func (c *Client) execute(req *Request) (*Response, error) { return response, wrapNoRetryErr(err) } +var ErrResponseBodyTooLarge = errors.New("resty: response body too large") + +// https://github.com/golang/go/issues/51115 +// [io.LimitedReader] can only return [io.EOF] +func readAllWithLimit(r io.Reader, maxSize int) ([]byte, error) { + if maxSize <= 0 { + return io.ReadAll(r) + } + + var buf [512]byte // make buf stack allocated + result := make([]byte, 0, 512) + total := 0 + for { + n, err := r.Read(buf[:]) + total += n + if total > maxSize { + return nil, ErrResponseBodyTooLarge + } + + if err != nil { + if err == io.EOF { + result = append(result, buf[:n]...) + break + } + return nil, err + } + + result = append(result, buf[:n]...) + } + + return result, nil +} + // getting TLS client config if not exists then create one func (c *Client) tlsConfig() (*tls.Config, error) { transport, err := c.Transport() diff --git a/client_test.go b/client_test.go index c89d932d..acd31d44 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,8 @@ package resty import ( "bytes" + "compress/gzip" + "crypto/rand" "crypto/tls" "errors" "fmt" @@ -1097,3 +1099,56 @@ func TestClone(t *testing.T) { assertEqual(t, "clone", parent.UserInfo.Username) assertEqual(t, "clone", clone.UserInfo.Username) } + +func TestResponseBodyLimit(t *testing.T) { + ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { + io.CopyN(w, rand.Reader, 100*800) + }) + defer ts.Close() + + t.Run("Client body limit", func(t *testing.T) { + c := dc().SetResponseBodyLimit(1024) + + _, err := c.R().Get(ts.URL + "/") + assertNotNil(t, err) + assertEqual(t, err, ErrResponseBodyTooLarge) + }) + + t.Run("request body limit", func(t *testing.T) { + c := dc() + + _, err := c.R().SetResponseBodyLimit(1024).Get(ts.URL + "/") + assertNotNil(t, err) + assertEqual(t, err, ErrResponseBodyTooLarge) + }) + + t.Run("body less than limit", func(t *testing.T) { + c := dc() + + res, err := c.R().SetResponseBodyLimit(800*100 + 10).Get(ts.URL + "/") + assertNil(t, err) + assertEqual(t, 800*100, len(res.body)) + }) + + t.Run("no body limit", func(t *testing.T) { + c := dc() + + res, err := c.R().Get(ts.URL + "/") + assertNil(t, err) + assertEqual(t, 800*100, len(res.body)) + }) + + t.Run("read error", func(t *testing.T) { + tse := createTestServer(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(hdrContentEncodingKey, "gzip") + var buf [1024]byte + w.Write(buf[:]) + }) + defer tse.Close() + + c := dc() + + _, err := c.R().SetResponseBodyLimit(10240).Get(tse.URL + "/") + assertErrorIs(t, err, gzip.ErrHeader) + }) +} diff --git a/request.go b/request.go index 91d432ca..cfbe89b4 100644 --- a/request.go +++ b/request.go @@ -73,6 +73,7 @@ type Request struct { multipartFiles []*File multipartFields []*MultipartField retryConditions []RetryConditionFunc + responseBodyLimit int } // Generate curl command for the request. @@ -600,6 +601,20 @@ func (r *Request) SetDoNotParseResponse(parse bool) *Request { return r } +// SetResponseBodyLimit set a max body size limit on response, avoid reading too many data to memory. +// +// Request will return [resty.ErrResponseBodyTooLarge] if uncompressed response body size if larger than limit. +// Body size limit will not be enforced in following case: +// - ResponseBodyLimit <= 0, which is the default behavior. +// - [Request.SetOutput] is called to save a response data to file. +// - "DoNotParseResponse" is set for client or request. +// +// This will override Client config. +func (r *Request) SetResponseBodyLimit(v int) *Request { + r.responseBodyLimit = v + return r +} + // SetPathParam method sets single URL path key-value pair in the // Resty current request instance. // diff --git a/resty_test.go b/resty_test.go index 22b483c1..95ef0b51 100644 --- a/resty_test.go +++ b/resty_test.go @@ -809,6 +809,7 @@ func dclr() *Request { } func assertNil(t *testing.T, v interface{}) { + t.Helper() if !isNil(v) { t.Errorf("[%v] was expected to be nil", v) } @@ -841,6 +842,7 @@ func assertErrorIs(t *testing.T, e, g error) (r bool) { } func assertEqual(t *testing.T, e, g interface{}) (r bool) { + t.Helper() if !equal(e, g) { t.Errorf("Expected [%v], got [%v]", e, g) }