Skip to content

Commit

Permalink
Fix model still deployed after calling undeploy API (#2510)
Browse files Browse the repository at this point in the history
* Fix model still deployed after calling undeploy API

Signed-off-by: Sicheng Song <[email protected]>

* Add UT coverage

Signed-off-by: Sicheng Song <[email protected]>

* Fix style

Signed-off-by: Sicheng Song <[email protected]>

* Add UT coverage

Signed-off-by: Sicheng Song <[email protected]>

* Add UT coverage

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo authored Jun 11, 2024
1 parent b051160 commit 22b558d
Show file tree
Hide file tree
Showing 3 changed files with 437 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
Expand All @@ -42,10 +41,10 @@
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodeResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand All @@ -59,11 +58,8 @@ public class TransportUndeployModelAction extends
private final MLModelManager mlModelManager;
private final ClusterService clusterService;
private final Client client;
private DiscoveryNodeHelper nodeFilter;
private final DiscoveryNodeHelper nodeFilter;
private final MLStats mlStats;
private NamedXContentRegistry xContentRegistry;

private ModelAccessControlHelper modelAccessControlHelper;

@Inject
public TransportUndeployModelAction(
Expand All @@ -74,9 +70,7 @@ public TransportUndeployModelAction(
ThreadPool threadPool,
Client client,
DiscoveryNodeHelper nodeFilter,
MLStats mlStats,
NamedXContentRegistry xContentRegistry,
ModelAccessControlHelper modelAccessControlHelper
MLStats mlStats
) {
super(
MLUndeployModelAction.NAME,
Expand All @@ -90,107 +84,128 @@ public TransportUndeployModelAction(
MLUndeployModelNodeResponse.class
);
this.mlModelManager = mlModelManager;

this.clusterService = clusterService;
this.client = client;
this.nodeFilter = nodeFilter;
this.mlStats = mlStats;
this.xContentRegistry = xContentRegistry;
this.modelAccessControlHelper = modelAccessControlHelper;
}

@Override
protected MLUndeployModelNodesResponse newResponse(
MLUndeployModelNodesRequest nodesRequest,
List<MLUndeployModelNodeResponse> responses,
List<FailedNodeException> failures
protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener<MLUndeployModelNodesResponse> listener) {
ActionListener<MLUndeployModelNodesResponse> wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> {
processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener);
}, listener::onFailure);
super.doExecute(task, request, wrappedListener);
}

void processUndeployModelResponseAndUpdate(
MLUndeployModelNodesResponse undeployModelNodesResponse,
ActionListener<MLUndeployModelNodesResponse> listener
) {
if (responses != null) {
Map<String, List<String>> actualRemovedNodesMap = new HashMap<>();
Map<String, String[]> modelWorkNodesBeforeRemoval = new HashMap<>();
responses.forEach(r -> {
Map<String, String[]> nodeCounts = r.getModelWorkerNodeBeforeRemoval();

if (nodeCounts != null) {
for (Map.Entry<String, String[]> entry : nodeCounts.entrySet()) {
// when undeploy a undeployed model, the entry.getvalue() is null
if (entry.getValue() != null
&& (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey())
|| modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) {
modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue());
}
List<MLUndeployModelNodeResponse> responses = undeployModelNodesResponse.getNodes();
if (responses == null || responses.isEmpty()) {
listener.onResponse(undeployModelNodesResponse);
return;
}

Map<String, List<String>> actualRemovedNodesMap = new HashMap<>();
Map<String, String[]> modelWorkNodesBeforeRemoval = new HashMap<>();
responses.forEach(r -> {
Map<String, String[]> nodeCounts = r.getModelWorkerNodeBeforeRemoval();

if (nodeCounts != null) {
for (Map.Entry<String, String[]> entry : nodeCounts.entrySet()) {
// when undeploy an undeployed model, the entry.getvalue() is null
if (entry.getValue() != null
&& (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey())
|| modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) {
modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue());
}
}
}

Map<String, String> modelUndeployStatus = r.getModelUndeployStatus();
for (Map.Entry<String, String> entry : modelUndeployStatus.entrySet()) {
String status = entry.getValue();
if (UNDEPLOYED.equals(status)) {
String modelId = entry.getKey();
if (!actualRemovedNodesMap.containsKey(modelId)) {
actualRemovedNodesMap.put(modelId, new ArrayList<>());
}
actualRemovedNodesMap.get(modelId).add(r.getNode().getId());
Map<String, String> modelUndeployStatus = r.getModelUndeployStatus();
for (Map.Entry<String, String> entry : modelUndeployStatus.entrySet()) {
String status = entry.getValue();
if (UNDEPLOYED.equals(status)) {
String modelId = entry.getKey();
if (!actualRemovedNodesMap.containsKey(modelId)) {
actualRemovedNodesMap.put(modelId, new ArrayList<>());
}
actualRemovedNodesMap.get(modelId).add(r.getNode().getId());
}
});

MLSyncUpInput syncUpInput = MLSyncUpInput
.builder()
.removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap))
.build();

MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (actualRemovedNodesMap.size() > 0) {
BulkRequest bulkRequest = new BulkRequest();
Map<String, Boolean> deployToAllNodes = new HashMap<>();
for (String modelId : actualRemovedNodesMap.keySet()) {
UpdateRequest updateRequest = new UpdateRequest();
List<String> removedNodes = actualRemovedNodesMap.get(modelId);
int removedNodeCount = removedNodes.size();
/**
* If allow custom deploy is false, user can only undeploy all nodes and status is undeployed.
* If allow custom deploy is true, user can undeploy all nodes and status is undeployed,
* or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and
* we need to update both planning worker nodes (count) and current worker nodes (count)
* and deployToAllNodes value in model index.
*/
Map<String, Object> updateDocument = new HashMap<>();
if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes.
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of());
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED);
} else { // undeploy partial nodes.
// TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed,
// and the user could be undeploying not running model nodes, and we should update model status to deployed.
updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false);
List<String> newPlanningWorkerNodes = Arrays
.stream(modelWorkNodesBeforeRemoval.get(modelId))
.filter(x -> !removedNodes.contains(x))
.collect(Collectors.toList());
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes);
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
deployToAllNodes.put(modelId, false);
}
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument);
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
}
});

