Skip to content

Commit

Permalink
[ML] adds new n_gram_encoding custom processor (elastic#61578)
Browse files Browse the repository at this point in the history
This adds a new `n_gram_encoding` feature processor for analytics and inference.

The focus of this processor is simple ngram encodings that allow:
 - multiple ngrams [1..5]
 - Prefix, infix, suffix
  • Loading branch information
benwtrent committed Sep 3, 2020
1 parent 48870c6 commit f7a4def
Show file tree
Hide file tree
Showing 9 changed files with 1,006 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.elasticsearch.client.ml.inference;

import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
Expand Down Expand Up @@ -57,6 +58,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
FrequencyEncoding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME),
CustomWordEmbedding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(NGram.NAME),
NGram::fromXContent));

// Model
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Objects;


/**
* PreProcessor for n-gram encoding a string
*/
public class NGram implements PreProcessor {

public static final String NAME = "n_gram_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_PREFIX = new ParseField("feature_prefix");
public static final ParseField NGRAMS = new ParseField("n_grams");
public static final ParseField START = new ParseField("start");
public static final ParseField LENGTH = new ParseField("length");
public static final ParseField CUSTOM = new ParseField("custom");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<NGram, Void> PARSER = new ConstructingObjectParser<NGram, Void>(
NAME,
true,
a -> new NGram((String)a[0],
(List<Integer>)a[1],
(Integer)a[2],
(Integer)a[3],
(Boolean)a[4],
(String)a[5]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareIntArray(ConstructingObjectParser.constructorArg(), NGRAMS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), START);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LENGTH);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), FEATURE_PREFIX);
}

public static NGram fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final String field;
private final String featurePrefix;
private final List<Integer> nGrams;
private final Integer start;
private final Integer length;
private final Boolean custom;

NGram(String field, List<Integer> nGrams, Integer start, Integer length, Boolean custom, String featurePrefix) {
this.field = field;
this.featurePrefix = featurePrefix;
this.nGrams = nGrams;
this.start = start;
this.length = length;
this.custom = custom;
}

@Override
public String getName() {
return NAME;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (field != null) {
builder.field(FIELD.getPreferredName(), field);
}
if (featurePrefix != null) {
builder.field(FEATURE_PREFIX.getPreferredName(), featurePrefix);
}
if (nGrams != null) {
builder.field(NGRAMS.getPreferredName(), nGrams);
}
if (start != null) {
builder.field(START.getPreferredName(), start);
}
if (length != null) {
builder.field(LENGTH.getPreferredName(), length);
}
if (custom != null) {
builder.field(CUSTOM.getPreferredName(), custom);
}
builder.endObject();
return builder;
}

public String getField() {
return field;
}

public String getFeaturePrefix() {
return featurePrefix;
}

public List<Integer> getnGrams() {
return nGrams;
}

public Integer getStart() {
return start;
}

public Integer getLength() {
return length;
}

public Boolean getCustom() {
return custom;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NGram nGram = (NGram) o;
return Objects.equals(field, nGram.field) &&
Objects.equals(featurePrefix, nGram.featurePrefix) &&
Objects.equals(nGrams, nGram.nGrams) &&
Objects.equals(start, nGram.start) &&
Objects.equals(length, nGram.length) &&
Objects.equals(custom, nGram.custom);
}

@Override
public int hashCode() {
return Objects.hash(field, featurePrefix, start, length, custom, nGrams);
}

public static Builder builder(String field) {
return new Builder(field);
}

public static class Builder {

private String field;
private String featurePrefix;
private List<Integer> nGrams;
private Integer start;
private Integer length;
private Boolean custom;

public Builder(String field) {
this.field = field;
}

public Builder setField(String field) {
this.field = field;
return this;
}

public Builder setCustom(boolean custom) {
this.custom = custom;
return this;
}

public Builder setFeaturePrefix(String featurePrefix) {
this.featurePrefix = featurePrefix;
return this;
}

public Builder setnGrams(List<Integer> nGrams) {
this.nGrams = nGrams;
return this;
}

public Builder setStart(Integer start) {
this.start = start;
return this;
}

public Builder setLength(Integer length) {
this.length = length;
return this;
}

public Builder setCustom(Boolean custom) {
this.custom = custom;
return this;
}

public NGram build() {
return new NGram(field, nGrams, start, length, custom, featurePrefix);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
Expand Down Expand Up @@ -704,7 +705,7 @@ public void testDefaultNamedXContents() {

public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(69, namedXContents.size());
assertEquals(70, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
Expand Down Expand Up @@ -785,8 +786,9 @@ public void testProvidedNamedXContents() {
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, HuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names,
hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME, NGram.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
assertEquals(Integer.valueOf(4),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;

import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;

import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.IntStream;


public class NGramTests extends AbstractXContentTestCase<NGram> {

@Override
protected NGram doParseInstance(XContentParser parser) throws IOException {
return NGram.fromXContent(parser);
}

@Override
protected boolean supportsUnknownFields() {
return true;
}

@Override
protected NGram createTestInstance() {
return createRandom();
}

public static NGram createRandom() {
return new NGram(randomAlphaOfLength(10),
IntStream.range(1, 5).limit(5).boxed().collect(Collectors.toList()),
randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean() ? null : randomIntBetween(1, 10),
randomBoolean() ? null : randomBoolean(),
randomBoolean() ? null : randomAlphaOfLength(10));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
Expand Down Expand Up @@ -39,12 +46,6 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -64,6 +65,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, NGram.NAME,
(p, c) -> NGram.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));

// PreProcessing Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
Expand All @@ -74,6 +77,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, NGram.NAME,
(p, c) -> NGram.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));

// Model Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
Expand Down Expand Up @@ -154,6 +159,8 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
FrequencyEncoding::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, CustomWordEmbedding.NAME.getPreferredName(),
CustomWordEmbedding::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, NGram.NAME.getPreferredName(),
NGram::new));

// Model
namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new));
Expand Down
Loading

0 comments on commit f7a4def

Please sign in to comment.