diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 89c11dae7..107154230 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -522,7 +522,7 @@ def save_checkpoint( self._output_dir, f"model-0{cpt_idx}-of-0{list(split_state_dicts.keys())[-1]}_{epoch}", ).with_suffix(".safetensors") - save_file(model_state_dict, output_path, metadata={'format': 'pt'}) + save_file(model_state_dict, output_path, metadata={"format": "pt"}) logger.info( "Model checkpoint of size " f"{os.path.getsize(output_path) / 1000**3:.2f} GB "