diff --git a/rust/worker/chroma_config.yaml b/rust/worker/chroma_config.yaml index 60c8731e383..e9a110090d6 100644 --- a/rust/worker/chroma_config.yaml +++ b/rust/worker/chroma_config.yaml @@ -49,6 +49,11 @@ query_service: sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 compaction_service: service_name: "compaction-service" @@ -101,3 +106,8 @@ compaction_service: sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 diff --git a/rust/worker/src/compactor/compaction_manager.rs b/rust/worker/src/compactor/compaction_manager.rs index 4b6957cfc74..ccdc7d11294 100644 --- a/rust/worker/src/compactor/compaction_manager.rs +++ b/rust/worker/src/compactor/compaction_manager.rs @@ -222,21 +222,23 @@ impl Configurable for CompactionManager { assignment_policy, ); - // TODO: real path - let path = PathBuf::from("~/tmp"); // TODO: hnsw index provider should be injected somehow let blockfile_provider = BlockfileProvider::try_from_config(&( config.blockfile_provider.clone(), storage.clone(), )) .await?; + let hnsw_index_provider = + HnswIndexProvider::try_from_config(&(config.hnsw_provider.clone(), storage.clone())) + .await?; + Ok(CompactionManager::new( scheduler, log, sysdb, storage.clone(), blockfile_provider, - HnswIndexProvider::new(storage.clone(), path), + hnsw_index_provider, compaction_manager_queue_size, Duration::from_secs(compaction_interval_sec), min_compaction_size, @@ -497,6 +499,7 @@ mod tests { let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); + let hnsw_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); let mut manager = CompactionManager::new( scheduler, log, @@ -508,7 +511,11 @@ mod tests { block_cache, sparse_index_cache, ), - HnswIndexProvider::new(storage, PathBuf::from(tmpdir.path().to_str().unwrap())), + HnswIndexProvider::new( + storage, + PathBuf::from(tmpdir.path().to_str().unwrap()), + hnsw_cache, + ), compaction_manager_queue_size, compaction_interval, min_compaction_size, diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index 0855793a38a..02e67219742 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -104,6 +104,7 @@ pub(crate) struct QueryServiceConfig { pub(crate) log: crate::log::config::LogConfig, pub(crate) dispatcher: crate::execution::config::DispatcherConfig, pub(crate) blockfile_provider: crate::blockstore::config::BlockfileProviderConfig, + pub(crate) hnsw_provider: crate::index::config::HnswProviderConfig, } #[derive(Deserialize)] @@ -130,6 +131,7 @@ pub(crate) struct CompactionServiceConfig { pub(crate) dispatcher: crate::execution::config::DispatcherConfig, pub(crate) compactor: crate::compactor::config::CompactorConfig, pub(crate) blockfile_provider: crate::blockstore::config::BlockfileProviderConfig, + pub(crate) hnsw_provider: crate::index::config::HnswProviderConfig, } /// # Description @@ -203,6 +205,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 compaction_service: service_name: "compaction-service" @@ -255,6 +262,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 "#, ); let config = RootConfig::load(); @@ -323,6 +335,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 compaction_service: service_name: "compaction-service" @@ -375,6 +392,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 "#, ); let config = RootConfig::load_from_path("random_path.yaml"); @@ -461,6 +483,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 compaction_service: service_name: "compaction-service" @@ -513,6 +540,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 "#, ); let config = RootConfig::load(); @@ -593,6 +625,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 compaction_service: service_name: "compaction-service" @@ -637,6 +674,11 @@ mod tests { sparse_index_cache_config: lru: capacity: 1000 + hnsw_provider: + hnsw_temporary_path: "~/tmp" + hnsw_cache_config: + lru: + capacity: 1000 "#, ); let config = RootConfig::load(); diff --git a/rust/worker/src/index/config.rs b/rust/worker/src/index/config.rs new file mode 100644 index 00000000000..91023a367a7 --- /dev/null +++ b/rust/worker/src/index/config.rs @@ -0,0 +1,8 @@ +use crate::cache::config::CacheConfig; +use serde::Deserialize; + +#[derive(Deserialize, Debug, Clone)] +pub(crate) struct HnswProviderConfig { + pub(crate) hnsw_temporary_path: String, + pub(crate) hnsw_cache_config: CacheConfig, +} diff --git a/rust/worker/src/index/hnsw_provider.rs b/rust/worker/src/index/hnsw_provider.rs index 9e55f12d8b8..0788ee4fc1c 100644 --- a/rust/worker/src/index/hnsw_provider.rs +++ b/rust/worker/src/index/hnsw_provider.rs @@ -1,17 +1,21 @@ +use super::config::HnswProviderConfig; use super::{ HnswIndex, HnswIndexConfig, HnswIndexFromSegmentError, Index, IndexConfig, IndexConfigFromSegmentError, }; +use crate::cache::cache::Cache; +use crate::config::Configurable; use crate::errors::ErrorCodes; use crate::index::types::PersistentIndex; use crate::storage::stream::ByteStreamItem; use crate::{errors::ChromaError, storage::Storage, types::Segment}; +use async_trait::async_trait; use futures::stream; use futures::stream::StreamExt; use parking_lot::RwLock; use std::fmt::Debug; use std::path::Path; -use std::{collections::HashMap, path::PathBuf, sync::Arc}; +use std::{path::PathBuf, sync::Arc}; use thiserror::Error; use tokio::io::AsyncWriteExt; use uuid::Uuid; @@ -28,7 +32,7 @@ const FILES: [&'static str; 4] = [ #[derive(Clone)] pub(crate) struct HnswIndexProvider { - cache: Arc>>>>, + cache: Cache>>, pub(crate) temporary_storage_path: PathBuf, storage: Storage, } @@ -37,25 +41,45 @@ impl Debug for HnswIndexProvider { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "HnswIndexProvider {{ temporary_storage_path: {:?}, cache: {} }}", + "HnswIndexProvider {{ temporary_storage_path: {:?} }}", self.temporary_storage_path, - self.cache.read().len(), ) } } +#[async_trait] +impl Configurable<(HnswProviderConfig, Storage)> for HnswIndexProvider { + async fn try_from_config( + config: &(HnswProviderConfig, Storage), + ) -> Result> { + let (hnsw_config, storage) = config; + let cache = Cache::new(&hnsw_config.hnsw_cache_config); + Ok(Self { + cache, + storage: storage.clone(), + temporary_storage_path: PathBuf::from(&hnsw_config.hnsw_temporary_path), + }) + } +} + impl HnswIndexProvider { - pub(crate) fn new(storage: Storage, storage_path: PathBuf) -> Self { + pub(crate) fn new( + storage: Storage, + storage_path: PathBuf, + cache: Cache>>, + ) -> Self { Self { - cache: Arc::new(RwLock::new(HashMap::new())), + cache, storage, temporary_storage_path: storage_path, } } pub(crate) fn get(&self, id: &Uuid) -> Option>> { - let cache = self.cache.read(); - cache.get(id).cloned() + match self.cache.get(id) { + Some(index) => Some(index.clone()), + None => None, + } } fn format_key(&self, id: &Uuid, file: &str) -> String { @@ -97,7 +121,7 @@ impl HnswIndexProvider { }; let hnsw_config = HnswIndexConfig::from_segment(segment, &new_storage_path); - let hnsw_config = match hnsw_config { + match hnsw_config { Ok(hnsw_config) => hnsw_config, Err(e) => { return Err(Box::new(HnswIndexProviderForkError::HnswConfigError(*e))); @@ -116,8 +140,7 @@ impl HnswIndexProvider { match HnswIndex::load(storage_path_str, &index_config, new_id) { Ok(index) => { let index = Arc::new(RwLock::new(index)); - let mut cache = self.cache.write(); - cache.insert(new_id, index.clone()); + self.cache.insert(new_id, index.clone()); Ok(index) } Err(e) => Err(Box::new(HnswIndexProviderForkError::IndexLoadError(e))), @@ -244,8 +267,7 @@ impl HnswIndexProvider { match HnswIndex::load(index_storage_path.to_str().unwrap(), &index_config, *id) { Ok(index) => { let index = Arc::new(RwLock::new(index)); - let mut cache = self.cache.write(); - cache.insert(*id, index.clone()); + self.cache.insert(*id, index.clone()); Ok(index) } Err(e) => Err(Box::new(HnswIndexProviderOpenError::IndexLoadError(e))), @@ -262,7 +284,6 @@ impl HnswIndexProvider { // Cases // A query comes in and the index is in the cache -> we can query the index based on segment files id (Same as compactor case 3 where we have the index) // A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id - pub(crate) fn create( &self, // TODO: This should not take Segment. The index layer should not know about the segment concept @@ -293,7 +314,6 @@ impl HnswIndexProvider { } }; - let mut cache = self.cache.write(); let index = match HnswIndex::init(&index_config, Some(&hnsw_config), id) { Ok(index) => index, Err(e) => { @@ -301,13 +321,12 @@ impl HnswIndexProvider { } }; let index = Arc::new(RwLock::new(index)); - cache.insert(id, index.clone()); + self.cache.insert(id, index.clone()); Ok(index) } pub(crate) fn commit(&self, id: &Uuid) -> Result<(), Box> { - let cache = self.cache.read(); - let index = match cache.get(id) { + let index = match self.cache.get(id) { Some(index) => index, None => { return Err(Box::new(HnswIndexProviderCommitError::NoIndexFound(*id))); @@ -328,8 +347,8 @@ impl HnswIndexProvider { // Scope to drop the cache lock before we await to write to s3 // TODO: since we commit(), we don't need to save the index here { - let cache = self.cache.read(); - let index = match cache.get(id) { + // let cache = self.cache.read(); + let index = match self.cache.get(id) { Some(index) => index, None => { return Err(Box::new(HnswIndexProviderFlushError::NoIndexFound(*id))); @@ -494,9 +513,11 @@ pub(crate) enum HnswIndexProviderFileError { mod tests { use super::*; use crate::{ + cache::config::{CacheConfig, UnboundedCacheConfig}, storage::{local::LocalStorage, Storage}, types::SegmentType, }; + use std::collections::HashMap; #[tokio::test] async fn test_fork() { @@ -507,8 +528,8 @@ mod tests { std::fs::create_dir_all(&hnsw_tmp_path).unwrap(); let storage = Storage::Local(LocalStorage::new(storage_dir.to_str().unwrap())); - - let provider = HnswIndexProvider::new(storage, hnsw_tmp_path); + let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); + let provider = HnswIndexProvider::new(storage, hnsw_tmp_path, cache); let segment = Segment { id: Uuid::new_v4(), r#type: SegmentType::HnswDistributed, diff --git a/rust/worker/src/index/mod.rs b/rust/worker/src/index/mod.rs index cd26321f352..11cbabfbe23 100644 --- a/rust/worker/src/index/mod.rs +++ b/rust/worker/src/index/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod config; pub(crate) mod fulltext; mod hnsw; pub(crate) mod hnsw_provider; diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index 3c5bdddc48c..c30b7f9ee4b 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -21,7 +21,6 @@ use crate::types::MetadataValue; use crate::types::ScalarEncoding; use async_trait::async_trait; use std::collections::HashMap; -use std::path::PathBuf; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; use tracing::{trace, trace_span, Instrument}; @@ -73,15 +72,15 @@ impl Configurable for WorkerServer { storage.clone(), )) .await?; - // TODO: inject hnsw index provider somehow - // TODO: real path - let path = PathBuf::from("~/tmp"); + let hnsw_index_provider = + HnswIndexProvider::try_from_config(&(config.hnsw_provider.clone(), storage.clone())) + .await?; Ok(WorkerServer { dispatcher: None, system: None, sysdb, log, - hnsw_index_provider: HnswIndexProvider::new(storage.clone(), path), + hnsw_index_provider, blockfile_provider, port: config.my_port, }) @@ -587,6 +586,7 @@ mod tests { let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); + let hnsw_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); let port = random_port::PortPicker::new().pick().unwrap(); let mut server = WorkerServer { dispatcher: None, @@ -596,6 +596,7 @@ mod tests { hnsw_index_provider: HnswIndexProvider::new( storage.clone(), tmp_dir.path().to_path_buf(), + hnsw_index_cache, ), blockfile_provider: BlockfileProvider::new_arrow( storage,