diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index 1c16fc31c..c4249f955 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -45,7 +45,7 @@ import checkpointing from generate_param_only_checkpoint import _read_train_checkpoint import llama_or_mistral_ckpt -from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM +from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoTokenizer def unpermute_from_match_maxtext_rope(arr): @@ -231,6 +231,17 @@ def convert_orbax_hf(hf_model_path, config): print(f"Saving HuggingFace model to path = {hf_model_path}") hf_model.save_pretrained(hf_model_path, state_dict=new_hf_model_params) + # load HF checkpoint to verify if it's correct + model_id = "mistralai/Mixtral-8x7B-v0.1" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + converted_hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, device_map="auto") + text = "Harry potter is " + inputs = tokenizer(text, return_tensors="pt") + + outputs = converted_hf_model.generate(**inputs, max_new_tokens=20) + print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + def main(argv: Sequence[str]): print(argv[:-1])