diff --git a/tests/recipes/test_full_finetune.py b/tests/recipes/test_full_finetune.py index b4f125080e..1f68291397 100644 --- a/tests/recipes/test_full_finetune.py +++ b/tests/recipes/test_full_finetune.py @@ -127,7 +127,7 @@ def test_loss(self, capsys, pytestconfig, tmpdir, monkeypatch): --config {_CONFIG_PATH} \ output_dir={tmpdir} \ model=torchtune.models.{ckpt} \ - checkpointer=torchtune.utils.{checkpointer} + checkpointer._component_=torchtune.utils.{checkpointer} checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.model_type=LLAMA2 @@ -174,7 +174,7 @@ def test_training_state_on_resume(self, capsys, tmpdir, monkeypatch): --config {_CONFIG_PATH} \ output_dir={ckpt_dir} \ model=torchtune.models.{model_ckpt} \ - checkpointer=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.model_type=LLAMA2 \ @@ -194,7 +194,7 @@ def test_training_state_on_resume(self, capsys, tmpdir, monkeypatch): --config {_CONFIG_PATH} \ output_dir={tmpdir} \ model=torchtune.models.{model_ckpt} \ - checkpointer=torchtune.utils.FullModelHFCheckpointer \ + checkpointer._component_=torchtune.utils.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{os.path.join(ckpt_dir, "hf_model_0001_2.pt")}]\ checkpointer.model_type=LLAMA2 \ @@ -235,7 +235,7 @@ def test_gradient_accumulation( tune full_finetune \ --config {_CONFIG_PATH} \ model=torchtune.models.{model_ckpt} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.model_type=LLAMA2 \ @@ -262,7 +262,7 @@ def test_gradient_accumulation( tune full_finetune \ --config {_CONFIG_PATH} \ model=torchtune.models.{model_ckpt} \ - checkpointer=torchtune.utils.FullModelTorchTuneCheckpointer \ + checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={ckpt_dir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.model_type=LLAMA2 \ diff --git a/tests/recipes/test_lora_finetune.py b/tests/recipes/test_lora_finetune.py index 0112c9d44b..6d4d22de01 100644 --- a/tests/recipes/test_lora_finetune.py +++ b/tests/recipes/test_lora_finetune.py @@ -118,9 +118,8 @@ def test_save_and_load_merged_weights( cmd = f""" tune {recipe_name} --config {config_path} \ - --override \ enable_fsdp={enable_fsdp} \ - model._component_=torchtune.models.{ckpt} \ + model=torchtune.models.{ckpt} \ model_checkpoint={fetch_ckpt_model_path(ckpt)} \ model.lora_rank=8 \ model.lora_alpha=16 \