Skip to content

Commit

Permalink
fix duplicate node if node has both data and ml role (#1829)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Jan 2, 2024
1 parent 741d4da commit 169df62
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public String[] getEligibleNodeIds(FunctionName functionName) {

public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
ClusterState state = this.clusterService.state();
final List<DiscoveryNode> eligibleNodes = new ArrayList<>();
final Set<DiscoveryNode> eligibleNodes = new HashSet<>();
for (DiscoveryNode node : state.nodes()) {
if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) {
continue;
Expand All @@ -88,7 +88,7 @@ public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
return eligibleNodes.toArray(new DiscoveryNode[0]);
}

private void getEligibleNodes(Set<String> allowedNodeRoles, List<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
private void getEligibleNodes(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
eligibleNodes.add(node);
}
Expand All @@ -110,21 +110,21 @@ public String[] filterEligibleNodes(FunctionName functionName, String[] nodeIds)
continue;
}
if (functionName == FunctionName.REMOTE) {// remote model
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
getEligibleNodeIds(remoteModelEligibleNodeRoles, eligibleNodes, node);
} else { // local model
if (onlyRunOnMLNode) {
if (MLNodeUtils.isMLNode(node)) {
eligibleNodes.add(node.getId());
}
} else {
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
getEligibleNodeIds(localModelEligibleNodeRoles, eligibleNodes, node);
}
}
}
return eligibleNodes.toArray(new String[0]);
}

private void getEligibleNodes(Set<String> allowedNodeRoles, Set<String> eligibleNodes, DiscoveryNode node) {
private void getEligibleNodeIds(Set<String> allowedNodeRoles, Set<String> eligibleNodes, DiscoveryNode node) {
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
eligibleNodes.add(node.getId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES;
import static org.opensearch.ml.utils.TestHelper.ALL_ROLES;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

Expand Down Expand Up @@ -52,6 +53,8 @@ public class DiscoveryNodeHelperTests extends OpenSearchTestCase {
private final String mlNode1Name = "mlNodeName1";
private final String mlNode2Id = "mlNode2";
private final String mlNode2Name = "mlNodeName2";
private final String allRoleNodeId = "allRoleNode";
private final String allRoleNodeName = "allRoleNodeName";
private final String clusterName = "multi-node-cluster";

@Mock
Expand All @@ -65,6 +68,7 @@ public class DiscoveryNodeHelperTests extends OpenSearchTestCase {
private DiscoveryNode warmDataNode1;
private DiscoveryNode mlNode1;
private DiscoveryNode mlNode2;
private DiscoveryNode allRoleNode;
private ClusterState clusterState;
private String nonExistingNodeName;

Expand Down Expand Up @@ -122,6 +126,14 @@ public void setup() throws IOException {
ImmutableSet.of(ML_ROLE),
Version.CURRENT
);
allRoleNode = new DiscoveryNode(
allRoleNodeName,
allRoleNodeId,
buildNewFakeTransportAddress(),
emptyMap(),
ALL_ROLES,
Version.CURRENT
);

DiscoveryNodes nodes = DiscoveryNodes
.builder()
Expand All @@ -131,6 +143,7 @@ public void setup() throws IOException {
.add(warmDataNode1)
.add(mlNode1)
.add(mlNode2)
.add(allRoleNode)
.build();
clusterState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, Map.of(), 0, false);

Expand Down Expand Up @@ -158,23 +171,35 @@ private void mockSettings(boolean onlyRunOnMLNode, String excludedNodeName) {

public void testGetEligibleNodes_MLNode_RemoteModel() {
DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.REMOTE);
assertEquals(4, eligibleNodes.length);
assertEquals(5, eligibleNodes.length);
Set<String> nodeIds = new HashSet<>();
nodeIds.addAll(Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList()));
assertTrue(nodeIds.contains(mlNode1.getId()));
assertTrue(nodeIds.contains(mlNode2.getId()));
assertTrue(nodeIds.contains(dataNode1.getId()));
assertTrue(nodeIds.contains(dataNode2.getId()));
assertTrue(nodeIds.contains(allRoleNode.getId()));
assertFalse(nodeIds.contains(warmDataNode1.getId()));
}

public void testGetEligibleNodes_MLNode_LocalModel() {
DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.TEXT_EMBEDDING);
assertEquals(2, eligibleNodes.length);
assertEquals(3, eligibleNodes.length);
Set<String> nodeIds = new HashSet<>();
nodeIds.addAll(Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList()));
assertTrue(nodeIds.contains(mlNode1.getId()));
assertTrue(nodeIds.contains(mlNode2.getId()));
assertTrue(nodeIds.contains(allRoleNode.getId()));
}

public void testGetEligibleNodes_MLNode_DataModel() {
DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.TEXT_EMBEDDING);
assertEquals(3, eligibleNodes.length);
Set<String> nodeIds = new HashSet<>();
nodeIds.addAll(Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList()));
assertTrue(nodeIds.contains(mlNode1.getId()));
assertTrue(nodeIds.contains(mlNode2.getId()));
assertTrue(nodeIds.contains(allRoleNode.getId()));
}

