Skip to content

Commit

Permalink
add simple accuracy check.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyang-google committed Aug 20, 2024
1 parent 54f2f19 commit 4e43965
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import checkpointing
from generate_param_only_checkpoint import _read_train_checkpoint
import llama_or_mistral_ckpt
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoTokenizer


def unpermute_from_match_maxtext_rope(arr):
Expand Down Expand Up @@ -231,6 +231,17 @@ def convert_orbax_hf(hf_model_path, config):
print(f"Saving HuggingFace model to path = {hf_model_path}")
hf_model.save_pretrained(hf_model_path, state_dict=new_hf_model_params)

# load HF checkpoint to verify if it's correct
model_id = "mistralai/Mixtral-8x7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)

converted_hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, device_map="auto")
text = "Harry potter is "
inputs = tokenizer(text, return_tensors="pt")

outputs = converted_hf_model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


def main(argv: Sequence[str]):
print(argv[:-1])
Expand Down

0 comments on commit 4e43965

Please sign in to comment.