diff --git a/Cargo.lock b/Cargo.lock index 8a26c003391..822ff9a89c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1275,6 +1275,7 @@ dependencies = [ "chroma-error", "prost 0.12.3", "prost-types", + "roaring", "thiserror", "tonic 0.10.2", "tonic-build", diff --git a/rust/types/Cargo.toml b/rust/types/Cargo.toml index 6e703a1c3d6..6cead6343c3 100644 --- a/rust/types/Cargo.toml +++ b/rust/types/Cargo.toml @@ -7,11 +7,12 @@ edition = "2021" path = "src/lib.rs" [dependencies] -tonic = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } -uuid = { workspace = true } +roaring = { workspace = true } thiserror = { workspace = true } +tonic = { workspace = true } +uuid = { workspace = true } chroma-error = { workspace = true } diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index cdb2073881c..0311156e81e 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -1,14 +1,20 @@ -use crate::chroma_proto; use chroma_error::{ChromaError, ErrorCodes}; -use std::collections::{HashMap, HashSet}; +use roaring::RoaringBitmap; +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, + ops::{BitAnd, BitOr}, +}; use thiserror::Error; +use crate::chroma_proto; + #[derive(Clone, Debug, PartialEq)] pub enum UpdateMetadataValue { - Int(i32), + Bool(bool), + Int(i64), Float(f64), Str(String), - Bool(bool), None, } @@ -31,8 +37,11 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue { fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result { match &value.value { + Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => { + Ok(UpdateMetadataValue::Bool(*value)) + } Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => { - Ok(UpdateMetadataValue::Int(*value as i32)) + Ok(UpdateMetadataValue::Int(*value)) } Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => { Ok(UpdateMetadataValue::Float(*value)) @@ -40,12 +49,8 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue { Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => { Ok(UpdateMetadataValue::Str(value.clone())) } - Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => { - Ok(UpdateMetadataValue::Bool(*value)) - } // Used to communicate that the user wants to delete this key. None => Ok(UpdateMetadataValue::None), - _ => Err(UpdateMetadataValueConversionError::InvalidValue), } } } @@ -53,6 +58,9 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for UpdateMetadataValue { impl From for chroma_proto::UpdateMetadataValue { fn from(value: UpdateMetadataValue) -> Self { let proto_value = match value { + UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)), + }, UpdateMetadataValue::Int(value) => chroma_proto::UpdateMetadataValue { value: Some(chroma_proto::update_metadata_value::Value::IntValue( value as i64, @@ -68,9 +76,6 @@ impl From for chroma_proto::UpdateMetadataValue { value, )), }, - UpdateMetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue { - value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)), - }, UpdateMetadataValue::None => chroma_proto::UpdateMetadataValue { value: None }, }; proto_value @@ -82,10 +87,10 @@ impl TryFrom<&UpdateMetadataValue> for MetadataValue { fn try_from(value: &UpdateMetadataValue) -> Result { match value { + UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)), UpdateMetadataValue::Int(value) => Ok(MetadataValue::Int(*value)), UpdateMetadataValue::Float(value) => Ok(MetadataValue::Float(*value)), UpdateMetadataValue::Str(value) => Ok(MetadataValue::Str(value.clone())), - UpdateMetadataValue::Bool(value) => Ok(MetadataValue::Bool(*value)), UpdateMetadataValue::None => Err(MetadataValueConversionError::InvalidValue), } } @@ -97,42 +102,50 @@ MetadataValue =========================================== */ -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum MetadataValue { - Int(i32), + Bool(bool), + Int(i64), Float(f64), Str(String), - Bool(bool), } -impl TryFrom<&MetadataValue> for i32 { +impl Eq for MetadataValue {} + +impl Ord for MetadataValue { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + +impl TryFrom<&MetadataValue> for bool { type Error = MetadataValueConversionError; fn try_from(value: &MetadataValue) -> Result { match value { - MetadataValue::Int(value) => Ok(*value), + MetadataValue::Bool(value) => Ok(*value), _ => Err(MetadataValueConversionError::InvalidValue), } } } -impl TryFrom<&MetadataValue> for f64 { +impl TryFrom<&MetadataValue> for i64 { type Error = MetadataValueConversionError; fn try_from(value: &MetadataValue) -> Result { match value { - MetadataValue::Float(value) => Ok(*value), + MetadataValue::Int(value) => Ok(*value), _ => Err(MetadataValueConversionError::InvalidValue), } } } -impl TryFrom<&MetadataValue> for bool { +impl TryFrom<&MetadataValue> for f64 { type Error = MetadataValueConversionError; fn try_from(value: &MetadataValue) -> Result { match value { - MetadataValue::Bool(value) => Ok(*value), + MetadataValue::Float(value) => Ok(*value), _ => Err(MetadataValueConversionError::InvalidValue), } } @@ -168,8 +181,11 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue { fn try_from(value: &chroma_proto::UpdateMetadataValue) -> Result { match &value.value { + Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => { + Ok(MetadataValue::Bool(*value)) + } Some(chroma_proto::update_metadata_value::Value::IntValue(value)) => { - Ok(MetadataValue::Int(*value as i32)) + Ok(MetadataValue::Int(*value)) } Some(chroma_proto::update_metadata_value::Value::FloatValue(value)) => { Ok(MetadataValue::Float(*value)) @@ -177,9 +193,6 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue { Some(chroma_proto::update_metadata_value::Value::StringValue(value)) => { Ok(MetadataValue::Str(value.clone())) } - Some(chroma_proto::update_metadata_value::Value::BoolValue(value)) => { - Ok(MetadataValue::Bool(*value)) - } _ => Err(MetadataValueConversionError::InvalidValue), } } @@ -188,6 +201,9 @@ impl TryFrom<&chroma_proto::UpdateMetadataValue> for MetadataValue { impl From for chroma_proto::UpdateMetadataValue { fn from(value: MetadataValue) -> Self { let proto_value = match value { + MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue { + value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)), + }, MetadataValue::Int(value) => chroma_proto::UpdateMetadataValue { value: Some(chroma_proto::update_metadata_value::Value::IntValue( value as i64, @@ -203,9 +219,6 @@ impl From for chroma_proto::UpdateMetadataValue { value, )), }, - MetadataValue::Bool(value) => chroma_proto::UpdateMetadataValue { - value: Some(chroma_proto::update_metadata_value::Value::BoolValue(value)), - }, }; proto_value } @@ -315,42 +328,40 @@ Metadata queries #[derive(Clone, Debug, PartialEq)] pub enum Where { - DirectWhereComparison(DirectComparison), + DirectWhereComparison(DirectWhereComparison), + DirectWhereDocumentComparison(DirectDocumentComparison), WhereChildren(WhereChildren), } +impl Where { + pub fn conjunction(children: Vec) -> Self { + Self::WhereChildren(WhereChildren { + operator: BooleanOperator::And, + children, + }) + } + pub fn disjunction(children: Vec) -> Self { + Self::WhereChildren(WhereChildren { + operator: BooleanOperator::Or, + children, + }) + } +} + #[derive(Clone, Debug, PartialEq)] -pub struct DirectComparison { +pub struct DirectWhereComparison { pub key: String, - pub comparison: WhereComparison, + pub comp: WhereComparison, } #[derive(Clone, Debug, PartialEq)] pub enum WhereComparison { - SingleStringComparison(String, WhereClauseComparator), - SingleIntComparison(u32, WhereClauseComparator), - SingleDoubleComparison(f64, WhereClauseComparator), - StringListComparison(Vec, WhereClauseListOperator), - IntListComparison(Vec, WhereClauseListOperator), - DoubleListComparison(Vec, WhereClauseListOperator), - BoolListComparison(Vec, WhereClauseListOperator), - SingleBoolComparison(bool, WhereClauseComparator), -} - -#[derive(Debug)] -pub enum MetadataType { - StringType, - IntType, - DoubleType, - StringListType, - IntListType, - DoubleListType, - BoolListType, - BoolType, + Primitive(PrimitiveOperator, MetadataValue), + Set(SetOperator, MetadataSetValue), } #[derive(Clone, Debug, PartialEq)] -pub enum WhereClauseComparator { +pub enum PrimitiveOperator { Equal, NotEqual, GreaterThan, @@ -360,15 +371,35 @@ pub enum WhereClauseComparator { } #[derive(Clone, Debug, PartialEq)] -pub enum WhereClauseListOperator { +pub enum SetOperator { In, NotIn, } +#[derive(Clone, Debug, PartialEq)] +pub enum MetadataSetValue { + Bool(Vec), + Int(Vec), + Float(Vec), + Str(Vec), +} + +impl MetadataSetValue { + pub fn into_vec(&self) -> Vec { + use MetadataSetValue::*; + match self { + Bool(vec) => vec.iter().map(|b| MetadataValue::Bool(*b)).collect(), + Int(vec) => vec.iter().map(|i| MetadataValue::Int(*i)).collect(), + Float(vec) => vec.iter().map(|f| MetadataValue::Float(*f)).collect(), + Str(vec) => vec.iter().map(|s| MetadataValue::Str(s.clone())).collect(), + } + } +} + #[derive(Clone, Debug, PartialEq)] pub struct WhereChildren { - pub children: Vec, pub operator: BooleanOperator, + pub children: Vec, } #[derive(Clone, Debug, PartialEq)] @@ -377,30 +408,18 @@ pub enum BooleanOperator { Or, } -#[derive(Clone, Debug, PartialEq)] -pub enum WhereDocument { - DirectWhereDocumentComparison(DirectDocumentComparison), - WhereDocumentChildren(WhereDocumentChildren), -} - #[derive(Clone, Debug, PartialEq)] pub struct DirectDocumentComparison { + pub operator: DocumentOperator, pub document: String, - pub operator: WhereDocumentOperator, } #[derive(Clone, Debug, PartialEq)] -pub enum WhereDocumentOperator { +pub enum DocumentOperator { Contains, NotContains, } -#[derive(Clone, Debug, PartialEq)] -pub struct WhereDocumentChildren { - pub children: Vec, - pub operator: BooleanOperator, -} - #[derive(Clone, Debug, PartialEq)] pub enum WhereConversionError { InvalidWhere, @@ -414,9 +433,9 @@ impl TryFrom for Where { fn try_from(proto_where: chroma_proto::Where) -> Result { match proto_where.r#where { Some(chroma_proto::r#where::Where::DirectComparison(proto_comparison)) => { - let comparison = DirectComparison { + let comparison = DirectWhereComparison { key: proto_comparison.key.clone(), - comparison: proto_comparison.try_into()?, + comp: proto_comparison.try_into()?, }; Ok(Where::DirectWhereComparison(comparison)) } @@ -446,181 +465,118 @@ impl TryFrom for WhereComparison { type Error = WhereConversionError; fn try_from(proto_comparison: chroma_proto::DirectComparison) -> Result { - match proto_comparison.r#comparison { - Some(chroma_proto::direct_comparison::Comparison::SingleStringOperand( - proto_string, - )) => { - let comparator = match TryInto::::try_into( - proto_string.comparator, - ) { - Ok(comparator) => comparator, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - }; - Ok(WhereComparison::SingleStringComparison( - proto_string.value, - comparator.try_into()?, - )) - } - Some(chroma_proto::direct_comparison::Comparison::SingleBoolOperand(proto_bool)) => { - let comparator = match TryInto::::try_into( - proto_bool.comparator, - ) { - Ok(comparator) => comparator, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - }; - Ok(WhereComparison::SingleBoolComparison( - proto_bool.value, - comparator.try_into()?, - )) - } - Some(chroma_proto::direct_comparison::Comparison::SingleIntOperand(proto_int)) => { - let comparator: WhereClauseComparator = match proto_int.comparator { - Some(comparator) => match comparator { - chroma_proto::single_int_comparison::Comparator::NumberComparator( - proto_comparator, - ) => { - match TryInto::::try_into( - proto_comparator, - ) { - Ok(comparator) => comparator.try_into()?, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - } - } - chroma_proto::single_int_comparison::Comparator::GenericComparator( - proto_comparator, - ) => { - match TryInto::::try_into( - proto_comparator, - ) { - Ok(comparator) => comparator.try_into()?, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - } - } + let id_to_generic_comparator = |id| { + TryInto::::try_into(id) + .map_err(|_| WhereConversionError::InvalidWhereComparison)? + .try_into() + }; + let id_to_number_comparator = |id| { + TryInto::::try_into(id) + .map_err(|_| WhereConversionError::InvalidWhereComparison)? + .try_into() + }; + let id_to_set_comparator = |id| { + TryInto::::try_into(id) + .map_err(|_| WhereConversionError::InvalidWhereComparison)? + .try_into() + }; + if let Some(proto_comp) = proto_comparison.r#comparison { + use chroma_proto::direct_comparison::Comparison::*; + match proto_comp { + SingleBoolOperand(single_bool_comparison) => Ok(WhereComparison::Primitive( + id_to_generic_comparator(single_bool_comparison.comparator)?, + MetadataValue::Bool(single_bool_comparison.value), + )), + SingleStringOperand(single_string_comparison) => Ok(WhereComparison::Primitive( + id_to_generic_comparator(single_string_comparison.comparator)?, + MetadataValue::Str(single_string_comparison.value), + )), + SingleIntOperand(single_int_comparison) => Ok(WhereComparison::Primitive( + match single_int_comparison.comparator { + Some( + chroma_proto::single_int_comparison::Comparator::GenericComparator( + proto_generic_comparator, + ), + ) => id_to_generic_comparator(proto_generic_comparator)?, + Some( + chroma_proto::single_int_comparison::Comparator::NumberComparator( + proto_number_comparator, + ), + ) => id_to_number_comparator(proto_number_comparator)?, + None => PrimitiveOperator::Equal, }, - None => WhereClauseComparator::Equal, - }; - Ok(WhereComparison::SingleIntComparison( - proto_int.value as u32, - comparator, - )) - } - Some(chroma_proto::direct_comparison::Comparison::SingleDoubleOperand( - proto_double, - )) => { - let comparator: WhereClauseComparator = match proto_double.comparator { - Some(comparator) => match comparator { - chroma_proto::single_double_comparison::Comparator::NumberComparator( - proto_comparator, - ) => { - match TryInto::::try_into( - proto_comparator, - ) { - Ok(comparator) => comparator.try_into()?, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - } - } - chroma_proto::single_double_comparison::Comparator::GenericComparator( - proto_comparator, - ) => { - match TryInto::::try_into( - proto_comparator, - ) { - Ok(comparator) => comparator.try_into()?, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - } - } + MetadataValue::Int(single_int_comparison.value), + )), + SingleDoubleOperand(single_double_comparison) => Ok(WhereComparison::Primitive( + match single_double_comparison.comparator { + Some( + chroma_proto::single_double_comparison::Comparator::GenericComparator( + proto_generic_comparator, + ), + ) => id_to_generic_comparator(proto_generic_comparator)?, + Some( + chroma_proto::single_double_comparison::Comparator::NumberComparator( + proto_number_comparator, + ), + ) => id_to_number_comparator(proto_number_comparator)?, + None => PrimitiveOperator::Equal, }, - None => WhereClauseComparator::Equal, - }; - Ok(WhereComparison::SingleDoubleComparison( - proto_double.value, - comparator, - )) - } - Some(chroma_proto::direct_comparison::Comparison::StringListOperand(proto_list)) => { - let list_operator = - match TryInto::::try_into(proto_list.list_operator) - { - Ok(list_operator) => list_operator, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - }; - Ok(WhereComparison::StringListComparison( - proto_list.values, - list_operator.try_into()?, - )) - } - Some(chroma_proto::direct_comparison::Comparison::IntListOperand(proto_list)) => { - let list_operator = - match TryInto::::try_into(proto_list.list_operator) - { - Ok(list_operator) => list_operator, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - }; - Ok(WhereComparison::IntListComparison( - proto_list.values.into_iter().map(|v| v as u32).collect(), - list_operator.try_into()?, - )) - } - Some(chroma_proto::direct_comparison::Comparison::DoubleListOperand(proto_list)) => { - let list_operator = - match TryInto::::try_into(proto_list.list_operator) - { - Ok(list_operator) => list_operator, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - }; - Ok(WhereComparison::DoubleListComparison( - proto_list.values, - list_operator.try_into()?, - )) - } - Some(chroma_proto::direct_comparison::Comparison::BoolListOperand(proto_list)) => { - let list_operator = - match TryInto::::try_into(proto_list.list_operator) - { - Ok(list_operator) => list_operator, - Err(_) => return Err(WhereConversionError::InvalidWhereComparison), - }; - Ok(WhereComparison::BoolListComparison( - proto_list.values, - list_operator.try_into()?, - )) + MetadataValue::Float(single_double_comparison.value), + )), + BoolListOperand(bool_list_comparison) => Ok(WhereComparison::Set( + id_to_set_comparator(bool_list_comparison.list_operator)?, + MetadataSetValue::Bool(bool_list_comparison.values), + )), + StringListOperand(string_list_comparison) => Ok(WhereComparison::Set( + id_to_set_comparator(string_list_comparison.list_operator)?, + MetadataSetValue::Str(string_list_comparison.values), + )), + IntListOperand(int_list_comparison) => Ok(WhereComparison::Set( + id_to_set_comparator(int_list_comparison.list_operator)?, + MetadataSetValue::Int(int_list_comparison.values), + )), + DoubleListOperand(double_list_comparison) => Ok(WhereComparison::Set( + id_to_set_comparator(double_list_comparison.list_operator)?, + MetadataSetValue::Float(double_list_comparison.values), + )), } - None => Err(WhereConversionError::InvalidWhereComparison), + } else { + Err(WhereConversionError::InvalidWhereComparison) } } } -impl TryFrom for WhereClauseComparator { +impl TryFrom for PrimitiveOperator { type Error = WhereConversionError; fn try_from(proto_comparator: chroma_proto::NumberComparator) -> Result { match proto_comparator { - chroma_proto::NumberComparator::Gt => Ok(WhereClauseComparator::GreaterThan), - chroma_proto::NumberComparator::Gte => Ok(WhereClauseComparator::GreaterThanOrEqual), - chroma_proto::NumberComparator::Lt => Ok(WhereClauseComparator::LessThan), - chroma_proto::NumberComparator::Lte => Ok(WhereClauseComparator::LessThanOrEqual), + chroma_proto::NumberComparator::Gt => Ok(PrimitiveOperator::GreaterThan), + chroma_proto::NumberComparator::Gte => Ok(PrimitiveOperator::GreaterThanOrEqual), + chroma_proto::NumberComparator::Lt => Ok(PrimitiveOperator::LessThan), + chroma_proto::NumberComparator::Lte => Ok(PrimitiveOperator::LessThanOrEqual), } } } -impl TryFrom for WhereClauseComparator { +impl TryFrom for PrimitiveOperator { type Error = WhereConversionError; fn try_from(proto_comparator: chroma_proto::GenericComparator) -> Result { match proto_comparator { - chroma_proto::GenericComparator::Eq => Ok(WhereClauseComparator::Equal), - chroma_proto::GenericComparator::Ne => Ok(WhereClauseComparator::NotEqual), + chroma_proto::GenericComparator::Eq => Ok(PrimitiveOperator::Equal), + chroma_proto::GenericComparator::Ne => Ok(PrimitiveOperator::NotEqual), } } } -impl TryFrom for WhereClauseListOperator { +impl TryFrom for SetOperator { type Error = WhereConversionError; fn try_from(proto_operator: chroma_proto::ListOperator) -> Result { match proto_operator { - chroma_proto::ListOperator::In => Ok(WhereClauseListOperator::In), - chroma_proto::ListOperator::Nin => Ok(WhereClauseListOperator::NotIn), + chroma_proto::ListOperator::In => Ok(SetOperator::In), + chroma_proto::ListOperator::Nin => Ok(SetOperator::NotIn), } } } @@ -654,7 +610,7 @@ impl TryFrom for BooleanOperator { } } -impl TryFrom for WhereDocument { +impl TryFrom for Where { type Error = WhereConversionError; fn try_from(proto_document: chroma_proto::WhereDocument) -> Result { @@ -670,7 +626,7 @@ impl TryFrom for WhereDocument { document: proto_comparison.document, operator: operator.try_into()?, }; - Ok(WhereDocument::DirectWhereDocumentComparison(comparison)) + Ok(Where::DirectWhereDocumentComparison(comparison)) } Some(chroma_proto::where_document::WhereDocument::Children(proto_children)) => { let operator = match TryInto::::try_into( @@ -679,30 +635,80 @@ impl TryFrom for WhereDocument { Ok(operator) => operator, Err(_) => return Err(WhereConversionError::InvalidWhereChildren), }; - let children = WhereDocumentChildren { + let children = WhereChildren { children: proto_children .children .into_iter() .map(|child| child.try_into()) - .collect::, WhereConversionError>>()?, + .collect::>()?, operator: operator.try_into()?, }; - Ok(WhereDocument::WhereDocumentChildren(children)) + Ok(Where::WhereChildren(children)) } None => Err(WhereConversionError::InvalidWhere), } } } -impl TryFrom for WhereDocumentOperator { +impl TryFrom for DocumentOperator { type Error = WhereConversionError; fn try_from(proto_operator: chroma_proto::WhereDocumentOperator) -> Result { match proto_operator { - chroma_proto::WhereDocumentOperator::Contains => Ok(WhereDocumentOperator::Contains), - chroma_proto::WhereDocumentOperator::NotContains => { - Ok(WhereDocumentOperator::NotContains) - } + chroma_proto::WhereDocumentOperator::Contains => Ok(DocumentOperator::Contains), + chroma_proto::WhereDocumentOperator::NotContains => Ok(DocumentOperator::NotContains), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum SignedRoaringBitmap { + Include(RoaringBitmap), + Exclude(RoaringBitmap), +} + +impl SignedRoaringBitmap { + pub fn empty() -> Self { + Self::Include(RoaringBitmap::new()) + } + + pub fn full() -> Self { + Self::Exclude(RoaringBitmap::new()) + } + + pub fn flip(self) -> Self { + use SignedRoaringBitmap::*; + match self { + Include(rbm) => Exclude(rbm), + Exclude(rbm) => Include(rbm), + } + } +} + +impl BitAnd for SignedRoaringBitmap { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + use SignedRoaringBitmap::*; + match (self, rhs) { + (Include(lhs), Include(rhs)) => Include(lhs & rhs), + (Include(lhs), Exclude(rhs)) => Include(lhs - rhs), + (Exclude(lhs), Include(rhs)) => Include(rhs - lhs), + (Exclude(lhs), Exclude(rhs)) => Exclude(lhs | rhs), + } + } +} + +impl BitOr for SignedRoaringBitmap { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + use SignedRoaringBitmap::*; + match (self, rhs) { + (Include(lhs), Include(rhs)) => Include(lhs | rhs), + (Include(lhs), Exclude(rhs)) => Exclude(rhs - lhs), + (Exclude(lhs), Include(rhs)) => Exclude(lhs - rhs), + (Exclude(lhs), Exclude(rhs)) => Exclude(lhs & rhs), } } } @@ -814,9 +820,9 @@ mod tests { match where_clause { Where::DirectWhereComparison(comparison) => { assert_eq!(comparison.key, "foo"); - match comparison.comparison { - WhereComparison::SingleIntComparison(value, _) => { - assert_eq!(value, 42); + match comparison.comp { + WhereComparison::Primitive(_, value) => { + assert_eq!(value, MetadataValue::Int(42)); } _ => panic!("Invalid comparison type"), } @@ -888,11 +894,11 @@ mod tests { }, )), }; - let where_document: WhereDocument = proto_where.try_into().unwrap(); + let where_document: Where = proto_where.try_into().unwrap(); match where_document { - WhereDocument::DirectWhereDocumentComparison(comparison) => { + Where::DirectWhereDocumentComparison(comparison) => { assert_eq!(comparison.document, "foo"); - assert_eq!(comparison.operator, WhereDocumentOperator::Contains); + assert_eq!(comparison.operator, DocumentOperator::Contains); } _ => panic!("Invalid where document type"), } @@ -933,9 +939,9 @@ mod tests { }, )), }; - let where_document: WhereDocument = proto_where.try_into().unwrap(); + let where_document: Where = proto_where.try_into().unwrap(); match where_document { - WhereDocument::WhereDocumentChildren(children) => { + Where::WhereChildren(children) => { assert_eq!(children.children.len(), 2); assert_eq!(children.operator, BooleanOperator::And); }