Skip to content

Commit

Permalink
[ENH] Introduce stream abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishiihara committed Jul 16, 2024
1 parent 3bcf445 commit a6e03f6
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 129 deletions.
142 changes: 68 additions & 74 deletions rust/worker/src/blockstore/arrow/concurrency_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,88 +2,82 @@
mod tests {
use crate::{
blockstore::arrow::{config::TEST_MAX_BLOCK_SIZE_BYTES, provider::ArrowBlockfileProvider},
storage::{local::LocalStorage, Storage},
cache::{
cache::Cache,
config::{CacheConfig, UnboundedCacheConfig},
},
storage::{sync_local::SyncLocalStorage, Storage},
};
use rand::Rng;
use shuttle::{future, thread};

#[test]
fn test_blockfile_shuttle() {
// shuttle::check_random(
// || {
// let tmp_dir = tempfile::tempdir().unwrap();
// let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
// let blockfile_provider = ArrowBlockfileProvider::new(storage);
// let writer = blockfile_provider.create::<&str, u32>().unwrap();
// let id = writer.id();
// // Generate N datapoints and then have T threads write them to the blockfile
// let range_min = 10;
// let range_max = 10000;
// let n = shuttle::rand::thread_rng().gen_range(range_min..range_max);
// // Make the max threads the number of cores * 2
// let max_threads = num_cpus::get() * 2;
// let t = shuttle::rand::thread_rng().gen_range(2..max_threads);
// let mut join_handles = Vec::with_capacity(t);
// for i in 0..t {
// let range_start = i * n / t;
// let range_end = (i + 1) * n / t;
// let writer = writer.clone();
// let handle = thread::spawn(move || {
// println!("Thread {} writing keys {} to {}", i, range_start, range_end);
// for j in range_start..range_end {
// let key_string = format!("key{}", j);
// future::block_on(async {
// writer
// .set::<&str, u32>("", key_string.as_str(), j as u32)
// .await
// .unwrap_or_else(|e| {
// println!(
// "Expect key to be set successfully, but got error: {:?}",
// e
// )
// });
// });
// }
// });
// join_handles.push(handle);
// }
shuttle::check_random(
|| {
let tmp_dir = tempfile::tempdir().unwrap();
let storage =
Storage::SyncLocal(SyncLocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache =
Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
block_cache,
sparse_index_cache,
);
let writer = blockfile_provider.create::<&str, u32>().unwrap();
let id = writer.id();
// Generate N datapoints and then have T threads write them to the blockfile
let range_min = 10;
let range_max = 10000;
let n = shuttle::rand::thread_rng().gen_range(range_min..range_max);
// Make the max threads the number of cores * 2
let max_threads = num_cpus::get() * 2;
let t = shuttle::rand::thread_rng().gen_range(2..max_threads);
let mut join_handles = Vec::with_capacity(t);
for i in 0..t {
let range_start = i * n / t;
let range_end = (i + 1) * n / t;
let writer = writer.clone();
let handle = thread::spawn(move || {
for j in range_start..range_end {
let key_string = format!("key{}", j);
future::block_on(async {
writer
.set::<&str, u32>("", key_string.as_str(), j as u32)
.await
.unwrap();
});
}
});
join_handles.push(handle);
}

// for handle in join_handles {
// handle.join().unwrap();
// }
for handle in join_handles {
handle.join().unwrap();
}

// // commit the writer
// future::block_on(async {
// let flusher = writer.commit::<&str, u32>().unwrap();
// flusher.flush::<&str, u32>().await.unwrap();
// });
// commit the writer
future::block_on(async {
let flusher = writer.commit::<&str, u32>().unwrap();
flusher.flush::<&str, u32>().await.unwrap();
});

// let reader = future::block_on(async {
// blockfile_provider.open::<&str, u32>(&id).await.unwrap()
// });
// // Read the data back
// for i in 0..n {
// let key_string = format!("key{}", i);
// println!("Reading key {}", key_string);
// future::block_on(async {
// match reader.get("", key_string.as_str()).await {
// Ok(value) => {
// // value.expect("Expect key to exist and there to be no error");
// assert_eq!(value, i as u32);
// }
// Err(e) => {
// println!(
// "Expect key to exist and there to be no error, but got error: {:?}",
// e
// )
// }
// }
// });
// // let value = value.expect("Expect key to exist and there to be no error");
// // assert_eq!(value, i as u32);
// }
// },
// 100,
// );
let reader = future::block_on(async {
blockfile_provider.open::<&str, u32>(&id).await.unwrap()
});
// Read the data back
for i in 0..n {
let key_string = format!("key{}", i);
let value =
future::block_on(async { reader.get("", key_string.as_str()).await });
let value = value.expect("Expect key to exist and there to be no error");
assert_eq!(value, i as u32);
}
},
100,
);
}
}
75 changes: 44 additions & 31 deletions rust/worker/src/blockstore/arrow/provider.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
block::{self, delta::BlockDelta, Block},
block::{delta::BlockDelta, Block},
blockfile::{ArrowBlockfileReader, ArrowBlockfileWriter},
config::ArrowBlockfileProviderConfig,
sparse_index::SparseIndex,
Expand All @@ -15,12 +15,12 @@ use crate::{
},
config::Configurable,
errors::{ChromaError, ErrorCodes},
storage::{config::StorageConfig, Storage},
storage::Storage,
};
use async_trait::async_trait;
use core::panic;
use futures::StreamExt;
use thiserror::Error;
use tokio::io::AsyncReadExt;
use tracing::{Instrument, Span};
use uuid::Uuid;