public void testGetEligibleNodes_DataNode() {
Expand All @@ -186,17 +211,25 @@ public void testGetEligibleNodes_DataNode() {

DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.REMOTE);
assertEquals(2, eligibleNodes.length);
assertEquals(dataNode1.getName(), eligibleNodes[0].getName());
assertEquals(dataNode2.getName(), eligibleNodes[1].getName());
Set<String> nodeNames = new HashSet<>();
nodeNames.add("dataNodeName1");
nodeNames.add("dataNodeName2");
assertTrue(nodeNames.contains(eligibleNodes[0].getName()));
assertTrue(nodeNames.contains(eligibleNodes[1].getName()));
}

public void testGetEligibleNodes_MLNode_Excluded() {
mockSettings(false, mlNode1.getName() + "," + mlNode2.getName());
DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings);
DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(FunctionName.TEXT_EMBEDDING);
assertEquals(2, eligibleNodes.length);
assertEquals(dataNode1.getName(), eligibleNodes[0].getName());
assertEquals(dataNode2.getName(), eligibleNodes[1].getName());
assertEquals(3, eligibleNodes.length);
Set<String> nodeNames = new HashSet<>();
nodeNames.add("dataNodeName1");
nodeNames.add("dataNodeName2");
nodeNames.add("allRoleNodeName");
assertTrue(nodeNames.contains(eligibleNodes[0].getName()));
assertTrue(nodeNames.contains(eligibleNodes[1].getName()));
assertTrue(nodeNames.contains(eligibleNodes[2].getName()));
}

public void testFilterEligibleNodes_Null() {
Expand Down Expand Up @@ -241,7 +274,7 @@ public void testFilterEligibleNodes_BothMLAndDataNodes() {

public void testGetAllNodeIds() {
String[] allNodeIds = discoveryNodeHelper.getAllNodeIds();
assertEquals(6, allNodeIds.length);
assertEquals(7, allNodeIds.length);
}

public void testGetNodes() {
Expand Down
11 changes: 11 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import static org.junit.Assert.assertNotNull;
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;
import static org.opensearch.cluster.node.DiscoveryNodeRole.DATA_ROLE;
import static org.opensearch.cluster.node.DiscoveryNodeRole.INGEST_ROLE;
import static org.opensearch.cluster.node.DiscoveryNodeRole.REMOTE_CLUSTER_CLIENT_ROLE;
import static org.opensearch.cluster.node.DiscoveryNodeRole.SEARCH_ROLE;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
Expand All @@ -21,12 +24,15 @@
import java.io.InputStreamReader;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -92,6 +98,11 @@ public Setting<Boolean> legacySetting() {
}
};

public static SortedSet<DiscoveryNodeRole> ALL_ROLES = Collections
.unmodifiableSortedSet(
new TreeSet<>(Arrays.asList(DATA_ROLE, INGEST_ROLE, CLUSTER_MANAGER_ROLE, REMOTE_CLUSTER_CLIENT_ROLE, SEARCH_ROLE, ML_ROLE))
);

public static XContentParser parser(String xc) throws IOException {
return parser(xc, true);
}
Expand Down

0 comments on commit 169df62

Please sign in to comment.