Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(expr): support shuffle approx_percentile #17814

Merged
merged 18 commits into from
Jul 30, 2024
Merged
102 changes: 102 additions & 0 deletions e2e_test/streaming/aggregate/approx_percentile.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Single phase approx percentile
statement ok
create table t(p_col double, grp_col int);

statement ok
insert into t select generate_series, 1 from generate_series(-1000, 1000);
kwannoel marked this conversation as resolved.
Show resolved Hide resolved

statement ok
flush;

query I
select
approx_percentile(0.01, 0.01) within group (order by p_col) as p01,
approx_percentile(0.1, 0.01) within group (order by p_col) as p10,
approx_percentile(0.5, 0.01) within group (order by p_col) as p50,
approx_percentile(0.9, 0.01) within group (order by p_col) as p90,
approx_percentile(0.99, 0.01) within group (order by p_col) as p99
from t group by grp_col;
----
-982.5779489474152 -804.4614206837127 0 804.4614206837127 982.5779489474152

query I
select
percentile_disc(0.01) within group (order by p_col) as p01,
percentile_disc(0.1) within group (order by p_col) as p10,
percentile_disc(0.5) within group (order by p_col) as p50,
percentile_disc(0.9) within group (order by p_col) as p90,
percentile_disc(0.99) within group (order by p_col) as p99
from t group by grp_col;
----
-980 -800 0 800 980

statement ok
create materialized view m1 as
select
approx_percentile(0.01, 0.01) within group (order by p_col) as p01,
approx_percentile(0.1, 0.01) within group (order by p_col) as p10,
approx_percentile(0.5, 0.01) within group (order by p_col) as p50,
approx_percentile(0.9, 0.01) within group (order by p_col) as p90,
approx_percentile(0.99, 0.01) within group (order by p_col) as p99
from t group by grp_col;

query I
select * from m1;
----
-982.5779489474152 -804.4614206837127 0 804.4614206837127 982.5779489474152

# statement ok
# recover;

# wait recovery
# sleep 10s
kwannoel marked this conversation as resolved.
Show resolved Hide resolved

query I
select * from m1;
----
-982.5779489474152 -804.4614206837127 0 804.4614206837127 982.5779489474152

# Test state encode / decode
onlyif can-use-recover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which cases we will set label can-use-recover?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only for e2e tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recover should not be used in cases where tests can execute in parallel.

statement ok
recover;
kwannoel marked this conversation as resolved.
Show resolved Hide resolved

onlyif can-use-recover
sleep 10s

query I
select * from m1;
----
-982.5779489474152 -804.4614206837127 0 804.4614206837127 982.5779489474152

# Test 0<x<1 values
statement ok
insert into t select 0.001, 1 from generate_series(1, 500);

statement ok
insert into t select 0.0001, 1 from generate_series(1, 501);

statement ok
flush;

query I
select * from m1;
----
-963.1209598593477 -699.3618972397041 0.00009999833511933609 699.3618972397041 963.1209598593477

query I
select
percentile_disc(0.01) within group (order by p_col) as p01,
percentile_disc(0.1) within group (order by p_col) as p10,
percentile_disc(0.5) within group (order by p_col) as p50,
percentile_disc(0.9) within group (order by p_col) as p90,
percentile_disc(0.99) within group (order by p_col) as p99
from t group by grp_col;
----
-970 -700 0.0001 700 970

statement ok
drop materialized view m1;

statement ok
drop table t;
196 changes: 180 additions & 16 deletions src/expr/impl/src/aggregate/approx_percentile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,220 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::mem::size_of;
use std::ops::Range;

use risingwave_common::array::*;
use risingwave_common::row::Row;
use risingwave_common::types::*;
use risingwave_common_estimate_size::EstimateSize;
use risingwave_expr::aggregate::{AggCall, AggStateDyn, AggregateFunction, AggregateState};
use risingwave_expr::{build_aggregate, Result};

