Skip to content

Commit

Permalink
2/n - Make Gemma use regular TransformerDecoder (#1553)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
  • Loading branch information
felipemello1 and Felipe Mello authored Sep 12, 2024
1 parent 7c51100 commit 7dad2d6
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 46 deletions.
6 changes: 3 additions & 3 deletions recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#
# This config works only when the model is being fine-tuned on 2+ GPUs.


# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
Expand All @@ -33,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -47,6 +46,7 @@ checkpointer:
output_dir: /tmp/gemma-2b
model_type: GEMMA
resume_from_checkpoint: False

save_adapter_weights_only: False

optimizer:
Expand All @@ -55,7 +55,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -54,7 +54,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -54,7 +54,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -57,7 +57,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -56,7 +56,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
36 changes: 19 additions & 17 deletions torchtune/models/gemma/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
RotaryPositionalEmbeddings,
TransformerSelfAttentionLayer,
)
from torchtune.models.gemma.rms_norm import GemmaRMSNorm
from torchtune.models.gemma.transformer import GemmaTransformerDecoder

from torchtune.models.gemma.rms_norm import GemmaRMSNorm
from torchtune.modules import TransformerDecoder, TiedLinear
from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear

"""
Expand Down Expand Up @@ -47,7 +48,7 @@ def gemma(
norm_eps: float = 1e-6,
rope_base: int = 10_000,
norm_embeddings: bool = True,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Build the decoder associated with the gemma model. This includes:
- Token embeddings
Expand Down Expand Up @@ -76,7 +77,7 @@ def gemma(
and mlp layers. Default: True
Returns:
GemmaTransformerDecoder: Instantiation of gemma model.
TransformerDecoder: Instantiation of gemma model.
"""
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
self_att = MultiHeadAttention(
Expand All @@ -100,16 +101,17 @@ def gemma(
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
model = GemmaTransformerDecoder(
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
output_proj = TiedLinear(tok_embeddings)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
output=output_proj,
head_dim=head_dim,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
norm_embeddings=norm_embeddings,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps)
)
return model

Expand Down Expand Up @@ -152,7 +154,7 @@ def lora_gemma(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Return a version of Gemma with LoRA applied based on the passed in configuration.
Note: output projection lora is not supported because it is tied to token embeddings
Expand Down Expand Up @@ -188,7 +190,7 @@ def lora_gemma(
supported for quantization currently.
Returns:
GemmaTransformerDecoder: Instantiation of Gemma model with LoRA applied to
TransformerDecoder: Instantiation of Gemma model with LoRA applied to
a subset of the attention projections in each layer.
"""
self_attn = lora_gemma_self_attention(
Expand Down Expand Up @@ -226,17 +228,17 @@ def lora_gemma(
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)

model = GemmaTransformerDecoder(
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
output_proj = TiedLinear(tok_embeddings)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
output=output_proj,
head_dim=head_dim,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
norm_embeddings=norm_embeddings,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps)
)

if quantize_base:
Expand Down
18 changes: 9 additions & 9 deletions torchtune/models/gemma/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List, Optional

from torchtune.models.gemma._component_builders import gemma, lora_gemma
from torchtune.models.gemma.transformer import GemmaTransformerDecoder
from torchtune.modules import TransformerDecoder

from torchtune.models.gemma._tokenizer import GemmaTokenizer
from torchtune.modules.peft import LORA_ATTN_MODULES
Expand All @@ -21,13 +21,13 @@
"""


def gemma_2b() -> GemmaTransformerDecoder:
def gemma_2b() -> TransformerDecoder:
"""
Builder for creating a Gemma 2B model initialized w/ the default 2b parameter values
from: https://blog.google/technology/developers/gemma-open-models/
Returns:
GemmaTransformerDecoder: Instantiation of Gemma 2B model
TransformerDecoder: Instantiation of Gemma 2B model
"""
return gemma(
vocab_size=256_000,
Expand Down Expand Up @@ -71,7 +71,7 @@ def lora_gemma_2b(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Builder for creating a Gemma 2B model with LoRA enabled.
Expand All @@ -93,7 +93,7 @@ def lora_gemma_2b(
quantize_base (bool): Whether to quantize base model weights
Returns:
GemmaTransformerDecoder: Instantiation of Gemma 2B model with LoRA applied
TransformerDecoder: Instantiation of Gemma 2B model with LoRA applied
"""
return lora_gemma(
lora_attn_modules=lora_attn_modules,
Expand Down Expand Up @@ -125,13 +125,13 @@ def lora_gemma_2b(



def gemma_7b() -> GemmaTransformerDecoder:
def gemma_7b() -> TransformerDecoder:
"""
Builder for creating a Gemma 7B model initialized w/ the default 7b parameter values
from: https://blog.google/technology/developers/gemma-open-models/
Returns:
GemmaTransformerDecoder: Instantiation of Gemma 7B model
TransformerDecoder: Instantiation of Gemma 7B model
"""
return gemma(
vocab_size=256_000,
Expand All @@ -155,7 +155,7 @@ def lora_gemma_7b(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Builder for creating a Gemma 7B model with LoRA enabled.
Expand All @@ -177,7 +177,7 @@ def lora_gemma_7b(
quantize_base (bool): Whether to quantize base model weights
Returns:
GemmaTransformerDecoder: Instantiation of Gemma 7B model with LoRA applied
TransformerDecoder: Instantiation of Gemma 7B model with LoRA applied
"""
return lora_gemma(
lora_attn_modules=lora_attn_modules,
Expand Down
47 changes: 47 additions & 0 deletions torchtune/models/gemma/gemma_norm_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class GemmaNormEmbeddings(nn.Embedding):
"""Module with Embedding and normalization specific to Gemma.
Gemma requires normalization right after the embeddings. By merging both
steps in a single module, we can utilize directly
:class:`~torch.modules.TransformerDecoder`.
For more details about the embedding module, please see
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
Args:
num_embeddings (int): size of the dictionary of embeddings.
embedding_dim (int): the size of each embedding vector.
*args: Variable length argument list to be passed to the Embedding module.
**kwargs: Arbitrary keyword arguments to be passed to the Embedding module.
Example:
>>> import torch
>>> from torchtune.models.gemma import GemmaNormEmbeddings
>>> embeddings = GemmaNormEmbeddings(2, 4)
>>> x = torch.randint(0, 2, (1, 3)) # ids can be 0 or 1
>>> print(x)
>>> print(embeddings(x))
>>> print(embeddings(x).shape)
tensor([[1, 0, 0]])
tensor([[[-0.2152, -2.1914, 2.8491, -0.4824],
[-3.6621, -1.0267, 1.5947, -1.7349],
[-3.6621, -1.0267, 1.5947, -1.7349]]], grad_fn=<MulBackward0>)
torch.Size([1, 3, 4])
"""

def __init__(self, num_embeddings: int, embedding_dim: int, *args, **kwargs):
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
self.embedding_dim = embedding_dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = super().forward(x)
return x * torch.tensor(self.embedding_dim**0.5, dtype=x.dtype)
5 changes: 5 additions & 0 deletions torchtune/models/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
from torchtune.modules import KVCache

from torchtune.modules.transformer import _get_clones, TransformerSelfAttentionLayer
from torchtune.utils.logging import deprecated


@deprecated(
msg="Please use torchtune.modules.TransformerDecoder instead. \
If you need an example, see torchtune.models.gemma._component_builders.py"
)
class GemmaTransformerDecoder(nn.Module):
"""
GemmaTransformer Decoder derived from Gemma architecture. A key difference between
Expand Down
11 changes: 3 additions & 8 deletions torchtune/training/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,27 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Union

import torch
from torch import nn

from torchtune.modules import (
TiedEmbeddingTransformerDecoder,
TransformerDecoder,
TransformerSelfAttentionLayer,
)
from torchtune.modules import TransformerDecoder, TransformerSelfAttentionLayer
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.utils import get_logger, torch_version_ge

log = get_logger("INFO")


def compile_model(
model: Union[TransformerDecoder, TiedEmbeddingTransformerDecoder],
model: TransformerDecoder,
verbose: bool = True,
) -> None:
"""
Utility to compile a transformer model inplace. On PyTorch nightlies we use per-layer compile
to reduce compile times. Otherwise we compile the full model, which takes longer.
Args:
model (Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]): A transformer model to compile.
model (TransformerDecoder): A transformer model to compile.
verbose (bool): Whether to log compile info. Default: True
Returns:
None
Expand Down

0 comments on commit 7dad2d6

Please sign in to comment.