Skip to content

Commit

Permalink
synchronize post s3.get operation in providers
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Aug 23, 2024
1 parent 69d1320 commit ee24084
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 25 deletions.
15 changes: 13 additions & 2 deletions rust/blockstore/src/arrow/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use chroma_error::{ChromaError, ErrorCodes};
use chroma_storage::Storage;
use core::panic;
use futures::StreamExt;
use std::sync::Arc;
use thiserror::Error;
use tracing::{Instrument, Span};
use uuid::Uuid;
Expand Down Expand Up @@ -144,6 +145,7 @@ pub(super) struct BlockManager {
block_cache: Cache<Uuid, Block>,
storage: Storage,
max_block_size_bytes: usize,
write_mutex: Arc<tokio::sync::Mutex<()>>,
}

impl BlockManager {
Expand All @@ -156,6 +158,7 @@ impl BlockManager {
block_cache,
storage,
max_block_size_bytes,
write_mutex: Arc::new(tokio::sync::Mutex::new(())),
}
}

Expand Down Expand Up @@ -228,8 +231,16 @@ impl BlockManager {
deserialization_span.in_scope(|| Block::from_bytes(&bytes, *id));
match block {
Ok(block) => {
self.block_cache.insert(*id, block.clone());
Some(block)
let _guard = self.write_mutex.lock().await;
match self.block_cache.get(id) {
Some(b) => {
return Some(b);
}
None => {
self.block_cache.insert(*id, block.clone());
Some(block)
}
}
}
Err(e) => {
// TODO: Return an error to callsite instead of None.
Expand Down
74 changes: 52 additions & 22 deletions rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@ use super::{
};
use crate::types::PersistentIndex;
use async_trait::async_trait;
use chroma_cache::cache;
use chroma_cache::cache::Cache;
use chroma_config::Configurable;
use chroma_error::ChromaError;
use chroma_error::ErrorCodes;
use chroma_storage::stream::ByteStreamItem;
use chroma_storage::Storage;
use chroma_types::Segment;
use futures::stream;
use futures::stream::StreamExt;
use parking_lot::RwLock;
use rand::seq::index;
use std::fmt::Debug;
Expand Down Expand Up @@ -53,6 +49,7 @@ pub struct HnswIndexProvider {
cache: Cache<Uuid, Arc<RwLock<HnswIndex>>>,
pub temporary_storage_path: PathBuf,
storage: Storage,
write_mutex: Arc<tokio::sync::Mutex<()>>,
}

impl Debug for HnswIndexProvider {
Expand All @@ -76,6 +73,7 @@ impl Configurable<(HnswProviderConfig, Storage)> for HnswIndexProvider {
cache,
storage: storage.clone(),
temporary_storage_path: PathBuf::from(&hnsw_config.hnsw_temporary_path),
write_mutex: Arc::new(tokio::sync::Mutex::new(())),
})
}
}
Expand All @@ -90,6 +88,7 @@ impl HnswIndexProvider {
cache,
storage,
temporary_storage_path: storage_path,
write_mutex: Arc::new(tokio::sync::Mutex::new(())),
}
}