#[build_aggregate("approx_percentile(float8) -> float8")]
/// TODO(kwannoel): for single phase agg, we can actually support `UDDSketch`.
/// For two phase agg, we still use `DDSketch`.
/// Then we also need to store the `relative_error` of the sketch, so we can report it
/// in an internal table, if it changes.
#[build_aggregate("approx_percentile(float8) -> float8", state = "bytea")]
fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
let fraction = agg.direct_args[0]
let quantile = agg.direct_args[0]
.literal()
.map(|x| (*x.as_float64()).into());
Ok(Box::new(ApproxPercentile { fraction }))
.map(|x| (*x.as_float64()).into())
.unwrap();
let relative_error: f64 = agg.direct_args[1]
.literal()
.map(|x| (*x.as_float64()).into())
.unwrap();
let base = (1.0 + relative_error) / (1.0 - relative_error);
Ok(Box::new(ApproxPercentile { quantile, base }))
}

#[allow(dead_code)]
pub struct ApproxPercentile {
fraction: Option<f64>,
quantile: f64,
base: f64,
}

#[derive(Debug, Default, EstimateSize)]
struct State(Vec<f64>);
type BucketCount = u64;
type BucketId = i32;
type Count = u64;

#[derive(Debug, Default)]
struct State {
count: BucketCount,
pos_buckets: BTreeMap<BucketId, Count>,
zeros: Count,
neg_buckets: BTreeMap<BucketId, Count>,
}

impl EstimateSize for State {
fn estimated_heap_size(&self) -> usize {
let count_size = 1;
let pos_buckets_size = self.pos_buckets.len() * 2;
let zero_bucket_size = size_of::<Count>();
let neg_buckets_size = self.pos_buckets.len() * 2;
kwannoel marked this conversation as resolved.
Show resolved Hide resolved
count_size + pos_buckets_size + zero_bucket_size + neg_buckets_size
}
}

impl AggStateDyn for State {}

impl ApproxPercentile {
fn add_datum(&self, state: &mut State, op: Op, datum: DatumRef<'_>) {
if let Some(value) = datum {
let prim_value = value.into_float64().into_inner();
let (non_neg, abs_value) = if prim_value < 0.0 {
(false, -prim_value)
} else {
(true, prim_value)
};
let bucket_id = abs_value.log(self.base).ceil() as BucketId;
match op {
Op::Delete | Op::UpdateDelete => {
if abs_value == 0.0 {
state.zeros -= 1;
} else if non_neg {
let count = state.pos_buckets.entry(bucket_id).or_insert(0);
*count -= 1;
} else {
let count = state.neg_buckets.entry(bucket_id).or_insert(0);
*count -= 1;
}
state.count -= 1;
}
Op::Insert | Op::UpdateInsert => {
if abs_value == 0.0 {
state.zeros += 1;
} else if non_neg {
let count = state.pos_buckets.entry(bucket_id).or_insert(0);
*count += 1;
} else {
let count = state.neg_buckets.entry(bucket_id).or_insert(0);
*count += 1;
}
state.count += 1;
}
}
};
}
}

