diff --git a/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java b/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java index 7fc13e252768b..4edfa437eafaa 100644 --- a/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java +++ b/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java @@ -33,6 +33,7 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateAction; import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -373,9 +374,11 @@ void forceConnect() { private void connect(ActionListener connectListener, boolean forceRun) { final boolean runConnect; final Collection> toNotify; + final ActionListener listener = connectListener == null ? null : + ContextPreservingActionListener.wrapPreservingContext(connectListener, transportService.getThreadPool().getThreadContext()); synchronized (queue) { - if (connectListener != null && queue.offer(connectListener) == false) { - connectListener.onFailure(new RejectedExecutionException("connect queue is full")); + if (listener != null && queue.offer(listener) == false) { + listener.onFailure(new RejectedExecutionException("connect queue is full")); return; } if (forceRun == false && queue.isEmpty()) { diff --git a/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java b/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java index e1fe265a4d7e2..e082a4d0ef8e3 100644 --- a/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java +++ b/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java @@ -48,6 +48,7 @@ import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.CancellableThreads; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.core.internal.io.IOUtils; @@ -562,6 +563,64 @@ public void testFetchShards() throws Exception { } } + public void testFetchShardsThreadContextHeader() throws Exception { + List knownNodes = new CopyOnWriteArrayList<>(); + try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT); + MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) { + DiscoveryNode seedNode = seedTransport.getLocalDiscoNode(); + knownNodes.add(seedTransport.getLocalDiscoNode()); + knownNodes.add(discoverableTransport.getLocalDiscoNode()); + Collections.shuffle(knownNodes, random()); + try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) { + service.start(); + service.acceptIncomingRequests(); + List nodes = Collections.singletonList(seedNode); + try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster", + nodes, service, Integer.MAX_VALUE, n -> true)) { + SearchRequest request = new SearchRequest("test-index"); + Thread[] threads = new Thread[10]; + for (int i = 0; i < threads.length; i++) { + final String threadId = Integer.toString(i); + threads[i] = new Thread(() -> { + ThreadContext threadContext = seedTransport.threadPool.getThreadContext(); + threadContext.putHeader("threadId", threadId); + AtomicReference reference = new AtomicReference<>(); + AtomicReference failReference = new AtomicReference<>(); + final ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest("test-index") + .indicesOptions(request.indicesOptions()).local(true).preference(request.preference()) + .routing(request.routing()); + CountDownLatch responseLatch = new CountDownLatch(1); + connection.fetchSearchShards(searchShardsRequest, + new LatchedActionListener<>(ActionListener.wrap( + resp -> { + reference.set(resp); + assertEquals(threadId, seedTransport.threadPool.getThreadContext().getHeader("threadId")); + }, + failReference::set), responseLatch)); + try { + responseLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + assertNull(failReference.get()); + assertNotNull(reference.get()); + ClusterSearchShardsResponse clusterSearchShardsResponse = reference.get(); + assertEquals(knownNodes, Arrays.asList(clusterSearchShardsResponse.getNodes())); + }); + } + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + assertTrue(connection.assertNoRunningConnections()); + } + } + } + } + public void testFetchShardsSkipUnavailable() throws Exception { List knownNodes = new CopyOnWriteArrayList<>(); try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) {