Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully quantize Fairseq transformer #1993

Closed
wants to merge 1 commit into from

Conversation

cndn
Copy link
Contributor

@cndn cndn commented Apr 10, 2020

Summary:
F.linear -> nn.Linear so FBGEMM backend could quantize the linear projection. We observed 3x+ speedup.

Backward compatibility code is added to upgrade_state_dict_named. Locally it worked.

Testing loading OSS checkpoints.

Differential Revision: D20967830

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D20967830

cndn added a commit to cndn/fairseq that referenced this pull request Apr 10, 2020
Summary:
Pull Request resolved: facebookresearch#1993

F.linear -> nn.Linear so FBGEMM backend could quantize the linear projection. We observed 3x+ speedup.

Backward compatibility code is added to upgrade_state_dict_named. Locally it worked.

Testing loading OSS checkpoints.

Differential Revision: D20967830

fbshipit-source-id: b00abab4e40facc52ccf1af6b3f830c036071bce
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D20967830

@erip
Copy link
Contributor

erip commented Apr 11, 2020

Nicely done! This should resolve #1943, too!

@erip
Copy link
Contributor

erip commented Apr 11, 2020

Looks like there might be some backwards compatibility issues. I tried loading the wmt14.en-fr.joined-dict.transformer pretrained model:

My test script:

$ cat test.py
#!/usr/bin/env python

import torch

from fairseq.sequence_generator import SequenceGenerator
from fairseq.models.transformer import TransformerModel
from fairseq.data import Dictionary

if __name__ == "__main__":
    tgt_dict = Dictionary.load(open('dict.fr.txt'))
    model = TransformerModel.from_pretrained('.', 'model.pt', '.', bpe='subword_nmt', bpe_vocab='bpecodes')

    generator = SequenceGenerator(model.models, tgt_dict)

    scripted_gen = torch.jit.script(generator)
    scripted_gen.save('generator.pt')

And the error:

$ ./test.py
Traceback (most recent call last):
  File "./test.py", line 11, in <module>
    model = TransformerModel.from_pretrained('.', 'model.pt', '.', bpe='subword_nmt', bpe_vocab='bpecodes')
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/models/fairseq_model.py", line 218, in from_pretrained
    **kwargs,
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/hub_utils.py", line 73, in from_pretrained
    arg_overrides=kwargs,
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/checkpoint_utils.py", line 210, in load_model_ensemble_and_task
    model.load_state_dict(state["model"], strict=strict, args=args)
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/models/fairseq_model.py", line 93, in load_state_dict
    return super().load_state_dict(new_state_dict, strict)
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/torch/nn/modules/module.py", line 855, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TransformerModel:
	Missing key(s) in state_dict: "decoder.output_projection.weight", "decoder.output_projection.bias".

@erip
Copy link
Contributor

erip commented Apr 11, 2020

It looks like there's at least an issue here -- specifically I think the property is ...embed_tokens.weight, not ...embed_tokens.weights. Once I correct that, the stack trace is just the missing bias term:

  File "./test.py", line 11, in <module>
    model = TransformerModel.from_pretrained('.', 'model.pt', '.', bpe='subword_nmt', bpe_vocab='bpecodes')
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/models/fairseq_model.py", line 218, in from_pretrained
    **kwargs,
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/hub_utils.py", line 73, in from_pretrained
    arg_overrides=kwargs,
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/checkpoint_utils.py", line 210, in load_model_ensemble_and_task
    model.load_state_dict(state["model"], strict=strict, args=args)
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/fairseq/models/fairseq_model.py", line 93, in load_state_dict
    return super().load_state_dict(new_state_dict, strict)
  File "/Users/erippeth/miniconda3/envs/fairseq-dev/lib/python3.6/site-packages/torch/nn/modules/module.py", line 855, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TransformerModel:
	Missing key(s) in state_dict: "decoder.output_projection.bias".

Edit:

the original F.linear projection adds no bias by way of default arguments, so it's safe to add bias=False to the newly added linear projections, too. Since I can't push commits to your repo, here's the diff:

diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py
index 98c4ab5..6b74129 100644
--- a/fairseq/models/transformer.py
+++ b/fairseq/models/transformer.py
@@ -683,11 +683,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
 
         if self.share_input_output_embed:
             self.output_projection = nn.Linear(
-                self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0]
+                self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False
             )
         else:
             self.output_projection = nn.Linear(
-                self.output_embed_dim, len(dictionary)
+                self.output_embed_dim, len(dictionary), bias=False
             )
 
     def build_decoder_layer(self, args, no_encoder_attn=False):
