diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index b861b5989..1c16fc31c 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -125,7 +125,7 @@ def convert_state_to_hf(training_state, model_size): # Port the embedding weights hf_model_params["model.embed_tokens.weight"] = torch.tensor(np.asarray( training_state.params['params']['token_embedder']['embedding']), - dtype=torch.float32) + dtype=torch.float16) for layer_int in tqdm(range(base_num_decoder_layers),desc="Porting parameters layerwise"): print("Converting weights for layer {}".format(layer_int)) @@ -138,79 +138,79 @@ def convert_state_to_hf(training_state, model_size): ,head_dim ) ).reshape(base_num_query_heads * head_dim,base_num_query_heads * head_dim).T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor(np.asarray( unpermute_from_match_maxtext_rope( training_state.params['params']["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :] ).reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["layers"]["self_attention"]["value"]["kernel"][:, layer_int, :, :] .reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.self_attn.o_proj.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["layers"]["self_attention"]["out"]["kernel"][:, layer_int, :, :] .reshape(base_num_query_heads * head_dim,base_num_query_heads * head_dim).T), - dtype=torch.float32 + dtype=torch.float16 ) # MLP Layers if num_experts is None: hf_model_params[f"model.layers.{layer_int}.mlp.gate_proj.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["layers"]["mlp"]["wi_0"]["kernel"][:, layer_int, :].T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.mlp.up_proj.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["layers"]["mlp"]["wi_1"]["kernel"][:, layer_int, :].T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.mlp.down_proj.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["layers"]["mlp"]["wo"]["kernel"][:, layer_int, :].T), - dtype=torch.float32 + dtype=torch.float16 ) else: hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.gate.weight"] = torch.tensor(np.asarray( training_state.params['params']['decoder']['layers']['MoeBlock_0']['gate']['kernel'][:,layer_int,:].T - ), dtype=torch.float32) + ), dtype=torch.float16) for k in range(num_experts): print("Coverting MoeBlock expert {} weights".format(k)) hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w1.weight"] = torch.tensor(np.asarray( training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_0'][k, layer_int, :, :].T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w2.weight"] = torch.tensor(np.asarray( training_state.params['params']['decoder']['layers']['MoeBlock_0']['wo'][k, layer_int, :, :].T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w3.weight"] = torch.tensor(np.asarray( training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_1'][k, layer_int, :, :].T), - dtype=torch.float32 + dtype=torch.float16 ) # Pre/post attention layer norm hf_model_params[f"model.layers.{layer_int}.input_layernorm.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"][:, layer_int] .reshape(base_num_query_heads * head_dim)), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params[f"model.layers.{layer_int}.post_attention_layernorm.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"][:, layer_int] .reshape(base_num_query_heads * head_dim)), - dtype=torch.float32 + dtype=torch.float16 ) # LM head and layernorm hf_model_params["lm_head.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["logits_dense"]["kernel"].T), - dtype=torch.float32 + dtype=torch.float16 ) hf_model_params["model.norm.weight"] = torch.tensor(np.asarray( training_state.params['params']["decoder"]["decoder_norm"]["scale"].reshape(base_num_query_heads * head_dim)), - dtype=torch.float32 + dtype=torch.float16 ) return hf_model_params @@ -225,7 +225,7 @@ def convert_orbax_hf(hf_model_path, config): print("MoeBlock gate shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['gate']['kernel'].shape)) print("MoeBlock wi_0 shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_0'].shape)) print("MoeBlock w1_1 shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_1'].shape)) - print("MoeBlock w0 shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wo'].shape)) + print("MoeBlock wo shape: {}".format(training_state.params['params']['decoder']['layers']['MoeBlock_0']['wo'].shape)) new_hf_model_params = convert_state_to_hf(training_state, config.model_name) print(f"Saving HuggingFace model to path = {hf_model_path}")