diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java index 5a1e07f34e256..cb0eac2841439 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java @@ -30,6 +30,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -225,7 +226,7 @@ public Builder addLeaf(int nodeIndex, double value) { for (int i = nodes.size(); i < nodeIndex + 1; i++) { nodes.add(null); } - nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value)); + nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(Collections.singletonList(value))); return this; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java index 30344674edae8..3ce66fe6cc047 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; +import java.util.List; import java.util.Objects; public class TreeNode implements ToXContentObject { @@ -61,7 +62,7 @@ public class TreeNode implements ToXContentObject { PARSER.declareInt(Builder::setSplitFeature, SPLIT_FEATURE); PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX); PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN); - PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE); + PARSER.declareDoubleArray(Builder::setLeafValue, LEAF_VALUE); PARSER.declareLong(Builder::setNumberSamples, NUMBER_SAMPLES); } @@ -74,7 +75,7 @@ public static Builder fromXContent(XContentParser parser) { private final Integer splitFeature; private final int nodeIndex; private final Double splitGain; - private final Double leafValue; + private final List leafValue; private final Boolean defaultLeft; private final Integer leftChild; private final Integer rightChild; @@ -86,7 +87,7 @@ public static Builder fromXContent(XContentParser parser) { Integer splitFeature, int nodeIndex, Double splitGain, - Double leafValue, + List leafValue, Boolean defaultLeft, Integer leftChild, Integer rightChild, @@ -123,7 +124,7 @@ public Double getSplitGain() { return splitGain; } - public Double getLeafValue() { + public List getLeafValue() { return leafValue; } @@ -212,7 +213,7 @@ public static class Builder { private Integer splitFeature; private int nodeIndex; private Double splitGain; - private Double leafValue; + private List leafValue; private Boolean defaultLeft; private Integer leftChild; private Integer rightChild; @@ -250,7 +251,7 @@ public Builder setSplitGain(Double splitGain) { return this; } - public Builder setLeafValue(Double leafValue) { + public Builder setLeafValue(List leafValue) { this.leafValue = leafValue; return this; } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java index b6198e716eb11..ea6160cc8b620 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; +import java.util.Collections; public class TreeNodeTests extends AbstractXContentTestCase { @@ -48,7 +49,7 @@ protected TreeNode createTestInstance() { public static TreeNode createRandomLeafNode(double internalValue) { return TreeNode.builder(randomInt(100)) .setDefaultLeft(randomBoolean() ? null : randomBoolean()) - .setLeafValue(internalValue) + .setLeafValue(Collections.singletonList(internalValue)) .setNumberSamples(randomNonNegativeLong()) .build(); } @@ -60,7 +61,7 @@ public static TreeNode.Builder createRandom(int nodeIndex, Integer featureIndex, Operator operator) { return TreeNode.builder(nodeIndex) - .setLeafValue(left == null ? randomDouble() : null) + .setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null) .setDefaultLeft(randomBoolean() ? null : randomBoolean()) .setLeftChild(left) .setRightChild(right) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java index add8399e89d00..2c166e39f8067 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -5,29 +5,37 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.ingest.IngestDocument; import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.Objects; -public class RawInferenceResults extends SingleValueInferenceResults { +public class RawInferenceResults implements InferenceResults { public static final String NAME = "raw"; - public RawInferenceResults(double value, Map featureImportance) { - super(value, featureImportance); + private final double[] value; + private final Map featureImportance; + + public RawInferenceResults(double[] value, Map featureImportance) { + this.value = value; + this.featureImportance = featureImportance; + } + + public double[] getValue() { + return value; } - public RawInferenceResults(StreamInput in) throws IOException { - super(in); + public Map getFeatureImportance() { + return featureImportance; } @Override public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); + throw new UnsupportedOperationException("[raw] does not support wire serialization"); } @Override @@ -35,13 +43,13 @@ public boolean equals(Object object) { if (object == this) { return true; } if (object == null || getClass() != object.getClass()) { return false; } RawInferenceResults that = (RawInferenceResults) object; - return Objects.equals(value(), that.value()) - && Objects.equals(getFeatureImportance(), that.getFeatureImportance()); + return Arrays.equals(value, that.value) + && Objects.equals(featureImportance, that.featureImportance); } @Override public int hashCode() { - return Objects.hash(value(), getFeatureImportance()); + return Objects.hash(Arrays.hashCode(value), featureImportance); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 74790a693eb15..52cabb4f7b9f3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -26,30 +26,29 @@ private InferenceHelpers() { } /** * @return Tuple of the highest scored index and the top classes */ - public static Tuple> topClasses(List probabilities, + public static Tuple> topClasses(double[] probabilities, List classificationLabels, @Nullable double[] classificationWeights, int numToInclude) { - if (classificationLabels != null && probabilities.size() != classificationLabels.size()) { + if (classificationLabels != null && probabilities.length != classificationLabels.size()) { throw ExceptionsHelper .serverError( "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", null, - probabilities.size(), + probabilities.length, classificationLabels.size()); } - List scores = classificationWeights == null ? + double[] scores = classificationWeights == null ? probabilities : - IntStream.range(0, probabilities.size()) - .mapToDouble(i -> probabilities.get(i) * classificationWeights[i]) - .boxed() - .collect(Collectors.toList()); + IntStream.range(0, probabilities.length) + .mapToDouble(i -> probabilities[i] * classificationWeights[i]) + .toArray(); - int[] sortedIndices = IntStream.range(0, probabilities.size()) + int[] sortedIndices = IntStream.range(0, scores.length) .boxed() - .sorted(Comparator.comparing(scores::get).reversed()) + .sorted(Comparator.comparing(i -> scores[(Integer)i]).reversed()) .mapToInt(i -> i) .toArray(); @@ -59,14 +58,14 @@ public static Tuple> List labels = classificationLabels == null ? // If we don't have the labels we should return the top classification values anyways, they will just be numeric - IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : + IntStream.range(0, probabilities.length).boxed().map(String::valueOf).collect(Collectors.toList()) : classificationLabels; - int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size()); + int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length); List topClassEntries = new ArrayList<>(count); for(int i = 0; i < count; i++) { int idx = sortedIndices[i]; - topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx))); + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx])); } return Tuple.tuple(sortedIndices[0], topClassEntries); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index 6534766c65f5b..8901c62bb00e8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.apache.lucene.util.Accountable; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; @@ -62,4 +63,8 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou * @return A {@code Map} mapping each featureName to its importance */ Map featureImportance(Map fields, Map featureDecoder); + + default Version getMinimalCompatibilityVersion() { + return Version.V_7_6_0; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 0ff88ca1c3b13..5c9996ba21588 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountables; import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; @@ -20,7 +21,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; @@ -139,19 +139,20 @@ public InferenceResults infer(Map fields, InferenceConfig config throw ExceptionsHelper.badRequestException( "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } - List inferenceResults = new ArrayList<>(this.models.size()); + double[][] inferenceResults = new double[this.models.size()][]; List> featureInfluence = new ArrayList<>(); + int i = 0; NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance()); - this.models.forEach(model -> { + for (TrainedModel model : models) { InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap()); - assert result instanceof SingleValueInferenceResults; - SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result; - inferenceResults.add(inferenceResult.value()); + assert result instanceof RawInferenceResults; + RawInferenceResults inferenceResult = (RawInferenceResults) result; + inferenceResults[i++] = inferenceResult.getValue(); if (config.requestingImportance()) { featureInfluence.add(inferenceResult.getFeatureImportance()); } - }); - List processed = outputAggregator.processValues(inferenceResults); + } + double[] processed = outputAggregator.processValues(inferenceResults); return buildResults(processed, featureInfluence, config, featureDecoderMap); } @@ -160,13 +161,13 @@ public TargetType targetType() { return targetType; } - private InferenceResults buildResults(List processedInferences, + private InferenceResults buildResults(double[] processedInferences, List> featureInfluence, InferenceConfig config, Map featureDecoderMap) { // Indicates that the config is useless and the caller just wants the raw value if (config instanceof NullInferenceConfig) { - return new RawInferenceResults(outputAggregator.aggregate(processedInferences), + return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)}, InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence))); } switch(targetType) { @@ -176,7 +177,7 @@ private InferenceResults buildResults(List processedInferences, InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence))); case CLASSIFICATION: ClassificationConfig classificationConfig = (ClassificationConfig) config; - assert classificationWeights == null || processedInferences.size() == classificationWeights.length; + assert classificationWeights == null || processedInferences.length == classificationWeights.length; // Adjust the probabilities according to the thresholds Tuple> topClasses = InferenceHelpers.topClasses( processedInferences, @@ -356,6 +357,11 @@ public Collection getChildResources() { return Collections.unmodifiableCollection(accountables); } + @Override + public Version getMinimalCompatibilityVersion() { + return models.stream().map(TrainedModel::getMinimalCompatibilityVersion).max(Version::compareTo).orElse(Version.V_7_6_0); + } + public static class Builder { private List featureNames; private List trainedModels; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java index 2dba96916390c..ccd6adbc9a301 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -19,9 +19,9 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; -import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid; +import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { @@ -78,31 +78,39 @@ public Integer expectedValueSize() { } @Override - public List processValues(List values) { + public double[] processValues(double[][] values) { Objects.requireNonNull(values, "values must not be null"); - if (weights != null && values.size() != weights.length) { + if (weights != null && values.length != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } - double summation = weights == null ? - values.stream().mapToDouble(Double::valueOf).sum() : - IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum(); - double probOfClassOne = sigmoid(summation); + double[] sumOnAxis1 = new double[values[0].length]; + for (int j = 0; j < values.length; j++) { + double[] value = values[j]; + double weight = weights == null ? 1.0 : weights[j]; + for(int i = 0; i < value.length; i++) { + if (i >= sumOnAxis1.length) { + throw new IllegalArgumentException("value entries must have the same dimensions"); + } + sumOnAxis1[i] += (value[i] * weight); + } + } + if (sumOnAxis1.length > 1) { + return softMax(sumOnAxis1); + } + + double probOfClassOne = sigmoid(sumOnAxis1[0]); assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0; - return Arrays.asList(1.0 - probOfClassOne, probOfClassOne); + return new double[] {1.0 - probOfClassOne, probOfClassOne}; } @Override - public double aggregate(List values) { + public double aggregate(double[] values) { Objects.requireNonNull(values, "values must not be null"); - assert values.size() == 2; int bestValue = 0; double bestProb = Double.NEGATIVE_INFINITY; - for (int i = 0; i < values.size(); i++) { - if (values.get(i) == null) { - throw new IllegalArgumentException("values must not contain null values"); - } - if (values.get(i) > bestProb) { - bestProb = values.get(i); + for (int i = 0; i < values.length; i++) { + if (values[i] > bestProb) { + bestProb = values[i]; bestValue = i; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index 16b1fd7c4051e..a7da6a6f5a8ce 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -10,8 +10,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; -import java.util.List; - public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable { /** @@ -20,15 +18,15 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A Integer expectedValueSize(); /** - * This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method. + * This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(double[])} method. * * Two major types of pre-processed values could be returned: - * - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)} - * - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)} + * - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(double[][])} + * - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(double[][])} * @param values the values to process * @return A new list containing the processed values or the same list if no processing is required */ - List processValues(List values); + double[] processValues(double[][] values); /** * Function to aggregate the processed values into a single double @@ -40,7 +38,7 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A * @param processedValues The values to aggregate * @return the aggregated value. */ - double aggregate(List processedValues); + double aggregate(double[] processedValues); /** * @return The name of the output aggregator diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index df2e33e4e6f84..02543ff6d7c5a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -89,21 +89,37 @@ public Integer expectedValueSize() { } @Override - public List processValues(List values) { + public double[] processValues(double[][] values) { Objects.requireNonNull(values, "values must not be null"); - if (weights != null && values.size() != weights.length) { + if (weights != null && values.length != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } + // Multiple leaf values + if (values[0].length > 1) { + double[] sumOnAxis1 = new double[values[0].length]; + for (int j = 0; j < values.length; j++) { + double[] value = values[j]; + double weight = weights == null ? 1.0 : weights[j]; + for(int i = 0; i < value.length; i++) { + if (i >= sumOnAxis1.length) { + throw new IllegalArgumentException("value entries must have the same dimensions"); + } + sumOnAxis1[i] += (value[i] * weight); + } + } + return softMax(sumOnAxis1); + } + // Singular leaf values List freqArray = new ArrayList<>(); - Integer maxVal = 0; - for (Double value : values) { - if (value == null) { - throw new IllegalArgumentException("values must not contain null values"); + int maxVal = 0; + for (double[] value : values) { + if (value.length != 1) { + throw new IllegalArgumentException("value entries must have the same dimensions"); } - if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) { + if (Double.isNaN(value[0]) || Double.isInfinite(value[0]) || value[0] < 0.0 || value[0] != Math.rint(value[0])) { throw new IllegalArgumentException("values must be whole, non-infinite, and positive"); } - Integer integerValue = value.intValue(); + int integerValue = Double.valueOf(value[0]).intValue(); freqArray.add(integerValue); if (integerValue > maxVal) { maxVal = integerValue; @@ -112,27 +128,27 @@ public List processValues(List values) { if (maxVal >= numClasses) { throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]"); } - List frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY)); + double[] frequencies = Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY) + .stream() + .mapToDouble(Double::doubleValue) + .toArray(); for (int i = 0; i < freqArray.size(); i++) { - Double weight = weights == null ? 1.0 : weights[i]; - Integer value = freqArray.get(i); - Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight; - frequencies.set(value, frequency); + double weight = weights == null ? 1.0 : weights[i]; + int value = freqArray.get(i); + double frequency = frequencies[value] == Double.NEGATIVE_INFINITY ? weight : frequencies[value] + weight; + frequencies[value] = frequency; } return softMax(frequencies); } @Override - public double aggregate(List values) { + public double aggregate(double[] values) { Objects.requireNonNull(values, "values must not be null"); int bestValue = 0; double bestFreq = Double.NEGATIVE_INFINITY; - for (int i = 0; i < values.size(); i++) { - if (values.get(i) == null) { - throw new IllegalArgumentException("values must not contain null values"); - } - if (values.get(i) > bestFreq) { - bestFreq = values.get(i); + for (int i = 0; i < values.length; i++) { + if (values[i] > bestFreq) { + bestFreq = values[i]; bestValue = i; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index ed1c13cf10203..3f1d701dba497 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -19,8 +19,6 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; import java.util.stream.IntStream; public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { @@ -73,28 +71,25 @@ public WeightedSum(StreamInput in) throws IOException { } @Override - public List processValues(List values) { + public double[] processValues(double[][] values) { Objects.requireNonNull(values, "values must not be null"); + assert values[0].length == 1; if (weights == null) { - return values; + return Arrays.stream(values).mapToDouble(v -> v[0]).toArray(); } - if (values.size() != weights.length) { + if (values.length != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } - return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList()); + return IntStream.range(0, weights.length).mapToDouble(i -> values[i][0] * weights[i]).toArray(); } @Override - public double aggregate(List values) { + public double aggregate(double[] values) { Objects.requireNonNull(values, "values must not be null"); - if (values.isEmpty()) { + if (values.length == 0) { throw new IllegalArgumentException("values must not be empty"); } - Optional summation = values.stream().reduce(Double::sum); - if (summation.isPresent()) { - return summation.get(); - } - throw new IllegalArgumentException("values must not contain null values"); + return Arrays.stream(values).sum(); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index ebad13530df2c..f9a5601f8345a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; @@ -130,7 +129,7 @@ public InferenceResults infer(Map fields, InferenceConfig config double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector); double[] scores = softmaxLayer.productPlusBias(true, h0); - List probabilities = softMax(Arrays.stream(scores).boxed().collect(Collectors.toList())); + double[] probabilities = softMax(scores); ClassificationConfig classificationConfig = (ClassificationConfig) config; Tuple> topClasses = InferenceHelpers.topClasses( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index db2dea9855bcf..e59a7299645bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountables; import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; @@ -29,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.utils.Statistics; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MapHelper; @@ -100,7 +102,7 @@ public static Tree fromXContentLenient(XContentParser parser) { this.nodes = Collections.unmodifiableList(nodes); this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); - this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); + this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue); } public Tree(StreamInput in) throws IOException { @@ -112,7 +114,7 @@ public Tree(StreamInput in) throws IOException { } else { this.classificationLabels = null; } - this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); + this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue); } @Override @@ -147,7 +149,8 @@ public InferenceResults infer(Map fields, InferenceConfig config return buildResult(node.getLeafValue(), featureImportance, config); } - private InferenceResults buildResult(Double value, Map featureImportance, InferenceConfig config) { + private InferenceResults buildResult(double[] value, Map featureImportance, InferenceConfig config) { + assert value != null && value.length > 0; // Indicates that the config is useless and the caller just wants the raw value if (config instanceof NullInferenceConfig) { return new RawInferenceResults(value, featureImportance); @@ -160,13 +163,13 @@ private InferenceResults buildResult(Double value, Map featureIm classificationLabels, null, classificationConfig.getNumTopClasses()); - return new ClassificationInferenceResults(value, + return new ClassificationInferenceResults(topClasses.v1(), classificationLabel(topClasses.v1(), classificationLabels), topClasses.v2(), featureImportance, config); case REGRESSION: - return new RegressionInferenceResults(value, config, featureImportance); + return new RegressionInferenceResults(value[0], config, featureImportance); default: throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); } @@ -193,14 +196,22 @@ public TargetType targetType() { return targetType; } - private List classificationProbability(double inferenceValue) { + private double[] classificationProbability(double[] inferenceValue) { + // Multi-value leaves, indicates that the leaves contain an array of values. + // The index of which corresponds to classification values + if (inferenceValue.length > 1) { + return Statistics.softMax(inferenceValue); + } // If we are classification, we should assume that the inference return value is whole. - assert inferenceValue == Math.rint(inferenceValue); + assert inferenceValue[0] == Math.rint(inferenceValue[0]); double maxCategory = this.highestOrderCategory.get(); // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); - List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); - list.set(Double.valueOf(inferenceValue).intValue(), 1.0); + double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0) + .stream() + .mapToDouble(Double::doubleValue) + .toArray(); + list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0; return list; } @@ -268,6 +279,7 @@ public void validate() { checkTargetType(); detectMissingNodes(); detectCycle(); + verifyLeafNodeUniformity(); } @Override @@ -331,7 +343,6 @@ private void shapRecursive(List processedFeatures, TreeNode currNode = nodes.get(nodeIndex); nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex); if (currNode.isLeaf()) { - // TODO multi-value???? double leafValue = nodeValues[nodeIndex]; for (int i = 1; i < nextIndex; ++i) { double scale = splitPath.sumUnwoundPath(i, nextIndex); @@ -375,7 +386,8 @@ private void shapRecursive(List processedFeatures, private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) { TreeNode node = nodes.get(nodeIndex); if (node.isLeaf()) { - nodeEstimates[nodeIndex] = node.getLeafValue(); + // TODO multi-value???? + nodeEstimates[nodeIndex] = node.getLeafValue()[0]; return 0; } @@ -424,6 +436,10 @@ private void checkTargetType() { throw ExceptionsHelper.badRequestException( "[target_type] should be [classification] if [classification_labels] are provided"); } + if (this.targetType != TargetType.CLASSIFICATION && this.nodes.stream().anyMatch(n -> n.getLeafValue().length > 1)) { + throw ExceptionsHelper.badRequestException( + "[target_type] should be [classification] if leaf nodes have multiple values"); + } } private void detectCycle() { @@ -465,14 +481,39 @@ private void detectMissingNodes() { } } + private void verifyLeafNodeUniformity() { + Integer leafValueLengths = null; + for (TreeNode node : nodes) { + if (node.isLeaf()) { + if (leafValueLengths == null) { + leafValueLengths = node.getLeafValue().length; + } else if (leafValueLengths != node.getLeafValue().length) { + throw ExceptionsHelper.badRequestException( + "[tree.tree_structure] all leaf nodes must have the same number of values"); + } + } + } + } + private static boolean nodeMissing(int nodeIdx, List nodes) { return nodeIdx >= nodes.size(); } private Double maxLeafValue() { - return targetType == TargetType.CLASSIFICATION ? - this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() : - null; + if (targetType != TargetType.CLASSIFICATION) { + return null; + } + double max = 0.0; + for (TreeNode node : this.nodes) { + if (node.isLeaf()) { + if (node.getLeafValue().length > 1) { + return (double)node.getLeafValue().length; + } else { + max = Math.max(node.getLeafValue()[0], max); + } + } + } + return max; } @Override @@ -493,6 +534,14 @@ public Collection getChildResources() { return Collections.unmodifiableCollection(accountables); } + @Override + public Version getMinimalCompatibilityVersion() { + if (nodes.stream().filter(TreeNode::isLeaf).anyMatch(t -> t.getLeafValue().length > 1)) { + return Version.V_7_7_0; + } + return Version.V_7_6_0; + } + public static class Builder { private List featureNames; private ArrayList nodes; @@ -586,6 +635,10 @@ TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultL * @return this */ Tree.Builder addLeaf(int nodeIndex, double value) { + return addLeaf(nodeIndex, Arrays.asList(value)); + } + + Tree.Builder addLeaf(int nodeIndex, List value) { for (int i = nodes.size(); i < nodeIndex + 1; i++) { nodes.add(null); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java index 351e02e8389c3..64a28591130d1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java @@ -21,6 +21,8 @@ import org.elasticsearch.xpack.core.ml.job.config.Operator; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; @@ -60,7 +62,7 @@ private static ObjectParser createParser(boolean lenient parser.declareInt(TreeNode.Builder::setSplitFeature, SPLIT_FEATURE); parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX); parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN); - parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE); + parser.declareDoubleArray(TreeNode.Builder::setLeafValue, LEAF_VALUE); parser.declareLong(TreeNode.Builder::setNumberSamples, NUMBER_SAMPLES); return parser; } @@ -74,7 +76,7 @@ public static TreeNode.Builder fromXContent(XContentParser parser, boolean lenie private final int splitFeature; private final int nodeIndex; private final double splitGain; - private final double leafValue; + private final double[] leafValue; private final boolean defaultLeft; private final int leftChild; private final int rightChild; @@ -86,7 +88,7 @@ private TreeNode(Operator operator, Integer splitFeature, int nodeIndex, Double splitGain, - Double leafValue, + List leafValue, Boolean defaultLeft, Integer leftChild, Integer rightChild, @@ -96,7 +98,7 @@ private TreeNode(Operator operator, this.splitFeature = splitFeature == null ? -1 : splitFeature; this.nodeIndex = nodeIndex; this.splitGain = splitGain == null ? Double.NaN : splitGain; - this.leafValue = leafValue == null ? Double.NaN : leafValue; + this.leafValue = leafValue == null ? new double[0] : leafValue.stream().mapToDouble(Double::doubleValue).toArray(); this.defaultLeft = defaultLeft == null ? false : defaultLeft; this.leftChild = leftChild == null ? -1 : leftChild; this.rightChild = rightChild == null ? -1 : rightChild; @@ -112,7 +114,11 @@ public TreeNode(StreamInput in) throws IOException { splitFeature = in.readInt(); splitGain = in.readDouble(); nodeIndex = in.readVInt(); - leafValue = in.readDouble(); + if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + leafValue = in.readDoubleArray(); + } else { + leafValue = new double[]{in.readDouble()}; + } defaultLeft = in.readBoolean(); leftChild = in.readInt(); rightChild = in.readInt(); @@ -144,7 +150,7 @@ public double getSplitGain() { return splitGain; } - public double getLeafValue() { + public double[] getLeafValue() { return leafValue; } @@ -190,7 +196,18 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(splitFeature); out.writeDouble(splitGain); out.writeVInt(nodeIndex); - out.writeDouble(leafValue); + if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeDoubleArray(leafValue); + } else { + if (leafValue.length > 1) { + throw new IOException("Multi-class classification models require that all nodes are at least version 7.7.0."); + } + if (leafValue.length == 0) { + out.writeDouble(Double.NaN); + } else { + out.writeDouble(leafValue[0]); + } + } out.writeBoolean(defaultLeft); out.writeInt(leftChild); out.writeInt(rightChild); @@ -209,7 +226,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } addOptionalDouble(builder, SPLIT_GAIN, splitGain); builder.field(NODE_INDEX.getPreferredName(), nodeIndex); - addOptionalDouble(builder, LEAF_VALUE, leafValue); + if (leafValue.length > 0) { + builder.field(LEAF_VALUE.getPreferredName(), leafValue); + } builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft); if (leftChild >= 0) { builder.field(LEFT_CHILD.getPreferredName(), leftChild); @@ -238,7 +257,7 @@ public boolean equals(Object o) { && Objects.equals(splitFeature, that.splitFeature) && Objects.equals(nodeIndex, that.nodeIndex) && Objects.equals(splitGain, that.splitGain) - && Objects.equals(leafValue, that.leafValue) + && Arrays.equals(leafValue, that.leafValue) && Objects.equals(defaultLeft, that.defaultLeft) && Objects.equals(leftChild, that.leftChild) && Objects.equals(rightChild, that.rightChild) @@ -252,7 +271,7 @@ public int hashCode() { splitFeature, splitGain, nodeIndex, - leafValue, + Arrays.hashCode(leafValue), defaultLeft, leftChild, rightChild, @@ -270,7 +289,7 @@ public static Builder builder(int nodeIndex) { @Override public long ramBytesUsed() { - return SHALLOW_SIZE; + return SHALLOW_SIZE + this.leafValue.length * Double.BYTES; } public static class Builder { @@ -279,7 +298,7 @@ public static class Builder { private Integer splitFeature; private int nodeIndex; private Double splitGain; - private Double leafValue; + private List leafValue; private Boolean defaultLeft; private Integer leftChild; private Integer rightChild; @@ -317,11 +336,19 @@ public Builder setSplitGain(Double splitGain) { return this; } - public Builder setLeafValue(Double leafValue) { + public Builder setLeafValue(double leafValue) { + return this.setLeafValue(Collections.singletonList(leafValue)); + } + + public Builder setLeafValue(List leafValue) { this.leafValue = leafValue; return this; } + List getLeafValue() { + return this.leafValue; + } + public Builder setDefaultLeft(Boolean defaultLeft) { this.defaultLeft = defaultLeft; return this; @@ -358,6 +385,9 @@ public void validate() { if (leafValue == null) { throw new IllegalArgumentException("[leaf_value] is required for a leaf node."); } + if (leafValue.stream().anyMatch(Objects::isNull)) { + throw new IllegalArgumentException("[leaf_value] cannot have null values."); + } } else { if (leftChild < 0) { throw new IllegalArgumentException("[left_child] must be a non-negative integer."); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java index 1cdddcd7af26b..260c8137d1d27 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -7,8 +7,7 @@ import org.elasticsearch.common.Numbers; -import java.util.List; -import java.util.stream.Collectors; +import java.util.Arrays; public final class Statistics { @@ -20,28 +19,29 @@ private Statistics(){} * Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the * softMax. * @param values Values on which to run SoftMax. - * @return A new list containing the softmax of the passed values + * @return A new array containing the softmax of the passed values */ - public static List softMax(List values) { - Double expSum = 0.0; - Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null); - if (max == null) { + public static double[] softMax(double[] values) { + double expSum = 0.0; + double max = Arrays.stream(values).filter(Statistics::isValid).max().orElse(Double.NaN); + if (isValid(max) == false) { throw new IllegalArgumentException("no valid values present"); } - List exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY) - .collect(Collectors.toList()); - for (int i = 0; i < exps.size(); i++) { - if (isValid(exps.get(i))) { - Double exp = Math.exp(exps.get(i)); + double[] exps = new double[values.length]; + for (int i = 0; i < exps.length; i++) { + if (isValid(values[i])) { + double exp = Math.exp(values[i] - max); expSum += exp; - exps.set(i, exp); + exps[i] = exp; + } else { + exps[i] = Double.NaN; } } - for (int i = 0; i < exps.size(); i++) { - if (isValid(exps.get(i))) { - exps.set(i, exps.get(i)/expSum); + for (int i = 0; i < exps.length; i++) { + if (isValid(exps[i])) { + exps[i] /= expSum; } else { - exps.set(i, 0.0); + exps[i] = 0.0; } } return exps; @@ -51,8 +51,8 @@ public static double sigmoid(double value) { return 1/(1 + Math.exp(-value)); } - private static boolean isValid(Double v) { - return v != null && Numbers.isValidDouble(v); + private static boolean isValid(double v) { + return Numbers.isValidDouble(v); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java index 1ebf009add7a7..325768054d0b0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java @@ -5,24 +5,37 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; -public class RawInferenceResultsTests extends AbstractWireSerializingTestCase { +import static org.hamcrest.CoreMatchers.equalTo; + +public class RawInferenceResultsTests extends ESTestCase { public static RawInferenceResults createRandomResults() { - return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08)); + int n = randomIntBetween(1, 10); + double[] results = new double[n]; + for (int i = 0; i < n; i++) { + results[i] = randomDouble(); + } + return new RawInferenceResults(results, randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08)); } - @Override - protected RawInferenceResults createTestInstance() { - return createRandomResults(); + public void testEqualityAndHashcode() { + int n = randomIntBetween(1, 10); + double[] results = new double[n]; + for (int i = 0; i < n; i++) { + results[i] = randomDouble(); + } + Map importance = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08); + RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance)); + RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance)); + assertThat(lft, equalTo(rgt)); + assertThat(lft.hashCode(), equalTo(rgt.hashCode())); } - @Override - protected Writeable.Reader instanceReader() { - return RawInferenceResults::new; - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java index e630a5874fc79..55cbbcb48e59b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java @@ -11,11 +11,10 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; -import java.util.Arrays; -import java.util.List; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; public class LogisticRegressionTests extends WeightedAggregatorTests { @@ -43,7 +42,13 @@ protected Writeable.Reader instanceReader() { public void testAggregate() { double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; - List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + double[][] values = new double[][]{ + new double[] {1.0}, + new double[] {2.0}, + new double[] {2.0}, + new double[] {3.0}, + new double[] {5.0} + }; LogisticRegression logisticRegression = new LogisticRegression(ones); assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0)); @@ -57,6 +62,36 @@ public void testAggregate() { assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0)); } + public void testAggregateMultiValueArrays() { + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; + double[][] values = new double[][]{ + new double[] {1.0, 0.0, 1.0}, + new double[] {2.0, 0.0, 0.0}, + new double[] {2.0, 3.0, 1.0}, + new double[] {3.0, 3.0, 1.0}, + new double[] {1.0, 1.0, 5.0} + }; + + LogisticRegression logisticRegression = new LogisticRegression(ones); + double[] processedValues = logisticRegression.processValues(values); + assertThat(processedValues.length, equalTo(3)); + assertThat(processedValues[0], closeTo(0.665240955, 0.00001)); + assertThat(processedValues[1], closeTo(0.090030573, 0.00001)); + assertThat(processedValues[2], closeTo(0.244728471, 0.00001)); + assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0)); + + double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0}; + + logisticRegression = new LogisticRegression(variedWeights); + processedValues = logisticRegression.processValues(values); + assertThat(processedValues.length, equalTo(3)); + assertThat(processedValues[0], closeTo(0.0, 0.00001)); + assertThat(processedValues[1], closeTo(0.0, 0.00001)); + assertThat(processedValues[2], closeTo(0.9999999, 0.00001)); + assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(2.0)); + + } + public void testCompatibleWith() { LogisticRegression logisticRegression = createTestInstance(); assertThat(logisticRegression.compatibleWith(TargetType.CLASSIFICATION), is(true)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java index 02bfe2797d990..d9f949be99d45 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java @@ -8,9 +8,6 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.junit.Before; -import java.util.ArrayList; -import java.util.List; - import static org.hamcrest.Matchers.equalTo; public abstract class WeightedAggregatorTests extends AbstractSerializingTestCase { @@ -35,9 +32,9 @@ public void testWithNullValues() { public void testWithValuesOfWrongLength() { int numberOfValues = randomIntBetween(5, 10); - List values = new ArrayList<>(numberOfValues); + double[][] values = new double[numberOfValues][]; for (int i = 0; i < numberOfValues; i++) { - values.add(randomDouble()); + values[i] = new double[] {randomDouble()}; } OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 6f0496772be0e..4e7258f0dc804 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; -import java.util.Arrays; -import java.util.List; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; @@ -44,7 +42,13 @@ protected Writeable.Reader instanceReader() { public void testAggregate() { double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; - List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + double[][] values = new double[][]{ + new double[] {1.0}, + new double[] {2.0}, + new double[] {2.0}, + new double[] {3.0}, + new double[] {5.0} + }; WeightedMode weightedMode = new WeightedMode(ones, 6); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); @@ -57,19 +61,55 @@ public void testAggregate() { weightedMode = new WeightedMode(6); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); - values = Arrays.asList(1.0, 1.0, 1.0, 1.0, 2.0); + values = new double[][]{ + new double[] {1.0}, + new double[] {1.0}, + new double[] {1.0}, + new double[] {1.0}, + new double[] {2.0} + }; weightedMode = new WeightedMode(6); - List processedValues = weightedMode.processValues(values); - assertThat(processedValues.size(), equalTo(6)); - assertThat(processedValues.get(0), equalTo(0.0)); - assertThat(processedValues.get(1), closeTo(0.95257412, 0.00001)); - assertThat(processedValues.get(2), closeTo((1.0 - 0.95257412), 0.00001)); - assertThat(processedValues.get(3), equalTo(0.0)); - assertThat(processedValues.get(4), equalTo(0.0)); - assertThat(processedValues.get(5), equalTo(0.0)); + double[] processedValues = weightedMode.processValues(values); + assertThat(processedValues.length, equalTo(6)); + assertThat(processedValues[0], equalTo(0.0)); + assertThat(processedValues[1], closeTo(0.95257412, 0.00001)); + assertThat(processedValues[2], closeTo((1.0 - 0.95257412), 0.00001)); + assertThat(processedValues[3], equalTo(0.0)); + assertThat(processedValues[4], equalTo(0.0)); + assertThat(processedValues[5], equalTo(0.0)); assertThat(weightedMode.aggregate(processedValues), equalTo(1.0)); } + public void testAggregateMultiValueArrays() { + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; + double[][] values = new double[][]{ + new double[] {1.0, 0.0, 1.0}, + new double[] {2.0, 0.0, 0.0}, + new double[] {2.0, 3.0, 1.0}, + new double[] {3.0, 3.0, 1.0}, + new double[] {1.0, 1.0, 5.0} + }; + + WeightedMode weightedMode = new WeightedMode(ones, 3); + double[] processedValues = weightedMode.processValues(values); + assertThat(processedValues.length, equalTo(3)); + assertThat(processedValues[0], closeTo(0.665240955, 0.00001)); + assertThat(processedValues[1], closeTo(0.090030573, 0.00001)); + assertThat(processedValues[2], closeTo(0.244728471, 0.00001)); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(0.0)); + + double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0}; + + weightedMode = new WeightedMode(variedWeights, 3); + processedValues = weightedMode.processValues(values); + assertThat(processedValues.length, equalTo(3)); + assertThat(processedValues[0], closeTo(0.0, 0.00001)); + assertThat(processedValues[1], closeTo(0.0, 0.00001)); + assertThat(processedValues[2], closeTo(0.9999999, 0.00001)); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); + + } + public void testCompatibleWith() { WeightedMode weightedMode = createTestInstance(); assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index 8e4a6577dbb27..b01ef7788347a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; -import java.util.Arrays; -import java.util.List; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; @@ -43,7 +41,13 @@ protected Writeable.Reader instanceReader() { public void testAggregate() { double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; - List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + double[][] values = new double[][]{ + new double[] {1.0}, + new double[] {2.0}, + new double[] {2.0}, + new double[] {3.0}, + new double[] {5.0} + }; WeightedSum weightedSum = new WeightedSum(ones); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java index 1a9aec15b45ed..2814158607adc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java @@ -55,7 +55,7 @@ public static TreeNode createRandomLeafNode(double internalValue) { return TreeNode.builder(randomInt(100)) .setDefaultLeft(randomBoolean() ? null : randomBoolean()) .setNumberSamples(randomNonNegativeLong()) - .setLeafValue(internalValue) + .setLeafValue(Collections.singletonList(internalValue)) .build(); } @@ -66,7 +66,7 @@ public static TreeNode.Builder createRandom(int nodeId, Integer featureIndex, Operator operator) { return TreeNode.builder(nodeId) - .setLeafValue(left == null ? randomDouble() : null) + .setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null) .setDefaultLeft(randomBoolean() ? null : randomBoolean()) .setLeftChild(left) .setRightChild(right) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 4e0fe560210da..c86f1c38a08cf 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -112,7 +112,7 @@ protected Writeable.Reader instanceReader() { public void testInferWithStump() { Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); - builder.setRoot(TreeNode.builder(0).setLeafValue(42.0)); + builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0))); builder.setFeatureNames(Collections.emptyList()); Tree tree = builder.build(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java index cc99a19b38a73..ba2975c200d04 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java @@ -16,18 +16,18 @@ public class StatisticsTests extends ESTestCase { public void testSoftMax() { - List values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0); - List softMax = Statistics.softMax(values); + double[] values = new double[] {Double.NEGATIVE_INFINITY, 1.0, -0.5, Double.NaN, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0}; + double[] softMax = Statistics.softMax(values); - List expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042); + double[] expected = new double[] {0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042}; - for(int i = 0; i < expected.size(); i++) { - assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001)); + for(int i = 0; i < expected.length; i++) { + assertThat(softMax[i], closeTo(expected[i], 0.000001)); } } public void testSoftMaxWithNoValidValues() { - List values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY); + double[] values = new double[] {Double.NEGATIVE_INFINITY, Double.NaN, Double.POSITIVE_INFINITY}; expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values)); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index e993d523b5430..afa5f1c6dcb22 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -211,14 +211,14 @@ private static TrainedModel buildRegression() { .setRightChild(2) .setSplitFeature(0) .setThreshold(0.5), - TreeNode.builder(1).setLeafValue(0.3), + TreeNode.builder(1).setLeafValue(Collections.singletonList(0.3)), TreeNode.builder(2) .setThreshold(0.0) .setSplitFeature(3) .setLeftChild(3) .setRightChild(4), - TreeNode.builder(3).setLeafValue(0.1), - TreeNode.builder(4).setLeafValue(0.2)) + TreeNode.builder(3).setLeafValue(Collections.singletonList(0.1)), + TreeNode.builder(4).setLeafValue(Collections.singletonList(0.2))) .build(); Tree tree2 = Tree.builder() .setFeatureNames(featureNames) @@ -227,8 +227,8 @@ private static TrainedModel buildRegression() { .setRightChild(2) .setSplitFeature(2) .setThreshold(1.0), - TreeNode.builder(1).setLeafValue(1.5), - TreeNode.builder(2).setLeafValue(0.9)) + TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)), + TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9))) .build(); Tree tree3 = Tree.builder() .setFeatureNames(featureNames) @@ -237,8 +237,8 @@ private static TrainedModel buildRegression() { .setRightChild(2) .setSplitFeature(1) .setThreshold(0.2), - TreeNode.builder(1).setLeafValue(1.5), - TreeNode.builder(2).setLeafValue(0.9)) + TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)), + TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9))) .build(); return Ensemble.builder() .setTargetType(TargetType.REGRESSION) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index f17ee697b660d..76b7665d1b3d0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -97,6 +97,18 @@ protected void masterOperation(Task task, return; } + Version minCompatibilityVersion = request.getTrainedModelConfig() + .getModelDefinition() + .getTrainedModel() + .getMinimalCompatibilityVersion(); + if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) { + listener.onFailure(ExceptionsHelper.badRequestException( + "Definition for [{}] requires that all nodes are at least version [{}]", + request.getTrainedModelConfig().getModelId(), + minCompatibilityVersion.toString())); + return; + } + TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig()) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 86d5f278c42a3..fc332c89a7727 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -22,6 +22,12 @@ import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; @@ -189,6 +195,109 @@ public void testInferModels() throws Exception { assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); } + public void testInferModelMultiClassModel() throws Exception { + String modelId = "test-load-models-classification-multi"; + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId) + .setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical"))) + .setParsedDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding))) + .setTrainedModel(buildMultiClassClassification())) + .setVersion(Version.CURRENT) + .setLicenseLevel(License.OperationMode.PLATINUM.description()) + .setCreateTime(Instant.now()) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) + .build(); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + + List> toInfer = new ArrayList<>(); + toInfer.add(new HashMap<>() {{ + put("field", new HashMap<>(){{ + put("foo", 1.0); + put("bar", 0.5); + }}); + put("other", new HashMap<>(){{ + put("categorical", "dog"); + }}); + }}); + toInfer.add(new HashMap<>() {{ + put("field", new HashMap<>(){{ + put("foo", 0.9); + put("bar", 1.5); + }}); + put("other", new HashMap<>(){{ + put("categorical", "cat"); + }}); + }}); + + List> toInfer2 = new ArrayList<>(); + toInfer2.add(new HashMap<>() {{ + put("field", new HashMap<>(){{ + put("foo", 0.0); + put("bar", 0.01); + }}); + put("other", new HashMap<>(){{ + put("categorical", "dog"); + }}); + }}); + toInfer2.add(new HashMap<>() {{ + put("field", new HashMap<>(){{ + put("foo", 1.0); + put("bar", 0.0); + }}); + put("other", new HashMap<>(){{ + put("categorical", "cat"); + }}); + }}); + + // Test regression + InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId, + toInfer, + ClassificationConfig.EMPTY_PARAMS, + true); + InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults() + .stream() + .map(i -> ((SingleValueInferenceResults)i).valueAsString()) + .collect(Collectors.toList()), + contains("option_0", "option_2")); + + request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfig.EMPTY_PARAMS, true); + response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults() + .stream() + .map(i -> ((SingleValueInferenceResults)i).valueAsString()) + .collect(Collectors.toList()), + contains("option_2", "option_0")); + + + // Get top classes + request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfig(3, null, null), true); + response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); + + ClassificationInferenceResults classificationInferenceResults = + (ClassificationInferenceResults)response.getInferenceResults().get(0); + + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_0")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_2")); + assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1")); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_2")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_0")); + assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1")); + } + + public void testInferMissingModel() { String model = "test-infer-missing-model"; InternalInferModelAction.Request request = new InternalInferModelAction.Request( @@ -256,6 +365,54 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setModelId(modelId); } + public static TrainedModel buildMultiClassClassification() { + List featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog"); + + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(1.0, 0.0, 2.0))) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(Arrays.asList(0.0, 1.0, 0.0))) + .addNode(TreeNode.builder(4).setLeafValue(Arrays.asList(0.0, 0.0, 1.0))).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(3) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(2.0, 0.0, 0.0))) + .addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 2.0, 0.0))) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(0.0, 0.0, 1.0))) + .addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 1.0, 0.0))) + .build(); + return Ensemble.builder() + .setClassificationLabels(Arrays.asList("option_0", "option_1", "option_2")) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3)) + .build(); + } + + @Override public NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>();