Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could it support Gemma? #616

Closed
solitude-alive opened this issue Mar 29, 2024 · 14 comments · Fixed by #630
Closed

Could it support Gemma? #616

solitude-alive opened this issue Mar 29, 2024 · 14 comments · Fixed by #630

Comments

@solitude-alive
Copy link
Contributor

The Google model have 2B model, it seems that we can use less than 4*24GB GPUs to fine-tune with full parameters. Do you plan to support it?

@joecummings
Copy link
Contributor

We are considering many new model additions and will keep you posted!

@kartikayk
Copy link
Contributor

@solitude-alive would you be open to adding this model? I'm happy to help share specific pointers and review code if you're interested. We'd love the contribution.

@solitude-alive
Copy link
Contributor Author

@kartikayk Yeah, I'm happy to do that. I would try it.

@kartikayk
Copy link
Contributor

@solitude-alive awesome!

For a starting point, take a look at the Mistral 7B model builder in:
https:/pytorch/torchtune/blob/main/torchtune/models/mistral/_model_builders.py

We expose specific models through model builders which basically stitich together components (eg: Attention, RoPE, RMS Norm etc). You can find some examples here:
https:/pytorch/torchtune/blob/main/torchtune/models/mistral/_component_builders.py#L36

I think adding support for gemma_2b would be similar. You just need to make sure the components line up with what Gemma is doing.

@solitude-alive
Copy link
Contributor Author

@kartikayk Hi,

_model_builders.py and _component_builders.py have been mostly completed, except for some components that need to be confirmed.

Is there documentation on how to load the weights file? It seems that Gemma only support [model-00001-of-00002.safetensors, model-00002-of-00002.safetensors] rather than .bin or .pth files.

@joecummings
Copy link
Contributor

@solitude-alive great catch! Right now, TorchTune supports only PyTorch-native .bin or .pt formats.

In order to add Gemma, we need to think about a functionality to support loading safetensors. Hugging Face has a great library and resources for this here: https://huggingface.co/docs/safetensors/index#usage and it probably makes sense to take a look at how we incorporate loading in TorchTune here:

def safe_torch_load(checkpoint_path: Path) -> Dict[str, Any]:
.

Is this something you feel comfortable adding? This would be an incredible feature b/c there's a lot of other models on HF Hub that only support safetensors, too.

@joecummings
Copy link
Contributor

Also, @solitude-alive - would love for you to join the Discord channel (see our README for invite link) so we can quickly answer any questions you may have as you work on this!

@solitude-alive
Copy link
Contributor Author

@joecummings Yeah, thanks.

@kartikayk
Copy link
Contributor

@solitude-alive Awesome! As @joecummings said, it would be awesome to add safetensor support to TorchTune's FullModelHfCheckpointer.

I verified that safetensors.safe_open produces the same state_dict with safetensor files as the TorchTune HF Checkpointer does with bin files for the llama-13B model. Here's a minimal validation:

 

# Examine safetensors

from safetensors import safe_open
from torchtune.models import convert_weights
from torchtune.utils import FullModelHFCheckpointer, ModelType

checkpoint_dir = '/data/users/kartikayk/cpts/Llama-2-13b-hf/'
safetensor_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors']
pytorch_files = ['pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin']

safetensor_sd = {}

for file in safetensor_files:
    file_path = checkpoint_dir + file
    with safe_open(file_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            safetensor_sd[key] = f.get_tensor(key)

# convert the state_dict from HF format to TorchTune format
# hf_to_tune needs to know some params for correct conversion
safetensor_sd_torchtune = convert_weights.hf_to_tune(safetensor_sd, num_heads=40, num_kv_heads=40, dim=5120)

# Use torchTune's HF Checkpointer to get the state_dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir='/data/users/kartikayk/cpts/Llama-2-13b-hf/',
    model_type=ModelType.LLAMA2
)

torchtune_sd = checkpointer.load_checkpoint()

# assert that we get the same keys and values
# torchtune checkpointer adds an additional 'model' key
for key in torchtune_sd['model'].keys():
    assert torch.equal(torchtune_sd['model'][key], safetensor_sd_torchtune[key])

And here's the output:

image

Given that these are numerically equivalent, I think the best way forward would be if you can add a flag to FullModelHFCheckpointer - something like is_safetensor and when this is True, just use ths safetensor.save_file instead of safe_torch_load to get the state_dict. Everything else, including the conversion to TorchTune's format should be the same. This is the relevant function: https:/pytorch/torchtune/blob/main/torchtune/utils/_checkpointing/_checkpointer.py#L323

Does this make sense to you?

@solitude-alive
Copy link
Contributor Author

@kartikayk Thank you for your suggestion.

@solitude-alive
Copy link
Contributor Author

Hi, it seems have errors on my device when I set the output_pro.weight = Tok_embedding.weight for Gemma. Is there any way to fix it?

[rank1]:[E ProcessGroupNCCL.cpp:523] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=21, OpType=BROADCAST, NumelIn=524290048, NumelOut=524290048, Timeout(ms)=600000) ran for 600431 milliseconds before timing out.

@kartikayk
Copy link
Contributor

Seems like this is actively being discussed on the discord. Once the discussion is over, we can come back and summarize it here.

cc: @ebsmothers

@solitude-alive
Copy link
Contributor Author

solitude-alive commented Apr 1, 2024

Hi, it seems have errors on my device when I set the output_pro.weight = Tok_embedding.weight for Gemma. Is there any way to fix it?

[rank1]:[E ProcessGroupNCCL.cpp:523] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=21, OpType=BROADCAST, NumelIn=524290048, NumelOut=524290048, Timeout(ms)=600000) ran for 600431 milliseconds before timing out.

Thanks for the discussion, there is a temporary solution: remove any weight tying that occurs before FSDP wrapping and put weight tying here.

@ebsmothers
Copy link
Contributor

Yeah to summarize the discussion on Discord: when training with FSDP the way we initialize the model undoes the weight tying. Specifically I suspect it's because we initialize on meta device. Not only that, but we cannot tie weights prior to FSDP wrapping or else we will hit a hang at our first sync point. You can see e.g. here for some discussion on the topic.

We can get around this by instead tying weights after FSDP wrapping. I believe @solitude-alive already has a tie_weight utility defined on their fork, we just need to call this in the recipe instead of the model builder. This way we can control when it gets executed; we can execute it anytime in our single device recipes, but need to execute it after FSDP wrapping in our distributed recipes. (Another option would be some kind of post-init hook but not sure offhand how to implement it.)

@solitude-alive solitude-alive mentioned this issue Apr 2, 2024
@joecummings joecummings linked a pull request Apr 2, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants