Skip to content

Commit

Permalink
Fix NPE & empty result handling in CountOnlyQueryPhaseResultConsumer (e…
Browse files Browse the repository at this point in the history
…lastic#103203)

Query results can be "null" in that they are a null instance, containing
no information from the shard. 

Additionally, if we see no results, its an empty reduce phase.

This code was introduced into 8.12, which has yet to be released, but
flagging as a bug for clarity.
  • Loading branch information
benwtrent authored Dec 8, 2023
1 parent 47b5753 commit bbbc58f
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/103203.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 103203
summary: Fix NPE & empty result handling in `CountOnlyQueryPhaseResultConsumer`
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,17 @@ Stream<SearchPhaseResult> getSuccessfulResults() {
public void consumeResult(SearchPhaseResult result, Runnable next) {
assert results.contains(result.getShardIndex()) == false : "shardIndex: " + result.getShardIndex() + " is already set";
results.add(result.getShardIndex());
progressListener.notifyQueryResult(result.getShardIndex(), result.queryResult());
// We have an empty result, track that we saw it for this shard and continue;
if (result.queryResult().isNull()) {
next.run();
return;
}
// set the relation to the first non-equal relation
relationAtomicReference.compareAndSet(TotalHits.Relation.EQUAL_TO, result.queryResult().getTotalHits().relation);
totalHits.add(result.queryResult().getTotalHits().value);
terminatedEarly.compareAndSet(false, (result.queryResult().terminatedEarly() != null && result.queryResult().terminatedEarly()));
timedOut.compareAndSet(false, result.queryResult().searchTimedOut());
progressListener.notifyQueryResult(result.getShardIndex(), result.queryResult());
next.run();
}

Expand All @@ -80,7 +85,7 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
1,
0,
0,
false
results.isEmpty()
);
if (progressListener != SearchProgressListener.NOOP) {
progressListener.notifyFinalReduce(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
Expand Down Expand Up @@ -149,6 +150,7 @@ public void terminatedEarly(boolean terminatedEarly) {
this.terminatedEarly = terminatedEarly;
}

@Nullable
public Boolean terminatedEarly() {
return this.terminatedEarly;
}
Expand Down Expand Up @@ -204,10 +206,12 @@ public void setRankShardResult(RankShardResult rankShardResult) {
this.rankShardResult = rankShardResult;
}

@Nullable
public RankShardResult getRankShardResult() {
return rankShardResult;
}

@Nullable
public DocValueFormat[] sortValueFormats() {
return sortValueFormats;
}
Expand Down Expand Up @@ -252,6 +256,7 @@ public void aggregations(InternalAggregations aggregations) {
hasAggs = aggregations != null;
}

@Nullable
public DelayableWriteable<InternalAggregations> aggregations() {
return aggregations;
}
Expand Down Expand Up @@ -455,6 +460,7 @@ public void writeToNoId(StreamOutput out) throws IOException {
}
}

@Nullable
public TotalHits getTotalHits() {
return totalHits;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.action.search;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.test.ESTestCase;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class CountOnlyQueryPhaseResultConsumerTests extends ESTestCase {

public void testProgressListenerExceptionsAreCaught() throws Exception {
ThrowingSearchProgressListener searchProgressListener = new ThrowingSearchProgressListener();

List<SearchShard> searchShards = new ArrayList<>();
for (int i = 0; i < 10; i++) {
searchShards.add(new SearchShard(null, new ShardId("index", "uuid", i)));
}
long timestamp = randomLongBetween(1000, Long.MAX_VALUE - 1000);
TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider(
timestamp,
timestamp,
() -> timestamp + 1000
);
searchProgressListener.notifyListShards(searchShards, Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, timeProvider);

CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(searchProgressListener, 10);
try {
AtomicInteger nextCounter = new AtomicInteger(0);
for (int i = 0; i < 10; i++) {
SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(i);
queryPhaseResultConsumer.consumeResult(querySearchResult, nextCounter::incrementAndGet);
}

assertEquals(10, searchProgressListener.onQueryResult.get());
queryPhaseResultConsumer.reduce();
assertEquals(1, searchProgressListener.onFinalReduce.get());
assertEquals(10, nextCounter.get());
} finally {
queryPhaseResultConsumer.decRef();
}
}

public void testNullShardResultHandling() throws Exception {
CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(SearchProgressListener.NOOP, 10);
try {
AtomicInteger nextCounter = new AtomicInteger(0);
for (int i = 0; i < 10; i++) {
SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null);
QuerySearchResult querySearchResult = QuerySearchResult.nullInstance();
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(i);
queryPhaseResultConsumer.consumeResult(querySearchResult, nextCounter::incrementAndGet);
}
var reducePhase = queryPhaseResultConsumer.reduce();
assertEquals(0, reducePhase.totalHits().value);
assertEquals(TotalHits.Relation.EQUAL_TO, reducePhase.totalHits().relation);
assertFalse(reducePhase.isEmptyResult());
assertEquals(10, nextCounter.get());
} finally {
queryPhaseResultConsumer.decRef();
}
}

public void testEmptyResults() throws Exception {
CountOnlyQueryPhaseResultConsumer queryPhaseResultConsumer = new CountOnlyQueryPhaseResultConsumer(SearchProgressListener.NOOP, 10);
try {
var reducePhase = queryPhaseResultConsumer.reduce();
assertEquals(0, reducePhase.totalHits().value);
assertEquals(TotalHits.Relation.EQUAL_TO, reducePhase.totalHits().relation);
assertTrue(reducePhase.isEmptyResult());
} finally {
queryPhaseResultConsumer.decRef();
}
}

private static class ThrowingSearchProgressListener extends SearchProgressListener {
private final AtomicInteger onQueryResult = new AtomicInteger(0);
private final AtomicInteger onPartialReduce = new AtomicInteger(0);
private final AtomicInteger onFinalReduce = new AtomicInteger(0);

@Override
protected void onListShards(
List<SearchShard> shards,
List<SearchShard> skippedShards,
SearchResponse.Clusters clusters,
boolean fetchPhase,
TransportSearchAction.SearchTimeProvider timeProvider
) {
throw new UnsupportedOperationException();
}

@Override
protected void onQueryResult(int shardIndex, QuerySearchResult queryResult) {
onQueryResult.incrementAndGet();
throw new UnsupportedOperationException();
}

@Override
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
onPartialReduce.incrementAndGet();
throw new UnsupportedOperationException();
}

@Override
protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
onFinalReduce.incrementAndGet();
throw new UnsupportedOperationException();
}
}
}

0 comments on commit bbbc58f

Please sign in to comment.