Skip to content

Commit

Permalink
NumpyOps: Better type-casting in asarray (#656)
Browse files Browse the repository at this point in the history
* `NumpyOps`: Better type-casting in `asarray`

* Simplify `dtype` check

* Update thinc/backends/numpy_ops.pyx

Co-authored-by: Adriane Boyd <[email protected]>

* Simplify casting further, avoid copies if possible

* Remove no-op

Co-authored-by: Adriane Boyd <[email protected]>
  • Loading branch information
shadeMe and adrianeboyd authored May 17, 2022
1 parent abbe0ff commit d2d7917
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,20 @@ class NumpyOps(Ops):

def asarray(self, data, dtype=None):
if isinstance(data, self.xp.ndarray):
if dtype is not None:
return self.xp.asarray(data, dtype=dtype)
else:
return self.xp.asarray(data)
array = data
elif hasattr(data, 'numpy'):
# Handles PyTorch Tensor
return data.numpy()
array = data.numpy()
elif hasattr(data, "get"):
return data.get()
elif dtype is not None:
return self.xp.array(data, dtype=dtype)
array = data.get()
else:
return self.xp.array(data)
array = self.xp.array(data)

if dtype is not None:
array = array.astype(dtype=dtype, copy=False)

return array


def alloc(self, shape: Shape, *, dtype: Optional[DTypes] = "float32", zeros: bool = True) -> ArrayXd:
if zeros:
Expand Down

0 comments on commit d2d7917

Please sign in to comment.