diff --git a/docs/changelog/103203.yaml b/docs/changelog/103203.yaml new file mode 100644 index 0000000000000..d2aa3e9961c6a --- /dev/null +++ b/docs/changelog/103203.yaml @@ -0,0 +1,5 @@ +pr: 103203 +summary: Fix NPE & empty result handling in `CountOnlyQueryPhaseResultConsumer` +area: Search +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java index 1e67522f6a671..13972ea2bf64a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumer.java @@ -49,12 +49,17 @@ Stream getSuccessfulResults() { public void consumeResult(SearchPhaseResult result, Runnable next) { assert results.contains(result.getShardIndex()) == false : "shardIndex: " + result.getShardIndex() + " is already set"; results.add(result.getShardIndex()); + progressListener.notifyQueryResult(result.getShardIndex(), result.queryResult()); + // We have an empty result, track that we saw it for this shard and continue; + if (result.queryResult().isNull()) { + next.run(); + return; + } // set the relation to the first non-equal relation relationAtomicReference.compareAndSet(TotalHits.Relation.EQUAL_TO, result.queryResult().getTotalHits().relation); totalHits.add(result.queryResult().getTotalHits().value); terminatedEarly.compareAndSet(false, (result.queryResult().terminatedEarly() != null && result.queryResult().terminatedEarly())); timedOut.compareAndSet(false, result.queryResult().searchTimedOut()); - progressListener.notifyQueryResult(result.getShardIndex(), result.queryResult()); next.run(); } @@ -80,7 +85,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { 1, 0, 0, - false + results.isEmpty() ); if (progressListener != SearchProgressListener.NOOP) { progressListener.notifyFinalReduce( diff --git a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java index 40d4e37045016..7bcafe7005047 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -149,6 +150,7 @@ public void terminatedEarly(boolean terminatedEarly) { this.terminatedEarly = terminatedEarly; } + @Nullable public Boolean terminatedEarly() { return this.terminatedEarly; } @@ -204,10 +206,12 @@ public void setRankShardResult(RankShardResult rankShardResult) { this.rankShardResult = rankShardResult; } + @Nullable public RankShardResult getRankShardResult() { return rankShardResult; } + @Nullable public DocValueFormat[] sortValueFormats() { return sortValueFormats; } @@ -252,6 +256,7 @@ public void aggregations(InternalAggregations aggregations) { hasAggs = aggregations != null; } + @Nullable public DelayableWriteable aggregations() { return aggregations; } @@ -455,6 +460,7 @@ public void writeToNoId(StreamOutput out) throws IOException { } } + @Nullable public TotalHits getTotalHits() { return totalHits; } diff --git a/server/src/test/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumerTests.java b/server/src/test/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumerTests.java new file mode 100644 index 0000000000000..33e6096bab763 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/search/CountOnlyQueryPhaseResultConsumerTests.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.search; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +public class CountOnlyQueryPhaseResultConsumerTests extends ESTestCase { + + public void testProgressListenerExceptionsAreCaught() throws Exception { + ThrowingSearchProgressListener searchProgressListener = new ThrowingSearchProgressListener(); + + List searchShards = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + searchShards.add(new SearchShard(null, new ShardId("index", "uuid", i))); + } + long timestamp = randomLongBetween(1000, Long.MAX_VALUE - 1000); + TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( + timestamp, + timestamp, + () -> timestamp + 1000 + ); + searchProgressListener.notifyListShards(searchShards, Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, timeProvider); + + CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(searchProgressListener, 10); + try { + AtomicInteger nextCounter = new AtomicInteger(0); + for (int i = 0; i < 10; i++) { + SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(i); + queryPhaseResultConsumer.consumeResult(querySearchResult, nextCounter::incrementAndGet); + } + + assertEquals(10, searchProgressListener.onQueryResult.get()); + queryPhaseResultConsumer.reduce(); + assertEquals(1, searchProgressListener.onFinalReduce.get()); + assertEquals(10, nextCounter.get()); + } finally { + queryPhaseResultConsumer.decRef(); + } + } + + public void testNullShardResultHandling() throws Exception { + CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(SearchProgressListener.NOOP, 10); + try { + AtomicInteger nextCounter = new AtomicInteger(0); + for (int i = 0; i < 10; i++) { + SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null); + QuerySearchResult querySearchResult = QuerySearchResult.nullInstance(); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(i); + queryPhaseResultConsumer.consumeResult(querySearchResult, nextCounter::incrementAndGet); + } + var reducePhase = queryPhaseResultConsumer.reduce(); + assertEquals(0, reducePhase.totalHits().value); + assertEquals(TotalHits.Relation.EQUAL_TO, reducePhase.totalHits().relation); + assertFalse(reducePhase.isEmptyResult()); + assertEquals(10, nextCounter.get()); + } finally { + queryPhaseResultConsumer.decRef(); + } + } + + public void testEmptyResults() throws Exception { + CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(SearchProgressListener.NOOP, 10); + try { + var reducePhase = queryPhaseResultConsumer.reduce(); + assertEquals(0, reducePhase.totalHits().value); + assertEquals(TotalHits.Relation.EQUAL_TO, reducePhase.totalHits().relation); + assertTrue(reducePhase.isEmptyResult()); + } finally { + queryPhaseResultConsumer.decRef(); + } + } + + private static class ThrowingSearchProgressListener extends SearchProgressListener { + private final AtomicInteger onQueryResult = new AtomicInteger(0); + private final AtomicInteger onPartialReduce = new AtomicInteger(0); + private final AtomicInteger onFinalReduce = new AtomicInteger(0); + + @Override + protected void onListShards( + List shards, + List skippedShards, + SearchResponse.Clusters clusters, + boolean fetchPhase, + TransportSearchAction.SearchTimeProvider timeProvider + ) { + throw new UnsupportedOperationException(); + } + + @Override + protected void onQueryResult(int shardIndex, QuerySearchResult queryResult) { + onQueryResult.incrementAndGet(); + throw new UnsupportedOperationException(); + } + + @Override + protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + onPartialReduce.incrementAndGet(); + throw new UnsupportedOperationException(); + } + + @Override + protected void onFinalReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { + onFinalReduce.incrementAndGet(); + throw new UnsupportedOperationException(); + } + } +}