Skip to content

Commit

Permalink
[ENH] Allow cache eviction for HNSW provider
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishiihara committed Jul 9, 2024
1 parent a7f5120 commit 04f0db4
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 31 deletions.
10 changes: 10 additions & 0 deletions rust/worker/chroma_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ query_service:
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000

compaction_service:
service_name: "compaction-service"
Expand Down Expand Up @@ -101,3 +106,8 @@ compaction_service:
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
15 changes: 11 additions & 4 deletions rust/worker/src/compactor/compaction_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,23 @@ impl Configurable<CompactionServiceConfig> for CompactionManager {
assignment_policy,
);

// TODO: real path
let path = PathBuf::from("~/tmp");
// TODO: hnsw index provider should be injected somehow
let blockfile_provider = BlockfileProvider::try_from_config(&(
config.blockfile_provider.clone(),
storage.clone(),
))
.await?;
let hnsw_index_provider =
HnswIndexProvider::try_from_config(&(config.hnsw_provider.clone(), storage.clone()))
.await?;

Ok(CompactionManager::new(
scheduler,
log,
sysdb,
storage.clone(),
blockfile_provider,
HnswIndexProvider::new(storage.clone(), path),
hnsw_index_provider,
compaction_manager_queue_size,
Duration::from_secs(compaction_interval_sec),
min_compaction_size,
Expand Down Expand Up @@ -497,6 +499,7 @@ mod tests {

let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let hnsw_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let mut manager = CompactionManager::new(
scheduler,
log,
Expand All @@ -508,7 +511,11 @@ mod tests {
block_cache,
sparse_index_cache,
),
HnswIndexProvider::new(storage, PathBuf::from(tmpdir.path().to_str().unwrap())),
HnswIndexProvider::new(
storage,
PathBuf::from(tmpdir.path().to_str().unwrap()),
hnsw_cache,
),
compaction_manager_queue_size,
compaction_interval,
min_compaction_size,
Expand Down
42 changes: 42 additions & 0 deletions rust/worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ pub(crate) struct QueryServiceConfig {
pub(crate) log: crate::log::config::LogConfig,
pub(crate) dispatcher: crate::execution::config::DispatcherConfig,
pub(crate) blockfile_provider: crate::blockstore::config::BlockfileProviderConfig,
pub(crate) hnsw_provider: crate::index::config::HnswProviderConfig,
}

#[derive(Deserialize)]
Expand All @@ -130,6 +131,7 @@ pub(crate) struct CompactionServiceConfig {
pub(crate) dispatcher: crate::execution::config::DispatcherConfig,
pub(crate) compactor: crate::compactor::config::CompactorConfig,
pub(crate) blockfile_provider: crate::blockstore::config::BlockfileProviderConfig,
pub(crate) hnsw_provider: crate::index::config::HnswProviderConfig,
}

/// # Description
Expand Down Expand Up @@ -203,6 +205,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
compaction_service:
service_name: "compaction-service"
Expand Down Expand Up @@ -255,6 +262,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
"#,
);
let config = RootConfig::load();
Expand Down Expand Up @@ -323,6 +335,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
compaction_service:
service_name: "compaction-service"
Expand Down Expand Up @@ -375,6 +392,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
"#,
);
let config = RootConfig::load_from_path("random_path.yaml");
Expand Down Expand Up @@ -461,6 +483,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
compaction_service:
service_name: "compaction-service"
Expand Down Expand Up @@ -513,6 +540,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
"#,
);
let config = RootConfig::load();
Expand Down Expand Up @@ -593,6 +625,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
compaction_service:
service_name: "compaction-service"
Expand Down Expand Up @@ -637,6 +674,11 @@ mod tests {
sparse_index_cache_config:
lru:
capacity: 1000
hnsw_provider:
hnsw_temporary_path: "~/tmp"
hnsw_cache_config:
lru:
capacity: 1000
"#,
);
let config = RootConfig::load();
Expand Down
8 changes: 8 additions & 0 deletions rust/worker/src/index/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use crate::cache::config::CacheConfig;
use serde::Deserialize;

#[derive(Deserialize, Debug, Clone)]
pub(crate) struct HnswProviderConfig {
pub(crate) hnsw_temporary_path: String,
pub(crate) hnsw_cache_config: CacheConfig,
}
65 changes: 43 additions & 22 deletions rust/worker/src/index/hnsw_provider.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use super::config::HnswProviderConfig;
use super::{
HnswIndex, HnswIndexConfig, HnswIndexFromSegmentError, Index, IndexConfig,
IndexConfigFromSegmentError,
};
use crate::cache::cache::Cache;
use crate::config::Configurable;
use crate::errors::ErrorCodes;
use crate::index::types::PersistentIndex;
use crate::storage::stream::ByteStreamItem;
use crate::{errors::ChromaError, storage::Storage, types::Segment};
use async_trait::async_trait;
use futures::stream;
use futures::stream::StreamExt;
use parking_lot::RwLock;
use std::fmt::Debug;
use std::path::Path;
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use std::{path::PathBuf, sync::Arc};
use thiserror::Error;
use tokio::io::AsyncWriteExt;
use uuid::Uuid;
Expand All @@ -28,7 +32,7 @@ const FILES: [&'static str; 4] = [

#[derive(Clone)]
pub(crate) struct HnswIndexProvider {
cache: Arc<RwLock<HashMap<Uuid, Arc<RwLock<HnswIndex>>>>>,
cache: Cache<Uuid, Arc<RwLock<HnswIndex>>>,
pub(crate) temporary_storage_path: PathBuf,
storage: Storage,
}
Expand All @@ -37,25 +41,45 @@ impl Debug for HnswIndexProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"HnswIndexProvider {{ temporary_storage_path: {:?}, cache: {} }}",
"HnswIndexProvider {{ temporary_storage_path: {:?} }}",
self.temporary_storage_path,
self.cache.read().len(),
)
}
}

