diff --git a/crates/deltalake-core/src/delta_datafusion/expr.rs b/crates/deltalake-core/src/delta_datafusion/expr.rs index 5a4cda6bd7..347d093658 100644 --- a/crates/deltalake-core/src/delta_datafusion/expr.rs +++ b/crates/deltalake-core/src/delta_datafusion/expr.rs @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionState; use datafusion_common::Result as DFResult; use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::{ - expr::InList, AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource, + expr::InList, AggregateUDF, Between, BinaryExpr, Cast, Expr, GetIndexedField, Like, TableSource, }; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use sqlparser::ast::escape_quoted_string; @@ -263,6 +263,28 @@ impl<'a> Display for SqlFormat<'a> { write!(f, "{expr} IN ({})", expr_vec_fmt!(list)) } } + Expr::GetIndexedField(GetIndexedField { expr, field }) => match field { + datafusion_expr::GetFieldAccess::NamedStructField { name } => { + write!( + f, + "{}[{}]", + SqlFormat { expr }, + ScalarValueFormat { scalar: name } + ) + } + datafusion_expr::GetFieldAccess::ListIndex { key } => { + write!(f, "{}[{}]", SqlFormat { expr }, SqlFormat { expr: key }) + } + datafusion_expr::GetFieldAccess::ListRange { start, stop } => { + write!( + f, + "{}[{}:{}]", + SqlFormat { expr }, + SqlFormat { expr: start }, + SqlFormat { expr: stop } + ) + } + }, _ => Err(fmt::Error), } } @@ -346,10 +368,10 @@ mod test { use arrow_schema::DataType as ArrowDataType; use datafusion::prelude::SessionContext; use datafusion_common::{Column, DFSchema, ScalarValue}; - use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable}; + use datafusion_expr::{cardinality, col, decode, lit, substring, Cast, Expr, ExprSchemable}; use crate::delta_datafusion::DeltaSessionContext; - use crate::kernel::{DataType, PrimitiveType, StructField, StructType}; + use crate::kernel::{ArrayType, DataType, PrimitiveType, StructField, StructType}; use crate::{DeltaOps, DeltaTable}; use super::fmt_expr_to_sql; @@ -422,6 +444,30 @@ mod test { DataType::Primitive(PrimitiveType::Binary), true, ), + StructField::new( + "_struct".to_string(), + DataType::Struct(Box::new(StructType::new(vec![ + StructField::new("a", DataType::Primitive(PrimitiveType::Integer), true), + StructField::new( + "nested", + DataType::Struct(Box::new(StructType::new(vec![StructField::new( + "b", + DataType::Primitive(PrimitiveType::Integer), + true, + )]))), + true, + ), + ]))), + true, + ), + StructField::new( + "_list".to_string(), + DataType::Array(Box::new(ArrayType::new( + DataType::Primitive(PrimitiveType::Integer), + true, + ))), + true, + ), ]); let table = DeltaOps::new_in_memory() @@ -541,6 +587,22 @@ mod test { .eq(lit("1")), "arrow_cast(value, 'Utf8') = '1'".to_string() ), + simple!( + col("_struct").field("a").eq(lit(20_i64)), + "_struct['a'] = 20".to_string() + ), + simple!( + col("_struct").field("nested").field("b").eq(lit(20_i64)), + "_struct['nested']['b'] = 20".to_string() + ), + simple!( + col("_list").index(lit(1_i64)).eq(lit(20_i64)), + "_list[1] = 20".to_string() + ), + simple!( + cardinality(col("_list").range(col("value"), lit(10_i64))), + "cardinality(_list[value:10])".to_string() + ), ]; let session: SessionContext = DeltaSessionContext::default().into(); diff --git a/crates/deltalake-core/src/operations/cast.rs b/crates/deltalake-core/src/operations/cast.rs index e697c06d54..6e77552286 100644 --- a/crates/deltalake-core/src/operations/cast.rs +++ b/crates/deltalake-core/src/operations/cast.rs @@ -8,29 +8,27 @@ use std::sync::Arc; use crate::DeltaResult; -fn cast_record_batch_columns( - batch: &RecordBatch, +fn cast_struct( + struct_array: &StructArray, fields: &Fields, cast_options: &CastOptions, ) -> Result>, arrow_schema::ArrowError> { fields .iter() - .map(|f| { - let col = batch.column_by_name(f.name()).unwrap(); - + .map(|field| { + let col = struct_array.column_by_name(field.name()).unwrap(); if let (DataType::Struct(_), DataType::Struct(child_fields)) = - (col.data_type(), f.data_type()) + (col.data_type(), field.data_type()) { - let child_batch = RecordBatch::from(StructArray::from(col.into_data())); - let child_columns = - cast_record_batch_columns(&child_batch, child_fields, cast_options)?; + let child_struct = StructArray::from(col.into_data()); + let s = cast_struct(&child_struct, child_fields, cast_options)?; Ok(Arc::new(StructArray::new( child_fields.clone(), - child_columns.clone(), - None, + s, + child_struct.nulls().map(ToOwned::to_owned), )) as ArrayRef) - } else if is_cast_required(col.data_type(), f.data_type()) { - cast_with_options(col, f.data_type(), cast_options) + } else if is_cast_required(col.data_type(), field.data_type()) { + cast_with_options(col, field.data_type(), cast_options) } else { Ok(col.clone()) } @@ -59,7 +57,13 @@ pub fn cast_record_batch( ..Default::default() }; - let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?; + let s = StructArray::new( + batch.schema().as_ref().to_owned().fields, + batch.columns().to_owned(), + None, + ); + + let columns = cast_struct(&s, target_schema.fields(), &cast_options)?; Ok(RecordBatch::try_new(target_schema, columns)?) } diff --git a/crates/deltalake-core/src/operations/delete.rs b/crates/deltalake-core/src/operations/delete.rs index 9e5cdfc82a..b60dde5687 100644 --- a/crates/deltalake-core/src/operations/delete.rs +++ b/crates/deltalake-core/src/operations/delete.rs @@ -338,6 +338,10 @@ mod tests { use arrow::array::Int32Array; use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::ArrayRef; + use arrow_array::StructArray; + use arrow_buffer::NullBuffer; + use arrow_schema::Fields; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::*; use serde_json::json; @@ -728,6 +732,58 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_delete_nested() { + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + // Test Delete with a predicate that references struct fields + // See #2019 + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Utf8, true), + Field::new( + "props", + DataType::Struct(Fields::from(vec![Field::new("a", DataType::Utf8, true)])), + true, + ), + ])); + + let struct_array = StructArray::new( + Fields::from(vec![Field::new("a", DataType::Utf8, true)]), + vec![Arc::new(arrow::array::StringArray::from(vec![ + Some("2021-02-01"), + Some("2021-02-02"), + None, + None, + ])) as ArrayRef], + Some(NullBuffer::from_iter(vec![true, true, true, false])), + ); + + let data = vec![ + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "D"])) as ArrayRef, + Arc::new(struct_array) as ArrayRef, + ]; + let batches = vec![RecordBatch::try_new(schema.clone(), data).unwrap()]; + + let table = DeltaOps::new_in_memory().write(batches).await.unwrap(); + + let (table, _metrics) = DeltaOps(table) + .delete() + .with_predicate("props['a'] = '2021-02-02'") + .await + .unwrap(); + + let expected = [ + "+----+-----------------+", + "| id | props |", + "+----+-----------------+", + "| A | {a: 2021-02-01} |", + "| C | {a: } |", + "| D | |", + "+----+-----------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_failure_nondeterministic_query() { // Deletion requires a deterministic predicate