From 7afe8a2de24c2b374dd0e853950468addcaaf2da Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 21 Oct 2020 09:00:51 +0300 Subject: [PATCH] [7.x][ML] Extend default evaluation metrics to all available (#63939) This commit extends the set of default metrics for the data frame analytics evaluation API to all available metrics. The motivation is that if the user skips setting an explicit set of metrics, they get most of the evaluation offering. Backport of #63939 --- .../evaluation/classification/Classification.java | 2 +- .../ml/dataframe/evaluation/regression/Huber.java | 4 ++++ .../dataframe/evaluation/regression/Regression.java | 2 +- .../classification/ClassificationTests.java | 9 +++++++++ .../outlierdetection/OutlierDetectionTests.java | 12 ++++++++++++ .../evaluation/regression/RegressionTests.java | 9 +++++++++ .../ml/integration/ClassificationEvaluationIT.java | 9 ++++++++- .../xpack/ml/integration/RegressionEvaluationIT.java | 8 +++++++- .../rest-api-spec/test/ml/evaluate_data_frame.yml | 6 +++++- 9 files changed, 56 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java index ab0fc45461f1c..f3515109a3fac 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -93,7 +93,7 @@ public Classification(String actualField, } private static List defaultMetrics() { - return Arrays.asList(new MulticlassConfusionMatrix()); + return Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java index 978ac0c74cded..4c6ee1af7f6bb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java @@ -80,6 +80,10 @@ public Huber(StreamInput in) throws IOException { this.delta = in.readDouble(); } + public Huber() { + this(DEFAULT_DELTA); + } + public Huber(@Nullable Double delta) { this.delta = delta != null ? delta : DEFAULT_DELTA; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index a90e6821255d7..a8bf53c781272 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -76,7 +76,7 @@ public Regression(String actualField, String predictedField, @Nullable List defaultMetrics() { - return Arrays.asList(new MeanSquaredError(), new RSquared()); + return Arrays.asList(new MeanSquaredError(), new RSquared(), new Huber()); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index 5466d2fa088ba..5a13540dc6c4c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -40,6 +40,7 @@ import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -110,6 +111,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Classification classification = new Classification("actual", "predicted", null, null); + + List metrics = classification.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())); + } + public void testGetFields() { Classification evaluation = new Classification("foo", "bar", "results", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java index c0b72dbe1c234..d68daf218882f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -89,6 +90,17 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + OutlierDetection outlierDetection = new OutlierDetection("actual", "predicted", null); + + List metrics = outlierDetection.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new AucRoc(false), + new Precision(Arrays.asList(0.25, 0.5, 0.75)), + new Recall(Arrays.asList(0.25, 0.5, 0.75)), + new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)))); + } + public void testGetFields() { OutlierDetection evaluation = new OutlierDetection("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index c8fc2d5d67d55..c6d060a3b84e4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; @@ -76,6 +77,14 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); } + public void testConstructor_GivenDefaultMetrics() { + Regression regression = new Regression("actual", "predicted", null); + + List metrics = regression.getMetrics(); + + assertThat(metrics, containsInAnyOrder(new Huber(), new MeanSquaredError(), new RSquared())); + } + public void testGetFields() { Regression evaluation = new Regression("foo", "bar", null); EvaluationFields fields = evaluation.getFields(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 931db5026535b..e4e8373d56239 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -33,6 +33,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -82,7 +83,13 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MulticlassConfusionMatrix.NAME.getPreferredName())); + containsInAnyOrder( + MulticlassConfusionMatrix.NAME.getPreferredName(), + Accuracy.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java index 6aca51bc15e5d..012a89a13247e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java @@ -25,6 +25,7 @@ import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -53,7 +54,12 @@ public void testEvaluate_DefaultMetrics() { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); assertThat( evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), - contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName())); + containsInAnyOrder( + MeanSquaredError.NAME.getPreferredName(), + RSquared.NAME.getPreferredName(), + Huber.NAME.getPreferredName() + ) + ); } public void testEvaluate_AllMetrics() { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 09cf11d266612..83fe922c02492 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -938,6 +938,10 @@ setup: } - is_true: classification.multiclass_confusion_matrix + - is_true: classification.accuracy + - is_true: classification.precision + - is_true: classification.recall + - is_false: classification.auc_roc --- "Test classification given missing actual_field": - do: @@ -1104,8 +1108,8 @@ setup: - match: { regression.mse.value: 28.67749840974834 } - match: { regression.r_squared.value: 0.8551031778603486 } + - match: { regression.huber.value: 1.9205280586939963 } - is_false: regression.msle.value - - is_false: regression.huber.value --- "Test regression given missing actual_field": - do: