diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java index ab0fc45461f1c..f3515109a3fac 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -93,7 +93,7 @@ public Classification(String actualField, } private static List defaultMetrics() { - return Arrays.asList(new MulticlassConfusionMatrix()); + return Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java index 978ac0c74cded..4c6ee1af7f6bb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java @@ -80,6 +80,10 @@ public Huber(StreamInput in) throws IOException { this.delta = in.readDouble(); } + public Huber() { + this(DEFAULT_DELTA); + } + public Huber(@Nullable Double delta) { this.delta = delta != null ? delta : DEFAULT_DELTA; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index a90e6821255d7..a8bf53c781272 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -76,7 +76,7 @@ public Regression(String actualField, String predictedField, @Nullable List defaultMetrics() { - return Arrays.asList(new MeanSquaredError(), new RSquared()); + return Arrays.asList(new MeanSquaredError(), new RSquared(), new Huber()); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index 5466d2fa088ba..5a13540dc6c4c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -40,6 +40,7 @@ import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -110,6 +111,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Classification classification = new Classification("actual", "predicted", null, null); + + List metrics = classification.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())); + } + public void testGetFields() { Classification evaluation = new Classification("foo", "bar", "results", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java index c0b72dbe1c234..d68daf218882f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -89,6 +90,17 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + OutlierDetection outlierDetection = new OutlierDetection("actual", "predicted", null); + + List metrics = outlierDetection.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new AucRoc(false), + new Precision(Arrays.asList(0.25, 0.5, 0.75)), + new Recall(Arrays.asList(0.25, 0.5, 0.75)), + new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)))); + } + public void testGetFields() { OutlierDetection evaluation = new OutlierDetection("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index c8fc2d5d67d55..c6d060a3b84e4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -76,6 +77,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Regression regression = new Regression("actual", "predicted", null); + + List metrics = regression.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Huber(), new MeanSquaredError(), new RSquared())); + } + public void testGetFields() { Regression evaluation = new Regression("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 931db5026535b..e4e8373d56239 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -33,6 +33,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -82,7 +83,13 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MulticlassConfusionMatrix.NAME.getPreferredName())); + containsInAnyOrder( + MulticlassConfusionMatrix.NAME.getPreferredName(), + Accuracy.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java index 6aca51bc15e5d..012a89a13247e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java @@ -25,6 +25,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -53,7 +54,12 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName())); + containsInAnyOrder( + MeanSquaredError.NAME.getPreferredName(), + RSquared.NAME.getPreferredName(), + Huber.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 09cf11d266612..83fe922c02492 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -938,6 +938,10 @@ setup: } - is_true: classification.multiclass_confusion_matrix + - is_true: classification.accuracy + - is_true: classification.precision + - is_true: classification.recall + - is_false: classification.auc_roc --- "Test classification given missing actual_field": - do: @@ -1104,8 +1108,8 @@ setup: - match: { regression.mse.value: 28.67749840974834 } - match: { regression.r_squared.value: 0.8551031778603486 } + - match: { regression.huber.value: 1.9205280586939963 } - is_false: regression.msle.value - - is_false: regression.huber.value --- "Test regression given missing actual_field": - do: