Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral configs #591

Merged
merged 6 commits into from
Mar 27, 2024
Merged

Mistral configs #591

merged 6 commits into from
Mar 27, 2024

Conversation

kartikayk
Copy link
Contributor

@kartikayk kartikayk commented Mar 26, 2024

Context

#571 added support for Mistral 7B. In this PR, I add configs for Mistral7B full finetuning and lora. These can be further improved, but have decent results OOTB.

Full finetune:

 tune --nnodes 1 --nproc_per_node 4 full_finetune_distributed \
--config mistral/7B_full \
metric_logger=torchtune.utils.metric_logging.WandBLogger \
metric_logger.project=test

Loss Curve:

Screenshot 2024-03-25 at 9 00 57 AM

Eval on truthfulqa_mc2

Screenshot 2024-03-25 at 9 01 08 AM

LoRA:

tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed \
--config mistral/7B_lora  \
metric_logger=torchtune.utils.metric_logging.WandBLogger \
metric_logger.project=test

Loss Curve:

image

Eval on truthfulqa_mc2

Screenshot 2024-03-26 at 7 11 15 AM

FAQ

How did you come up with these configs?

Getting training to stabilize took some work. This model seems to need a smaller LR than the Llama2 7B/13B models and for LoRA, I needed to ramp up rank and alpha a bit. I don't claim this to be novel work though. I just did a bunch of snooping around on localllama [example] and some other forums and blogs to come up with a config with reasonable results.

Changelog

  • Adding component and model builders for LoRA with Mistral 7B
  • Add configs for LoRA and Full finetuning for Mistral 7B

Test plan

  • Tests
pytest tests
  • E2E Runs - see above

Copy link

pytorch-bot bot commented Mar 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/591

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit ef54a33 with merge base d2e36ed (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2024
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wooooo kwargs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't take credit for this. This comes from @ebsmothers

@@ -0,0 +1,69 @@
# This config is currently a WIP. Use it with caution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these configs based on the paper and/or other results showing good performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly digging on various forums like reddit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the comment to say something like "this config is based on a small set of experiments and is not intended to reproduce results from the Mistral paper or elsewhere"

output_dir: /tmp/Mistral-7B-v0.1
model_type: LLAMA2
output_dir: /tmp/Mistral-7B-v0.1/
model_type: MISTRAL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This technically varies from the HF way of defining model type b/c mistral is llama-based. What does model_type=Mistral give us?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less confusion? Defining mistral model to be llama type is a bit confusing I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's technically true that it is a llama-type model and this is the accepted standard in HF.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd definitely dont want to carry over any confusions from there

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I like the separate model type. Yes it's an identical architecture, but just helps to be explicit. Main q is: is this used as a hard constraint in places? E.g. is it gonna prevent me from loading Mistral weights into a LLAMA2 model type

@@ -40,7 +40,7 @@ The library currently supports the following models and fine-tuning methods.
|-----------------------------------------------|-----------|-----------------------------------------------------------|
| [Llama2](torchtune/models/llama2/_model_builders.py) | 7B | Full Finetuning [[single device](recipes/configs/llama2/7B_full_single_device.yaml), [distributed](recipes/configs/llama2/7B_full.yaml)] LoRA [[single device](recipes/configs/llama2/7B_lora_single_device.yaml), [distributed](recipes/configs/llama2/7B_lora.yaml)] QLoRA [single device](recipes/configs/llama2/7B_qlora_single_device.yaml) |
| [Llama2](torchtune/models/llama2/_model_builders.py) | 13B | [Full Finetuning](recipes/configs/llama2/13B_full.yaml), [LoRA](recipes/configs/llama2/13B_lora.yaml)
| [Mistral](torchtune/models/mistral//_model_builders.py) | 7B | Full Finetuning and LoRA are WIP and will be added soon
| [Mistral](torchtune/models/mistral//_model_builders.py) | 7B | [Full Finetuning](recipes/configs/mistral/7B_full.yaml), [LoRA](recipes/configs/mistral/7B_lora.yaml)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary for this PR but I think this table needs a cleanup. We should either split finetuning methods into separate columns or rows grouped under model family

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this table does need an update. Let me follow up separately on this.


tok_embeddings = nn.Embedding(vocab_size, embed_dim)

# TODO: quantize_base is not applied to final output_proj currently.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have quantize_base at all in mistral right now, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh need to run some experiments on this before I add this. Training on mistral is quite different from llama2 and so will do QLoRA as a follow up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would at least remove the todo for now then

lora_alpha: float = 16,
) -> TransformerDecoder:
"""
Builder for creating a Llama2 7B model with LoRA enabled.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:D

num_kv_heads: int,
max_seq_len: int,
attn_dropout: float = 0.0,
rope_base: int = 10_000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we parametrize this here but not in llama2 builders?

Returns:
TransformerDecoder: Instantiation of Llama2 7B model with LoRA applied
"""
return lora_mistral(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to document somewhere exactly what the differences are between this builder and the equivalent llama2 one. Doesn't have to be for this PR though; as an added bonus it nicely shows off how easily we can switch between the two.

# Model Arguments
model:
_component_: torchtune.models.mistral.lora_mistral_7b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: order q, k, v for clarity

Comment on lines 55 to 56
# Distributed
cpu_offload: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is even used? (I see it's still lurking around in other configs too..)

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we wanna add a unit test for the new model? Doesn't have to be for all levels of components but at least e.g. at the level of mistral or lora_mistral

@kartikayk kartikayk merged commit bbac98c into main Mar 27, 2024
20 checks passed
@kartikayk kartikayk deleted the mistral_configs branch March 27, 2024 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants