diff --git a/README.md b/README.md index a66f69e73..6c8a987ac 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,10 @@ The library currently supports the following models and fine-tuning methods. | Model | Sizes | Finetuning Methods | |-----------------------------------------------|-----------|-----------------------------------------------------------| -| [Llama2](torchtune/models/llama2.py) | 7B | Full Finetuning [[single device](recipes/full_finetune_single_device.py), [distributed](recipes/full_finetune_distributed.py)], LoRA [[single device](recipes/lora_finetune_single_device.py), [distributed](recipes/lora_finetune_distributed.py)], QLoRA [single device](recipes/lora_finetune_single_device.py) | +| [Llama2](torchtune/models/llama2/_model_builders.py) | 7B | Full Finetuning [[single device](recipes/configs/llama2/7B_full_single_device.yaml), [distributed](recipes/configs/llama2/7B_full.yaml)] LoRA [[single device](recipes/configs/llama2/7B_lora_single_device.yaml), [distributed](recipes/configs/llama2/7B_lora.yaml)] QLoRA [single device](recipes/configs/llama2/7B_qlora_single_device.yaml) | +| [Llama2](torchtune/models/llama2/_model_builders.py) | 13B | [Full Finetuning](recipes/configs/llama2/13B_full.yaml), [LoRA](recipes/configs/llama2/13B_lora.yaml) +| [Mistral](torchtune/models/mistral//_model_builders.py) | 7B | Full Finetuning and LoRA are WIP and will be added soon +   @@ -49,11 +52,11 @@ experience different peak memory utilization based on changes made in configurat | Example HW Resources | Finetuning Method | Config | Model Size | Peak Memory per GPU |--------------|-------------------|---------|------------|---------------------| -| 1 x RTX 4090 | QLoRA | [qlora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/qlora_finetune_single_device.yaml) | 7B | 9.29 GB * | -| 2 x RTX 4090 | LoRA | [lora_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/lora_finetune_distributed.yaml) | 7B | 14.17 GB * | -| 1 x RTX 4090 | LoRA | [lora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/lora_finetune_single_device.yaml) | 7B | 17.18 GB * | -| 1 x A6000 | Full finetune | [full_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/full_finetune_single_device.yaml) | 7B | 27.15 GB * | -| 4 x RTX 4090 | Full finetune | [full_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/full_finetune_distributed.yaml) | 7B | 12.01 GB * | +| 1 x RTX 4090 | QLoRA | [qlora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml) | 7B | 9.29 GB * | +| 2 x RTX 4090 | LoRA | [lora_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_lora.yaml) | 7B | 14.17 GB * | +| 1 x RTX 4090 | LoRA | [lora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_lora_single_device.yaml) | 7B | 17.18 GB * | +| 1 x A6000 | Full finetune | [full_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_single_device.yaml) | 7B | 27.15 GB * | +| 4 x RTX 4090 | Full finetune | [full_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full.yaml) | 7B | 12.01 GB * | NOTE: * indicates an estimated metric based on experiments conducted on A100 GPUs with GPU memory artificially limited using [torch.cuda.set_per_process_memory_fraction API](https://pytorch.org/docs/stable/generated/torch.cuda.set_per_process_memory_fraction.html). Peak memory per GPU is as reported by `torch.cuda.max_memory_reserved()`. Please file an issue if you are not able to reproduce these results when running TorchTune on certain hardware. @@ -117,24 +120,24 @@ Note: While the ``tune download`` command allows you to download *any* model fro TorchTune contains recipes for: - Full finetuning on [single device](https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_single_device.py) and on [multiple devices with FSDP](https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py) - LoRA finetuning on [single device](https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_single_device.py) and on [multiple devices with FSDP](https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_distributed.py). -- QLoRA finetuning on [single device](https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_single_device.py), with a QLoRA specific [configuration](https://github.com/pytorch/torchtune/blob/main/recipes/configs/qlora_finetune_single_device.yaml) +- QLoRA finetuning on [single device](https://github.com/pytorch/torchtune/blob/main/recipes/lora_finetune_single_device.py), with a QLoRA specific [configuration](https://github.com/pytorch/torchtune/blob/main/recipes/configs/7B_qlora_single_device.yaml) -To run a full finetune on two devices on the Alpaca dataset using FSDP: +To run a full finetune on two devices on the Alpaca dataset using the Llama2 7B model and FSDP: ``` tune --nnodes 1 --nproc_per_node 2 \ full_finetune_distributed \ ---config full_finetune_distributed +--config llama2/7B_full ``` The argument passed to `--nproc_per_node` can be varied depending on how many GPUs you have. A full finetune can be memory-intensive, so make sure you are running on enough devices. See [this table](https://github.com/pytorch/torchtune/blob/main/README.md#finetuning-resource-requirements) for resource requirements on common hardware setups. -Similarly, you can finetune with LoRA on the Alpaca dataset on two devices via the following. +Similarly, you can finetune with LoRA on the Alpaca dataset using the Llama2 13B model on two devices via the following. ``` tune --nnodes 1 --nproc_per_node 2 \ lora_finetune_distributed \ ---config lora_finetune_distributed +--config llama2/13B_lora ``` Again, the argument to `--nproc_per_node` can be varied subject to memory constraints of your device(s). @@ -142,7 +145,7 @@ Again, the argument to `--nproc_per_node` can be varied subject to memory constr An example to run QLoRA on a single device can be achieved with the following: ``` -tune lora_finetune_single_device --config recipes/configs/qlora_finetune_single_device.yaml +tune lora_finetune_single_device --config recipes/configs/llama2/7B_qlora_single_device ```   @@ -152,8 +155,8 @@ tune lora_finetune_single_device --config recipes/configs/qlora_finetune_single_ To copy a recipe to customize it yourself and then run ``` tune cp full_finetune_distributed.py my_recipe/full_finetune_distributed.py -tune cp full_finetune_distributed.yaml my_recipe/full_finetune_distributed.yaml -tune my_recipe/full_finetune_distributed.py --config my_recipe/full_finetune_distributed.yaml +tune cp llama2/7B_full.yaml my_recipe/7B_full.yaml +tune my_recipe/full_finetune_distributed.py --config my_recipe/7B_full.yaml ```   diff --git a/docs/source/examples/configs.rst b/docs/source/examples/configs.rst index 648d11395..ef564fa1a 100644 --- a/docs/source/examples/configs.rst +++ b/docs/source/examples/configs.rst @@ -161,7 +161,7 @@ will list out all the locations where an error was found. .. code-block:: bash - tune validate --config recipes/configs/full_finetune_single_device.yaml batch_size=4 + tune validate --config recipes/configs/llama2/7B_full.yaml batch_size=4 Best practices for writing configs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/examples/first_finetune_tutorial.rst b/docs/source/examples/first_finetune_tutorial.rst index cd6319615..cdb86d350 100644 --- a/docs/source/examples/first_finetune_tutorial.rst +++ b/docs/source/examples/first_finetune_tutorial.rst @@ -78,7 +78,7 @@ It looks like there's already a config called :code:`alpaca_llama_full_finetune` .. code-block:: bash - tune cp full_finetune_distributed.yaml custom_config.yaml + tune cp llama2/7B_full.yaml custom_config.yaml Now you can update the custom YAML config to point to your model and tokenizer. While you're at it, you can make some other changes, like setting the random seed in order to make replication easier, diff --git a/docs/source/examples/lora_finetune.rst b/docs/source/examples/lora_finetune.rst index e3393b6e9..5f14b236a 100644 --- a/docs/source/examples/lora_finetune.rst +++ b/docs/source/examples/lora_finetune.rst @@ -258,7 +258,7 @@ You can then run the following command to perform a LoRA finetune of Llama2-7B u .. note:: Make sure to point to the location of your Llama2 weights and tokenizer. This can be done either by adding :code:`checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path` - or by directly modifying the :code:`lora_finetune_distributed.yaml` file. See our :ref:`config_tutorial_label` + or by directly modifying the :code:`7B_lora.yaml` file. See our :ref:`config_tutorial_label` for more details on how you can easily clone and modify TorchTune configs. .. note:: diff --git a/docs/source/examples/recipe_deepdive.rst b/docs/source/examples/recipe_deepdive.rst index b80a09736..f1a0bcc37 100644 --- a/docs/source/examples/recipe_deepdive.rst +++ b/docs/source/examples/recipe_deepdive.rst @@ -43,7 +43,7 @@ Each recipe consists of three components: In the following sections, we'll take a closer look at each of these components. For a complete working example, refer to the `full finetuning recipe `_ in TorchTune and the associated -`config `_. +`config `_. What Recipes are not? diff --git a/recipes/README.md b/recipes/README.md index 198042189..53b0a2ffd 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -6,7 +6,7 @@ Recipes are the primary entry points for TorchTune users. These can be thought of as end-to-end pipelines for training and optionally evaluating LLMs. Each recipe consists of three components: -- **Configurable parameters**, specified through yaml configs [example](https://github.com/pytorch/torchtune/blob/main/recipes/configs/full_finetune_distributed.yaml) and command-line overrides +- **Configurable parameters**, specified through yaml configs [example](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full.yaml) and command-line overrides - **Recipe class**, core logic needed for training, exposed to users through a set of APIs [interface](https://github.com/pytorch/torchtune/blob/main/recipes/interfaces.py) - **Recipe script**, puts everything together including parsing and validating configs, setting up the environment, and correctly using the recipe class diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml new file mode 100644 index 000000000..872533ad6 --- /dev/null +++ b/recipes/configs/llama2/13B_full.yaml @@ -0,0 +1,85 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama2 13B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download --repo-id meta-llama/Llama-2-13b-hf \ +# --hf-token \ +# --output-dir /tmp/llama2-13b-hf +# +# To launch on 4 devices, run the following command from root: +# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# --config llama2/13B_full \ +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# --config llama2/13B_full \ +# checkpointer.checkpoint_dir= +# +# This config should be used with 2+ GPUs. Single device full fine-tuning +# requires several memory optimizations which are exposed through +# 7B_full_single_device.yaml. Please update the model and checkpoints to 13B +# in that config. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: True +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_13b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/llama2-13b-hf/ + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + recipe_checkpoint: null + output_dir: /tmp/llama2-13b-hf/ + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Distributed +cpu_offload: False + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: null diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml new file mode 100644 index 000000000..2a338272d --- /dev/null +++ b/recipes/configs/llama2/13B_lora.yaml @@ -0,0 +1,90 @@ +# Config for multi-device LoRA in lora_finetune_distributed.py +# using a Llama2 13B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download --repo-id meta-llama/Llama-2-13b-hf \ +# --hf-token \ +# --output-dir /tmp/llama2-13b-hf +# +# To launch on 4 devices, run the following command from root: +# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ +# --config llama2/13B_lora \ +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ +# --config llama2/13B_lora \ +# checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device lora finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml and update the model and checkpoints to +# the 13B model. + + +# Model Arguments +model: + _component_: torchtune.models.llama2.lora_llama2_13b + lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: True + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/llama2-13b-hf/ + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/llama2-13b-hf/ + model_type: LLAMA2 +resume_from_checkpoint: False + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: True + use_clean: True +seed: null +shuffle: True +batch_size: 32 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 2e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + +# Training +epochs: 1 +max_steps_per_epoch: null + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: null + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml new file mode 100644 index 000000000..74b5d5a8f --- /dev/null +++ b/recipes/configs/llama2/7B_full.yaml @@ -0,0 +1,80 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download --repo-id meta-llama/Llama-2-7b \ +# --hf-token \ +# --output-dir /tmp/llama2 +# +# To launch on 4 devices, run the following command from root: +# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# --config llama2/7B_full \ +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ +# --config llama2/7B_full \ +# checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 7B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: True +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/llama2 + checkpoint_files: [consolidated.00.pth] + recipe_checkpoint: null + output_dir: /tmp/llama2 + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Distributed +cpu_offload: False + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: null diff --git a/recipes/configs/full_finetune_single_device.yaml b/recipes/configs/llama2/7B_full_single_device.yaml similarity index 58% rename from recipes/configs/full_finetune_single_device.yaml rename to recipes/configs/llama2/7B_full_single_device.yaml index 561b85e45..e657e5fc7 100644 --- a/recipes/configs/full_finetune_single_device.yaml +++ b/recipes/configs/llama2/7B_full_single_device.yaml @@ -1,9 +1,25 @@ -# Config for FullFinetuneRecipe in full_finetune_single_device.py +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama2 7B model # -# To launch, run the following command from root: +# This config assumes that you've run the following command before launching +# this run: +# tune download --repo-id meta-llama/Llama-2-7b \ +# --hf-token \ +# --output-dir /tmp/llama2 +# +# To launch on a single device, run the following command from root: +# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \ +# --config llama2/7B_full_single_device \ +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: # tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \ -# --config full_finetune_single_device \ -# checkpointer.checkpoint_dir== ... +# --config llama2/7B_full_single_device \ +# checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + # Tokenizer tokenizer: diff --git a/recipes/configs/lora_finetune_distributed.yaml b/recipes/configs/llama2/7B_lora.yaml similarity index 58% rename from recipes/configs/lora_finetune_distributed.yaml rename to recipes/configs/llama2/7B_lora.yaml index 4797db08a..25175af97 100644 --- a/recipes/configs/lora_finetune_distributed.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -1,7 +1,27 @@ -# Config for LoRAFinetuneDistributedRecipe in lora_finetune_distributed.py +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a Llama2 7B model # -# To launch, run the following command from root: -# tune --nnodes 1 --nproc_per_node 1 lora_finetune_distributed --config alpaca_llama2_lora_finetune_distributed model_checkpoint= ... +# This config assumes that you've run the following command before launching +# this run: +# tune download --repo-id meta-llama/Llama-2-7b \ +# --hf-token \ +# --output-dir /tmp/llama2 +# +# To launch on 4 devices, run the following command from root: +# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ +# --config llama2/7B_lora \ +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ +# --config llama2/7B_lora \ +# checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device lora finetuning please use 7B_lora_single_device.yaml +# or 7B_qlora_single_device.yaml + # Model Arguments model: diff --git a/recipes/configs/lora_finetune_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml similarity index 61% rename from recipes/configs/lora_finetune_single_device.yaml rename to recipes/configs/llama2/7B_lora_single_device.yaml index f86e61551..6c950d4e9 100644 --- a/recipes/configs/lora_finetune_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -1,9 +1,25 @@ -# Config for LoRAFinetuneRecipeSingleDevice in lora_finetune_single_device.py +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama2 7B model # -# To launch, run the following command from root: -# tune lora_finetune_single_device \ -# --config lora_finetune_single_device \ -# checkpointer.checkpoint_dir= ... +# This config assumes that you've run the following command before launching +# this run: +# tune download --repo-id meta-llama/Llama-2-7b \ +# --hf-token \ +# --output-dir /tmp/llama2 +# +# To launch on a single device, run the following command from root: +# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ +# --config llama2/7B_lora_single_device \ +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ +# --config 7B_lora_single_device \ +# checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + # Model Arguments model: diff --git a/recipes/configs/qlora_finetune_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml similarity index 62% rename from recipes/configs/qlora_finetune_single_device.yaml rename to recipes/configs/llama2/7B_qlora_single_device.yaml index 361b85723..1ae29e57a 100644 --- a/recipes/configs/qlora_finetune_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -1,9 +1,24 @@ -# QLoRA specific config for LoRAFinetuneRecipeSingleDevice in lora_finetune_single_device.py +# Config for single device QLoRA with lora_finetune_single_device.py +# using a Llama2 7B model # -# To launch, run the following command from root: -# tune lora_finetune_single_device \ -# --config qlora_finetune_single_device \ -# checkpointer.checkpoint_dir= ... +# This config assumes that you've run the following command before launching +# this run: +# tune download --repo-id meta-llama/Llama-2-7b \ +# --hf-token \ +# --output-dir /tmp/llama2 +# +# To launch on a single device, run the following command from root: +# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ +# --config 7B_qlora_single_device \ +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune --nnodes 1 --nproc_per_node 1 lora_finetune_single_device \ +# --config 7B_qlora_single_device \ +# checkpointer.checkpoint_dir= +# +# This config works only for training on single device. # Model Arguments model: diff --git a/recipes/configs/full_finetune_distributed.yaml b/recipes/configs/mistral/7B_full.yaml similarity index 57% rename from recipes/configs/full_finetune_distributed.yaml rename to recipes/configs/mistral/7B_full.yaml index ac4588256..ba85749eb 100644 --- a/recipes/configs/full_finetune_distributed.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -1,13 +1,9 @@ -# Config for FullFinetuneRecipe in full_finetune_distributed.py -# -# To launch, run the following command from root: -# tune --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ -# --config full_finetune_distributed \ -# checkpointer.checkpoint_dir= ... +# This config is currently a WIP. Use it with caution + # Tokenizer tokenizer: - _component_: torchtune.models.llama2.llama2_tokenizer + _component_: torchtune.models.mistral.mistral_tokenizer path: /tmp/llama2/tokenizer.model # Dataset @@ -19,22 +15,25 @@ shuffle: True # Model Arguments model: - _component_: torchtune.models.llama2.llama2_7b + _component_: torchtune.models.mistral.mistral_7b checkpointer: - _component_: torchtune.utils.FullModelMetaCheckpointer - checkpoint_dir: /tmp/llama2 - checkpoint_files: [consolidated.00.pth] + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Mistral-7B-v0.1 + checkpoint_files: [ + pytorch_model-00001-of-00002.bin, + pytorch_model-00002-of-00002.bin + ] recipe_checkpoint: null - output_dir: /tmp/llama2 + output_dir: /tmp/Mistral-7B-v0.1 model_type: LLAMA2 resume_from_checkpoint: False # Fine-tuning arguments -batch_size: 2 +batch_size: 32 epochs: 3 optimizer: - _component_: torch.optim.SGD + _component_: torch.optim.AdamW lr: 2e-5 loss: _component_: torch.nn.CrossEntropyLoss diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 40512e5ff..6c3102662 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -59,7 +59,8 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): The following configs can be used to run this recipe: >>> tune ls RECIPE CONFIG - full_finetune_distributed full_finetune_distributed + full_finetune_distributed llama2/7B_full + llama2/13B_full Args: cfg (DictConfig): OmegaConf object parsed from yaml file @@ -479,7 +480,7 @@ def recipe_main(cfg: DictConfig) -> None: Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``full_finetune_distributed.yaml`` + - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ if not utils.is_distributed(): diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 7ec4be248..dbe44d6ea 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -49,7 +49,7 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): The following configs can be used to run this recipe: >>> tune ls RECIPE CONFIG - full_finetune_single_device full_finetune_single_device + full_finetune_single_device llama2/7B_full_single_device Args: cfg (DictConfig): OmegaConf object parsed from yaml file @@ -377,7 +377,7 @@ def recipe_main(cfg: DictConfig) -> None: Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``full_finetune_single_device.yaml`` + - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ recipe = FullFinetuneRecipeSingleDevice(cfg=cfg) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 3a2bf2ac5..2451c8bf4 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -60,7 +60,8 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): The following configs can be used to run this recipe: >>> tune ls RECIPE CONFIG - lora_finetune_distributed lora_finetune_distributed + lora_finetune_distributed llama2/7B_lora + llama2/13B_lora Args: cfg (DictConfig): OmegaConf object parsed from yaml file @@ -80,20 +81,7 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - # For CUDA devices, check if the HW supports bf16 if bf16 is specified. - if ( - self._dtype == torch.bfloat16 - and self._device != torch.device("cpu") - and not torch.cuda.is_bf16_supported() - ): - raise RuntimeError("Full bf16 training is not supported on this hardware.") - - world_size, rank = utils.get_world_size_and_rank() - if world_size == 1: - raise ValueError( - "This recipe doesn't support training with world_size = 1." - "Please use the single device version of the recipe instead." - ) + _, rank = utils.get_world_size_and_rank() # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this @@ -103,6 +91,7 @@ def __init__(self, cfg: DictConfig) -> None: self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 self._log_peak_memory_every_n_steps = 100 + # training attributes self._enable_activation_checkpointing = cfg.enable_activation_checkpointing @@ -558,7 +547,7 @@ def recipe_main(cfg: DictConfig) -> None: Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``lora_finetune_distributed.yaml`` + - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ if not utils.is_distributed(): diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 310bed641..ec7a301f6 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -53,7 +53,8 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): The following configs can be used to run this recipe: >>> tune ls RECIPE CONFIG - lora_finetune_single_device lora_finetune_single_device + lora_finetune_single_device llama2/7B_lora_single_device + llama2/7B_qlora_single_device Args: cfg (DictConfig): OmegaConf object parsed from yaml file @@ -442,7 +443,7 @@ def recipe_main(cfg: DictConfig) -> None: Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``lora_finetune_single_device.yaml`` + - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ recipe = LoRAFinetuneRecipeSingleDevice(cfg=cfg) diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 293fcddae..3afd37c1c 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -59,7 +59,7 @@ def test_loss(self, tmpdir, monkeypatch): cmd = f""" tune --nnodes 1 --nproc_per_node 2 full_finetune_distributed - --config full_finetune_distributed \ + --config llama2/7B_full \ output_dir={tmpdir} \ checkpointer._component_=torchtune.utils.FullModelHFCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 419c40223..1d58a140e 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -59,7 +59,7 @@ def test_loss(self, tmpdir, monkeypatch): cmd = f""" tune full_finetune_single_device - --config full_finetune_single_device \ + --config llama2/7B_full_single_device \ output_dir={tmpdir} \ checkpointer._component_=torchtune.utils.FullModelMetaCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -105,7 +105,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Train for two epochs cmd_1 = f""" tune full_finetune_single_device - --config full_finetune_single_device \ + --config llama2/7B_full_single_device \ output_dir={tmpdir} \ checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -124,7 +124,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Resume training cmd_2 = f""" tune full_finetune_single_device - --config full_finetune_single_device \ + --config llama2/7B_full_single_device \ output_dir={tmpdir} \ checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ @@ -187,7 +187,7 @@ def test_gradient_accumulation(self, tmpdir, monkeypatch): cmd_1 = f""" tune full_finetune_single_device \ - --config full_finetune_single_device \ + --config llama2/7B_full_single_device \ checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -213,7 +213,7 @@ def test_gradient_accumulation(self, tmpdir, monkeypatch): # Update the cmd with new values for gradient accumulation cmd_2 = f""" tune full_finetune_single_device \ - --config full_finetune_single_device \ + --config llama2/7B_full_single_device \ checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index 71de15e87..7027a4008 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -57,7 +57,7 @@ def test_loss(self, tmpdir, monkeypatch): log_file = gen_log_file_name(tmpdir) cmd = f""" tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed - --config lora_finetune_distributed \ + --config llama2/7B_lora \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -110,7 +110,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Train for two epochs cmd_1 = f""" tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed - --config lora_finetune_distributed \ + --config llama2/7B_lora \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -134,7 +134,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Resume training cmd_2 = f""" tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed - --config lora_finetune_distributed \ + --config llama2/7B_lora \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ @@ -167,7 +167,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): ckpt_dir = ckpt_path.parent cmd = f""" tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed - --config lora_finetune_distributed \ + --config llama2/7B_lora \ output_dir={tmpdir} \ model=torchtune.models.lora_small_test_model \ checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index d9f86f1b9..034ca679d 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -55,7 +55,7 @@ def test_loss(self, tmpdir, monkeypatch): cmd = f""" tune lora_finetune_single_device - --config lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelMetaCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -107,7 +107,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Train for two epochs cmd_1 = f""" tune lora_finetune_single_device - --config lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -132,7 +132,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Resume training cmd_2 = f""" tune lora_finetune_single_device - --config lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ @@ -165,7 +165,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): cmd = f""" tune lora_finetune_single_device - --config lora_finetune_single_device \ + --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ diff --git a/tests/regression_tests/test_llama2_7b.py b/tests/regression_tests/test_llama2_7b.py index e6de5e9cb..a4743711c 100644 --- a/tests/regression_tests/test_llama2_7b.py +++ b/tests/regression_tests/test_llama2_7b.py @@ -56,7 +56,7 @@ def test_loss(self, tmpdir, monkeypatch): cmd = f""" tune --nnodes 1 --nproc_per_node 2 full_finetune_distributed - --config full_finetune_distributed \ + --config llama2/7B_full \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -100,7 +100,7 @@ def test_finetune_and_eval(self, tmpdir, capsys, monkeypatch): # Run on prod LoRA FT config but with only 10 steps for now ft_cmd = f""" tune --nnodes 1 --nproc_per_node 2 lora_finetune_distributed - --config lora_finetune_distributed \ + --config llama2/7B_lora \ output_dir={tmpdir} \ checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ diff --git a/tests/torchtune/_cli/test_cp.py b/tests/torchtune/_cli/test_cp.py index 060efb22c..499ec4bce 100644 --- a/tests/torchtune/_cli/test_cp.py +++ b/tests/torchtune/_cli/test_cp.py @@ -23,7 +23,7 @@ def test_copy_successful(self, capsys, monkeypatch, tmpdir, already_exists): if already_exists: dest.touch() - args = f"tune cp full_finetune_single_device.yaml {dest}".split() + args = f"tune cp llama2/7B_full.yaml {dest}".split() monkeypatch.setattr(sys, "argv", args) runpy.run_path(TUNE_PATH, run_name="__main__") @@ -41,7 +41,7 @@ def test_copy_skips_when_dest_already_exists_and_no_clobber_is_true( existing_file = tmpdir_path / "existing_file.yaml" existing_file.touch() - args = f"tune cp full_finetune_single_device.yaml {existing_file} -n".split() + args = f"tune cp llama2/7B_full_single_device.yaml {existing_file} -n".split() monkeypatch.setattr(sys, "argv", args) runpy.run_path(TUNE_PATH, run_name="__main__") diff --git a/torchtune/__init__.py b/torchtune/__init__.py index 556b79d89..00754930e 100644 --- a/torchtune/__init__.py +++ b/torchtune/__init__.py @@ -15,10 +15,13 @@ "eleuther_eval.py", ] _CONFIG_LISTS = { - "full_finetune_single_device.py": ["full_finetune_single_device.yaml"], - "full_finetune_distributed.py": ["full_finetune_distributed.yaml"], - "lora_finetune_single_device.py": ["lora_finetune_single_device.yaml"], - "lora_finetune_distributed.py": ["lora_finetune_distributed.yaml"], + "full_finetune_single_device.py": ["llama2/7B_full_single_device.yaml"], + "full_finetune_distributed.py": ["llama2/7B_full.yaml", "llama2/13B_full.yaml"], + "lora_finetune_single_device.py": [ + "llama2/7B_lora_single_device.yaml", + "llama2/7B_qlora_single_device.yaml", + ], + "lora_finetune_distributed.py": ["llama2/7B_lora.yaml", "llama2/13B_lora.yaml"], "alpaca_generate.py": ["alpaca_generate.yaml"], "eleuther_eval.py": ["eleuther_eval.yaml"], } diff --git a/torchtune/_cli/cp.py b/torchtune/_cli/cp.py index 1ae135dbb..a403527de 100644 --- a/torchtune/_cli/cp.py +++ b/torchtune/_cli/cp.py @@ -66,7 +66,7 @@ def main(parser): epilog=textwrap.dedent( """\ examples: - $ tune cp lora_finetune_distributed.yaml ./my_custom_llama2_lora.yaml + $ tune cp llama2/7B_lora.yaml ./my_custom_llama2_lora.yaml $ tune cp full_finetune_distributed.py ./my_custom_full_finetune.py $ tune cp full_finetune_distributed.py ./new_dir/my_custom_full_finetune.py --make-parents diff --git a/torchtune/_cli/ls.py b/torchtune/_cli/ls.py index ce1369c2b..3e7f24c0e 100644 --- a/torchtune/_cli/ls.py +++ b/torchtune/_cli/ls.py @@ -16,7 +16,7 @@ def main(): # Print table header - header = f"{'RECIPE':<20} {'CONFIG':<15}" + header = f"{'RECIPE':<40} {'CONFIG':<40}" print(header) # Print recipe/config pairs @@ -24,14 +24,14 @@ def main(): configs = list_configs(recipe) # If there are no configs for a recipe, print a blank config if len(configs) == 0: - row = f"{recipe:<20} {_NULL_VALUE:<15}" + row = f"{recipe:<40} {_NULL_VALUE:<40}" print(row) for i, config in enumerate(configs): # If there are multiple configs for a single recipe, omit the recipe name # on latter configs if i > 0: recipe = "" - row = f"{recipe:<20} {config:<15}" + row = f"{recipe:<40} {config:<40}" print(row) @@ -44,12 +44,14 @@ def main(): examples: $ tune ls RECIPE CONFIG - full_finetune_distributed.py full_finetune_distributed.yaml - lora_finetune_distributed.py lora_finetune_distributed.yaml + full_finetune_distributed.py llama2/7B_full, + llama2/13B_full + lora_finetune_distributed.py llama2/7B_lora, + llama2/13B_lora alpaca_generate.py alpaca_generate.yaml To run one of these recipes: - $ tune full_finetune_single_device --config full_finetune_single_device + $ tune full_finetune_single_device --config llama2/7B_full_single_device """ ), formatter_class=argparse.RawTextHelpFormatter, diff --git a/torchtune/_cli/validate.py b/torchtune/_cli/validate.py index 8ac864c7d..1c054d866 100644 --- a/torchtune/_cli/validate.py +++ b/torchtune/_cli/validate.py @@ -25,7 +25,7 @@ def main(cfg: DictConfig): epilog=textwrap.dedent( """\ examples: - $ tune validate --config recipes/configs/full_finetune_distributed.yaml + $ tune validate --config recipes/configs/llama2/7B_lora.yaml Config is well-formed! """ ), diff --git a/torchtune/models/__init__.py b/torchtune/models/__init__.py index 6e6d39e27..8b19aba83 100644 --- a/torchtune/models/__init__.py +++ b/torchtune/models/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.models import llama2 # noqa +from torchtune.models import convert_weights, llama2, mistral # noqa diff --git a/torchtune/models/llama2/_convert_weights.py b/torchtune/models/convert_weights.py similarity index 54% rename from torchtune/models/llama2/_convert_weights.py rename to torchtune/models/convert_weights.py index a41b7fecd..9ac65b0b2 100644 --- a/torchtune/models/llama2/_convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -64,14 +64,23 @@ def _get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: return new_key -# =========== Convertors for Llama2 7B =========== +def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from Meta's format to TorchTune's format. State dicts + from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + Eg of Meta-format state dict can be found in the ``meta-llama/Llama-2-7b`` + repo in HF (https://huggingface.co/meta-llama/Llama-2-7b). -def meta_to_tune_llama2_7b( - original_state_dict: Dict[str, torch.Tensor] -) -> Dict[str, torch.Tensor]: + Args: + state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. + + Returns: + Dict[str, torch.Tensor]: State dict in TorchTune's format. + """ converted_state_dict = {} - for key, value in original_state_dict.items(): + for key, value in state_dict.items(): if key not in ["rope.freqs"]: # Skip loading the position embeddings new_key = _get_mapped_key(key, _FROM_META) converted_state_dict[new_key] = value @@ -79,71 +88,109 @@ def meta_to_tune_llama2_7b( return converted_state_dict -def tune_to_meta_llama2_7b( - original_state_dict: Dict[str, torch.Tensor] -) -> Dict[str, torch.Tensor]: +def tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from TorchTune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in TorchTune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ converted_state_dict = {} inverted_mapping_dict = {v: k for k, v in _FROM_META.items()} - for key, value in original_state_dict.items(): + for key, value in state_dict.items(): new_key = _get_mapped_key(key, inverted_mapping_dict) converted_state_dict[new_key] = value return converted_state_dict -def hf_to_tune_llama2_7b( - original_state_dict, - num_heads=32, - num_kv_heads=32, - dim=4096, -): +def hf_to_tune( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to TorchTune's format. State dicts + from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + + Eg of HF-format state dict can be found in the ``meta-llama/Llama-2-7b-hf`` + repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf). + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + + Returns: + Dict[str, torch.Tensor]: State dict in TorchTune's format. + """ converted_state_dict = {} head_dim = dim // num_heads - for key, value in original_state_dict.items(): + def _permute(t, n_heads): + return ( + t.view(n_heads, 2, head_dim // 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings new_key = _get_mapped_key(key, _FROM_HF) if "q_proj" in key: - value = ( - value.view(num_heads, 2, head_dim // 2, dim) - .transpose(1, 2) - .reshape((head_dim * num_heads), dim) - ) + value = _permute(value, num_heads) elif "k_proj" in key: - value = ( - value.view(num_kv_heads, 2, head_dim // 2, dim) - .transpose(1, 2) - .reshape((head_dim * num_kv_heads), dim) - ) + value = _permute(value, num_kv_heads) converted_state_dict[new_key] = value return converted_state_dict -def tune_to_hf_llama2_7b( - original_state_dict, - num_heads=32, - num_kv_heads=32, - dim=4096, +def tune_to_hf( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, ): + """ + Convert a state dict from TorchTune's format to HF's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in TorchTune's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ converted_state_dict = {} inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()} head_dim = dim // num_heads - for key, value in original_state_dict.items(): + def _permute(t, n_heads): + return ( + t.view(n_heads, head_dim // 2, 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): new_key = _get_mapped_key(key, inverted_mapping_dict) if "q_proj" in key: - value = ( - value.view(num_heads, head_dim // 2, 2, dim) - .transpose(1, 2) - .reshape((head_dim * num_heads), dim) - ) + value = _permute(value, num_heads) elif "k_proj" in key: - value = ( - value.view(num_kv_heads, head_dim // 2, 2, dim) - .transpose(1, 2) - .reshape((head_dim * num_kv_heads), dim) - ) + value = _permute(value, num_kv_heads) converted_state_dict[new_key] = value return converted_state_dict diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 92053793c..b66f6d9fc 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -6,15 +6,12 @@ from ._checkpoint_utils import convert_llama2_fair_format from ._component_builders import llama2, lora_llama2 -from ._convert_weights import ( # noqa - hf_to_tune_llama2_7b, - meta_to_tune_llama2_7b, - tune_to_hf_llama2_7b, - tune_to_meta_llama2_7b, -) -from ._model_builders import ( + +from ._model_builders import ( # noqa + llama2_13b, llama2_7b, llama2_tokenizer, + lora_llama2_13b, lora_llama2_7b, qlora_llama2_7b, ) diff --git a/torchtune/models/llama2/_component_builders.py b/torchtune/models/llama2/_component_builders.py index a7a4cc286..22d5a9602 100644 --- a/torchtune/models/llama2/_component_builders.py +++ b/torchtune/models/llama2/_component_builders.py @@ -151,6 +151,7 @@ def lora_llama2( num_kv_heads: int, embed_dim: int, max_seq_len: int, + intermediate_dim: Optional[int] = None, attn_dropout: float = 0.0, max_batch_size: Optional[int] = None, norm_eps: float = 1e-5, @@ -185,6 +186,8 @@ def lora_llama2( by :func:`~torchtune.modules.KVCache` attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` max_batch_size (Optional[int]): maximum batch size to be passed to :func:`~torchtune.modules.KVCache` norm_eps (float): epsilon in RMS norms. lora_rank (int): rank of each low-rank approximation @@ -214,7 +217,7 @@ def lora_llama2( quantize_base=quantize_base, ) - hidden_dim = scale_hidden_dim_for_mlp(embed_dim) + hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim) if apply_lora_to_mlp: mlp = lora_llama2_mlp( dim=embed_dim, diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 6de2b048a..87f58db76 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -24,8 +24,8 @@ def llama2_7b(max_batch_size: Optional[int] = None) -> TransformerDecoder: """ - Builder for creating a Llama2 model initialized w/ the default 7b parameter values. - From https://arxiv.org/abs/2307.09288, these default values are: + Builder for creating a Llama2 model initialized w/ the default 7b parameter values + from https://arxiv.org/abs/2307.09288 Args: max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache. @@ -63,7 +63,7 @@ def lora_llama2_7b( quantize_base: bool = False, ) -> TransformerDecoder: """ - Builder for creating a Llama2 model with LoRA enabled. + Builder for creating a Llama2 7B model with LoRA enabled. The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_7b`, while LoRA default params are based on @@ -111,3 +111,82 @@ def lora_llama2_7b( that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. Please see `lora_llama2_7b` for full API arguments. """ + + +def llama2_13b(max_batch_size: Optional[int] = None) -> TransformerDecoder: + """ + Builder for creating a Llama2 model initialized w/ the default 13b parameter values + from https://arxiv.org/abs/2307.09288 + + Args: + max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache. + + Returns: + TransformerDecoder: Instantiation of Llama2 13B model + """ + return llama2( + vocab_size=32_000, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=4096, + max_batch_size=max_batch_size, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_llama2_13b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + max_batch_size: Optional[int] = None, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Llama2 13B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache. + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Llama2 13B model with LoRA applied + """ + + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_000, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + max_seq_len=4096, + intermediate_dim=13824, + max_batch_size=max_batch_size, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=0.05, + quantize_base=False, + ) diff --git a/torchtune/models/mistral/__init__.py b/torchtune/models/mistral/__init__.py new file mode 100644 index 000000000..b362e792c --- /dev/null +++ b/torchtune/models/mistral/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import mistral # noqa +from ._model_builders import mistral_7b, mistral_tokenizer # noqa diff --git a/torchtune/models/mistral/_component_builders.py b/torchtune/models/mistral/_component_builders.py new file mode 100644 index 000000000..eaaf8cf71 --- /dev/null +++ b/torchtune/models/mistral/_component_builders.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Literal, Optional + +from torch import nn + +from torchtune.modules import ( + CausalSelfAttention, + FeedForward, + KVCache, + RMSNorm, + RotaryPositionalEmbeddings, + TransformerDecoder, + TransformerDecoderLayer, +) + +""" +Component builders for the Mistral 7B models and popular variants such as LoRA. + +TorchTune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``CausalSelfAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + +def mistral( + vocab_size: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-5, + rope_base: int = 10_000, +) -> TransformerDecoder: + """ + Build the decoder assoicated with the mistral model. This includes: + - Token embeddings + - num_layers number of TransformerDecoderLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + This does NOT currently include inference-time optimizations such as + sliding-window attention + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. If specified, + user should ensure `num_heads` % `num_kv_heads` == 0. Default value is + `None`, in which case this is the same as MHA + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp` + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + max_batch_size (Optional[int]): maximum batch size to be passed to :func:`~torchtune.modules.KVCache` + norm_eps (float): epsilon in RMS norms. + + Returns: + TransformerDecoder: Instantiation of mistral model. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + self_attn = CausalSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + ) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + layer = TransformerDecoderLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layer=layer, + num_layers=num_layers, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + +def mistral_mlp(dim: int, hidden_dim: int) -> FeedForward: + """ + Build the MLP layer associated with the Mistral model. + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) diff --git a/torchtune/models/mistral/_model_builders.py b/torchtune/models/mistral/_model_builders.py new file mode 100644 index 000000000..9f2ef3780 --- /dev/null +++ b/torchtune/models/mistral/_model_builders.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional +from functools import partial + +from torch import nn + +from torchtune.models.mistral._component_builders import mistral + +from torchtune.modules import Tokenizer, TransformerDecoder + + +""" +Model builders build specific instantiations using component builders. For example +the ``mistral_7b`` model builder uses the ``mistral`` component builder. +""" + + +def mistral_7b() -> TransformerDecoder: + """ + Builder for creating a Mistral 7B model initialized w/ the default 7b parameter values + from https://mistral.ai/news/announcing-mistral-7b/ + + + Returns: + TransformerDecoder: Instantiation of Mistral 7B model + """ + return mistral( + vocab_size=32_000, + num_layers=32, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + intermediate_dim=14336, + max_seq_len=32768, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def mistral_tokenizer(path: str) -> Tokenizer: + tokenizer = Tokenizer.from_file(path) + # Original tokenizer has no pad_id, which causes indexing errors when batch training + tokenizer.pad_id = 0 + return tokenizer diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index e11be4be3..6c273a11b 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -14,7 +14,7 @@ import torch from torchtune import utils -from torchtune.models import llama2 +from torchtune.models import convert_weights from torchtune.utils._checkpointing._checkpointer_utils import ( get_path, ModelType, @@ -376,7 +376,7 @@ def load_checkpoint(self) -> Dict[str, Any]: del state_dict gc.collect() - converted_state_dict[utils.MODEL_KEY] = llama2.hf_to_tune_llama2_7b( + converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, num_heads=self._config["num_attention_heads"], num_kv_heads=self._config["num_key_value_heads"], @@ -424,7 +424,7 @@ def save_checkpoint( self._output_dir.mkdir(exist_ok=True) # convert the state_dict back to hf format; do this inplace - state_dict[utils.MODEL_KEY] = llama2.tune_to_hf_llama2_7b( + state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[utils.MODEL_KEY], num_heads=self._config["num_attention_heads"], num_kv_heads=self._config["num_key_value_heads"], @@ -553,7 +553,7 @@ def load_checkpoint(self) -> Dict[str, Any]: """ state_dict: Dict[str:Any] = {} model_state_dict = safe_torch_load(self._checkpoint_path) - state_dict[utils.MODEL_KEY] = llama2.meta_to_tune_llama2_7b(model_state_dict) + state_dict[utils.MODEL_KEY] = convert_weights.meta_to_tune(model_state_dict) if self._adapter_checkpoint: adapter_state_dict = safe_torch_load(self._adapter_checkpoint) @@ -597,7 +597,7 @@ def save_checkpoint( """ self._output_dir.mkdir(exist_ok=True) model_state_dict = state_dict[utils.MODEL_KEY] - state_dict[utils.MODEL_KEY] = llama2.tune_to_meta_llama2_7b(model_state_dict) + state_dict[utils.MODEL_KEY] = convert_weights.tune_to_meta(model_state_dict) # Output file is always a .pt file with the epoch number in the name checkpoint_file = Path.joinpath(