-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mistral configs #591
Mistral configs #591
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/591
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit ef54a33 with merge base d2e36ed (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
lora_attn_modules: List[LORA_ATTN_MODULES], | ||
apply_lora_to_mlp: bool = False, | ||
apply_lora_to_output: bool = False, | ||
*, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wooooo kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't take credit for this. This comes from @ebsmothers
recipes/configs/mistral/7B_lora.yaml
Outdated
@@ -0,0 +1,69 @@ | |||
# This config is currently a WIP. Use it with caution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these configs based on the paper and/or other results showing good performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly digging on various forums like reddit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we update the comment to say something like "this config is based on a small set of experiments and is not intended to reproduce results from the Mistral paper or elsewhere"
output_dir: /tmp/Mistral-7B-v0.1 | ||
model_type: LLAMA2 | ||
output_dir: /tmp/Mistral-7B-v0.1/ | ||
model_type: MISTRAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This technically varies from the HF way of defining model type b/c mistral is llama-based. What does model_type=Mistral give us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Less confusion? Defining mistral model to be llama type is a bit confusing I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it's technically true that it is a llama-type model and this is the accepted standard in HF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd definitely dont want to carry over any confusions from there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I like the separate model type. Yes it's an identical architecture, but just helps to be explicit. Main q is: is this used as a hard constraint in places? E.g. is it gonna prevent me from loading Mistral weights into a LLAMA2 model type
a7a24bc
to
6fd95cf
Compare
@@ -40,7 +40,7 @@ The library currently supports the following models and fine-tuning methods. | |||
|-----------------------------------------------|-----------|-----------------------------------------------------------| | |||
| [Llama2](torchtune/models/llama2/_model_builders.py) | 7B | Full Finetuning [[single device](recipes/configs/llama2/7B_full_single_device.yaml), [distributed](recipes/configs/llama2/7B_full.yaml)] LoRA [[single device](recipes/configs/llama2/7B_lora_single_device.yaml), [distributed](recipes/configs/llama2/7B_lora.yaml)] QLoRA [single device](recipes/configs/llama2/7B_qlora_single_device.yaml) | | |||
| [Llama2](torchtune/models/llama2/_model_builders.py) | 13B | [Full Finetuning](recipes/configs/llama2/13B_full.yaml), [LoRA](recipes/configs/llama2/13B_lora.yaml) | |||
| [Mistral](torchtune/models/mistral//_model_builders.py) | 7B | Full Finetuning and LoRA are WIP and will be added soon | |||
| [Mistral](torchtune/models/mistral//_model_builders.py) | 7B | [Full Finetuning](recipes/configs/mistral/7B_full.yaml), [LoRA](recipes/configs/mistral/7B_lora.yaml) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary for this PR but I think this table needs a cleanup. We should either split finetuning methods into separate columns or rows grouped under model family
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, this table does need an update. Let me follow up separately on this.
|
||
tok_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
|
||
# TODO: quantize_base is not applied to final output_proj currently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have quantize_base at all in mistral right now, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeh need to run some experiments on this before I add this. Training on mistral is quite different from llama2 and so will do QLoRA as a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would at least remove the todo for now then
lora_alpha: float = 16, | ||
) -> TransformerDecoder: | ||
""" | ||
Builder for creating a Llama2 7B model with LoRA enabled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:D
num_kv_heads: int, | ||
max_seq_len: int, | ||
attn_dropout: float = 0.0, | ||
rope_base: int = 10_000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we parametrize this here but not in llama2 builders?
Returns: | ||
TransformerDecoder: Instantiation of Llama2 7B model with LoRA applied | ||
""" | ||
return lora_mistral( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to document somewhere exactly what the differences are between this builder and the equivalent llama2 one. Doesn't have to be for this PR though; as an added bonus it nicely shows off how easily we can switch between the two.
recipes/configs/mistral/7B_lora.yaml
Outdated
# Model Arguments | ||
model: | ||
_component_: torchtune.models.mistral.lora_mistral_7b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: order q, k, v for clarity
recipes/configs/mistral/7B_lora.yaml
Outdated
# Distributed | ||
cpu_offload: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is even used? (I see it's still lurking around in other configs too..)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we wanna add a unit test for the new model? Doesn't have to be for all levels of components but at least e.g. at the level of mistral
or lora_mistral
Context
#571 added support for Mistral 7B. In this PR, I add configs for Mistral7B full finetuning and lora. These can be further improved, but have decent results OOTB.
Full finetune:
Loss Curve:
Eval on
truthfulqa_mc2
LoRA:
Loss Curve:
Eval on
truthfulqa_mc2
FAQ
How did you come up with these configs?
Getting training to stabilize took some work. This model seems to need a smaller LR than the Llama2 7B/13B models and for LoRA, I needed to ramp up rank and alpha a bit. I don't claim this to be novel work though. I just did a bunch of snooping around on localllama [example] and some other forums and blogs to come up with a config with reasonable results.
Changelog
Test plan