Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: ensure predicates are parsable #1690

Merged
merged 7 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
505 changes: 505 additions & 0 deletions rust/src/delta_datafusion/expr.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ use crate::{open_table, open_table_with_storage_options, DeltaTable, Invariant,

const PATH_COLUMN: &str = "__delta_rs_path";

pub mod expr;

impl From<DeltaTableError> for DataFusionError {
fn from(err: DeltaTableError) -> Self {
match err {
Expand Down
15 changes: 12 additions & 3 deletions rust/src/operations/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};

use crate::delta_datafusion::expr::fmt_expr_to_sql;
use crate::protocol::{Action, Add, Remove};
use datafusion::execution::context::{SessionContext, SessionState};
use datafusion::physical_expr::create_physical_expr;
Expand Down Expand Up @@ -263,7 +264,7 @@ async fn execute(
// Do not make a commit when there are zero updates to the state
if !actions.is_empty() {
let operation = DeltaOperation::Delete {
predicate: Some(predicate.canonical_name()),
predicate: Some(fmt_expr_to_sql(&predicate)?),
};
version = commit(
object_store.as_ref(),
Expand Down Expand Up @@ -298,7 +299,9 @@ impl std::future::IntoFuture for DeleteBuilder {
let predicate = match this.predicate {
Some(predicate) => match predicate {
Expression::DataFusion(expr) => Some(expr),
Expression::String(s) => Some(this.snapshot.parse_predicate_expression(s)?),
Expression::String(s) => {
Some(this.snapshot.parse_predicate_expression(s, &state)?)
}
},
None => None,
};
Expand Down Expand Up @@ -335,6 +338,7 @@ mod tests {
use arrow::record_batch::RecordBatch;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::*;
use serde_json::json;
use std::sync::Arc;

async fn setup_table(partitions: Option<Vec<&str>>) -> DeltaTable {
Expand Down Expand Up @@ -456,7 +460,7 @@ mod tests {
assert_eq!(table.version(), 2);
assert_eq!(table.get_file_uris().count(), 2);

let (table, metrics) = DeltaOps(table)
let (mut table, metrics) = DeltaOps(table)
.delete()
.with_predicate(col("value").eq(lit(1)))
.await
Expand All @@ -470,6 +474,11 @@ mod tests {
assert_eq!(metrics.num_deleted_rows, Some(1));
assert_eq!(metrics.num_copied_rows, Some(3));

let commit_info = table.history(None).await.unwrap();
let last_commit = &commit_info[commit_info.len() - 1];
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["predicate"], json!("value = 1"));

let expected = vec![
"+----+-------+------------+",
"| id | value | modified |",
Expand Down
32 changes: 23 additions & 9 deletions rust/src/operations/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use serde_json::{Map, Value};

use super::datafusion_utils::{into_expr, maybe_into_expr, Expression};
use super::transaction::commit;
use crate::delta_datafusion::expr::fmt_expr_to_sql;
use crate::delta_datafusion::{parquet_scan_from_actions, register_store};
use crate::operations::datafusion_utils::MetricObserverExec;
use crate::operations::write::write_execution_plan;
Expand Down Expand Up @@ -171,6 +172,7 @@ impl MergeBuilder {
let builder = builder(UpdateBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
builder.updates,
OperationType::Update,
Expand Down Expand Up @@ -204,6 +206,7 @@ impl MergeBuilder {
let builder = builder(DeleteBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
HashMap::default(),
OperationType::Delete,
Expand Down Expand Up @@ -240,6 +243,7 @@ impl MergeBuilder {
let builder = builder(InsertBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
builder.set,
OperationType::Insert,
Expand Down Expand Up @@ -278,6 +282,7 @@ impl MergeBuilder {
let builder = builder(UpdateBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
builder.updates,
OperationType::Update,
Expand Down Expand Up @@ -311,6 +316,7 @@ impl MergeBuilder {
let builder = builder(DeleteBuilder::default());
let op = MergeOperation::try_new(
&self.snapshot,
&self.state.as_ref(),
builder.predicate,
HashMap::default(),
OperationType::Delete,
Expand Down Expand Up @@ -448,15 +454,21 @@ struct MergeOperation {
impl MergeOperation {
pub fn try_new(
snapshot: &DeltaTableState,
state: &Option<&SessionState>,
predicate: Option<Expression>,
operations: HashMap<Column, Expression>,
r#type: OperationType,
) -> DeltaResult<Self> {
let predicate = maybe_into_expr(predicate, snapshot)?;
let context = SessionContext::new();
let mut s = &context.state();
if let Some(df_state) = state {
s = df_state;
}
let predicate = maybe_into_expr(predicate, snapshot, s)?;
let mut _operations = HashMap::new();

for (column, expr) in operations {
_operations.insert(column, into_expr(expr, snapshot)?);
_operations.insert(column, into_expr(expr, snapshot, s)?);
}

Ok(MergeOperation {
Expand Down Expand Up @@ -518,7 +530,7 @@ async fn execute(

let predicate = match predicate {
Expression::DataFusion(expr) => expr,
Expression::String(s) => snapshot.parse_predicate_expression(s)?,
Expression::String(s) => snapshot.parse_predicate_expression(s, &state)?,
};

let schema = snapshot.input_schema()?;
Expand Down Expand Up @@ -675,7 +687,10 @@ async fn execute(
};

let action_type = action_type.to_string();
let predicate = op.predicate.map(|expr| expr.display_name().unwrap());
let predicate = op
.predicate
.map(|expr| fmt_expr_to_sql(&expr))
.transpose()?;

predicates.push(MergePredicate {
action_type,
Expand Down Expand Up @@ -1035,7 +1050,7 @@ async fn execute(
// Do not make a commit when there are zero updates to the state
if !actions.is_empty() {
let operation = DeltaOperation::Merge {
predicate: Some(predicate.canonical_name()),
predicate: Some(fmt_expr_to_sql(&predicate)?),
matched_predicates: match_operations,
not_matched_predicates: not_match_target_operations,
not_matched_by_source_predicates: not_match_source_operations,
Expand Down Expand Up @@ -1222,10 +1237,9 @@ mod tests {
parameters["notMatchedPredicates"],
json!(r#"[{"actionType":"insert"}]"#)
);
// Todo: Expected this predicate to actually be 'value = 1'. Predicate should contain a valid sql expression
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"update","predicate":"value = Int32(1)"}]"#)
json!(r#"[{"actionType":"update","predicate":"value = 1"}]"#)
);

let expected = vec![
Expand Down Expand Up @@ -1447,7 +1461,7 @@ mod tests {
assert_eq!(parameters["predicate"], json!("id = source.id"));
assert_eq!(
parameters["matchedPredicates"],
json!(r#"[{"actionType":"delete","predicate":"source.value <= Int32(10)"}]"#)
json!(r#"[{"actionType":"delete","predicate":"source.value <= 10"}]"#)
);

let expected = vec![
Expand Down Expand Up @@ -1579,7 +1593,7 @@ mod tests {
assert_eq!(parameters["predicate"], json!("id = source.id"));
assert_eq!(
parameters["notMatchedBySourcePredicates"],
json!(r#"[{"actionType":"delete","predicate":"modified > Utf8(\"2021-02-01\")"}]"#)
json!(r#"[{"actionType":"delete","predicate":"modified > '2021-02-01'"}]"#)
);

let expected = vec![
Expand Down
12 changes: 9 additions & 3 deletions rust/src/operations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ mod datafusion_utils {
use arrow_schema::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::Result as DataFusionResult;
use datafusion::execution::context::SessionState;
use datafusion::physical_plan::DisplayAs;
use datafusion::physical_plan::{
metrics::{ExecutionPlanMetricsSet, MetricsSet},
Expand Down Expand Up @@ -240,19 +241,24 @@ mod datafusion_utils {
}
}

pub(crate) fn into_expr(expr: Expression, snapshot: &DeltaTableState) -> DeltaResult<Expr> {
pub(crate) fn into_expr(
expr: Expression,
snapshot: &DeltaTableState,
df_state: &SessionState,
) -> DeltaResult<Expr> {
match expr {
Expression::DataFusion(expr) => Ok(expr),
Expression::String(s) => snapshot.parse_predicate_expression(s),
Expression::String(s) => snapshot.parse_predicate_expression(s, df_state),
}
}

pub(crate) fn maybe_into_expr(
expr: Option<Expression>,
snapshot: &DeltaTableState,
df_state: &SessionState,
) -> DeltaResult<Option<Expr>> {
Ok(match expr {
Some(predicate) => Some(into_expr(predicate, snapshot)?),
Some(predicate) => Some(into_expr(predicate, snapshot, df_state)?),
None => None,
})
}
Expand Down
5 changes: 4 additions & 1 deletion rust/src/operations/transaction/conflict_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ impl<'a> TransactionInfo<'a> {
actions: &'a Vec<Action>,
read_whole_table: bool,
) -> DeltaResult<Self> {
use datafusion::prelude::SessionContext;

let session = SessionContext::new();
let read_predicates = read_predicates
.map(|pred| read_snapshot.parse_predicate_expression(pred))
.map(|pred| read_snapshot.parse_predicate_expression(pred, &session.state()))
.transpose()?;
Ok(Self {
txn_id: "".into(),
Expand Down
47 changes: 28 additions & 19 deletions rust/src/operations/transaction/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow::datatypes::{
DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
};
use datafusion::datasource::physical_plan::wrap_partition_type_in_dict;
use datafusion::execution::context::SessionState;
use datafusion::optimizer::utils::conjunction;
use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
use datafusion_common::config::ConfigOptions;
Expand Down Expand Up @@ -104,7 +105,11 @@ impl DeltaTableState {
}

/// Parse an expression string into a datafusion [`Expr`]
pub fn parse_predicate_expression(&self, expr: impl AsRef<str>) -> DeltaResult<Expr> {
pub fn parse_predicate_expression(
&self,
expr: impl AsRef<str>,
df_state: &SessionState,
) -> DeltaResult<Expr> {
let dialect = &GenericDialect {};
let mut tokenizer = Tokenizer::new(dialect, expr.as_ref());
let tokens = tokenizer
Expand All @@ -121,7 +126,7 @@ impl DeltaTableState {

// TODO should we add the table name as qualifier when available?
let df_schema = DFSchema::try_from_qualified_schema("", self.arrow_schema()?.as_ref())?;
let context_provider = DummyContextProvider::default();
let context_provider = DeltaContextProvider { state: df_state };
let sql_to_rel = SqlToRel::new(&context_provider);

Ok(sql_to_rel.sql_to_expr(sql, &df_schema, &mut Default::default())?)
Expand Down Expand Up @@ -342,59 +347,63 @@ impl PruningStatistics for DeltaTableState {
}
}

#[derive(Default)]
struct DummyContextProvider {
options: ConfigOptions,
pub(crate) struct DeltaContextProvider<'a> {
state: &'a SessionState,
}

impl ContextProvider for DummyContextProvider {
impl<'a> ContextProvider for DeltaContextProvider<'a> {
fn get_table_provider(&self, _name: TableReference) -> DFResult<Arc<dyn TableSource>> {
unimplemented!()
}

fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
unimplemented!()
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}

fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
unimplemented!()
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state.aggregate_functions().get(name).cloned()
}

fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
fn get_variable_type(&self, _var: &[String]) -> Option<DataType> {
unimplemented!()
}

fn options(&self) -> &ConfigOptions {
&self.options
self.state.config_options()
}

fn get_window_meta(&self, _name: &str) -> Option<Arc<datafusion_expr::WindowUDF>> {
unimplemented!()
fn get_window_meta(&self, name: &str) -> Option<Arc<datafusion_expr::WindowUDF>> {
self.state.window_functions().get(name).cloned()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::operations::transaction::test_utils::{create_add_action, init_table_actions};
use datafusion::prelude::SessionContext;
use datafusion_expr::{col, lit};

#[test]
fn test_parse_predicate_expression() {
let state = DeltaTableState::from_actions(init_table_actions(), 0).unwrap();
let snapshot = DeltaTableState::from_actions(init_table_actions(), 0).unwrap();
let session = SessionContext::new();
let state = session.state();

// parses simple expression
let parsed = state.parse_predicate_expression("value > 10").unwrap();
let parsed = snapshot
.parse_predicate_expression("value > 10", &state)
.unwrap();
let expected = col("value").gt(lit::<i64>(10));
assert_eq!(parsed, expected);

// fails for unknown column
let parsed = state.parse_predicate_expression("non_existent > 10");
let parsed = snapshot.parse_predicate_expression("non_existent > 10", &state);
assert!(parsed.is_err());

// parses complex expression
let parsed = state
.parse_predicate_expression("value > 10 OR value <= 0")
let parsed = snapshot
.parse_predicate_expression("value > 10 OR value <= 0", &state)
.unwrap();
let expected = col("value")
.gt(lit::<i64>(10))
Expand Down
Loading
Loading