From 8b9b86235092ac7888901dadfc3aa255a61350a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 14 Mar 2022 11:05:29 +0100 Subject: [PATCH 1/7] Fix compatibility with Torch without torch.cuda.amp.common --- thinc/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thinc/util.py b/thinc/util.py index ec2930d16..faae1a307 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -31,7 +31,7 @@ has_torch = True has_torch_gpu = torch.cuda.device_count() != 0 - has_torch_amp = not torch.cuda.amp.common.amp_definitely_not_available() + has_torch_amp = hasattr(torch.cuda.amp, "common") and not torch.cuda.amp.common.amp_definitely_not_available() except ImportError: # pragma: no cover has_torch = False has_torch_gpu = False From ec0da84f96fb652a11086623b5c8a6c4688e76d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 14 Mar 2022 12:22:39 +0100 Subject: [PATCH 2/7] Disable PyTorch-based activation tests pre-PyTorch 1.9.0 --- thinc/tests/backends/test_ops.py | 4 +++- thinc/util.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index e1f9c8274..6fe8c32b0 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -5,9 +5,10 @@ from hypothesis import given, settings from hypothesis.strategies import composite, integers from numpy.testing import assert_allclose +from packaging.version import Version from thinc.api import NumpyOps, CupyOps, Ops, get_ops from thinc.api import get_current_ops, use_ops -from thinc.util import has_torch, torch2xp, xp2torch +from thinc.util import has_torch, torch2xp, xp2torch, torch_version from thinc.api import fix_random_seed from thinc.api import LSTM from thinc.types import Floats2d @@ -1001,6 +1002,7 @@ def test_ngrams(): @pytest.mark.skipif(not has_torch, reason="needs PyTorch") +@pytest.mark.skipif(torch_version < Version("1.9.0"), reason="needs PyTorch 1.9.0") @pytest.mark.parametrize("ops", ALL_OPS) @pytest.mark.parametrize("dtype", ["float32", "float64"]) @pytest.mark.parametrize("torch_func", TORCH_FUNCS) diff --git a/thinc/util.py b/thinc/util.py index faae1a307..591bed7f4 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -1,6 +1,7 @@ from typing import Any, Union, Sequence, cast, Dict, Optional, Callable, TypeVar from typing import List, Tuple import numpy +from packaging.version import Version import random import functools from wasabi import table @@ -31,11 +32,16 @@ has_torch = True has_torch_gpu = torch.cuda.device_count() != 0 - has_torch_amp = hasattr(torch.cuda.amp, "common") and not torch.cuda.amp.common.amp_definitely_not_available() + has_torch_amp = ( + hasattr(torch.cuda.amp, "common") + and not torch.cuda.amp.common.amp_definitely_not_available() + ) + torch_version = Version(str(torch.__version__)) except ImportError: # pragma: no cover has_torch = False has_torch_gpu = False has_torch_amp = False + torch_version = Version("0.0.0") try: # pragma: no cover import tensorflow.experimental.dlpack From 6adb0b8b149fc01ac3aee4797e6d367b918e60d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 15 Mar 2022 09:30:37 +0100 Subject: [PATCH 3/7] Don't use gradient scaling unconditionally in PyTorch wrapper test --- thinc/tests/layers/test_pytorch_wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thinc/tests/layers/test_pytorch_wrapper.py b/thinc/tests/layers/test_pytorch_wrapper.py index fef27b7e7..ce3b6ae8d 100644 --- a/thinc/tests/layers/test_pytorch_wrapper.py +++ b/thinc/tests/layers/test_pytorch_wrapper.py @@ -80,7 +80,9 @@ def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision): PyTorchWrapper_v2( pytorch_layer.cuda(), mixed_precision=mixed_precision, - grad_scaler=PyTorchGradScaler(enabled=True, init_scale=2.0 ** 16), + grad_scaler=PyTorchGradScaler( + enabled=mixed_precision, init_scale=2.0**16 + ), ).initialize(), ) # pytorch allocator is set in PyTorchShim From 57da4ce7ec70332d9d37216c1b0a9aabd46c1f24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 15 Mar 2022 09:30:37 +0100 Subject: [PATCH 4/7] Disable gradient scaling tests on older PyTorch versions --- thinc/tests/shims/test_pytorch_grad_scaler.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/thinc/tests/shims/test_pytorch_grad_scaler.py b/thinc/tests/shims/test_pytorch_grad_scaler.py index 9d96bab23..8a02b7c80 100644 --- a/thinc/tests/shims/test_pytorch_grad_scaler.py +++ b/thinc/tests/shims/test_pytorch_grad_scaler.py @@ -2,7 +2,8 @@ from hypothesis import given, settings from hypothesis.strategies import lists, one_of, tuples -from thinc.util import has_torch, has_torch_gpu, is_torch_array +from thinc.util import has_torch, has_torch_amp, has_torch_gpu +from thinc.util import is_torch_array from thinc.api import PyTorchGradScaler from ..strategies import ndarrays @@ -22,6 +23,9 @@ def tensors(): @pytest.mark.skipif(not has_torch, reason="needs PyTorch") @pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif( + not has_torch_amp, reason="requires PyTorch with mixed-precision support" +) @given(X=one_of(tensors(), lists(tensors()), tuples(tensors()))) @settings(deadline=None) def test_scale_random_inputs(X): @@ -32,16 +36,19 @@ def test_scale_random_inputs(X): scaler.to_(device_id) if is_torch_array(X): - assert torch.allclose(scaler.scale(X), X * 2.0 ** 16) + assert torch.allclose(scaler.scale(X), X * 2.0**16) else: scaled1 = scaler.scale(X) - scaled2 = [t * 2.0 ** 16 for t in X] + scaled2 = [t * 2.0**16 for t in X] for t1, t2 in zip(scaled1, scaled2): assert torch.allclose(t1, t2) @pytest.mark.skipif(not has_torch, reason="needs PyTorch") @pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif( + not has_torch_amp, reason="requires PyTorch with mixed-precision support" +) def test_grad_scaler(): import torch @@ -53,10 +60,10 @@ def test_grad_scaler(): # Test that scaling works as expected. t = torch.tensor([1.0], device=device_id) assert scaler.scale([torch.tensor([1.0], device=device_id)]) == [ - torch.tensor([2.0 ** 16], device=device_id) + torch.tensor([2.0**16], device=device_id) ] assert scaler.scale(torch.tensor([1.0], device=device_id)) == torch.tensor( - [2.0 ** 16], device=device_id + [2.0**16], device=device_id ) with pytest.raises(ValueError): scaler.scale("bogus") @@ -65,7 +72,7 @@ def test_grad_scaler(): # Test infinity detection. g = [ - torch.tensor([2.0 ** 16], device=device_id), + torch.tensor([2.0**16], device=device_id), torch.tensor([float("Inf")], device=device_id), ] @@ -80,5 +87,5 @@ def test_grad_scaler(): # Since infinity was found, the scale should be halved from 2**16 # to 2**15 for the next step. assert scaler.scale([torch.tensor([1.0], device=device_id)]) == [ - torch.tensor([2.0 ** 15], device=device_id) + torch.tensor([2.0**15], device=device_id) ] From 8f1f8ccdf2350b8f4a1aa0c86276040442e66993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 15 Mar 2022 09:30:37 +0100 Subject: [PATCH 5/7] Set minimum required PyTorch version to 1.6.0 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index aa0457d7c..59b69419a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -86,7 +86,7 @@ cuda115 = datasets = ml_datasets>=0.2.0,<0.3.0 torch = - torch>=1.5.0 + torch>=1.6.0 tensorflow = tensorflow>=2.0.0,<2.6.0 mxnet = From 53089c8868d28e3734b35210c2eaeb423fbd3d19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 15 Mar 2022 09:30:37 +0100 Subject: [PATCH 6/7] Check that torch>=1.9.0 for mixed-precision training Torch versions prior to 1.9.0 do not have the functionality that we need for mixed-precision training and gradient scaling. --- thinc/shims/pytorch.py | 6 ++++++ thinc/shims/pytorch_grad_scaler.py | 9 +++++++-- thinc/tests/layers/test_pytorch_wrapper.py | 15 +++++++++++++++ thinc/tests/shims/test_pytorch_grad_scaler.py | 9 +++++++++ thinc/util.py | 4 ++-- 5 files changed, 39 insertions(+), 4 deletions(-) diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index ae3908cbb..1ed890ce2 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -13,6 +13,7 @@ pass from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive +from ..util import has_torch_amp from ..backends import get_current_ops, context_pools, CupyOps from ..backends import set_gpu_allocator from ..optimizers import Optimizer @@ -42,6 +43,11 @@ def __init__( mixed_precision: bool = False, grad_scaler: Optional[PyTorchGradScaler] = None, ): + if mixed_precision and not has_torch_amp: + raise ValueError( + "Mixed-precision training is not supported, upgrade to torch>=1.9.0" + ) + super().__init__(model, config, optimizer) if grad_scaler is None: diff --git a/thinc/shims/pytorch_grad_scaler.py b/thinc/shims/pytorch_grad_scaler.py index 2a2f0ccf6..7eb63cb87 100644 --- a/thinc/shims/pytorch_grad_scaler.py +++ b/thinc/shims/pytorch_grad_scaler.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, List, Union, cast -from ..util import is_torch_array +from ..util import has_torch_amp, is_torch_array try: import torch @@ -23,7 +23,7 @@ class PyTorchGradScaler: def __init__( self, enabled: bool = False, - init_scale: float = 2.0 ** 16, + init_scale: float = 2.0**16, backoff_factor: float = 0.5, growth_factor: float = 2.0, growth_interval: int = 2000, @@ -50,6 +50,11 @@ def __init__( When no overflows were found for this number of steps, the scale will be multiplied by "growth_factor". """ + if enabled and not has_torch_amp: + raise ValueError( + "Gradient scaling is not supported, upgrade to torch>=1.9.0" + ) + self._enabled = enabled self._growth_factor = growth_factor self._backoff_factor = backoff_factor diff --git a/thinc/tests/layers/test_pytorch_wrapper.py b/thinc/tests/layers/test_pytorch_wrapper.py index ce3b6ae8d..6028c8390 100644 --- a/thinc/tests/layers/test_pytorch_wrapper.py +++ b/thinc/tests/layers/test_pytorch_wrapper.py @@ -148,3 +148,18 @@ def test_pytorch_convert_inputs(data, n_args, kwargs_keys): convert_inputs = model.attrs["convert_inputs"] Y, backprop = convert_inputs(model, data, is_train=True) check_input_converters(Y, backprop, data, n_args, kwargs_keys, torch.Tensor) + + +@pytest.mark.skipif(not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU") +@pytest.mark.skipif( + has_torch_amp, reason="needs PyTorch without mixed-precision support" +) +def test_raises_on_old_pytorch(): + import torch.nn + + pytorch_layer = torch.nn.Linear(5, 5) + with pytest.raises(ValueError, match=r"not supported.*1.9.0"): + PyTorchWrapper_v2( + pytorch_layer.cuda(), + mixed_precision=True, + ) diff --git a/thinc/tests/shims/test_pytorch_grad_scaler.py b/thinc/tests/shims/test_pytorch_grad_scaler.py index 8a02b7c80..fd4c3e132 100644 --- a/thinc/tests/shims/test_pytorch_grad_scaler.py +++ b/thinc/tests/shims/test_pytorch_grad_scaler.py @@ -89,3 +89,12 @@ def test_grad_scaler(): assert scaler.scale([torch.tensor([1.0], device=device_id)]) == [ torch.tensor([2.0**15], device=device_id) ] + + +@pytest.mark.skipif(not has_torch, reason="needs PyTorch") +@pytest.mark.skipif( + has_torch_amp, reason="needs PyTorch without gradient scaling support" +) +def test_raises_on_old_pytorch(): + with pytest.raises(ValueError, match=r"not supported.*1.9.0"): + PyTorchGradScaler(enabled=True) diff --git a/thinc/util.py b/thinc/util.py index 591bed7f4..3c6416efa 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -32,11 +32,11 @@ has_torch = True has_torch_gpu = torch.cuda.device_count() != 0 + torch_version = Version(str(torch.__version__)) has_torch_amp = ( - hasattr(torch.cuda.amp, "common") + torch_version >= Version("1.9.0") and not torch.cuda.amp.common.amp_definitely_not_available() ) - torch_version = Version(str(torch.__version__)) except ImportError: # pragma: no cover has_torch = False has_torch_gpu = False From 60ec8318e7dac6690ea56c9cc56d750e861017ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 15 Mar 2022 09:30:37 +0100 Subject: [PATCH 7/7] Refine exception message for mixed-precision training --- thinc/shims/pytorch.py | 2 +- thinc/shims/pytorch_grad_scaler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 1ed890ce2..efa3d0467 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -45,7 +45,7 @@ def __init__( ): if mixed_precision and not has_torch_amp: raise ValueError( - "Mixed-precision training is not supported, upgrade to torch>=1.9.0" + "Mixed-precision training is not supported, requires capable GPU and torch>=1.9.0" ) super().__init__(model, config, optimizer) diff --git a/thinc/shims/pytorch_grad_scaler.py b/thinc/shims/pytorch_grad_scaler.py index 7eb63cb87..1b7d6ffff 100644 --- a/thinc/shims/pytorch_grad_scaler.py +++ b/thinc/shims/pytorch_grad_scaler.py @@ -52,7 +52,7 @@ def __init__( """ if enabled and not has_torch_amp: raise ValueError( - "Gradient scaling is not supported, upgrade to torch>=1.9.0" + "Gradient scaling is not supported, requires capable GPU and torch>=1.9.0" ) self._enabled = enabled