From d29e6474a2d6952cd221a15c86be91924ab1cbd8 Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Tue, 1 Oct 2024 06:50:19 -0400 Subject: [PATCH] Support kNN filter on nested metadata Current knn search over nested vectors only supports filtering on parent's meatadata. This adds support for filtering over nested metadata. Closes #106994 --- docs/reference/query-dsl/knn-query.asciidoc | 37 +++--- .../search-your-data/knn-search.asciidoc | 46 +++++-- .../search.vectors/100_knn_nested_search.yml | 77 ++++++++++++ .../130_knn_query_nested_search.yml | 87 ++++++++++++++ .../org/elasticsearch/TransportVersions.java | 1 + .../action/search/DfsQueryPhase.java | 3 +- .../query/ToChildBlockJoinQueryBuilder.java | 113 ++++++++++++++++++ .../action/search/SearchCapabilities.java | 4 +- .../elasticsearch/search/SearchModule.java | 4 + .../vectors/KnnScoreDocQueryBuilder.java | 48 +++++++- .../search/vectors/KnnVectorQueryBuilder.java | 33 +++-- .../action/search/DfsQueryPhaseTests.java | 6 +- .../ToChildBlockJoinQueryBuilderTests.java | 53 ++++++++ .../vectors/KnnScoreDocQueryBuilderTests.java | 45 ++++++- 14 files changed, 512 insertions(+), 45 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java create mode 100644 server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java diff --git a/docs/reference/query-dsl/knn-query.asciidoc b/docs/reference/query-dsl/knn-query.asciidoc index daf9e9499a189..f64b4f3ae5628 100644 --- a/docs/reference/query-dsl/knn-query.asciidoc +++ b/docs/reference/query-dsl/knn-query.asciidoc @@ -240,26 +240,30 @@ to <>: * kNN search over nested dense_vectors diversifies the top results over the top-level document -* `filter` over the top-level document metadata is supported and acts as a -pre-filter -* `filter` over `nested` field metadata is not supported +* `filter` is supported both over the top-level document metadata and `nested` +field metadata. Filter acts as a pre-filter. -A sample query can look like below: +A sample query with filter over nested metadata can look like below: [source,js] ---- { - "query" : { - "nested" : { - "path" : "paragraph", - "query" : { - "knn": { - "query_vector": [ - 0.45, - 45 - ], - "field": "paragraph.vector", - "num_candidates": 2 + "query": { + "nested": { + "path": "paragraph", + "query": { + "knn": { + "query_vector": [ + 0.45, + 45 + ], + "field": "paragraph.vector", + "k": 10, + "filter": { + "match": { + "paragraph.language": "EN" + } + } } } } @@ -268,6 +272,9 @@ A sample query can look like below: ---- // NOTCONSOLE +The above query only considers vectors with `"paragraph.language": "EN"` +for scoring parents' documents. + [[knn-query-aggregations]] ==== Knn query with aggregations `knn` query calculates aggregations on top `k` documents from each shard. diff --git a/docs/reference/search/search-your-data/knn-search.asciidoc b/docs/reference/search/search-your-data/knn-search.asciidoc index 70cf9eec121d7..7531bb4703890 100644 --- a/docs/reference/search/search-your-data/knn-search.asciidoc +++ b/docs/reference/search/search-your-data/knn-search.asciidoc @@ -677,6 +677,9 @@ PUT passage_vectors "type": "hnsw" } }, + "language": { + "type" : "keyword" + }, "text": { "type": "text", "index": false @@ -695,9 +698,9 @@ With the above mapping, we can index multiple passage vectors along with storing ---- POST passage_vectors/_bulk?refresh=true { "index": { "_id": "1" } } -{ "full_text": "first paragraph another paragraph", "creation_time": "2019-05-04", "paragraph": [ { "vector": [ 0.45, 45 ], "text": "first paragraph", "paragraph_id": "1" }, { "vector": [ 0.8, 0.6 ], "text": "another paragraph", "paragraph_id": "2" } ] } +{ "full_text": "first paragraph another paragraph", "creation_time": "2019-05-04", "paragraph": [ { "vector": [ 0.45, 45 ], "text": "first paragraph", "paragraph_id": "1", "language": "EN" }, { "vector": [ 0.8, 0.6 ], "text": "another paragraph", "paragraph_id": "2", "language": "FR" } ] } { "index": { "_id": "2" } } -{ "full_text": "number one paragraph number two paragraph", "creation_time": "2020-05-04", "paragraph": [ { "vector": [ 1.2, 4.5 ], "text": "number one paragraph", "paragraph_id": "1" }, { "vector": [ -1, 42 ], "text": "number two paragraph", "paragraph_id": "2" } ] } +{ "full_text": "number one paragraph number two paragraph", "creation_time": "2020-05-04", "paragraph": [ { "vector": [ 1.2, 4.5 ], "text": "number one paragraph", "paragraph_id": "1", "language": "FR" }, { "vector": [ -1, 42 ], "text": "number two paragraph", "paragraph_id": "2", "language": "EN" } ] } ---- //TEST[continued] //TEST[s/\.\.\.//] @@ -776,12 +779,8 @@ scored by their nearest passage vector (e.g. `"paragraph.vector"`). ---- // TESTRESPONSE[s/"took": 4/"took" : "$body.took"/] -What if you wanted to filter by some top-level document metadata? You can do this by adding `filter` to your -`knn` clause. - - -NOTE: `filter` will always be over the top-level document metadata. This means you cannot filter based on `nested` - field metadata. +What if you wanted to filter by some document metadata? You can do this by adding `filter` to your +`knn` clause. `filter` can be run based on both: the top-level document metadata and `nested` field metadata. [source,console] ---- @@ -858,6 +857,37 @@ Now we have filtered based on the top level `"creation_time"` and only one docum ---- // TESTRESPONSE[s/"took": 4/"took" : "$body.took"/] + +Filtering by nested field metadata: `paragraph.language` makes kNN search only consider vectors with this metadata +for scoring parents' documents: + +[source,console] +---- +POST passage_vectors/_search +{ + "fields": [ + "creation_time", + "full_text" + ], + "_source": false, + "knn": { + "query_vector": [ + 0.45, + 45 + ], + "field": "paragraph.vector", + "k": 2, + "num_candidates": 2, + "filter": { + "match" : { + "paragraph.language" : "EN" + } + } + } +} +---- +//TEST[continued] + [discrete] [[nested-knn-search-inner-hits]] ==== Nested kNN Search with Inner hits diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml index d627be2fb15c3..3010b1681ec38 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml @@ -16,6 +16,8 @@ setup: nested: type: nested properties: + language: + type: keyword paragraph_id: type: keyword vector: @@ -37,8 +39,10 @@ setup: nested: - paragraph_id: 0 vector: [230.0, 300.33, -34.8988, 15.555, -200.0] + language: EN - paragraph_id: 1 vector: [240.0, 300, -3, 1, -20] + language: FR - do: index: @@ -49,10 +53,13 @@ setup: nested: - paragraph_id: 0 vector: [-0.5, 100.0, -13, 14.8, -156.0] + language: EN - paragraph_id: 2 vector: [0, 100.0, 0, 14.8, -156.0] + language: EN - paragraph_id: 3 vector: [0, 1.0, 0, 1.8, -15.0] + language: FR - do: index: @@ -63,6 +70,7 @@ setup: nested: - paragraph_id: 0 vector: [0.5, 111.3, -13.0, 14.8, -156.0] + language: FR - do: indices.refresh: {} @@ -461,3 +469,72 @@ setup: - match: {hits.hits.0._id: "2"} - length: {hits.hits.0.inner_hits.nested.hits.hits: 1} - match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"} + + +--- +"Test filter on nested fields": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ knn_filter_on_nested_fields ] + test_runner_features: ["capabilities", "close_to"] + reason: "Capability for filtering on nested fields required" + + - do: + search: + index: test + body: + _source: false + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 10 + filter: { match: { nested.language: "EN" } } + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" } + - close_to: { hits.hits.0._score: { value: 0.0182, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0182, error: 0.0001 } } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + + - do: + search: + index: test + body: + _source: false + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 10 + filter: { match: { nested.language: "FR" } } + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml index 79ff3f61742f8..9848bbca9f014 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml @@ -16,6 +16,8 @@ setup: nested: type: nested properties: + language: + type: keyword paragraph_id: type: keyword vector: @@ -38,8 +40,10 @@ setup: nested: - paragraph_id: 0 vector: [230.0, 300.33, -34.8988, 15.555, -200.0] + language: EN - paragraph_id: 1 vector: [240.0, 300, -3, 1, -20] + language: FR - do: index: @@ -50,10 +54,13 @@ setup: nested: - paragraph_id: 0 vector: [-0.5, 100.0, -13, 14.8, -156.0] + language: EN - paragraph_id: 2 vector: [0, 100.0, 0, 14.8, -156.0] + language: EN - paragraph_id: 3 vector: [0, 1.0, 0, 1.8, -15.0] + language: FR - do: index: @@ -64,6 +71,7 @@ setup: nested: - paragraph_id: 0 vector: [0.5, 111.3, -13.0, 14.8, -156.0] + language: FR - do: indices.refresh: {} @@ -406,3 +414,82 @@ setup: - match: {hits.total.value: 1} - match: {hits.hits.0._id: "2"} + + +--- +"Test filter on nested fields": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ knn_filter_on_nested_fields ] + test_runner_features: ["capabilities", "close_to"] + reason: "Capability for filtering on nested fields required" + + - do: + search: + index: test + body: + _source: false + query: + nested: + path: nested + query: + knn: + boost: 2 + field: nested.vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 10 + filter: + match: + nested.language: "EN" + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false } + + - match: {hits.total.value: 2} + - match: {hits.hits.0._id: "2"} + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" } + - close_to: { hits.hits.0._score: { value: 0.02036, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.02036, error: 0.0001 } } + - match: {hits.hits.1._id: "1"} + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" } + + - do: + search: + index: test + body: + _source: false + query: + nested: + path: nested + query: + knn: + boost: 2 + field: nested.vector + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 10 + filter: + match: + nested.language: "FR" + inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language" ], _source: false } + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - close_to: { hits.hits.0._score: { value: 0.0041, error: 0.0001 } } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0041, error: 0.0001 } } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 0ced472ea310c..0d5b4618db2e5 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -231,6 +231,7 @@ static TransportVersion def(int id) { public static final TransportVersion CCS_REMOTE_TELEMETRY_STATS = def(8_755_00_0); public static final TransportVersion ESQL_CCS_EXECUTION_INFO = def(8_756_00_0); public static final TransportVersion REGEX_AND_RANGE_INTERVAL_QUERIES = def(8_757_00_0); + public static final TransportVersion TO_CHILD_BLOCK_JOIN_QUERY = def(8_758_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index e0e240be0377a..56157523cc8e6 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -157,7 +157,8 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { scoreDocs.toArray(Lucene.EMPTY_SCORE_DOCS), source.knnSearch().get(i).getField(), source.knnSearch().get(i).getQueryVector(), - source.knnSearch().get(i).getSimilarity() + source.knnSearch().get(i).getSimilarity(), + source.knnSearch().get(i).getFilterQueries() ).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName()); if (nestedPath != null) { query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit()); diff --git a/server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java new file mode 100644 index 0000000000000..1e6e6feee3f42 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * A query returns child documents whose parent matches the provided query. + * This query is used only for internal purposes and is not exposed to a user. + */ +public class ToChildBlockJoinQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "to_child_block_join"; + private final QueryBuilder parentQueryBuilder; + + public ToChildBlockJoinQueryBuilder(QueryBuilder parentQueryBuilder) { + this.parentQueryBuilder = parentQueryBuilder; + } + + public ToChildBlockJoinQueryBuilder(StreamInput in) throws IOException { + super(in); + parentQueryBuilder = in.readNamedWriteable(QueryBuilder.class); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(parentQueryBuilder); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field("query"); + parentQueryBuilder.toXContent(builder, params); + boostAndQueryNameToXContent(builder); + builder.endObject(); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + QueryBuilder rewritten = parentQueryBuilder.rewrite(queryRewriteContext); + if (rewritten instanceof MatchNoneQueryBuilder) { + return rewritten; + } + if (rewritten != parentQueryBuilder) { + return new ToChildBlockJoinQueryBuilder(rewritten); + } + return this; + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + final Query parentFilter; + NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper(); + if (originalObjectMapper != null) { + try { + // we are in a nested context, to get the parent filter we need to go up one level + context.nestedScope().previousLevel(); + NestedObjectMapper objectMapper = context.nestedScope().getObjectMapper(); + parentFilter = objectMapper == null + ? Queries.newNonNestedFilter(context.indexVersionCreated()) + : objectMapper.nestedTypeFilter(); + } finally { + context.nestedScope().nextLevel(originalObjectMapper); + } + } else { + // we are NOT in a nested context, coming from the top level knn search + parentFilter = Queries.newNonNestedFilter(context.indexVersionCreated()); + } + final BitSetProducer parentBitSet = context.bitsetFilter(parentFilter); + Query parentQuery = parentQueryBuilder.toQuery(context); + // ensure that parentQuery only applies to parent docs by adding parentFilter + return new ToChildBlockJoinQuery(Queries.filtered(parentQuery, parentFilter), parentBitSet); + } + + @Override + protected boolean doEquals(ToChildBlockJoinQueryBuilder other) { + return Objects.equals(parentQueryBuilder, other.parentQueryBuilder); + } + + @Override + protected int doHashCode() { + return Objects.hash(parentQueryBuilder); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY; + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 45fd6afe4fca6..6aa8a8ade4e66 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -21,5 +21,7 @@ private SearchCapabilities() {} /** Support regex and range match rules in interval queries. */ private static final String RANGE_REGEX_INTERVAL_QUERY_CAPABILITY = "range_regexp_interval_queries"; - public static final Set CAPABILITIES = Set.of(RANGE_REGEX_INTERVAL_QUERY_CAPABILITY); + private static final String KNN_FILTER_ON_NESTED_FIELDS_CAPABILITY = "knn_filter_on_nested_fields"; + + public static final Set CAPABILITIES = Set.of(RANGE_REGEX_INTERVAL_QUERY_CAPABILITY, KNN_FILTER_ON_NESTED_FIELDS_CAPABILITY); } diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 6308b19358410..14f4c8627c9b0 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -67,6 +67,7 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.query.TermsSetQueryBuilder; +import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder; import org.elasticsearch.index.query.TypeQueryV7Builder; import org.elasticsearch.index.query.WildcardQueryBuilder; import org.elasticsearch.index.query.WrapperQueryBuilder; @@ -1203,6 +1204,9 @@ private void registerQueryParsers(List plugins) { registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> { throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly"); })); + registerQuery(new QuerySpec<>(ToChildBlockJoinQueryBuilder.NAME, ToChildBlockJoinQueryBuilder::new, parser -> { + throw new IllegalArgumentException("[to_child_block_join] queries cannot be provided directly"); + })); registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index 15052fdad3818..f9aad2d774719 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -18,14 +18,17 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Arrays; +import java.util.List; import java.util.Objects; /** @@ -39,6 +42,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder filterQueries; /** * Creates a query builder. @@ -46,11 +50,13 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder filterQueries) { this.scoreDocs = scoreDocs; this.fieldName = fieldName; this.queryVector = queryVector; this.vectorSimilarity = vectorSimilarity; + this.filterQueries = filterQueries; } public KnnScoreDocQueryBuilder(StreamInput in) throws IOException { @@ -77,6 +83,11 @@ public KnnScoreDocQueryBuilder(StreamInput in) throws IOException { } else { this.vectorSimilarity = null; } + if (in.getTransportVersion().onOrAfter(TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY)){ + this.filterQueries = readQueries(in); + } else { + this.filterQueries = List.of(); + } } @Override @@ -100,6 +111,8 @@ Float vectorSimilarity() { return vectorSimilarity; } + + @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeArray(Lucene::writeScoreDoc, scoreDocs); @@ -120,6 +133,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { || out.getTransportVersion().isPatchFrom(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS_BACKPORT_8_15)) { out.writeOptionalFloat(vectorSimilarity); } + if (out.getTransportVersion().onOrAfter(TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY)){ + writeQueries(out, filterQueries); + } } @Override @@ -139,6 +155,13 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep if (vectorSimilarity != null) { builder.field("similarity", vectorSimilarity); } + if (filterQueries.isEmpty() == false) { + builder.startArray("filter"); + for (QueryBuilder filterQuery : filterQueries) { + filterQuery.toXContent(builder, params); + } + builder.endArray(); + } boostAndQueryNameToXContent(builder); builder.endObject(); } @@ -164,7 +187,23 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return new MatchNoneQueryBuilder("The \"" + getName() + "\" query was rewritten to a \"match_none\" query."); } if (queryRewriteContext.convertToInnerHitsRewriteContext() != null && queryVector != null && fieldName != null) { - return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity); + QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity); + if (filterQueries.isEmpty()) { + return exactKnnQuery; + } else { + BoolQueryBuilder filterQueryChildren = new BoolQueryBuilder(); + for (QueryBuilder query : this.filterQueries) { + filterQueryChildren.filter(query); + } + // filter can be both over parents or nested docs, + // so add them as should clauses to a filter + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + boolQuery.must(exactKnnQuery); + boolQuery.filter(new BoolQueryBuilder() + .should(filterQueryChildren) + .should(new ToChildBlockJoinQueryBuilder(filterQueryChildren))); + return boolQuery; + } } return super.doRewrite(queryRewriteContext); } @@ -205,7 +244,8 @@ protected boolean doEquals(KnnScoreDocQueryBuilder other) { } return Objects.equals(fieldName, other.fieldName) && Objects.equals(queryVector, other.queryVector) - && Objects.equals(vectorSimilarity, other.vectorSimilarity); + && Objects.equals(vectorSimilarity, other.vectorSimilarity) + && Objects.equals(filterQueries, other.filterQueries); } @Override @@ -215,7 +255,7 @@ protected int doHashCode() { int hashCode = Objects.hash(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); result = 31 * result + hashCode; } - return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector)); + return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector), filterQueries); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 1d064210ae704..9196e223f2e96 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -28,10 +28,12 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder; import org.elasticsearch.index.search.NestedHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; @@ -401,9 +403,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { boost ).queryName(queryName).addFilterQueries(filterQueries); } - if (ctx.convertToInnerHitsRewriteContext() != null) { - return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName); - } boolean changed = false; List rewrittenQueries = new ArrayList<>(filterQueries.size()); for (QueryBuilder query : filterQueries) { @@ -422,6 +421,25 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { .queryName(queryName) .addFilterQueries(rewrittenQueries); } + if (ctx.convertToInnerHitsRewriteContext() != null) { + QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity); + if (filterQueries.isEmpty()) { + return exactKnnQuery; + } else { + BoolQueryBuilder filterQueryChildren = new BoolQueryBuilder(); + for (QueryBuilder query : this.filterQueries) { + filterQueryChildren.filter(query); + } + // filter can be both over parents or nested docs, + // so add them as should clauses to a filter + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + boolQuery.must(exactKnnQuery); + boolQuery.filter(new BoolQueryBuilder() + .should(filterQueryChildren) + .should(new ToChildBlockJoinQueryBuilder(filterQueryChildren))); + return boolQuery; + } + } return this; } @@ -482,14 +500,13 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { parentBitSet = context.bitsetFilter(parentFilter); if (filterQuery != null) { NestedHelper nestedHelper = new NestedHelper(context.nestedLookup(), context::isFieldMapped); - // We treat the provided filter as a filter over PARENT documents, so if it might match nested documents - // we need to adjust it. - if (nestedHelper.mightMatchNestedDocs(filterQuery)) { + // If filter matches non-nested docs, we assume this is a filter over parents docs, + // so we will modify it accordingly: matching parents docs with join to its child docs + if (nestedHelper.mightMatchNonNestedDocs(filterQuery, parentPath)) { // Ensure that the query only returns parent documents matching `filterQuery` filterQuery = Queries.filtered(filterQuery, parentFilter); + filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } - // Now join the filterQuery & parentFilter to provide the matching blocks of children - filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet); } diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 90174a89209b8..3aa0f99f68572 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -355,13 +355,15 @@ public void testRewriteShardSearchRequestWithRank() { new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) }, "vector", VectorData.fromFloats(new float[] { 0.0f }), - null + null, + List.of() ); KnnScoreDocQueryBuilder ksdqb1 = new KnnScoreDocQueryBuilder( new ScoreDoc[] { new ScoreDoc(1, 2.0f, 1) }, "vector2", VectorData.fromFloats(new float[] { 0.0f }), - null + null, + List.of() ); assertEquals( List.of(bm25, ksdqb0, ksdqb1), diff --git a/server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java new file mode 100644 index 0000000000000..add16a9ebf70b --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.elasticsearch.test.AbstractQueryTestCase; + +import java.io.IOException; + +import static org.hamcrest.CoreMatchers.instanceOf; + +public class ToChildBlockJoinQueryBuilderTests extends AbstractQueryTestCase< ToChildBlockJoinQueryBuilder> { + @Override + protected ToChildBlockJoinQueryBuilder doCreateTestQueryBuilder() { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + return new ToChildBlockJoinQueryBuilder(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); + } + + @Override + protected void doAssertLuceneQuery(ToChildBlockJoinQueryBuilder queryBuilder, Query query, + SearchExecutionContext context) throws IOException { + assertThat(query, instanceOf(ToChildBlockJoinQuery.class)); + } + + @Override + public void testUnknownField() throws IOException { + // Test isn't relevant, since query is never parsed from xContent + } + + @Override + public void testUnknownObjectException() { + // Test isn't relevant, since query is never parsed from xContent + } + + @Override + public void testFromXContent() throws IOException { + // Test isn't relevant, since query is never parsed from xContent + } + + @Override + public void testValidOutput() { + // Test isn't relevant, since query is never parsed from xContent + } + +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java index 18d5c8c85fbec..7b609c6101610 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java @@ -24,9 +24,11 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.InnerHitsRewriteContext; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.test.AbstractQueryTestCase; @@ -54,11 +56,20 @@ protected KnnScoreDocQueryBuilder doCreateTestQueryBuilder() { for (int doc = 0; doc < numDocs; doc++) { scoreDocs.add(new ScoreDoc(doc, randomFloat())); } + List filters = new ArrayList<>(); + if (randomBoolean()) { + int numFilters = randomIntBetween(1, 5); + for (int i = 0; i < numFilters; i++) { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); + } + } return new KnnScoreDocQueryBuilder( scoreDocs.toArray(new ScoreDoc[0]), randomBoolean() ? "field" : null, randomBoolean() ? VectorData.fromFloats(randomVector(10)) : null, - randomBoolean() ? randomFloat() : null + randomBoolean() ? randomFloat() : null, + filters ); } @@ -68,7 +79,8 @@ public void testValidOutput() { new ScoreDoc[] { new ScoreDoc(0, 4.25f), new ScoreDoc(5, 1.6f) }, "field", VectorData.fromFloats(new float[] { 1.0f, 2.0f }), - null + null, + List.of() ); String expected = """ { @@ -159,7 +171,8 @@ public void testRewriteToMatchNone() throws IOException { new ScoreDoc[0], randomBoolean() ? "field" : null, randomBoolean() ? VectorData.fromFloats(randomVector(10)) : null, - randomBoolean() ? randomFloat() : null + randomBoolean() ? randomFloat() : null, + List.of() ); QueryRewriteContext context = randomBoolean() ? new InnerHitsRewriteContext(createSearchExecutionContext().getParserConfig(), System::currentTimeMillis) @@ -170,15 +183,33 @@ public void testRewriteToMatchNone() throws IOException { public void testRewriteForInnerHits() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); InnerHitsRewriteContext innerHitsRewriteContext = new InnerHitsRewriteContext(context.getParserConfig(), System::currentTimeMillis); + List filters = new ArrayList<>(); + boolean hasFilters = randomBoolean(); + if (hasFilters) { + int numFilters = randomIntBetween(1, 5); + for (int i = 0; i < numFilters; i++) { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); + } + } + KnnScoreDocQueryBuilder queryBuilder = new KnnScoreDocQueryBuilder( new ScoreDoc[] { new ScoreDoc(0, 4.25f), new ScoreDoc(5, 1.6f) }, randomAlphaOfLength(10), VectorData.fromFloats(randomVector(10)), - randomBoolean() ? randomFloat() : null + randomBoolean() ? randomFloat() : null, + filters ); queryBuilder.boost(randomFloat()); queryBuilder.queryName(randomAlphaOfLength(10)); QueryBuilder rewritten = queryBuilder.rewrite(innerHitsRewriteContext); + + if (hasFilters) { + assertTrue(rewritten instanceof BoolQueryBuilder); + BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) rewritten; + rewritten = boolQueryBuilder.must().get(0); + } + assertTrue(rewritten instanceof ExactKnnQueryBuilder); ExactKnnQueryBuilder exactKnnQueryBuilder = (ExactKnnQueryBuilder) rewritten; assertEquals(queryBuilder.queryVector(), exactKnnQueryBuilder.getQuery()); @@ -228,7 +259,8 @@ public void testScoreDocQueryWeightCount() throws IOException { scoreDocs, "field", VectorData.fromFloats(randomVector(10)), - null + null, + List.of() ); Query query = queryBuilder.doToQuery(context); final Weight w = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -276,7 +308,8 @@ public void testScoreDocQuery() throws IOException { scoreDocs, "field", VectorData.fromFloats(randomVector(10)), - null + null, + List.of() ); final Query query = queryBuilder.doToQuery(context); final Weight w = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);