Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump RCF Version and Fix Default Rules Bug in AnomalyDetector #1334

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.12.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.2.0'

// we inherit jackson-core from opensearch core
implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1"
Expand Down Expand Up @@ -700,9 +700,6 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler',
'org.opensearch.timeseries.transport.SingleStreamResultRequest',
'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public AnomalyDetector(

this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name();

this.rules = rules == null ? getDefaultRule() : rules;
this.rules = rules == null || rules.isEmpty() ? getDefaultRule() : rules;
}

/*
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/opensearch/timeseries/JobProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void process(Job jobParameter, JobExecutionContext context) {
* @param executionStartTime analysis start time
* @param executionEndTime analysis end time
* @param recorder utility to record job execution result
* @param detector associated detector accessor
* @param config associated config accessor
*/
public void runJob(
Job jobParameter,
Expand All @@ -209,7 +209,7 @@ public void runJob(
Instant executionStartTime,
Instant executionEndTime,
ExecuteResultResponseRecorderType recorder,
Config detector
Config config
) {
String configId = jobParameter.getName();
if (lock == null) {
Expand All @@ -222,7 +222,7 @@ public void runJob(
"Can't run job due to null lock",
false,
recorder,
detector
config
);
return;
}
Expand All @@ -243,7 +243,7 @@ public void runJob(
user,
roles,
recorder,
detector
config
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,18 @@ protected void executeRequest(FeatureRequest coldStartRequest, ActionListener<Vo
);
IntermediateResultType result = modelManager.getResult(currentSample, modelState, modelId, config, taskId);
resultSaver.saveResult(result, config, coldStartRequest, modelId);
}

// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
}
}

} finally {
listener.onResponse(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ public void stopJob(String configId, TransportService transportService, ActionLi
}));
}

private ActionListener<StopConfigResponse> stopConfigListener(
public ActionListener<StopConfigResponse> stopConfigListener(
String configId,
TransportService transportService,
ActionListener<JobResponse> listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void bulk(String resultIndexOrAlias, List<ResultType> results, String con
} catch (Exception e) {
String error = "Failed to bulk index result";
LOG.error(error, e);
listener.onFailure(new TimeSeriesException(error, e));
listener.onFailure(new TimeSeriesException(configId, error, e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT
if (relative) {
thresholdType1 = "actual_over_expected_ratio";
thresholdType2 = "expected_over_actual_ratio";
value = 0.3;
value = 0.2;
} else {
thresholdType1 = "actual_over_expected_margin";
thresholdType2 = "expected_over_actual_margin";
Expand Down
63 changes: 63 additions & 0 deletions src/test/java/org/opensearch/ad/ml/ADColdStartTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ad.ml;

import java.io.IOException;
import java.util.ArrayList;

import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.timeseries.TestHelpers;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;

public class ADColdStartTests extends OpenSearchTestCase {
private int baseDimensions = 1;
private int shingleSize = 8;
private int dimensions;

@Override
public void setUp() throws Exception {
super.setUp();
dimensions = baseDimensions * shingleSize;
}

/**
* Test if no explicit rule is provided, we apply 20% rule.
* @throws IOException when failing to constructor detector
*/
public void testEmptyRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(new ArrayList<>()).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}

public void testNullRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(null).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}
}
29 changes: 29 additions & 0 deletions src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.lucene.search.TotalHits;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.Version;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
Expand Down Expand Up @@ -104,6 +105,7 @@
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -136,6 +138,7 @@
import org.opensearch.timeseries.transport.JobResponse;
import org.opensearch.timeseries.transport.StatsNodeResponse;
import org.opensearch.timeseries.transport.StatsNodesResponse;
import org.opensearch.timeseries.transport.StopConfigResponse;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.timeseries.util.DiscoveryNodeFilterer;
import org.opensearch.transport.TransportResponseHandler;
Expand Down Expand Up @@ -1544,4 +1547,30 @@ public void testDeleteTaskDocs() {
verify(adTaskCacheManager, times(1)).addDeletedTask(anyString());
verify(function, times(1)).execute();
}

public void testStopConfigListener_onResponse_failure() {
// Arrange
String configId = randomAlphaOfLength(5);
TransportService transportService = mock(TransportService.class);
@SuppressWarnings("unchecked")
ActionListener<JobResponse> listener = mock(ActionListener.class);

// Act
ActionListener<StopConfigResponse> stopConfigListener = indexAnomalyDetectorJobActionHandler
.stopConfigListener(configId, transportService, listener);
StopConfigResponse stopConfigResponse = mock(StopConfigResponse.class);
when(stopConfigResponse.success()).thenReturn(false);

stopConfigListener.onResponse(stopConfigResponse);

// Assert
ArgumentCaptor<OpenSearchStatusException> exceptionCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);

verify(adTaskManager, times(1))
.stopLatestRealtimeTask(eq(configId), eq(TaskState.FAILED), exceptionCaptor.capture(), eq(transportService), eq(listener));

OpenSearchStatusException capturedException = exceptionCaptor.getValue();
assertEquals("Failed to delete model", capturedException.getMessage());
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, capturedException.status());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.time.Clock;
import java.util.Optional;

import org.opensearch.ResourceAlreadyExistsException;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.bulk.BulkAction;
Expand All @@ -43,11 +44,13 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.timeseries.util.IndexUtils;
Expand Down Expand Up @@ -232,4 +235,127 @@ private AnomalyResult wrongAnomalyResult() {
null
);
}

public void testResponseIsAcknowledgedTrue() throws InterruptedException {
String testIndex = "testIndex";

// Set up mocks for doesIndexExist and doesAliasExist
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false);

// Mock initCustomResultIndexDirectly to simulate index creation and call the listener
doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
// Simulate immediate onResponse call
listener.onResponse(new CreateIndexResponse(true, true, testIndex));
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);

// Call bulk method
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onResponse is called
verify(client, times(1)).prepareBulk();
}

