Skip to content

Commit

Permalink
fix: remove casts of structs to record batch (delta-io#2033)
Browse files Browse the repository at this point in the history
# Description
Fixes an issue where the writer attempts to convert a Arrow `Struct`
into a `RecordBatch`. This cannot be done since it will drop the
validity array and would prevents structs with a value of `null` from
being stored correctly.

This PR also extends the predicate representation for struct field
access, list index access, and list range access.

# Related Issue(s)
- closes delta-io#2019
  • Loading branch information
Blajda authored and r3stl355 committed Jan 10, 2024
1 parent fc6da94 commit f50e584
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 17 deletions.
68 changes: 65 additions & 3 deletions crates/deltalake-core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionState;
use datafusion_common::Result as DFResult;
use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference};
use datafusion_expr::{
expr::InList, AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource,
expr::InList, AggregateUDF, Between, BinaryExpr, Cast, Expr, GetIndexedField, Like, TableSource,
};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use sqlparser::ast::escape_quoted_string;
Expand Down Expand Up @@ -263,6 +263,28 @@ impl<'a> Display for SqlFormat<'a> {
write!(f, "{expr} IN ({})", expr_vec_fmt!(list))
}
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => match field {
datafusion_expr::GetFieldAccess::NamedStructField { name } => {
write!(
f,
"{}[{}]",
SqlFormat { expr },
ScalarValueFormat { scalar: name }
)
}
datafusion_expr::GetFieldAccess::ListIndex { key } => {
write!(f, "{}[{}]", SqlFormat { expr }, SqlFormat { expr: key })
}
datafusion_expr::GetFieldAccess::ListRange { start, stop } => {
write!(
f,
"{}[{}:{}]",
SqlFormat { expr },
SqlFormat { expr: start },
SqlFormat { expr: stop }
)
}
},
_ => Err(fmt::Error),
}
}
Expand Down Expand Up @@ -346,10 +368,10 @@ mod test {
use arrow_schema::DataType as ArrowDataType;
use datafusion::prelude::SessionContext;
use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable};
use datafusion_expr::{cardinality, col, decode, lit, substring, Cast, Expr, ExprSchemable};

use crate::delta_datafusion::DeltaSessionContext;
use crate::kernel::{DataType, PrimitiveType, StructField, StructType};
use crate::kernel::{ArrayType, DataType, PrimitiveType, StructField, StructType};
use crate::{DeltaOps, DeltaTable};

use super::fmt_expr_to_sql;
Expand Down Expand Up @@ -422,6 +444,30 @@ mod test {
DataType::Primitive(PrimitiveType::Binary),
true,
),
StructField::new(
"_struct".to_string(),
DataType::Struct(Box::new(StructType::new(vec![
StructField::new("a", DataType::Primitive(PrimitiveType::Integer), true),
StructField::new(
"nested",
DataType::Struct(Box::new(StructType::new(vec![StructField::new(
"b",
DataType::Primitive(PrimitiveType::Integer),
true,
)]))),
true,
),
]))),
true,
),
StructField::new(
"_list".to_string(),
DataType::Array(Box::new(ArrayType::new(
DataType::Primitive(PrimitiveType::Integer),
true,
))),
true,
),
]);

let table = DeltaOps::new_in_memory()
Expand Down Expand Up @@ -541,6 +587,22 @@ mod test {
.eq(lit("1")),
"arrow_cast(value, 'Utf8') = '1'".to_string()
),
simple!(
col("_struct").field("a").eq(lit(20_i64)),
"_struct['a'] = 20".to_string()
),
simple!(
col("_struct").field("nested").field("b").eq(lit(20_i64)),
"_struct['nested']['b'] = 20".to_string()
),
simple!(
col("_list").index(lit(1_i64)).eq(lit(20_i64)),
"_list[1] = 20".to_string()
),
simple!(
cardinality(col("_list").range(col("value"), lit(10_i64))),
"cardinality(_list[value:10])".to_string()
),
];

