Remove TiedEmbeddingTransformerDecoder
(and GemmaTransformerDecoder
)
#1454
Labels
better engineering
Tasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs
community help wanted
We would love the community's help completing this issue
Credit to @pbontrager for this.
To quote from #1447
Also just going to throw this out there: @pbontrager had suggested that we can do away with TiedEmbeddingTransformerDecoder entirely and instead do something like (crappy diff snippet from one of our model builders as an example)
to build tied-embedding models directly in our TransformerDecoder class (though possibly without a lambda and maybe with a proper function). Main open question is whether this works with FSDP and checkpointing
To provide some background, currently, for models which have tied embedding-output projection weights such as Gemma and Qwen2, we define an entirely separate
TransformerDecoder
. This class is identical except for the final lines in the forward signature:Making this change should be conceptually very straightforward, but will touch several parts of the codebase. Off the top of my head, and in no particular order:
or as a proper function. Just for a single model (e.g. Qwen2 0.5B)
2) Test out a recipe with this builder and make sure things work OK - they should! Also make sure one of our distributed recipes works okay (we can help out here).
3) Extend the change to any other models using either
GemmaTransformerDecoder
, orTiedEmbeddingTransformerDecoder
. All of these component and model builders should now just construct and return aTransformerDecoder
.4) CTRL+F
GemmaTransformerDecoder
andTiedEmbeddingTransformerDecoder
. Eradicate. Docs, tests,__init__.py
s. Be ruthless.5) Probably something else here.
6) $$$
Tasks
The text was updated successfully, but these errors were encountered: