Skip to content

Commit

Permalink
[CLN] Cleanup codebase with the refactored metadata filtering pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Sicheng Pan committed Sep 25, 2024
1 parent 25b4e7a commit 61bd467
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 1,124 deletions.
144 changes: 38 additions & 106 deletions rust/index/src/fulltext/types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use crate::fulltext::tokenizer::ChromaTokenizer;
use crate::metadata::types::MetadataIndexError;
use crate::utils::{merge_sorted_vecs_conjunction, merge_sorted_vecs_disjunction};
use chroma_blockstore::{BlockfileFlusher, BlockfileReader, BlockfileWriter};
use chroma_error::{ChromaError, ErrorCodes};
use chroma_types::{BooleanOperator, WhereDocument, WhereDocumentOperator};
use parking_lot::Mutex;
use roaring::RoaringBitmap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tantivy::tokenizer::Token;
Expand Down Expand Up @@ -403,7 +400,7 @@ impl<'me> FullTextIndexReader<'me> {
self.tokenizer.encode(document)
}

pub async fn search(&self, query: &str) -> Result<Vec<i32>, FullTextIndexError> {
pub async fn search(&self, query: &str) -> Result<RoaringBitmap, FullTextIndexError> {
let binding = self.encode_tokens(query);
let tokens = binding.get_tokens();

Expand All @@ -416,7 +413,7 @@ impl<'me> FullTextIndexReader<'me> {
.get_by_prefix(token.text.as_str())
.await?;
if res.len() == 0 {
return Ok(vec![]);
return Ok(RoaringBitmap::new());
}
if res.len() > 1 {
panic!("Invariant violation. Multiple frequency values found for a token.");
Expand All @@ -430,7 +427,7 @@ impl<'me> FullTextIndexReader<'me> {
}

if token_frequencies.len() == 0 {
return Ok(vec![]);
return Ok(RoaringBitmap::new());
}
// TODO sort by frequency. This adds an additional layer of complexity
// with repeat characters where we need to keep track of which positions
Expand Down Expand Up @@ -492,16 +489,12 @@ impl<'me> FullTextIndexReader<'me> {
}
}
if new_candidates.is_empty() {
return Ok(vec![]);
return Ok(RoaringBitmap::new());
}
candidates = new_candidates;
}

let mut results = vec![];
for (doc_id, _) in candidates.drain() {
results.push(doc_id as i32);
}
return Ok(results);
return Ok(candidates.into_keys().collect());
}

// We use this to implement deletes in the Writer. A delete() is implemented
Expand Down Expand Up @@ -539,56 +532,6 @@ impl<'me> FullTextIndexReader<'me> {
}
}

pub fn process_where_document_clause_with_callback<
F: Fn(&str, WhereDocumentOperator) -> Vec<i32>,
>(
where_document_clause: &WhereDocument,
callback: &F,
) -> Result<Vec<usize>, MetadataIndexError> {
let mut results = vec![];
match where_document_clause {
WhereDocument::DirectWhereDocumentComparison(direct_document_comparison) => {
match &direct_document_comparison.operator {
WhereDocumentOperator::Contains => {
let result = callback(
&direct_document_comparison.document,
WhereDocumentOperator::Contains,
);
results = result.iter().map(|x| *x as usize).collect();
}
WhereDocumentOperator::NotContains => {
todo!();
}
}
}
WhereDocument::WhereDocumentChildren(where_document_children) => {
let mut first_iteration = true;
for child in where_document_children.children.iter() {
let child_results: Vec<usize> =
match process_where_document_clause_with_callback(&child, callback) {
Ok(result) => result,
Err(_) => vec![],
};
if first_iteration {
results = child_results;
first_iteration = false;
} else {
match where_document_children.operator {
BooleanOperator::And => {
results = merge_sorted_vecs_conjunction(&results, &child_results);
}
BooleanOperator::Or => {
results = merge_sorted_vecs_disjunction(&results, &child_results);
}
}
}
}
}
}
results.sort();
return Ok(results);
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -666,13 +609,13 @@ mod tests {
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let res = index_reader.search("hello").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));

let res = index_reader.search("world").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));

let res = index_reader.search("hello world").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));
}

#[tokio::test]
Expand Down Expand Up @@ -739,7 +682,7 @@ mod tests {
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let res = index_reader.search("aaaa").await.unwrap();
assert_eq!(res, vec![2]);
assert_eq!(res, RoaringBitmap::from([2]));
}

#[tokio::test]
Expand Down Expand Up @@ -841,12 +784,11 @@ mod tests {
let index_reader =
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let mut res = index_reader.search("hello").await.unwrap();
res.sort();
assert_eq!(res, vec![1, 2]);
let res = index_reader.search("hello").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1, 2]));

let res = index_reader.search("hello world").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));
}

#[tokio::test]
Expand Down Expand Up @@ -879,12 +821,11 @@ mod tests {
let index_reader =
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let mut res = index_reader.search("hello").await.unwrap();
res.sort();
assert_eq!(res, vec![1, 2]);
let res = index_reader.search("hello").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1, 2]));