#[async_trait::async_trait]
impl AggregateFunction for ApproxPercentile {
fn return_type(&self) -> DataType {
DataType::Float64
}

fn create_state(&self) -> Result<AggregateState> {
todo!()
Ok(AggregateState::Any(Box::<State>::default()))
}

async fn update(&self, _state: &mut AggregateState, _input: &StreamChunk) -> Result<()> {
todo!()
async fn update(&self, state: &mut AggregateState, input: &StreamChunk) -> Result<()> {
let state: &mut State = state.downcast_mut();
for (op, row) in input.rows() {
let datum = row.datum_at(0);
self.add_datum(state, op, datum);
}
Ok(())
}

async fn update_range(
&self,
_state: &mut AggregateState,
_input: &StreamChunk,
_range: Range<usize>,
state: &mut AggregateState,
input: &StreamChunk,
range: Range<usize>,
) -> Result<()> {
todo!()
let state = state.downcast_mut();
for (op, row) in input.rows_in(range) {
self.add_datum(state, op, row.datum_at(0));
}
Ok(())
}

// TODO(kwannoel): Instead of iterating over all buckets, we can maintain the
// approximate quantile bucket on the fly.
async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
let state = state.downcast_ref::<State>();
let quantile_count = (state.count as f64 * self.quantile) as u64;
let mut acc_count = 0;
for (bucket_id, count) in state.neg_buckets.iter().rev() {
acc_count += count;
if acc_count > quantile_count {
// approx value = -2 * y^i / (y + 1)
let approx_percentile = -2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
return Ok(Datum::from(approx_percentile));
}
}
acc_count += state.zeros;
if acc_count > quantile_count {
return Ok(Datum::from(ScalarImpl::Float64(0.0.into())));
}
for (bucket_id, count) in &state.pos_buckets {
acc_count += count;
if acc_count > quantile_count {
// approx value = 2 * y^i / (y + 1)
let approx_percentile = 2.0 * self.base.powi(*bucket_id) / (self.base + 1.0);
let approx_percentile = ScalarImpl::Float64(approx_percentile.into());
return Ok(Datum::from(approx_percentile));
}
}
return Ok(None);
}

fn encode_state(&self, state: &AggregateState) -> Result<Datum> {
let state = state.downcast_ref::<State>();
let mut encoded_state = Vec::with_capacity(state.estimated_heap_size());
encoded_state.extend_from_slice(&state.count.to_be_bytes());
encoded_state.extend_from_slice(&state.zeros.to_be_bytes());
let neg_buckets_size =
state.neg_buckets.len() * (size_of::<BucketId>() + size_of::<Count>());
encoded_state.extend_from_slice(&neg_buckets_size.to_be_bytes());
for (bucket_id, count) in &state.neg_buckets {
encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
encoded_state.extend_from_slice(&count.to_be_bytes());
}
for (bucket_id, count) in &state.pos_buckets {
encoded_state.extend_from_slice(&bucket_id.to_be_bytes());
encoded_state.extend_from_slice(&count.to_be_bytes());
}
let encoded_scalar = ScalarImpl::Bytea(encoded_state.into());
Ok(Datum::from(encoded_scalar))
}

async fn get_result(&self, _state: &AggregateState) -> Result<Datum> {
todo!()
fn decode_state(&self, datum: Datum) -> Result<AggregateState> {
let mut state = State::default();
let Some(scalar_state) = datum else {
return Ok(AggregateState::Any(Box::new(state)));
};
let encoded_state: Box<[u8]> = scalar_state.into_bytea();
let mut cursor = 0;
state.count = u64::from_be_bytes(encoded_state[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
kwannoel marked this conversation as resolved.
Show resolved Hide resolved
state.zeros = u64::from_be_bytes(encoded_state[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
let neg_buckets_size =
usize::from_be_bytes(encoded_state[cursor..cursor + 8].try_into().unwrap());
let neg_buckets_end = cursor + neg_buckets_size;
cursor += 8;
while cursor < neg_buckets_end {
let bucket_id =
i32::from_be_bytes(encoded_state[cursor..cursor + 4].try_into().unwrap());
cursor += 4;
let count = u64::from_be_bytes(encoded_state[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
state.neg_buckets.insert(bucket_id, count);
}
let pos_buckets_end = encoded_state.len();
while cursor < pos_buckets_end {
let bucket_id =
i32::from_be_bytes(encoded_state[cursor..cursor + 4].try_into().unwrap());
cursor += 4;
let count = u64::from_be_bytes(encoded_state[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
state.pos_buckets.insert(bucket_id, count);
}
Ok(AggregateState::Any(Box::new(state)))
}
}
3 changes: 2 additions & 1 deletion src/frontend/src/optimizer/plan_node/generic/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ impl<PlanRef: stream::StreamPlanRef> Agg<PlanRef> {
agg_kinds::single_value_state_iff_in_append_only!() if in_append_only => {
AggCallState::Value
}
agg_kinds::single_value_state!() => AggCallState::Value,
agg_kinds::single_value_state!()
| AggKind::Builtin(PbAggKind::ApproxPercentile) => AggCallState::Value,
kwannoel marked this conversation as resolved.
Show resolved Hide resolved
AggKind::Builtin(
PbAggKind::Min
| PbAggKind::Max
Expand Down
Loading