Skip to content

Commit

Permalink
feat: add max reponse body limit (#830)
Browse files Browse the repository at this point in the history
  • Loading branch information
trim21 authored Aug 29, 2024
1 parent f575bf6 commit 10bf84f
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 6 deletions.
61 changes: 55 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ type Client struct {
// HeaderAuthorizationKey is used to set/access Request Authorization header
// value when `SetAuthToken` option is used.
HeaderAuthorizationKey string
ResponseBodyLimit int

jsonEscapeHTML bool
setContentLength bool
Expand Down Expand Up @@ -442,11 +443,12 @@ func (c *Client) R() *Request {
RawPathParams: map[string]string{},
Debug: c.Debug,

client: c,
multipartFiles: []*File{},
multipartFields: []*MultipartField{},
jsonEscapeHTML: c.jsonEscapeHTML,
log: c.log,
client: c,
multipartFiles: []*File{},
multipartFields: []*MultipartField{},
jsonEscapeHTML: c.jsonEscapeHTML,
log: c.log,
responseBodyLimit: c.ResponseBodyLimit,
}
return r
}
Expand Down Expand Up @@ -1089,6 +1091,20 @@ func (c *Client) SetJSONEscapeHTML(b bool) *Client {
return c
}

// SetResponseBodyLimit set a max body size limit on response, avoid reading too many data to memory.
//
// Client will return [resty.ErrResponseBodyTooLarge] if uncompressed response body size if larger than limit.
// Body size limit will not be enforced in following case:
// - ResponseBodyLimit <= 0, which is the default behavior.
// - [Request.SetOutput] is called to save a response data to file.
// - "DoNotParseResponse" is set for client or request.
//
// this can be overridden at client level with [Request.SetResponseBodyLimit]
func (c *Client) SetResponseBodyLimit(v int) *Client {
c.ResponseBodyLimit = v
return c
}

// EnableTrace method enables the Resty client trace for the requests fired from
// the client using `httptrace.ClientTrace` and provides insights.
//
Expand Down Expand Up @@ -1238,7 +1254,7 @@ func (c *Client) execute(req *Request) (*Response, error) {
}
}

if response.body, err = io.ReadAll(body); err != nil {
if response.body, err = readAllWithLimit(body, req.responseBodyLimit); err != nil {
response.setReceivedAt()
return response, err
}
Expand All @@ -1258,6 +1274,39 @@ func (c *Client) execute(req *Request) (*Response, error) {
return response, wrapNoRetryErr(err)
}

var ErrResponseBodyTooLarge = errors.New("resty: response body too large")

// https:/golang/go/issues/51115
// [io.LimitedReader] can only return [io.EOF]
func readAllWithLimit(r io.Reader, maxSize int) ([]byte, error) {
if maxSize <= 0 {
return io.ReadAll(r)
}

var buf [512]byte // make buf stack allocated
result := make([]byte, 0, 512)
total := 0
for {
n, err := r.Read(buf[:])
total += n
if total > maxSize {
return nil, ErrResponseBodyTooLarge
}

if err != nil {
if err == io.EOF {
result = append(result, buf[:n]...)
break
}
return nil, err
}

result = append(result, buf[:n]...)
}

return result, nil
}

// getting TLS client config if not exists then create one
func (c *Client) tlsConfig() (*tls.Config, error) {
transport, err := c.Transport()
Expand Down
55 changes: 55 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package resty

import (
"bytes"
"compress/gzip"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -1097,3 +1099,56 @@ func TestClone(t *testing.T) {
assertEqual(t, "clone", parent.UserInfo.Username)
assertEqual(t, "clone", clone.UserInfo.Username)
}

func TestResponseBodyLimit(t *testing.T) {
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
io.CopyN(w, rand.Reader, 100*800)
})
defer ts.Close()

t.Run("Client body limit", func(t *testing.T) {
c := dc().SetResponseBodyLimit(1024)

_, err := c.R().Get(ts.URL + "/")
assertNotNil(t, err)
assertEqual(t, err, ErrResponseBodyTooLarge)
})

t.Run("request body limit", func(t *testing.T) {
c := dc()

_, err := c.R().SetResponseBodyLimit(1024).Get(ts.URL + "/")
assertNotNil(t, err)
assertEqual(t, err, ErrResponseBodyTooLarge)
})

t.Run("body less than limit", func(t *testing.T) {
c := dc()

res, err := c.R().SetResponseBodyLimit(800*100 + 10).Get(ts.URL + "/")
assertNil(t, err)
assertEqual(t, 800*100, len(res.body))
})

t.Run("no body limit", func(t *testing.T) {
c := dc()

res, err := c.R().Get(ts.URL + "/")
assertNil(t, err)
assertEqual(t, 800*100, len(res.body))
})

t.Run("read error", func(t *testing.T) {
tse := createTestServer(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(hdrContentEncodingKey, "gzip")
var buf [1024]byte
w.Write(buf[:])
})
defer tse.Close()

c := dc()

_, err := c.R().SetResponseBodyLimit(10240).Get(tse.URL + "/")
assertErrorIs(t, err, gzip.ErrHeader)
})
}
15 changes: 15 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type Request struct {
multipartFiles []*File
multipartFields []*MultipartField
retryConditions []RetryConditionFunc
responseBodyLimit int
}

// Generate curl command for the request.
Expand Down Expand Up @@ -600,6 +601,20 @@ func (r *Request) SetDoNotParseResponse(parse bool) *Request {
return r
}

// SetResponseBodyLimit set a max body size limit on response, avoid reading too many data to memory.
//
// Request will return [resty.ErrResponseBodyTooLarge] if uncompressed response body size if larger than limit.
// Body size limit will not be enforced in following case:
// - ResponseBodyLimit <= 0, which is the default behavior.
// - [Request.SetOutput] is called to save a response data to file.
// - "DoNotParseResponse" is set for client or request.
//
// This will override Client config.
func (r *Request) SetResponseBodyLimit(v int) *Request {
r.responseBodyLimit = v
return r
}

// SetPathParam method sets single URL path key-value pair in the
// Resty current request instance.
//
Expand Down
2 changes: 2 additions & 0 deletions resty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ func dclr() *Request {
}

func assertNil(t *testing.T, v interface{}) {
t.Helper()
if !isNil(v) {
t.Errorf("[%v] was expected to be nil", v)
}
Expand Down Expand Up @@ -841,6 +842,7 @@ func assertErrorIs(t *testing.T, e, g error) (r bool) {
}

func assertEqual(t *testing.T, e, g interface{}) (r bool) {
t.Helper()
if !equal(e, g) {
t.Errorf("Expected [%v], got [%v]", e, g)
}
Expand Down

0 comments on commit 10bf84f

Please sign in to comment.