diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index d1fc616359..d9adbe7cbc 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -95,6 +95,25 @@ class RawDeltaTable: writer_properties: Optional[Dict[str, int]], safe_cast: bool = False, ) -> str: ... + def merge_execute( + self, + source: pa.RecordBatchReader, + predicate: str, + source_alias: Optional[str], + target_alias: Optional[str], + writer_properties: Optional[Dict[str, int | None]], + safe_cast: bool, + matched_update_updates: Optional[Dict[str, str]], + matched_update_predicate: Optional[str], + matched_delete_predicate: Optional[str], + matched_delete_all: Optional[bool], + not_matched_insert_updates: Optional[Dict[str, str]], + not_matched_insert_predicate: Optional[str], + not_matched_by_source_update_updates: Optional[Dict[str, str]], + not_matched_by_source_update_predicate: Optional[str], + not_matched_by_source_delete_predicate: Optional[str], + not_matched_by_source_delete_all: Optional[bool], + ) -> str: ... def get_active_partitions( self, partitions_filters: Optional[FilterType] = None ) -> Any: ... diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 80a48f619e..e913eb7622 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: import pandas +from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import RawDeltaTable from ._util import encode_partition_value from .data_catalog import DataCatalog @@ -525,6 +526,61 @@ def optimize( ) -> "TableOptimizer": return TableOptimizer(self) + def merge( + self, + source: Union[pyarrow.Table, pyarrow.RecordBatch, pyarrow.RecordBatchReader], + predicate: str, + source_alias: Optional[str] = None, + target_alias: Optional[str] = None, + error_on_type_mismatch: bool = True, + ) -> "TableMerger": + """Pass the source data which you want to merge on the target delta table, providing a + predicate in SQL query like format. You can also specify on what to do when the underlying data types do not + match the underlying table. + + Args: + source (pyarrow.Table | pyarrow.RecordBatch | pyarrow.RecordBatchReader ): source data + predicate (str): SQL like predicate on how to merge + source_alias (str): Alias for the source table + target_alias (str): Alias for the target table + error_on_type_mismatch (bool): specify if merge will return error if data types are mismatching :default = True + + Returns: + TableMerger: TableMerger Object + """ + invariants = self.schema().invariants + checker = _DeltaDataChecker(invariants) + + if isinstance(source, pyarrow.RecordBatchReader): + schema = source.schema + elif isinstance(source, pyarrow.RecordBatch): + schema = source.schema + source = [source] + elif isinstance(source, pyarrow.Table): + schema = source.schema + source = source.to_reader() + else: + raise TypeError( + f"{type(source).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch or Table are valid inputs for source." + ) + + def validate_batch(batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: + checker.check_batch(batch) + return batch + + source = pyarrow.RecordBatchReader.from_batches( + schema, (validate_batch(batch) for batch in source) + ) + + return TableMerger( + self, + source=source, + predicate=predicate, + source_alias=source_alias, + target_alias=target_alias, + safe_cast=not error_on_type_mismatch, + ) + def pyarrow_schema(self) -> pyarrow.Schema: """ Get the current schema of the DeltaTable with the Parquet PyArrow format. @@ -747,6 +803,308 @@ def delete(self, predicate: Optional[str] = None) -> Dict[str, Any]: return json.loads(metrics) +class TableMerger: + """API for various table MERGE commands.""" + + def __init__( + self, + table: DeltaTable, + source: pyarrow.RecordBatchReader, + predicate: str, + source_alias: Optional[str] = None, + target_alias: Optional[str] = None, + safe_cast: bool = True, + ): + self.table = table + self.source = source + self.predicate = predicate + self.source_alias = source_alias + self.target_alias = target_alias + self.safe_cast = safe_cast + self.writer_properties: Optional[Dict[str, Optional[int]]] = None + self.matched_update_updates: Optional[Dict[str, str]] = None + self.matched_update_predicate: Optional[str] = None + self.matched_delete_predicate: Optional[str] = None + self.matched_delete_all: Optional[bool] = None + self.not_matched_insert_updates: Optional[Dict[str, str]] = None + self.not_matched_insert_predicate: Optional[str] = None + self.not_matched_by_source_update_updates: Optional[Dict[str, str]] = None + self.not_matched_by_source_update_predicate: Optional[str] = None + self.not_matched_by_source_delete_predicate: Optional[str] = None + self.not_matched_by_source_delete_all: Optional[bool] = None + + def with_writer_properties( + self, + data_page_size_limit: Optional[int] = None, + dictionary_page_size_limit: Optional[int] = None, + data_page_row_count_limit: Optional[int] = None, + write_batch_size: Optional[int] = None, + max_row_group_size: Optional[int] = None, + ) -> "TableMerger": + """Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html: + + Args: + data_page_size_limit (int|None, optional): Limit DataPage size to this in bytes. Defaults to None. + dictionary_page_size_limit (int|None, optional): Limit the size of each DataPage to store dicts to this amount in bytes. Defaults to None. + data_page_row_count_limit (int|None, optional): Limit the number of rows in each DataPage. Defaults to None. + write_batch_size (int|None, optional): Splits internally to smaller batch size. Defaults to None. + max_row_group_size (int|None, optional): Max number of rows in row group. Defaults to None. + + Returns: + TableMerger: TableMerger Object + """ + writer_properties = { + "data_page_size_limit": data_page_size_limit, + "dictionary_page_size_limit": dictionary_page_size_limit, + "data_page_row_count_limit": data_page_row_count_limit, + "write_batch_size": write_batch_size, + "max_row_group_size": max_row_group_size, + } + self.writer_properties = writer_properties + return self + + def when_matched_update( + self, updates: Dict[str, str], predicate: Optional[str] = None + ) -> "TableMerger": + """Update a matched table row based on the rules defined by ``updates``. + If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + + Args: + updates (dict): a mapping of column name to update SQL expression. + predicate (str | None, optional): SQL like predicate on when to update. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ + ... .when_matched_update( + ... updates = { + ... "x": "source.x", + ... "y": "source.y" + ... } + ... ).execute() + """ + self.matched_update_updates = updates + self.matched_update_predicate = predicate + return self + + def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": + """Updating all source fields to target fields, source and target are required to have the same field names. + If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + + Args: + predicate (str | None, optional): SQL like predicate on when to update all columns. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ + ... .when_matched_update_all().execute() + """ + + src_alias = (self.source_alias + ".") if self.source_alias is not None else "" + trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + + self.matched_update_updates = { + f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" + for col in self.source.schema + } + self.matched_update_predicate = predicate + return self + + def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": + """Delete a matched row from the table only if the given ``predicate`` (if specified) is + true for the matched row. If not specified it deletes all matches. + + Args: + predicate (str | None, optional): SQL like predicate on when to delete. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + Delete on a predicate + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ + ... .when_matched_delete(predicate = "source.deleted = true") + ... .execute() + + Delete all records that were matched + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ + ... .when_matched_delete() + ... .execute() + """ + + if predicate is None: + self.matched_delete_all = True + else: + self.matched_delete_predicate = predicate + return self + + def when_not_matched_insert( + self, updates: Dict[str, str], predicate: Optional[str] = None + ) -> "TableMerger": + """Insert a new row to the target table based on the rules defined by ``updates``. If a + ``predicate`` is specified, then it must evaluate to true for the new row to be inserted. + + Args: + updates (dict): a mapping of column name to insert SQL expression. + predicate (str | None, optional): SQL like predicate on when to insert. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ + ... .when_not_matched_insert( + ... updates = { + ... "x": "source.x", + ... "y": "source.y" + ... } + ... ).execute() + """ + + self.not_matched_insert_updates = updates + self.not_matched_insert_predicate = predicate + + return self + + def when_not_matched_insert_all( + self, predicate: Optional[str] = None + ) -> "TableMerger": + """Insert a new row to the target table, updating all source fields to target fields. Source and target are + required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for + the new row to be inserted. + + Args: + predicate (str | None, optional): SQL like predicate on when to insert. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + Examples: + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ + ... .when_not_matched_insert_all().execute() + """ + + src_alias = (self.source_alias + ".") if self.source_alias is not None else "" + trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + self.not_matched_insert_updates = { + f"{trgt_alias}{col.name}": f"{src_alias}{col.name}" + for col in self.source.schema + } + self.not_matched_insert_predicate = predicate + return self + + def when_not_matched_by_source_update( + self, updates: Dict[str, str], predicate: Optional[str] = None + ) -> "TableMerger": + """Update a target row that has no matches in the source based on the rules defined by ``updates``. + If a ``predicate`` is specified, then it must evaluate to true for the row to be updated. + + Args: + updates (dict): a mapping of column name to update SQL expression. + predicate (str | None, optional): SQL like predicate on when to update. Defaults to None. + + Returns: + TableMerger: TableMerger Object + + >>> from deltalake import DeltaTable + >>> import pyarrow as pa + >>> data = pa.table({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> dt = DeltaTable("tmp") + >>> dt.merge(source=data, predicate='target.x = source.x', source_alias='source', target_alias='target') \ + ... .when_not_matched_by_source_update( + ... predicate = "y > 3" + ... updates = { + ... "y": "0", + ... } + ... ).execute() + """ + self.not_matched_by_source_update_updates = updates + self.not_matched_by_source_update_predicate = predicate + return self + + def when_not_matched_by_source_delete( + self, predicate: Optional[str] = None + ) -> "TableMerger": + """Delete a target row that has no matches in the source from the table only if the given + ``predicate`` (if specified) is true for the target row. + + Args: + updates (dict): a mapping of column name to update SQL expression. + predicate (str | None, optional): SQL like predicate on when to delete when not matched by source. Defaults to None. + + Returns: + TableMerger: TableMerger Object + """ + + if predicate is None: + self.not_matched_by_source_delete_all = True + else: + self.not_matched_by_source_delete_predicate = predicate + return self + + def execute(self) -> Dict[str, Any]: + """Executes MERGE with the previously provided settings in Rust with Apache Datafusion query engine. + + Returns: + Dict[str, any]: metrics + """ + metrics = self.table._table.merge_execute( + source=self.source, + predicate=self.predicate, + source_alias=self.source_alias, + target_alias=self.target_alias, + safe_cast=self.safe_cast, + writer_properties=self.writer_properties, + matched_update_updates=self.matched_update_updates, + matched_update_predicate=self.matched_update_predicate, + matched_delete_predicate=self.matched_delete_predicate, + matched_delete_all=self.matched_delete_all, + not_matched_insert_updates=self.not_matched_insert_updates, + not_matched_insert_predicate=self.not_matched_insert_predicate, + not_matched_by_source_update_updates=self.not_matched_by_source_update_updates, + not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate, + not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate, + not_matched_by_source_delete_all=self.not_matched_by_source_delete_all, + ) + self.table.update_incremental() + return json.loads(metrics) + + class TableOptimizer: """API for various table optimization commands.""" diff --git a/python/src/lib.rs b/python/src/lib.rs index 2dc1ab5dd7..2f46436984 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,13 +15,18 @@ use std::time::{SystemTime, UNIX_EPOCH}; use arrow::pyarrow::PyArrowType; use chrono::{DateTime, Duration, FixedOffset, Utc}; use deltalake::arrow::compute::concat_batches; +use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; use deltalake::arrow::record_batch::RecordBatch; +use deltalake::arrow::record_batch::RecordBatchReader; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::create_checkpoint; +use deltalake::datafusion::datasource::memory::MemTable; +use deltalake::datafusion::datasource::provider::TableProvider; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; use deltalake::operations::delete::DeleteBuilder; +use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::transaction::commit; @@ -395,6 +400,185 @@ impl RawDeltaTable { Ok(serde_json::to_string(&metrics).unwrap()) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (source, + predicate, + source_alias = None, + target_alias = None, + safe_cast = false, + writer_properties = None, + matched_update_updates = None, + matched_update_predicate = None, + matched_delete_predicate = None, + matched_delete_all = None, + not_matched_insert_updates = None, + not_matched_insert_predicate = None, + not_matched_by_source_update_updates = None, + not_matched_by_source_update_predicate = None, + not_matched_by_source_delete_predicate = None, + not_matched_by_source_delete_all = None, + ))] + pub fn merge_execute( + &mut self, + source: PyArrowType, + predicate: String, + source_alias: Option, + target_alias: Option, + safe_cast: bool, + writer_properties: Option>, + matched_update_updates: Option>, + matched_update_predicate: Option, + matched_delete_predicate: Option, + matched_delete_all: Option, + not_matched_insert_updates: Option>, + not_matched_insert_predicate: Option, + not_matched_by_source_update_updates: Option>, + not_matched_by_source_update_predicate: Option, + not_matched_by_source_delete_predicate: Option, + not_matched_by_source_delete_all: Option, + ) -> PyResult { + let ctx = SessionContext::new(); + let schema = source.0.schema(); + let batches = vec![source.0.map(|batch| batch.unwrap()).collect::>()]; + let table_provider: Arc = + Arc::new(MemTable::try_new(schema, batches).unwrap()); + let source_df = ctx.read_table(table_provider).unwrap(); + + let mut cmd = MergeBuilder::new( + self._table.object_store(), + self._table.state.clone(), + predicate, + source_df, + ) + .with_safe_cast(safe_cast); + + if let Some(src_alias) = source_alias { + cmd = cmd.with_source_alias(src_alias); + } + + if let Some(trgt_alias) = target_alias { + cmd = cmd.with_target_alias(trgt_alias); + } + + if let Some(writer_props) = writer_properties { + let mut properties = WriterProperties::builder(); + let data_page_size_limit = writer_props.get("data_page_size_limit"); + let dictionary_page_size_limit = writer_props.get("dictionary_page_size_limit"); + let data_page_row_count_limit = writer_props.get("data_page_row_count_limit"); + let write_batch_size = writer_props.get("write_batch_size"); + let max_row_group_size = writer_props.get("max_row_group_size"); + + if let Some(data_page_size) = data_page_size_limit { + properties = properties.set_data_page_size_limit(*data_page_size); + } + if let Some(dictionary_page_size) = dictionary_page_size_limit { + properties = properties.set_dictionary_page_size_limit(*dictionary_page_size); + } + if let Some(data_page_row_count) = data_page_row_count_limit { + properties = properties.set_data_page_row_count_limit(*data_page_row_count); + } + if let Some(batch_size) = write_batch_size { + properties = properties.set_write_batch_size(*batch_size); + } + if let Some(row_group_size) = max_row_group_size { + properties = properties.set_max_row_group_size(*row_group_size); + } + cmd = cmd.with_writer_properties(properties.build()); + } + + if let Some(mu_updates) = matched_update_updates { + if let Some(mu_predicate) = matched_update_predicate { + cmd = cmd + .when_matched_update(|mut update| { + for (col_name, expression) in mu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(mu_predicate) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_matched_update(|mut update| { + for (col_name, expression) in mu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update + }) + .map_err(PythonError::from)?; + } + } + + if let Some(_md_delete_all) = matched_delete_all { + cmd = cmd + .when_matched_delete(|delete| delete) + .map_err(PythonError::from)?; + } else if let Some(md_predicate) = matched_delete_predicate { + cmd = cmd + .when_matched_delete(|delete| delete.predicate(md_predicate)) + .map_err(PythonError::from)?; + } + + if let Some(nmi_updates) = not_matched_insert_updates { + if let Some(nmi_predicate) = not_matched_insert_predicate { + cmd = cmd + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in nmi_updates { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert.predicate(nmi_predicate) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_insert(|mut insert| { + for (col_name, expression) in nmi_updates { + insert = insert.set(col_name.clone(), expression.clone()); + } + insert + }) + .map_err(PythonError::from)?; + } + } + + if let Some(nmbsu_updates) = not_matched_by_source_update_updates { + if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { + cmd = cmd + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in nmbsu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update.predicate(nmbsu_predicate) + }) + .map_err(PythonError::from)?; + } else { + cmd = cmd + .when_not_matched_by_source_update(|mut update| { + for (col_name, expression) in nmbsu_updates { + update = update.update(col_name.clone(), expression.clone()); + } + update + }) + .map_err(PythonError::from)?; + } + } + + if let Some(_nmbs_delete_all) = not_matched_by_source_delete_all { + cmd = cmd + .when_not_matched_by_source_delete(|delete| delete) + .map_err(PythonError::from)?; + } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { + cmd = cmd + .when_not_matched_by_source_delete(|delete| delete.predicate(nmbs_predicate)) + .map_err(PythonError::from)?; + } + + let (table, metrics) = rt()? + .block_on(cmd.into_future()) + .map_err(PythonError::from)?; + self._table.state = table.state; + Ok(serde_json::to_string(&metrics).unwrap()) + } + // Run the restore command on the Delta Table: restore table to a given version or datetime #[pyo3(signature = (target, *, ignore_missing_files = false, protocol_downgrade_allowed = false))] pub fn restore( diff --git a/python/stubs/pyarrow/__init__.pyi b/python/stubs/pyarrow/__init__.pyi index fb01f5796d..f8c9d152aa 100644 --- a/python/stubs/pyarrow/__init__.pyi +++ b/python/stubs/pyarrow/__init__.pyi @@ -4,6 +4,7 @@ __version__: str Schema: Any Table: Any RecordBatch: Any +RecordBatchReader: Any Field: Any DataType: Any ListType: Any diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 6ddb68a526..d10de6bc00 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -235,3 +235,16 @@ def existing_table(tmp_path: pathlib.Path, sample_data: pa.Table): path = str(tmp_path) write_deltalake(path, sample_data) return DeltaTable(path) + + +@pytest.fixture() +def sample_table(): + nrows = 5 + return pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py new file mode 100644 index 0000000000..fc08563443 --- /dev/null +++ b/python/tests/test_merge.py @@ -0,0 +1,492 @@ +import pathlib + +import pyarrow as pa + +from deltalake import DeltaTable, write_deltalake + + +def test_merge_when_matched_delete_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["5"]), + "weight": pa.array([105], pa.int32()), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_matched_delete().execute() + + nrows = 4 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_matched_delete_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["5", "4"]), + "weight": pa.array([1, 2], pa.int64()), + "sold": pa.array([1, 2], pa.int32()), + "deleted": pa.array([True, False]), + "customer": pa.array(["Adam", "Patrick"]), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_matched_delete("s.deleted = True").execute() + + nrows = 4 + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4"]), + "price": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + "deleted": pa.array([False] * nrows), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_matched_update_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_matched_update({"price": "s.price", "sold": "s.sold"}).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_matched_update_all_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([True, True]), + "weight": pa.array([10, 15], pa.int64()), + } + ) + + dt.merge( + source=source_table, + predicate="t.id = s.id", + source_alias="s", + target_alias="t", + ).when_matched_update_all().execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), + "deleted": pa.array([False, False, False, True, True]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_matched_update_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, True]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_matched_update( + updates={"price": "source.price", "sold": "source.sold"}, + predicate="source.deleted = False", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 10, 4], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 4], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_insert_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["4", "6"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_insert( + updates={ + "id": "source.id", + "price": "source.price", + "sold": "source.sold", + "deleted": "False", + } + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6"]), + "price": pa.array([0, 1, 2, 3, 4, 100], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 20], pa.int32()), + "deleted": pa.array([False] * 6), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_insert_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "10"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_insert( + updates={ + "id": "source.id", + "price": "source.price", + "sold": "source.sold", + "deleted": "False", + }, + predicate="source.price < bigint'50'", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6"]), + "price": pa.array([0, 1, 2, 3, 4, 10], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 10], pa.int32()), + "deleted": pa.array([False] * 6), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_insert_all_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "10"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([None, None], pa.bool_()), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_insert_all( + predicate="source.price < bigint'50'", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5", "6"]), + "price": pa.array([0, 1, 2, 3, 4, 10], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 4, 10], pa.int32()), + "deleted": pa.array([False, False, False, False, False, None]), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_by_source_update_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "7"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_by_source_update( + updates={ + "sold": "int'10'", + } + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 3, 4], pa.int64()), + "sold": pa.array([10, 10, 10, 10, 10], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_by_source_update_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "7"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_by_source_update( + updates={ + "sold": "int'10'", + }, + predicate="target.price > bigint'3'", + ).execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4", "5"]), + "price": pa.array([0, 1, 2, 3, 4], pa.int64()), + "sold": pa.array([0, 1, 2, 3, 10], pa.int32()), + "deleted": pa.array([False] * 5), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_by_source_delete_with_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["6", "7"]), + "price": pa.array([10, 100], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + "deleted": pa.array([False, False]), + } + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_by_source_delete(predicate="target.price > bigint'3'").execute() + + expected = pa.table( + { + "id": pa.array(["1", "2", "3", "4"]), + "price": pa.array( + [0, 1, 2, 3], + pa.int64(), + ), + "sold": pa.array([0, 1, 2, 3], pa.int32()), + "deleted": pa.array([False] * 4), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + + +def test_merge_when_not_matched_by_source_delete_wo_predicate( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + {"id": pa.array(["4", "5"]), "weight": pa.array([1.5, 1.6], pa.float64())} + ) + + dt.merge( + source=source_table, + source_alias="source", + target_alias="target", + predicate="target.id = source.id", + ).when_not_matched_by_source_delete().execute() + + expected = pa.table( + { + "id": pa.array(["4", "5"]), + "price": pa.array( + [3, 4], + pa.int64(), + ), + "sold": pa.array([3, 4], pa.int32()), + "deleted": pa.array([False] * 2), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index fa6f586ad0..46a2c540bf 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -60,6 +60,7 @@ use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr use datafusion_physical_expr::{create_physical_expr, expressions, PhysicalExpr}; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; +use serde::Serialize; use serde_json::{Map, Value}; use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; @@ -119,12 +120,13 @@ pub struct MergeBuilder { impl MergeBuilder { /// Create a new [`MergeBuilder`] - pub fn new( + pub fn new>( object_store: ObjectStoreRef, snapshot: DeltaTableState, - predicate: Expression, + predicate: E, source: DataFrame, ) -> Self { + let predicate = predicate.into(); Self { predicate, source, @@ -527,7 +529,7 @@ impl MergeOperationConfig { } } -#[derive(Default)] +#[derive(Default, Serialize)] /// Metrics for the Merge Operation pub struct MergeMetrics { /// Number of rows in the source data