Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deep_speed initialization for models in the transformers library #85

Open
DesperateExplorer opened this issue Jul 19, 2023 · 6 comments
Labels
help wanted Extra attention is needed

Comments

@DesperateExplorer
Copy link

Dear authors,

I found that collie can not initialize DeepSpeed when using models in the transformers library. For example, when replace this line of script with the from_pretrained interface of the transformers library, to which any config of the type CollieConfig can not be passed, even the monitors can not be registered correctly since ds is not initialized (DeepSpeed backend not set, please initialize it using init_process_group()). Is there any workaround of this issue or Collie can only support training the internally reimplemented models?

@00INDEX
Copy link
Collaborator

00INDEX commented Jul 19, 2023

Hi @DesperateExplorer , Collie can use models from transformers, in the case of ZeRO parallelism. But you need to execute setup_distribution manually:

from collie import setup_distribution, CollieConfig
from transformers import AutoModelForCausalLM
model_name = "openlm-research/open_llama_7b_v2"
config = CollieConfig.from_pretrianed(model_name)
setup_distribution(config)
model = AutoModelForCausalLM.from_pretrained(model_name)

@DesperateExplorer
Copy link
Author

Why is the memory consumption of the LLaMA-7B from transformers much larger than the internal implementation by Collie? Taking LLaMA-7B and AdamW for example, when using the internal implementation, train_micro_batch_size_per_gpu can be 2 and will not cause OOM for V100 on the ShareGPT dataset (max context = 2048), however, when using the transformers implementation, "train_micro_batch_size_per_gpu = 1" will cause OOM. Even switching to Lomo, I can not fit "train_micro_batch_size_per_gpu = 1" sample into the 32GB memory without OOM.

@x54-729
Copy link
Contributor

x54-729 commented Jul 20, 2023

Collie's LLaMA used flash attetion as MHA, which can reduce memory usage. If your use_flash is True, the memory usage is less than transformers implementation

@DesperateExplorer
Copy link
Author

Collie's LLaMA used flash attetion as MHA, which can reduce memory usage. If your use_flash is True, the memory usage is less than transformers implementation

Actually, not. On V100 (Volta architecture), any kind of flash attention is not supported.

@Carol-gutianle
Copy link
Collaborator

Carol-gutianle commented Jul 24, 2023

You can try to set the pretrained_config.gradient_checkpointing to True, just like this:
image

@x54-729
Copy link
Contributor

x54-729 commented Jul 25, 2023

You can try to set the pretrained_config.gradient_checkpointing to True, just like this: image

config.checkpointing=True also works now.

@00INDEX 00INDEX added the help wanted Extra attention is needed label Aug 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants