Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2/n - Make Gemma use regular TransformerDecoder #1553

Merged
6 changes: 3 additions & 3 deletions recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#
# This config works only when the model is being fine-tuned on 2+ GPUs.


# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
Expand All @@ -33,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -47,6 +46,7 @@ checkpointer:
output_dir: /tmp/gemma-2b
model_type: GEMMA
resume_from_checkpoint: False

save_adapter_weights_only: False

optimizer:
Expand All @@ -55,7 +55,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -54,7 +54,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -54,7 +54,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -57,7 +57,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 16
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
Expand All @@ -56,7 +56,7 @@ optimizer:

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
36 changes: 19 additions & 17 deletions torchtune/models/gemma/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
RotaryPositionalEmbeddings,
TransformerSelfAttentionLayer,
)
from torchtune.models.gemma.rms_norm import GemmaRMSNorm
from torchtune.models.gemma.transformer import GemmaTransformerDecoder

from torchtune.models.gemma.rms_norm import GemmaRMSNorm
from torchtune.modules import TransformerDecoder, TiedLinear
from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear

"""
Expand Down Expand Up @@ -47,7 +48,7 @@ def gemma(
norm_eps: float = 1e-6,
rope_base: int = 10_000,
norm_embeddings: bool = True,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Build the decoder associated with the gemma model. This includes:
- Token embeddings
Expand Down Expand Up @@ -76,7 +77,7 @@ def gemma(
and mlp layers. Default: True

Returns:
GemmaTransformerDecoder: Instantiation of gemma model.
TransformerDecoder: Instantiation of gemma model.
"""
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
self_att = MultiHeadAttention(
Expand All @@ -100,16 +101,17 @@ def gemma(
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
model = GemmaTransformerDecoder(
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
output_proj = TiedLinear(tok_embeddings)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
output=output_proj,
head_dim=head_dim,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
norm_embeddings=norm_embeddings,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps)
)
return model

Expand Down Expand Up @@ -152,7 +154,7 @@ def lora_gemma(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Return a version of Gemma with LoRA applied based on the passed in configuration.
Note: output projection lora is not supported because it is tied to token embeddings
Expand Down Expand Up @@ -188,7 +190,7 @@ def lora_gemma(
supported for quantization currently.

Returns:
GemmaTransformerDecoder: Instantiation of Gemma model with LoRA applied to
TransformerDecoder: Instantiation of Gemma model with LoRA applied to
a subset of the attention projections in each layer.
"""
self_attn = lora_gemma_self_attention(
Expand Down Expand Up @@ -226,17 +228,17 @@ def lora_gemma(
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)

model = GemmaTransformerDecoder(
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
output_proj = TiedLinear(tok_embeddings)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
output=output_proj,
head_dim=head_dim,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
norm_embeddings=norm_embeddings,
norm=GemmaRMSNorm(embed_dim, eps=norm_eps)
)

if quantize_base:
Expand Down
18 changes: 9 additions & 9 deletions torchtune/models/gemma/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List, Optional

from torchtune.models.gemma._component_builders import gemma, lora_gemma
from torchtune.models.gemma.transformer import GemmaTransformerDecoder
from torchtune.modules import TransformerDecoder

from torchtune.models.gemma._tokenizer import GemmaTokenizer
from torchtune.modules.peft import LORA_ATTN_MODULES
Expand All @@ -21,13 +21,13 @@
"""


def gemma_2b() -> GemmaTransformerDecoder:
def gemma_2b() -> TransformerDecoder:
"""
Builder for creating a Gemma 2B model initialized w/ the default 2b parameter values
from: https://blog.google/technology/developers/gemma-open-models/

Returns:
GemmaTransformerDecoder: Instantiation of Gemma 2B model
TransformerDecoder: Instantiation of Gemma 2B model
"""
return gemma(
vocab_size=256_000,
Expand Down Expand Up @@ -71,7 +71,7 @@ def lora_gemma_2b(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Builder for creating a Gemma 2B model with LoRA enabled.

Expand All @@ -93,7 +93,7 @@ def lora_gemma_2b(
quantize_base (bool): Whether to quantize base model weights

Returns:
GemmaTransformerDecoder: Instantiation of Gemma 2B model with LoRA applied
TransformerDecoder: Instantiation of Gemma 2B model with LoRA applied
"""
return lora_gemma(
lora_attn_modules=lora_attn_modules,
Expand Down Expand Up @@ -125,13 +125,13 @@ def lora_gemma_2b(



def gemma_7b() -> GemmaTransformerDecoder:
def gemma_7b() -> TransformerDecoder:
"""
Builder for creating a Gemma 7B model initialized w/ the default 7b parameter values
from: https://blog.google/technology/developers/gemma-open-models/

Returns:
GemmaTransformerDecoder: Instantiation of Gemma 7B model
TransformerDecoder: Instantiation of Gemma 7B model
"""
return gemma(
vocab_size=256_000,
Expand All @@ -155,7 +155,7 @@ def lora_gemma_7b(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> GemmaTransformerDecoder:
) -> TransformerDecoder:
"""
Builder for creating a Gemma 7B model with LoRA enabled.

Expand All @@ -177,7 +177,7 @@ def lora_gemma_7b(
quantize_base (bool): Whether to quantize base model weights

Returns:
GemmaTransformerDecoder: Instantiation of Gemma 7B model with LoRA applied
TransformerDecoder: Instantiation of Gemma 7B model with LoRA applied
"""
return lora_gemma(
lora_attn_modules=lora_attn_modules,
Expand Down
18 changes: 18 additions & 0 deletions torchtune/models/gemma/gemma_norm_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class GemmaNormEmbeddings(nn.Embedding):
def __init__(self, in_dim: int, out_dim: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstrings, esp to explain why this is a separate class

super().__init__(in_dim, out_dim)
self.out_dim = out_dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = super().forward(x)
return x * torch.tensor(self.out_dim**0.5, dtype=x.dtype)
5 changes: 5 additions & 0 deletions torchtune/models/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
from torchtune.modules import KVCache

from torchtune.modules.transformer import _get_clones, TransformerSelfAttentionLayer
from torchtune.utils.logging import deprecated


@deprecated(
msg="Please use torchtune.modules.TransformerDecoder instead. \
If you need an example, see torchtune.models.gemma._component_builders.py"
)
class GemmaTransformerDecoder(nn.Module):
"""
GemmaTransformer Decoder derived from Gemma architecture. A key difference between
Expand Down
11 changes: 3 additions & 8 deletions torchtune/training/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,27 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Union

import torch
from torch import nn

from torchtune.modules import (
TiedEmbeddingTransformerDecoder,
TransformerDecoder,
TransformerSelfAttentionLayer,
)
from torchtune.modules import TransformerDecoder, TransformerSelfAttentionLayer
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.utils import get_logger, torch_version_ge

log = get_logger("INFO")


def compile_model(
model: Union[TransformerDecoder, TiedEmbeddingTransformerDecoder],
model: TransformerDecoder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wasn't this handled in 1/n?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot and left a comment that i was doing it in 2/n

verbose: bool = True,
) -> None:
"""
Utility to compile a transformer model inplace. On PyTorch nightlies we use per-layer compile
to reduce compile times. Otherwise we compile the full model, which takes longer.

Args:
model (Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]): A transformer model to compile.
model (TransformerDecoder): A transformer model to compile.
verbose (bool): Whether to log compile info. Default: True
Returns:
None
Expand Down
Loading