@@ -891,7 +891,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
                 "{}.embed_positions._float_tensor".format(name)
             ] = torch.FloatTensor(1)
 
-        embed_tokens_weights_key = f"{name}.embed_tokens.weights"
+        embed_tokens_weights_key = f"{name}.embed_tokens.weight"
         embed_out_key = f"{name}.embed_out"
         if embed_tokens_weights_key in state_dict:
             state_dict[f"{name}.output_projection.weight"] = state_dict[

@cndn
Copy link
Contributor Author

cndn commented Apr 14, 2020

Thanks @erip ! I will address it very soon.

Summary:
Pull Request resolved: facebookresearch#1993

F.linear -> nn.Linear so FBGEMM backend could quantize the linear projection. We observed 3x+ speedup.

Add backward compatibility code.

Reviewed By: jhcross

Differential Revision: D20967830

fbshipit-source-id: 5a4b4c41f9c46fc06a05c50f57c249e8fcd7b1c8
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D20967830

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 6379573.

myleott added a commit that referenced this pull request Apr 20, 2020
facebook-github-bot pushed a commit that referenced this pull request Apr 21, 2020
Summary:
This reverts commit 6379573.

It doesn't tie weights and breaks old checkpoints.
Pull Request resolved: #2032

Reviewed By: cndn, ngoyal2707

Differential Revision: D21141945

Pulled By: myleott

fbshipit-source-id: b2f2ce8092a1bf8bcd6a7e422a69306e342b8cdd
myleott added a commit that referenced this pull request May 11, 2020
Summary:
This reverts commit 6379573.

It doesn't tie weights and breaks old checkpoints.
Pull Request resolved: #2032

Reviewed By: cndn, ngoyal2707

Differential Revision: D21141945

Pulled By: myleott

fbshipit-source-id: b2f2ce8092a1bf8bcd6a7e422a69306e342b8cdd
jahutwb referenced this pull request Jul 8, 2020
#1190)

Summary:
The main changes are in fairseq_incremental_decoder.py. I made the base `reorder_incremental_state` implementation a no-op and instead we expect callers (e.g., SequenceGenerator) to call `reorder_incremental_state_scripting`.

Pull Request resolved: fairinternal/fairseq-py#1190

Test Plan:
I ran unit tests both in PyTorch 1.5 and nightly (1.6).

I also tested some of the pretrained translation models, but it'd be good to test with some prod runs.

Reviewed By: jhcross

Differential Revision: D22095614

Pulled By: myleott

fbshipit-source-id: 484b8d47b4feda4efe52233a3d46a207d0816766
moussaKam pushed a commit to moussaKam/language-adaptive-pretraining that referenced this pull request Sep 29, 2020
Summary:
Pull Request resolved: facebookresearch#1993

F.linear -> nn.Linear so FBGEMM backend could quantize the linear projection. We observed 3x+ speedup.

Add backward compatibility code.

Reviewed By: jhcross

Differential Revision: D20967830

fbshipit-source-id: 11d2c98dd5c1965691d6df433e8428499c9c4dc0
moussaKam pushed a commit to moussaKam/language-adaptive-pretraining that referenced this pull request Sep 29, 2020
…acebookresearch#2032)

Summary:
This reverts commit 6379573.

It doesn't tie weights and breaks old checkpoints.
Pull Request resolved: facebookresearch#2032

Reviewed By: cndn, ngoyal2707

Differential Revision: D21141945

Pulled By: myleott

fbshipit-source-id: b2f2ce8092a1bf8bcd6a7e422a69306e342b8cdd
mgaido91 pushed a commit to mgaido91/FBK-fairseq-ST that referenced this pull request Jan 12, 2021
Summary:
Pull Request resolved: facebookresearch/fairseq#1993

F.linear -> nn.Linear so FBGEMM backend could quantize the linear projection. We observed 3x+ speedup.

Add backward compatibility code.

Reviewed By: jhcross

Differential Revision: D20967830

fbshipit-source-id: 11d2c98dd5c1965691d6df433e8428499c9c4dc0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants