From 2c382a85dbacdf3432e4542e8d7296a592a4847f Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Wed, 25 Sep 2024 11:04:30 -0700 Subject: [PATCH] [ENH] Propagate version through system (#2839) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Pass the request_version_context into the protobufs and to the rust code. In the rust orchestrators, we error on mismatches and propagate this error up to python. - New functionality - None ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None --- chromadb/proto/convert.py | 19 +++++++ .../segment/impl/metadata/grpc_segment.py | 9 +++- chromadb/segment/impl/vector/grpc_segment.py | 11 ++++- rust/error/src/lib.rs | 8 +-- .../execution/operators/merge_knn_results.rs | 2 +- .../src/execution/orchestration/count.rs | 26 +++++++++- .../execution/orchestration/get_vectors.rs | 23 ++++++++- .../src/execution/orchestration/hnsw.rs | 24 ++++++++- .../src/execution/orchestration/metadata.rs | 28 ++++++++++- rust/worker/src/server.rs | 49 ++++++++++++++++++- 10 files changed, 184 insertions(+), 15 deletions(-) diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 3ca3ae23369..8f131c6c3a6 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -9,6 +9,7 @@ LogRecord, Metadata, Operation, + RequestVersionContext, ScalarEncoding, Segment, SegmentScope, @@ -295,3 +296,21 @@ def from_proto_vector_query_result( distance=vector_query_result.distance, embedding=from_proto_vector(vector_query_result.vector)[0], ) + + +def from_proto_request_version_context( + request_version_context: proto.RequestVersionContext, +) -> RequestVersionContext: + return RequestVersionContext( + collection_version=request_version_context.collection_version, + log_position=request_version_context.log_position, + ) + + +def to_proto_request_version_context( + request_version_context: RequestVersionContext, +) -> proto.RequestVersionContext: + return proto.RequestVersionContext( + collection_version=request_version_context["collection_version"], + log_position=request_version_context["log_position"], + ) diff --git a/chromadb/segment/impl/metadata/grpc_segment.py b/chromadb/segment/impl/metadata/grpc_segment.py index 03119d3c2ef..54718a0284f 100644 --- a/chromadb/segment/impl/metadata/grpc_segment.py +++ b/chromadb/segment/impl/metadata/grpc_segment.py @@ -1,4 +1,5 @@ from typing import Dict, List, Optional, Sequence +from chromadb.proto.convert import to_proto_request_version_context from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor from chromadb.segment import MetadataReader from chromadb.config import System @@ -52,9 +53,11 @@ def count(self, request_version_context: RequestVersionContext) -> int: request: pb.CountRecordsRequest = pb.CountRecordsRequest( segment_id=self._segment["id"].hex, collection_id=self._segment["collection"].hex, + version_context=to_proto_request_version_context(request_version_context), ) response: pb.CountRecordsResponse = self._metadata_reader_stub.CountRecords( - request, timeout=self._request_timeout_seconds + request, + timeout=self._request_timeout_seconds, ) return response.count @@ -104,10 +107,12 @@ def get_metadata( limit=limit, offset=offset, include_metadata=include_metadata, + version_context=to_proto_request_version_context(request_version_context), ) response: pb.QueryMetadataResponse = self._metadata_reader_stub.QueryMetadata( - request, timeout=self._request_timeout_seconds + request, + timeout=self._request_timeout_seconds, ) results: List[MetadataEmbeddingRecord] = [] for record in response.records: diff --git a/chromadb/segment/impl/vector/grpc_segment.py b/chromadb/segment/impl/vector/grpc_segment.py index 9f55a3b6fd0..576409b7bc1 100644 --- a/chromadb/segment/impl/vector/grpc_segment.py +++ b/chromadb/segment/impl/vector/grpc_segment.py @@ -4,6 +4,7 @@ from chromadb.proto.convert import ( from_proto_vector_embedding_record, from_proto_vector_query_result, + to_proto_request_version_context, to_proto_vector, ) from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor @@ -64,9 +65,11 @@ def get_vectors( ids=ids, segment_id=self._segment["id"].hex, collection_id=self._segment["collection"].hex, + version_context=to_proto_request_version_context(request_version_context), ) response: GetVectorsResponse = self._vector_reader_stub.GetVectors( - request, timeout=self._request_timeout_seconds + request, + timeout=self._request_timeout_seconds, ) results: List[VectorEmbeddingRecord] = [] for vector in response.records: @@ -89,9 +92,13 @@ def query_vectors( include_embeddings=query["include_embeddings"], segment_id=self._segment["id"].hex, collection_id=self._segment["collection"].hex, + version_context=to_proto_request_version_context( + query["request_version_context"] + ), ) response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors( - request, timeout=self._request_timeout_seconds + request, + timeout=self._request_timeout_seconds, ) results: List[List[VectorQueryResult]] = [] for result in response.results: diff --git a/rust/error/src/lib.rs b/rust/error/src/lib.rs index 37f08c16a1d..32dbed38e9f 100644 --- a/rust/error/src/lib.rs +++ b/rust/error/src/lib.rs @@ -10,7 +10,7 @@ pub enum ErrorCodes { // CANCELLED indicates the operation was cancelled (typically by the caller). Cancelled = 1, // UNKNOWN indicates an unknown error. - UNKNOWN = 2, + Unknown = 2, // INVALID_ARGUMENT indicates client specified an invalid argument. InvalidArgument = 3, // DEADLINE_EXCEEDED means operation expired before completion. @@ -21,8 +21,6 @@ pub enum ErrorCodes { AlreadyExists = 6, // PERMISSION_DENIED indicates the caller does not have permission to execute the specified operation. PermissionDenied = 7, - // UNAUTHENTICATED indicates the request does not have valid authentication credentials for the operation. - UNAUTHENTICATED = 16, // RESOURCE_EXHAUSTED indicates some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system is out of space. ResourceExhausted = 8, // FAILED_PRECONDITION indicates operation was rejected because the system is not in a state required for the operation's execution. @@ -39,6 +37,10 @@ pub enum ErrorCodes { Unavailable = 14, // DATA_LOSS indicates unrecoverable data loss or corruption. DataLoss = 15, + // UNAUTHENTICATED indicates the request does not have valid authentication credentials for the operation. + Unauthenticated = 16, + // VERSION_MISMATCH indicates a version mismatch. This is not from the gRPC spec and is specific to Chroma. + VersionMismatch = 17, } pub trait ChromaError: Error + Send { diff --git a/rust/worker/src/execution/operators/merge_knn_results.rs b/rust/worker/src/execution/operators/merge_knn_results.rs index b328de9530c..05b418d48d8 100644 --- a/rust/worker/src/execution/operators/merge_knn_results.rs +++ b/rust/worker/src/execution/operators/merge_knn_results.rs @@ -63,7 +63,7 @@ pub enum MergeKnnResultsOperatorError {} impl ChromaError for MergeKnnResultsOperatorError { fn code(&self) -> ErrorCodes { - return ErrorCodes::UNKNOWN; + return ErrorCodes::Unknown; } } diff --git a/rust/worker/src/execution/orchestration/count.rs b/rust/worker/src/execution/orchestration/count.rs index 00627f4aeb5..c25c6e7f35f 100644 --- a/rust/worker/src/execution/orchestration/count.rs +++ b/rust/worker/src/execution/orchestration/count.rs @@ -36,6 +36,9 @@ pub(crate) struct CountQueryOrchestrator { blockfile_provider: BlockfileProvider, // Result channel result_channel: Option>>>, + // Request version context + collection_version: u32, + log_position: u64, } #[derive(Error, Debug)] @@ -52,6 +55,8 @@ enum CountQueryOrchestratorError { CollectionNotFound(Uuid), #[error("Get collection error: {0}")] GetCollectionError(#[from] GetCollectionsError), + #[error("Collection version mismatch")] + CollectionVersionMismatch, } impl ChromaError for CountQueryOrchestratorError { @@ -65,6 +70,7 @@ impl ChromaError for CountQueryOrchestratorError { CountQueryOrchestratorError::SystemTimeError(_) => ErrorCodes::Internal, CountQueryOrchestratorError::CollectionNotFound(_) => ErrorCodes::NotFound, CountQueryOrchestratorError::GetCollectionError(e) => e.code(), + CountQueryOrchestratorError::CollectionVersionMismatch => ErrorCodes::VersionMismatch, } } } @@ -78,6 +84,8 @@ impl CountQueryOrchestrator { sysdb: Box, dispatcher: ComponentHandle, blockfile_provider: BlockfileProvider, + collection_version: u32, + log_position: u64, ) -> Self { Self { system, @@ -90,6 +98,8 @@ impl CountQueryOrchestrator { dispatcher, blockfile_provider, result_channel: None, + collection_version, + log_position, } } @@ -140,8 +150,19 @@ impl CountQueryOrchestrator { } }; + // If the collection version does not match the request version then we terminate with an error + if collection.version as u32 != self.collection_version { + terminate_with_error( + self.result_channel.take(), + Box::new(CountQueryOrchestratorError::CollectionVersionMismatch), + ctx, + ); + return; + } + self.record_segment = Some(record_segment); self.collection = Some(collection); + self.pull_logs(ctx).await; } // shared @@ -170,7 +191,9 @@ impl CountQueryOrchestrator { let input = PullLogsInput::new( collection.id, // The collection log position is inclusive, and we want to start from the next log. - collection.log_position + 1, + // Note that we query using the incoming log position this is critical for correctness + // TODO: We should make all the log service code use u64 instead of i64 + (self.log_position as i64) + 1, 100, None, Some(end_timestamp), @@ -307,7 +330,6 @@ impl Component for CountQueryOrchestrator { async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { self.start(ctx).await; - self.pull_logs(ctx).await; } } diff --git a/rust/worker/src/execution/orchestration/get_vectors.rs b/rust/worker/src/execution/orchestration/get_vectors.rs index 33bc55cfe13..99066b07d29 100644 --- a/rust/worker/src/execution/orchestration/get_vectors.rs +++ b/rust/worker/src/execution/orchestration/get_vectors.rs @@ -43,6 +43,8 @@ enum GetVectorsError { TaskSendError(#[from] ChannelError), #[error("System time error")] SystemTimeError(#[from] std::time::SystemTimeError), + #[error("Collection version mismatch")] + CollectionVersionMismatch, } impl ChromaError for GetVectorsError { @@ -50,6 +52,7 @@ impl ChromaError for GetVectorsError { match self { GetVectorsError::TaskSendError(e) => e.code(), GetVectorsError::SystemTimeError(_) => ErrorCodes::Internal, + GetVectorsError::CollectionVersionMismatch => ErrorCodes::VersionMismatch, } } } @@ -74,6 +77,8 @@ pub struct GetVectorsOrchestrator { // Result channel result_channel: Option>>>, + collection_version: u32, + log_position: u64, } impl GetVectorsOrchestrator { @@ -86,6 +91,8 @@ impl GetVectorsOrchestrator { sysdb: Box, dispatcher: ComponentHandle, blockfile_provider: BlockfileProvider, + collection_version: u32, + log_position: u64, ) -> Self { Self { state: ExecutionState::Pending, @@ -100,6 +107,8 @@ impl GetVectorsOrchestrator { record_segment: None, collection: None, result_channel: None, + collection_version, + log_position, } } @@ -132,7 +141,9 @@ impl GetVectorsOrchestrator { let input = PullLogsInput::new( collection.id, // The collection log position is inclusive, and we want to start from the next log - collection.log_position + 1, + // Note that we query using the incoming log position this is critical for correctness + // TODO: We should make all the log service code use u64 instead of i64 + (self.log_position as i64) + 1, 100, None, Some(end_timestamp), @@ -241,6 +252,16 @@ impl Component for GetVectorsOrchestrator { } }; + // If the collection version does not match the request version then we terminate with an error + if collection.version as u32 != self.collection_version { + terminate_with_error( + self.result_channel.take(), + Box::new(GetVectorsError::CollectionVersionMismatch), + ctx, + ); + return; + } + let record_segment = match get_record_segment_by_collection_id(self.sysdb.clone(), collection_id).await { Ok(segment) => segment, diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index 5ea356984d5..757f95501d2 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -88,6 +88,8 @@ enum HnswSegmentQueryError { RecordSegmentNotFound(Uuid), #[error("Collection has no dimension set")] CollectionHasNoDimension, + #[error("Collection version mismatch")] + CollectionVersionMismatch, } impl ChromaError for HnswSegmentQueryError { @@ -99,6 +101,7 @@ impl ChromaError for HnswSegmentQueryError { HnswSegmentQueryError::GetCollectionError(_) => ErrorCodes::Internal, HnswSegmentQueryError::RecordSegmentNotFound(_) => ErrorCodes::NotFound, HnswSegmentQueryError::CollectionHasNoDimension => ErrorCodes::InvalidArgument, + HnswSegmentQueryError::CollectionVersionMismatch => ErrorCodes::VersionMismatch, } } } @@ -143,6 +146,9 @@ pub(crate) struct HnswQueryOrchestrator { result_channel: Option< tokio::sync::oneshot::Sender>, Box>>, >, + // Request version context + collection_version: u32, + log_position: u64, } impl HnswQueryOrchestrator { @@ -159,6 +165,8 @@ impl HnswQueryOrchestrator { hnsw_index_provider: HnswIndexProvider, blockfile_provider: BlockfileProvider, dispatcher: ComponentHandle, + collection_version: u32, + log_position: u64, ) -> Self { // Set the merge dependency count to the number of query vectors * 2 // N for the HNSW query and N for the Brute force query @@ -203,6 +211,8 @@ impl HnswQueryOrchestrator { hnsw_index_provider, blockfile_provider, result_channel: None, + collection_version, + log_position, } } @@ -230,7 +240,9 @@ impl HnswQueryOrchestrator { let input = PullLogsInput::new( collection.id, // The collection log position is inclusive, and we want to start from the next log - collection.log_position + 1, + // Note that we query using the incoming log position this is critical for correctness + // TODO: We should make all the log service code use u64 instead of i64 + (self.log_position as i64) + 1, 100, None, Some(end_timestamp), @@ -571,6 +583,16 @@ impl Component for HnswQueryOrchestrator { } }; + // If the collection version does not match the request version then we terminate with an error + if collection.version as u32 != self.collection_version { + terminate_with_error( + self.result_channel.take(), + Box::new(HnswSegmentQueryError::CollectionVersionMismatch), + ctx, + ); + return; + } + // If segment is uninitialized and dimension is not set then we assume // that this is a query before any add so return empty response. if hnsw_segment.file_path.len() <= 0 && collection.dimension.is_none() { diff --git a/rust/worker/src/execution/orchestration/metadata.rs b/rust/worker/src/execution/orchestration/metadata.rs index f3b81525895..ef47ea9e468 100644 --- a/rust/worker/src/execution/orchestration/metadata.rs +++ b/rust/worker/src/execution/orchestration/metadata.rs @@ -65,6 +65,9 @@ pub(crate) struct MetadataQueryOrchestrator { include_metadata: bool, // Result channel result_channel: Option>, + // Request version context + collection_version: u32, + log_position: u64, } #[derive(Error, Debug)] @@ -81,6 +84,8 @@ enum MetadataQueryOrchestratorError { CollectionNotFound(Uuid), #[error("Get collection error: {0}")] GetCollectionError(#[from] GetCollectionsError), + #[error("Collection version mismatch")] + CollectionVersionMismatch, } impl ChromaError for MetadataQueryOrchestratorError { @@ -94,6 +99,9 @@ impl ChromaError for MetadataQueryOrchestratorError { MetadataQueryOrchestratorError::SystemTimeError(_) => ErrorCodes::Internal, MetadataQueryOrchestratorError::CollectionNotFound(_) => ErrorCodes::NotFound, MetadataQueryOrchestratorError::GetCollectionError(e) => e.code(), + MetadataQueryOrchestratorError::CollectionVersionMismatch => { + ErrorCodes::VersionMismatch + } } } } @@ -113,6 +121,8 @@ impl MetadataQueryOrchestrator { offset: Option, limit: Option, include_metadata: bool, + collection_version: u32, + log_position: u64, ) -> Self { Self { state: ExecutionState::Pending, @@ -134,6 +144,8 @@ impl MetadataQueryOrchestrator { limit, include_metadata, result_channel: None, + collection_version, + log_position, } } @@ -182,8 +194,19 @@ impl MetadataQueryOrchestrator { } }; + // If the collection version does not match the version in the request, return an error + if collection.version as u32 != self.collection_version { + terminate_with_error( + self.result_channel.take(), + Box::new(MetadataQueryOrchestratorError::CollectionVersionMismatch), + ctx, + ); + return; + } + self.record_segment = Some(record_segment); self.collection = Some(collection); + self.pull_logs(ctx).await; } async fn pull_logs(&mut self, ctx: &ComponentContext) { @@ -211,7 +234,9 @@ impl MetadataQueryOrchestrator { let input = PullLogsInput::new( collection.id, // The collection log position is inclusive, and we want to start from the next log. - collection.log_position + 1, + // Note: We use the log position sent in the request for transactionality + // TODO: Change log service to use u64 instead of i64 + (self.log_position as i64) + 1, 100, None, Some(end_timestamp), @@ -385,7 +410,6 @@ impl Component for MetadataQueryOrchestrator { async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { self.start(ctx).await; - self.pull_logs(ctx).await; } } diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index b55d0821492..927e94fd01b 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -146,6 +146,16 @@ impl WorkerServer { } }; + let (collection_version, log_position) = match request.version_context { + Some(version_context) => ( + version_context.collection_version, + version_context.log_position, + ), + None => { + return Err(Status::invalid_argument("No version context provided")); + } + }; + let mut proto_results_for_all = Vec::new(); let mut query_vectors = Vec::new(); @@ -182,6 +192,8 @@ impl WorkerServer { self.hnsw_index_provider.clone(), self.blockfile_provider.clone(), dispatcher, + collection_version, + log_position, ); orchestrator.run().await } @@ -256,6 +268,16 @@ impl WorkerServer { } }; + let (collection_version, log_position) = match request.version_context { + Some(version_context) => ( + version_context.collection_version, + version_context.log_position, + ), + None => { + return Err(Status::invalid_argument("No version context provided")); + } + }; + let dispatcher = match self.dispatcher { Some(ref dispatcher) => dispatcher.clone(), None => { @@ -279,6 +301,8 @@ impl WorkerServer { self.sysdb.clone(), dispatcher, self.blockfile_provider.clone(), + collection_version, + log_position, ); let result = orchestrator.run().await; let mut result = match result { @@ -335,6 +359,16 @@ impl WorkerServer { } }; + let (collection_version, log_position) = match request.version_context { + Some(version_context) => ( + version_context.collection_version, + version_context.log_position, + ), + None => { + return Err(Status::invalid_argument("No version context provided")); + } + }; + let dispatcher = match self.dispatcher { Some(ref dispatcher) => dispatcher, None => { @@ -392,6 +426,8 @@ impl WorkerServer { request.offset, request.limit, request.include_metadata, + collection_version, + log_position, ); let result = orchestrator.run().await; @@ -506,7 +542,16 @@ impl chroma_proto::metadata_reader_server::MetadataReader for WorkerServer { } }; - println!("Querying count for segment {}", segment_uuid); + let (collection_version, log_position) = match request.version_context { + Some(version_context) => ( + version_context.collection_version, + version_context.log_position, + ), + None => { + return Err(Status::invalid_argument("No version context provided")); + } + }; + let dispatcher = match self.dispatcher { Some(ref dispatcher) => dispatcher, None => { @@ -529,6 +574,8 @@ impl chroma_proto::metadata_reader_server::MetadataReader for WorkerServer { self.sysdb.clone(), dispatcher.clone(), self.blockfile_provider.clone(), + collection_version, + log_position, ); let result = orchestrator.run().await;