Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Gemma2 support to MaxText #814

Merged
merged 1 commit into from
Aug 10, 2024
Merged

add Gemma2 support to MaxText #814

merged 1 commit into from
Aug 10, 2024

Conversation

ZhaoyueCheng
Copy link
Collaborator

@ZhaoyueCheng ZhaoyueCheng commented Aug 6, 2024

add Gemma2 support to MaxText on top of PR

  • add gemma Decoder block which merges [local_sliding_attention, global_attention] combination into one Decoder layer with post_attn_norm, post_ffw_norm support
  • add convert_gemma2_chkpt.py to convert checkpoint from Gemma2 architecture
    • merges one [local_sliding_attention, global_attention] combination into one Decoder layer
    • adds post_attn_norm, post_ffw_norm, transpose_gating_einsum and query_pre_attn_scalar support which is new from Gemma1
  • enabling post_attn_norm, post_ffw_norm in the Gemma2 Decoder architecture
  • add final_logits_soft_cap support
  • add gemma2-2b gemma2-9b gemma2-27b config yaml file and 2b end_to_end test script and test golden logits dumped from flax

Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some quick comments

MaxText/configs/base.yml Show resolved Hide resolved
MaxText/configs/models/gemma2-9b.yml Outdated Show resolved Hide resolved
end_to_end/tpu/gemma2/9b/1_test_gemma.sh Outdated Show resolved Hide resolved
end_to_end/tpu/gemma2/9b/2_test_gemma.sh Outdated Show resolved Hide resolved
MaxText/pyconfig.py Outdated Show resolved Hide resolved
@@ -0,0 +1,331 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to open source all the notebooks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had the original gemma export notebook and other notebooks (llama, mixtral, etc) open sourced so added the gemma2 notebook to scratch_code folder as well with other open source notebooks

@salrowili
Copy link

Please note that if we merge local_sliding_attention, global_attention and assigned half the decoder layer to "base_num_decoder_layers", the TFLOPS calculation will report half the actual TFLOPS . This is due to how maxtext_utils calcuate the TFLOPS
https://github.com/google/maxtext/blob/644eb87ae90dd8b210ce17f1c16ca7a54e80fceb/MaxText/maxtext_utils.py#L139

@gobbleturk
Copy link
Collaborator

Please note that if we merge local_sliding_attention, global_attention and assigned half the decoder layer to "base_num_decoder_layers", the TFLOPS calculation will report half the actual TFLOPS . This is due to how maxtext_utils calcuate the TFLOPS

https://github.com/google/maxtext/blob/644eb87ae90dd8b210ce17f1c16ca7a54e80fceb/MaxText/maxtext_utils.py#L139

Good catch! Thank you for noting this

@ZhaoyueCheng
Copy link
Collaborator Author

Please note that if we merge local_sliding_attention, global_attention and assigned half the decoder layer to "base_num_decoder_layers", the TFLOPS calculation will report half the actual TFLOPS . This is due to how maxtext_utils calcuate the TFLOPS

https://github.com/google/maxtext/blob/644eb87ae90dd8b210ce17f1c16ca7a54e80fceb/MaxText/maxtext_utils.py#L139

updated, thanks for the comment!

Copy link
Collaborator

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Left a few nits.

QQ: Are we going to add logit checker for Gemma9B in a separate PR? Since you have the code for golden logits here, can you also add golden logits jsonl file for Gemma2-9B in this PR

@@ -0,0 +1,246 @@
"""
Copyright 2023 Google LLC
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Update the copyright to 2024

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, thanks for the review!

MaxText/pyconfig.py Outdated Show resolved Hide resolved
@khatwanimohit khatwanimohit removed their assignment Aug 7, 2024
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks very good, thank you for adding this! Just want to clean up the flops calculation.

MaxText/maxtext_utils.py Outdated Show resolved Hide resolved
MaxText/pyconfig.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you for fixing the tflops calculation!

@gobbleturk gobbleturk assigned ZhaoyueCheng and unassigned gobbleturk Aug 9, 2024
@ZhaoyueCheng
Copy link
Collaborator Author

Awesome, thank you for fixing the tflops calculation!

Thanks for the detailed review and suggestions!!

…erter, Config Files, Flop Calculation and Run Scripts
@copybara-service copybara-service bot merged commit da50760 into main Aug 10, 2024
13 of 14 checks passed
@copybara-service copybara-service bot deleted the gemma2-2b branch August 10, 2024 00:24
@salrowili
Copy link

@ZhaoyueCheng I would like to thank you for adding Gemma 2 support which really is appreciated.

I have pre-trained Gemma 2 using maxtext but now i stuck since i could not convert MaxText to hugging face format to do the SFT stage. Gemma2, unlike Llama, Gemma and Mistral, uses local and global attention . Is there any thing you can do to "MaxText/llama_or_mistral_ckpt.py" script to include Gemma 2 checkpoint conversion to HF format? The global and local attention used in the following weights:

mlp_global
mlp_local
post_ffw_norm_global
post_ffw_norm_local
post_self_attention_norm_global
post_self_attention_norm_local
pre_ffw_norm_global
pre_ffw_norm_local
pre_self_attention_norm_global
pre_self_attention_norm_local
self_attention_global

This related to #829

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants