Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DON'T MERGE] GCS Checkpointing Testing Workload modification #782

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

bernardhan33
Copy link
Collaborator

This is created as a draft PR for GCS internal members to comment. This will not be merged to main.

Checkpointing a 64B model through MaxText

  • Read and Write times to be collected and sent to GCS buckets before a separate Python program aggregates and uploads to BQ. I've created b/353631904 to track the improvement of letting each pod to write directly to BQ, which is currently blocked by needed nodepool recreation.
  • A sample YAML file is provided for code review purposes.

@@ -390,3 +390,6 @@ enable_single_controller: False

# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html
allow_split_physical_axes: False

# Attributes needed for the MaxText checkpointing loop using standalone_checkpointer.
gcs_metrics_bucket: distributed-checkpointing-metrics
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be / is there any ability to provide a folder or path instead of just the bucket?

MattIrv and others added 7 commits August 20, 2024 14:03
* Configurable number of parameters and save/restore loop.

This adds the following features to the gcs-checkpointing setup:
- Configurable number of parameters by setting the $PARAMETERS env
  variable
- Reloads all checkpoints immediately after saving them
- Adds a configurable minimum step time (when saving checkpoints)

* Address code review comments

* Address round 2 of code review
…g workload (#850)

* Use quantized kv values in einsum.

* Add support for proxy backend when running maxengine_server

* Add tp support for MoE

* Incorporate GKE's stopgap solution in the absence of in-place pod restart

* Add enable_model_warmup flag for AOT compilation at model server start

* Install Transformer Engine from github for stable and nightly mode

* Add support for cloud logging in emergency checkpoint manager.

* Adding support for creating maxtext docker images with jax-ss

* Copybara import of the project:

--
a3ed121cfbebe350ba60ab06d92327b8a3dc76b6 by RissyRan <[email protected]>:

Update tp sharding annotation for MoE block

COPYBARA_INTEGRATE_REVIEW=#788 from google:tp_update a3ed121cfbebe350ba60ab06d92327b8a3dc76b6
PiperOrigin-RevId: 653803904

* Llama2 forward pass checker

* support pre-tokenized dataset and fix hf issues

* Fix a Goodput upload to TB bug

* Temp fix for c4_mlperf pipeline

* Adds new remat policy to save out_proj activations during training.

PiperOrigin-RevId: 655745941

* Add wrap argument when creating XAOT topology mesh

* Refactor code

* Fix maxtext profiler failures

* Changes to Copybara to include renamed files

PiperOrigin-RevId: 656444119

* Copybara import of the project:

--
adea1b1e8e4006de77f09e3cf09ba661a929bb93 by Param Bole <[email protected]>:

Refactoring Maxtext-JAX-Stable-Stack Build Process

COPYBARA_INTEGRATE_REVIEW=#796 from google:pb_jax_ss_maxtext adea1b1e8e4006de77f09e3cf09ba661a929bb93
PiperOrigin-RevId: 656474261

* Refactoring Maxtext-JAX-Stable-Stack Build Process

* Remove pre-commit install from setup.sh

* Internal change

PiperOrigin-RevId: 657990080

* Pin jax stable version > 0.4

* Revert internal change to maxengine

* Add unique_indices to jnp.take

* Removing unused arg PLATFORM

* Add Gemma2 9B Config, Window Attention and attention logic soft cap.

Fixing attention_type type annotation

Fixing trailing-whitespace

Updating 9B test_gemma.sh scripts

Update gemma 9B 2_test_gemma.sh

Updating gemma 9B checkpoint directory

Update default BASE_OUTPUT_PATH in gemma 9B

Remove finetunning script for now as it OOMs on TPU v4

Updating gemma-9B to gemma2-9B

Fixing comment in gemma2-9b

Updating comments on attention kernel and attention_type.

* Removing unused agument PLATFORM from preflight.sh

* Revert "Add unique_indices to jnp.take"

* Consolidate usages of `MultiprocessingOptions` and `AsyncOptions`. Formalize `StandardCheckpointer` as an `AsyncCheckpointer` that doesn't require a `CheckpointArgs` object, and instead allows directly passing the state and extra args.

PiperOrigin-RevId: 660588942

* Remove unused checkpoint import

* Add A3 AoT compilation support using mock GPU client

Support XAOT on GPU

Update mock GPU client API and avoid initializing backend in pyconfig

Allow NoneType in SystemCharacteristics

Fix spacing

Use Optional instead of Union

Make topology_name optional and improve readme

Use dot_product attention in README example

* generates padding batch in hf pipeline to improve data utilization

* Add Gemma2 Support to MaxText: Gemma2 Decoder Layers, Checkpoint Converter, Config Files, Flop Calculation and Run Scripts

* Change scripts for llama2 7b a3 GPU to use llama2_7b_gpu.yml

PiperOrigin-RevId: 662216695

* Copybara import of the project:

--
80a49ff7a7239b09433d49c7501c4c558727bc30 by RissyRan <[email protected]>:

Add instruction for Mixtral

COPYBARA_INTEGRATE_REVIEW=#822 from google:mixtral_instruct 80a49ff7a7239b09433d49c7501c4c558727bc30
PiperOrigin-RevId: 662219295

* [MLPerf][GPT3] Bypass setting eval_interval in using synthetic dataset

* Removing the resgistration of the proxy backend used by Pathways.

If needed, this is handled by a separate, Pathways specific utilities package.

* Support AoT in 16-vm GPU Llama2 train script

* Initial grad accumulate with scan

Gradient accumulation

manually tested

globals are sus

Add grad accumulation test

Gradient accumulation config assert

Gradient accumulation config assert

pylint

fixed tests

global and micro batch size

Clean up with microbatches

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

grad acc

grad acc

* Add support for local sliding window attention in TPU splash_attention

* add kl divergence and more logging for forward_pass_logit_checker

* Add dropping strategy

* add node attributes; fix GCS upload; add checkpointID

* add header to the checkpointing workload

* add rank to node attributes list

---------

Co-authored-by: Mitali Singh <[email protected]>
Co-authored-by: Shaurya Gupta <[email protected]>
Co-authored-by: maxtext authors <[email protected]>
Co-authored-by: RissyRan <[email protected]>
Co-authored-by: Xuefeng Gu <[email protected]>
Co-authored-by: Vivian Wu <[email protected]>
Co-authored-by: michelle-yooh <[email protected]>
Co-authored-by: Abhinav Singh <[email protected]>
Co-authored-by: Param Bole <[email protected]>
Co-authored-by: khatwanimohit <[email protected]>
Co-authored-by: aireenmei <[email protected]>
Co-authored-by: Dipannita Shaw <[email protected]>
Co-authored-by: ZhiyuLi-goog <[email protected]>
Co-authored-by: Raymond Zou <[email protected]>
Co-authored-by: hengtaoguo <[email protected]>
Co-authored-by: Gagik Amirkhanyan <[email protected]>
Co-authored-by: Colin Gaffney <[email protected]>
Co-authored-by: gobbleturk <[email protected]>
Co-authored-by: Jon Bolin <[email protected]>
Co-authored-by: Zhaoyue Cheng <[email protected]>
Co-authored-by: Luke Baumann <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants