Skip to content

Commit

Permalink
fix convert_safetensors.py for rwkv6
Browse files Browse the repository at this point in the history
  • Loading branch information
josStorer committed Mar 2, 2024
1 parent 753916c commit 3c632b3
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions convert_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,21 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
)

for k in kk:
new_k = rename_key(rename, k).lower()
v = loaded[k].half()
del loaded[k]
for transpose_name in transpose_names:
if transpose_name in k:
v = v.transpose(0, 1)
print(f"{new_k}\t{v.shape}\t{v.dtype}")
loaded[new_k] = {
"dtype": str(v.dtype).split(".")[-1],
"shape": v.shape,
"data": v.detach().numpy().tobytes(),
}
with torch.no_grad():
for k in kk:
new_k = rename_key(rename, k).lower()
v = loaded[k].half()
del loaded[k]
for transpose_name in transpose_names:
if transpose_name in new_k:
dims = len(v.shape)
v = v.transpose(dims - 2, dims - 1)
print(f"{new_k}\t{v.shape}\t{v.dtype}")
loaded[new_k] = {
"dtype": str(v.dtype).split(".")[-1],
"shape": v.shape,
"data": v.numpy().tobytes(),
}

dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
Expand Down

0 comments on commit 3c632b3

Please sign in to comment.