From d0ccc3e02e1a9015d05cade8dfc61896948275c7 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Jun 2020 18:53:19 -0700 Subject: [PATCH] Add FairseqDecoder.reorder_incremental_state_scripting for TorchScript (#1190) Summary: The main changes are in fairseq_incremental_decoder.py. I made the base `reorder_incremental_state` implementation a no-op and instead we expect callers (e.g., SequenceGenerator) to call `reorder_incremental_state_scripting`. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1190 Test Plan: I ran unit tests both in PyTorch 1.5 and nightly (1.6). I also tested some of the pretrained translation models, but it'd be good to test with some prod runs. Reviewed By: jhcross Differential Revision: D22095614 Pulled By: myleott fbshipit-source-id: 484b8d47b4feda4efe52233a3d46a207d0816766 --- fairseq/models/fairseq_incremental_decoder.py | 34 +++++++--- fairseq/models/lstm.py | 63 +++++++++++-------- fairseq/models/transformer.py | 11 ---- fairseq/modules/transformer_layer.py | 12 ---- fairseq/sequence_generator.py | 2 +- 5 files changed, 63 insertions(+), 59 deletions(-) diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py index 51ab577288..68e583fea8 100644 --- a/fairseq/models/fairseq_incremental_decoder.py +++ b/fairseq/models/fairseq_incremental_decoder.py @@ -2,11 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +import logging from typing import Dict, Optional +from torch import Tensor + from fairseq.models import FairseqDecoder from fairseq.incremental_decoding_utils import with_incremental_state -from torch import Tensor + + +logger = logging.getLogger(__name__) @with_incremental_state @@ -68,18 +74,28 @@ def reorder_incremental_state( ): """Reorder incremental state. - This should be called when the order of the input has changed from the + This will be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams. """ - seen: Dict[int, Optional[Tensor]] = {} - for _, module in self.named_modules(): + pass + + def reorder_incremental_state_scripting( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + for module in self.modules(): if hasattr(module, 'reorder_incremental_state'): - if id(module) not in seen and module is not self: - seen[id(module)] = None - result = module.reorder_incremental_state(incremental_state, new_order) - if result is not None: - incremental_state = result + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result def set_beam_size(self, beam_size): """Sets the beam size in the decoder and all children.""" diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index c2fbde33a4..83baf7f065 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -347,6 +347,7 @@ def forward(self, input, source_hids, encoder_padding_mask): x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) return x, attn_scores + class LSTMDecoder(FairseqIncrementalDecoder): """LSTM decoder.""" def __init__( @@ -410,18 +411,6 @@ def __init__( elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) - def get_cached_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): - cached_state = self.get_incremental_state(incremental_state, 'cached_state') - assert cached_state is not None - prev_hiddens_ = cached_state["prev_hiddens"] - assert prev_hiddens_ is not None - prev_cells_ = cached_state["prev_cells"] - assert prev_cells_ is not None - prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] - prev_cells = [prev_cells_[j] for j in range(self.num_layers)] - input_feed = cached_state["input_feed"] # can be None for decoder-only language models - return prev_hiddens, prev_cells, input_feed - def forward( self, prev_output_tokens, @@ -529,9 +518,13 @@ def extract_features( prev_cells_tensor = torch.stack(prev_cells) cache_state = torch.jit.annotate( Dict[str, Optional[Tensor]], - {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": input_feed}) - self.set_incremental_state( - incremental_state, 'cached_state', cache_state) + { + "prev_hiddens": prev_hiddens_tensor, + "prev_cells": prev_cells_tensor, + "input_feed": input_feed, + } + ) + self.set_incremental_state(incremental_state, 'cached_state', cache_state) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) @@ -559,23 +552,41 @@ def output_layer(self, x): x = self.fc_out(x) return x - def reorder_state(self, state: List[Tensor], new_order): - return [ - state_i.index_select(0, new_order) if state_i is not None else None - for state_i in state - ] + def get_cached_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + ) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]: + cached_state = self.get_incremental_state(incremental_state, 'cached_state') + assert cached_state is not None + prev_hiddens_ = cached_state["prev_hiddens"] + assert prev_hiddens_ is not None + prev_cells_ = cached_state["prev_cells"] + assert prev_cells_ is not None + prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] + prev_cells = [prev_cells_[j] for j in range(self.num_layers)] + input_feed = cached_state["input_feed"] # can be None for decoder-only language models + return prev_hiddens, prev_cells, input_feed - def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order): + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): if incremental_state is None or len(incremental_state) == 0: return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) - cached_state = (prev_hiddens, prev_cells, [input_feed]) - new_state = [self.reorder_state(state, new_order) for state in cached_state] - prev_hiddens_tensor = torch.stack(new_state[0]) - prev_cells_tensor = torch.stack(new_state[1]) + prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens] + prev_cells = [p.index_select(0, new_order) for p in prev_cells] + if input_feed is not None: + input_feed = input_feed.index_select(0, new_order) cached_state_new = torch.jit.annotate( Dict[str, Optional[Tensor]], - {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]}) + { + "prev_hiddens": torch.stack(prev_hiddens), + "prev_cells": torch.stack(prev_cells), + "input_feed": input_feed, + } + ) self.set_incremental_state(incremental_state, 'cached_state', cached_state_new), return diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 352db5a293..9171aaf4a2 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -882,17 +882,6 @@ def upgrade_state_dict_named(self, state_dict, name): return state_dict - # Overwrite the method to temporaily support JIT scripting in Transformer - @torch.jit.export - def reorder_incremental_state( - self, - incremental_state: Dict[str, Dict[str, Optional[Tensor]]], - new_order: Tensor, - ): - """Scriptable reorder incremental state in the transformer.""" - for layer in self.layers: - layer.reorder_incremental_state(incremental_state, new_order) - def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index cae8498315..8fb08b3aaf 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -379,18 +379,6 @@ def forward( def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn - @torch.jit.export - def reorder_incremental_state( - self, - incremental_state: Dict[str, Dict[str, Optional[Tensor]]], - new_order: Tensor, - ): - """Scriptable reorder incremental state in transformer layers.""" - self.self_attn.reorder_incremental_state(incremental_state, new_order) - - if self.encoder_attn is not None: - self.encoder_attn.reorder_incremental_state(incremental_state, new_order) - def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 7ecdde869f..a523b1ea64 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -794,7 +794,7 @@ def reorder_incremental_state(self, new_order): if not self.has_incremental_states(): return for i, model in enumerate(self.models): - model.decoder.reorder_incremental_state( + model.decoder.reorder_incremental_state_scripting( self.incremental_states[i], new_order )