Skip to content

Commit

Permalink
[Backport to main] update connector API (#1651)
Browse files Browse the repository at this point in the history
* update connector API

Signed-off-by: Xun Zhang <[email protected]>

* more ut test coverage

Signed-off-by: Xun Zhang <[email protected]>

* check connector usage in deployed models before updating connector

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
Co-authored-by: Xun Zhang <[email protected]>
  • Loading branch information
rbhavna and Zhangxunmt authored Nov 16, 2023
1 parent 4d53db5 commit 5759bf2
Show file tree
Hide file tree
Showing 9 changed files with 987 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class MLUpdateConnectorAction extends ActionType<UpdateResponse> {
public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction();
public static final String NAME = "cluster:admin/opensearch/ml/connectors/update";

private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
public class MLUpdateConnectorRequest extends ActionRequest {
String connectorId;
Map<String, Object> updateContent;

@Builder
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) {
this.connectorId = connectorId;
this.updateContent = updateContent;
}

public MLUpdateConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.updateContent = in.readMap();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
out.writeMap(this.getUpdateContent());
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.connectorId == null) {
exception = addValidationError("ML connector id can't be null", exception);
}

return exception;
}

public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();

return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build();
}

public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLUpdateConnectorRequest) {
return (MLUpdateConnectorRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUpdateConnectorRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLUpdateConnectorRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.when;

public class MLUpdateConnectorRequestTests {
private String connectorId;
private Map<String, Object> updateContent;
private MLUpdateConnectorRequest mlUpdateConnectorRequest;

@Mock
XContentParser parser;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
this.connectorId = "test-connector_id";
this.updateContent = Map.of("description", "new description");
mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
.build();
}

@Test
public void writeTo_Success() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
mlUpdateConnectorRequest.writeTo(bytesStreamOutput);
MLUpdateConnectorRequest parsedUpdateRequest = new MLUpdateConnectorRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(connectorId, parsedUpdateRequest.getConnectorId());
assertEquals(updateContent, parsedUpdateRequest.getUpdateContent());
}

@Test
public void validate_Success() {
assertNull(mlUpdateConnectorRequest.validate());
}

@Test
public void validate_Exception_NullConnectorId() {
MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build();
Exception exception = updateConnectorRequest.validate();

assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage());
}

@Test
public void parse_success() throws IOException {
RestRequest.Method method = RestRequest.Method.POST;
final Map<String, Object> updatefields = Map.of("version", "new version", "description", "new description");
when(parser.map()).thenReturn(updatefields);

MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId);
assertEquals(updateConnectorRequest.getConnectorId(), connectorId);
assertEquals(updateConnectorRequest.getUpdateContent(), updatefields);
}

@Test
public void fromActionRequest_Success() {
MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
.build();
assertSame(MLUpdateConnectorRequest.fromActionRequest(mlUpdateConnectorRequest), mlUpdateConnectorRequest);
}

@Test
public void fromActionRequest_Success_fromActionRequest() {
MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder()
.connectorId(connectorId)
.updateContent(updateContent)
.build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlUpdateConnectorRequest.writeTo(out);
}
};
MLUpdateConnectorRequest request = MLUpdateConnectorRequest.fromActionRequest(actionRequest);
assertNotSame(request, mlUpdateConnectorRequest);
assertEquals(mlUpdateConnectorRequest.getConnectorId(), request.getConnectorId());
assertEquals(mlUpdateConnectorRequest.getUpdateContent(), request.getUpdateContent());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequest_IOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLUpdateConnectorRequest.fromActionRequest(actionRequest);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.connector;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

@Log4j2
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class UpdateConnectorTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;

ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelManager mlModelManager;

@Inject
public UpdateConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ConnectorAccessControlHelper connectorAccessControlHelper,
MLModelManager mlModelManager
) {
super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new);
this.client = client;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.mlModelManager = mlModelManager;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request);
String connectorId = mlUpdateConnectorAction.getConnectorId();
UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId);
updateRequest.doc(mlUpdateConnectorAction.getUpdateContent());
updateRequest.docAsUpsert(true);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> {
if (Boolean.TRUE.equals(hasPermission)) {
updateUndeployedConnector(connectorId, updateRequest, listener, context);
} else {
listener
.onFailure(
new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId)
);
}
}, exception -> {
log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception);
listener.onFailure(exception);
}));
} catch (Exception e) {
log.error("Failed to update ML connector for connector id {}. Details {}:", connectorId, e);
listener.onFailure(e);
}
}

private void updateUndeployedConnector(
String connectorId,
UpdateRequest updateRequest,
ActionListener<UpdateResponse> listener,
ThreadContext.StoredContext context
) {
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
boolQueryBuilder.must(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId));
boolQueryBuilder.must(QueryBuilders.idsQuery().addIds(mlModelManager.getAllModelIds()));
sourceBuilder.query(boolQueryBuilder);
searchRequest.source(sourceBuilder);

client.search(searchRequest, ActionListener.wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
} else {
log.error(searchHits.length + " models are still using this connector, please undeploy the models first!");
listener
.onFailure(
new MLValidationException(
searchHits.length + " models are still using this connector, please undeploy the models first!"
)
);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
return;
}
log.error("Failed to update ML connector: " + connectorId, e);
listener.onFailure(e);

}));
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String connectorId,
ActionListener<UpdateResponse> actionListener,
ThreadContext.StoredContext context
) {
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
log.info("Failed to update the connector with ID: {}", connectorId);
actionListener.onResponse(updateResponse);
return;
}
log.info("Successfully updated the connector with ID: {}", connectorId);
actionListener.onResponse(updateResponse);
}, exception -> {
log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception);
actionListener.onFailure(exception);
}), context::restore);
}
}
Loading

0 comments on commit 5759bf2

Please sign in to comment.