Skip to content

Commit

Permalink
Fix convert_checkpoint.py for hf and gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jun 10, 2024
1 parent 4535bdf commit 6d8eb3f
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _load_orig_llama_weight(input_ckpt_dir: epath.Path):

def _load_hf_llama_weight(input_ckpt_dir: epath.Path):
print(f"Loading checkpoint files from {input_ckpt_dir}.")
safetensors_files = input_ckpt_dir.glob("*.safetensors")
safetensors_files = list(input_ckpt_dir.glob("*.safetensors"))
if len(list(safetensors_files)) == 0:
raise ValueError(
f"No *.safetensors found in the input dir {input_ckpt_dir}"
Expand Down Expand Up @@ -418,6 +418,12 @@ def _get_llama_state_dict(input_ckpt_dir):
print(f"Merging weights takes {end - start} seconds")
return state_dict, params

def fix_json(text):
text = text.replace("'", '"')
lines = text.split('\n')
lines[-3] = lines[-3].replace(",", "")
return '\n'.join(lines)


def _get_gemma_state_dict(input_ckpt_dir):
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
Expand All @@ -426,7 +432,9 @@ def _get_gemma_state_dict(input_ckpt_dir):
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
"model_state_dict"
]
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
config_text = fix_json((input_ckpt_dir / "config.json").read_text())
print('gemma config is', config_text)
model_config = json.loads(config_text)
for key in list(state_dict.keys()):
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
assert (
Expand Down

0 comments on commit 6d8eb3f

Please sign in to comment.