Skip to content

Commit

Permalink
Merge pull request #36 from basecamp/sse
Browse files Browse the repository at this point in the history
Don't buffer SSE responses
  • Loading branch information
kevinmcconnell authored Oct 2, 2024
2 parents 745eeec + dbb4824 commit ba1d549
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 0 deletions.
7 changes: 7 additions & 0 deletions internal/server/logging_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,10 @@ func (r *loggerResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
return con, rw, err
}

func (r *loggerResponseWriter) Flush() {
flusher, ok := r.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
26 changes: 26 additions & 0 deletions internal/server/response_buffer_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type bufferedResponseWriter struct {
buffer *Buffer
hijacked bool
headerWritten bool
bypass bool
}

func (w *bufferedResponseWriter) Send() error {
Expand All @@ -61,6 +62,7 @@ func (w *bufferedResponseWriter) Send() error {
if w.headerWritten {
w.ResponseWriter.WriteHeader(w.statusCode)
}

return w.buffer.Send(w.ResponseWriter)
}

Expand All @@ -72,10 +74,27 @@ func (w *bufferedResponseWriter) WriteHeader(statusCode int) {
if !w.headerWritten {
w.statusCode = statusCode
w.headerWritten = true

if w.ShouldSwitchToUnbuffered() {
w.SwitchToUnbuffered()
}
}
}

func (w *bufferedResponseWriter) ShouldSwitchToUnbuffered() bool {
return w.Header().Get("Content-Type") == "text/event-stream"
}

func (w *bufferedResponseWriter) SwitchToUnbuffered() {
w.bypass = true
_ = w.Send()
}

func (w *bufferedResponseWriter) Write(data []byte) (int, error) {
if w.bypass {
return w.ResponseWriter.Write(data)
}

n, err := w.buffer.Write(data)
if err == ErrMaximumSizeExceeded {
// Returning an error here will cause the ReverseProxy to panic. If the
Expand All @@ -95,3 +114,10 @@ func (w *bufferedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
return nil, nil, http.ErrNotSupported
}

func (w *bufferedResponseWriter) Flush() {
flusher, ok := w.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
7 changes: 7 additions & 0 deletions internal/server/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,10 @@ func (r *targetResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
r.inflightRequest.hijacked = true
return hijacker.Hijack()
}

func (r *targetResponseWriter) Flush() {
flusher, ok := r.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
58 changes: 58 additions & 0 deletions internal/server/target_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"bufio"
"context"
"net"
"net/http"
Expand Down Expand Up @@ -30,6 +31,63 @@ func TestTarget_Serve(t *testing.T) {
require.Equal(t, "ok", string(w.Body.String()))
}

func TestTarget_ServeSSE(t *testing.T) {
receiveSSEMessage := func(bufferRequests, bufferResponses bool) (string, error) {
finishedReading := make(chan struct{})

targetOptions := TargetOptions{
BufferRequests: bufferRequests,
BufferResponses: bufferResponses,
MaxMemoryBufferSize: DefaultMaxMemoryBufferSize,
HealthCheckConfig: defaultHealthCheckConfig,
}

target := testTargetWithOptions(t, targetOptions, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Write([]byte("data: hello\n\n"))
w.(http.Flusher).Flush()

// Don't return until the client has finished reading. Fail the test if this takes too long.
select {
case <-finishedReading:
break
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for client to finish reading")
}
})

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r, err := target.StartRequest(r)
require.NoError(t, err)
target.SendRequest(w, r)
}))
defer server.Close()
defer close(finishedReading)

resp, err := http.Get(server.URL)
require.NoError(t, err)

scanner := bufio.NewScanner(resp.Body)
if !scanner.Scan() {
return "", scanner.Err()
}

return scanner.Text(), nil
}

t.Run("without buffering", func(t *testing.T) {
message, err := receiveSSEMessage(false, false)
require.NoError(t, err)
assert.Equal(t, "data: hello", message)
})

t.Run("with buffering", func(t *testing.T) {
message, err := receiveSSEMessage(true, true)
require.NoError(t, err)
assert.Equal(t, "data: hello", message)
})
}

func TestTarget_ServeWebSocket(t *testing.T) {
sendWebsocketMessage := func(bufferRequests, bufferResponses bool, body string) (websocket.MessageType, []byte, error) {
targetOptions := TargetOptions{
Expand Down

0 comments on commit ba1d549

Please sign in to comment.