let res = index_reader.search("world").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));
}

#[tokio::test]
Expand Down Expand Up @@ -919,21 +860,17 @@ mod tests {
let index_reader =
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let mut res = index_reader.search("hello").await.unwrap();
res.sort();
assert_eq!(res, vec![1, 2, 4]);
let res = index_reader.search("hello").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1, 2, 4]));

let mut res = index_reader.search("world").await.unwrap();
res.sort();
assert_eq!(res, vec![1, 3, 4]);
let res = index_reader.search("world").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1, 3, 4]));

let mut res = index_reader.search("hello world").await.unwrap();
res.sort();
assert_eq!(res, vec![1]);
let res = index_reader.search("hello world").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1]));

let mut res = index_reader.search("world hello").await.unwrap();
res.sort();
assert_eq!(res, vec![4]);
let res = index_reader.search("world hello").await.unwrap();
assert_eq!(res, RoaringBitmap::from([4]));
}

#[tokio::test]
Expand Down Expand Up @@ -972,17 +909,14 @@ mod tests {
let index_reader =
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let mut res = index_reader.search("aaa").await.unwrap();
res.sort();
assert_eq!(res, vec![1, 2, 4, 5]);
let res = index_reader.search("aaa").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1, 2, 4, 5]));

let mut res = index_reader.search("bbb").await.unwrap();
res.sort();
assert_eq!(res, vec![3, 4, 5]);
let res = index_reader.search("bbb").await.unwrap();
assert_eq!(res, RoaringBitmap::from([3, 4, 5]));

let mut res = index_reader.search("aaabbb").await.unwrap();
res.sort();
assert_eq!(res, vec![4, 5]);
let res = index_reader.search("aaabbb").await.unwrap();
assert_eq!(res, RoaringBitmap::from([4, 5]));
}

#[tokio::test]
Expand Down Expand Up @@ -1020,14 +954,13 @@ mod tests {
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let res = index_reader.search("!!!!!").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));

let mut res = index_reader.search("!!!").await.unwrap();
res.sort();
assert_eq!(res, vec![1, 2]);
let res = index_reader.search("!!!").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1, 2]));

let res = index_reader.search(".!.").await.unwrap();
assert_eq!(res, vec![3]);
assert_eq!(res, RoaringBitmap::from([3]));
}

#[tokio::test]
Expand Down Expand Up @@ -1153,12 +1086,11 @@ mod tests {
let index_reader =
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let mut res = index_reader.search("hello").await.unwrap();
res.sort();
assert_eq!(res, vec![1, 2, 3]);
let res = index_reader.search("hello").await.unwrap();
assert_eq!(res, RoaringBitmap::from([1, 2, 3]));

let res = index_reader.search("world").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));
}

#[tokio::test]
Expand Down Expand Up @@ -1196,6 +1128,6 @@ mod tests {
FullTextIndexReader::new(pl_blockfile_reader, freq_blockfile_reader, tokenizer);

let res = index_reader.search("world").await.unwrap();
assert_eq!(res, vec![1]);
assert_eq!(res, RoaringBitmap::from([1]));
}
}
12 changes: 6 additions & 6 deletions rust/index/src/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ impl HnswIndexConfig {
}
}

let m = get_metadata_value_as::<i32>(metadata, "hnsw:M").unwrap_or(DEFAULT_HNSW_M as i32);
let ef_construction = get_metadata_value_as::<i32>(metadata, "hnsw:construction_ef")
.unwrap_or(DEFAULT_HNSW_EF_CONSTRUCTION as i32);
let ef_search = get_metadata_value_as::<i32>(metadata, "hnsw:search_ef")
.unwrap_or(DEFAULT_HNSW_EF_SEARCH as i32);
let m = get_metadata_value_as::<i64>(metadata, "hnsw:M").unwrap_or(DEFAULT_HNSW_M as i64);
let ef_construction = get_metadata_value_as::<i64>(metadata, "hnsw:construction_ef")
.unwrap_or(DEFAULT_HNSW_EF_CONSTRUCTION as i64);
let ef_search = get_metadata_value_as::<i64>(metadata, "hnsw:search_ef")
.unwrap_or(DEFAULT_HNSW_EF_SEARCH as i64);
return Ok(HnswIndexConfig {
max_elements: DEFAULT_MAX_ELEMENTS,
m: m as usize,
Expand Down Expand Up @@ -848,7 +848,7 @@ pub mod test {

// Try partial metadata
let mut metadata = HashMap::new();
metadata.insert("hnsw:M".to_string(), MetadataValue::Int(10 as i32));
metadata.insert("hnsw:M".to_string(), MetadataValue::Int(10 as i64));

let segment = Segment {
id: Uuid::new_v4(),
Expand Down
Loading

0 comments on commit 61bd467

Please sign in to comment.