Skip to content

Commit

Permalink
Check that torch>=1.9.0 for mixed-precision training
Browse files Browse the repository at this point in the history
Torch versions prior to 1.9.0 do not have the functionality that we need
for mixed-precision training and gradient scaling.
  • Loading branch information
danieldk committed Mar 15, 2022
1 parent 8f1f8cc commit 85ab7a5
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 4 deletions.
6 changes: 6 additions & 0 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
pass

from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..util import has_torch_amp
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
from ..optimizers import Optimizer
Expand Down Expand Up @@ -42,6 +43,11 @@ def __init__(
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
):
if mixed_precision and not has_torch_amp:
raise ValueError(
"Mixed-precision training is not supported, upgrade to torch>=1.9.0"
)

super().__init__(model, config, optimizer)

if grad_scaler is None:
Expand Down
9 changes: 7 additions & 2 deletions thinc/shims/pytorch_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Iterable, List, Union, cast

from ..util import is_torch_array
from ..util import has_torch_amp, is_torch_array

try:
import torch
Expand All @@ -23,7 +23,7 @@ class PyTorchGradScaler:
def __init__(
self,
enabled: bool = False,
init_scale: float = 2.0 ** 16,
init_scale: float = 2.0**16,
backoff_factor: float = 0.5,
growth_factor: float = 2.0,
growth_interval: int = 2000,
Expand All @@ -50,6 +50,11 @@ def __init__(
When no overflows were found for this number of steps, the scale will
be multiplied by "growth_factor".
"""
if enabled and not has_torch_amp:
raise ValueError(
"Gradient scaling is not supported, upgrade to torch>=1.9.0"
)

self._enabled = enabled
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
Expand Down
15 changes: 15 additions & 0 deletions thinc/tests/layers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,18 @@ def test_pytorch_convert_inputs(data, n_args, kwargs_keys):
convert_inputs = model.attrs["convert_inputs"]
Y, backprop = convert_inputs(model, data, is_train=True)
check_input_converters(Y, backprop, data, n_args, kwargs_keys, torch.Tensor)


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(
has_torch_amp, reason="needs PyTorch without mixed-precision support"
)
def test_raises_on_old_pytorch():
import torch.nn

pytorch_layer = torch.nn.Linear(5, 5)
with pytest.raises(ValueError, match=r"not supported.*1.9.0"):
PyTorchWrapper_v2(
pytorch_layer.cuda(),
mixed_precision=True,
)
9 changes: 9 additions & 0 deletions thinc/tests/shims/test_pytorch_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,12 @@ def test_grad_scaler():
assert scaler.scale([torch.tensor([1.0], device=device_id)]) == [
torch.tensor([2.0**15], device=device_id)
]


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(
has_torch_amp, reason="needs PyTorch without gradient scaling support"
)
def test_raises_on_old_pytorch():
with pytest.raises(ValueError, match=r"not supported.*1.9.0"):
PyTorchGradScaler(enabled=True)
4 changes: 2 additions & 2 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
torch_version = Version(str(torch.__version__))
has_torch_amp = (
hasattr(torch.cuda.amp, "common")
torch_version >= Version("1.9.0")
and not torch.cuda.amp.common.amp_definitely_not_available()
)
torch_version = Version(str(torch.__version__))
except ImportError: # pragma: no cover
has_torch = False
has_torch_gpu = False
Expand Down

0 comments on commit 85ab7a5

Please sign in to comment.