-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize configs and add Llama2 13B + Mistral 7B #571
Changes from 8 commits
51f20b3
092d53b
3829a52
014cfa7
4170548
d942280
fd4df1d
917f926
2a06b8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Config for multi-device full finetuning in full_finetune_distributed.py | ||
# using a Llama2 13B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download --repo-id meta-llama/Llama-2-13b-hf \ | ||
# --hf-token <HF_TOKEN> \ | ||
# --output-dir /tmp/llama2-13b-hf | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ | ||
# --config llama2/13B_full \ | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ | ||
# --config llama2/13B_full \ | ||
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config should be used with 2+ GPUs. Single device full fine-tuning | ||
# requires several memory optimizations which are exposed through | ||
# 7B_full_single_device.yaml. Please update the model and checkpoints to 13B | ||
# in that config. | ||
Comment on lines
+21
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess implicit in our choice of naming here is that >1 device is kind of now the "default", right? While I understand that we are doing more memory optimizations in the single device recipes now, we've obviously seen that FSDP comes with its own nuances too. So I do wonder if it's now hard for someone to just come in and say "give me a simple single-device recipe to get started on" This is also a bit weird for QLoRA imo where we currently only support single device There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeh this is a good question. I did this primarily for two reasons:
Does this make sense? |
||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/llama2/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_13b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/llama2-13b-hf/ | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00003.bin, | ||
pytorch_model-00002-of-00003.bin, | ||
pytorch_model-00003-of-00003.bin | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/llama2-13b-hf/ | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 3 | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
lr: 2e-5 | ||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
|
||
# Training env | ||
device: cuda | ||
|
||
# Distributed | ||
cpu_offload: False | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-llama2-finetune | ||
log_every_n_steps: null |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Config for multi-device LoRA in lora_finetune_distributed.py | ||
# using a Llama2 13B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download --repo-id meta-llama/Llama-2-13b-hf \ | ||
# --hf-token <HF_TOKEN> \ | ||
# --output-dir /tmp/llama2-13b-hf | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ | ||
# --config llama2/13B_lora \ | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \ | ||
# --config llama2/13B_lora \ | ||
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works best when the model is being fine-tuned on 2+ GPUs. | ||
# For single device lora finetuning please use 7B_lora_single_device.yaml | ||
# or 7B_qlora_single_device.yaml and update the model and checkpoints to | ||
# the 13B model. | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.lora_llama2_13b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: True | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/llama2-13b-hf/ | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00003.bin, | ||
pytorch_model-00002-of-00003.bin, | ||
pytorch_model-00003-of-00003.bin | ||
] | ||
adapter_checkpoint: null | ||
recipe_checkpoint: null | ||
output_dir: /tmp/llama2-13b-hf/ | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/llama2/tokenizer.model | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
use_clean: True | ||
seed: null | ||
shuffle: True | ||
batch_size: 32 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 2e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
|
||
# Logging | ||
output_dir: /tmp/lora_finetune_output | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: null | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Config for multi-device full finetuning in full_finetune_distributed.py | ||
# using a Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why add this huge block in the config? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently there's no documentation on the configs at all. Once we have live docs available, we can add these to the docs. But for now, I'd like to give some understanding to users about when and how to use each config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there's still some information in the README on configs and we can make that more clear. I think cluttering up the configs can be overwhelming. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would it be overwhelming? Isn't it just documentation? I don't think we would be able to add config-level info to the README? |
||
# this run: | ||
# tune download --repo-id meta-llama/Llama-2-7b \ | ||
# --hf-token <HF_TOKEN> \ | ||
# --output-dir /tmp/llama2 | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ | ||
# --config llama2/7B_full \ | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \ | ||
# --config llama2/7B_full \ | ||
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works best when the model is being fine-tuned on 2+ GPUs. | ||
# Single device full finetuning requires more memory optimizations. It's | ||
# best to use 7B_full_single_device.yaml for those cases | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/llama2/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelMetaCheckpointer | ||
checkpoint_dir: /tmp/llama2 | ||
checkpoint_files: [consolidated.00.pth] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/llama2 | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 3 | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
lr: 2e-5 | ||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
|
||
# Training env | ||
device: cuda | ||
|
||
# Distributed | ||
cpu_offload: False | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-llama2-finetune | ||
log_every_n_steps: null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to watch out for: there can be lots of issues when filenames start with a number. I've definitely seen it as a problem with Python imports, maybe it will be ok with YAML files? But something to keep in mind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting - what sort of issues? But yeh I don't expect us to be importing the configs anymore?