public void testResponseIsAcknowledgedFalse() {
String testIndex = "testIndex";
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false);

doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
listener.onResponse(new CreateIndexResponse(false, false, testIndex));
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("Creating custom result index with mappings call not acknowledged", exceptionCaptor.getValue().getMessage());
}

public void testResourceAlreadyExistsException() {
String testIndex = "testIndex";
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false, true);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false, false);

doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
listener.onFailure(new ResourceAlreadyExistsException("index already exists"));
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(1);
listener.onResponse(true);
return null;
}).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onResponse is called
verify(client, times(1)).prepareBulk();
}

public void testOtherException() {
String testIndex = "testIndex";
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false);

Exception testException = new OpenSearchRejectedExecutionException("Test exception");

doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
listener.onFailure(testException);
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals(testException, exceptionCaptor.getValue());
}

public void testTimeSeriesExceptionCaughtInBulk() {
String testIndex = "testIndex";
TimeSeriesException testException = new TimeSeriesException("Test TimeSeriesException");

// Mock doesIndexExist to throw TimeSeriesException
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenThrow(testException);

AnomalyResult result = mock(AnomalyResult.class);

// Call bulk method
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onFailure is called with the TimeSeriesException
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals(testException, exceptionCaptor.getValue());
}

public void testExceptionCaughtInBulk() {
String testIndex = "testIndex";
NullPointerException testException = new NullPointerException("Test NullPointerException");

// Mock doesIndexExist to throw NullPointerException
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenThrow(testException);

AnomalyResult result = mock(AnomalyResult.class);

// Call bulk method
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onFailure is called with a TimeSeriesException wrapping the original exception
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
Exception capturedException = exceptionCaptor.getValue();
assertTrue(capturedException instanceof TimeSeriesException);
assertEquals("Failed to bulk index result", capturedException.getMessage());
assertEquals(testException, capturedException.getCause());
}
}
Loading
Loading