Skip to content

Commit

Permalink
simplify __array_ufunc__ check
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Sep 12, 2022
1 parent 67d7efc commit cdcb3fb
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods
from .options import OPTIONS, _get_keep_attrs
from .pycompat import dask_array_type
from .pycompat import is_duck_array


class SupportsArithmetic:
Expand All @@ -33,20 +33,21 @@ class SupportsArithmetic:

# TODO: allow extending this with some sort of registration system
_HANDLED_TYPES = (
np.ndarray,
np.generic,
numbers.Number,
bytes,
str,
) + dask_array_type
)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
from .computation import apply_ufunc

# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
out = kwargs.get("out", ())
for x in inputs + out:
if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)):
if not is_duck_array(x) and not isinstance(
x, self._HANDLED_TYPES + (SupportsArithmetic,)
):
return NotImplemented

if ufunc.signature is not None:
Expand Down

0 comments on commit cdcb3fb

Please sign in to comment.