From 68d39431e41384f4c8a46f9de96a32093507896d Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 25 Jul 2023 13:16:02 -0400 Subject: [PATCH 01/10] add scan to ChunkManager ABC --- xarray/core/parallelcompat.py | 36 +++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 26efc5fc412..fefc8517520 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -403,6 +403,42 @@ def reduction( """ raise NotImplementedError() + def scan( + self, + func: Callable, + binop: Callable, + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: np.dtype | None = None, + ): + """ + General version of a 1D scan, also known as a cumulative array reduction. + + Used in ``ffill`` and ``bfill` in xarray. + + Parameters + ---------- + func: callable + Cumulative function like np.cumsum or np.cumprod + binop: callable + Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` + ident: Number + Associated identity like ``np.cumsum->0`` or ``np.cumprod->1`` + arr: dask Array + axis: int, optional + dtype: dtype + + Returns + ------- + Chunked array + + See also + -------- + dask.array.cumreduction + """ + raise NotImplementedError() + @abstractmethod def apply_gufunc( self, From dff3131b639efd5e238bf7456bdf01e274c3fd16 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 25 Jul 2023 13:16:24 -0400 Subject: [PATCH 02/10] implement scan for dask using cumreduction --- xarray/core/daskmanager.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py index 56d8dc9e23a..a7757dbb382 100644 --- a/xarray/core/daskmanager.py +++ b/xarray/core/daskmanager.py @@ -97,6 +97,26 @@ def reduction( keepdims=keepdims, ) + def scan( + self, + func: Callable, + binop: Callable, + ident: float, + arr: T_ChunkedArray, + axis: int | None = None, + dtype: np.dtype | None = None, + ): + from dask.array import cumreduction + + return cumreduction( + func, + binop, + ident, + arr, + axis=axis, + dtype=dtype, + ) + def apply_gufunc( self, func: Callable, From 9cd95236e7e5d0eb8655568d3d1fe31226fc0bb5 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 25 Jul 2023 13:18:23 -0400 Subject: [PATCH 03/10] generalize push to work for non-dask chunked arrays --- xarray/core/dask_array_ops.py | 39 ----------------------------- xarray/core/duck_array_ops.py | 47 +++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index d2d3e4a6d1c..dd6d25045d5 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -53,42 +53,3 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): # See issue dask/dask#6516 coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs) return coeffs, residuals - - -def push(array, n, axis): - """ - Dask-aware bottleneck.push - """ - import bottleneck - import dask.array as da - import numpy as np - - def _fill_with_last_one(a, b): - # cumreduction apply the push func over all the blocks first so, the only missing part is filling - # the missing values using the last data of the previous chunk - return np.where(~np.isnan(b), b, a) - - if n is not None and 0 < n < array.shape[axis] - 1: - arange = da.broadcast_to( - da.arange( - array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype - ).reshape( - tuple(size if i == axis else 1 for i, size in enumerate(array.shape)) - ), - array.shape, - array.chunks, - ) - valid_arange = da.where(da.notnull(array), arange, np.nan) - valid_limits = (arange - push(valid_arange, None, axis)) <= n - # omit the forward fill that violate the limit - return da.where(valid_limits, push(array, None, axis), np.nan) - - # The method parameter makes that the tests for python 3.7 fails. - return da.reductions.cumreduction( - func=bottleneck.push, - binop=_fill_with_last_one, - ident=np.nan, - x=array, - axis=axis, - dtype=array.dtype, - ) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 4f245e59f73..676e8bf0210 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -671,11 +671,54 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) +def _chunked_push(array, n, axis): + """ + Chunk-aware bottleneck.push + """ + import bottleneck + import numpy as np + + chunkmanager = get_chunked_array_type(array) + xp = chunkmanager.array_api + + def _fill_with_last_one(a, b): + # cumreduction apply the push func over all the blocks first so, the only missing part is filling + # the missing values using the last data of the previous chunk + return np.where(~np.isnan(b), b, a) + + if n is not None and 0 < n < array.shape[axis] - 1: + arange = xp.arange( + array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype + ) + broadcasted_arange = xp.broadcast_to( + xp.reshape( + arange, + tuple(size if i == axis else 1 for i, size in enumerate(array.shape)), + ), + array.shape, + array.chunks, + ) + valid_arange = xp.where(xp.notnull(array), broadcasted_arange, np.nan) + valid_limits = (arange - push(valid_arange, None, axis)) <= n + # omit the forward fill that violate the limit + return xp.where(valid_limits, push(array, None, axis), np.nan) + + # The method parameter makes that the tests for python 3.7 fails. + return chunkmanager.scan( + func=bottleneck.push, + binop=_fill_with_last_one, + ident=np.nan, + arr=array, + axis=axis, + dtype=array.dtype, + ) + + def push(array, n, axis): from bottleneck import push - if is_duck_dask_array(array): - return dask_array_ops.push(array, n, axis) + if is_chunked_array(array): + return _chunked_push(array, n, axis) else: return push(array, n, axis) From ca8a58aa8bd7e0d670e75ac5f627e2d13b6d4f19 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 25 Jul 2023 13:24:26 -0400 Subject: [PATCH 04/10] whatsnew --- doc/whats-new.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6b1b8b4c69b..235a4d9ef44 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -73,6 +73,10 @@ Internal Changes - :py:func:`as_variable` now consistently includes the variable name in any exceptions raised. (:pull:`7995`). By `Peter Hill `_ +- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`, + potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.ffill` to + use non-dask chunked array types. + (:pull:`8019`) By `Tom Nicholas `_. .. _whats-new.2023.07.0: From b96b112a59ab32bf8a648ede4ace5192421f1815 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Tue, 25 Jul 2023 13:59:51 -0400 Subject: [PATCH 05/10] fix importerror --- xarray/core/daskmanager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py index a7757dbb382..c1567cedd8c 100644 --- a/xarray/core/daskmanager.py +++ b/xarray/core/daskmanager.py @@ -106,7 +106,7 @@ def scan( axis: int | None = None, dtype: np.dtype | None = None, ): - from dask.array import cumreduction + from dask.array.reductions import cumreduction return cumreduction( func, From 828321b100b02d795b2db3265d8e50385a5d38ea Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Wed, 26 Jul 2023 09:40:32 -0400 Subject: [PATCH 06/10] Allow arbitrary kwargs Co-authored-by: Deepak Cherian --- xarray/core/daskmanager.py | 2 ++ xarray/core/parallelcompat.py | 1 + 2 files changed, 3 insertions(+) diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py index c1567cedd8c..153c2825426 100644 --- a/xarray/core/daskmanager.py +++ b/xarray/core/daskmanager.py @@ -105,6 +105,7 @@ def scan( arr: T_ChunkedArray, axis: int | None = None, dtype: np.dtype | None = None, + **kwargs, ): from dask.array.reductions import cumreduction @@ -115,6 +116,7 @@ def scan( arr, axis=axis, dtype=dtype, + **kwargs, ) def apply_gufunc( diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index fefc8517520..4d02f623272 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -411,6 +411,7 @@ def scan( arr: T_ChunkedArray, axis: int | None = None, dtype: np.dtype | None = None, + **kwargs, ): """ General version of a 1D scan, also known as a cumulative array reduction. From b5eb57d1b1d50646d9d2757dafba024b85a4e9b2 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Wed, 26 Jul 2023 09:41:17 -0400 Subject: [PATCH 07/10] Type hint return value of T_ChunkedArray Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/parallelcompat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 4d02f623272..12afa44d9b5 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -412,7 +412,7 @@ def scan( axis: int | None = None, dtype: np.dtype | None = None, **kwargs, - ): + ) -> T_ChunkedArray: """ General version of a 1D scan, also known as a cumulative array reduction. From 9fb84611c336543987676a6068b9c14797055808 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Wed, 26 Jul 2023 09:42:09 -0400 Subject: [PATCH 08/10] Type hint return value of Dask array --- xarray/core/daskmanager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/daskmanager.py b/xarray/core/daskmanager.py index 153c2825426..efa04bc3df2 100644 --- a/xarray/core/daskmanager.py +++ b/xarray/core/daskmanager.py @@ -106,7 +106,7 @@ def scan( axis: int | None = None, dtype: np.dtype | None = None, **kwargs, - ): + ) -> DaskArray: from dask.array.reductions import cumreduction return cumreduction( From 1c3b59d496deb3a6a832ef6ef17a1ab9b9325887 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Tue, 12 Dec 2023 08:30:38 -0700 Subject: [PATCH 09/10] ffill -> bfill in doc/whats-new.rst Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 235a4d9ef44..88351694498 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -74,7 +74,7 @@ Internal Changes - :py:func:`as_variable` now consistently includes the variable name in any exceptions raised. (:pull:`7995`). By `Peter Hill `_ - Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`, - potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.ffill` to + potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to use non-dask chunked array types. (:pull:`8019`) By `Tom Nicholas `_. From 43551ee15dd87d259b82b92c8f32901aee4a993e Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Wed, 13 Dec 2023 17:34:40 -0500 Subject: [PATCH 10/10] hopefully fix docs warning --- xarray/core/parallelcompat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/parallelcompat.py b/xarray/core/parallelcompat.py index 0cadfbe3409..37542925dde 100644 --- a/xarray/core/parallelcompat.py +++ b/xarray/core/parallelcompat.py @@ -416,7 +416,7 @@ def scan( """ General version of a 1D scan, also known as a cumulative array reduction. - Used in ``ffill`` and ``bfill` in xarray. + Used in ``ffill`` and ``bfill`` in xarray. Parameters ----------