Skip to content

Commit

Permalink
change to torch bf16 type.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyang-google committed Aug 20, 2024
1 parent 05149fc commit 54f2f19
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def convert_state_to_hf(training_state, model_size):
# Port the embedding weights
hf_model_params["model.embed_tokens.weight"] = torch.tensor(np.asarray(
training_state.params['params']['token_embedder']['embedding']),
dtype=torch.float32)
dtype=torch.float16)

for layer_int in tqdm(range(base_num_decoder_layers),desc="Porting parameters layerwise"):
print("Converting weights for layer {}".format(layer_int))
Expand All @@ -138,79 +138,79 @@ def convert_state_to_hf(training_state, model_size):
,head_dim
)
).reshape(base_num_query_heads * head_dim,base_num_query_heads * head_dim).T),
dtype=torch.float32
dtype=torch.float16
)

hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor(np.asarray(
unpermute_from_match_maxtext_rope(
training_state.params['params']["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :]
).reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["layers"]["self_attention"]["value"]["kernel"][:, layer_int, :, :]
.reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params[f"model.layers.{layer_int}.self_attn.o_proj.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["layers"]["self_attention"]["out"]["kernel"][:, layer_int, :, :]
.reshape(base_num_query_heads * head_dim,base_num_query_heads * head_dim).T),
dtype=torch.float32
dtype=torch.float16
)

# MLP Layers
if num_experts is None:
hf_model_params[f"model.layers.{layer_int}.mlp.gate_proj.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["layers"]["mlp"]["wi_0"]["kernel"][:, layer_int, :].T),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params[f"model.layers.{layer_int}.mlp.up_proj.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["layers"]["mlp"]["wi_1"]["kernel"][:, layer_int, :].T),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params[f"model.layers.{layer_int}.mlp.down_proj.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["layers"]["mlp"]["wo"]["kernel"][:, layer_int, :].T),
dtype=torch.float32
dtype=torch.float16
)
else:
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.gate.weight"] = torch.tensor(np.asarray(
training_state.params['params']['decoder']['layers']['MoeBlock_0']['gate']['kernel'][:,layer_int,:].T
), dtype=torch.float32)
), dtype=torch.float16)
for k in range(num_experts):
print("Coverting MoeBlock expert {} weights".format(k))
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w1.weight"] = torch.tensor(np.asarray(
training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_0'][k, layer_int, :, :].T),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w2.weight"] = torch.tensor(np.asarray(
training_state.params['params']['decoder']['layers']['MoeBlock_0']['wo'][k, layer_int, :, :].T),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w3.weight"] = torch.tensor(np.asarray(
training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_1'][k, layer_int, :, :].T),
dtype=torch.float32
dtype=torch.float16
)

# Pre/post attention layer norm
hf_model_params[f"model.layers.{layer_int}.input_layernorm.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"][:, layer_int]
.reshape(base_num_query_heads * head_dim)),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params[f"model.layers.{layer_int}.post_attention_layernorm.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"][:, layer_int]
.reshape(base_num_query_heads * head_dim)),
dtype=torch.float32
dtype=torch.float16
)

# LM head and layernorm
hf_model_params["lm_head.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["logits_dense"]["kernel"].T),
dtype=torch.float32
dtype=torch.float16
)
hf_model_params["model.norm.weight"] = torch.tensor(np.asarray(
training_state.params['params']["decoder"]["decoder_norm"]["scale"].reshape(base_num_query_heads * head_dim)),
dtype=torch.float32
dtype=torch.float16
)

return hf_model_params
Expand All @@ -225,7 +225,7 @@ def convert_orbax_hf(hf_model_path, config):
print("MoeBlock gate shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['gate']['kernel'].shape))
print("MoeBlock wi_0 shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_0'].shape))
print("MoeBlock w1_1 shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_1'].shape))
print("MoeBlock w0 shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wo'].shape))
print("MoeBlock wo shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wo'].shape))

new_hf_model_params = convert_state_to_hf(training_state, config.model_name)
print(f"Saving HuggingFace model to path = {hf_model_path}")
Expand Down

0 comments on commit 54f2f19

Please sign in to comment.