Skip to content

Commit

Permalink
Torch backwards compatibility fixes (#610)
Browse files Browse the repository at this point in the history
* Fix compatibility with Torch without torch.cuda.amp.common

* Disable PyTorch-based activation tests pre-PyTorch 1.9.0

* Don't use gradient scaling unconditionally in PyTorch wrapper test

* Disable gradient scaling tests on older PyTorch versions

* Set minimum required PyTorch version to 1.6.0

* Check that torch>=1.9.0 for mixed-precision training

Torch versions prior to 1.9.0 do not have the functionality that we need
for mixed-precision training and gradient scaling.

* Refine exception message for mixed-precision training
  • Loading branch information
danieldk committed Mar 16, 2022
1 parent 85b291d commit f50905f
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ cuda115 =
datasets =
ml_datasets>=0.2.0,<0.3.0
torch =
torch>=1.5.0
torch>=1.6.0
tensorflow =
tensorflow>=2.0.0,<2.6.0
mxnet =
Expand Down
6 changes: 6 additions & 0 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
pass

from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..util import has_torch_amp
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
from ..optimizers import Optimizer
Expand Down Expand Up @@ -42,6 +43,11 @@ def __init__(
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
):
if mixed_precision and not has_torch_amp:
raise ValueError(
"Mixed-precision training is not supported, requires capable GPU and torch>=1.9.0"
)

super().__init__(model, config, optimizer)

if grad_scaler is None:
Expand Down
9 changes: 7 additions & 2 deletions thinc/shims/pytorch_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Iterable, List, Union, cast

from ..util import is_torch_array
from ..util import has_torch_amp, is_torch_array

try:
import torch
Expand All @@ -23,7 +23,7 @@ class PyTorchGradScaler:
def __init__(
self,
enabled: bool = False,
init_scale: float = 2.0 ** 16,
init_scale: float = 2.0**16,
backoff_factor: float = 0.5,
growth_factor: float = 2.0,
growth_interval: int = 2000,
Expand All @@ -50,6 +50,11 @@ def __init__(
When no overflows were found for this number of steps, the scale will
be multiplied by "growth_factor".
"""
if enabled and not has_torch_amp:
raise ValueError(
"Gradient scaling is not supported, requires capable GPU and torch>=1.9.0"
)

self._enabled = enabled
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
Expand Down
4 changes: 3 additions & 1 deletion thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from hypothesis import given, settings
from hypothesis.strategies import composite, integers
from numpy.testing import assert_allclose
from packaging.version import Version
from thinc.api import NumpyOps, CupyOps, Ops, get_ops
from thinc.api import get_current_ops, use_ops
from thinc.util import has_torch, torch2xp, xp2torch
from thinc.util import has_torch, torch2xp, xp2torch, torch_version
from thinc.api import fix_random_seed
from thinc.api import LSTM
from thinc.types import Floats2d
Expand Down Expand Up @@ -1001,6 +1002,7 @@ def test_ngrams():


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(torch_version < Version("1.9.0"), reason="needs PyTorch 1.9.0")
@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize("torch_func", TORCH_FUNCS)
Expand Down
19 changes: 18 additions & 1 deletion thinc/tests/layers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision):
PyTorchWrapper_v2(
pytorch_layer.cuda(),
mixed_precision=mixed_precision,
grad_scaler=PyTorchGradScaler(enabled=True, init_scale=2.0 ** 16),
grad_scaler=PyTorchGradScaler(
enabled=mixed_precision, init_scale=2.0**16
),
).initialize(),
)
# pytorch allocator is set in PyTorchShim
Expand Down Expand Up @@ -146,3 +148,18 @@ def test_pytorch_convert_inputs(data, n_args, kwargs_keys):
convert_inputs = model.attrs["convert_inputs"]
Y, backprop = convert_inputs(model, data, is_train=True)
check_input_converters(Y, backprop, data, n_args, kwargs_keys, torch.Tensor)


@pytest.mark.skipif(not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU")
@pytest.mark.skipif(
has_torch_amp, reason="needs PyTorch without mixed-precision support"
)
def test_raises_on_old_pytorch():
import torch.nn

pytorch_layer = torch.nn.Linear(5, 5)
with pytest.raises(ValueError, match=r"not supported.*1.9.0"):
PyTorchWrapper_v2(
pytorch_layer.cuda(),
mixed_precision=True,
)
30 changes: 23 additions & 7 deletions thinc/tests/shims/test_pytorch_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from hypothesis import given, settings
from hypothesis.strategies import lists, one_of, tuples
from thinc.util import has_torch, has_torch_gpu, is_torch_array
from thinc.util import has_torch, has_torch_amp, has_torch_gpu
from thinc.util import is_torch_array
from thinc.api import PyTorchGradScaler

from ..strategies import ndarrays
Expand All @@ -22,6 +23,9 @@ def tensors():

@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(
not has_torch_amp, reason="requires PyTorch with mixed-precision support"
)
@given(X=one_of(tensors(), lists(tensors()), tuples(tensors())))
@settings(deadline=None)
def test_scale_random_inputs(X):
Expand All @@ -32,16 +36,19 @@ def test_scale_random_inputs(X):
scaler.to_(device_id)

if is_torch_array(X):
assert torch.allclose(scaler.scale(X), X * 2.0 ** 16)
assert torch.allclose(scaler.scale(X), X * 2.0**16)
else:
scaled1 = scaler.scale(X)
scaled2 = [t * 2.0 ** 16 for t in X]
scaled2 = [t * 2.0**16 for t in X]
for t1, t2 in zip(scaled1, scaled2):
assert torch.allclose(t1, t2)


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(
not has_torch_amp, reason="requires PyTorch with mixed-precision support"
)
def test_grad_scaler():
import torch

Expand All @@ -53,10 +60,10 @@ def test_grad_scaler():
# Test that scaling works as expected.
t = torch.tensor([1.0], device=device_id)
assert scaler.scale([torch.tensor([1.0], device=device_id)]) == [
torch.tensor([2.0 ** 16], device=device_id)
torch.tensor([2.0**16], device=device_id)
]
assert scaler.scale(torch.tensor([1.0], device=device_id)) == torch.tensor(
[2.0 ** 16], device=device_id
[2.0**16], device=device_id
)
with pytest.raises(ValueError):
scaler.scale("bogus")
Expand All @@ -65,7 +72,7 @@ def test_grad_scaler():

# Test infinity detection.
g = [
torch.tensor([2.0 ** 16], device=device_id),
torch.tensor([2.0**16], device=device_id),
torch.tensor([float("Inf")], device=device_id),
]

Expand All @@ -80,5 +87,14 @@ def test_grad_scaler():
# Since infinity was found, the scale should be halved from 2**16
# to 2**15 for the next step.
assert scaler.scale([torch.tensor([1.0], device=device_id)]) == [
torch.tensor([2.0 ** 15], device=device_id)
torch.tensor([2.0**15], device=device_id)
]


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(
has_torch_amp, reason="needs PyTorch without gradient scaling support"
)
def test_raises_on_old_pytorch():
with pytest.raises(ValueError, match=r"not supported.*1.9.0"):
PyTorchGradScaler(enabled=True)
8 changes: 7 additions & 1 deletion thinc/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Union, Sequence, cast, Dict, Optional, Callable, TypeVar
from typing import List, Tuple
import numpy
from packaging.version import Version
import random
import functools
from wasabi import table
Expand Down Expand Up @@ -31,11 +32,16 @@

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
has_torch_amp = not torch.cuda.amp.common.amp_definitely_not_available()
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
and not torch.cuda.amp.common.amp_definitely_not_available()
)
except ImportError: # pragma: no cover
has_torch = False
has_torch_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")

try: # pragma: no cover
import tensorflow.experimental.dlpack
Expand Down

0 comments on commit f50905f

Please sign in to comment.