Expand Down Expand Up @@ -119,7 +118,9 @@ impl HnswIndexProvider {
) -> Result<Arc<RwLock<HnswIndex>>, Box<HnswIndexProviderForkError>> {
let new_id = Uuid::new_v4();
let new_storage_path = self.temporary_storage_path.join(new_id.to_string());
match self.create_dir_all(&new_storage_path) {
// This is ok to be called from multiple threads concurrently. See
// the documentation of tokio::fs::create_dir_all to see why.
match self.create_dir_all(&new_storage_path).await {
Ok(_) => {}
Err(e) => {
return Err(Box::new(HnswIndexProviderForkError::FileError(*e)));
Expand Down Expand Up @@ -164,9 +165,17 @@ impl HnswIndexProvider {

match HnswIndex::load(storage_path_str, &index_config, new_id) {
Ok(index) => {
let index = Arc::new(RwLock::new(index));
self.cache.insert(segment.collection, index.clone());
Ok(index)
let _guard = self.write_mutex.lock().await;
match self.cache.get(&segment.collection) {
Some(index) => {
return Ok(index.clone());
}
None => {
let index = Arc::new(RwLock::new(index));
self.cache.insert(segment.collection, index.clone());
Ok(index)
}
}
}
Err(e) => Err(Box::new(HnswIndexProviderForkError::IndexLoadError(e))),
}
Expand All @@ -177,6 +186,8 @@ impl HnswIndexProvider {
file_path: &PathBuf,
buf: Arc<Vec<u8>>,
) -> Result<(), Box<HnswIndexProviderFileError>> {
// Synchronize concurrent writes to the same file.
let _guard = self.write_mutex.lock().await;
let file_handle = tokio::fs::File::create(&file_path).await;
let mut file_handle = match file_handle {
Ok(file) => file,
Expand Down Expand Up @@ -225,7 +236,6 @@ impl HnswIndexProvider {
}
};
let file_path = index_storage_path.join(file);
// For now, we never evict from the cache, so if the index is being loaded, the file does not exist
self.copy_bytes_to_local_file(&file_path, buf).instrument(tracing::info_span!(parent: Span::current(), "hnsw provider copy bytes to local file", file = file)).await?;
tracing::info!(
"Copied {} bytes from storage key: {} to file: {}",
Expand All @@ -246,7 +256,8 @@ impl HnswIndexProvider {
) -> Result<Arc<RwLock<HnswIndex>>, Box<HnswIndexProviderOpenError>> {
let index_storage_path = self.temporary_storage_path.join(id.to_string());

match self.create_dir_all(&index_storage_path) {
// Create directories should be thread safe.
match self.create_dir_all(&index_storage_path).await {
Ok(_) => {}
Err(e) => {
return Err(Box::new(HnswIndexProviderOpenError::FileError(*e)));
Expand All @@ -263,6 +274,7 @@ impl HnswIndexProvider {
}
}

// Thread safe.
let index_config = IndexConfig::from_segment(&segment, dimensionality);
let index_config = match index_config {
Ok(index_config) => index_config,
Expand All @@ -271,6 +283,7 @@ impl HnswIndexProvider {
}
};

// Thread safe.
let hnsw_config = HnswIndexConfig::from_segment(segment, &index_storage_path);
let hnsw_config = match hnsw_config {
Ok(hnsw_config) => hnsw_config,
Expand All @@ -282,9 +295,17 @@ impl HnswIndexProvider {
// TODO: don't unwrap path conv here
match HnswIndex::load(index_storage_path.to_str().unwrap(), &index_config, *id) {
Ok(index) => {
let index = Arc::new(RwLock::new(index));
self.cache.insert(segment.collection, index.clone());
Ok(index)
let _guard = self.write_mutex.lock().await;
match self.cache.get(&segment.collection) {
Some(index) => {
return Ok(index.clone());
}
None => {
let index = Arc::new(RwLock::new(index));
self.cache.insert(segment.collection, index.clone());
Ok(index)
}
}
}
Err(e) => Err(Box::new(HnswIndexProviderOpenError::IndexLoadError(e))),
}
Expand All @@ -300,7 +321,7 @@ impl HnswIndexProvider {
// Cases
// A query comes in and the index is in the cache -> we can query the index based on segment files id (Same as compactor case 3 where we have the index)
// A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id
pub fn create(
pub async fn create(
&self,
// TODO: This should not take Segment. The index layer should not know about the segment concept
segment: &Segment,
Expand All @@ -309,7 +330,7 @@ impl HnswIndexProvider {
let id = Uuid::new_v4();
let index_storage_path = self.temporary_storage_path.join(id.to_string());

match self.create_dir_all(&index_storage_path) {
match self.create_dir_all(&index_storage_path).await {
Ok(_) => {}
Err(e) => {
return Err(Box::new(HnswIndexProviderCreateError::FileError(*e)));
Expand All @@ -336,9 +357,18 @@ impl HnswIndexProvider {
return Err(Box::new(HnswIndexProviderCreateError::IndexInitError(e)));
}
};
let index = Arc::new(RwLock::new(index));
self.cache.insert(segment.collection, index.clone());
Ok(index)

let _guard = self.write_mutex.lock().await;
match self.cache.get(&segment.collection) {
Some(index) => {
return Ok(index.clone());
}
None => {
let index = Arc::new(RwLock::new(index));
self.cache.insert(segment.collection, index.clone());
Ok(index)
}
}
}

pub fn commit(&self, index: Arc<RwLock<HnswIndex>>) -> Result<(), Box<dyn ChromaError>> {
Expand Down Expand Up @@ -391,8 +421,8 @@ impl HnswIndexProvider {
tokio::fs::remove_dir_all(index_storage_path).await
}

fn create_dir_all(&self, path: &PathBuf) -> Result<(), Box<HnswIndexProviderFileError>> {
match std::fs::create_dir_all(path) {
async fn create_dir_all(&self, path: &PathBuf) -> Result<(), Box<HnswIndexProviderFileError>> {
match tokio::fs::create_dir_all(path).await {
Ok(_) => Ok(()),
Err(e) => return Err(Box::new(HnswIndexProviderFileError::IOError(e))),
}
Expand Down Expand Up @@ -532,7 +562,7 @@ mod tests {
let hnsw_tmp_path = storage_dir.join("hnsw");

// Create the directories needed
std::fs::create_dir_all(&hnsw_tmp_path).unwrap();
tokio::fs::create_dir_all(&hnsw_tmp_path).await.unwrap();

let storage = Storage::Local(LocalStorage::new(storage_dir.to_str().unwrap()));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
Expand All @@ -547,7 +577,7 @@ mod tests {
};

let dimensionality = 128;
let created_index = provider.create(&segment, dimensionality).unwrap();
let created_index = provider.create(&segment, dimensionality).await.unwrap();
let created_index_id = created_index.read().id;

let forked_index = provider
Expand Down
5 changes: 4 additions & 1 deletion rust/worker/src/segment/distributed_hnsw_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ impl DistributedHNSWSegmentWriter {
segment.id,
)))
} else {
let index = match hnsw_index_provider.create(segment, dimensionality as i32) {
let index = match hnsw_index_provider
.create(segment, dimensionality as i32)
.await
{
Ok(index) => index,
Err(e) => {
return Err(Box::new(
Expand Down

0 comments on commit ee24084

Please sign in to comment.