Skip to content

Commit

Permalink
Cherry-pick and merge reqd. changes from explosion#646
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe committed May 17, 2022
1 parent 7e67c12 commit 3c43448
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
3 changes: 2 additions & 1 deletion thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from . import _custom_kernels
from ..types import DeviceTypes
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_torch_array, is_tensorflow_array, is_mxnet_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..util import is_cupy_array


@registry.ops("CupyOps")
Expand Down
4 changes: 2 additions & 2 deletions thinc/tests/layers/test_tensorflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from thinc.api import Adam, ArgsKwargs, Linear, Model, TensorFlowWrapper
from thinc.api import get_current_ops, keras_subclass, tensorflow2xp, xp2tensorflow
from thinc.util import has_cupy, has_tensorflow, to_categorical
from thinc.util import has_tensorflow, to_categorical, gpu_is_available

from ..util import check_input_converters, make_tempdir

Expand Down Expand Up @@ -362,7 +362,7 @@ def test_tensorflow_wrapper_to_cpu(tf_model):


@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
@pytest.mark.skipif(not has_cupy, reason="needs cupy")
@pytest.mark.skipif(not gpu_is_available(), reason="needs GPU/cupy")
def test_tensorflow_wrapper_to_gpu(model, X):
model.to_gpu(0)

Expand Down
15 changes: 15 additions & 0 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def get_array_module(arr): # pragma: no cover


def gpu_is_available():
if not has_cupy:
return False

try:
cupy.cuda.runtime.getDeviceCount()
return True
Expand Down Expand Up @@ -124,6 +127,10 @@ def is_torch_array(obj: Any) -> bool: # pragma: no cover
return False


def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_torch_array(obj) and obj.is_cuda


def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover
if not has_tensorflow:
return False
Expand All @@ -133,6 +140,10 @@ def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover
return False


def is_tensorflow_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_tensorflow_array(obj) and "GPU:" in obj.device


def is_mxnet_array(obj: Any) -> bool: # pragma: no cover
if not has_mxnet:
return False
Expand All @@ -142,6 +153,10 @@ def is_mxnet_array(obj: Any) -> bool: # pragma: no cover
return False


def is_mxnet_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_mxnet_array(obj) and obj.context.device_type != "cpu"


def to_numpy(data): # pragma: no cover
if isinstance(data, numpy.ndarray):
return data
Expand Down

0 comments on commit 3c43448

Please sign in to comment.