Skip to content

Commit

Permalink
1/n - remove TiedEmbeddingTransformerDecoder from qwen (pytorch#1547)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
  • Loading branch information
felipemello1 and Felipe Mello authored Sep 12, 2024
1 parent 01619ce commit 7c51100
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 72 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ Modeling Components and Building Blocks
RMSNorm
Fp32LayerNorm
TanhGate
TiedLinear
TransformerSelfAttentionLayer
TransformerCrossAttentionLayer
TransformerDecoder
TiedEmbeddingTransformerDecoder
VisionTransformer

Base Tokenizers
Expand Down
12 changes: 6 additions & 6 deletions recipes/configs/qwen2/1.5B_lora.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Config for multi-device LoRA finetuning in lora_finetune_distributed.py
# using a Qwen2 0.5B model
# using a Qwen2 1.5B model
#
# This config assumes that you've run the following command before launching
# this run:
Expand Down Expand Up @@ -27,18 +27,18 @@ model:

tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
path: /tmp/Qwen2-1.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-1.5B-Instruct/merges.txt
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune
output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
model_type: QWEN2
resume_from_checkpoint: False

Expand Down Expand Up @@ -67,7 +67,7 @@ max_steps_per_epoch: null
gradient_accumulation_steps: 8

# Logging
output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune
output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
Expand Down
77 changes: 29 additions & 48 deletions torchtune/models/qwen2/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import List, Union
from typing import List
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook

from torch import nn

from torchtune.modules.transformer import TransformerDecoder, TiedEmbeddingTransformerDecoder
from torchtune.modules.transformer import TransformerDecoder
from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings

from torchtune.modules import (
MultiHeadAttention,
FeedForward,
RMSNorm,
TransformerSelfAttentionLayer,
TiedLinear
)


Expand Down Expand Up @@ -48,7 +48,7 @@ def qwen2(
norm_eps: float = 1e-5,
rope_base: float = 1_000_000.0,
tie_word_embeddings: bool = False,
) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]:
) -> TransformerDecoder:
"""
Build the decoder associated with the Qwen2 model. This includes:
- Token embeddings
Expand Down Expand Up @@ -104,28 +104,20 @@ def qwen2(
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = None if tie_word_embeddings else nn.Linear(embed_dim, vocab_size, bias=False)
if output_proj is None:
return TiedEmbeddingTransformerDecoder(
tok_embeddings=tok_embeddings,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
)
if tie_word_embeddings:
output_proj = TiedLinear(tok_embeddings)
else:
return TransformerDecoder(
tok_embeddings=tok_embeddings,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
return TransformerDecoder(
tok_embeddings=tok_embeddings,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)


def qwen2_mlp(dim: int, hidden_dim: int) -> FeedForward:
Expand Down Expand Up @@ -162,7 +154,7 @@ def lora_qwen2(
use_dora: bool = False,
# Quantization args
quantize_base: bool = False,
) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]:
) -> TransformerDecoder:
"""
Return a version of Qwen2 (an instance of :func:`~torchtune.models.qwen2.transformer.Qwen2TransformerDecoder`)
with LoRA applied based on the passed in configuration.
Expand Down Expand Up @@ -251,7 +243,7 @@ def lora_qwen2(
"apply_lora_to_output is incompatible with tie_word_embeddings,"
" as there would be no output to apply lora to!"
)
output_proj = None
output_proj = TiedLinear(tok_embeddings)
else:
# TODO: quantize_base is not applied to final output_proj currently.
adapter_cls = DoRALinear if use_dora else LoRALinear
Expand All @@ -260,27 +252,16 @@ def lora_qwen2(
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
if output_proj is None:
model = TiedEmbeddingTransformerDecoder(
tok_embeddings=tok_embeddings,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=(embed_dim // num_heads),
norm=RMSNorm(embed_dim, eps=norm_eps),
)
else:
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=(embed_dim // num_heads),
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layers=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=(embed_dim // num_heads),
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)

if quantize_base:
# For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly
Expand Down
26 changes: 13 additions & 13 deletions torchtune/models/qwen2/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torchtune.models.qwen2._component_builders import qwen2, lora_qwen2
from torchtune.models.qwen2._tokenizer import Qwen2Tokenizer
from torchtune.modules import TransformerDecoder, TiedEmbeddingTransformerDecoder
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
from torchtune.data._prompt_templates import _TemplateType
Expand Down Expand Up @@ -42,17 +42,17 @@ def qwen2_7b() -> TransformerDecoder:
)


def qwen2_0_5b() -> TiedEmbeddingTransformerDecoder:
def qwen2_0_5b() -> TransformerDecoder:
"""
Builder for creating a Qwen2 model initialized w/ the default 0.5B parameter values
from https://huggingface.co/Qwen/Qwen2-0.5B-Instruct
Returns:
TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model
TransformerDecoder: Instantiation of Qwen2 0.5B model
Note:
Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default
and returns an instance of `TiedEmbeddingTransformerDecoder`.
and returns an instance of `TransformerDecoder`.
"""
return qwen2(
vocab_size=151936,
Expand All @@ -69,17 +69,17 @@ def qwen2_0_5b() -> TiedEmbeddingTransformerDecoder:
)


def qwen2_1_5b() -> TiedEmbeddingTransformerDecoder:
def qwen2_1_5b() -> TransformerDecoder:
"""
Builder for creating a Qwen2 model initialized w/ the default 1.5B parameter values
from https://huggingface.co/Qwen/Qwen2-1.5B-Instruct
Returns:
TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model
TransformerDecoder: Instantiation of Qwen2 1.5B model
Note:
Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default
and returns an instance of `TiedEmbeddingTransformerDecoder`.
and returns an instance of `TransformerDecoder`.
"""
return qwen2(
vocab_size=151936,
Expand Down Expand Up @@ -191,7 +191,7 @@ def lora_qwen2_0_5b(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> TiedEmbeddingTransformerDecoder:
) -> TransformerDecoder:
"""
Builder for creating a Qwen2 0.5B model with LoRA enabled.
Expand All @@ -211,11 +211,11 @@ def lora_qwen2_0_5b(
quantize_base (bool): Whether to quantize base model weights
Returns:
TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied
TransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied
Note:
Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default
and returns an instance of `TiedEmbeddingTransformerDecoder`.
and returns an instance of `TransformerDecoder`.
"""
return lora_qwen2(
lora_attn_modules=lora_attn_modules,
Expand Down Expand Up @@ -248,7 +248,7 @@ def lora_qwen2_1_5b(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
) -> TiedEmbeddingTransformerDecoder:
) -> TransformerDecoder:
"""
Builder for creating a Qwen2 1.5B model with LoRA enabled.
Expand All @@ -268,11 +268,11 @@ def lora_qwen2_1_5b(
quantize_base (bool): Whether to quantize base model weights
Returns:
TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied
TransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied
Note:
Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default
and returns an instance of `TiedEmbeddingTransformerDecoder`.
and returns an instance of `TransformerDecoder`.
"""
return lora_qwen2(
lora_attn_modules=lora_attn_modules,
Expand Down
2 changes: 2 additions & 0 deletions torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .position_embeddings import RotaryPositionalEmbeddings # noqa
from .rms_norm import RMSNorm # noqa
from .tanh_gate import TanhGate # noqa
from .tied_linear import TiedLinear # noqa
from .transformer import ( # noqa
TiedEmbeddingTransformerDecoder,
TransformerCrossAttentionLayer,
Expand All @@ -32,6 +33,7 @@
"KVCache",
"RotaryPositionalEmbeddings",
"RMSNorm",
"TiedLinear",
"Fp32LayerNorm",
"VisionTransformer",
"TransformerDecoder",
Expand Down
34 changes: 34 additions & 0 deletions torchtune/modules/tied_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F


class TiedLinear:
"""
A tied linear layer, without bias, that shares the same weight as another linear layer.
This is useful for models that use tied weights, such as qwen and gemma.
It requires as input an nn.Module, instead of the weight of the module, so it
can work with FSDP. Otherwise, the memory reference will be lost after FSDP is applied.
Args:
tied_module (nn.Module): The module whose weight is shared. Only
the weight is used. The bias is ignored.
Raises:
AttributeError: If the provided module does not have an attribute 'weight'.
"""

def __init__(self, tied_module: nn.Module):
self.tied_module = tied_module
if not hasattr(tied_module, "weight"):
raise AttributeError(
"Provided module does not have attribute 'weight'. Please check your tied_module."
)

def __call__(self, x: torch.tensor) -> torch.tensor:
return F.linear(x, self.tied_module.weight)
13 changes: 9 additions & 4 deletions torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn
from torchtune.modules import MultiHeadAttention

from torchtune.modules.attention_utils import _MaskType
from torchtune.utils.logging import deprecated


class TransformerSelfAttentionLayer(nn.Module):
Expand Down Expand Up @@ -295,7 +295,7 @@ class TransformerDecoder(nn.Module):
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
output (nn.Linear): Callable that applies a linear transformation to the output of
output (Union[nn.Linear, Callable]): Callable that applies a linear transformation to the output of
the decoder.
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when
layers is not a list.
Expand All @@ -320,7 +320,7 @@ def __init__(
num_heads: int,
head_dim: int,
norm: nn.Module,
output: nn.Linear,
output: Union[nn.Linear, Callable],
num_layers: Optional[int] = None,
output_hidden_states: Optional[List[int]] = None,
) -> None:
Expand Down Expand Up @@ -516,6 +516,11 @@ def forward(
return output


@deprecated(
msg="Please use torchtune.modules.TransformerDecoder instead. \
If you need an example, see torchtune.models.qwen2._component_builders.py \
and how to implement torch.modules.TiedLinear for the output projection."
)
class TiedEmbeddingTransformerDecoder(nn.Module):
"""
Transformer Decoder with tied embedding weight. A key difference between
Expand Down

0 comments on commit 7c51100

Please sign in to comment.