Skip to content

Commit

Permalink
[ENH] Propagate version through system (#2839)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Pass the request_version_context into the protobufs and to the rust
code. In the rust orchestrators, we error on mismatches and propagate
this error up to python.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Sep 25, 2024
1 parent 7a29ad4 commit 2c382a8
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 15 deletions.
19 changes: 19 additions & 0 deletions chromadb/proto/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LogRecord,
Metadata,
Operation,
RequestVersionContext,
ScalarEncoding,
Segment,
SegmentScope,
Expand Down Expand Up @@ -295,3 +296,21 @@ def from_proto_vector_query_result(
distance=vector_query_result.distance,
embedding=from_proto_vector(vector_query_result.vector)[0],
)


def from_proto_request_version_context(
request_version_context: proto.RequestVersionContext,
) -> RequestVersionContext:
return RequestVersionContext(
collection_version=request_version_context.collection_version,
log_position=request_version_context.log_position,
)


def to_proto_request_version_context(
request_version_context: RequestVersionContext,
) -> proto.RequestVersionContext:
return proto.RequestVersionContext(
collection_version=request_version_context["collection_version"],
log_position=request_version_context["log_position"],
)
9 changes: 7 additions & 2 deletions chromadb/segment/impl/metadata/grpc_segment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Optional, Sequence
from chromadb.proto.convert import to_proto_request_version_context
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.segment import MetadataReader
from chromadb.config import System
Expand Down Expand Up @@ -52,9 +53,11 @@ def count(self, request_version_context: RequestVersionContext) -> int:
request: pb.CountRecordsRequest = pb.CountRecordsRequest(
segment_id=self._segment["id"].hex,
collection_id=self._segment["collection"].hex,
version_context=to_proto_request_version_context(request_version_context),
)
response: pb.CountRecordsResponse = self._metadata_reader_stub.CountRecords(
request, timeout=self._request_timeout_seconds
request,
timeout=self._request_timeout_seconds,
)
return response.count

Expand Down Expand Up @@ -104,10 +107,12 @@ def get_metadata(
limit=limit,
offset=offset,
include_metadata=include_metadata,
version_context=to_proto_request_version_context(request_version_context),
)

response: pb.QueryMetadataResponse = self._metadata_reader_stub.QueryMetadata(
request, timeout=self._request_timeout_seconds
request,
timeout=self._request_timeout_seconds,
)
results: List[MetadataEmbeddingRecord] = []
for record in response.records:
Expand Down
11 changes: 9 additions & 2 deletions chromadb/segment/impl/vector/grpc_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from chromadb.proto.convert import (
from_proto_vector_embedding_record,
from_proto_vector_query_result,
to_proto_request_version_context,
to_proto_vector,
)
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
Expand Down Expand Up @@ -64,9 +65,11 @@ def get_vectors(
ids=ids,
segment_id=self._segment["id"].hex,
collection_id=self._segment["collection"].hex,
version_context=to_proto_request_version_context(request_version_context),
)
response: GetVectorsResponse = self._vector_reader_stub.GetVectors(
request, timeout=self._request_timeout_seconds
request,
timeout=self._request_timeout_seconds,
)
results: List[VectorEmbeddingRecord] = []
for vector in response.records:
Expand All @@ -89,9 +92,13 @@ def query_vectors(
include_embeddings=query["include_embeddings"],
segment_id=self._segment["id"].hex,
collection_id=self._segment["collection"].hex,
version_context=to_proto_request_version_context(
query["request_version_context"]
),
)
response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors(
request, timeout=self._request_timeout_seconds
request,
timeout=self._request_timeout_seconds,
)
results: List[List[VectorQueryResult]] = []
for result in response.results:
Expand Down
8 changes: 5 additions & 3 deletions rust/error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub enum ErrorCodes {
// CANCELLED indicates the operation was cancelled (typically by the caller).
Cancelled = 1,
// UNKNOWN indicates an unknown error.
UNKNOWN = 2,
Unknown = 2,
// INVALID_ARGUMENT indicates client specified an invalid argument.
InvalidArgument = 3,
// DEADLINE_EXCEEDED means operation expired before completion.
Expand All @@ -21,8 +21,6 @@ pub enum ErrorCodes {
AlreadyExists = 6,
// PERMISSION_DENIED indicates the caller does not have permission to execute the specified operation.
PermissionDenied = 7,
// UNAUTHENTICATED indicates the request does not have valid authentication credentials for the operation.
UNAUTHENTICATED = 16,
// RESOURCE_EXHAUSTED indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system is out of space.
ResourceExhausted = 8,
// FAILED_PRECONDITION indicates operation was rejected because the system is not in a state required for the operation's execution.
Expand All @@ -39,6 +37,10 @@ pub enum ErrorCodes {
Unavailable = 14,
// DATA_LOSS indicates unrecoverable data loss or corruption.
DataLoss = 15,
// UNAUTHENTICATED indicates the request does not have valid authentication credentials for the operation.
Unauthenticated = 16,
// VERSION_MISMATCH indicates a version mismatch. This is not from the gRPC spec and is specific to Chroma.
VersionMismatch = 17,
}

pub trait ChromaError: Error + Send {
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/execution/operators/merge_knn_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub enum MergeKnnResultsOperatorError {}

impl ChromaError for MergeKnnResultsOperatorError {
fn code(&self) -> ErrorCodes {
return ErrorCodes::UNKNOWN;
return ErrorCodes::Unknown;
}
}

Expand Down
26 changes: 24 additions & 2 deletions rust/worker/src/execution/orchestration/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ pub(crate) struct CountQueryOrchestrator {
blockfile_provider: BlockfileProvider,
// Result channel
result_channel: Option<tokio::sync::oneshot::Sender<Result<usize, Box<dyn ChromaError>>>>,
// Request version context
collection_version: u32,
log_position: u64,
}

#[derive(Error, Debug)]
Expand All @@ -52,6 +55,8 @@ enum CountQueryOrchestratorError {
CollectionNotFound(Uuid),
#[error("Get collection error: {0}")]
GetCollectionError(#[from] GetCollectionsError),
#[error("Collection version mismatch")]
CollectionVersionMismatch,
}

impl ChromaError for CountQueryOrchestratorError {
Expand All @@ -65,6 +70,7 @@ impl ChromaError for CountQueryOrchestratorError {
CountQueryOrchestratorError::SystemTimeError(_) => ErrorCodes::Internal,
CountQueryOrchestratorError::CollectionNotFound(_) => ErrorCodes::NotFound,
CountQueryOrchestratorError::GetCollectionError(e) => e.code(),
CountQueryOrchestratorError::CollectionVersionMismatch => ErrorCodes::VersionMismatch,
}
}
}
Expand All @@ -78,6 +84,8 @@ impl CountQueryOrchestrator {
sysdb: Box<SysDb>,
dispatcher: ComponentHandle<Dispatcher>,
blockfile_provider: BlockfileProvider,
collection_version: u32,
log_position: u64,
) -> Self {
Self {
system,
Expand All @@ -90,6 +98,8 @@ impl CountQueryOrchestrator {
dispatcher,
blockfile_provider,
result_channel: None,
collection_version,
log_position,
}
}

Expand Down Expand Up @@ -140,8 +150,19 @@ impl CountQueryOrchestrator {
}
};

// If the collection version does not match the request version then we terminate with an error
if collection.version as u32 != self.collection_version {
terminate_with_error(
self.result_channel.take(),
Box::new(CountQueryOrchestratorError::CollectionVersionMismatch),
ctx,
);
return;
}

self.record_segment = Some(record_segment);
self.collection = Some(collection);
self.pull_logs(ctx).await;
}

// shared
Expand Down Expand Up @@ -170,7 +191,9 @@ impl CountQueryOrchestrator {
let input = PullLogsInput::new(
collection.id,
// The collection log position is inclusive, and we want to start from the next log.
collection.log_position + 1,
// Note that we query using the incoming log position this is critical for correctness
// TODO: We should make all the log service code use u64 instead of i64
(self.log_position as i64) + 1,
100,
None,
Some(end_timestamp),
Expand Down Expand Up @@ -307,7 +330,6 @@ impl Component for CountQueryOrchestrator {

async fn on_start(&mut self, ctx: &crate::system::ComponentContext<Self>) -> () {
self.start(ctx).await;
self.pull_logs(ctx).await;
}
}

Expand Down
23 changes: 22 additions & 1 deletion rust/worker/src/execution/orchestration/get_vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ enum GetVectorsError {
TaskSendError(#[from] ChannelError),
#[error("System time error")]
SystemTimeError(#[from] std::time::SystemTimeError),
#[error("Collection version mismatch")]
CollectionVersionMismatch,
}

impl ChromaError for GetVectorsError {
fn code(&self) -> ErrorCodes {
match self {
GetVectorsError::TaskSendError(e) => e.code(),
GetVectorsError::SystemTimeError(_) => ErrorCodes::Internal,
GetVectorsError::CollectionVersionMismatch => ErrorCodes::VersionMismatch,
}
}
}
Expand All @@ -74,6 +77,8 @@ pub struct GetVectorsOrchestrator {
// Result channel
result_channel:
Option<tokio::sync::oneshot::Sender<Result<GetVectorsResult, Box<dyn ChromaError>>>>,
collection_version: u32,
log_position: u64,
}

impl GetVectorsOrchestrator {
Expand All @@ -86,6 +91,8 @@ impl GetVectorsOrchestrator {
sysdb: Box<SysDb>,
dispatcher: ComponentHandle<Dispatcher>,
blockfile_provider: BlockfileProvider,
collection_version: u32,
log_position: u64,
) -> Self {
Self {
state: ExecutionState::Pending,
Expand All @@ -100,6 +107,8 @@ impl GetVectorsOrchestrator {
record_segment: None,
collection: None,
result_channel: None,
collection_version,
log_position,
}
}

Expand Down Expand Up @@ -132,7 +141,9 @@ impl GetVectorsOrchestrator {
let input = PullLogsInput::new(
collection.id,
// The collection log position is inclusive, and we want to start from the next log
collection.log_position + 1,
// Note that we query using the incoming log position this is critical for correctness
// TODO: We should make all the log service code use u64 instead of i64
(self.log_position as i64) + 1,
100,
None,
Some(end_timestamp),
Expand Down Expand Up @@ -241,6 +252,16 @@ impl Component for GetVectorsOrchestrator {
}
};

// If the collection version does not match the request version then we terminate with an error
if collection.version as u32 != self.collection_version {
terminate_with_error(
self.result_channel.take(),
Box::new(GetVectorsError::CollectionVersionMismatch),
ctx,
);
return;
}

let record_segment =
match get_record_segment_by_collection_id(self.sysdb.clone(), collection_id).await {
Ok(segment) => segment,
Expand Down
24 changes: 23 additions & 1 deletion rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ enum HnswSegmentQueryError {
RecordSegmentNotFound(Uuid),
#[error("Collection has no dimension set")]
CollectionHasNoDimension,
#[error("Collection version mismatch")]
CollectionVersionMismatch,
}

impl ChromaError for HnswSegmentQueryError {
Expand All @@ -99,6 +101,7 @@ impl ChromaError for HnswSegmentQueryError {
HnswSegmentQueryError::GetCollectionError(_) => ErrorCodes::Internal,
HnswSegmentQueryError::RecordSegmentNotFound(_) => ErrorCodes::NotFound,
HnswSegmentQueryError::CollectionHasNoDimension => ErrorCodes::InvalidArgument,
HnswSegmentQueryError::CollectionVersionMismatch => ErrorCodes::VersionMismatch,
}
}
}
Expand Down Expand Up @@ -143,6 +146,9 @@ pub(crate) struct HnswQueryOrchestrator {
result_channel: Option<
tokio::sync::oneshot::Sender<Result<Vec<Vec<VectorQueryResult>>, Box<dyn ChromaError>>>,
>,
// Request version context
collection_version: u32,
log_position: u64,
}

impl HnswQueryOrchestrator {
Expand All @@ -159,6 +165,8 @@ impl HnswQueryOrchestrator {
hnsw_index_provider: HnswIndexProvider,
blockfile_provider: BlockfileProvider,
dispatcher: ComponentHandle<Dispatcher>,
collection_version: u32,
log_position: u64,
) -> Self {
// Set the merge dependency count to the number of query vectors * 2
// N for the HNSW query and N for the Brute force query
Expand Down Expand Up @@ -203,6 +211,8 @@ impl HnswQueryOrchestrator {
hnsw_index_provider,
blockfile_provider,
result_channel: None,
collection_version,
log_position,
}
}

Expand Down Expand Up @@ -230,7 +240,9 @@ impl HnswQueryOrchestrator {
let input = PullLogsInput::new(
collection.id,
// The collection log position is inclusive, and we want to start from the next log
collection.log_position + 1,
// Note that we query using the incoming log position this is critical for correctness
// TODO: We should make all the log service code use u64 instead of i64
(self.log_position as i64) + 1,
100,
None,
Some(end_timestamp),
Expand Down Expand Up @@ -571,6 +583,16 @@ impl Component for HnswQueryOrchestrator {
}
};

// If the collection version does not match the request version then we terminate with an error
if collection.version as u32 != self.collection_version {
terminate_with_error(
self.result_channel.take(),
Box::new(HnswSegmentQueryError::CollectionVersionMismatch),
ctx,
);
return;
}

// If segment is uninitialized and dimension is not set then we assume
// that this is a query before any add so return empty response.
if hnsw_segment.file_path.len() <= 0 && collection.dimension.is_none() {
Expand Down
Loading

0 comments on commit 2c382a8

Please sign in to comment.