diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryCollectorContext.java b/server/src/main/java/org/elasticsearch/search/query/QueryCollectorContext.java index b63739df76bfe..baf6889ba07f6 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryCollectorContext.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryCollectorContext.java @@ -24,6 +24,7 @@ import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.SimpleCollector; import org.apache.lucene.search.Weight; import org.elasticsearch.common.lucene.MinimumScoreCollector; import org.elasticsearch.common.lucene.search.FilteredCollector; @@ -41,6 +42,17 @@ import static org.elasticsearch.search.profile.query.CollectorResult.REASON_SEARCH_TERMINATE_AFTER_COUNT; abstract class QueryCollectorContext { + private static final Collector EMPTY_COLLECTOR = new SimpleCollector() { + @Override + public void collect(int doc) { + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + }; + private String profilerName; QueryCollectorContext(String profilerName) { @@ -124,7 +136,7 @@ Collector create(Collector in ) throws IOException { static QueryCollectorContext createMultiCollectorContext(Collection subs) { return new QueryCollectorContext(REASON_SEARCH_MULTI) { @Override - Collector create(Collector in) throws IOException { + Collector create(Collector in) { List subCollectors = new ArrayList<> (); subCollectors.add(in); subCollectors.addAll(subs); @@ -132,7 +144,7 @@ Collector create(Collector in) throws IOException { } @Override - protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) throws IOException { + protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) { final List subCollectors = new ArrayList<> (); subCollectors.add(in); if (subs.stream().anyMatch((col) -> col instanceof InternalProfileCollector == false)) { @@ -152,12 +164,20 @@ protected InternalProfileCollector createWithProfiler(InternalProfileCollector i */ static QueryCollectorContext createEarlyTerminationCollectorContext(int numHits) { return new QueryCollectorContext(REASON_SEARCH_TERMINATE_AFTER_COUNT) { - private EarlyTerminatingCollector collector; + private Collector collector; + /** + * Creates a {@link MultiCollector} to ensure that the {@link EarlyTerminatingCollector} + * can terminate the collection independently of the provided in {@link Collector}. + */ @Override - Collector create(Collector in) throws IOException { + Collector create(Collector in) { assert collector == null; - this.collector = new EarlyTerminatingCollector(in, numHits, true); + + List subCollectors = new ArrayList<> (); + subCollectors.add(new EarlyTerminatingCollector(EMPTY_COLLECTOR, numHits, true)); + subCollectors.add(in); + this.collector = MultiCollector.wrap(subCollectors); return collector; } }; diff --git a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index ea513ccaaccb1..456e937e99d96 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -452,6 +452,40 @@ public void testTerminateAfterEarlyTermination() throws Exception { assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); assertThat(collector.getTotalHits(), equalTo(1)); } + + // tests with trackTotalHits and terminateAfter + context.terminateAfter(10); + context.setSize(0); + for (int trackTotalHits : new int[] { -1, 3, 76, 100}) { + context.trackTotalHitsUpTo(trackTotalHits); + TotalHitCountCollector collector = new TotalHitCountCollector(); + context.queryCollectors().put(TotalHitCountCollector.class, collector); + QueryPhase.executeInternal(context); + assertTrue(context.queryResult().terminatedEarly()); + if (trackTotalHits == -1) { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); + } else { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10))); + } + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); + assertThat(collector.getTotalHits(), equalTo(10)); + } + + context.terminateAfter(7); + context.setSize(10); + for (int trackTotalHits : new int[] { -1, 3, 75, 100}) { + context.trackTotalHitsUpTo(trackTotalHits); + EarlyTerminatingCollector collector = new EarlyTerminatingCollector(new TotalHitCountCollector(), 1, false); + context.queryCollectors().put(EarlyTerminatingCollector.class, collector); + QueryPhase.executeInternal(context); + assertTrue(context.queryResult().terminatedEarly()); + if (trackTotalHits == -1) { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); + } else { + assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(7L)); + } + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(7)); + } reader.close(); dir.close(); }