Skip to content

Commit

Permalink
[Feature/extensions]Enforce type safety for RegisterTransportActionsR…
Browse files Browse the repository at this point in the history
…equest (#4796)

* fix compiler error for test file

Signed-off-by: mloufra <[email protected]>

* fix test class RegisterTransportActionsRequestTests

Signed-off-by: mloufra <[email protected]>

Signed-off-by: mloufra <[email protected]>
  • Loading branch information
mloufra authored Oct 21, 2022
1 parent 401c210 commit 58f2c6e
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Return consumed params and content from extensions ([#4705](https:/opensearch-project/OpenSearch/pull/4705))
- Modified EnvironmentSettingsRequest to pass entire Settings object ([#4731](https:/opensearch-project/OpenSearch/pull/4731))
- Added contentParser method to ExtensionRestRequest ([#4760](https:/opensearch-project/OpenSearch/pull/4760))
- Enforce type safety for RegisterTransportActionsRequest([#4796](https:/opensearch-project/OpenSearch/pull/4796))

## [2.x]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

package org.opensearch.extensions;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.action.support.TransportAction;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.transport.TransportRequest;
Expand All @@ -16,6 +19,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Map.Entry;

/**
* Request to register extension Transport actions
Expand All @@ -24,22 +28,27 @@
*/
public class RegisterTransportActionsRequest extends TransportRequest {
private String uniqueId;
private Map<String, Class> transportActions;
private Map<String, Class<? extends TransportAction<? extends ActionRequest, ? extends ActionResponse>>> transportActions;

public RegisterTransportActionsRequest(String uniqueId, Map<String, Class> transportActions) {
public RegisterTransportActionsRequest(
String uniqueId,
Map<String, Class<? extends TransportAction<? extends ActionRequest, ? extends ActionResponse>>> transportActions
) {
this.uniqueId = uniqueId;
this.transportActions = new HashMap<>(transportActions);
}

public RegisterTransportActionsRequest(StreamInput in) throws IOException {
super(in);
this.uniqueId = in.readString();
Map<String, Class> actions = new HashMap<>();
Map<String, Class<? extends TransportAction<? extends ActionRequest, ? extends ActionResponse>>> actions = new HashMap<>();
int actionCount = in.readVInt();
for (int i = 0; i < actionCount; i++) {
try {
String actionName = in.readString();
Class transportAction = Class.forName(in.readString());
@SuppressWarnings("unchecked")
Class<? extends TransportAction<? extends ActionRequest, ? extends ActionResponse>> transportAction = (Class<
? extends TransportAction<? extends ActionRequest, ? extends ActionResponse>>) Class.forName(in.readString());
actions.put(actionName, transportAction);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Could not read transport action");
Expand All @@ -52,7 +61,9 @@ public String getUniqueId() {
return uniqueId;
}

public Map<String, Class> getTransportActions() {
/// comments

public Map<String, Class<? extends TransportAction<? extends ActionRequest, ? extends ActionResponse>>> getTransportActions() {
return transportActions;
}

Expand All @@ -61,7 +72,8 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(uniqueId);
out.writeVInt(this.transportActions.size());
for (Map.Entry<String, Class> action : transportActions.entrySet()) {
for (Entry<String, Class<? extends TransportAction<? extends ActionRequest, ? extends ActionResponse>>> action : transportActions
.entrySet()) {
out.writeString(action.getKey());
out.writeString(action.getValue().getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.extensions;

import org.junit.Before;
import org.opensearch.action.admin.indices.create.AutoCreateAction.TransportAction;
import org.opensearch.common.collect.Map;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
Expand All @@ -21,7 +22,7 @@ public class RegisterTransportActionsRequestTests extends OpenSearchTestCase {

@Before
public void setup() {
this.originalRequest = new RegisterTransportActionsRequest("extension-uniqueId", Map.of("testAction", Map.class));
this.originalRequest = new RegisterTransportActionsRequest("extension-uniqueId", Map.of("testAction", TransportAction.class));
}

public void testRegisterTransportActionsRequest() throws IOException {
Expand All @@ -39,7 +40,7 @@ public void testRegisterTransportActionsRequest() throws IOException {
public void testToString() {
assertEquals(
originalRequest.toString(),
"TransportActionsRequest{uniqueId=extension-uniqueId, actions={testAction=class org.opensearch.common.collect.Map}}"
"TransportActionsRequest{uniqueId=extension-uniqueId, actions={testAction=class org.opensearch.action.admin.indices.create.AutoCreateAction$TransportAction}}"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.junit.After;
import org.junit.Before;
import org.opensearch.Version;
import org.opensearch.action.admin.indices.create.AutoCreateAction.TransportAction;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
Expand Down Expand Up @@ -129,10 +130,7 @@ public void testRegisterAction() {

public void testRegisterTransportActionsRequest() {
String action = "test-action";
RegisterTransportActionsRequest request = new RegisterTransportActionsRequest(
"uniqueid1",
Map.of(action, ExtensionTransportActionsHandlerTests.class)
);
RegisterTransportActionsRequest request = new RegisterTransportActionsRequest("uniqueid1", Map.of(action, TransportAction.class));
ExtensionBooleanResponse response = (ExtensionBooleanResponse) extensionTransportActionsHandler
.handleRegisterTransportActionsRequest(request);
assertTrue(response.getStatus());
Expand Down Expand Up @@ -165,7 +163,7 @@ public void testSendTransportRequestToExtension() throws InterruptedException {
// Register Action
RegisterTransportActionsRequest registerRequest = new RegisterTransportActionsRequest(
"uniqueid1",
Map.of(action, ExtensionTransportActionsHandlerTests.class)
Map.of(action, TransportAction.class)
);
ExtensionBooleanResponse response = (ExtensionBooleanResponse) extensionTransportActionsHandler
.handleRegisterTransportActionsRequest(registerRequest);
Expand Down

0 comments on commit 58f2c6e

Please sign in to comment.