Skip to content

Commit

Permalink
Merge pull request #14 from thejoeejoee/feat-sub-http-response-status
Browse files Browse the repository at this point in the history
[watermill-http] Custom HTTP response status
  • Loading branch information
m110 authored Jun 27, 2024
2 parents 92154bf + 0103e46 commit 03bd46c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
30 changes: 30 additions & 0 deletions pkg/http/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package http

import (
"context"
"github.com/ThreeDotsLabs/watermill/message"
)

// ctxResponseStatusCodeKey is a context key for the http status code in the message context
type ctxResponseStatusCodeKey struct{}

// StatusCodeFromContext returns the status code from the context.
func StatusCodeFromContext(ctx context.Context, otherwise int) int {
if v := ctx.Value(ctxResponseStatusCodeKey{}); v != nil {
if code, ok := v.(int); ok {
return code
}
}
return otherwise
}

// WithResponseStatusCode returns a new context with the status code.
func WithResponseStatusCode(ctx context.Context, code int) context.Context {
return context.WithValue(ctx, ctxResponseStatusCodeKey{}, code)
}

// SetResponseStatusCode sets a http status code to the given message.
func SetResponseStatusCode(m *message.Message, code int) *message.Message {
m.SetContext(WithResponseStatusCode(m.Context(), code))
return m
}
38 changes: 38 additions & 0 deletions pkg/http/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,44 @@ func TestHttpPubSub(t *testing.T) {
})
}

func TestHttpSubStatusCode(t *testing.T) {
pub, sub := createPubSub(t)

defer func() {
require.NoError(t, pub.Close())
require.NoError(t, sub.Close())
}()

msgs, err := sub.Subscribe(context.Background(), "/test")
require.NoError(t, err)

go func() {
_ = sub.StartHTTPServer()
}()

waitForHTTP(t, sub, time.Second*10)

t.Run("response with custom http status code", func(t *testing.T) {
go func() {
select {
case <-time.After(time.Second * 10):
return
case msg := <-msgs:
http.SetResponseStatusCode(msg, nethttp.StatusForbidden)
msg.Nack()
}
}()

req, err := nethttp.NewRequest(nethttp.MethodPost, fmt.Sprintf("http://%s/test", sub.Addr()), nil)
require.NoError(t, err)

resp, err := nethttp.DefaultClient.Do(req)
require.NoError(t, err)

require.Equal(t, nethttp.StatusForbidden, resp.StatusCode)
})
}

func waitForHTTP(t *testing.T, sub *http.Subscriber, timeoutTime time.Duration) {
timeout := time.After(timeoutTime)
for {
Expand Down
15 changes: 9 additions & 6 deletions pkg/http/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,17 @@ func (s *Subscriber) Subscribe(ctx context.Context, url string) (<-chan *message
s.logger.Trace("Waiting for ACK", logFields)
select {
case <-msg.Acked():
s.logger.Trace("Message acknowledged", logFields.Add(watermill.LogFields{"err": err}))
w.WriteHeader(http.StatusOK)
code := StatusCodeFromContext(msg.Context(), http.StatusOK)
s.logger.Trace("Message acknowledged", logFields.Add(watermill.LogFields{"err": err, "http_status_code": code}))
w.WriteHeader(code)
case <-msg.Nacked():
s.logger.Trace("Message nacked", logFields.Add(watermill.LogFields{"err": err}))
w.WriteHeader(http.StatusInternalServerError)
code := StatusCodeFromContext(msg.Context(), http.StatusInternalServerError)
s.logger.Trace("Message nacked", logFields.Add(watermill.LogFields{"err": err, "http_status_code": code}))
w.WriteHeader(code)
case <-r.Context().Done():
s.logger.Info("Request stopped without ACK received", logFields)
w.WriteHeader(http.StatusInternalServerError)
code := StatusCodeFromContext(msg.Context(), http.StatusInternalServerError)
s.logger.Info("Request stopped without ACK received", logFields.Add(watermill.LogFields{"http_status_code": code}))
w.WriteHeader(code)
}
})

Expand Down

0 comments on commit 03bd46c

Please sign in to comment.