Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert parallel state_dict to normal state_dict? #122

Open
JinchaoLove opened this issue Sep 18, 2023 · 3 comments
Open

How to convert parallel state_dict to normal state_dict? #122

JinchaoLove opened this issue Sep 18, 2023 · 3 comments

Comments

@JinchaoLove
Copy link

JinchaoLove commented Sep 18, 2023

Hi, there! I saved parallel state_dict (requires_grad True only) with 8 GPUs remotely, how to load these state_dicts and save them as one locally? Thanks in advance.

collie_dp0_pp0_tp0.pt  collie_zero_dp0_pp0_tp0.pt  collie_zero_dp2_pp0_tp0.pt  collie_zero_dp4_pp0_tp0.pt  collie_zero_dp6_pp0_tp0.pt
collie.json            collie_zero_dp1_pp0_tp0.pt  collie_zero_dp3_pp0_tp0.pt  collie_zero_dp5_pp0_tp0.pt  collie_zero_dp7_pp0_tp0.pt
@KaiLv69
Copy link
Collaborator

KaiLv69 commented Sep 18, 2023

Hi, the model weights should be saved in files like pytorch_model.bin with CheckpointCallback below.

callbacks = [CheckpointCallback(your_path, every_n_batches=1600, model_only=False,peft_only=False)]

BTW, are you using the main branch or dev branch? Recommend using dev now.

@JinchaoLove
Copy link
Author

JinchaoLove commented Sep 18, 2023

Hi, the model weights should be saved in files like pytorch_model.bin with CheckpointCallback below.

callbacks = [CheckpointCallback(your_path, every_n_batches=1600, model_only=False,peft_only=False)]

BTW, are you using the main branch or dev branch? Recommend using dev now.

Got it. I'm using the dev branch. So the aforementioned are all trainer state (not model weights) as defined in the Trainer. The issue caused by my filter method of if requires_grad, which is always False in state_dict.

self.checkpoint_file = "collie_dp{}_pp{}_tp{}.pt".format(env.dp_rank, env.pp_rank, env.tp_rank)  # Trainer state
state_dict = {n: p.detach().cpu() for n, p in model.state_dict().items() if p.requires_grad}  # always empty

@JinchaoLove
Copy link
Author

JinchaoLove commented Sep 19, 2023

The topk in the CheckpointCallback defaults to 0, which will not save the model... I think it's better to set it to be 1 or -1 or raise a warning by default in case of misconfiguration.

@JinchaoLove JinchaoLove reopened this Sep 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants