From 7dad2d6d214db79cf15b030bcfbab4f05e680e50 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 12 Sep 2024 16:25:46 -0400 Subject: [PATCH] 2/n - Make Gemma use regular TransformerDecoder (#1553) Co-authored-by: Felipe Mello --- recipes/configs/gemma/2B_lora.yaml | 6 +-- .../configs/gemma/2B_lora_single_device.yaml | 4 +- .../configs/gemma/2B_qlora_single_device.yaml | 4 +- recipes/configs/gemma/7B_lora.yaml | 4 +- .../configs/gemma/7B_lora_single_device.yaml | 2 +- .../configs/gemma/7B_qlora_single_device.yaml | 4 +- torchtune/models/gemma/_component_builders.py | 36 +++++++------- torchtune/models/gemma/_model_builders.py | 18 +++---- .../models/gemma/gemma_norm_embedding.py | 47 +++++++++++++++++++ torchtune/models/gemma/transformer.py | 5 ++ torchtune/training/_compile.py | 11 ++--- 11 files changed, 95 insertions(+), 46 deletions(-) create mode 100644 torchtune/models/gemma/gemma_norm_embedding.py diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 2e345791da..8e67fe2168 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -15,7 +15,6 @@ # # This config works only when the model is being fine-tuned on 2+ GPUs. - # Tokenizer tokenizer: _component_: torchtune.models.gemma.gemma_tokenizer @@ -33,7 +32,7 @@ model: lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: True lora_rank: 64 - lora_alpha: 16 + lora_alpha: 128 lora_dropout: 0.0 checkpointer: @@ -47,6 +46,7 @@ checkpointer: output_dir: /tmp/gemma-2b model_type: GEMMA resume_from_checkpoint: False + save_adapter_weights_only: False optimizer: @@ -55,7 +55,7 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 + num_warmup_steps: 10 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index 5b0af37ffa..8c322495ce 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -32,7 +32,7 @@ model: lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: True lora_rank: 64 - lora_alpha: 16 + lora_alpha: 128 lora_dropout: 0.0 checkpointer: @@ -54,7 +54,7 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 + num_warmup_steps: 10 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index d67d79bfaa..7ed60ce180 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -32,7 +32,7 @@ model: lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: True lora_rank: 64 - lora_alpha: 16 + lora_alpha: 128 lora_dropout: 0.0 checkpointer: @@ -54,7 +54,7 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 + num_warmup_steps: 10 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index b79bbb6845..5d0bcdb08f 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -33,7 +33,7 @@ model: lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: True lora_rank: 64 - lora_alpha: 16 + lora_alpha: 128 lora_dropout: 0.0 checkpointer: @@ -57,7 +57,7 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 + num_warmup_steps: 10 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index a8dfc1878e..aa69fa50f8 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -56,7 +56,7 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 + num_warmup_steps: 10 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index acb044ece9..8a08c49b5c 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -32,7 +32,7 @@ model: lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: True lora_rank: 64 - lora_alpha: 16 + lora_alpha: 128 lora_dropout: 0.0 checkpointer: @@ -56,7 +56,7 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 + num_warmup_steps: 10 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss diff --git a/torchtune/models/gemma/_component_builders.py b/torchtune/models/gemma/_component_builders.py index 0c1a723ab0..dbb4845f4a 100644 --- a/torchtune/models/gemma/_component_builders.py +++ b/torchtune/models/gemma/_component_builders.py @@ -16,9 +16,10 @@ RotaryPositionalEmbeddings, TransformerSelfAttentionLayer, ) -from torchtune.models.gemma.rms_norm import GemmaRMSNorm -from torchtune.models.gemma.transformer import GemmaTransformerDecoder +from torchtune.models.gemma.rms_norm import GemmaRMSNorm +from torchtune.modules import TransformerDecoder, TiedLinear +from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear """ @@ -47,7 +48,7 @@ def gemma( norm_eps: float = 1e-6, rope_base: int = 10_000, norm_embeddings: bool = True, -) -> GemmaTransformerDecoder: +) -> TransformerDecoder: """ Build the decoder associated with the gemma model. This includes: - Token embeddings @@ -76,7 +77,7 @@ def gemma( and mlp layers. Default: True Returns: - GemmaTransformerDecoder: Instantiation of gemma model. + TransformerDecoder: Instantiation of gemma model. """ rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_att = MultiHeadAttention( @@ -100,16 +101,17 @@ def gemma( sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), ) - tok_embeddings = nn.Embedding(vocab_size, embed_dim) - model = GemmaTransformerDecoder( + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) + output_proj = TiedLinear(tok_embeddings) + model = TransformerDecoder( tok_embeddings=tok_embeddings, - layer=layer, + layers=layer, num_layers=num_layers, max_seq_len=max_seq_len, num_heads=num_heads, + output=output_proj, head_dim=head_dim, - norm=GemmaRMSNorm(embed_dim, eps=norm_eps), - norm_embeddings=norm_embeddings, + norm=GemmaRMSNorm(embed_dim, eps=norm_eps) ) return model @@ -152,7 +154,7 @@ def lora_gemma( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, -) -> GemmaTransformerDecoder: +) -> TransformerDecoder: """ Return a version of Gemma with LoRA applied based on the passed in configuration. Note: output projection lora is not supported because it is tied to token embeddings @@ -188,7 +190,7 @@ def lora_gemma( supported for quantization currently. Returns: - GemmaTransformerDecoder: Instantiation of Gemma model with LoRA applied to + TransformerDecoder: Instantiation of Gemma model with LoRA applied to a subset of the attention projections in each layer. """ self_attn = lora_gemma_self_attention( @@ -226,17 +228,17 @@ def lora_gemma( sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), ) - tok_embeddings = nn.Embedding(vocab_size, embed_dim) - - model = GemmaTransformerDecoder( + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) + output_proj = TiedLinear(tok_embeddings) + model = TransformerDecoder( tok_embeddings=tok_embeddings, - layer=layer, + layers=layer, num_layers=num_layers, max_seq_len=max_seq_len, num_heads=num_heads, + output=output_proj, head_dim=head_dim, - norm=GemmaRMSNorm(embed_dim, eps=norm_eps), - norm_embeddings=norm_embeddings, + norm=GemmaRMSNorm(embed_dim, eps=norm_eps) ) if quantize_base: diff --git a/torchtune/models/gemma/_model_builders.py b/torchtune/models/gemma/_model_builders.py index 3e59cb2ef2..9c13409ec1 100644 --- a/torchtune/models/gemma/_model_builders.py +++ b/torchtune/models/gemma/_model_builders.py @@ -6,7 +6,7 @@ from typing import List, Optional from torchtune.models.gemma._component_builders import gemma, lora_gemma -from torchtune.models.gemma.transformer import GemmaTransformerDecoder +from torchtune.modules import TransformerDecoder from torchtune.models.gemma._tokenizer import GemmaTokenizer from torchtune.modules.peft import LORA_ATTN_MODULES @@ -21,13 +21,13 @@ """ -def gemma_2b() -> GemmaTransformerDecoder: +def gemma_2b() -> TransformerDecoder: """ Builder for creating a Gemma 2B model initialized w/ the default 2b parameter values from: https://blog.google/technology/developers/gemma-open-models/ Returns: - GemmaTransformerDecoder: Instantiation of Gemma 2B model + TransformerDecoder: Instantiation of Gemma 2B model """ return gemma( vocab_size=256_000, @@ -71,7 +71,7 @@ def lora_gemma_2b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, -) -> GemmaTransformerDecoder: +) -> TransformerDecoder: """ Builder for creating a Gemma 2B model with LoRA enabled. @@ -93,7 +93,7 @@ def lora_gemma_2b( quantize_base (bool): Whether to quantize base model weights Returns: - GemmaTransformerDecoder: Instantiation of Gemma 2B model with LoRA applied + TransformerDecoder: Instantiation of Gemma 2B model with LoRA applied """ return lora_gemma( lora_attn_modules=lora_attn_modules, @@ -125,13 +125,13 @@ def lora_gemma_2b( -def gemma_7b() -> GemmaTransformerDecoder: +def gemma_7b() -> TransformerDecoder: """ Builder for creating a Gemma 7B model initialized w/ the default 7b parameter values from: https://blog.google/technology/developers/gemma-open-models/ Returns: - GemmaTransformerDecoder: Instantiation of Gemma 7B model + TransformerDecoder: Instantiation of Gemma 7B model """ return gemma( vocab_size=256_000, @@ -155,7 +155,7 @@ def lora_gemma_7b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, -) -> GemmaTransformerDecoder: +) -> TransformerDecoder: """ Builder for creating a Gemma 7B model with LoRA enabled. @@ -177,7 +177,7 @@ def lora_gemma_7b( quantize_base (bool): Whether to quantize base model weights Returns: - GemmaTransformerDecoder: Instantiation of Gemma 7B model with LoRA applied + TransformerDecoder: Instantiation of Gemma 7B model with LoRA applied """ return lora_gemma( lora_attn_modules=lora_attn_modules, diff --git a/torchtune/models/gemma/gemma_norm_embedding.py b/torchtune/models/gemma/gemma_norm_embedding.py new file mode 100644 index 0000000000..9d3a696ea5 --- /dev/null +++ b/torchtune/models/gemma/gemma_norm_embedding.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +class GemmaNormEmbeddings(nn.Embedding): + """Module with Embedding and normalization specific to Gemma. + Gemma requires normalization right after the embeddings. By merging both + steps in a single module, we can utilize directly + :class:`~torch.modules.TransformerDecoder`. + + For more details about the embedding module, please see + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html + + Args: + num_embeddings (int): size of the dictionary of embeddings. + embedding_dim (int): the size of each embedding vector. + *args: Variable length argument list to be passed to the Embedding module. + **kwargs: Arbitrary keyword arguments to be passed to the Embedding module. + + Example: + >>> import torch + >>> from torchtune.models.gemma import GemmaNormEmbeddings + >>> embeddings = GemmaNormEmbeddings(2, 4) + >>> x = torch.randint(0, 2, (1, 3)) # ids can be 0 or 1 + >>> print(x) + >>> print(embeddings(x)) + >>> print(embeddings(x).shape) + tensor([[1, 0, 0]]) + tensor([[[-0.2152, -2.1914, 2.8491, -0.4824], + [-3.6621, -1.0267, 1.5947, -1.7349], + [-3.6621, -1.0267, 1.5947, -1.7349]]], grad_fn=) + torch.Size([1, 3, 4]) + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, *args, **kwargs): + super().__init__(num_embeddings, embedding_dim, *args, **kwargs) + self.embedding_dim = embedding_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = super().forward(x) + return x * torch.tensor(self.embedding_dim**0.5, dtype=x.dtype) diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index 82f2847bd9..a932895d71 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -12,8 +12,13 @@ from torchtune.modules import KVCache from torchtune.modules.transformer import _get_clones, TransformerSelfAttentionLayer +from torchtune.utils.logging import deprecated +@deprecated( + msg="Please use torchtune.modules.TransformerDecoder instead. \ +If you need an example, see torchtune.models.gemma._component_builders.py" +) class GemmaTransformerDecoder(nn.Module): """ GemmaTransformer Decoder derived from Gemma architecture. A key difference between diff --git a/torchtune/training/_compile.py b/torchtune/training/_compile.py index 3f8d8c279e..893093a753 100644 --- a/torchtune/training/_compile.py +++ b/torchtune/training/_compile.py @@ -5,16 +5,11 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Union import torch from torch import nn -from torchtune.modules import ( - TiedEmbeddingTransformerDecoder, - TransformerDecoder, - TransformerSelfAttentionLayer, -) +from torchtune.modules import TransformerDecoder, TransformerSelfAttentionLayer from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.utils import get_logger, torch_version_ge @@ -22,7 +17,7 @@ def compile_model( - model: Union[TransformerDecoder, TiedEmbeddingTransformerDecoder], + model: TransformerDecoder, verbose: bool = True, ) -> None: """ @@ -30,7 +25,7 @@ def compile_model( to reduce compile times. Otherwise we compile the full model, which takes longer. Args: - model (Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]): A transformer model to compile. + model (TransformerDecoder): A transformer model to compile. verbose (bool): Whether to log compile info. Default: True Returns: None