From 1567e7f9df0a4f851f549a7ced537f2e81745b18 Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Wed, 25 Sep 2024 11:39:48 -0700 Subject: [PATCH] Fix logic bug and add unit test --- .../execution/operators/metadata_filtering.rs | 466 +++++++++++++++++- 1 file changed, 449 insertions(+), 17 deletions(-) diff --git a/rust/worker/src/execution/operators/metadata_filtering.rs b/rust/worker/src/execution/operators/metadata_filtering.rs index 9930be04e25..22661cfc6f4 100644 --- a/rust/worker/src/execution/operators/metadata_filtering.rs +++ b/rust/worker/src/execution/operators/metadata_filtering.rs @@ -187,7 +187,7 @@ impl<'me> MetadataLogReader<'me> { GreaterThan => (Excluded(val), Unbounded), GreaterThanOrEqual => (Included(val), Unbounded), LessThan => (Unbounded, Excluded(val)), - LessThanOrEqual => (Unbounded, Excluded(val)), + LessThanOrEqual => (Unbounded, Included(val)), }; Ok(btm .range(bounds) @@ -345,7 +345,7 @@ impl<'me> RoaringMetadataFilter<'me> for DirectWhereComparison { WhereComparison::Set(set_operator, metadata_set_value) => match set_operator { In => { Box::pin( - Where::conjunction( + Where::disjunction( metadata_set_value .into_vec() .into_iter() @@ -485,6 +485,8 @@ impl Operator for MetadataFilte } (Some(c), None) | (None, Some(c)) => c, _ => { + // User does not provide any filter, which is interpreted as a full scan + // Create a trivially true where clause, which will lead to a full scan conjunction = Where::conjunction(vec![]); &conjunction } @@ -572,9 +574,11 @@ mod test { use chroma_cache::{cache::Cache, config::CacheConfig, config::UnboundedCacheConfig}; use chroma_storage::{local::LocalStorage, Storage}; use chroma_types::{ - Chunk, DirectDocumentComparison, DirectWhereComparison, LogRecord, MetadataValue, - Operation, OperationRecord, PrimitiveOperator, UpdateMetadataValue, Where, WhereComparison, + BooleanOperator, Chunk, DirectDocumentComparison, DirectWhereComparison, DocumentOperator, + LogRecord, MetadataSetValue, MetadataValue, Operation, OperationRecord, PrimitiveOperator, + SetOperator, UpdateMetadataValue, Where, WhereChildren, WhereComparison, }; + use roaring::RoaringBitmap; use std::{collections::HashMap, str::FromStr}; use uuid::Uuid; @@ -653,7 +657,7 @@ mod test { }, ]; let data: Chunk = Chunk::new(data.into()); - let mut record_segment_reader: Option = None; + let record_segment_reader: Option; match RecordSegmentReader::from_segment(&record_segment, &blockfile_provider).await { Ok(reader) => { record_segment_reader = Some(reader); @@ -859,7 +863,7 @@ mod test { }, ]; let data: Chunk = Chunk::new(data.into()); - let mut record_segment_reader: Option = None; + let record_segment_reader: Option; match RecordSegmentReader::from_segment(&record_segment, &blockfile_provider).await { Ok(reader) => { record_segment_reader = Some(reader); @@ -1044,7 +1048,7 @@ mod test { }, ]; let data: Chunk = Chunk::new(data.into()); - let mut record_segment_reader: Option = None; + let record_segment_reader: Option; match RecordSegmentReader::from_segment(&record_segment, &blockfile_provider).await { Ok(reader) => { record_segment_reader = Some(reader); @@ -1205,7 +1209,7 @@ mod test { .await .expect("Error creating segment writer"); let mut logs = Vec::new(); - for i in 0..60 { + for i in 1..=60 { let mut meta = HashMap::new(); if i % 2 == 0 { meta.insert("even".to_string(), UpdateMetadataValue::Bool(i % 4 == 0)); @@ -1229,7 +1233,7 @@ mod test { }); } let data: Chunk = Chunk::new(logs.into()); - let mut record_segment_reader: Option = None; + let record_segment_reader: Option; match RecordSegmentReader::from_segment(&record_segment, &blockfile_provider).await { Ok(reader) => { record_segment_reader = Some(reader); @@ -1283,7 +1287,7 @@ mod test { .expect("Flush metadata segment writer failed"); } let mut logs = Vec::new(); - for i in 60..120 { + for i in 61..=120 { let mut meta = HashMap::new(); if i % 2 == 0 { meta.insert("even".to_string(), UpdateMetadataValue::Bool(i % 4 == 0)); @@ -1306,7 +1310,7 @@ mod test { }, }); } - for i in 0..20 { + for i in 1..=20 { logs.push(LogRecord { log_offset: 120 + i, record: OperationRecord { @@ -1321,14 +1325,340 @@ mod test { } let data: Chunk = Chunk::new(logs.into()); let operator = MetadataFilteringOperator::new(); - let where_clause: Where = Where::DirectWhereComparison(DirectWhereComparison { - key: String::from("bye"), + + // Test set summary: + // Total records count: 120, with id 1-120 + // Records 1-60 are compacted + // Records 61-120 are in the log + // Records with id % 6 == 1 are deleted + // Record metadata has the following keys + // - even: only exists for even ids, value is a boolean matching id % 4 == 0 + // - mod_three_{id % 3}: a floating point value converted from id + // - mod_five: an integer value matching id % 5 + // Record document has format "-->{id}<--" + + let existing = (1..=120).filter(|i| i % 6 != 0); + + // A full scan should yield all existing records that are not yet deleted + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + None, + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!(res.offset_ids, existing.clone().collect()); + + // A full scan within the user specified ids should yield matching records + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + Some((31..=90).map(|i| format!("id_{}", i)).collect()), + None, + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing.clone().filter(|i| &31 <= i && i <= &90).collect() + ); + + // A $eq check on metadata should yield matching records + let where_clause = Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Primitive(PrimitiveOperator::Equal, MetadataValue::Int(2)), + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + Some(where_clause), + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing.clone().filter(|i| i % 5 == 2).collect() + ); + + // A $ne check on metadata should yield matching records + let where_clause = Where::DirectWhereComparison(DirectWhereComparison { + key: "even".to_string(), comp: WhereComparison::Primitive( - PrimitiveOperator::Equal, - MetadataValue::Str(String::from("world")), + PrimitiveOperator::NotEqual, + MetadataValue::Bool(false), + ), + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + Some(where_clause), + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing + .clone() + .filter(|i| i % 2 == 1 || i % 4 == 0) + .collect() + ); + + // A $lte check on metadata should yield matching records + let where_clause = Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_three_2".to_string(), + comp: WhereComparison::Primitive( + PrimitiveOperator::LessThanOrEqual, + MetadataValue::Float(50.0), ), }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + Some(where_clause), + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing + .clone() + .filter(|i| i % 3 == 2 && i <= &50) + .collect() + ); + + // A $contains check on document should yield matching records + let where_doc_clause = Where::DirectWhereDocumentComparison(DirectDocumentComparison { + operator: DocumentOperator::Contains, + document: String::from("6<-"), + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + None, + Some(where_doc_clause), + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing.clone().filter(|i| i % 10 == 6).collect() + ); + + // A $not_contains check on document should yield matching records + let where_doc_clause = Where::DirectWhereDocumentComparison(DirectDocumentComparison { + operator: DocumentOperator::NotContains, + document: String::from("3<-"), + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + None, + Some(where_doc_clause), + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing.clone().filter(|i| i % 10 != 3).collect() + ); + + // A $in check on metadata should yield matching records + let where_clause = Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Set(SetOperator::In, MetadataSetValue::Int(vec![1, 3])), + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + Some(where_clause), + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing + .clone() + .filter(|i| i % 5 == 1 || i % 5 == 3) + .collect() + ); + + // A $in should behave like a disjunction of $eq + let contain_res = res.offset_ids; + let where_clause = Where::WhereChildren(WhereChildren { + operator: BooleanOperator::Or, + children: vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Primitive( + PrimitiveOperator::Equal, + MetadataValue::Int(1), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Primitive( + PrimitiveOperator::Equal, + MetadataValue::Int(3), + ), + }), + ], + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + Some(where_clause), + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!(res.offset_ids, contain_res); + + // A $nin check on metadata should yield matching records + let where_clause = Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Set(SetOperator::NotIn, MetadataSetValue::Int(vec![1, 3])), + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + Some(where_clause), + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing + .clone() + .filter(|i| i % 5 != 1 && i % 5 != 3) + .collect() + ); + + // A $nin should behave like a conjunction of $neq + let contain_res = res.offset_ids; + let where_clause = Where::WhereChildren(WhereChildren { + operator: BooleanOperator::And, + children: vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Primitive( + PrimitiveOperator::NotEqual, + MetadataValue::Int(1), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Primitive( + PrimitiveOperator::NotEqual, + MetadataValue::Int(3), + ), + }), + ], + }); + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + Some(where_clause), + None, + None, + None, + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!(res.offset_ids, contain_res); + // offset and limit should yield the correct chunk of records let input = MetadataFilteringInput::new( blockfile_provider.clone(), record_segment.clone(), @@ -1337,15 +1667,117 @@ mod test { None, None, None, + Some(36), + Some(54), + ); + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!(res.offset_ids, existing.clone().skip(36).take(54).collect()); + + // A large offset should yield no record + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), None, None, + None, + Some(200), + None, + ); + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!(res.offset_ids, RoaringBitmap::new()); + + // A large limit should yield all records + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + None, + None, + None, + None, + Some(200), ); let res = operator .run(&input) .await .expect("Error during running of operator"); - assert_eq!(100, res.offset_ids.len()); - assert_eq!(res.offset_ids, (1..=120).filter(|i| i % 6 != 1).collect()); + assert_eq!(res.offset_ids, existing.clone().collect()); + + // Finally, test a composite filter with limit and offset + let where_clause = Where::WhereChildren(WhereChildren { + operator: BooleanOperator::And, + children: vec![ + Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_three_0".to_string(), + comp: WhereComparison::Primitive( + PrimitiveOperator::GreaterThanOrEqual, + MetadataValue::Float(12.0), + ), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "mod_five".to_string(), + comp: WhereComparison::Set( + SetOperator::NotIn, + MetadataSetValue::Int(vec![0, 3]), + ), + }), + Where::WhereChildren(WhereChildren { + operator: BooleanOperator::Or, + children: vec![ + Where::DirectWhereDocumentComparison(DirectDocumentComparison { + operator: DocumentOperator::NotContains, + document: "6<-".to_string(), + }), + Where::DirectWhereComparison(DirectWhereComparison { + key: "even".to_string(), + comp: WhereComparison::Primitive( + PrimitiveOperator::Equal, + MetadataValue::Bool(true), + ), + }), + ], + }), + ], + }); + + let input = MetadataFilteringInput::new( + blockfile_provider.clone(), + record_segment.clone(), + metadata_segment.clone(), + data.clone(), + Some((0..90).map(|i| format!("id_{}", i)).collect()), + Some(where_clause), + None, + Some(2), + Some(7), + ); + + let res = operator + .run(&input) + .await + .expect("Error during running of operator"); + assert_eq!( + res.offset_ids, + existing + .filter(|i| i % 3 == 0 + && i >= &12 + && i <= &90 + && i % 5 != 0 + && i % 5 != 3 + && (i % 10 != 6 || i % 4 == 0)) + .skip(2) + .take(7) + .collect() + ); } }