From 78de1797fd192fd781bcd67fc0c0b9818479d3a3 Mon Sep 17 00:00:00 2001 From: foxstarius <116147736+foxstarius@users.noreply.github.com> Date: Sun, 5 May 2024 14:08:57 +0200 Subject: [PATCH] feat: Add kNN Query (#198) kNN search was added to Elasticsearch in v8.0 Co-authored-by: kennylindahl Co-authored-by: Andreas Franzon --- src/core/index.js | 2 + src/core/knn.js | 138 +++++++++++++++++++++ src/core/request-body-search.js | 25 +++- src/index.d.ts | 103 +++++++++++++-- src/index.js | 8 ++ test/core-test/knn.test.js | 130 +++++++++++++++++++ test/core-test/request-body-search.test.js | 90 +++++++++++++- 7 files changed, 485 insertions(+), 11 deletions(-) create mode 100644 src/core/knn.js create mode 100644 test/core-test/knn.test.js diff --git a/src/core/index.js b/src/core/index.js index 80013be..1849a51 100644 --- a/src/core/index.js +++ b/src/core/index.js @@ -8,6 +8,8 @@ exports.Aggregation = require('./aggregation'); exports.Query = require('./query'); +exports.KNN = require('./knn'); + exports.Suggester = require('./suggester'); exports.Script = require('./script'); diff --git a/src/core/knn.js b/src/core/knn.js new file mode 100644 index 0000000..4c1c346 --- /dev/null +++ b/src/core/knn.js @@ -0,0 +1,138 @@ +'use strict'; + +const { recursiveToJSON, checkType } = require('./util'); +const Query = require('./query'); + +/** + * Class representing a k-Nearest Neighbors (k-NN) query. + * This class extends the Query class to support the specifics of k-NN search, including setting up the field, + * query vector, number of neighbors (k), and number of candidates. + * + * @example + * const qry = esb.kNN('my_field', 100, 1000).vector([1,2,3]); + * const qry = esb.kNN('my_field', 100, 1000).queryVectorBuilder('model_123', 'Sample model text'); + * + * NOTE: kNN search was added to Elasticsearch in v8.0 + * + * [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) + */ +class KNN { + // eslint-disable-next-line require-jsdoc + constructor(field, k, numCandidates) { + if (k > numCandidates) + throw new Error('KNN numCandidates cannot be less than k'); + this._body = {}; + this._body.field = field; + this._body.k = k; + this._body.filter = []; + this._body.num_candidates = numCandidates; + } + + /** + * Sets the query vector for the k-NN search. + * @param {Array} vector - The query vector. + * @returns {KNN} Returns the instance of KNN for method chaining. + */ + queryVector(vector) { + if (this._body.query_vector_builder) + throw new Error( + 'cannot provide both query_vector_builder and query_vector' + ); + this._body.query_vector = vector; + return this; + } + + /** + * Sets the query vector builder for the k-NN search. + * This method configures a query vector builder using a specified model ID and model text. + * It's important to note that either a direct query vector or a query vector builder can be + * provided, but not both. + * + * @param {string} modelId - The ID of the model to be used for generating the query vector. + * @param {string} modelText - The text input based on which the query vector is generated. + * @returns {KNN} Returns the instance of KNN for method chaining. + * @throws {Error} Throws an error if both query_vector_builder and query_vector are provided. + * + * @example + * let knn = new esb.KNN().queryVectorBuilder('model_123', 'Sample model text'); + */ + queryVectorBuilder(modelId, modelText) { + if (this._body.query_vector) + throw new Error( + 'cannot provide both query_vector_builder and query_vector' + ); + this._body.query_vector_builder = { + text_embeddings: { + model_id: modelId, + model_text: modelText + } + }; + return this; + } + + /** + * Adds one or more filter queries to the k-NN search. + * + * This method is designed to apply filters to the k-NN search. It accepts either a single + * query or an array of queries. Each query acts as a filter, refining the search results + * according to the specified conditions. These queries must be instances of the `Query` class. + * If any provided query is not an instance of `Query`, a TypeError is thrown. + * + * @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering. + * @returns {KNN} Returns `this` to allow method chaining. + * @throws {TypeError} If any of the provided queries is not an instance of `Query`. + * + * @example + * let knn = new esb.KNN().filter(new esb.TermQuery('field', 'value')); // Applying a single filter query + * + * @example + * let knn = new esb.KNN().filter([ + * new esb.TermQuery('field1', 'value1'), + * new esb.TermQuery('field2', 'value2') + * ]); // Applying multiple filter queries + */ + filter(queries) { + const queryArray = Array.isArray(queries) ? queries : [queries]; + queryArray.forEach(query => { + checkType(query, Query); + this._body.filter.push(query); + }); + return this; + } + + /** + * Sets the field to perform the k-NN search on. + * @param {number} boost - The number of the boost + * @returns {KNN} Returns the instance of KNN for method chaining. + */ + boost(boost) { + this._body.boost = boost; + return this; + } + + /** + * Sets the field to perform the k-NN search on. + * @param {number} similarity - The number of the similarity + * @returns {KNN} Returns the instance of KNN for method chaining. + */ + similarity(similarity) { + this._body.similarity = similarity; + return this; + } + + /** + * Override default `toJSON` to return DSL representation for the `query` + * + * @override + * @returns {Object} returns an Object which maps to the elasticsearch query DSL + */ + toJSON() { + if (!this._body.query_vector && !this._body.query_vector_builder) + throw new Error( + 'either query_vector_builder or query_vector must be provided' + ); + return recursiveToJSON(this._body); + } +} + +module.exports = KNN; diff --git a/src/core/request-body-search.js b/src/core/request-body-search.js index e69a3c2..87a9232 100644 --- a/src/core/request-body-search.js +++ b/src/core/request-body-search.js @@ -10,7 +10,8 @@ const Query = require('./query'), Rescore = require('./rescore'), Sort = require('./sort'), Highlight = require('./highlight'), - InnerHits = require('./inner-hits'); + InnerHits = require('./inner-hits'), + KNN = require('./knn'); const { checkType, setDefault, recursiveToJSON } = require('./util'); const RuntimeField = require('./runtime-field'); @@ -70,6 +71,7 @@ class RequestBodySearch { constructor() { // Maybe accept some optional parameter? this._body = {}; + this._knn = []; this._aggs = []; this._suggests = []; this._suggestText = null; @@ -88,6 +90,21 @@ class RequestBodySearch { return this; } + /** + * Sets knn on the search request body. + * + * @param {Knn|Knn[]} knn + * @returns {RequestBodySearch} returns `this` so that calls can be chained. + */ + kNN(knn) { + const knns = Array.isArray(knn) ? knn : [knn]; + knns.forEach(_knn => { + checkType(_knn, KNN); + this._knn.push(_knn); + }); + return this; + } + /** * Sets aggregation on the request body. * Alias for method `aggregation` @@ -867,6 +884,12 @@ class RequestBodySearch { toJSON() { const dsl = recursiveToJSON(this._body); + if (!isEmpty(this._knn)) + dsl.knn = + this._knn.length == 1 + ? recMerge(this._knn) + : this._knn.map(knn => recursiveToJSON(knn)); + if (!isEmpty(this._aggs)) dsl.aggs = recMerge(this._aggs); if (!isEmpty(this._suggests) || !isNil(this._suggestText)) { diff --git a/src/index.d.ts b/src/index.d.ts index e9d7d48..c14842e 100644 --- a/src/index.d.ts +++ b/src/index.d.ts @@ -18,6 +18,13 @@ declare namespace esb { */ query(query: Query): this; + /** + * Sets knn on the request body. + * + * @param {KNN|KNN[]} knn + */ + kNN(knn: KNN | KNN[]): this; + /** * Sets aggregation on the request body. * Alias for method `aggregation` @@ -3141,7 +3148,7 @@ declare namespace esb { /** * Sets the script used to compute the score of documents returned by the query. - * + * * @param {Script} script A valid `Script` object */ script(script: Script): this; @@ -3761,6 +3768,84 @@ declare namespace esb { spanQry?: SpanQueryBase ): SpanFieldMaskingQuery; + /** + * Knn performs k-nearest neighbor (KNN) searches. + * This class allows configuring the KNN search with various parameters such as field, query vector, + * number of nearest neighbors (k), number of candidates, boost factor, and similarity metric. + * + * NOTE: Only available in Elasticsearch v8.0+ + */ + export class KNN { + /** + * Creates an instance of Knn, initializing the internal state for the k-NN search. + * + * @param {string} field - (Optional) The field against which to perform the k-NN search. + * @param {number} k - (Optional) The number of nearest neighbors to retrieve. + * @param {number} numCandidates - (Optional) The number of candidate neighbors to consider during the search. + * @throws {Error} If the number of candidates (numCandidates) is less than the number of neighbors (k). + */ + constructor(field: string, k: number, numCandidates: number); + + /** + * Sets the query vector for the KNN search, an array of numbers representing the reference point. + * + * @param {number[]} vector + */ + queryVector(vector: number[]): this; + + /** + * Sets the query vector builder for the k-NN search. + * This method configures a query vector builder using a specified model ID and model text. + * Note that either a direct query vector or a query vector builder can be provided, but not both. + * + * @param {string} modelId - The ID of the model used for generating the query vector. + * @param {string} modelText - The text input based on which the query vector is generated. + * @returns {KNN} Returns the instance of Knn for method chaining. + * @throws {Error} If both query_vector_builder and query_vector are provided. + */ + queryVectorBuilder(modelId: string, modelText: string): this; + + /** + * Adds one or more filter queries to the k-NN search. + * This method is designed to apply filters to the k-NN search. It accepts either a single + * query or an array of queries. Each query acts as a filter, refining the search results + * according to the specified conditions. These queries must be instances of the `Query` class. + * + * @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering. + * @returns {KNN} Returns `this` to allow method chaining. + * @throws {TypeError} If any of the provided queries is not an instance of `Query`. + */ + filter(queries: Query | Query[]): this; + + /** + * Applies a boost factor to the query to influence the relevance score of returned documents. + * + * @param {number} boost + */ + boost(boost: number): this; + + /** + * Sets the similarity metric used in the KNN algorithm to calculate similarity. + * + * @param {number} similarity + */ + similarity(similarity: number): this; + + /** + * Override default `toJSON` to return DSL representation for the `query` + * + * @override + */ + toJSON(): object; + } + + /** + * Factory function to instantiate a new Knn object. + * + * @returns {KNN} + */ + export function kNN(field: string, k: number, numCandidates: number): KNN; + /** * Base class implementation for all aggregation types. * @@ -3913,9 +3998,9 @@ declare namespace esb { /** * A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents. * These values can be extracted either from specific numeric fields in the documents. - * + * * [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html) - * + * * Added in Elasticsearch v6.4.0 * [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html) * @@ -3929,7 +4014,7 @@ declare namespace esb { /** * Sets the value - * + * * @param {string | Script} value Field name or script to be used as the value * @param {number=} missing Sets the missing parameter which defines how documents * that are missing a value should be treated. @@ -3939,7 +4024,7 @@ declare namespace esb { /** * Sets the weight - * + * * @param {string | Script} weight Field name or script to be used as the weight * @param {number=} missing Sets the missing parameter which defines how documents * that are missing a value should be treated. @@ -3969,9 +4054,9 @@ declare namespace esb { /** * A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents. * These values can be extracted either from specific numeric fields in the documents. - * + * * [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html) - * + * * Added in Elasticsearch v6.4.0 * [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html) * @@ -8922,7 +9007,7 @@ declare namespace esb { /** * Sets the type of the runtime field. - * + * * @param {string} type One of `boolean`, `composite`, `date`, `double`, `geo_point`, `ip`, `keyword`, `long`, `lookup`. * @returns {void} */ @@ -8930,7 +9015,7 @@ declare namespace esb { /** * Sets the source of the script. - * + * * @param {string} script * @returns {void} */ diff --git a/src/index.js b/src/index.js index 77c49df..01f5c85 100644 --- a/src/index.js +++ b/src/index.js @@ -15,6 +15,7 @@ const { RuntimeField, SearchTemplate, Query, + KNN, util: { constructorWrapper } } = require('./core'); @@ -343,6 +344,13 @@ exports.spanWithinQuery = constructorWrapper(SpanWithinQuery); exports.SpanFieldMaskingQuery = SpanFieldMaskingQuery; exports.spanFieldMaskingQuery = constructorWrapper(SpanFieldMaskingQuery); + +/* ============ ============ ============ */ +/* ======== KNN ======== */ +/* ============ ============ ============ */ +exports.KNN = KNN; +exports.kNN = constructorWrapper(KNN); + /* ============ ============ ============ */ /* ======== Metrics Aggregations ======== */ /* ============ ============ ============ */ diff --git a/test/core-test/knn.test.js b/test/core-test/knn.test.js new file mode 100644 index 0000000..25f5719 --- /dev/null +++ b/test/core-test/knn.test.js @@ -0,0 +1,130 @@ +import test from 'ava'; +import { KNN, TermQuery } from '../../src'; + +test('knn can be instantiated', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const json = knn.toJSON(); + t.truthy(json); +}); + +test('knn throws error if numCandidates is less than k', t => { + const error = t.throws(() => + new KNN('my_field', 10, 5).queryVector([1, 2, 3]) + ); + t.is(error.message, 'KNN numCandidates cannot be less than k'); +}); + +test('knn queryVector sets correctly', t => { + const vector = [1, 2, 3]; + const knn = new KNN('my_field', 5, 10).queryVector(vector); + const json = knn.toJSON(); + t.deepEqual(json.query_vector, vector); +}); + +test('knn queryVectorBuilder sets correctly', t => { + const modelId = 'model_123'; + const modelText = 'Sample model text'; + const knn = new KNN('my_field', 5, 10).queryVectorBuilder( + modelId, + modelText + ); + const json = knn.toJSON(); + t.deepEqual(json.query_vector_builder.text_embeddings, { + model_id: modelId, + model_text: modelText + }); +}); + +test('knn filter method adds queries correctly', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const query = new TermQuery('field', 'value'); + knn.filter(query); + const json = knn.toJSON(); + t.deepEqual(json.filter, [query.toJSON()]); +}); + +test('knn filter method adds queries as array correctly', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const query1 = new TermQuery('field1', 'value1'); + const query2 = new TermQuery('field2', 'value2'); + knn.filter([query1, query2]); + const json = knn.toJSON(); + t.deepEqual(json.filter, [query1.toJSON(), query2.toJSON()]); +}); + +test('knn boost method sets correctly', t => { + const boostValue = 1.5; + const knn = new KNN('my_field', 5, 10) + .boost(boostValue) + .queryVector([1, 2, 3]); + const json = knn.toJSON(); + t.is(json.boost, boostValue); +}); + +test('knn similarity method sets correctly', t => { + const similarityValue = 0.8; + const knn = new KNN('my_field', 5, 10) + .similarity(similarityValue) + .queryVector([1, 2, 3]); + const json = knn.toJSON(); + t.is(json.similarity, similarityValue); +}); + +test('knn toJSON method returns correct DSL', t => { + const knn = new KNN('my_field', 5, 10) + .queryVector([1, 2, 3]) + .filter(new TermQuery('field', 'value')); + + const expectedDSL = { + field: 'my_field', + k: 5, + num_candidates: 10, + query_vector: [1, 2, 3], + filter: [{ term: { field: 'value' } }] + }; + + t.deepEqual(knn.toJSON(), expectedDSL); +}); + +test('knn toJSON throws error if neither query_vector nor query_vector_builder is provided', t => { + const knn = new KNN('my_field', 5, 10); + const error = t.throws(() => knn.toJSON()); + t.is( + error.message, + 'either query_vector_builder or query_vector must be provided' + ); +}); + +test('knn throws error when first queryVector and then queryVectorBuilder are set', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const error = t.throws(() => { + knn.queryVectorBuilder('model_123', 'Sample model text'); + }); + t.is( + error.message, + 'cannot provide both query_vector_builder and query_vector' + ); +}); + +test('knn throws error when first queryVectorBuilder and then queryVector are set', t => { + const knn = new KNN('my_field', 5, 10).queryVectorBuilder( + 'model_123', + 'Sample model text' + ); + const error = t.throws(() => { + knn.queryVector([1, 2, 3]); + }); + t.is( + error.message, + 'cannot provide both query_vector_builder and query_vector' + ); +}); + +test('knn filter throws TypeError if non-Query type is passed', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const error = t.throws(() => { + knn.filter('not_a_query'); + }, TypeError); + + t.is(error.message, 'Argument must be an instance of Query'); +}); diff --git a/test/core-test/request-body-search.test.js b/test/core-test/request-body-search.test.js index 47b1191..2327606 100644 --- a/test/core-test/request-body-search.test.js +++ b/test/core-test/request-body-search.test.js @@ -16,7 +16,8 @@ import { Highlight, Rescore, InnerHits, - RuntimeField + RuntimeField, + KNN } from '../../src'; import { illegalParamType, makeSetsOptionMacro } from '../_macros'; @@ -71,6 +72,11 @@ const innerHits = new InnerHits() .name('last_tweets') .size(5) .sort(new Sort('date', 'desc')); +const kNNVectorBuilder = new KNN('my_field', 5, 10) + .similarity(0.6) + .filter(new TermQuery('field', 'value')) + .queryVectorBuilder('model_123', 'Sample model text'); +const kNNVector = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); const instance = new RequestBodySearch(); @@ -83,9 +89,11 @@ test(illegalParamType, instance, 'scriptFields', 'Object'); test(illegalParamType, instance, 'highlight', 'Highlight'); test(illegalParamType, instance, 'rescore', 'Rescore'); test(illegalParamType, instance, 'postFilter', 'Query'); +test(illegalParamType, instance, 'kNN', 'KNN'); test(setsOption, 'query', { param: searchQry }); test(setsOption, 'aggregation', { param: aggA, keyName: 'aggs' }); test(setsOption, 'agg', { param: aggA, keyName: 'aggs' }); +test(setsOption, 'kNN', { param: kNNVectorBuilder, keyName: 'knn' }); test(setsOption, 'suggest', { param: suggest }); test(setsOption, 'suggestText', { param: 'suggest-text', @@ -347,3 +355,83 @@ test('sets multiple indices_boost', t => { }; t.deepEqual(value, expected); }); + +test('kNN setup query vector builder', t => { + const value = new RequestBodySearch().kNN(kNNVectorBuilder).toJSON(); + const expected = { + knn: { + field: 'my_field', + k: 5, + filter: [ + { + term: { + field: 'value' + } + } + ], + num_candidates: 10, + query_vector_builder: { + text_embeddings: { + model_id: 'model_123', + model_text: 'Sample model text' + } + }, + similarity: 0.6 + } + }; + + t.deepEqual(value, expected); +}); + +test('kNN setup query vector', t => { + const value = new RequestBodySearch().kNN(kNNVector).toJSON(); + const expected = { + knn: { + field: 'my_field', + k: 5, + filter: [], + num_candidates: 10, + query_vector: [1, 2, 3] + } + }; + + t.deepEqual(value, expected); +}); + +test('kNN setup query vector array', t => { + const value = new RequestBodySearch() + .kNN([kNNVector, kNNVectorBuilder]) + .toJSON(); + const expected = { + knn: [ + { + field: 'my_field', + k: 5, + filter: [], + num_candidates: 10, + query_vector: [1, 2, 3] + }, + { + field: 'my_field', + filter: [ + { + term: { + field: 'value' + } + } + ], + k: 5, + num_candidates: 10, + query_vector_builder: { + text_embeddings: { + model_id: 'model_123', + model_text: 'Sample model text' + } + }, + similarity: 0.6 + } + ] + }; + + t.deepEqual(value, expected); +});