diff --git a/pkg/http/context.go b/pkg/http/context.go new file mode 100644 index 0000000..e444b70 --- /dev/null +++ b/pkg/http/context.go @@ -0,0 +1,30 @@ +package http + +import ( + "context" + "github.com/ThreeDotsLabs/watermill/message" +) + +// ctxResponseStatusCodeKey is a context key for the http status code in the message context +type ctxResponseStatusCodeKey struct{} + +// StatusCodeFromContext returns the status code from the context. +func StatusCodeFromContext(ctx context.Context, otherwise int) int { + if v := ctx.Value(ctxResponseStatusCodeKey{}); v != nil { + if code, ok := v.(int); ok { + return code + } + } + return otherwise +} + +// WithResponseStatusCode returns a new context with the status code. +func WithResponseStatusCode(ctx context.Context, code int) context.Context { + return context.WithValue(ctx, ctxResponseStatusCodeKey{}, code) +} + +// SetResponseStatusCode sets a http status code to the given message. +func SetResponseStatusCode(m *message.Message, code int) *message.Message { + m.SetContext(WithResponseStatusCode(m.Context(), code)) + return m +} diff --git a/pkg/http/pubsub_test.go b/pkg/http/pubsub_test.go index 2755935..554ee9f 100644 --- a/pkg/http/pubsub_test.go +++ b/pkg/http/pubsub_test.go @@ -95,6 +95,44 @@ func TestHttpPubSub(t *testing.T) { }) } +func TestHttpSubStatusCode(t *testing.T) { + pub, sub := createPubSub(t) + + defer func() { + require.NoError(t, pub.Close()) + require.NoError(t, sub.Close()) + }() + + msgs, err := sub.Subscribe(context.Background(), "/test") + require.NoError(t, err) + + go func() { + _ = sub.StartHTTPServer() + }() + + waitForHTTP(t, sub, time.Second*10) + + t.Run("response with custom http status code", func(t *testing.T) { + go func() { + select { + case <-time.After(time.Second * 10): + return + case msg := <-msgs: + http.SetResponseStatusCode(msg, nethttp.StatusForbidden) + msg.Nack() + } + }() + + req, err := nethttp.NewRequest(nethttp.MethodPost, fmt.Sprintf("http://%s/test", sub.Addr()), nil) + require.NoError(t, err) + + resp, err := nethttp.DefaultClient.Do(req) + require.NoError(t, err) + + require.Equal(t, nethttp.StatusForbidden, resp.StatusCode) + }) +} + func waitForHTTP(t *testing.T, sub *http.Subscriber, timeoutTime time.Duration) { timeout := time.After(timeoutTime) for { diff --git a/pkg/http/subscriber.go b/pkg/http/subscriber.go index d50813e..eaca068 100644 --- a/pkg/http/subscriber.go +++ b/pkg/http/subscriber.go @@ -139,14 +139,17 @@ func (s *Subscriber) Subscribe(ctx context.Context, url string) (<-chan *message s.logger.Trace("Waiting for ACK", logFields) select { case <-msg.Acked(): - s.logger.Trace("Message acknowledged", logFields.Add(watermill.LogFields{"err": err})) - w.WriteHeader(http.StatusOK) + code := StatusCodeFromContext(msg.Context(), http.StatusOK) + s.logger.Trace("Message acknowledged", logFields.Add(watermill.LogFields{"err": err, "http_status_code": code})) + w.WriteHeader(code) case <-msg.Nacked(): - s.logger.Trace("Message nacked", logFields.Add(watermill.LogFields{"err": err})) - w.WriteHeader(http.StatusInternalServerError) + code := StatusCodeFromContext(msg.Context(), http.StatusInternalServerError) + s.logger.Trace("Message nacked", logFields.Add(watermill.LogFields{"err": err, "http_status_code": code})) + w.WriteHeader(code) case <-r.Context().Done(): - s.logger.Info("Request stopped without ACK received", logFields) - w.WriteHeader(http.StatusInternalServerError) + code := StatusCodeFromContext(msg.Context(), http.StatusInternalServerError) + s.logger.Info("Request stopped without ACK received", logFields.Add(watermill.LogFields{"http_status_code": code})) + w.WriteHeader(code) } })