Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/extensions]Enforce type safety for NamedWriteableRegistryParseRequest #4923

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Modified EnvironmentSettingsRequest to pass entire Settings object ([#4731](https:/opensearch-project/OpenSearch/pull/4731))
- Added contentParser method to ExtensionRestRequest ([#4760](https:/opensearch-project/OpenSearch/pull/4760))
- Enforce type safety for RegisterTransportActionsRequest([#4796](https:/opensearch-project/OpenSearch/pull/4796))
- Enforce type safety for NamedWriteableRegistryParseRequest ([#4923](https:/opensearch-project/OpenSearch/pull/4923))

## [2.x]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
*/
public class NamedWriteableRegistryParseRequest extends TransportRequest {

private final Class categoryClass;
private final Class<? extends NamedWriteable> categoryClass;
private byte[] context;

/**
* @param categoryClass Class category for this parse request
* @param context StreamInput object to convert into a byte array and transport to the extension
* @throws IllegalArgumentException if context bytes could not be read
*/
public NamedWriteableRegistryParseRequest(Class categoryClass, StreamInput context) {
public NamedWriteableRegistryParseRequest(Class<? extends NamedWriteable> categoryClass, StreamInput context) {
try {
byte[] streamInputBytes = context.readAllBytes();
this.categoryClass = categoryClass;
Expand All @@ -42,10 +42,11 @@ public NamedWriteableRegistryParseRequest(Class categoryClass, StreamInput conte
* @param in StreamInput from which class fields are read from
* @throws IllegalArgumentException if the fully qualified class name is invalid and the class object cannot be generated at runtime
*/
@SuppressWarnings("unchecked")
public NamedWriteableRegistryParseRequest(StreamInput in) throws IOException {
super(in);
try {
this.categoryClass = Class.forName(in.readString());
this.categoryClass = (Class<? extends NamedWriteable>) Class.forName(in.readString());
this.context = in.readByteArray();
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Category class definition not found", e);
Expand Down Expand Up @@ -85,7 +86,7 @@ public int hashCode() {
/**
* Returns the class instance of the category class sent over by the SDK
*/
public Class getCategoryClass() {
public Class<? extends NamedWriteable> getCategoryClass() {
return this.categoryClass;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,29 @@
*/
public class NamedWriteableRegistryResponse extends TransportResponse {

private final Map<String, Class> registry;
private final Map<String, Class<? extends NamedWriteable>> registry;

/**
* @param registry Map of writeable names and their associated category class
*/
public NamedWriteableRegistryResponse(Map<String, Class> registry) {
public NamedWriteableRegistryResponse(Map<String, Class<? extends NamedWriteable>> registry) {
this.registry = new HashMap<>(registry);
}

/**
* @param in StreamInput from which map entries of writeable names and their associated category classes are read from
* @throws IllegalArgumentException if the fully qualified class name is invalid and the class object cannot be generated at runtime
*/
@SuppressWarnings("unchecked")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we try to put @SuppressWarnings at the lowest scope possible. In the constructor we were assigning a value to an instance variable, so we had to use the method level, but here we are assigning to a local variable. You should be able to put this suppression down above line 47.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, I will move @SuppressWarnings in next commit

public NamedWriteableRegistryResponse(StreamInput in) throws IOException {
super(in);
// Stream output for registry map begins with a variable integer that tells us the number of entries being sent across the wire
Map<String, Class> registry = new HashMap<>();
Map<String, Class<? extends NamedWriteable>> registry = new HashMap<>();
int registryEntryCount = in.readVInt();
for (int i = 0; i < registryEntryCount; i++) {
try {
String name = in.readString();
Class categoryClass = Class.forName(in.readString());
Class<? extends NamedWriteable> categoryClass = (Class<? extends NamedWriteable>) Class.forName(in.readString());
registry.put(name, categoryClass);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Category class definition not found", e);
Expand All @@ -57,7 +58,7 @@ public NamedWriteableRegistryResponse(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
// Stream out registry size prior to streaming out registry entries
out.writeVInt(this.registry.size());
for (Map.Entry<String, Class> entry : registry.entrySet()) {
for (Map.Entry<String, Class<? extends NamedWriteable>> entry : registry.entrySet()) {
out.writeString(entry.getKey()); // Unique named writeable name
out.writeString(entry.getValue().getName()); // Fully qualified category class name
}
Expand All @@ -84,7 +85,7 @@ public int hashCode() {
/**
* Returns a map of writeable names and their associated category class
*/
public Map<String, Class> getRegistry() {
public Map<String, Class<? extends NamedWriteable>> getRegistry() {
return Collections.unmodifiableMap(this.registry);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.extensions.ExtensionsOrchestrator.OpenSearchRequestType;
import org.opensearch.transport.TransportService;

Expand All @@ -29,7 +30,7 @@ public class ExtensionNamedWriteableRegistry {

private static final Logger logger = LogManager.getLogger(ExtensionNamedWriteableRegistry.class);

private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionNamedWriteableRegistry;
private Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionNamedWriteableRegistry;
private List<DiscoveryExtension> extensionsInitializedList;
private TransportService transportService;

Expand All @@ -54,7 +55,8 @@ public void getNamedWriteables() {
// Retrieve named writeable registry entries from each extension
for (DiscoveryNode extensionNode : extensionsInitializedList) {
try {
Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionRegistry = getNamedWriteables(extensionNode);
Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionRegistry =
getNamedWriteables(extensionNode);
if (extensionRegistry.isEmpty() == false) {
this.extensionNamedWriteableRegistry.putAll(extensionRegistry);
}
Expand All @@ -74,8 +76,9 @@ public void getNamedWriteables() {
* @throws UnknownHostException if connection to the extension node failed
* @return A map of category classes and their associated names and readers for this discovery node
*/
private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getNamedWriteables(DiscoveryNode extensionNode)
throws UnknownHostException {
private Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> getNamedWriteables(
DiscoveryNode extensionNode
) throws UnknownHostException {
NamedWriteableRegistryResponseHandler namedWriteableRegistryResponseHandler = new NamedWriteableRegistryResponseHandler(
extensionNode,
transportService,
Expand Down Expand Up @@ -104,7 +107,7 @@ private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getNamedWri
* @throws IllegalArgumentException if there is no reader associated with the given category class and name
* @return A map of the discovery node and its associated extension reader
*/
public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class categoryClass, String name) {
public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class<? extends NamedWriteable> categoryClass, String name) {

ExtensionReader reader = null;
DiscoveryNode extension = null;
Expand Down Expand Up @@ -133,9 +136,11 @@ public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class categoryClas
* @param name Unique name identifying the Writeable object
* @return The extension reader
*/
private ExtensionReader getExtensionReader(DiscoveryNode extensionNode, Class categoryClass, String name) {
private ExtensionReader getExtensionReader(DiscoveryNode extensionNode, Class<? extends NamedWriteable> categoryClass, String name) {
ExtensionReader reader = null;
Map<Class, Map<String, ExtensionReader>> categoryMap = this.extensionNamedWriteableRegistry.get(extensionNode);
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = this.extensionNamedWriteableRegistry.get(
extensionNode
);
if (categoryMap != null) {
Map<String, ExtensionReader> readerMap = categoryMap.get(categoryClass);
if (readerMap != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.common.io.stream.NamedWriteableRegistryParseRequest;
import org.opensearch.common.io.stream.NamedWriteableRegistryResponse;
import org.opensearch.common.io.stream.StreamInput;
Expand All @@ -34,7 +35,7 @@
public class NamedWriteableRegistryResponseHandler implements TransportResponseHandler<NamedWriteableRegistryResponse> {
private static final Logger logger = LogManager.getLogger(NamedWriteableRegistryResponseHandler.class);

private final Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionRegistry;
private final Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionRegistry;
private final DiscoveryNode extensionNode;
private final TransportService transportService;
private final String requestType;
Expand All @@ -56,7 +57,7 @@ public NamedWriteableRegistryResponseHandler(DiscoveryNode extensionNode, Transp
/**
* @return A map of the given DiscoveryNode and its inner named writeable registry map
*/
public Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getExtensionRegistry() {
public Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> getExtensionRegistry() {
return Collections.unmodifiableMap(this.extensionRegistry);
}

Expand All @@ -68,7 +69,8 @@ public Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getExtension
* @param context StreamInput object to convert into a byte array and transport to the extension
* @throws UnknownHostException if connection to the extension node failed
*/
public void parseNamedWriteable(DiscoveryNode extensionNode, Class categoryClass, StreamInput context) throws UnknownHostException {
public void parseNamedWriteable(DiscoveryNode extensionNode, Class<? extends NamedWriteable> categoryClass, StreamInput context)
throws UnknownHostException {
NamedWriteableRegistryParseResponseHandler namedWriteableRegistryParseResponseHandler =
new NamedWriteableRegistryParseResponseHandler();
try {
Expand Down Expand Up @@ -98,16 +100,16 @@ public void handleResponse(NamedWriteableRegistryResponse response) {
if (response.getRegistry().isEmpty() == false) {

// Extension has sent over entries to register, initialize inner category map
Map<Class, Map<String, ExtensionReader>> categoryMap = new HashMap<>();
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = new HashMap<>();

// Reader map associated with this current category
Map<String, ExtensionReader> readers = null;
Class currentCategory = null;
Class<? extends NamedWriteable> currentCategory = null;

for (Map.Entry<String, Class> entry : response.getRegistry().entrySet()) {
for (Map.Entry<String, Class<? extends NamedWriteable>> entry : response.getRegistry().entrySet()) {

String name = entry.getKey();
Class categoryClass = entry.getValue();
Class<? extends NamedWriteable> categoryClass = entry.getValue();
if (currentCategory != categoryClass) {
// After first pass, readers and current category are set
if (currentCategory != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ public void testNamedWriteableRegistryResponseHandler() throws Exception {
String requestType = ExtensionsOrchestrator.REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY;

// Create response to pass to response handler
Map<String, Class> responseRegistry = new HashMap<>();
Map<String, Class<? extends NamedWriteable>> responseRegistry = new HashMap<>();
responseRegistry.put(Example.NAME, Example.class);
NamedWriteableRegistryResponse response = new NamedWriteableRegistryResponse(responseRegistry);

Expand All @@ -761,10 +761,11 @@ public void testNamedWriteableRegistryResponseHandler() throws Exception {
responseHandler.handleResponse(response);

// Ensure that response entries have been processed correctly into their respective maps
Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionsRegistry = responseHandler.getExtensionRegistry();
Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionsRegistry = responseHandler
.getExtensionRegistry();
assertEquals(extensionsRegistry.size(), 1);

Map<Class, Map<String, ExtensionReader>> categoryMap = extensionsRegistry.get(extensionNode);
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = extensionsRegistry.get(extensionNode);
assertEquals(categoryMap.size(), 1);

Map<String, ExtensionReader> readerMap = categoryMap.get(Example.class);
Expand Down Expand Up @@ -798,7 +799,7 @@ public void testParseNamedWriteables() throws Exception {
String requestType = ExtensionsOrchestrator.REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE;
List<DiscoveryExtension> extensionsList = new ArrayList<>(extensionsOrchestrator.extensionIdMap.values());
DiscoveryNode extensionNode = extensionsList.get(0);
Class categoryClass = Example.class;
Class<? extends NamedWriteable> categoryClass = Example.class;

// convert context into an input stream then stream input for mock
byte[] context = new byte[0];
Expand Down