Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove TiedEmbeddingTransformerDecoder (and GemmaTransformerDecoder) #1454

Closed
2 tasks done
SalmanMohammadi opened this issue Aug 29, 2024 · 2 comments
Closed
2 tasks done
Labels
better engineering Tasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs community help wanted We would love the community's help completing this issue

Comments

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Aug 29, 2024

Credit to @pbontrager for this.

To quote from #1447


Also just going to throw this out there: @pbontrager had suggested that we can do away with TiedEmbeddingTransformerDecoder entirely and instead do something like (crappy diff snippet from one of our model builders as an example)

    tok_embeddings = nn.Embedding(vocab_size, embed_dim)
    
	# How we do things for non-tied-embedding models (e.g. Llama)
    # output_proj = nn.Linear(embed_dim, vocab_size, bias=False)

	# What things would look like for Gemma or Qwen2 instead
    output_proj = lambda x: F.linear(x, tok_embeddings.weight)

    return TransformerDecoder(
        tok_embeddings=tok_embeddings,
        layers=layer,
        num_layers=num_layers,
        max_seq_len=max_seq_len,
        num_heads=num_heads,
        head_dim=head_dim,
        norm=RMSNorm(embed_dim, eps=norm_eps),
        output=output_proj,
    )

to build tied-embedding models directly in our TransformerDecoder class (though possibly without a lambda and maybe with a proper function). Main open question is whether this works with FSDP and checkpointing


To provide some background, currently, for models which have tied embedding-output projection weights such as Gemma and Qwen2, we define an entirely separate TransformerDecoder. This class is identical except for the final lines in the forward signature:

# in TransformerDecoder

        # shape: [b, s, out_dim] - out_dim is usually the vocab size
        output = self.output(h).float()
        return output

# in TiedEmbeddingTransformerDecoder and GemmaTransformerDecoder

        # shape: [b, s, out_dim] - out_dim is usually the vocab size
        output = F.linear(h, self.tok_embeddings.weight).float()
        ...
        return output

Making this change should be conceptually very straightforward, but will touch several parts of the codebase. Off the top of my head, and in no particular order:

  1. Try parameterize the tied embedding using a callable (either a lambda as above, or something like
output_proj = partial(torch.nn.functional.linear, tok_embeddings.weight)

or as a proper function. Just for a single model (e.g. Qwen2 0.5B)
2) Test out a recipe with this builder and make sure things work OK - they should! Also make sure one of our distributed recipes works okay (we can help out here).
3) Extend the change to any other models using either GemmaTransformerDecoder, or TiedEmbeddingTransformerDecoder. All of these component and model builders should now just construct and return a TransformerDecoder.
4) CTRL+F GemmaTransformerDecoder and TiedEmbeddingTransformerDecoder. Eradicate. Docs, tests, __init__.pys. Be ruthless.
5) Probably something else here.
6) $$$

Tasks

  1. CLA Signed
  2. CLA Signed
@SalmanMohammadi SalmanMohammadi added good first issue Good for newcomers community help wanted We would love the community's help completing this issue better engineering Tasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs labels Aug 29, 2024
@SalmanMohammadi SalmanMohammadi changed the title Remove TiedEmbeddingTransformerDecoder Remove TiedEmbeddingTransformerDecoder (and GemmaTransformerDecoder) Aug 29, 2024
@SalmanMohammadi SalmanMohammadi removed the good first issue Good for newcomers label Aug 29, 2024
@felipemello1
Copy link
Contributor

#1547
#1553

@felipemello1
Copy link
Contributor

PRs landed. Thanks for the issue!! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better engineering Tasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs community help wanted We would love the community's help completing this issue
Projects
None yet
Development

No branches or pull requests

2 participants