-
Notifications
You must be signed in to change notification settings - Fork 293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Gemma2 support to MaxText #814
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some quick comments
@@ -0,0 +1,331 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to open source all the notebooks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we had the original gemma export notebook and other notebooks (llama, mixtral, etc) open sourced so added the gemma2 notebook to scratch_code folder as well with other open source notebooks
Please note that if we merge local_sliding_attention, global_attention and assigned half the decoder layer to "base_num_decoder_layers", the TFLOPS calculation will report half the actual TFLOPS . This is due to how maxtext_utils calcuate the TFLOPS |
Good catch! Thank you for noting this |
updated, thanks for the comment! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! Left a few nits.
QQ: Are we going to add logit checker for Gemma9B in a separate PR? Since you have the code for golden logits here, can you also add golden logits jsonl file for Gemma2-9B in this PR
MaxText/layers/gemma2.py
Outdated
@@ -0,0 +1,246 @@ | |||
""" | |||
Copyright 2023 Google LLC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Update the copyright to 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated, thanks for the review!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks very good, thank you for adding this! Just want to clean up the flops calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you for fixing the tflops calculation!
Thanks for the detailed review and suggestions!! |
…erter, Config Files, Flop Calculation and Run Scripts
@ZhaoyueCheng I would like to thank you for adding Gemma 2 support which really is appreciated. I have pre-trained Gemma 2 using maxtext but now i stuck since i could not convert MaxText to hugging face format to do the SFT stage. Gemma2, unlike Llama, Gemma and Mistral, uses local and global attention . Is there any thing you can do to "MaxText/llama_or_mistral_ckpt.py" script to include Gemma 2 checkpoint conversion to HF format? The global and local attention used in the following weights:
This related to #829 |
add Gemma2 support to MaxText on top of PR
gemma
Decoder block which merges [local_sliding_attention, global_attention] combination into one Decoder layer withpost_attn_norm
,post_ffw_norm
supportconvert_gemma2_chkpt.py
to convert checkpoint from Gemma2 architecturepost_attn_norm
,post_ffw_norm
,transpose_gating_einsum
andquery_pre_attn_scalar
support which is new from Gemma1post_attn_norm
,post_ffw_norm
in the Gemma2 Decoder architecturefinal_logits_soft_cap
supportgemma2-2b
gemma2-9b
gemma2-27b
config yaml file and 2b end_to_end test script and test golden logits dumped from flax