Skip to content

Commit

Permalink
Add FairseqDecoder.reorder_incremental_state_scripting for TorchScript (
Browse files Browse the repository at this point in the history
#1190)

Summary:
The main changes are in fairseq_incremental_decoder.py. I made the base `reorder_incremental_state` implementation a no-op and instead we expect callers (e.g., SequenceGenerator) to call `reorder_incremental_state_scripting`.

Pull Request resolved: fairinternal/fairseq-py#1190

Test Plan:
I ran unit tests both in PyTorch 1.5 and nightly (1.6).

I also tested some of the pretrained translation models, but it'd be good to test with some prod runs.

Reviewed By: jhcross

Differential Revision: D22095614

Pulled By: myleott

fbshipit-source-id: 484b8d47b4feda4efe52233a3d46a207d0816766
  • Loading branch information
myleott authored and facebook-github-bot committed Jun 23, 2020
1 parent d5d2cf3 commit d0ccc3e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 59 deletions.
34 changes: 25 additions & 9 deletions fairseq/models/fairseq_incremental_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, Optional

from torch import Tensor

from fairseq.models import FairseqDecoder
from fairseq.incremental_decoding_utils import with_incremental_state
from torch import Tensor


logger = logging.getLogger(__name__)


@with_incremental_state
Expand Down Expand Up @@ -68,18 +74,28 @@ def reorder_incremental_state(
):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
This will be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
seen: Dict[int, Optional[Tensor]] = {}
for _, module in self.named_modules():
pass

def reorder_incremental_state_scripting(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Main entry point for reordering the incremental state.
Due to limitations in TorchScript, we call this function in
:class:`fairseq.sequence_generator.SequenceGenerator` instead of
calling :func:`reorder_incremental_state` directly.
"""
for module in self.modules():
if hasattr(module, 'reorder_incremental_state'):
if id(module) not in seen and module is not self:
seen[id(module)] = None
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result

def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
Expand Down
63 changes: 37 additions & 26 deletions fairseq/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def forward(self, input, source_hids, encoder_padding_mask):
x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1)))
return x, attn_scores


class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
def __init__(
Expand Down Expand Up @@ -410,18 +411,6 @@ def __init__(
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)

def get_cached_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
cached_state = self.get_incremental_state(incremental_state, 'cached_state')
assert cached_state is not None
prev_hiddens_ = cached_state["prev_hiddens"]
assert prev_hiddens_ is not None
prev_cells_ = cached_state["prev_cells"]
assert prev_cells_ is not None
prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)]
prev_cells = [prev_cells_[j] for j in range(self.num_layers)]
input_feed = cached_state["input_feed"] # can be None for decoder-only language models
return prev_hiddens, prev_cells, input_feed

def forward(
self,
prev_output_tokens,
Expand Down Expand Up @@ -529,9 +518,13 @@ def extract_features(
prev_cells_tensor = torch.stack(prev_cells)
cache_state = torch.jit.annotate(
Dict[str, Optional[Tensor]],
{"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": input_feed})
self.set_incremental_state(
incremental_state, 'cached_state', cache_state)
{
"prev_hiddens": prev_hiddens_tensor,
"prev_cells": prev_cells_tensor,
"input_feed": input_feed,
}
)
self.set_incremental_state(incremental_state, 'cached_state', cache_state)

# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
Expand Down Expand Up @@ -559,23 +552,41 @@ def output_layer(self, x):
x = self.fc_out(x)
return x

def reorder_state(self, state: List[Tensor], new_order):
return [
state_i.index_select(0, new_order) if state_i is not None else None
for state_i in state
]
def get_cached_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]:
cached_state = self.get_incremental_state(incremental_state, 'cached_state')
assert cached_state is not None
prev_hiddens_ = cached_state["prev_hiddens"]
assert prev_hiddens_ is not None
prev_cells_ = cached_state["prev_cells"]
assert prev_cells_ is not None
prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)]
prev_cells = [prev_cells_[j] for j in range(self.num_layers)]
input_feed = cached_state["input_feed"] # can be None for decoder-only language models
return prev_hiddens, prev_cells, input_feed

def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order):
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
if incremental_state is None or len(incremental_state) == 0:
return
prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state)
cached_state = (prev_hiddens, prev_cells, [input_feed])
new_state = [self.reorder_state(state, new_order) for state in cached_state]
prev_hiddens_tensor = torch.stack(new_state[0])
prev_cells_tensor = torch.stack(new_state[1])
prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens]
prev_cells = [p.index_select(0, new_order) for p in prev_cells]
if input_feed is not None:
input_feed = input_feed.index_select(0, new_order)
cached_state_new = torch.jit.annotate(
Dict[str, Optional[Tensor]],
{"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]})
{
"prev_hiddens": torch.stack(prev_hiddens),
"prev_cells": torch.stack(prev_cells),
"input_feed": input_feed,
}
)
self.set_incremental_state(incremental_state, 'cached_state', cached_state_new),
return

Expand Down
11 changes: 0 additions & 11 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,17 +882,6 @@ def upgrade_state_dict_named(self, state_dict, name):

return state_dict

# Overwrite the method to temporaily support JIT scripting in Transformer
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Scriptable reorder incremental state in the transformer."""
for layer in self.layers:
layer.reorder_incremental_state(incremental_state, new_order)


def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
Expand Down
12 changes: 0 additions & 12 deletions fairseq/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,18 +379,6 @@ def forward(
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn

@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Scriptable reorder incremental state in transformer layers."""
self.self_attn.reorder_incremental_state(incremental_state, new_order)

if self.encoder_attn is not None:
self.encoder_attn.reorder_incremental_state(incremental_state, new_order)


def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
Expand Down
2 changes: 1 addition & 1 deletion fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def reorder_incremental_state(self, new_order):
if not self.has_incremental_states():
return
for i, model in enumerate(self.models):
model.decoder.reorder_incremental_state(
model.decoder.reorder_incremental_state_scripting(
self.incremental_states[i], new_order
)

Expand Down

4 comments on commit d0ccc3e

@jahutwb
Copy link

@jahutwb jahutwb commented on d0ccc3e Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi!
Could you explain, how have you tested some of the pretrained translation models?
Were you able to export these models to TorchScript?
I've tried @erip's test script #1993, but it raised

RuntimeError:
Arguments for call are not valid.

then some invalid arguments, and

The original call is:
  File "/data/home/j.borkowski/fairseq/fairseq/sequence_generator.py", line 729

            attn: Optional[Tensor] = None
            decoder_len = len(decoder_out)
                          ~~~ <--- HERE
            if decoder_len > 1 and decoder_out[1] is not None:
                if isinstance(decoder_out[1], Tensor):

Is it finally possible to export transformer model to TorchScript?

@myleott
Copy link
Contributor Author

@myleott myleott commented on d0ccc3e Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, unfortunately full scripting depends on functionality that will be in the upcoming PyTorch 1.6 release (e.g., being able to script the len() function). You can use the PyTorch nightly build and it should work.

@jahutwb
Copy link

@jahutwb jahutwb commented on d0ccc3e Jul 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!
It worked for @erip's example as well as for my iwslt14 de-en model, but it failed when I've tried to script my quantized de-en model

RuntimeError:
Module 'PQEmbedding' has no attribute 'weight'
(This attribute exists on the Python module, but it's an @property method. @property methods aree file a feature request on Github):
  File "/data/home/j.borkowski/fairseq/fairseq/modules/quantization/pq/modules/qemb.py", line 70
    def forward(self, input):
        return F.embedding(
            input, self.weight, self.padding_idx, self.max_norm,
                   ~~~~~~~~~~~ <--- HERE
            self.norm_type, self.scale_grad_by_freq, self.sparse)

@kalyangvs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used pytorch1.6 - using the same script, encountered - torch.nn.modules.module.ModuleAttributeError: 'ModuleList' object has no attribute 'pad'
Probably caused here.
@myleott

Please sign in to comment.