From ac8a71a66f2352165df350d685298c02ad84d3d8 Mon Sep 17 00:00:00 2001 From: Laura Hanu <32672979+laurahanu@users.noreply.github.com> Date: Fri, 8 Apr 2022 16:56:53 +0100 Subject: [PATCH] only load state dict when the checkpoint is not None --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c81ef06ebed853..7f1b12386202fa 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1792,7 +1792,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # load pt weights early so that we know which dtype to init the model under if from_pt: - if not is_sharded: + if not is_sharded and state_dict is None: # Time to load the checkpoint state_dict = load_state_dict(resolved_archive_file) # set dtype to instantiate the model under: