diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f9075b6a4ed3..86b72c1fefdb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add query for initialized extensions ([#5658](https://github.com/opensearch-project/OpenSearch/pull/5658)) - Revert 'Added jackson dependency to server' and change extension reading ([#5768](https://github.com/opensearch-project/OpenSearch/pull/5768)) - Add support to disallow search request with preference parameter with strict weighted shard routing([#5874](https://github.com/opensearch-project/OpenSearch/pull/5874)) +- Replace latches with CompletableFutures for extensions ([#5646](https://github.com/opensearch-project/OpenSearch/pull/5646)) ### Dependencies - Bumps `log4j-core` from 2.18.0 to 2.19.0 diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index 885e4b0e35ee6..ca65215599891 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -20,7 +20,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; @@ -198,7 +200,7 @@ public void initializeServicesAndRestHandler( * * @param request which was sent by an extension. */ - public ExtensionActionResponse handleTransportRequest(ExtensionActionRequest request) throws InterruptedException { + public ExtensionActionResponse handleTransportRequest(ExtensionActionRequest request) throws Exception { return extensionTransportActionsHandler.sendTransportRequestToExtension(request); } @@ -401,13 +403,17 @@ public String executor() { new InitializeExtensionRequest(transportService.getLocalNode(), extension), initializeExtensionResponseHandler ); - // TODO: make asynchronous - inProgressFuture.get(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS); - } catch (Exception e) { - try { - throw e; - } catch (Exception e1) { - logger.error(e.toString()); + inProgressFuture.orTimeout(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof TimeoutException) { + logger.info("No response from extension to request."); + } + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } else if (e.getCause() instanceof Error) { + throw (Error) e.getCause(); + } else { + throw new RuntimeException(e.getCause()); } } } @@ -462,7 +468,7 @@ public void handleResponse(AcknowledgedResponse response) { @Override public void handleException(TransportException exp) { - + inProgressIndexNameFuture.completeExceptionally(exp); } @Override @@ -506,20 +512,21 @@ public void beforeIndexRemoved( new IndicesModuleRequest(indexModule), acknowledgedResponseHandler ); - // TODO: make asynchronous - inProgressIndexNameFuture.get(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS); - logger.info("Received ack response from Extension"); - } catch (Exception e) { - try { - throw e; - } catch (Exception e1) { - logger.error(e.toString()); - } + inProgressIndexNameFuture.whenComplete((r, e) -> { + if (e != null) { + inProgressFuture.complete(response); + } else if (e == null) { + inProgressFuture.completeExceptionally(e); + } + }); + } catch (Exception ex) { + inProgressFuture.completeExceptionally(ex); } } }); + } else { + inProgressFuture.complete(response); } - inProgressFuture.complete(response); } @Override @@ -542,14 +549,18 @@ public String executor() { new IndicesModuleRequest(indexModule), indicesModuleResponseHandler ); - // TODO: make asynchronous - inProgressFuture.get(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS); + inProgressFuture.orTimeout(EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join(); logger.info("Received response from Extension"); - } catch (Exception e) { - try { - throw e; - } catch (Exception e1) { - logger.error(e.toString()); + } catch (CompletionException e) { + if (e.getCause() instanceof TimeoutException) { + logger.info("No response from extension to request."); + } + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } else if (e.getCause() instanceof Error) { + throw (Error) e.getCause(); + } else { + throw new RuntimeException(e.getCause()); } } } diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java index ac3ec6630634a..f76fe794b2f84 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java @@ -28,8 +28,10 @@ import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; /** * This class manages TransportActions for extensions @@ -108,10 +110,9 @@ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransport * @return {@link TransportResponse} which is sent back to the transport action invoker. * @throws InterruptedException when message transport fails. */ - public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) - throws InterruptedException { + public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) throws Exception { DiscoveryExtensionNode extension = extensionIdMap.get(request.getUniqueId()); - final CountDownLatch inProgressLatch = new CountDownLatch(1); + final CompletableFuture inProgressFuture = new CompletableFuture<>(); final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]); client.execute( ExtensionProxyAction.INSTANCE, @@ -120,7 +121,7 @@ public TransportResponse handleTransportActionRequestFromExtension(TransportActi @Override public void onResponse(ExtensionActionResponse actionResponse) { response.setResponseBytes(actionResponse.getResponseBytes()); - inProgressLatch.countDown(); + inProgressFuture.complete(actionResponse); } @Override @@ -128,11 +129,24 @@ public void onFailure(Exception exp) { logger.debug("Transport request failed", exp); byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); response.setResponseBytes(responseBytes); - inProgressLatch.countDown(); + inProgressFuture.completeExceptionally(exp); } } ); - inProgressLatch.await(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS); + try { + inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof TimeoutException) { + logger.info("No response from extension to request."); + } + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } else if (e.getCause() instanceof Error) { + throw (Error) e.getCause(); + } else { + throw new RuntimeException(e.getCause()); + } + } return response; } @@ -143,12 +157,12 @@ public void onFailure(Exception exp) { * @return {@link ExtensionActionResponse} which encapsulates the transport response from the extension. * @throws InterruptedException when message transport fails. */ - public ExtensionActionResponse sendTransportRequestToExtension(ExtensionActionRequest request) throws InterruptedException { + public ExtensionActionResponse sendTransportRequestToExtension(ExtensionActionRequest request) throws Exception { DiscoveryExtensionNode extension = actionsMap.get(request.getAction()); if (extension == null) { throw new ActionNotFoundTransportException(request.getAction()); } - final CountDownLatch inProgressLatch = new CountDownLatch(1); + final CompletableFuture inProgressFuture = new CompletableFuture<>(); final ExtensionActionResponse extensionActionResponse = new ExtensionActionResponse(new byte[0]); final TransportResponseHandler extensionActionResponseTransportResponseHandler = new TransportResponseHandler() { @@ -161,7 +175,7 @@ public ExtensionActionResponse read(StreamInput in) throws IOException { @Override public void handleResponse(ExtensionActionResponse response) { extensionActionResponse.setResponseBytes(response.getResponseBytes()); - inProgressLatch.countDown(); + inProgressFuture.complete(response); } @Override @@ -169,7 +183,7 @@ public void handleException(TransportException exp) { logger.debug("Transport request failed", exp); byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); extensionActionResponse.setResponseBytes(responseBytes); - inProgressLatch.countDown(); + inProgressFuture.completeExceptionally(exp); } @Override @@ -187,7 +201,20 @@ public String executor() { } catch (Exception e) { logger.info("Failed to send transport action to extension " + extension.getName(), e); } - inProgressLatch.await(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS); + try { + inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof TimeoutException) { + logger.info("No response from extension to request."); + } + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } else if (e.getCause() instanceof Error) { + throw (Error) e.getCause(); + } else { + throw new RuntimeException(e.getCause()); + } + } return extensionActionResponse; } } diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java index 38e92ed604a09..357be3a9fc2fe 100644 --- a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java +++ b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java @@ -32,8 +32,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.unmodifiableList; @@ -122,7 +125,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC emptyList(), false ); - final CountDownLatch inProgressLatch = new CountDownLatch(1); + final CompletableFuture inProgressFuture = new CompletableFuture<>(); final TransportResponseHandler restExecuteOnExtensionResponseHandler = new TransportResponseHandler< RestExecuteOnExtensionResponse>() { @@ -143,15 +146,13 @@ public void handleResponse(RestExecuteOnExtensionResponse response) { if (response.isContentConsumed()) { request.content(); } + inProgressFuture.complete(response); } @Override public void handleException(TransportException exp) { logger.debug("REST request failed", exp); - // Status is already defaulted to 500 (INTERNAL_SERVER_ERROR) - byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); - restExecuteOnExtensionResponse.setContent(responseBytes); - inProgressLatch.countDown(); + inProgressFuture.completeExceptionally(exp); } @Override @@ -172,15 +173,24 @@ public String executor() { new ExtensionRestRequest(method, path, params, contentType, content, requestIssuerIdentity), restExecuteOnExtensionResponseHandler ); - try { - inProgressLatch.await(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { + inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof TimeoutException) { return channel -> channel.sendResponse( new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, "No response from extension to request.") ); } - } catch (Exception e) { - logger.info("Failed to send REST Actions to extension " + discoveryExtensionNode.getName(), e); + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } else if (e.getCause() instanceof Error) { + throw (Error) e.getCause(); + } else { + throw new RuntimeException(e.getCause()); + } + } catch (Exception ex) { + logger.info("Failed to send REST Actions to extension " + discoveryExtensionNode.getName(), ex); + return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, ex.getMessage())); } BytesRestResponse restResponse = new BytesRestResponse( restExecuteOnExtensionResponse.getStatus(), diff --git a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java index 44cf3a38f01d1..5de2113672ca5 100644 --- a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java +++ b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java @@ -84,6 +84,8 @@ import org.opensearch.test.transport.MockTransportService; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.ConnectTransportException; +import org.opensearch.transport.NodeNotConnectedException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportService; @@ -427,23 +429,23 @@ public void testInitialize() throws Exception { mockLogAppender.addExpectation( new MockLogAppender.SeenEventExpectation( - "Connect Transport Exception 1", + "Node Not Connected Exception 1", "org.opensearch.extensions.ExtensionsManager", Level.ERROR, - "ConnectTransportException[[firstExtension][127.0.0.0:9300] connect_timeout[30s]]" + "[secondExtension][127.0.0.1:9301] Node not connected" ) ); mockLogAppender.addExpectation( new MockLogAppender.SeenEventExpectation( - "Connect Transport Exception 2", + "Node Not Connected Exception 2", "org.opensearch.extensions.ExtensionsManager", Level.ERROR, - "ConnectTransportException[[secondExtension][127.0.0.1:9301] connect_exception]; nested: ConnectException[Connection refused];" + "[firstExtension][127.0.0.0:9300] Node not connected" ) ); - extensionsManager.initialize(); + expectThrows(ConnectTransportException.class, () -> extensionsManager.initialize()); // Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for // now. @@ -831,21 +833,8 @@ public void testOnIndexModule() throws Exception { new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)), Collections.emptyMap() ); + expectThrows(NodeNotConnectedException.class, () -> extensionsManager.onIndexModule(indexModule)); - try (MockLogAppender mockLogAppender = MockLogAppender.createForLoggers(LogManager.getLogger(ExtensionsManager.class))) { - - mockLogAppender.addExpectation( - new MockLogAppender.SeenEventExpectation( - "IndicesModuleRequest Failure", - "org.opensearch.extensions.ExtensionsManager", - Level.ERROR, - "IndicesModuleRequest failed" - ) - ); - - extensionsManager.onIndexModule(indexModule); - mockLogAppender.assertAllExpectationsMatched(); - } } private void initialize(ExtensionsManager extensionsManager) { diff --git a/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java b/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java index c3d6372a4f6b8..276e47d7f55a8 100644 --- a/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java +++ b/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java @@ -31,6 +31,7 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.ActionNotFoundTransportException; +import org.opensearch.transport.NodeNotConnectedException; import org.opensearch.transport.TransportService; import org.opensearch.transport.nio.MockNioTransport; @@ -172,10 +173,6 @@ public void testSendTransportRequestToExtension() throws InterruptedException { ); assertTrue(response.getStatus()); - ExtensionActionResponse extensionResponse = extensionTransportActionsHandler.sendTransportRequestToExtension(request); - assertEquals( - "Request failed: [firstExtension][127.0.0.0:9300] Node not connected", - new String(extensionResponse.getResponseBytes(), StandardCharsets.UTF_8) - ); + expectThrows(NodeNotConnectedException.class, () -> extensionTransportActionsHandler.sendTransportRequestToExtension(request)); } }