Skip to content

Commit

Permalink
reverting 1064
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Jul 12, 2024
1 parent 06a125e commit 3e7e0ac
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 129 deletions.
4 changes: 1 addition & 3 deletions tests/torchtune/utils/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,6 @@ def state_dict(self, weight_dtype):
),
"model.norm.weight": randn(_DIM, dtype=weight_dtype),
}
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
return state_dict

@pytest.fixture
Expand Down Expand Up @@ -704,8 +703,7 @@ def test_load_save_checkpoint_single_file(
# Converted state dict from the checkpointer

state_dict = single_file_checkpointer.load_checkpoint()
# Check that we've loaded all the keys - we're loading one less key in: lm_head.weight
assert len(state_dict["model"].keys()) == (len(orig_state_dict.keys()) - 1)
assert len(state_dict["model"].keys()) == len(orig_state_dict.keys())

# the keys in original state dict should match up with the keys in the weight_map
for key in orig_state_dict.keys():
Expand Down
1 change: 0 additions & 1 deletion torchtune/models/gemma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

from ._component_builders import gemma # noqa
from ._convert_weights import gemma_hf_to_tune, gemma_tune_to_hf # noqa
from ._model_builders import ( # noqa
gemma_2b,
gemma_7b,
Expand Down
108 changes: 0 additions & 108 deletions torchtune/models/gemma/_convert_weights.py

This file was deleted.

17 changes: 0 additions & 17 deletions torchtune/utils/_checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torchtune import utils

from torchtune.models import convert_weights
from torchtune.models.gemma import gemma_hf_to_tune, gemma_tune_to_hf
from torchtune.models.mistral import (
mistral_reward_hf_to_tune,
mistral_reward_tune_to_hf,
Expand Down Expand Up @@ -427,14 +426,6 @@ def load_checkpoint(self) -> Dict[str, Any]:
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
elif self._model_type == ModelType.GEMMA:
converted_state_dict[utils.MODEL_KEY] = gemma_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
else:
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
Expand Down Expand Up @@ -485,14 +476,6 @@ def save_checkpoint(
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
elif self._model_type == ModelType.GEMMA:
state_dict[utils.MODEL_KEY] = gemma_tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
else:
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[utils.MODEL_KEY],
Expand Down

0 comments on commit 3e7e0ac

Please sign in to comment.