diff --git a/.golangci.yml b/.golangci.yml index fc627292..d94bf062 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -132,3 +132,6 @@ issues: # We want to show examples with http.Get - linters: [noctx] path: internal/memhttp/memhttp_test.go + # We need to initialize a map of all protocol headers + - linters: [gochecknoglobals] + path: header.go diff --git a/connect_ext_test.go b/connect_ext_test.go index 42294086..81ea030a 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2631,6 +2631,70 @@ func TestBlankImportCodeGeneration(t *testing.T) { assert.NotNil(t, desc) } +// TestSetProtocolHeaders tests that headers required by the protocols are set +// overriding user provided headers. +func TestSetProtocolHeaders(t *testing.T) { + t.Parallel() + tests := []struct { + name string + clientOption connect.ClientOption + expectContentType string + }{{ + name: "connect", + expectContentType: "application/proto", + }, { + name: "grpc", + clientOption: connect.WithGRPC(), + expectContentType: "application/grpc", + }, { + name: "grpcweb", + clientOption: connect.WithGRPCWeb(), + expectContentType: "application/grpc-web+proto", + }} + for _, tt := range tests { + testcase := tt + t.Run(testcase.name, func(t *testing.T) { + t.Parallel() + pingServer := &pingServer{} + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := memhttptest.NewServer(t, mux) + + clientOpts := []connect.ClientOption{} + if testcase.clientOption == nil { + // Use a different protocol to test the override. + clientOpts = append(clientOpts, connect.WithGRPC()) + } + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) + + pingProxyServer := &pluggablePingServer{ + ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + return client.Ping(ctx, request) + }, + } + proxyMux := http.NewServeMux() + proxyMux.Handle(pingv1connect.NewPingServiceHandler(pingProxyServer)) + proxyServer := memhttptest.NewServer(t, proxyMux) + + proxyClientOpts := []connect.ClientOption{} + if testcase.clientOption != nil { + proxyClientOpts = append(proxyClientOpts, testcase.clientOption) + } + proxyClient := pingv1connect.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) + + request := connect.NewRequest(&pingv1.PingRequest{Number: 42}) + request.Header().Set("X-Test", t.Name()) + response, err := proxyClient.Ping(context.Background(), request) + if !assert.Nil(t, err) { + return + } + // Assert the Content-Type is set for the proxy clients protocol and not the client's. + assert.Equal(t, response.Header().Get("Content-Type"), testcase.expectContentType) + assert.Equal(t, len(response.Header().Values("Content-Type")), 1) + }) + } +} + type unflushableWriter struct { w http.ResponseWriter } diff --git a/error_writer.go b/error_writer.go index 58ce3c42..f05d19ec 100644 --- a/error_writer.go +++ b/error_writer.go @@ -128,7 +128,7 @@ func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error { if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(response.Header(), connectErr.meta) + mergeNonProtocolHeaders(response.Header(), connectErr.meta) } response.WriteHeader(connectCodeToHTTP(CodeOf(err))) data, marshalErr := json.Marshal(newConnectWireError(err)) diff --git a/handler.go b/handler.go index 1d573291..5eab6c71 100644 --- a/handler.go +++ b/handler.go @@ -71,8 +71,8 @@ func NewUnaryHandler[Req, Res any]( if err != nil { return err } - mergeHeaders(conn.ResponseHeader(), response.Header()) - mergeHeaders(conn.ResponseTrailer(), response.Trailer()) + mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) + mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) return conn.Send(response.Any()) } diff --git a/header.go b/header.go index f3c7cacd..b3f05432 100644 --- a/header.go +++ b/header.go @@ -19,6 +19,33 @@ import ( "net/http" ) +var ( + protocolHeaders = map[string]struct{}{ + // HTTP headers. + headerContentType: {}, + headerContentLength: {}, + headerContentEncoding: {}, + headerHost: {}, + headerUserAgent: {}, + headerTrailer: {}, + headerDate: {}, + // Connect headers. + connectUnaryHeaderAcceptCompression: {}, + connectUnaryTrailerPrefix: {}, + connectStreamingHeaderCompression: {}, + connectStreamingHeaderAcceptCompression: {}, + connectHeaderTimeout: {}, + connectHeaderProtocolVersion: {}, + // gRPC headers. + grpcHeaderCompression: {}, + grpcHeaderAcceptCompression: {}, + grpcHeaderTimeout: {}, + grpcHeaderStatus: {}, + grpcHeaderMessage: {}, + grpcHeaderDetails: {}, + } +) + // EncodeBinaryHeader base64-encodes the data. It always emits unpadded values. // // In the Connect, gRPC, and gRPC-Web protocols, binary headers must have keys @@ -57,10 +84,9 @@ func mergeHeaders(into, from http.Header) { } } -// mergeMetdataHeaders merges the metadata headers from the "from" header into -// the "into" header. It skips over non metadata headers that should not be -// propagated from the server to the client. -func mergeMetadataHeaders(into, from http.Header) { +// mergeNonProtocolHeaders merges headers excluding protocol headers defined in +// protocolHeaders. +func mergeNonProtocolHeaders(into, from http.Header) { for key, vals := range from { if len(vals) == 0 { // For response trailers, net/http will pre-populate entries @@ -68,30 +94,7 @@ func mergeMetadataHeaders(into, from http.Header) { // are no actual values for those keys, we skip them. continue } - switch http.CanonicalHeaderKey(key) { - case headerContentType, - headerContentLength, - headerContentEncoding, - headerHost, - headerUserAgent, - headerTrailer, - headerDate: - // HTTP headers. - case connectUnaryHeaderAcceptCompression, - connectUnaryTrailerPrefix, - connectStreamingHeaderCompression, - connectStreamingHeaderAcceptCompression, - connectHeaderTimeout, - connectHeaderProtocolVersion: - // Connect headers. - case grpcHeaderCompression, - grpcHeaderAcceptCompression, - grpcHeaderTimeout, - grpcHeaderStatus, - grpcHeaderMessage, - grpcHeaderDetails: - // gRPC headers. - default: + if _, isProtocolHeader := protocolHeaders[key]; !isProtocolHeader { into[key] = append(into[key], vals...) } } diff --git a/protocol_connect.go b/protocol_connect.go index e3c5e4a5..bf26f1aa 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -765,7 +765,7 @@ func (hc *connectUnaryHandlerConn) mergeResponseHeader(err error) { } if err != nil { if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(header, connectErr.meta) + mergeNonProtocolHeaders(header, connectErr.meta) } } for k, v := range hc.responseTrailer { @@ -850,7 +850,7 @@ func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Hea if err != nil { end.Error = newConnectWireError(err) if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(end.Trailer, connectErr.meta) + mergeNonProtocolHeaders(end.Trailer, connectErr.meta) } } data, marshalErr := json.Marshal(end) diff --git a/protocol_grpc.go b/protocol_grpc.go index 32b116f4..5addf31e 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -841,7 +841,7 @@ func grpcErrorToTrailer(trailer http.Header, protobuf Codec, err error) { return } if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(trailer, connectErr.meta) + mergeNonProtocolHeaders(trailer, connectErr.meta) } var ( status = grpcStatusFromError(err)