Skip to content

Commit

Permalink
Merge Main into mlperf/4.1 branch
Browse files Browse the repository at this point in the history
  • Loading branch information
anfals committed Aug 20, 2024
2 parents 90d7304 + 14379df commit 8cbea66
Show file tree
Hide file tree
Showing 111 changed files with 4,781 additions and 775 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ jobs:
- name: build jax nightly image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_nightly MODE=nightly DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_nightly
- name: build jax stable stack image
run : |
bash docker_maxtext_jax_stable_stack_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxtext-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true
gpu:
strategy:
fail-fast: false
Expand Down
239 changes: 123 additions & 116 deletions MaxText/accelerator_to_spec_map.py

Large diffs are not rendered by default.

24 changes: 5 additions & 19 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from etils import epath
import orbax.checkpoint
from orbax.checkpoint.logging import abstract_logger, cloud_logger, standard_logger, composite_logger
from orbax.checkpoint import pytree_checkpoint_handler, type_handlers
from orbax.checkpoint import pytree_checkpoint_handler
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions, PyTree
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import jax
Expand Down Expand Up @@ -80,25 +80,12 @@ def create_orbax_emergency_checkpoint_manager(
abstract_state: PyTree,
local_save_interval_steps: int,
persistent_save_interval_steps: int,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
):
"""Returns an emergency checkpoint."""
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
max_logging.log("Creating emergency checkpoint manager...")

local_registry = type_handlers.create_type_handler_registry(
(
jax.Array,
type_handlers.ArrayHandler(primary_host=None, replica_id=None),
),
)

local_checkpoint_handler = PyTreeCheckpointHandler(
use_ocdbt=True,
use_zarr3=True,
primary_host=None,
type_handler_registry=local_registry,
)

options = emergency_checkpoint_manager.CheckpointManagerOptions(
local=LocalCheckpointOptions(
save_interval_steps=local_save_interval_steps
Expand All @@ -107,14 +94,14 @@ def create_orbax_emergency_checkpoint_manager(
save_interval_steps=persistent_save_interval_steps
),
)

emergency_mngr = emergency_checkpoint_manager.CheckpointManager(
local_checkpoint_dir,
epath.Path(persistent_checkpoint_dir),
global_mesh=global_mesh,
abstract_state=abstract_state,
options=options,
local_state_handler=local_checkpoint_handler,
local_state_handler=emergency_checkpoint_manager.local_checkpoint_handler(),
logger=orbax_logger,
)

max_logging.log("Emergency checkpoint manager created!")
Expand Down Expand Up @@ -261,7 +248,7 @@ def map_to_pspec(data):
max_logging.log(f"restoring full state from {load_full_state_from_path=}")
p = epath.Path(load_full_state_from_path)
ckptr = orbax.checkpoint.StandardCheckpointer()
restored = ckptr.restore(p, args=orbax.checkpoint.args.StandardRestore(abstract_unboxed_pre_state))
restored = ckptr.restore(p, abstract_unboxed_pre_state)
return {"items": restored}, None

else:
Expand Down Expand Up @@ -333,4 +320,3 @@ def save_params_to_path(checkpoint_dir, params):
force=True
)
print(f"Quantized params checkpoint saved at: {checkpoint_dir}")

1 change: 1 addition & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

BATCH = "activation_batch"
LENGTH = "activation_length"
EMBED = "activation_embed"
HEAD = "activation_heads"
KV_BATCH = "activation_kv_batch"
KV_HEAD = "activation_kv_heads"
Expand Down
7 changes: 3 additions & 4 deletions MaxText/configs/a3/llama_2_7b/16vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set -e

export OUTPUT_PATH="gs://maxtext-experiments-multipod"
export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)"
export EXECUTABLE="train.py"

# Set environment variables
for ARGUMENT in "$@"; do
Expand All @@ -29,7 +30,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_disable_hlo_passes=rematerialization"

# 16 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs profiler=xplane
python MaxText/$EXECUTABLE MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
dcn_data_parallelism=16 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
6 changes: 2 additions & 4 deletions MaxText/configs/a3/llama_2_7b/1vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/


# 1 node, DATA_DP=1, ICI_FSDP=8
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu\
steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=$OUTPUT_PATH
python MaxText/train.py MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
dcn_data_parallelism=1 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
7 changes: 2 additions & 5 deletions MaxText/configs/a3/llama_2_7b/2vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/


# 2 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=2 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs profiler=xplane

python MaxText/train.py MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
dcn_data_parallelism=2 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
7 changes: 2 additions & 5 deletions MaxText/configs/a3/llama_2_7b/4vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_disable_hlo_passes=rematerialization"

# 4 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=4 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs profiler=xplane

python MaxText/train.py MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
dcn_data_parallelism=4 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
7 changes: 2 additions & 5 deletions MaxText/configs/a3/llama_2_7b/8vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_disable_hlo_passes=rematerialization"

# 8 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=8 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs profiler=xplane

