diff --git a/rust/src/delta_datafusion/expr.rs b/rust/src/delta_datafusion/expr.rs new file mode 100644 index 0000000000..d60fe6666c --- /dev/null +++ b/rust/src/delta_datafusion/expr.rs @@ -0,0 +1,505 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// This product includes software from the Datafusion project (Apache 2.0) +// https://github.com/apache/arrow-datafusion +// Display functions and required macros were pulled from https://github.com/apache/arrow-datafusion/blob/ddb95497e2792015d5a5998eec79aac8d37df1eb/datafusion/expr/src/expr.rs + +//! Utility functions for Datafusion's Expressions + +use std::fmt::{self, Display, Formatter, Write}; + +use datafusion_common::ScalarValue; +use datafusion_expr::{ + expr::{InList, ScalarUDF}, + Between, BinaryExpr, Expr, Like, +}; +use sqlparser::ast::escape_quoted_string; + +use crate::DeltaTableError; + +struct SqlFormat<'a> { + expr: &'a Expr, +} + +macro_rules! expr_vec_fmt { + ( $ARRAY:expr ) => {{ + $ARRAY + .iter() + .map(|e| format!("{}", SqlFormat { expr: e })) + .collect::>() + .join(", ") + }}; +} + +struct BinaryExprFormat<'a> { + expr: &'a BinaryExpr, +} + +impl<'a> Display for BinaryExprFormat<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Put parentheses around child binary expressions so that we can see the difference + // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed, + // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are + // equivalent and the parentheses are not necessary. + + fn write_child(f: &mut Formatter<'_>, expr: &Expr, precedence: u8) -> fmt::Result { + match expr { + Expr::BinaryExpr(child) => { + let p = child.op.precedence(); + if p == 0 || p < precedence { + write!(f, "({})", BinaryExprFormat { expr: child })?; + } else { + write!(f, "{}", BinaryExprFormat { expr: child })?; + } + } + _ => write!(f, "{}", SqlFormat { expr })?, + } + Ok(()) + } + + let precedence = self.expr.op.precedence(); + write_child(f, self.expr.left.as_ref(), precedence)?; + write!(f, " {} ", self.expr.op)?; + write_child(f, self.expr.right.as_ref(), precedence) + } +} + +impl<'a> Display for SqlFormat<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.expr { + Expr::Column(c) => write!(f, "{c}"), + Expr::Literal(v) => write!(f, "{}", ScalarValueFormat { scalar: v }), + Expr::Case(case) => { + write!(f, "CASE ")?; + if let Some(e) = &case.expr { + write!(f, "{} ", SqlFormat { expr: e })?; + } + for (w, t) in &case.when_then_expr { + write!( + f, + "WHEN {} THEN {} ", + SqlFormat { expr: w }, + SqlFormat { expr: t } + )?; + } + if let Some(e) = &case.else_expr { + write!(f, "ELSE {} ", SqlFormat { expr: e })?; + } + write!(f, "END") + } + Expr::Not(expr) => write!(f, "NOT {}", SqlFormat { expr }), + Expr::Negative(expr) => write!(f, "(- {})", SqlFormat { expr }), + Expr::IsNull(expr) => write!(f, "{} IS NULL", SqlFormat { expr }), + Expr::IsNotNull(expr) => write!(f, "{} IS NOT NULL", SqlFormat { expr }), + Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SqlFormat { expr }), + Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SqlFormat { expr }), + Expr::IsUnknown(expr) => write!(f, "{} IS UNKNOWN", SqlFormat { expr }), + Expr::IsNotTrue(expr) => write!(f, "{} IS NOT TRUE", SqlFormat { expr }), + Expr::IsNotFalse(expr) => write!(f, "{} IS NOT FALSE", SqlFormat { expr }), + Expr::IsNotUnknown(expr) => write!(f, "{} IS NOT UNKNOWN", SqlFormat { expr }), + Expr::BinaryExpr(expr) => write!(f, "{}", BinaryExprFormat { expr }), + Expr::ScalarFunction(func) => fmt_function(f, &func.fun.to_string(), false, &func.args), + Expr::ScalarUDF(ScalarUDF { fun, args }) => fmt_function(f, &fun.name, false, args), + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + if *negated { + write!( + f, + "{} NOT BETWEEN {} AND {}", + SqlFormat { expr }, + SqlFormat { expr: low }, + SqlFormat { expr: high } + ) + } else { + write!( + f, + "{} BETWEEN {} AND {}", + SqlFormat { expr }, + SqlFormat { expr: low }, + SqlFormat { expr: high } + ) + } + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + write!(f, "{}", SqlFormat { expr })?; + let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; + if *negated { + write!(f, " NOT")?; + } + if let Some(char) = escape_char { + write!( + f, + " {op_name} {} ESCAPE '{char}'", + SqlFormat { expr: pattern } + ) + } else { + write!(f, " {op_name} {}", SqlFormat { expr: pattern }) + } + } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive: _, + }) => { + write!(f, "{expr}")?; + if *negated { + write!(f, " NOT")?; + } + if let Some(char) = escape_char { + write!(f, " SIMILAR TO {pattern} ESCAPE '{char}'") + } else { + write!(f, " SIMILAR TO {pattern}") + } + } + Expr::InList(InList { + expr, + list, + negated, + }) => { + if *negated { + write!(f, "{expr} NOT IN ({})", expr_vec_fmt!(list)) + } else { + write!(f, "{expr} IN ({})", expr_vec_fmt!(list)) + } + } + _ => Err(fmt::Error), + } + } +} + +/// Format an `Expr` to a parsable SQL expression +pub fn fmt_expr_to_sql(expr: &Expr) -> Result { + let mut s = String::new(); + write!(&mut s, "{}", SqlFormat { expr }).map_err(|_| { + DeltaTableError::Generic("Unable to convert expression to string".to_owned()) + })?; + Ok(s) +} + +fn fmt_function(f: &mut fmt::Formatter, fun: &str, distinct: bool, args: &[Expr]) -> fmt::Result { + let args: Vec = args + .iter() + .map(|arg| format!("{}", SqlFormat { expr: arg })) + .collect(); + + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) +} + +macro_rules! format_option { + ($F:expr, $EXPR:expr) => {{ + match $EXPR { + Some(e) => write!($F, "{e}"), + None => write!($F, "NULL"), + } + }}; +} + +struct ScalarValueFormat<'a> { + scalar: &'a ScalarValue, +} + +impl<'a> fmt::Display for ScalarValueFormat<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.scalar { + ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float32(e) => format_option!(f, e)?, + ScalarValue::Float64(e) => format_option!(f, e)?, + ScalarValue::Int8(e) => format_option!(f, e)?, + ScalarValue::Int16(e) => format_option!(f, e)?, + ScalarValue::Int32(e) => format_option!(f, e)?, + ScalarValue::Int64(e) => format_option!(f, e)?, + ScalarValue::UInt8(e) => format_option!(f, e)?, + ScalarValue::UInt16(e) => format_option!(f, e)?, + ScalarValue::UInt32(e) => format_option!(f, e)?, + ScalarValue::UInt64(e) => format_option!(f, e)?, + ScalarValue::Utf8(e) | ScalarValue::LargeUtf8(e) => match e { + Some(e) => write!(f, "'{}'", escape_quoted_string(e, '\''))?, + None => write!(f, "NULL")?, + }, + ScalarValue::Binary(e) + | ScalarValue::FixedSizeBinary(_, e) + | ScalarValue::LargeBinary(e) => match e { + Some(l) => write!( + f, + "decode('{}', 'hex')", + l.iter() + .map(|v| format!("{v:02x}")) + .collect::>() + .join("") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::Null => write!(f, "NULL")?, + _ => return Err(fmt::Error), + }; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use datafusion::prelude::SessionContext; + use datafusion_common::{DFSchema, ScalarValue}; + use datafusion_expr::{col, decode, lit, substring, Expr, ExprSchemable}; + + use crate::{DeltaOps, DeltaTable, Schema, SchemaDataType, SchemaField}; + + use super::fmt_expr_to_sql; + + struct ParseTest { + expr: Expr, + expected: String, + override_expected_expr: Option, + } + + macro_rules! simple { + ( $EXPR:expr, $STR:expr ) => {{ + ParseTest { + expr: $EXPR, + expected: $STR, + override_expected_expr: None, + } + }}; + } + + async fn setup_table() -> DeltaTable { + let schema = Schema::new(vec![ + SchemaField::new( + "id".to_string(), + SchemaDataType::primitive("string".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "value".to_string(), + SchemaDataType::primitive("integer".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "value2".to_string(), + SchemaDataType::primitive("integer".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "modified".to_string(), + SchemaDataType::primitive("string".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "active".to_string(), + SchemaDataType::primitive("boolean".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "money".to_string(), + SchemaDataType::primitive("decimal(12,2)".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "_date".to_string(), + SchemaDataType::primitive("date".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "_timestamp".to_string(), + SchemaDataType::primitive("timestamp".to_string()), + true, + HashMap::new(), + ), + SchemaField::new( + "_binary".to_string(), + SchemaDataType::primitive("binary".to_string()), + true, + HashMap::new(), + ), + ]); + + let table = DeltaOps::new_in_memory() + .create() + .with_columns(schema.get_fields().clone()) + .await + .unwrap(); + assert_eq!(table.version(), 0); + table + } + + #[tokio::test] + async fn test_expr_sql() { + let table = setup_table().await; + + // String expression that we output must be parsable for conflict resolution. + let tests = vec![ + simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), + simple!(col("active").is_true(), "active IS TRUE".to_string()), + simple!(col("active"), "active".to_string()), + simple!(col("active").eq(lit(true)), "active = true".to_string()), + simple!(col("active").is_null(), "active IS NULL".to_string()), + simple!( + col("modified").eq(lit("2021-02-03")), + "modified = '2021-02-03'".to_string() + ), + simple!( + col("modified").eq(lit("'validate ' escapi\\ng'")), + "modified = '''validate '' escapi\\ng'''".to_string() + ), + simple!(col("money").gt(lit(0.10)), "money > 0.1".to_string()), + ParseTest { + expr: col("_binary").eq(lit(ScalarValue::Binary(Some(vec![0xAA, 0x00, 0xFF])))), + expected: "_binary = decode('aa00ff', 'hex')".to_string(), + override_expected_expr: Some(col("_binary").eq(decode(lit("aa00ff"), lit("hex")))), + }, + simple!( + col("value").between(lit(20_i64), lit(30_i64)), + "value BETWEEN 20 AND 30".to_string() + ), + simple!( + col("value").not_between(lit(20_i64), lit(30_i64)), + "value NOT BETWEEN 20 AND 30".to_string() + ), + simple!( + col("modified").like(lit("abc%")), + "modified LIKE 'abc%'".to_string() + ), + simple!( + col("modified").not_like(lit("abc%")), + "modified NOT LIKE 'abc%'".to_string() + ), + simple!( + (((col("value") * lit(2_i64) + col("value2")) / lit(3_i64)) - col("value")) + .gt(lit(0_i64)), + "(value * 2 + value2) / 3 - value > 0".to_string() + ), + simple!( + col("modified").in_list(vec![lit("a"), lit("c")], false), + "modified IN ('a', 'c')".to_string() + ), + simple!( + col("modified").in_list(vec![lit("a"), lit("c")], true), + "modified NOT IN ('a', 'c')".to_string() + ), + // Validate order of operations is maintained + simple!( + col("modified") + .eq(lit("value")) + .and(col("value").eq(lit(1_i64))) + .or(col("modified") + .eq(lit("value2")) + .and(col("value").gt(lit(1_i64)))), + "modified = 'value' AND value = 1 OR modified = 'value2' AND value > 1".to_string() + ), + simple!( + col("modified") + .eq(lit("value")) + .or(col("value").eq(lit(1_i64))) + .and( + col("modified") + .eq(lit("value2")) + .or(col("value").gt(lit(1_i64))), + ), + "(modified = 'value' OR value = 1) AND (modified = 'value2' OR value > 1)" + .to_string() + ), + // Validate functions are correctly parsed + simple!( + substring(col("modified"), lit(0_i64), lit(4_i64)).eq(lit("2021")), + "substr(modified, 0, 4) = '2021'".to_string() + ), + ]; + + let session = SessionContext::new(); + + for test in tests { + let actual = fmt_expr_to_sql(&test.expr).unwrap(); + assert_eq!(test.expected, actual); + + let actual_expr = table + .state + .parse_predicate_expression(actual, &session.state()) + .unwrap(); + + match test.override_expected_expr { + None => assert_eq!(test.expr, actual_expr), + Some(expr) => assert_eq!(expr, actual_expr), + } + } + + let unsupported_types = vec![ + /* TODO: Determine proper way to display decimal values in an sql expression*/ + simple!( + col("money").gt(lit(ScalarValue::Decimal128(Some(100), 12, 2))), + "money > 0.1".to_string() + ), + simple!( + col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond(Some(100), None))), + "".to_string() + ), + simple!( + col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond( + Some(100), + Some("UTC".into()) + ))), + "".to_string() + ), + simple!( + col("value") + .cast_to::( + &arrow_schema::DataType::Utf8, + &table + .state + .input_schema() + .unwrap() + .as_ref() + .to_owned() + .try_into() + .unwrap() + ) + .unwrap() + .eq(lit("1")), + "CAST(value as STRING) = '1'".to_string() + ), + ]; + + for test in unsupported_types { + assert!(fmt_expr_to_sql(&test.expr).is_err()); + } + } +} diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion/mod.rs similarity index 99% rename from rust/src/delta_datafusion.rs rename to rust/src/delta_datafusion/mod.rs index e542413cfd..166996dddd 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion/mod.rs @@ -76,6 +76,8 @@ use crate::{open_table, open_table_with_storage_options, DeltaTable, Invariant, const PATH_COLUMN: &str = "__delta_rs_path"; +pub mod expr; + impl From for DataFusionError { fn from(err: DeltaTableError) -> Self { match err { diff --git a/rust/src/operations/delete.rs b/rust/src/operations/delete.rs index d7f908680d..f07c92e442 100644 --- a/rust/src/operations/delete.rs +++ b/rust/src/operations/delete.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::protocol::{Action, Add, Remove}; use datafusion::execution::context::{SessionContext, SessionState}; use datafusion::physical_expr::create_physical_expr; @@ -263,7 +264,7 @@ async fn execute( // Do not make a commit when there are zero updates to the state if !actions.is_empty() { let operation = DeltaOperation::Delete { - predicate: Some(predicate.canonical_name()), + predicate: Some(fmt_expr_to_sql(&predicate)?), }; version = commit( object_store.as_ref(), @@ -298,7 +299,9 @@ impl std::future::IntoFuture for DeleteBuilder { let predicate = match this.predicate { Some(predicate) => match predicate { Expression::DataFusion(expr) => Some(expr), - Expression::String(s) => Some(this.snapshot.parse_predicate_expression(s)?), + Expression::String(s) => { + Some(this.snapshot.parse_predicate_expression(s, &state)?) + } }, None => None, }; @@ -335,6 +338,7 @@ mod tests { use arrow::record_batch::RecordBatch; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::*; + use serde_json::json; use std::sync::Arc; async fn setup_table(partitions: Option>) -> DeltaTable { @@ -456,7 +460,7 @@ mod tests { assert_eq!(table.version(), 2); assert_eq!(table.get_file_uris().count(), 2); - let (table, metrics) = DeltaOps(table) + let (mut table, metrics) = DeltaOps(table) .delete() .with_predicate(col("value").eq(lit(1))) .await @@ -470,6 +474,11 @@ mod tests { assert_eq!(metrics.num_deleted_rows, Some(1)); assert_eq!(metrics.num_copied_rows, Some(3)); + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert_eq!(parameters["predicate"], json!("value = 1")); + let expected = vec![ "+----+-------+------------+", "| id | value | modified |", diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index d088fbd3b7..d52dd26819 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -62,6 +62,7 @@ use serde_json::{Map, Value}; use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::commit; +use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{parquet_scan_from_actions, register_store}; use crate::operations::datafusion_utils::MetricObserverExec; use crate::operations::write::write_execution_plan; @@ -171,6 +172,7 @@ impl MergeBuilder { let builder = builder(UpdateBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, builder.updates, OperationType::Update, @@ -204,6 +206,7 @@ impl MergeBuilder { let builder = builder(DeleteBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, HashMap::default(), OperationType::Delete, @@ -240,6 +243,7 @@ impl MergeBuilder { let builder = builder(InsertBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, builder.set, OperationType::Insert, @@ -278,6 +282,7 @@ impl MergeBuilder { let builder = builder(UpdateBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, builder.updates, OperationType::Update, @@ -311,6 +316,7 @@ impl MergeBuilder { let builder = builder(DeleteBuilder::default()); let op = MergeOperation::try_new( &self.snapshot, + &self.state.as_ref(), builder.predicate, HashMap::default(), OperationType::Delete, @@ -448,15 +454,21 @@ struct MergeOperation { impl MergeOperation { pub fn try_new( snapshot: &DeltaTableState, + state: &Option<&SessionState>, predicate: Option, operations: HashMap, r#type: OperationType, ) -> DeltaResult { - let predicate = maybe_into_expr(predicate, snapshot)?; + let context = SessionContext::new(); + let mut s = &context.state(); + if let Some(df_state) = state { + s = df_state; + } + let predicate = maybe_into_expr(predicate, snapshot, s)?; let mut _operations = HashMap::new(); for (column, expr) in operations { - _operations.insert(column, into_expr(expr, snapshot)?); + _operations.insert(column, into_expr(expr, snapshot, s)?); } Ok(MergeOperation { @@ -518,7 +530,7 @@ async fn execute( let predicate = match predicate { Expression::DataFusion(expr) => expr, - Expression::String(s) => snapshot.parse_predicate_expression(s)?, + Expression::String(s) => snapshot.parse_predicate_expression(s, &state)?, }; let schema = snapshot.input_schema()?; @@ -675,7 +687,10 @@ async fn execute( }; let action_type = action_type.to_string(); - let predicate = op.predicate.map(|expr| expr.display_name().unwrap()); + let predicate = op + .predicate + .map(|expr| fmt_expr_to_sql(&expr)) + .transpose()?; predicates.push(MergePredicate { action_type, @@ -1035,7 +1050,7 @@ async fn execute( // Do not make a commit when there are zero updates to the state if !actions.is_empty() { let operation = DeltaOperation::Merge { - predicate: Some(predicate.canonical_name()), + predicate: Some(fmt_expr_to_sql(&predicate)?), matched_predicates: match_operations, not_matched_predicates: not_match_target_operations, not_matched_by_source_predicates: not_match_source_operations, @@ -1222,10 +1237,9 @@ mod tests { parameters["notMatchedPredicates"], json!(r#"[{"actionType":"insert"}]"#) ); - // Todo: Expected this predicate to actually be 'value = 1'. Predicate should contain a valid sql expression assert_eq!( parameters["notMatchedBySourcePredicates"], - json!(r#"[{"actionType":"update","predicate":"value = Int32(1)"}]"#) + json!(r#"[{"actionType":"update","predicate":"value = 1"}]"#) ); let expected = vec![ @@ -1447,7 +1461,7 @@ mod tests { assert_eq!(parameters["predicate"], json!("id = source.id")); assert_eq!( parameters["matchedPredicates"], - json!(r#"[{"actionType":"delete","predicate":"source.value <= Int32(10)"}]"#) + json!(r#"[{"actionType":"delete","predicate":"source.value <= 10"}]"#) ); let expected = vec![ @@ -1579,7 +1593,7 @@ mod tests { assert_eq!(parameters["predicate"], json!("id = source.id")); assert_eq!( parameters["notMatchedBySourcePredicates"], - json!(r#"[{"actionType":"delete","predicate":"modified > Utf8(\"2021-02-01\")"}]"#) + json!(r#"[{"actionType":"delete","predicate":"modified > '2021-02-01'"}]"#) ); let expected = vec![ diff --git a/rust/src/operations/mod.rs b/rust/src/operations/mod.rs index 7b6cb27ace..c07b81438b 100644 --- a/rust/src/operations/mod.rs +++ b/rust/src/operations/mod.rs @@ -205,6 +205,7 @@ mod datafusion_utils { use arrow_schema::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::Result as DataFusionResult; + use datafusion::execution::context::SessionState; use datafusion::physical_plan::DisplayAs; use datafusion::physical_plan::{ metrics::{ExecutionPlanMetricsSet, MetricsSet}, @@ -240,19 +241,24 @@ mod datafusion_utils { } } - pub(crate) fn into_expr(expr: Expression, snapshot: &DeltaTableState) -> DeltaResult { + pub(crate) fn into_expr( + expr: Expression, + snapshot: &DeltaTableState, + df_state: &SessionState, + ) -> DeltaResult { match expr { Expression::DataFusion(expr) => Ok(expr), - Expression::String(s) => snapshot.parse_predicate_expression(s), + Expression::String(s) => snapshot.parse_predicate_expression(s, df_state), } } pub(crate) fn maybe_into_expr( expr: Option, snapshot: &DeltaTableState, + df_state: &SessionState, ) -> DeltaResult> { Ok(match expr { - Some(predicate) => Some(into_expr(predicate, snapshot)?), + Some(predicate) => Some(into_expr(predicate, snapshot, df_state)?), None => None, }) } diff --git a/rust/src/operations/transaction/conflict_checker.rs b/rust/src/operations/transaction/conflict_checker.rs index d75e401def..d7a9d3fb86 100644 --- a/rust/src/operations/transaction/conflict_checker.rs +++ b/rust/src/operations/transaction/conflict_checker.rs @@ -114,8 +114,11 @@ impl<'a> TransactionInfo<'a> { actions: &'a Vec, read_whole_table: bool, ) -> DeltaResult { + use datafusion::prelude::SessionContext; + + let session = SessionContext::new(); let read_predicates = read_predicates - .map(|pred| read_snapshot.parse_predicate_expression(pred)) + .map(|pred| read_snapshot.parse_predicate_expression(pred, &session.state())) .transpose()?; Ok(Self { txn_id: "".into(), diff --git a/rust/src/operations/transaction/state.rs b/rust/src/operations/transaction/state.rs index 6fe1d65aee..5924609fb7 100644 --- a/rust/src/operations/transaction/state.rs +++ b/rust/src/operations/transaction/state.rs @@ -5,6 +5,7 @@ use arrow::datatypes::{ DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; use datafusion::datasource::physical_plan::wrap_partition_type_in_dict; +use datafusion::execution::context::SessionState; use datafusion::optimizer::utils::conjunction; use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use datafusion_common::config::ConfigOptions; @@ -104,7 +105,11 @@ impl DeltaTableState { } /// Parse an expression string into a datafusion [`Expr`] - pub fn parse_predicate_expression(&self, expr: impl AsRef) -> DeltaResult { + pub fn parse_predicate_expression( + &self, + expr: impl AsRef, + df_state: &SessionState, + ) -> DeltaResult { let dialect = &GenericDialect {}; let mut tokenizer = Tokenizer::new(dialect, expr.as_ref()); let tokens = tokenizer @@ -121,7 +126,7 @@ impl DeltaTableState { // TODO should we add the table name as qualifier when available? let df_schema = DFSchema::try_from_qualified_schema("", self.arrow_schema()?.as_ref())?; - let context_provider = DummyContextProvider::default(); + let context_provider = DeltaContextProvider { state: df_state }; let sql_to_rel = SqlToRel::new(&context_provider); Ok(sql_to_rel.sql_to_expr(sql, &df_schema, &mut Default::default())?) @@ -342,34 +347,33 @@ impl PruningStatistics for DeltaTableState { } } -#[derive(Default)] -struct DummyContextProvider { - options: ConfigOptions, +pub(crate) struct DeltaContextProvider<'a> { + state: &'a SessionState, } -impl ContextProvider for DummyContextProvider { +impl<'a> ContextProvider for DeltaContextProvider<'a> { fn get_table_provider(&self, _name: TableReference) -> DFResult> { unimplemented!() } - fn get_function_meta(&self, _name: &str) -> Option> { - unimplemented!() + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - unimplemented!() + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() } - fn get_variable_type(&self, _: &[String]) -> Option { + fn get_variable_type(&self, _var: &[String]) -> Option { unimplemented!() } fn options(&self) -> &ConfigOptions { - &self.options + self.state.config_options() } - fn get_window_meta(&self, _name: &str) -> Option> { - unimplemented!() + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() } } @@ -377,24 +381,29 @@ impl ContextProvider for DummyContextProvider { mod tests { use super::*; use crate::operations::transaction::test_utils::{create_add_action, init_table_actions}; + use datafusion::prelude::SessionContext; use datafusion_expr::{col, lit}; #[test] fn test_parse_predicate_expression() { - let state = DeltaTableState::from_actions(init_table_actions(), 0).unwrap(); + let snapshot = DeltaTableState::from_actions(init_table_actions(), 0).unwrap(); + let session = SessionContext::new(); + let state = session.state(); // parses simple expression - let parsed = state.parse_predicate_expression("value > 10").unwrap(); + let parsed = snapshot + .parse_predicate_expression("value > 10", &state) + .unwrap(); let expected = col("value").gt(lit::(10)); assert_eq!(parsed, expected); // fails for unknown column - let parsed = state.parse_predicate_expression("non_existent > 10"); + let parsed = snapshot.parse_predicate_expression("non_existent > 10", &state); assert!(parsed.is_err()); // parses complex expression - let parsed = state - .parse_predicate_expression("value > 10 OR value <= 0") + let parsed = snapshot + .parse_predicate_expression("value > 10 OR value <= 0", &state) .unwrap(); let expected = col("value") .gt(lit::(10)) diff --git a/rust/src/operations/update.rs b/rust/src/operations/update.rs index b030bc5644..3891c04fd9 100644 --- a/rust/src/operations/update.rs +++ b/rust/src/operations/update.rs @@ -43,7 +43,9 @@ use parquet::file::properties::WriterProperties; use serde_json::{Map, Value}; use crate::{ - delta_datafusion::{find_files, parquet_scan_from_actions, register_store}, + delta_datafusion::{ + expr::fmt_expr_to_sql, find_files, parquet_scan_from_actions, register_store, + }, protocol::{Action, DeltaOperation, Remove}, storage::{DeltaObjectStore, ObjectStoreRef}, table::state::DeltaTableState, @@ -194,7 +196,7 @@ async fn execute( let predicate = match predicate { Some(predicate) => match predicate { Expression::DataFusion(expr) => Some(expr), - Expression::String(s) => Some(snapshot.parse_predicate_expression(s)?), + Expression::String(s) => Some(snapshot.parse_predicate_expression(s, &state)?), }, None => None, }; @@ -203,7 +205,9 @@ async fn execute( .into_iter() .map(|(key, expr)| match expr { Expression::DataFusion(e) => Ok((key, e)), - Expression::String(s) => snapshot.parse_predicate_expression(s).map(|e| (key, e)), + Expression::String(s) => snapshot + .parse_predicate_expression(s, &state) + .map(|e| (key, e)), }) .collect::, _>>()?; @@ -416,7 +420,7 @@ async fn execute( metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis() as u64; let operation = DeltaOperation::Update { - predicate: Some(predicate.canonical_name()), + predicate: Some(fmt_expr_to_sql(&predicate)?), }; version = commit( object_store.as_ref(), @@ -481,6 +485,7 @@ mod tests { use arrow_array::Int32Array; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::*; + use serde_json::json; use std::sync::Arc; async fn setup_table(partitions: Option>) -> DeltaTable { @@ -603,7 +608,7 @@ mod tests { assert_eq!(table.version(), 1); assert_eq!(table.get_file_uris().count(), 1); - let (table, metrics) = DeltaOps(table) + let (mut table, metrics) = DeltaOps(table) .update() .with_predicate(col("modified").eq(lit("2021-02-03"))) .with_update("modified", lit("2023-05-14")) @@ -617,6 +622,11 @@ mod tests { assert_eq!(metrics.num_updated_rows, 2); assert_eq!(metrics.num_copied_rows, 2); + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert_eq!(parameters["predicate"], json!("modified = '2021-02-03'")); + let expected = vec![ "+----+-------+------------+", "| id | value | modified |",