diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index e0dd9f5639..68f630239d 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -3,9 +3,13 @@ use crate::kernel::{ ArrayType, DataType as DeltaDataType, MapType, MetadataValue, StructField, StructType, }; -use arrow_array::{new_null_array, Array, ArrayRef, RecordBatch, StructArray}; +use arrow_array::cast::AsArray; +use arrow_array::{ + new_null_array, Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, RecordBatch, + RecordBatchOptions, StructArray, +}; use arrow_cast::{cast_with_options, CastOptions}; -use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ArrowError, DataType, FieldRef, Fields, SchemaRef as ArrowSchemaRef}; use std::collections::HashMap; use std::sync::Arc; @@ -53,7 +57,7 @@ pub(crate) fn merge_struct( field.is_nullable() || right_field.is_nullable(), ); - new_field.metadata = field.metadata.clone(); + new_field.metadata.clone_from(&field.metadata); try_merge_metadata(&mut new_field.metadata, &right_field.metadata)?; Ok(new_field) } @@ -130,40 +134,135 @@ fn cast_struct( fields: &Fields, cast_options: &CastOptions, add_missing: bool, -) -> Result>, arrow_schema::ArrowError> { - fields - .iter() - .map(|field| { - let col_or_not = struct_array.column_by_name(field.name()); - match col_or_not { - None => match add_missing { - true => Ok(new_null_array(field.data_type(), struct_array.len())), - false => Err(arrow_schema::ArrowError::SchemaError(format!( - "Could not find column {0}", - field.name() - ))), - }, - Some(col) => { - if let (DataType::Struct(_), DataType::Struct(child_fields)) = - (col.data_type(), field.data_type()) - { - let child_struct = StructArray::from(col.into_data()); - let s = - cast_struct(&child_struct, child_fields, cast_options, add_missing)?; - Ok(Arc::new(StructArray::new( - child_fields.clone(), - s, - child_struct.nulls().map(ToOwned::to_owned), - )) as ArrayRef) - } else if is_cast_required(col.data_type(), field.data_type()) { - cast_with_options(col, field.data_type(), cast_options) - } else { - Ok(col.clone()) - } +) -> Result { + StructArray::try_new( + fields.to_owned(), + fields + .iter() + .map(|field| { + let col_or_not = struct_array.column_by_name(field.name()); + match col_or_not { + None => match add_missing { + true if field.is_nullable() => { + Ok(new_null_array(field.data_type(), struct_array.len())) + } + _ => Err(ArrowError::SchemaError(format!( + "Could not find column {0}", + field.name() + ))), + }, + Some(col) => cast_field(col, field, cast_options, add_missing), } - } - }) - .collect::, _>>() + }) + .collect::, _>>()?, + struct_array.nulls().map(ToOwned::to_owned), + ) +} + +fn cast_list( + array: &GenericListArray, + field: &FieldRef, + cast_options: &CastOptions, + add_missing: bool, +) -> Result, ArrowError> { + let values = cast_field(array.values(), field, cast_options, add_missing)?; + GenericListArray::::try_new( + field.clone(), + array.offsets().clone(), + values, + array.nulls().cloned(), + ) +} + +fn cast_map( + array: &MapArray, + entries_field: &FieldRef, + sorted: bool, + cast_options: &CastOptions, + add_missing: bool, +) -> Result { + match entries_field.data_type() { + DataType::Struct(entry_fields) => { + let entries = cast_struct(array.entries(), entry_fields, cast_options, add_missing)?; + MapArray::try_new( + entries_field.clone(), + array.offsets().to_owned(), + entries, + array.nulls().cloned(), + sorted, + ) + } + _ => Err(ArrowError::CastError( + "Map entries must be a struct".to_string(), + )), + } +} + +fn cast_field( + col: &ArrayRef, + field: &FieldRef, + cast_options: &CastOptions, + add_missing: bool, +) -> Result { + if let (DataType::Struct(_), DataType::Struct(child_fields)) = + (col.data_type(), field.data_type()) + { + let child_struct = StructArray::from(col.into_data()); + Ok(Arc::new(cast_struct( + &child_struct, + child_fields, + cast_options, + add_missing, + )?) as ArrayRef) + } else if let (DataType::List(_), DataType::List(child_fields)) = + (col.data_type(), field.data_type()) + { + Ok(Arc::new(cast_list( + col.as_any() + .downcast_ref::>() + .ok_or(ArrowError::CastError(format!( + "Expected a list for {} but got {}", + field.name(), + col.data_type() + )))?, + child_fields, + cast_options, + add_missing, + )?) as ArrayRef) + } else if let (DataType::LargeList(_), DataType::LargeList(child_fields)) = + (col.data_type(), field.data_type()) + { + Ok(Arc::new(cast_list( + col.as_any() + .downcast_ref::>() + .ok_or(ArrowError::CastError(format!( + "Expected a list for {} but got {}", + field.name(), + col.data_type() + )))?, + child_fields, + cast_options, + add_missing, + )?) as ArrayRef) + } else if let (DataType::Map(_, _), DataType::Map(child_fields, sorted)) = + (col.data_type(), field.data_type()) + { + Ok(Arc::new(cast_map( + col.as_map_opt().ok_or(ArrowError::CastError(format!( + "Expected a map for {} but got {}", + field.name(), + col.data_type() + )))?, + child_fields, + *sorted, + cast_options, + add_missing, + )?) as ArrayRef) + } else if is_cast_required(col.data_type(), field.data_type()) { + cast_with_options(col, field.data_type(), cast_options) + } else { + Ok(col.clone()) + } } fn is_cast_required(a: &DataType, b: &DataType) -> bool { @@ -193,18 +292,26 @@ pub fn cast_record_batch( batch.columns().to_owned(), None, ); - let columns = cast_struct(&s, target_schema.fields(), &cast_options, add_missing)?; - Ok(RecordBatch::try_new(target_schema, columns)?) + let struct_array = cast_struct(&s, target_schema.fields(), &cast_options, add_missing)?; + Ok(RecordBatch::try_new_with_options( + target_schema, + struct_array.columns().to_vec(), + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + )?) } #[cfg(test)] mod tests { use std::collections::HashMap; + use std::ops::Deref; use std::sync::Arc; - use arrow::array::ArrayData; - use arrow_array::{Array, ArrayRef, ListArray, RecordBatch}; - use arrow_buffer::Buffer; + use arrow::array::types::Int32Type; + use arrow::array::{ + new_empty_array, new_null_array, Array, ArrayData, ArrayRef, AsArray, Int32Array, + ListArray, PrimitiveArray, RecordBatch, StringArray, StructArray, + }; + use arrow::buffer::{Buffer, NullBuffer}; use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; use itertools::Itertools; @@ -354,4 +461,303 @@ mod tests { assert!(is_cast_required(&field1, &field2)); } + + #[test] + fn test_add_missing_null_fields_with_no_missing_fields() { + let schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new("field2", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])), + ], + ) + .unwrap(); + let result = cast_record_batch(&batch, schema.clone(), false, true).unwrap(); + assert_eq!(result.schema(), schema); + assert_eq!(result.num_columns(), 2); + assert_eq!( + result.column(0).deref().as_primitive::(), + &PrimitiveArray::::from_iter([1, 2, 3]) + ); + assert_eq!( + result.column(1).deref().as_string(), + &StringArray::from(vec![Some("a"), None, Some("c")]) + ); + } + + #[test] + fn test_add_missing_null_fields_with_missing_beginning() { + let schema = Arc::new(Schema::new(vec![Field::new( + "field2", + DataType::Utf8, + true, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec![ + Some("a"), + None, + Some("c"), + ]))], + ) + .unwrap(); + + let new_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, true), + Field::new("field2", DataType::Utf8, true), + ])); + let result = cast_record_batch(&batch, new_schema.clone(), false, true).unwrap(); + assert_eq!(result.schema(), new_schema); + assert_eq!(result.num_columns(), 2); + assert_eq!( + result.column(0).deref().as_primitive::(), + new_null_array(&DataType::Int32, 3) + .deref() + .as_primitive::() + ); + assert_eq!( + result.column(1).deref().as_string(), + &StringArray::from(vec![Some("a"), None, Some("c")]) + ); + } + + #[test] + fn test_add_missing_null_fields_with_missing_end() { + let schema = Arc::new(Schema::new(vec![Field::new( + "field1", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let new_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new("field2", DataType::Utf8, true), + ])); + let result = cast_record_batch(&batch, new_schema.clone(), false, true).unwrap(); + assert_eq!(result.schema(), new_schema); + assert_eq!(result.num_columns(), 2); + assert_eq!( + result.column(0).deref().as_primitive::(), + &PrimitiveArray::::from(vec![Some(1), Some(2), Some(3)]) + ); + assert_eq!( + result.column(1).deref().as_string::(), + new_null_array(&DataType::Utf8, 3).deref().as_string() + ); + } + + #[test] + fn test_add_missing_null_fields_error_on_missing_non_null() { + let schema = Arc::new(Schema::new(vec![Field::new( + "field1", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let new_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new("field2", DataType::Utf8, false), + ])); + let result = cast_record_batch(&batch, new_schema.clone(), false, true); + assert!(result.is_err()); + } + + #[test] + fn test_add_missing_null_fields_nested_struct_missing() { + let nested_fields = Fields::from(vec![Field::new("nested1", DataType::Utf8, true)]); + let schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new("field2", DataType::Struct(nested_fields.clone()), true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StructArray::new( + nested_fields, + vec![Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef], + None, + )), + ], + ) + .unwrap(); + let new_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new( + "field2", + DataType::Struct(Fields::from(vec![ + Field::new("nested1", DataType::Utf8, true), + Field::new("nested2", DataType::Utf8, true), + ])), + true, + ), + ])); + let result = cast_record_batch(&batch, new_schema.clone(), false, true).unwrap(); + assert_eq!(result.schema(), new_schema); + assert_eq!(result.num_columns(), 2); + assert_eq!( + result.column(0).deref().as_primitive::(), + &PrimitiveArray::::from_iter([1, 2, 3]) + ); + let struct_column = result.column(1).deref().as_struct(); + assert_eq!(struct_column.num_columns(), 2); + assert_eq!( + struct_column.column(0).deref().as_string(), + &StringArray::from(vec![Some("a"), None, Some("c")]) + ); + assert_eq!( + struct_column.column(1).deref().as_string::(), + new_null_array(&DataType::Utf8, 3).deref().as_string() + ); + } + + #[test] + fn test_add_missing_null_fields_nested_struct_missing_non_nullable() { + let nested_fields = Fields::from(vec![Field::new("nested1", DataType::Utf8, false)]); + let schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new("field2", DataType::Struct(nested_fields.clone()), true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StructArray::new( + nested_fields, + vec![new_null_array(&DataType::Utf8, 3)], + Some(NullBuffer::new_null(3)), + )), + ], + ) + .unwrap(); + let new_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new( + "field2", + DataType::Struct(Fields::from(vec![ + Field::new("nested1", DataType::Utf8, false), + Field::new("nested2", DataType::Utf8, true), + ])), + true, + ), + ])); + let result = cast_record_batch(&batch, new_schema.clone(), false, true).unwrap(); + assert_eq!(result.schema(), new_schema); + assert_eq!(result.num_columns(), 2); + assert_eq!( + result.column(0).deref().as_primitive::(), + &PrimitiveArray::::from_iter([1, 2, 3]) + ); + let struct_column = result.column(1).deref().as_struct(); + assert_eq!(struct_column.num_columns(), 2); + let expected: [Option<&str>; 3] = Default::default(); + assert_eq!( + struct_column.column(0).deref().as_string(), + &StringArray::from(Vec::from(expected)) + ); + assert_eq!( + struct_column.column(1).deref().as_string::(), + new_null_array(&DataType::Utf8, 3).deref().as_string(), + ); + } + + #[test] + fn test_add_missing_null_fields_list_missing() { + let schema = Arc::new(Schema::new(vec![Field::new( + "field1", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let new_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new( + "field2", + DataType::List(Arc::new(Field::new("nested1", DataType::Utf8, true))), + true, + ), + ])); + let result = cast_record_batch(&batch, new_schema.clone(), false, true).unwrap(); + assert_eq!(result.schema(), new_schema); + assert_eq!(result.num_columns(), 2); + assert_eq!( + result.column(0).deref().as_primitive::(), + &PrimitiveArray::::from_iter([1, 2, 3]) + ); + let list_column = result.column(1).deref().as_list::(); + assert_eq!(list_column.len(), 3); + assert_eq!(list_column.value_offsets(), &[0, 0, 0, 0]); + assert_eq!( + list_column.values().deref().as_string::(), + new_empty_array(&DataType::Utf8).deref().as_string() + ) + } + + #[test] + fn test_add_missing_null_fields_map_missing() { + let schema = Arc::new(Schema::new(vec![Field::new( + "field1", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let new_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Int32, false), + Field::new( + "field2", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, true), + Field::new("value", DataType::Utf8, true), + ])), + true, + )), + false, + ), + true, + ), + ])); + let result = cast_record_batch(&batch, new_schema.clone(), false, true).unwrap(); + assert_eq!(result.schema(), new_schema); + assert_eq!(result.num_columns(), 2); + assert_eq!( + result.column(0).deref().as_primitive::(), + &PrimitiveArray::::from_iter([1, 2, 3]) + ); + let map_column = result.column(1).deref().as_map(); + assert_eq!(map_column.len(), 3); + assert_eq!(map_column.offsets().as_ref(), &[0; 4]); + assert_eq!( + map_column.keys().deref().as_string::(), + new_empty_array(&DataType::Utf8).deref().as_string() + ); + assert_eq!( + map_column.values().deref().as_string::(), + new_empty_array(&DataType::Utf8).deref().as_string() + ); + } }