#[async_trait]
impl Configurable<(HnswProviderConfig, Storage)> for HnswIndexProvider {
async fn try_from_config(
config: &(HnswProviderConfig, Storage),
) -> Result<Self, Box<dyn ChromaError>> {
let (hnsw_config, storage) = config;
let cache = Cache::new(&hnsw_config.hnsw_cache_config);
Ok(Self {
cache,
storage: storage.clone(),
temporary_storage_path: PathBuf::from(&hnsw_config.hnsw_temporary_path),
})
}
}

impl HnswIndexProvider {
pub(crate) fn new(storage: Storage, storage_path: PathBuf) -> Self {
pub(crate) fn new(
storage: Storage,
storage_path: PathBuf,
cache: Cache<Uuid, Arc<RwLock<HnswIndex>>>,
) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
cache,
storage,
temporary_storage_path: storage_path,
}
}

pub(crate) fn get(&self, id: &Uuid) -> Option<Arc<RwLock<HnswIndex>>> {
let cache = self.cache.read();
cache.get(id).cloned()
match self.cache.get(id) {
Some(index) => Some(index.clone()),
None => None,
}
}

fn format_key(&self, id: &Uuid, file: &str) -> String {
Expand Down Expand Up @@ -97,7 +121,7 @@ impl HnswIndexProvider {
};

let hnsw_config = HnswIndexConfig::from_segment(segment, &new_storage_path);
let hnsw_config = match hnsw_config {
match hnsw_config {
Ok(hnsw_config) => hnsw_config,
Err(e) => {
return Err(Box::new(HnswIndexProviderForkError::HnswConfigError(*e)));
Expand All @@ -116,8 +140,7 @@ impl HnswIndexProvider {
match HnswIndex::load(storage_path_str, &index_config, new_id) {
Ok(index) => {
let index = Arc::new(RwLock::new(index));
let mut cache = self.cache.write();
cache.insert(new_id, index.clone());
self.cache.insert(new_id, index.clone());
Ok(index)
}
Err(e) => Err(Box::new(HnswIndexProviderForkError::IndexLoadError(e))),
Expand Down Expand Up @@ -244,8 +267,7 @@ impl HnswIndexProvider {
match HnswIndex::load(index_storage_path.to_str().unwrap(), &index_config, *id) {
Ok(index) => {
let index = Arc::new(RwLock::new(index));
let mut cache = self.cache.write();
cache.insert(*id, index.clone());
self.cache.insert(*id, index.clone());
Ok(index)
}
Err(e) => Err(Box::new(HnswIndexProviderOpenError::IndexLoadError(e))),
Expand All @@ -262,7 +284,6 @@ impl HnswIndexProvider {
// Cases
// A query comes in and the index is in the cache -> we can query the index based on segment files id (Same as compactor case 3 where we have the index)
// A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id

pub(crate) fn create(
&self,
// TODO: This should not take Segment. The index layer should not know about the segment concept
Expand Down Expand Up @@ -293,21 +314,19 @@ impl HnswIndexProvider {
}
};

let mut cache = self.cache.write();
let index = match HnswIndex::init(&index_config, Some(&hnsw_config), id) {
Ok(index) => index,
Err(e) => {
return Err(Box::new(HnswIndexProviderCreateError::IndexInitError(e)));
}
};
let index = Arc::new(RwLock::new(index));
cache.insert(id, index.clone());
self.cache.insert(id, index.clone());
Ok(index)
}

pub(crate) fn commit(&self, id: &Uuid) -> Result<(), Box<HnswIndexProviderCommitError>> {
let cache = self.cache.read();
let index = match cache.get(id) {
let index = match self.cache.get(id) {
Some(index) => index,
None => {
return Err(Box::new(HnswIndexProviderCommitError::NoIndexFound(*id)));
Expand All @@ -328,8 +347,8 @@ impl HnswIndexProvider {
// Scope to drop the cache lock before we await to write to s3
// TODO: since we commit(), we don't need to save the index here
{
let cache = self.cache.read();
let index = match cache.get(id) {
// let cache = self.cache.read();
let index = match self.cache.get(id) {
Some(index) => index,
None => {
return Err(Box::new(HnswIndexProviderFlushError::NoIndexFound(*id)));
Expand Down Expand Up @@ -494,9 +513,11 @@ pub(crate) enum HnswIndexProviderFileError {
mod tests {
use super::*;
use crate::{
cache::config::{CacheConfig, UnboundedCacheConfig},
storage::{local::LocalStorage, Storage},
types::SegmentType,
};
use std::collections::HashMap;

#[tokio::test]
async fn test_fork() {
Expand All @@ -507,8 +528,8 @@ mod tests {
std::fs::create_dir_all(&hnsw_tmp_path).unwrap();

let storage = Storage::Local(LocalStorage::new(storage_dir.to_str().unwrap()));

let provider = HnswIndexProvider::new(storage, hnsw_tmp_path);
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let provider = HnswIndexProvider::new(storage, hnsw_tmp_path, cache);
let segment = Segment {
id: Uuid::new_v4(),
r#type: SegmentType::HnswDistributed,
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/index/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub(crate) mod config;
pub(crate) mod fulltext;
mod hnsw;
pub(crate) mod hnsw_provider;
Expand Down
Loading

0 comments on commit 04f0db4

Please sign in to comment.