Skip to content

Commit

Permalink
Revert "Fully quantize Fairseq transformer (#1993)" (#2032)
Browse files Browse the repository at this point in the history
Summary:
This reverts commit 6379573.

It doesn't tie weights and breaks old checkpoints.
Pull Request resolved: #2032

Reviewed By: cndn, ngoyal2707

Differential Revision: D21141945

Pulled By: myleott

fbshipit-source-id: b2f2ce8092a1bf8bcd6a7e422a69306e342b8cdd
  • Loading branch information
myleott authored and facebook-github-bot committed Apr 21, 2020
1 parent 6379573 commit ec57664
Showing 1 changed file with 9 additions and 23 deletions.
32 changes: 9 additions & 23 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,11 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(
torch.Tensor(len(dictionary), self.output_embed_dim)
)
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)

if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
Expand All @@ -681,16 +686,6 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
else:
self.layernorm_embedding = None

if self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False
)
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5)

def build_decoder_layer(self, args, no_encoder_attn=False):
return TransformerDecoderLayer(args, no_encoder_attn)

Expand Down Expand Up @@ -857,7 +852,10 @@ def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
else:
return features

Expand Down Expand Up @@ -892,18 +890,6 @@ def upgrade_state_dict_named(self, state_dict, name):
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)

embed_tokens_weights_key = f"{name}.embed_tokens.weights"
embed_out_key = f"{name}.embed_out"
if embed_tokens_weights_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_tokens_weights_key
]
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
del state_dict[embed_out_key]

for i in range(self.num_layers):
# update layer norms
layer_norm_map = {
Expand Down

0 comments on commit ec57664

Please sign in to comment.