let session: SessionContext = DeltaSessionContext::default().into();
Expand Down
32 changes: 18 additions & 14 deletions crates/deltalake-core/src/operations/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,27 @@ use std::sync::Arc;

use crate::DeltaResult;

fn cast_record_batch_columns(
batch: &RecordBatch,
fn cast_struct(
struct_array: &StructArray,
fields: &Fields,
cast_options: &CastOptions,
) -> Result<Vec<Arc<(dyn Array)>>, arrow_schema::ArrowError> {
fields
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();

.map(|field| {
let col = struct_array.column_by_name(field.name()).unwrap();
if let (DataType::Struct(_), DataType::Struct(child_fields)) =
(col.data_type(), f.data_type())
(col.data_type(), field.data_type())
{
let child_batch = RecordBatch::from(StructArray::from(col.into_data()));
let child_columns =
cast_record_batch_columns(&child_batch, child_fields, cast_options)?;
let child_struct = StructArray::from(col.into_data());
let s = cast_struct(&child_struct, child_fields, cast_options)?;
Ok(Arc::new(StructArray::new(
child_fields.clone(),
child_columns.clone(),
None,
s,
child_struct.nulls().map(ToOwned::to_owned),
)) as ArrayRef)
} else if is_cast_required(col.data_type(), f.data_type()) {
cast_with_options(col, f.data_type(), cast_options)
} else if is_cast_required(col.data_type(), field.data_type()) {
cast_with_options(col, field.data_type(), cast_options)
} else {
Ok(col.clone())
}
Expand Down Expand Up @@ -59,7 +57,13 @@ pub fn cast_record_batch(
..Default::default()
};

let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?;
let s = StructArray::new(
batch.schema().as_ref().to_owned().fields,
batch.columns().to_owned(),
None,
);

let columns = cast_struct(&s, target_schema.fields(), &cast_options)?;
Ok(RecordBatch::try_new(target_schema, columns)?)
}

Expand Down
56 changes: 56 additions & 0 deletions crates/deltalake-core/src/operations/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ mod tests {
use arrow::array::Int32Array;
use arrow::datatypes::{Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::ArrayRef;
use arrow_array::StructArray;
use arrow_buffer::NullBuffer;
use arrow_schema::Fields;
use datafusion::assert_batches_sorted_eq;
use datafusion::prelude::*;
use serde_json::json;
Expand Down Expand Up @@ -728,6 +732,58 @@ mod tests {
assert_batches_sorted_eq!(&expected, &actual);
}

#[tokio::test]
async fn test_delete_nested() {
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
// Test Delete with a predicate that references struct fields
// See #2019
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Utf8, true),
Field::new(
"props",
DataType::Struct(Fields::from(vec![Field::new("a", DataType::Utf8, true)])),
true,
),
]));

let struct_array = StructArray::new(
Fields::from(vec![Field::new("a", DataType::Utf8, true)]),
vec![Arc::new(arrow::array::StringArray::from(vec![
Some("2021-02-01"),
Some("2021-02-02"),
None,
None,
])) as ArrayRef],
Some(NullBuffer::from_iter(vec![true, true, true, false])),
);

let data = vec![
Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "D"])) as ArrayRef,
Arc::new(struct_array) as ArrayRef,
];
let batches = vec![RecordBatch::try_new(schema.clone(), data).unwrap()];

let table = DeltaOps::new_in_memory().write(batches).await.unwrap();

let (table, _metrics) = DeltaOps(table)
.delete()
.with_predicate("props['a'] = '2021-02-02'")
.await
.unwrap();

let expected = [
"+----+-----------------+",
"| id | props |",
"+----+-----------------+",
"| A | {a: 2021-02-01} |",
"| C | {a: } |",
"| D | |",
"+----+-----------------+",
];
let actual = get_data(&table).await;
assert_batches_sorted_eq!(&expected, &actual);
}

#[tokio::test]
async fn test_failure_nondeterministic_query() {
// Deletion requires a deterministic predicate
Expand Down

0 comments on commit f50e584

Please sign in to comment.