Expand Down Expand Up @@ -213,28 +213,37 @@ impl BlockManager {
None => {
async {
let key = format!("block/{}", id);
let bytes = self.storage.get(&key).instrument(
let stream = self.storage.get(&key).instrument(
tracing::trace_span!(parent: Span::current(), "BlockManager storage get"),
).await;
let mut buf: Vec<u8> = Vec::new();
match bytes {
match stream {
Ok(mut bytes) => {
let res = bytes.read_to_end(&mut buf).instrument(
tracing::trace_span!(parent: Span::current(), "BlockManager read bytes to end"),
let read_block_span = tracing::trace_span!(parent: Span::current(), "BlockManager read bytes to end");
let buf = read_block_span.in_scope(|| async {
let mut buf: Vec<u8> = Vec::new();
while let Some(res) = bytes.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(e) => {
tracing::error!("Error reading block from storage: {}", e);
return None;
}
}
}
Some(buf)
}
).await;
tracing::info!("Read {:?} bytes from s3", buf.len());
match res {
Ok(_) => {}
Err(e) => {
// TODO: Return an error to callsite instead of None.
tracing::error!(
"Error reading block {:?} from s3 {:?}",
key,
e
);
let buf = match buf {
Some(buf) => {
buf
}
None => {
return None;
}
}
};
tracing::info!("Read {:?} bytes from s3", buf.len());
let deserialization_span = tracing::trace_span!(parent: Span::current(), "BlockManager deserialize block");
let block = deserialization_span.in_scope(|| Block::from_bytes(&buf, *id));
match block {
Expand All @@ -252,10 +261,9 @@ impl BlockManager {
None
}
}
}
},
Err(e) => {
// TODO: Return an error to callsite instead of None.
tracing::error!("Error reading block {:?} from s3 {:?}", key, e);
tracing::error!("Error reading block from storage: {}", e);
None
}
}
Expand Down Expand Up @@ -330,17 +338,22 @@ impl SparseIndexManager {
tracing::info!("Cache miss - fetching sparse index from storage");
let key = format!("sparse_index/{}", id);
tracing::debug!("Reading sparse index from storage with key: {}", key);
let bytes = self.storage.get(&key).await;
let stream = self.storage.get(&key).await;
let mut buf: Vec<u8> = Vec::new();
match bytes {
match stream {
Ok(mut bytes) => {
let res = bytes.read_to_end(&mut buf).await;
match res {
Ok(_) => {}
Err(e) => {
// TODO: return error
tracing::error!("Error reading sparse index from storage: {}", e);
return None;
while let Some(res) = bytes.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(e) => {
tracing::error!(
"Error reading sparse index from storage: {}",
e
);
return None;
}
}
}
let block = Block::from_bytes(&buf, *id);
Expand Down
58 changes: 47 additions & 11 deletions rust/worker/src/index/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ use super::{
};
use crate::errors::ErrorCodes;
use crate::index::types::PersistentIndex;
use crate::storage::stream::ByteStreamItem;
use crate::{errors::ChromaError, storage::Storage, types::Segment};
use futures::stream;
use futures::stream::StreamExt;
use parking_lot::RwLock;
use std::fmt::Debug;
use std::path::Path;
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use thiserror::Error;
use tokio::io::AsyncWriteExt;
use tracing::{instrument, Instrument, Span};
use uuid::Uuid;

Expand Down Expand Up @@ -130,8 +134,9 @@ impl HnswIndexProvider {
// Fetch the files from storage and put them in the index storage path
for file in FILES.iter() {
let key = self.format_key(source_id, file);
let res = self.storage.get(&key).await;
let mut reader = match res {
tracing::info!("Loading hnsw index file: {}", key);
let stream = self.storage.get(&key).await;
let reader = match stream {
Ok(reader) => reader,
Err(e) => {
tracing::error!("Failed to load hnsw index file from storage: {}", e);
Expand All @@ -142,27 +147,58 @@ impl HnswIndexProvider {
let file_path = index_storage_path.join(file);
// For now, we never evict from the cache, so if the index is being loaded, the file does not exist
let file_handle = tokio::fs::File::create(&file_path).await;
let mut file_handle = match file_handle {
let file_handle = match file_handle {
Ok(file) => file,
Err(e) => {
tracing::error!("Failed to create file: {}", e);
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
};
let copy_res = tokio::io::copy(&mut reader, &mut file_handle)
.instrument(tracing::info_span!(parent: Span::current(), "hnsw provider file read", file = file))
.await;
match copy_res {
Ok(bytes_read) => {
tracing::info!("Copied {} bytes to file {:?}", bytes_read, file_path);
let total_bytes_written = self.copy_stream_to_local_file(reader, file_handle).await?;
tracing::info!(
"Copied {} bytes from storage key: {} to file: {}",
total_bytes_written,
key,
file_path.to_str().unwrap()
);
// bytes is an AsyncBufRead, so we fil and consume it to a file
tracing::info!("Loaded hnsw index file: {}", file);
}
Ok(())
}

async fn copy_stream_to_local_file(
&self,
stream: Box<dyn stream::Stream<Item = ByteStreamItem> + Unpin + Send>,
file_handle: tokio::fs::File,
) -> Result<u64, Box<HnswIndexProviderFileError>> {
let mut total_bytes_written = 0;
let mut file_handle = file_handle;
let mut stream = stream;
while let Some(res) = stream.next().await {
let chunk = match res {
Ok(chunk) => chunk,
Err(e) => {
return Err(Box::new(HnswIndexProviderFileError::StorageGetError(e)));
}
};

let res = file_handle.write_all(&chunk).await;
match res {
Ok(_) => {
total_bytes_written += chunk.len() as u64;
}
Err(e) => {
tracing::error!("Failed to copy file: {}", e);
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
}
}
Ok(())
match file_handle.flush().await {
Ok(_) => Ok(total_bytes_written),
Err(e) => {
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
}
}

pub(crate) async fn open(
Expand Down
2 changes: 2 additions & 0 deletions rust/worker/src/storage/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub(crate) enum StorageConfig {
S3(S3StorageConfig),
#[serde(alias = "local")]
Local(LocalStorageConfig),
#[serde(alias = "sync_local")]
SyncLocal(LocalStorageConfig),
}

#[derive(Deserialize, PartialEq, Debug)]
Expand Down
Loading

0 comments on commit a6e03f6

Please sign in to comment.