From bbbc58f37a073880ec7923f31ad599efd94f64bc Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 8 Dec 2023 15:34:07 -0500 Subject: [PATCH] Fix NPE & empty result handling in CountOnlyQueryPhaseResultConsumer (#103203) Query results can be "null" in that they are a null instance, containing no information from the shard. Additionally, if we see no results, its an empty reduce phase. This code was introduced into 8.12, which has yet to be released, but flagging as a bug for clarity. --- docs/changelog/103203.yaml | 5 + .../CountOnlyQueryPhaseResultConsumer.java | 9 +- .../search/query/QuerySearchResult.java | 6 + ...ountOnlyQueryPhaseResultConsumerTests.java | 133 ++++++++++++++++++ 4 files changed, 151 insertions(+), 2 deletions(-) create mode 100644 docs/changelog/103203.yaml create mode 100644 server/src/test/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumerTests.java diff --git a/docs/changelog/103203.yaml b/docs/changelog/103203.yaml new file mode 100644 index 0000000000000..d2aa3e9961c6a --- /dev/null +++ b/docs/changelog/103203.yaml @@ -0,0 +1,5 @@ +pr: 103203 +summary: Fix NPE & empty result handling in `CountOnlyQueryPhaseResultConsumer` +area: Search +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java index 1e67522f6a671..13972ea2bf64a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java @@ -49,12 +49,17 @@ Stream getSuccessfulResults() { public void consumeResult(SearchPhaseResult result, Runnable next) { assert results.contains(result.getShardIndex()) == false : "shardIndex: " + result.getShardIndex() + " is already set"; results.add(result.getShardIndex()); + progressListener.notifyQueryResult(result.getShardIndex(), result.queryResult()); + // We have an empty result, track that we saw it for this shard and continue; + if (result.queryResult().isNull()) { + next.run(); + return; + } // set the relation to the first non-equal relation relationAtomicReference.compareAndSet(TotalHits.Relation.EQUAL_TO, result.queryResult().getTotalHits().relation); totalHits.add(result.queryResult().getTotalHits().value); terminatedEarly.compareAndSet(false, (result.queryResult().terminatedEarly() != null && result.queryResult().terminatedEarly())); timedOut.compareAndSet(false, result.queryResult().searchTimedOut()); - progressListener.notifyQueryResult(result.getShardIndex(), result.queryResult()); next.run(); } @@ -80,7 +85,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { 1, 0, 0, - false + results.isEmpty() ); if (progressListener != SearchProgressListener.NOOP) { progressListener.notifyFinalReduce( diff --git a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java index 40d4e37045016..7bcafe7005047 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -149,6 +150,7 @@ public void terminatedEarly(boolean terminatedEarly) { this.terminatedEarly = terminatedEarly; } + @Nullable public Boolean terminatedEarly() { return this.terminatedEarly; } @@ -204,10 +206,12 @@ public void setRankShardResult(RankShardResult rankShardResult) { this.rankShardResult = rankShardResult; } + @Nullable public RankShardResult getRankShardResult() { return rankShardResult; } + @Nullable public DocValueFormat[] sortValueFormats() { return sortValueFormats; } @@ -252,6 +256,7 @@ public void aggregations(InternalAggregations aggregations) { hasAggs = aggregations != null; } + @Nullable public DelayableWriteable aggregations() { return aggregations; } @@ -455,6 +460,7 @@ public void writeToNoId(StreamOutput out) throws IOException { } } + @Nullable public TotalHits getTotalHits() { return totalHits; } diff --git a/server/src/test/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumerTests.java b/server/src/test/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumerTests.java new file mode 100644 index 0000000000000..33e6096bab763 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumerTests.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.search; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +public class CountOnlyQueryPhaseResultConsumerTests extends ESTestCase { + + public void testProgressListenerExceptionsAreCaught() throws Exception { + ThrowingSearchProgressListener searchProgressListener = new ThrowingSearchProgressListener(); + + List searchShards = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + searchShards.add(new SearchShard(null, new ShardId("index", "uuid", i))); + } + long timestamp = randomLongBetween(1000, Long.MAX_VALUE - 1000); + TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( + timestamp, + timestamp, + () -> timestamp + 1000 + ); + searchProgressListener.notifyListShards(searchShards, Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, timeProvider); + + CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(searchProgressListener, 10); + try { + AtomicInteger nextCounter = new AtomicInteger(0); + for (int i = 0; i < 10; i++) { + SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(i); + queryPhaseResultConsumer.consumeResult(querySearchResult, nextCounter::incrementAndGet); + } + + assertEquals(10, searchProgressListener.onQueryResult.get()); + queryPhaseResultConsumer.reduce(); + assertEquals(1, searchProgressListener.onFinalReduce.get()); + assertEquals(10, nextCounter.get()); + } finally { + queryPhaseResultConsumer.decRef(); + } + } + + public void testNullShardResultHandling() throws Exception { + CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(SearchProgressListener.NOOP, 10); + try { + AtomicInteger nextCounter = new AtomicInteger(0); + for (int i = 0; i < 10; i++) { + SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null); + QuerySearchResult querySearchResult = QuerySearchResult.nullInstance(); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(i); + queryPhaseResultConsumer.consumeResult(querySearchResult, nextCounter::incrementAndGet); + } + var reducePhase = queryPhaseResultConsumer.reduce(); + assertEquals(0, reducePhase.totalHits().value); + assertEquals(TotalHits.Relation.EQUAL_TO, reducePhase.totalHits().relation); + assertFalse(reducePhase.isEmptyResult()); + assertEquals(10, nextCounter.get()); + } finally { + queryPhaseResultConsumer.decRef(); + } + } + + public void testEmptyResults() throws Exception { + CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(SearchProgressListener.NOOP, 10); + try { + var reducePhase = queryPhaseResultConsumer.reduce(); + assertEquals(0, reducePhase.totalHits().value); + assertEquals(TotalHits.Relation.EQUAL_TO, reducePhase.totalHits().relation); + assertTrue(reducePhase.isEmptyResult()); + } finally { + queryPhaseResultConsumer.decRef(); + } + } + + private static class ThrowingSearchProgressListener extends SearchProgressListener { + private final AtomicInteger onQueryResult = new AtomicInteger(0); + private final AtomicInteger onPartialReduce = new AtomicInteger(0); + private final AtomicInteger onFinalReduce = new AtomicInteger(0); + + @Override + protected void onListShards( + List shards, + List skippedShards, + SearchResponse.Clusters clusters, + boolean fetchPhase, + TransportSearchAction.SearchTimeProvider timeProvider + ) { + throw new UnsupportedOperationException(); + } + + @Override + protected void onQueryResult(int shardIndex, QuerySearchResult queryResult) { + onQueryResult.incrementAndGet(); + throw new UnsupportedOperationException(); + } + + @Override + protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + onPartialReduce.incrementAndGet(); + throw new UnsupportedOperationException(); + } + + @Override + protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + onFinalReduce.incrementAndGet(); + throw new UnsupportedOperationException(); + } + } +}