Skip to content

Commit

Permalink
add generic type
Browse files Browse the repository at this point in the history
Signed-off-by: mloufra <[email protected]>
  • Loading branch information
mloufra committed Oct 25, 2022
1 parent 58f2c6e commit 5f59f41
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
*/
public class NamedWriteableRegistryParseRequest extends TransportRequest {

private final Class categoryClass;
private final Class<? extends NamedWriteable> categoryClass;
private byte[] context;

/**
* @param categoryClass Class category for this parse request
* @param context StreamInput object to convert into a byte array and transport to the extension
* @throws IllegalArgumentException if context bytes could not be read
*/
public NamedWriteableRegistryParseRequest(Class categoryClass, StreamInput context) {
public NamedWriteableRegistryParseRequest(Class<? extends NamedWriteable> categoryClass, StreamInput context) {
try {
byte[] streamInputBytes = context.readAllBytes();
this.categoryClass = categoryClass;
Expand All @@ -42,10 +42,11 @@ public NamedWriteableRegistryParseRequest(Class categoryClass, StreamInput conte
* @param in StreamInput from which class fields are read from
* @throws IllegalArgumentException if the fully qualified class name is invalid and the class object cannot be generated at runtime
*/
@SuppressWarnings("unchecked")
public NamedWriteableRegistryParseRequest(StreamInput in) throws IOException {
super(in);
try {
this.categoryClass = Class.forName(in.readString());
this.categoryClass = (Class<? extends NamedWriteable>) Class.forName(in.readString());
this.context = in.readByteArray();
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Category class definition not found", e);
Expand Down Expand Up @@ -85,7 +86,7 @@ public int hashCode() {
/**
* Returns the class instance of the category class sent over by the SDK
*/
public Class getCategoryClass() {
public Class<? extends NamedWriteable> getCategoryClass() {
return this.categoryClass;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,29 @@
*/
public class NamedWriteableRegistryResponse extends TransportResponse {

private final Map<String, Class> registry;
private final Map<String, Class<? extends NamedWriteable>> registry;

/**
* @param registry Map of writeable names and their associated category class
*/
public NamedWriteableRegistryResponse(Map<String, Class> registry) {
public NamedWriteableRegistryResponse(Map<String, Class<? extends NamedWriteable>> registry) {
this.registry = new HashMap<>(registry);
}

/**
* @param in StreamInput from which map entries of writeable names and their associated category classes are read from
* @throws IllegalArgumentException if the fully qualified class name is invalid and the class object cannot be generated at runtime
*/
@SuppressWarnings("unchecked")
public NamedWriteableRegistryResponse(StreamInput in) throws IOException {
super(in);
// Stream output for registry map begins with a variable integer that tells us the number of entries being sent across the wire
Map<String, Class> registry = new HashMap<>();
Map<String, Class<? extends NamedWriteable>> registry = new HashMap<>();
int registryEntryCount = in.readVInt();
for (int i = 0; i < registryEntryCount; i++) {
try {
String name = in.readString();
Class categoryClass = Class.forName(in.readString());
Class<? extends NamedWriteable> categoryClass = (Class<? extends NamedWriteable>) Class.forName(in.readString());
registry.put(name, categoryClass);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Category class definition not found", e);
Expand All @@ -57,7 +58,7 @@ public NamedWriteableRegistryResponse(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
// Stream out registry size prior to streaming out registry entries
out.writeVInt(this.registry.size());
for (Map.Entry<String, Class> entry : registry.entrySet()) {
for (Map.Entry<String, Class<? extends NamedWriteable>> entry : registry.entrySet()) {
out.writeString(entry.getKey()); // Unique named writeable name
out.writeString(entry.getValue().getName()); // Fully qualified category class name
}
Expand All @@ -84,7 +85,7 @@ public int hashCode() {
/**
* Returns a map of writeable names and their associated category class
*/
public Map<String, Class> getRegistry() {
public Map<String, Class<? extends NamedWriteable>> getRegistry() {
return Collections.unmodifiableMap(this.registry);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.extensions.ExtensionsOrchestrator.OpenSearchRequestType;
import org.opensearch.transport.TransportService;

Expand All @@ -29,7 +30,7 @@ public class ExtensionNamedWriteableRegistry {

private static final Logger logger = LogManager.getLogger(ExtensionNamedWriteableRegistry.class);

private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionNamedWriteableRegistry;
private Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionNamedWriteableRegistry;
private List<DiscoveryExtension> extensionsInitializedList;
private TransportService transportService;

Expand All @@ -54,7 +55,8 @@ public void getNamedWriteables() {
// Retrieve named writeable registry entries from each extension
for (DiscoveryNode extensionNode : extensionsInitializedList) {
try {
Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionRegistry = getNamedWriteables(extensionNode);
Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionRegistry =
getNamedWriteables(extensionNode);
if (extensionRegistry.isEmpty() == false) {
this.extensionNamedWriteableRegistry.putAll(extensionRegistry);
}
Expand All @@ -74,8 +76,9 @@ public void getNamedWriteables() {
* @throws UnknownHostException if connection to the extension node failed
* @return A map of category classes and their associated names and readers for this discovery node
*/
private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getNamedWriteables(DiscoveryNode extensionNode)
throws UnknownHostException {
private Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> getNamedWriteables(
DiscoveryNode extensionNode
) throws UnknownHostException {
NamedWriteableRegistryResponseHandler namedWriteableRegistryResponseHandler = new NamedWriteableRegistryResponseHandler(
extensionNode,
transportService,
Expand Down Expand Up @@ -104,7 +107,7 @@ private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getNamedWri
* @throws IllegalArgumentException if there is no reader associated with the given category class and name
* @return A map of the discovery node and its associated extension reader
*/
public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class categoryClass, String name) {
public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class<? extends NamedWriteable> categoryClass, String name) {

ExtensionReader reader = null;
DiscoveryNode extension = null;
Expand Down Expand Up @@ -133,9 +136,11 @@ public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class categoryClas
* @param name Unique name identifying the Writeable object
* @return The extension reader
*/
private ExtensionReader getExtensionReader(DiscoveryNode extensionNode, Class categoryClass, String name) {
private ExtensionReader getExtensionReader(DiscoveryNode extensionNode, Class<? extends NamedWriteable> categoryClass, String name) {
ExtensionReader reader = null;
Map<Class, Map<String, ExtensionReader>> categoryMap = this.extensionNamedWriteableRegistry.get(extensionNode);
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = this.extensionNamedWriteableRegistry.get(
extensionNode
);
if (categoryMap != null) {
Map<String, ExtensionReader> readerMap = categoryMap.get(categoryClass);
if (readerMap != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.common.io.stream.NamedWriteableRegistryParseRequest;
import org.opensearch.common.io.stream.NamedWriteableRegistryResponse;
import org.opensearch.common.io.stream.StreamInput;
Expand All @@ -34,7 +35,7 @@
public class NamedWriteableRegistryResponseHandler implements TransportResponseHandler<NamedWriteableRegistryResponse> {
private static final Logger logger = LogManager.getLogger(NamedWriteableRegistryResponseHandler.class);

private final Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionRegistry;
private final Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionRegistry;
private final DiscoveryNode extensionNode;
private final TransportService transportService;
private final String requestType;
Expand All @@ -56,7 +57,7 @@ public NamedWriteableRegistryResponseHandler(DiscoveryNode extensionNode, Transp
/**
* @return A map of the given DiscoveryNode and its inner named writeable registry map
*/
public Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getExtensionRegistry() {
public Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> getExtensionRegistry() {
return Collections.unmodifiableMap(this.extensionRegistry);
}

Expand All @@ -68,7 +69,8 @@ public Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getExtension
* @param context StreamInput object to convert into a byte array and transport to the extension
* @throws UnknownHostException if connection to the extension node failed
*/
public void parseNamedWriteable(DiscoveryNode extensionNode, Class categoryClass, StreamInput context) throws UnknownHostException {
public void parseNamedWriteable(DiscoveryNode extensionNode, Class<? extends NamedWriteable> categoryClass, StreamInput context)
throws UnknownHostException {
NamedWriteableRegistryParseResponseHandler namedWriteableRegistryParseResponseHandler =
new NamedWriteableRegistryParseResponseHandler();
try {
Expand Down Expand Up @@ -98,16 +100,16 @@ public void handleResponse(NamedWriteableRegistryResponse response) {
if (response.getRegistry().isEmpty() == false) {

// Extension has sent over entries to register, initialize inner category map
Map<Class, Map<String, ExtensionReader>> categoryMap = new HashMap<>();
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = new HashMap<>();

// Reader map associated with this current category
Map<String, ExtensionReader> readers = null;
Class currentCategory = null;
Class<? extends NamedWriteable> currentCategory = null;

for (Map.Entry<String, Class> entry : response.getRegistry().entrySet()) {
for (Map.Entry<String, Class<? extends NamedWriteable>> entry : response.getRegistry().entrySet()) {

String name = entry.getKey();
Class categoryClass = entry.getValue();
Class<? extends NamedWriteable> categoryClass = entry.getValue();
if (currentCategory != categoryClass) {
// After first pass, readers and current category are set
if (currentCategory != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ public void testNamedWriteableRegistryResponseHandler() throws Exception {
String requestType = ExtensionsOrchestrator.REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY;

// Create response to pass to response handler
Map<String, Class> responseRegistry = new HashMap<>();
Map<String, Class<? extends NamedWriteable>> responseRegistry = new HashMap<>();
responseRegistry.put(Example.NAME, Example.class);
NamedWriteableRegistryResponse response = new NamedWriteableRegistryResponse(responseRegistry);

Expand All @@ -761,10 +761,11 @@ public void testNamedWriteableRegistryResponseHandler() throws Exception {
responseHandler.handleResponse(response);

// Ensure that response entries have been processed correctly into their respective maps
Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionsRegistry = responseHandler.getExtensionRegistry();
Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionsRegistry = responseHandler
.getExtensionRegistry();
assertEquals(extensionsRegistry.size(), 1);

Map<Class, Map<String, ExtensionReader>> categoryMap = extensionsRegistry.get(extensionNode);
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = extensionsRegistry.get(extensionNode);
assertEquals(categoryMap.size(), 1);

Map<String, ExtensionReader> readerMap = categoryMap.get(Example.class);
Expand Down Expand Up @@ -798,7 +799,7 @@ public void testParseNamedWriteables() throws Exception {
String requestType = ExtensionsOrchestrator.REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE;
List<DiscoveryExtension> extensionsList = new ArrayList<>(extensionsOrchestrator.extensionIdMap.values());
DiscoveryNode extensionNode = extensionsList.get(0);
Class categoryClass = Example.class;
Class<? extends NamedWriteable> categoryClass = Example.class;

// convert context into an input stream then stream input for mock
byte[] context = new byte[0];
Expand Down

0 comments on commit 5f59f41

Please sign in to comment.