Skip to content

Commit

Permalink
Add scan.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 31, 2024
1 parent 88c5dc4 commit 179cbce
Showing 1 changed file with 119 additions and 1 deletion.
120 changes: 119 additions & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from itertools import product
from numbers import Integral, Number
from operator import add
from typing import TYPE_CHECKING, Any, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Sequence, Union
from warnings import warn

import ndindex
Expand All @@ -22,6 +22,7 @@
from cubed.core.plan import Plan, new_temp_path
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.blockwise import key_to_slices
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.spec import spec_from_config
from cubed.storage.backend import open_backend_array
Expand Down Expand Up @@ -1442,3 +1443,120 @@ def smallest_blockdim(blockdims):
m = ntd[0]
out = ntd
return out


def wrapper_binop(
out: np.ndarray,
left: Array,
right: Array,
*,
binop: Callable,
block_id: tuple[int, ...],
axis: int,
identity: Any,
) -> Array:
# print(type(out), out.shape)
# print(block_id)
# print("left", left)
# print("right", right)
left_slicer = key_to_slices(block_id, left)
right_slicer = list(left_slicer)

# For the first block, we add the identity element
# For all other blocks `k`, we add the `k-1` element along `axis`
right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis])
right_slicer = tuple(right_slicer)
right_ = right[right_slicer] if block_id[axis] > 0 else identity
# print("left", left[left_slicer].shape)
# print("right", right_.shape)
return binop(left[left_slicer], right_)


def scan(
array: "Array",
func: Callable,
*,
preop: Callable,
binop: Callable,
identity: Any,
axis: int,
dtype=None,
) -> Array:
"""
Generic parallel scan.
Parameters
----------
x: Cubed Array
func: callable
Scan or cumulative function like np.cumsum or np.cumprod
preop: callable
Function applied blockwise that reduces each block to a single value
along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
identity: Any
Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
axis: int
dtype: dtype
Notes
-----
This method uses a variant of the Blelloch (1989) alogrithm.
Returns
-------
Array
See also
--------
cumsum
cumprod
"""
# Blelloch (1990) out-of-core algorithm.
# 1. First, scan blockwise
scanned = blockwise(func, "ij", array, "ij", axis=axis)
# If there is only a single chunk, we can be done
if array.numblocks[-1] == 1:
return scanned

# 2. Calculate the blockwise reduction using `preop`
# TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
reduced = blockwise(
preop, "ij", array, "ij", axis=axis, adjust_chunks={"j": 1}, keepdims=True
)

# 3. Now scan `reduced` to generate the increments for each block of `scanned`.
# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
# Instead we generalize recursively apply the scan to `reduced`.
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
new_chunks = reduced.chunksize[:-1] + (new_chunksize,)
merged = merge_chunks(reduced, new_chunks)

# 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
increment = scan(
merged, func, preop=preop, binop=binop, identity=identity, axis=axis
)

# 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
# Use map_direct since the chunks of increment and scanned aren't aligned anymore.
assert increment.shape[axis] == scanned.numblocks[axis]
# 5. Bada-bing, bada-boom.
return map_direct(
partial(wrapper_binop, binop=binop, axis=axis, identity=identity),
scanned,
increment,
shape=scanned.shape,
dtype=scanned.dtype,
chunks=scanned.chunks,
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
)


# result = scan(
# array, preop=np.sum, func=np.cumsum, binop=np.add, identity=0, axis=-1
# )
# print(result)
# print(result.compute())
# np.testing.assert_equal(result, np.cumsum(array.compute(), axis=-1))

0 comments on commit 179cbce

Please sign in to comment.