Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate http client tracing headers #2842

Merged
merged 7 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ public void set(AkkaHttpHeaders carrier, String key, String value) {
HttpRequest request = carrier.getRequest();
if (request != null) {
// It looks like this cast is only needed in Java, Scala would have figured it out
carrier.setRequest((HttpRequest) request.addHeader(RawHeader.create(key, value)));
carrier.setRequest(
(HttpRequest) request.removeHeader(key).addHeader(RawHeader.create(key, value)));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ enum ClientRequestContextSetter implements TextMapSetter<ClientRequestContext> {
@Override
public void set(@Nullable ClientRequestContext carrier, String key, String value) {
if (carrier != null) {
carrier.addAdditionalRequestHeader(key, value);
carrier.setAdditionalRequestHeader(key, value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,12 @@ protected TextMapSetter<HttpRequest> getSetter() {
}

public HttpHeaders inject(HttpHeaders original) {
Map<String, List<String>> headerMap = new HashMap<>();
Map<String, List<String>> headerMap = new HashMap<>(original.map());

inject(
Context.current(),
headerMap,
(carrier, key, value) -> carrier.put(key, Collections.singletonList(value)));
headerMap.putAll(original.map());

return HttpHeaders.of(headerMap, (s, s2) -> true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/

import com.squareup.okhttp.Callback
import com.squareup.okhttp.Headers
import com.squareup.okhttp.MediaType
import com.squareup.okhttp.OkHttpClient
import com.squareup.okhttp.Request
Expand All @@ -27,11 +26,11 @@ class OkHttp2Test extends HttpClientTest<Request> implements AgentTestTrait {
@Override
Request buildRequest(String method, URI uri, Map<String, String> headers) {
def body = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse("text/plain"), "") : null
return new Request.Builder()
def request = new Request.Builder()
.url(uri.toURL())
.method(method, body)
.headers(Headers.of(HeadersUtil.headersToArray(headers)))
.build()
headers.forEach({ key, value -> request.header(key, value) })
return request.build()
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ public class HeadersInjectAdapter implements TextMapSetter<HttpHeaders> {

@Override
public void set(HttpHeaders carrier, String key, String value) {
carrier.add(key, value);
carrier.set(key, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ class HttpHeadersInjectAdapter implements TextMapSetter<ClientRequest.Builder> {

@Override
public void set(ClientRequest.Builder carrier, String key, String value) {
carrier.header(key, value);
carrier.headers(httpHeaders -> httpHeaders.set(key, value));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import static org.junit.Assume.assumeTrue

import groovy.transform.stc.ClosureParams
import groovy.transform.stc.SimpleType
import io.opentelemetry.api.GlobalOpenTelemetry
import io.opentelemetry.api.common.AttributeKey
import io.opentelemetry.api.trace.Span
import io.opentelemetry.instrumentation.test.InstrumentationSpecification
Expand Down Expand Up @@ -95,10 +96,18 @@ abstract class HttpClientTest<REQUEST> extends InstrumentationSpecification {
return sendRequest(request, method, uri, headers)
}

// ideally private, but then groovy closures in this class cannot find them
final int doReusedRequest(String method, URI uri, Map<String, String> headers = [:]) {
private int doReusedRequest(String method, URI uri) {
def request = buildRequest(method, uri, [:])
sendRequest(request, method, uri, [:])
return sendRequest(request, method, uri, [:])
}

private int doRequestWithExistingTracingHeaders(String method, URI uri) {
def headers = new HashMap()
for (String field : GlobalOpenTelemetry.getPropagators().getTextMapPropagator().fields()) {
headers.put(field, "12345789")
}
def request = buildRequest(method, uri, headers)
sendRequest(request, method, uri, headers)
return sendRequest(request, method, uri, headers)
}

Expand Down Expand Up @@ -509,6 +518,31 @@ abstract class HttpClientTest<REQUEST> extends InstrumentationSpecification {
url = server.address.resolve(path)
}

// this test verifies two things:
// * the javaagent doesn't cause multiples of tracing headers to be added
// (TestHttpServer throws exception if there are multiples)
// * the javaagent overwrites the existing tracing headers
// (so that it propagates the same trace id / span id that it reports to the backend
// and the trace is not broken)
def "request with existing tracing headers"() {
when:
def responseCode = doRequestWithExistingTracingHeaders(method, url)

then:
responseCode == 200
assertTraces(1) {
trace(0, 2 + extraClientSpans()) {
clientSpan(it, 0, null, method, url)
serverSpan(it, 1 + extraClientSpans(), span(extraClientSpans()))
}
}

where:
path = "/success"
method = "GET"
url = server.address.resolve(path)
}

def "connection error (unopened port)"() {
given:
assumeTrue(testConnectionFailure())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class TestHttpServer implements AutoCloseable {
for (String field : GlobalOpenTelemetry.getPropagators().getTextMapPropagator().fields()) {
def headers = req.getHeaders(field)
if (headers.hasMoreElements() && headers.nextElement() && headers.hasMoreElements()) {
throw new AssertionError("more than one traceparent header present")
throw new AssertionError("more than one " + field + " header present")
}
}
}
Expand Down