Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Update some models for AMP training #104

Merged
merged 13 commits into from
Aug 10, 2020
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- `CopyNet` and `SimpleSeq2Seq` models now work with AMP.

## [v1.1.0rc2](https:/allenai/allennlp-models/releases/tag/v1.1.0rc2) - 2020-07-31

### Changed
Expand Down
10 changes: 7 additions & 3 deletions allennlp_models/generation/models/copynet_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,13 @@ def _decoder_step(
# shape: (group_size, decoder_input_dim)
projected_decoder_input = self._input_projection_layer(decoder_input)

state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
projected_decoder_input, (state["decoder_hidden"], state["decoder_context"])
)
# TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells.
with torch.cuda.amp.autocast(False):
state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
projected_decoder_input.float(),
(state["decoder_hidden"].float(), state["decoder_context"].float()),
)

return state

def _get_generation_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
Expand Down
9 changes: 6 additions & 3 deletions allennlp_models/generation/models/simple_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,12 @@ def _prepare_output_projections(

# shape (decoder_hidden): (batch_size, decoder_output_dim)
# shape (decoder_context): (batch_size, decoder_output_dim)
decoder_hidden, decoder_context = self._decoder_cell(
decoder_input, (decoder_hidden, decoder_context)
)

# TODO (epwalsh): remove the autocast(False) once torch's AMP is working for RNNs.
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
with torch.cuda.amp.autocast(False):
decoder_hidden, decoder_context = self._decoder_cell(
decoder_input.float(), (decoder_hidden.float(), decoder_context.float())
)

state["decoder_hidden"] = decoder_hidden
state["decoder_context"] = decoder_context
Expand Down
8 changes: 5 additions & 3 deletions allennlp_models/generation/modules/decoder_nets/lstm_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def forward(

# shape (decoder_hidden): (batch_size, decoder_output_dim)
# shape (decoder_context): (batch_size, decoder_output_dim)
decoder_hidden, decoder_context = self._decoder_cell(
decoder_input, (decoder_hidden, decoder_context)
)
# TODO (epwalsh): remove the autocast(False) once torch's AMP is working for RNNs.
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
with torch.cuda.amp.autocast(False):
decoder_hidden, decoder_context = self._decoder_cell(
decoder_input.float(), (decoder_hidden.float(), decoder_context.float())
)

return (
{"decoder_hidden": decoder_hidden, "decoder_context": decoder_context},
Expand Down
32 changes: 30 additions & 2 deletions tests/generation/models/copynet_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from scipy.special import logsumexp
import torch
import pytest

from allennlp.common.testing import ModelTestCase
from allennlp.commands.train import train_model_from_file
from allennlp.common.testing import ModelTestCase, requires_gpu

from allennlp_models.generation import CopyNetDatasetReader, CopyNetSeq2Seq # noqa: F401
from tests import FIXTURES_ROOT
Expand All @@ -16,9 +18,35 @@ def setup_method(self):
FIXTURES_ROOT / "generation" / "copynet" / "data" / "copyover.tsv",
)

def test_model_can_train_save_load_predict(self):
def test_model_can_train_save_load(self):
self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-2)

@requires_gpu
def test_model_can_train_with_amp(self):
train_model_from_file(
self.param_file,
self.TEST_DIR,
overrides="{'trainer.use_amp':true,'trainer.cuda_device':0}",
)

# NOTE: as of writing this test, AMP does not work with RNNs and LSTMCells. Hence we had
# to wrap the call to LSTMCell() in CopyNet (and other models) within an autocast(False) context.
# But if this part of the test fails, i.e. a RuntimeError is never raised,
# that means AMP may be working now with RNNs, in which case we can remove
# any calls to `autocast(False)` around RNNs like we do in CopyNet.
# So just do a grep search for uses of 'autocast(False)' or 'autocast(enabled=False)'
# in the library.
# If you're still confused, contact @epwalsh.
with pytest.raises(RuntimeError, match="expected scalar type Half but found Float"):
rnn = torch.nn.LSTMCell(10, 20).cuda()

hx = torch.rand((3, 20), device="cuda")
cx = torch.rand((3, 20), device="cuda")
inp = torch.rand((3, 10), device="cuda")

with torch.cuda.amp.autocast(True):
hx, cx = rnn(inp, (hx, cx))

def test_vocab(self):
vocab = self.model.vocab
assert vocab.get_vocab_size(self.model._target_namespace) == 8
Expand Down
11 changes: 10 additions & 1 deletion tests/generation/models/simple_seq2seq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy
import torch

from allennlp.common.testing import ModelTestCase
from allennlp.commands.train import train_model_from_file
from allennlp.common.testing import ModelTestCase, requires_gpu
from allennlp.nn.beam_search import BeamSearch
from allennlp.nn.util import sequence_cross_entropy_with_logits

Expand All @@ -21,6 +22,14 @@ def setup_method(self):
def test_model_can_train_save_and_load(self):
self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-2)

@requires_gpu
def test_model_can_train_with_amp(self):
train_model_from_file(
self.param_file,
self.TEST_DIR,
overrides="{'trainer.use_amp':true,'trainer.cuda_device':0}",
)

def test_bidirectional_model_can_train_save_and_load(self):
param_overrides = json.dumps({"model": {"encoder": {"bidirectional": True}}})
self.ensure_model_can_train_save_and_load(
Expand Down