python MaxText/train.py MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
dcn_data_parallelism=8 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane
52 changes: 46 additions & 6 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,23 @@ save_config_to_gcs: False

# Activation dtypes.
dtype: "bfloat16"
quantization: "" #defaults to no quantization, i.e. bf16. possible alternative setting is 'int8' or use fp8 to run with 8-bit floating-point GeMMs on NVIDIA GPUs.
quantize_kvcache: False
# Used to configure quantization in the transformer layers, defaults to null implying bf16.
# Possible alternative settings are as follows:
# 'int8' for dynamic range quantization using 8-bits
# 'int8w' for weights only quantization using 8-bits
# 'int4w' for weights only quantization using 4-bits
# 'intmp' for mixed precision weight only quantization based on config file
# 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
quantization: ""
# Path to file with quantization config - only used for mixed precision.
# Example configs in ../Maxtext/configs/quantization
# Allows us to configure different bits, tiling and scale for quantizing selected weights.
# Bits represents number of bits to quantize to,
# tile-size represents the tiling sized used in AQT tiled_dot_general,
# Value of scale is used to scale the abs_max value used for AQT quantization
# Defaults values are 8 bits, tile-size=-1 (no tiling) and scale=1.
quant_cfg_path: ""
quantize_kvcache: False # Set to True to quantize KV Cache values, defaults to False
# Valid kv_quant_axis values:
# - "" is valid only when quantize_kvcache is False
# - "dkv" indicates quantize kv cache over the cache_kv, i.e. kv dimension axis
Expand Down Expand Up @@ -114,12 +129,22 @@ num_pipeline_repeats: -1
num_pipeline_microbatches: -1
scan_pipeline_iterations: True # This can be set independently of scan_layers, which is relevant when num_layers_per_pipeline_stage > 1.

# Choose 'remat_policy' between 'minimal', 'save_dot_except_mlpwi', 'save_dot_except_mlp', 'save_qkv_proj', 'qkv_proj_offloaded', 'minimal_offloaded' and 'full'.
# Choose 'remat_policy' between 'minimal', 'save_dot_except_mlpwi', 'save_dot_except_mlp', 'save_qkv_proj', 'qkv_proj_offloaded', 'minimal_offloaded', 'save_out_proj' and 'full'.
# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest)
remat_policy: 'full'
scan_layers: True
param_scan_axis: 1

# The attention parameter dictates the specific algorithm/methodology used to compute the attention scores
# The attention_type parameter determines the variants of attention, e.g. global or local_sliding
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te
attention_type: 'global' # Supported attention_type: global, local_sliding
sliding_window_size: 0
attn_logits_soft_cap: 0.0
final_logits_soft_cap: 0.0
use_post_attn_norm: False
use_post_ffw_norm: False


# Combine matmuls for QKV and MLP
fused_qkv: False
Expand Down Expand Up @@ -210,14 +235,21 @@ ici_pipeline_parallelism: 1
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
num_slices: -1

# Tokenizer and Dataset
# Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
tokenizer_path: "assets/tokenizer.llama2"
tokenize_train_data: True # False if the dataset is pre-tokenized
tokenize_eval_data: True # False if the dataset is pre-tokenized
add_bos: True
add_eos: True

# Dataset
per_device_batch_size: 12.0
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
train_data_column: 'text'
eval_data_column: 'text'
# dataset_type must be synthetic, hf, grain, tfds
# details in: https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md
dataset_type: tfds
Expand Down Expand Up @@ -283,6 +315,10 @@ init_weights_seed: 0
# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0

# Instead of updating the weights every step, you may effectively use a larger
# batch by accumulating the gradient over a set of steps.
gradient_accumulation_steps: 1

# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
opt_type: "adamw" # one of "adam_pax" or "adamw"
Expand Down Expand Up @@ -318,11 +354,13 @@ decode_sampling_top_k: 0 # set if you're doing top-k
decode_sampling_temperature: 1.

eval_interval: -1 # the specific number of train step between eval_step
eval_batch_num: -1 # only run this number of batches for eval, for debugging use
eval_steps: -1 # only run this number of batches for eval, for debugging use
target_eval_loss: 0. # early stop once reaching target eval_loss

# Goodput parameters
enable_goodput_recording: False
monitor_goodput: False
goodput_upload_interval_seconds: 60

# Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md
# Set to True for GCE, False if running via XPK
Expand All @@ -343,6 +381,8 @@ inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""
inference_metadata_file: "" # path to a json file
enable_model_warmup: False


# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
Expand Down
2 changes: 2 additions & 0 deletions MaxText/configs/inference_jetstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ base_config: "base.yml"

enable_jax_profiler: False
jax_profiler_port: 9999

enable_model_warmup: False
55 changes: 0 additions & 55 deletions MaxText/configs/llama2_70b_gpu.yml

This file was deleted.

57 changes: 0 additions & 57 deletions MaxText/configs/llama2_7b_gpu.yml

This file was deleted.

Loading

0 comments on commit 8cbea66

Please sign in to comment.