From 5d1be250e9f8e8fca8a593e22f5afc44733006f5 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 4 Sep 2020 11:45:05 +0300 Subject: [PATCH] [ML] Add incremental id during data frame analytics reindexing (#61943) Previously, we added a copy of the `_id` during reindexing and sorted the destination index on that. This allowed us to traverse the docs in the destination index in a stable order multiple times and with efficiency. However, the destination index being sorted means we cannot have `nested` typed fields. This is a problem as it does not allow us to provide a good experience with our evaluate API when it comes to computing metrics for specific classes, features, etc. This commit changes the approach in order to result to a destination index that allows nested fields. Instead of adding a copy of the `_id` field, we now add an incremental id that we can use to traverse the docs in a stable order. We also ensure we always assign the same incremental id to the same doc from the source indices by sorting on `_seq_no` during reindexing. That in combination with the reindexing API using scroll gives us a stable order as scroll uses the (`_index`, `_doc`, shard_id) tuple to resolve ties. The extractor now does not need to scroll. Instead we sort on the incremental id and we do ranged searches to avoid the sort-all-docs overhead. Finally, the `TestDocsIterator` is simply changed to search_after the incremental id. With these changes data frame analytics jobs do not use scroll at any part. Having all these in place, the commit adds the `nested` types to the necessary fields of `classification` and `regression` analyses results. --- .../ml/dataframe/analyses/Classification.java | 29 +++- .../core/ml/dataframe/analyses/MapUtils.java | 64 ------- .../ml/dataframe/analyses/Regression.java | 18 +- .../analyses/ClassificationTests.java | 12 +- .../dataframe/analyses/RegressionTests.java | 2 +- .../dataframe/DataFrameAnalyticsManager.java | 25 ++- .../xpack/ml/dataframe/DestinationIndex.java | 10 +- .../extractor/DataFrameDataExtractor.java | 80 +++------ .../extractor/ExtractedFieldsDetector.java | 2 +- .../dataframe/inference/TestDocsIterator.java | 6 +- .../ml/dataframe/DestinationIndexTests.java | 9 +- .../DataFrameDataExtractorTests.java | 162 +++++------------- .../xpack/ml/test/SearchHitBuilder.java | 4 +- 13 files changed, 151 insertions(+), 272 deletions(-) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 5c3a9bead9d15..94291bb6e0099 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -15,6 +15,9 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.FieldAliasMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; @@ -123,6 +126,30 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU ) ); + static final Map FEATURE_IMPORTANCE_MAPPING; + static { + Map classesProperties = new HashMap<>(); + classesProperties.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); + classesProperties.put("importance", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + + Map classesMapping = new HashMap<>(); + classesMapping.put("dynamic", false); + classesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); + classesMapping.put("properties", classesProperties); + + Map properties = new HashMap<>(); + properties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); + properties.put("importance", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + properties.put("classes", classesMapping); + + Map mapping = new HashMap<>(); + mapping.put("dynamic", false); + mapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); + mapping.put("properties", properties); + + FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(mapping); + } + private final String dependentVariable; private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; @@ -347,7 +374,7 @@ public List getFieldCardinalityConstraints() { @Override public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { Map additionalProperties = new HashMap<>(); - additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.classificationFeatureImportanceMapping()); + additionalProperties.put(resultsFieldName + ".feature_importance", FEATURE_IMPORTANCE_MAPPING); Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties); if ((dependentVariableMapping instanceof Map) == false) { return additionalProperties; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java deleted file mode 100644 index 3cc8825944f28..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - *//* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe.analyses; - -import org.elasticsearch.index.mapper.KeywordFieldMapper; -import org.elasticsearch.index.mapper.NumberFieldMapper; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -final class MapUtils { - - private static Map createFeatureImportanceMapping(Map featureImportanceMappingProperties){ - featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); - Map featureImportanceMapping = new HashMap<>(); - // TODO sorted indices don't support nested types - //featureImportanceMapping.put("dynamic", true); - //featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); - featureImportanceMapping.put("properties", featureImportanceMappingProperties); - return featureImportanceMapping; - } - - private static final Map CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING; - static { - Map classImportancePropertiesMapping = new HashMap<>(); - // TODO sorted indices don't support nested types - //classImportancePropertiesMapping.put("dynamic", true); - //classImportancePropertiesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); - classImportancePropertiesMapping.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); - classImportancePropertiesMapping.put("importance", - Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); - Map featureImportancePropertiesMapping = new HashMap<>(); - featureImportancePropertiesMapping.put("classes", Collections.singletonMap("properties", classImportancePropertiesMapping)); - CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING = - Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping)); - } - - private static final Map REGRESSION_FEATURE_IMPORTANCE_MAPPING; - static { - Map featureImportancePropertiesMapping = new HashMap<>(); - featureImportancePropertiesMapping.put("importance", - Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); - REGRESSION_FEATURE_IMPORTANCE_MAPPING = - Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping)); - } - - static Map regressionFeatureImportanceMapping() { - return REGRESSION_FEATURE_IMPORTANCE_MAPPING; - } - - static Map classificationFeatureImportanceMapping() { - return CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING; - } - - private MapUtils() {} -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 02ae159d26b24..694abec5ce9f7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -14,7 +14,9 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; @@ -97,6 +99,20 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno ) ); + static final Map FEATURE_IMPORTANCE_MAPPING; + static { + Map properties = new HashMap<>(); + properties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); + properties.put("importance", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + + Map mapping = new HashMap<>(); + mapping.put("dynamic", false); + mapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); + mapping.put("properties", properties); + + FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(mapping); + } + private final String dependentVariable; private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; @@ -269,7 +285,7 @@ public List getFieldCardinalityConstraints() { @Override public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { Map additionalProperties = new HashMap<>(); - additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.regressionFeatureImportanceMapping()); + additionalProperties.put(resultsFieldName + ".feature_importance", FEATURE_IMPORTANCE_MAPPING); // Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of // high (over 10M) values of dependent variable. additionalProperties.put(resultsFieldName + "." + predictionFieldName, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index b4f48df4e40ee..2bc5b2d66f906 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -352,12 +352,12 @@ public void testFieldCardinalityLimitsIsNonEmpty() { public void testGetExplicitlyMappedFields() { assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING))); assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING))); assertThat( new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING))); Map explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields( Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")), "results"); @@ -365,7 +365,7 @@ public void testGetExplicitlyMappedFields() { allOf( hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")), hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz")))); - assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())); + assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING)); explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields( new HashMap<>() {{ @@ -380,7 +380,7 @@ public void testGetExplicitlyMappedFields() { allOf( hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")), hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long")))); - assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())); + assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING)); assertThat( new Classification("foo").getExplicitlyMappedFields( @@ -389,7 +389,7 @@ public void testGetExplicitlyMappedFields() { put("path", "missing"); }}), "results"), - equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()))); + equalTo(Collections.singletonMap("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING))); } public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index a839e4cbdc5e9..7e248190fa174 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -317,7 +317,7 @@ public void testFieldCardinalityLimitsIsEmpty() { public void testGetExplicitlyMappedFields() { Map explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results"); assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double"))); - assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.regressionFeatureImportanceMapping())); + assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Regression.FEATURE_IMPORTANCE_MAPPING)); } public void testGetStateDocId() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index f66e6af29de14..aa84fda1f70fc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -28,10 +28,12 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.mapper.SeqNoFieldMapper; import org.elasticsearch.index.reindex.BulkByScrollResponse; import org.elasticsearch.index.reindex.ReindexAction; import org.elasticsearch.index.reindex.ReindexRequest; import org.elasticsearch.script.Script; +import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.xpack.core.ClientHelper; @@ -49,6 +51,9 @@ import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import java.time.Clock; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; @@ -263,11 +268,27 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF reindexRequest.setRefresh(true); reindexRequest.setSourceIndices(config.getSource().getIndex()); reindexRequest.setSourceQuery(config.getSource().getParsedQuery()); + reindexRequest.getSearchRequest().allowPartialSearchResults(false); reindexRequest.getSearchRequest().source().fetchSource(config.getSource().getSourceFiltering()); + reindexRequest.getSearchRequest().source().sort(SeqNoFieldMapper.NAME, SortOrder.ASC); reindexRequest.setDestIndex(config.getDest().getIndex()); - reindexRequest.setScript(new Script("ctx._source." + DestinationIndex.ID_COPY + " = ctx._id")); + + // We explicitly set slices to 1 as we cannot parallelize in order to have the incremental id + reindexRequest.setSlices(1); + Map counterValueParam = new HashMap<>(); + counterValueParam.put("value", -1); + reindexRequest.setScript( + new Script( + Script.DEFAULT_SCRIPT_TYPE, + Script.DEFAULT_SCRIPT_LANG, + // We use indirection here because top level params are immutable. + // This is a work around at the moment but the plan is to make this a feature of reindex API. + "ctx._source." + DestinationIndex.INCREMENTAL_ID + " = ++params.counter.value", + Collections.singletonMap("counter", counterValueParam) + ) + ); + reindexRequest.setParentTask(task.getParentTaskId()); - reindexRequest.getSearchRequest().allowPartialSearchResults(false); final ThreadContext threadContext = parentTaskClient.threadPool().getThreadContext(); final Supplier supplier = threadContext.newRestorableContext(false); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndex.java index fe20ef2cbad09..c5fec1ca44e37 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndex.java @@ -23,9 +23,7 @@ import org.elasticsearch.cluster.metadata.MappingMetadata; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexSortConfig; -import org.elasticsearch.index.mapper.KeywordFieldMapper; -import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; @@ -46,7 +44,7 @@ */ public final class DestinationIndex { - public static final String ID_COPY = "ml__id_copy"; + public static final String INCREMENTAL_ID = "ml__incremental_id"; /** * The field that indicates whether a doc was used for training or not @@ -136,8 +134,6 @@ private static Settings settings(GetSettingsResponse settingsResponse) { Integer maxNumberOfReplicas = findMaxSettingValue(settingsResponse, IndexMetadata.SETTING_NUMBER_OF_REPLICAS); Settings.Builder settingsBuilder = Settings.builder(); - settingsBuilder.put(IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), ID_COPY); - settingsBuilder.put(IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey(), SortOrder.ASC); if (maxNumberOfShards != null) { settingsBuilder.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, maxNumberOfShards); } @@ -163,7 +159,7 @@ private static Integer findMaxSettingValue(GetSettingsResponse settingsResponse, private static Map createAdditionalMappings(DataFrameAnalyticsConfig config, Map mappingsProperties) { Map properties = new HashMap<>(); - properties.put(ID_COPY, Map.of("type", KeywordFieldMapper.CONTENT_TYPE)); + properties.put(INCREMENTAL_ID, Map.of("type", NumberFieldMapper.NumberType.LONG.typeName())); properties.putAll(config.getAnalysis().getExplicitlyMappedFields(mappingsProperties, config.getDest().getResultsField())); return properties; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index fd24eb5d268de..211b25b4a9988 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -9,16 +9,11 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.ClearScrollAction; -import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.SearchScrollAction; -import org.elasticsearch.action.search.SearchScrollRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -46,28 +41,28 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; /** - * An implementation that extracts data from elasticsearch using search and scroll on a client. - * It supports safe and responsive cancellation by continuing the scroll until a new timestamp - * is seen. + * An implementation that extracts data from elasticsearch using ranged searches + * on the incremental id. + * We detect the end of the extraction by doing an additional search at the end + * which should return empty results. + * It supports safe and responsive cancellation by continuing from the latest + * incremental id that was seen. * Note that this class is NOT thread-safe. */ public class DataFrameDataExtractor { private static final Logger LOGGER = LogManager.getLogger(DataFrameDataExtractor.class); - private static final TimeValue SCROLL_TIMEOUT = new TimeValue(30, TimeUnit.MINUTES); public static final String NULL_VALUE = "\0"; private final Client client; private final DataFrameDataExtractorContext context; - private String scrollId; - private String lastSortKey; + private long lastSortKey = -1; private boolean isCancelled; private boolean hasNext; private boolean searchHasShardFailure; @@ -122,7 +117,7 @@ public Optional> next() throws IOException { throw new NoSuchElementException(); } - Optional> hits = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll()); + Optional> hits = Optional.ofNullable(nextSearch()); if (hits.isPresent() && hits.get().isEmpty() == false) { lastSortKey = hits.get().get(hits.get().size() - 1).getSortKey(); } else { @@ -131,8 +126,7 @@ public Optional> next() throws IOException { return hits; } - protected List initScroll() throws IOException { - LOGGER.debug("[{}] Initializing scroll", context.jobId); + protected List nextSearch() throws IOException { return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest())); } @@ -154,7 +148,7 @@ private List tryRequestWithSearchResponse(Supplier request) } LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e); markScrollAsErrored(); - return initScroll(); + return nextSearch(); } } @@ -163,24 +157,24 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest } private SearchRequestBuilder buildSearchRequest() { + long from = lastSortKey + 1; + long to = from + context.scrollSize; + + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] Searching docs with [{}] in [{}, {})", context.jobId, DestinationIndex.INCREMENTAL_ID, from, to)); + SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) - .setScroll(SCROLL_TIMEOUT) // This ensures the search throws if there are failures and the scroll context gets cleared automatically .setAllowPartialSearchResults(false) - .addSort(DestinationIndex.ID_COPY, SortOrder.ASC) + .addSort(DestinationIndex.INCREMENTAL_ID, SortOrder.ASC) .setIndices(context.indices) .setSize(context.scrollSize); - if (lastSortKey == null) { - searchRequestBuilder.setQuery(context.query); - } else { - LOGGER.debug(() -> new ParameterizedMessage("[{}] Searching docs with [{}] greater than [{}]", - context.jobId, DestinationIndex.ID_COPY, lastSortKey)); - QueryBuilder queryPlusLastSortKey = QueryBuilders.boolQuery() + searchRequestBuilder.setQuery( + QueryBuilders.boolQuery() .filter(context.query) - .filter(QueryBuilders.rangeQuery(DestinationIndex.ID_COPY).gt(lastSortKey)); - searchRequestBuilder.setQuery(queryPlusLastSortKey); - } + .filter(QueryBuilders.rangeQuery(DestinationIndex.INCREMENTAL_ID).gte(from).lt(to)) + ); setFetchSource(searchRequestBuilder); @@ -206,10 +200,8 @@ private void setFetchSource(SearchRequestBuilder searchRequestBuilder) { } private List processSearchResponse(SearchResponse searchResponse) { - scrollId = searchResponse.getScrollId(); if (searchResponse.getHits().getHits().length == 0) { hasNext = false; - clearScroll(scrollId); return null; } @@ -218,7 +210,6 @@ private List processSearchResponse(SearchResponse searchResponse) { for (SearchHit hit : hits) { if (isCancelled) { hasNext = false; - clearScroll(scrollId); break; } rows.add(createRow(hit)); @@ -301,35 +292,12 @@ private Row createRow(SearchHit hit) { return new Row(extractedValues, hit, isTraining); } - private List continueScroll() throws IOException { - LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId); - return tryRequestWithSearchResponse(() -> executeSearchScrollRequest(scrollId)); - } - private void markScrollAsErrored() { // This could be a transient error with the scroll Id. // Reinitialise the scroll and try again but only once. - scrollId = null; searchHasShardFailure = true; } - protected SearchResponse executeSearchScrollRequest(String scrollId) { - return ClientHelper.executeWithHeaders(context.headers, ClientHelper.ML_ORIGIN, client, - () -> new SearchScrollRequestBuilder(client, SearchScrollAction.INSTANCE) - .setScroll(SCROLL_TIMEOUT) - .setScrollId(scrollId) - .get()); - } - - private void clearScroll(String scrollId) { - if (scrollId != null) { - ClearScrollRequest request = new ClearScrollRequest(); - request.addScrollId(scrollId); - ClientHelper.executeWithHeaders(context.headers, ClientHelper.ML_ORIGIN, client, - () -> client.execute(ClearScrollAction.INSTANCE, request).actionGet()); - } - } - public List getFieldNames() { return Stream.concat(Arrays.stream(organicFeatures), Arrays.stream(processedFeatures)).collect(Collectors.toList()); } @@ -439,11 +407,11 @@ public boolean isTraining() { } public int getChecksum() { - return Arrays.hashCode(values); + return (int) getSortKey(); } - public String getSortKey() { - return (String) hit.getSortValues()[0]; + public long getSortKey() { + return (long) hit.getSortValues()[0]; } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 1b03544d015e5..4f8731f72ad65 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -58,7 +58,7 @@ public class ExtractedFieldsDetector { * Fields to ignore. These are mostly internal meta fields. */ private static final List IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", - "_source", "_type", "_uid", "_version", "_feature", "_ignored", "_nested_path", DestinationIndex.ID_COPY, + "_source", "_type", "_uid", "_version", "_feature", "_ignored", "_nested_path", DestinationIndex.INCREMENTAL_ID, "_data_stream_timestamp"); private final DataFrameAnalyticsConfig config; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/TestDocsIterator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/TestDocsIterator.java index a6b0d5e0e9315..5e86655f6a4ed 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/TestDocsIterator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/TestDocsIterator.java @@ -29,7 +29,7 @@ public class TestDocsIterator extends SearchAfterDocumentsIterator { private final DataFrameAnalyticsConfig config; - private String lastDocId; + private Long lastDocId; private final Map docValueFieldAndFormatPairs; TestDocsIterator(OriginSettingClient client, DataFrameAnalyticsConfig config, ExtractedFields extractedFields) { @@ -54,7 +54,7 @@ protected QueryBuilder getQuery() { @Override protected FieldSortBuilder sortField() { - return SortBuilders.fieldSort(DestinationIndex.ID_COPY).order(SortOrder.ASC); + return SortBuilders.fieldSort(DestinationIndex.INCREMENTAL_ID).order(SortOrder.ASC); } @Override @@ -69,7 +69,7 @@ protected Object[] searchAfterFields() { @Override protected void extractSearchAfterFields(SearchHit lastSearchHit) { - lastDocId = lastSearchHit.getId(); + lastDocId = (long) lastSearchHit.getSortValues()[0]; } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java index bee8a4370be56..1d6856af6fa58 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java @@ -161,16 +161,13 @@ private Map testCreateDestinationIndex(DataFrameAnalysis analysi CreateIndexRequest createIndexRequest = createIndexRequestCaptor.getValue(); - assertThat(createIndexRequest.settings().keySet(), - containsInAnyOrder("index.number_of_shards", "index.number_of_replicas", "index.sort.field", "index.sort.order")); + assertThat(createIndexRequest.settings().keySet(), containsInAnyOrder("index.number_of_shards", "index.number_of_replicas")); assertThat(createIndexRequest.settings().getAsInt("index.number_of_shards", -1), equalTo(5)); assertThat(createIndexRequest.settings().getAsInt("index.number_of_replicas", -1), equalTo(1)); - assertThat(createIndexRequest.settings().get("index.sort.field"), equalTo("ml__id_copy")); - assertThat(createIndexRequest.settings().get("index.sort.order"), equalTo("asc")); try (XContentParser parser = createParser(JsonXContent.jsonXContent, createIndexRequest.mappings())) { Map map = parser.map(); - assertThat(extractValue("_doc.properties.ml__id_copy.type", map), equalTo("keyword")); + assertThat(extractValue("_doc.properties.ml__incremental_id.type", map), equalTo("long")); assertThat(extractValue("_doc.properties.field_1", map), equalTo("field_1_mappings")); assertThat(extractValue("_doc.properties.field_2", map), equalTo("field_2_mappings")); assertThat(extractValue("_doc.properties.numerical-field.type", map), equalTo("integer")); @@ -280,7 +277,7 @@ private Map testUpdateMappingsToDestIndex(DataFrameAnalysis anal assertThat(putMappingRequest.indices(), arrayContaining(DEST_INDEX)); try (XContentParser parser = createParser(JsonXContent.jsonXContent, putMappingRequest.source())) { Map map = parser.map(); - assertThat(extractValue("properties.ml__id_copy.type", map), equalTo("keyword")); + assertThat(extractValue("properties.ml__incremental_id.type", map), equalTo("long")); return map; } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index ba7affcef0204..b327073b97028 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -6,10 +6,6 @@ package org.elasticsearch.xpack.ml.dataframe.extractor; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.ActionFuture; -import org.elasticsearch.action.search.ClearScrollAction; -import org.elasticsearch.action.search.ClearScrollRequest; -import org.elasticsearch.action.search.ClearScrollResponse; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.ShardSearchFailure; @@ -39,7 +35,6 @@ import org.elasticsearch.xpack.ml.extractor.SourceField; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.junit.Before; -import org.mockito.ArgumentCaptor; import java.io.IOException; import java.util.ArrayList; @@ -58,10 +53,8 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -76,9 +69,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { private int scrollSize; private Map headers; private TrainTestSplitterFactory trainTestSplitterFactory; - private ArgumentCaptor capturedClearScrollRequests; - private ActionFuture clearScrollFuture; - private int searchHitCounter; + private long searchHitCounter; @Before @SuppressWarnings("unchecked") @@ -100,10 +91,6 @@ public void setUpTests() { trainTestSplitterFactory = mock(TrainTestSplitterFactory.class); when(trainTestSplitterFactory.create()).thenReturn(row -> true); - - clearScrollFuture = mock(ActionFuture.class); - capturedClearScrollRequests = ArgumentCaptor.forClass(ClearScrollRequest.class); - when(client.execute(same(ClearScrollAction.INSTANCE), capturedClearScrollRequests.capture())).thenReturn(clearScrollFuture); } public void testTwoPageExtraction() throws IOException { @@ -144,67 +131,25 @@ public void testTwoPageExtraction() throws IOException { assertThat(rows.isEmpty(), is(true)); assertThat(dataExtractor.hasNext(), is(false)); - // Now let's assert we're sending the expected search request - assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + // Now let's assert we're sending the expected search requests + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(3)); String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); assertThat(searchRequest, containsString("allowPartialSearchResults=false")); assertThat(searchRequest, containsString("indices=[index-1,index-2]")); assertThat(searchRequest, containsString("\"size\":1000")); - assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}")); + assertThat(searchRequest, containsString("\"query\":{\"bool\":{\"filter\":[{\"match_all\":{\"boost\":1.0}},{\"range\":" + + "{\"ml__incremental_id\":{\"from\":0,\"to\":1000,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}}]")); assertThat(searchRequest, containsString("\"docvalue_fields\":[{\"field\":\"field_1\"},{\"field\":\"field_2\"}]")); assertThat(searchRequest, containsString("\"_source\":{\"includes\":[],\"excludes\":[]}")); - assertThat(searchRequest, containsString("\"sort\":[{\"ml__id_copy\":{\"order\":\"asc\"}}]")); + assertThat(searchRequest, containsString("\"sort\":[{\"ml__incremental_id\":{\"order\":\"asc\"}}]")); - // Check continue scroll requests had correct ids - assertThat(dataExtractor.capturedContinueScrollIds.size(), equalTo(2)); - assertThat(dataExtractor.capturedContinueScrollIds.get(0), equalTo(response1.getScrollId())); - assertThat(dataExtractor.capturedContinueScrollIds.get(1), equalTo(response2.getScrollId())); + searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"bool\":{\"filter\":[{\"match_all\":{\"boost\":1.0}},{\"range\":" + + "{\"ml__incremental_id\":{\"from\":3,\"to\":1003,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}}]")); - // Check we cleared the scroll with the latest scroll id - List capturedClearScrollRequests = getCapturedClearScrollIds(); - assertThat(capturedClearScrollRequests.size(), equalTo(1)); - assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId())); - } - - public void testRecoveryFromErrorOnSearchAfterRetry() throws IOException { - TestExtractor dataExtractor = createExtractor(true, false); - - // First search will fail - dataExtractor.setNextResponse(createResponseWithShardFailures()); - - // Next one will succeed - SearchResponse response = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); - dataExtractor.setNextResponse(response); - - // Last one - SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); - dataExtractor.setNextResponse(lastAndEmptyResponse); - - assertThat(dataExtractor.hasNext(), is(true)); - - // First batch expected as normally since we'll retry after the error - Optional> rows = dataExtractor.next(); - assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(1)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); - assertThat(dataExtractor.hasNext(), is(true)); - - // Next batch should return empty - rows = dataExtractor.next(); - assertThat(rows.isEmpty(), is(true)); - assertThat(dataExtractor.hasNext(), is(false)); - - // Check we cleared the scroll with the latest scroll id - List capturedClearScrollRequests = getCapturedClearScrollIds(); - assertThat(capturedClearScrollRequests.size(), equalTo(1)); - assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId())); - - // Notice we've done two searches here - assertThat(dataExtractor.capturedSearchRequests, hasSize(2)); - - // Assert the second search did not include a range query as the failure happened on the very first search - String searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", ""); - assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}")); + searchRequest = dataExtractor.capturedSearchRequests.get(2).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"bool\":{\"filter\":[{\"match_all\":{\"boost\":1.0}},{\"range\":" + + "{\"ml__incremental_id\":{\"from\":4,\"to\":1004,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}}]")); } public void testErrorOnSearchTwiceLeadsToFailure() { @@ -220,14 +165,14 @@ public void testErrorOnSearchTwiceLeadsToFailure() { expectThrows(RuntimeException.class, () -> dataExtractor.next()); } - public void testRecoveryFromErrorOnContinueScrollAfterRetry() throws IOException { + public void testRecoveryFromErrorOnSearch() throws IOException { TestExtractor dataExtractor = createExtractor(true, false); - // Search will succeed + // First search will succeed SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); dataExtractor.setNextResponse(response1); - // But the first continue scroll fails + // But the second search fails dataExtractor.setNextResponse(createResponseWithShardFailures()); // The next one succeeds and we shall recover @@ -260,45 +205,34 @@ public void testRecoveryFromErrorOnContinueScrollAfterRetry() throws IOException assertThat(rows.isEmpty(), is(true)); assertThat(dataExtractor.hasNext(), is(false)); - // Notice we've done two searches and two continues here - assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(2)); - assertThat(dataExtractor.capturedContinueScrollIds.size(), equalTo(2)); + // Notice we've done 4 searches + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(4)); - // Assert the second search continued from the latest successfully processed doc - String searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", ""); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); assertThat(searchRequest, containsString("\"query\":{\"bool\":{")); assertThat(searchRequest, containsString("{\"match_all\":{\"boost\":1.0}")); - assertThat(searchRequest, containsString("{\"range\":{\"ml__id_copy\":{\"from\":\"1\",\"to\":null,\"include_lower\":false")); - - // Check we cleared the scroll with the latest scroll id - List capturedClearScrollRequests = getCapturedClearScrollIds(); - assertThat(capturedClearScrollRequests.size(), equalTo(1)); - assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId())); - } - - public void testErrorOnContinueScrollTwiceLeadsToFailure() throws IOException { - TestExtractor dataExtractor = createExtractor(true, false); - - // Search will succeed - SearchResponse response1 = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); - dataExtractor.setNextResponse(response1); - - // But the first continue scroll fails - dataExtractor.setNextResponse(createResponseWithShardFailures()); - // As well as the second - dataExtractor.setNextResponse(createResponseWithShardFailures()); + assertThat(searchRequest, containsString( + "{\"range\":{\"ml__incremental_id\":{\"from\":0,\"to\":1000,\"include_lower\":true,\"include_upper\":false")); - assertThat(dataExtractor.hasNext(), is(true)); + // Assert the second search continued from the latest successfully processed doc + searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"bool\":{")); + assertThat(searchRequest, containsString("{\"match_all\":{\"boost\":1.0}")); + assertThat(searchRequest, containsString( + "{\"range\":{\"ml__incremental_id\":{\"from\":2,\"to\":1002,\"include_lower\":true,\"include_upper\":false")); - // First batch expected as normally since we'll retry after the error - Optional> rows = dataExtractor.next(); - assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(1)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); - assertThat(dataExtractor.hasNext(), is(true)); + // Assert the third search continued from the latest successfully processed doc + searchRequest = dataExtractor.capturedSearchRequests.get(2).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"bool\":{")); + assertThat(searchRequest, containsString("{\"match_all\":{\"boost\":1.0}")); + assertThat(searchRequest, containsString( + "{\"range\":{\"ml__incremental_id\":{\"from\":2,\"to\":1002,\"include_lower\":true,\"include_upper\":false")); - // We get second batch as we retried after the error - expectThrows(RuntimeException.class, () -> dataExtractor.next()); + searchRequest = dataExtractor.capturedSearchRequests.get(3).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"bool\":{")); + assertThat(searchRequest, containsString("{\"match_all\":{\"boost\":1.0}")); + assertThat(searchRequest, containsString( + "{\"range\":{\"ml__incremental_id\":{\"from\":3,\"to\":1003,\"include_lower\":true,\"include_upper\":false")); } public void testIncludeSourceIsFalseAndNoSourceFields() throws IOException { @@ -319,7 +253,7 @@ public void testIncludeSourceIsFalseAndNoSourceFields() throws IOException { assertThat(dataExtractor.next().isEmpty(), is(true)); assertThat(dataExtractor.hasNext(), is(false)); - assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(2)); String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); assertThat(searchRequest, containsString("\"docvalue_fields\":[{\"field\":\"field_1\"},{\"field\":\"field_2\"}]")); assertThat(searchRequest, containsString("\"_source\":false")); @@ -350,7 +284,7 @@ public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOExceptio assertThat(dataExtractor.next().isEmpty(), is(true)); assertThat(dataExtractor.hasNext(), is(false)); - assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(2)); String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); assertThat(searchRequest, containsString("\"docvalue_fields\":[{\"field\":\"field_1\"}]")); assertThat(searchRequest, containsString("\"_source\":{\"includes\":[\"field_2\"],\"excludes\":[]}")); @@ -592,14 +526,13 @@ private static PreProcessor buildPreProcessor(String inputField, String... outpu private SearchResponse createSearchResponse(List field1Values, List field2Values) { assertThat(field1Values.size(), equalTo(field2Values.size())); SearchResponse searchResponse = mock(SearchResponse.class); - when(searchResponse.getScrollId()).thenReturn(randomAlphaOfLength(1000)); List hits = new ArrayList<>(); for (int i = 0; i < field1Values.size(); i++) { SearchHitBuilder searchHitBuilder = new SearchHitBuilder(randomInt()); addField(searchHitBuilder, "field_1", field1Values.get(i)); addField(searchHitBuilder, "field_2", field2Values.get(i)); searchHitBuilder.setSource("{\"field_1\":" + field1Values.get(i) + ",\"field_2\":" + field2Values.get(i) + "}"); - searchHitBuilder.setStringSortValue(String.valueOf(searchHitCounter++)); + searchHitBuilder.setLongSortValue(searchHitCounter++); hits.add(searchHitBuilder.build()); } SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[0]), new TotalHits(hits.size(), TotalHits.Relation.EQUAL_TO), 1); @@ -625,15 +558,10 @@ private SearchResponse createResponseWithShardFailures() { return searchResponse; } - private List getCapturedClearScrollIds() { - return capturedClearScrollRequests.getAllValues().stream().map(r -> r.getScrollIds().get(0)).collect(Collectors.toList()); - } - private static class TestExtractor extends DataFrameDataExtractor { private Queue responses = new LinkedList<>(); private List capturedSearchRequests = new ArrayList<>(); - private List capturedContinueScrollIds = new ArrayList<>(); TestExtractor(Client client, DataFrameDataExtractorContext context) { super(client, context); @@ -652,16 +580,6 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest } return searchResponse; } - - @Override - protected SearchResponse executeSearchScrollRequest(String scrollId) { - capturedContinueScrollIds.add(scrollId); - SearchResponse searchResponse = responses.remove(); - if (searchResponse.getShardFailures() != null) { - throw new RuntimeException(searchResponse.getShardFailures()[0].getCause()); - } - return searchResponse; - } } private static class CategoricalPreProcessor implements PreProcessor { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java index b5a85e4106e3c..504fc56f5ec8e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java @@ -42,8 +42,8 @@ public SearchHitBuilder setSource(String sourceJson) { return this; } - public SearchHitBuilder setStringSortValue(String sortValue) { - hit.sortValues(new String[] { sortValue }, new DocValueFormat[] { DocValueFormat.RAW }); + public SearchHitBuilder setLongSortValue(Long sortValue) { + hit.sortValues(new Long[] { sortValue }, new DocValueFormat[] { DocValueFormat.RAW }); return this; }