MLSyncUpInput syncUpInput = MLSyncUpInput
.builder()
.removedWorkerNodes(covertRemoveNodesMapForSyncUp(actualRemovedNodesMap))
.build();

MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (actualRemovedNodesMap.size() > 0) {
BulkRequest bulkRequest = new BulkRequest();
Map<String, Boolean> deployToAllNodes = new HashMap<>();
for (String modelId : actualRemovedNodesMap.keySet()) {
UpdateRequest updateRequest = new UpdateRequest();
List<String> removedNodes = actualRemovedNodesMap.get(modelId);
int removedNodeCount = removedNodes.size();
/**
* If allow custom deploy is false, user can only undeploy all nodes and status is undeployed.
* If allow custom deploy is true, user can undeploy all nodes and status is undeployed,
* or undeploy partial nodes, and status is deployed, this case means user created a new deployment plan, and
* we need to update both planning worker nodes (count) and current worker nodes (count)
* and deployToAllNodes value in model index.
*/
Map<String, Object> updateDocument = new HashMap<>();
if (modelWorkNodesBeforeRemoval.get(modelId).length == removedNodeCount) { // undeploy all nodes.
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, ImmutableList.of());
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, 0);
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, 0);
updateDocument.put(MLModel.MODEL_STATE_FIELD, MLModelState.UNDEPLOYED);
} else { // undeploy partial nodes.
// TODO (to fix) when undeploy partial nodes, the original model status could be partially_deployed,
// and the user could be undeploying not running model nodes, and we should update model status to deployed.
updateDocument.put(MLModel.DEPLOY_TO_ALL_NODES_FIELD, false);
List<String> newPlanningWorkerNodes = Arrays
.stream(modelWorkNodesBeforeRemoval.get(modelId))
.filter(x -> !removedNodes.contains(x))
.collect(Collectors.toList());
updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkerNodes);
updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size());
deployToAllNodes.put(modelId, false);
}
syncUpInput.setDeployToAllNodes(deployToAllNodes);
ActionListener<BulkResponse> actionListener = ActionListener.wrap(r -> {
log
.debug(
"updated model state as undeployed for : {}",
Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))
);
}, e -> { log.error("Failed to update model state as undeployed", e); });
client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { syncUpUndeployedModels(syncUpRequest); }));
} else {
syncUpUndeployedModels(syncUpRequest);
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument);
bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
}
syncUpInput.setDeployToAllNodes(deployToAllNodes);
ActionListener<BulkResponse> actionListener = ActionListener.wrap(r -> {
log
.debug(
"updated model state as undeployed for : {}",
Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))
);
}, e -> { log.error("Failed to update model state as undeployed", e); });
client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> {
syncUpUndeployedModels(syncUpRequest);
listener.onResponse(undeployModelNodesResponse);
}));
} else {
syncUpUndeployedModels(syncUpRequest);
listener.onResponse(undeployModelNodesResponse);
}
}
}

@Override
protected MLUndeployModelNodesResponse newResponse(
MLUndeployModelNodesRequest nodesRequest,
List<MLUndeployModelNodeResponse> responses,
List<FailedNodeException> failures
) {
return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures);
}

Expand Down
Loading

0 comments on commit 22b558d

Please sign in to comment.