Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(common): add MemcmpEncoded struct to represent memcmp encoded data #10319

Merged
merged 3 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/batch/src/executor/top_n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::memory::MemoryContext;
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
use risingwave_common::util::memcmp_encoding::encode_chunk;
use risingwave_common::util::memcmp_encoding::{encode_chunk, MemcmpEncoded};
use risingwave_common::util::sort_util::ColumnOrder;
use risingwave_pb::batch_plan::plan_node::NodeBody;

Expand Down Expand Up @@ -200,7 +200,7 @@ impl TopNHeap {

#[derive(Clone, EstimateSize)]
pub struct HeapElem {
encoded_row: Vec<u8>,
encoded_row: MemcmpEncoded,
row: OwnedRow,
}

Expand All @@ -225,7 +225,7 @@ impl Ord for HeapElem {
}

impl HeapElem {
pub fn new(encoded_row: Vec<u8>, row: impl Row) -> Self {
pub fn new(encoded_row: MemcmpEncoded, row: impl Row) -> Self {
Self {
encoded_row,
row: row.into_owned_row(),
Expand Down
3 changes: 2 additions & 1 deletion src/common/benches/bench_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{
DataType, Date, Datum, Interval, ScalarImpl, StructType, Time, Timestamp,
};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_common::util::sort_util::OrderType;
use risingwave_common::util::{memcmp_encoding, value_encoding};

Expand All @@ -42,7 +43,7 @@ impl Case {
}
}

fn key_serialization(datum: &Datum) -> Vec<u8> {
fn key_serialization(datum: &Datum) -> MemcmpEncoded {
let result = memcmp_encoding::encode_value(
datum.as_ref().map(ScalarImpl::as_scalar_ref_impl),
OrderType::default(),
Expand Down
109 changes: 95 additions & 14 deletions src/common/src/util/memcmp_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::Deref;

use bytes::{Buf, BufMut};
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use super::iter_util::{ZipEqDebug, ZipEqFast};
use crate::array::{ArrayImpl, DataChunk};
use crate::estimate_size::EstimateSize;
use crate::row::{OwnedRow, Row};
use crate::types::{
DataType, Date, Datum, Int256, ScalarImpl, Serial, Time, Timestamp, ToDatumRef, F32, F64,
Expand Down Expand Up @@ -180,12 +183,83 @@ fn calculate_encoded_size_inner(
Ok(deserializer.position() - base_position)
}

pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> memcomparable::Result<Vec<u8>> {
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, EstimateSize)]
pub struct MemcmpEncoded(Box<[u8]>);

impl MemcmpEncoded {
pub fn as_inner(&self) -> &[u8] {
&self.0
}

pub fn into_inner(self) -> Box<[u8]> {
self.0
}
}

impl AsRef<[u8]> for MemcmpEncoded {
fn as_ref(&self) -> &[u8] {
&self.0
}
}

impl Deref for MemcmpEncoded {
type Target = [u8];

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl IntoIterator for MemcmpEncoded {
type IntoIter = std::vec::IntoIter<Self::Item>;
type Item = u8;

fn into_iter(self) -> Self::IntoIter {
self.0.into_vec().into_iter()
}
}

impl FromIterator<u8> for MemcmpEncoded {
fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}

impl From<Vec<u8>> for MemcmpEncoded {
fn from(v: Vec<u8>) -> Self {
Self(v.into_boxed_slice())
}
}

impl From<Box<[u8]>> for MemcmpEncoded {
fn from(v: Box<[u8]>) -> Self {
Self(v)
}
}

impl From<MemcmpEncoded> for Vec<u8> {
fn from(v: MemcmpEncoded) -> Self {
v.0.into()
}
}

impl From<MemcmpEncoded> for Box<[u8]> {
fn from(v: MemcmpEncoded) -> Self {
v.0
}
}

/// Encode a datum into memcomparable format.
pub fn encode_value(
value: impl ToDatumRef,
order: OrderType,
) -> memcomparable::Result<MemcmpEncoded> {
let mut serializer = memcomparable::Serializer::new(vec![]);
serialize_datum(value, order, &mut serializer)?;
Ok(serializer.into_inner())
Ok(serializer.into_inner().into())
}

/// Decode a datum from memcomparable format.
pub fn decode_value(
ty: &DataType,
encoded_value: &[u8],
Expand All @@ -195,21 +269,23 @@ pub fn decode_value(
deserialize_datum(ty, order, &mut deserializer)
}

pub fn encode_array(array: &ArrayImpl, order: OrderType) -> memcomparable::Result<Vec<Vec<u8>>> {
/// Encode an array into memcomparable format.
pub fn encode_array(
array: &ArrayImpl,
order: OrderType,
) -> memcomparable::Result<Vec<MemcmpEncoded>> {
let mut data = Vec::with_capacity(array.len());
for datum in array.iter() {
data.push(encode_value(datum, order)?);
}
Ok(data)
}

/// This function is used to accelerate the comparison of tuples. It takes datachunk and
/// user-defined order as input, yield encoded binary string with order preserved for each tuple in
/// the datachunk.
/// Encode a chunk into memcomparable format.
pub fn encode_chunk(
chunk: &DataChunk,
column_orders: &[ColumnOrder],
) -> memcomparable::Result<Vec<Vec<u8>>> {
) -> memcomparable::Result<Vec<MemcmpEncoded>> {
let encoded_columns: Vec<_> = column_orders
.iter()
.map(|o| encode_array(chunk.column_at(o.column_index), o.order_type))
Expand All @@ -222,18 +298,22 @@ pub fn encode_chunk(
}
}

Ok(encoded_chunk)
Ok(encoded_chunk.into_iter().map(Into::into).collect())
}

/// Encode a row into memcomparable format.
pub fn encode_row(row: impl Row, order_types: &[OrderType]) -> memcomparable::Result<Vec<u8>> {
pub fn encode_row(
row: impl Row,
order_types: &[OrderType],
) -> memcomparable::Result<MemcmpEncoded> {
let mut serializer = memcomparable::Serializer::new(vec![]);
row.iter()
.zip_eq_debug(order_types)
.try_for_each(|(datum, order)| serialize_datum(datum, *order, &mut serializer))?;
Ok(serializer.into_inner())
Ok(serializer.into_inner().into())
}

/// Decode a row from memcomparable format.
pub fn decode_row(
encoded_row: &[u8],
data_types: &[DataType],
Expand All @@ -259,11 +339,12 @@ mod tests {
use crate::array::{DataChunk, ListValue, StructValue};
use crate::row::{OwnedRow, RowExt};
use crate::types::{DataType, FloatExt, ScalarImpl, F32};
use crate::util::iter_util::ZipEqFast;
use crate::util::sort_util::{ColumnOrder, OrderType};

#[test]
fn test_memcomparable() {
fn encode_num(num: Option<i32>, order_type: OrderType) -> Vec<u8> {
fn encode_num(num: Option<i32>, order_type: OrderType) -> MemcmpEncoded {
encode_value(num.map(ScalarImpl::from), order_type).unwrap()
}

Expand Down Expand Up @@ -465,11 +546,11 @@ mod tests {
use num_traits::*;
use rand::seq::SliceRandom;

fn serialize(f: F32) -> Vec<u8> {
fn serialize(f: F32) -> MemcmpEncoded {
encode_value(&Some(ScalarImpl::from(f)), OrderType::default()).unwrap()
}

fn deserialize(data: Vec<u8>) -> F32 {
fn deserialize(data: MemcmpEncoded) -> F32 {
decode_value(&DataType::Float32, &data, OrderType::default())
.unwrap()
.unwrap()
Expand Down Expand Up @@ -539,7 +620,7 @@ mod tests {
let concated_encoded_row1 = encoded_v10
.into_iter()
.chain(encoded_v11.into_iter())
.collect_vec();
.collect();
assert_eq!(encoded_row1, concated_encoded_row1);

let encoded_row2 = encode_row(row2.project(&order_col_indices), &order_types).unwrap();
Expand Down
5 changes: 3 additions & 2 deletions src/stream/src/executor/aggregation/agg_state_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ use risingwave_common::array::{ArrayImpl, Op};
use risingwave_common::buffer::Bitmap;
use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::types::{Datum, DatumRef};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_common::util::row_serde::OrderedRowSerde;
use smallvec::SmallVec;

use super::minput_agg_impl::MInputAggregator;
use crate::common::cache::{StateCache, StateCacheFiller};

/// Cache key type.
type CacheKey = Vec<u8>;
type CacheKey = MemcmpEncoded;

// TODO(yuchao): May extract common logic here to `struct [Data/Stream]ChunkRef` if there's other
// usage in the future. https:/risingwavelabs/risingwave/pull/5908#discussion_r1002896176
Expand Down Expand Up @@ -76,7 +77,7 @@ impl<'a> Iterator for StateCacheInputBatch<'a> {
.map(|col_idx| self.columns[*col_idx].value_at(self.idx)),
&mut key,
);
key
key.into()
};
let value = self
.arg_col_indices
Expand Down
2 changes: 1 addition & 1 deletion src/stream/src/executor/aggregation/minput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl<S: StateStore> MaterializedInputState<S> {
.project(&self.state_table_order_col_indices),
&mut cache_key,
);
cache_key
cache_key.into()
};
let cache_value = self
.state_table_arg_col_indices
Expand Down
12 changes: 4 additions & 8 deletions src/stream/src/executor/over_window/eowc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::row::{OwnedRow, Row, RowExt};
use risingwave_common::types::{DataType, ToDatumRef, ToOwnedDatum};
use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
use risingwave_common::util::memcmp_encoding;
use risingwave_common::util::memcmp_encoding::{self, MemcmpEncoded};
use risingwave_common::util::sort_util::OrderType;
use risingwave_common::{must_match, row};
use risingwave_expr::function::window::WindowFuncCall;
use risingwave_storage::store::PrefetchOptions;
use risingwave_storage::StateStore;

use super::state::{create_window_state, EstimatedVecDeque, WindowState};
use super::MemcmpEncoded;
use crate::cache::{new_unbounded, ManagedLruCache};
use crate::common::table::state_table::StateTable;
use crate::executor::over_window::state::{StateEvictHint, StateKey};
Expand Down Expand Up @@ -241,8 +240,7 @@ impl<S: StateStore> EowcOverWindowExecutor<S> {
let encoded_pk = memcmp_encoding::encode_row(
(&row).project(&this.input_pk_indices),
&vec![OrderType::ascending(); this.input_pk_indices.len()],
)?
.into_boxed_slice();
)?;
let key = StateKey {
order_key: order_key.into(),
encoded_pk,
Expand Down Expand Up @@ -292,8 +290,7 @@ impl<S: StateStore> EowcOverWindowExecutor<S> {
let encoded_partition_key = memcmp_encoding::encode_row(
&partition_key,
&vec![OrderType::ascending(); this.partition_key_indices.len()],
)?
.into_boxed_slice();
)?;

// Get the partition.
Self::ensure_key_in_cache(
Expand All @@ -316,8 +313,7 @@ impl<S: StateStore> EowcOverWindowExecutor<S> {
let encoded_pk = memcmp_encoding::encode_row(
input_row.project(&this.input_pk_indices),
&vec![OrderType::ascending(); this.input_pk_indices.len()],
)?
.into_boxed_slice();
)?;
let key = StateKey {
order_key: order_key.into(),
encoded_pk,
Expand Down
2 changes: 0 additions & 2 deletions src/stream/src/executor/over_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,3 @@ mod eowc;
mod state;

pub use eowc::{EowcOverWindowExecutor, EowcOverWindowExecutorArgs};

type MemcmpEncoded = Box<[u8]>;
2 changes: 1 addition & 1 deletion src/stream/src/executor/over_window/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use std::collections::{BTreeSet, VecDeque};
use educe::Educe;
use risingwave_common::estimate_size::{EstimateSize, KvSize};
use risingwave_common::types::{Datum, DefaultOrdered, ScalarImpl};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_expr::function::window::{WindowFuncCall, WindowFuncKind};
use smallvec::SmallVec;

use super::MemcmpEncoded;
use crate::executor::{StreamExecutorError, StreamExecutorResult};

mod buffer;
Expand Down
6 changes: 2 additions & 4 deletions src/stream/src/executor/sort_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use risingwave_common::row::{self, OwnedRow, Row, RowExt};
use risingwave_common::types::{
DefaultOrd, DefaultOrdered, ScalarImpl, ScalarRefImpl, ToOwnedDatum,
};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_storage::row_serde::row_serde_util::deserialize_pk_with_vnode;
use risingwave_storage::store::PrefetchOptions;
use risingwave_storage::StateStore;
Expand All @@ -35,9 +36,6 @@ use super::{StreamExecutorError, StreamExecutorResult};
use crate::common::cache::{OrderedStateCache, StateCache, StateCacheFiller};
use crate::common::table::state_table::StateTable;

// TODO(rc): This should be a struct in `memcmp_encoding` module. See #8606.
type MemcmpEncoded = Box<[u8]>;

type CacheKey = (
DefaultOrdered<ScalarImpl>, // sort (watermark) column value
MemcmpEncoded, // memcmp-encoded pk
Expand All @@ -56,7 +54,7 @@ fn row_to_cache_key<S: StateStore>(
buffer_table
.pk_serde()
.serialize((&row).project(buffer_table.pk_indices()), &mut pk);
(timestamp_val.into(), pk.into_boxed_slice())
(timestamp_val.into(), pk.into())
}

/// [`SortBuffer`] is a common component that consume an unordered stream and produce an ordered
Expand Down