From 344b46c34533288f9f8797228133b72d96f45f0d Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Fri, 2 Aug 2024 17:00:00 +0000 Subject: [PATCH] Initial grad accumulate with scan Gradient accumulation manually tested globals are sus Add grad accumulation test Gradient accumulation config assert Gradient accumulation config assert pylint fixed tests global and micro batch size Clean up with microbatches Gradient accumulation Gradient accumulation Gradient accumulation Gradient accumulation Gradient accumulation Gradient accumulation Gradient accumulation grad acc grad acc --- MaxText/configs/base.yml | 4 + MaxText/layers/pipeline.py | 20 ++--- MaxText/max_utils.py | 4 +- MaxText/maxtext_utils.py | 4 +- MaxText/pyconfig.py | 18 ++-- MaxText/tests/gradient_accumulation_test.py | 92 +++++++++++++++++++++ MaxText/train.py | 50 +++++++++-- 7 files changed, 166 insertions(+), 26 deletions(-) create mode 100644 MaxText/tests/gradient_accumulation_test.py diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index cb0b3f99d..b6ea09beb 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -315,6 +315,10 @@ init_weights_seed: 0 # You may disable clipping by setting gradient_clipping_threshold to zero. gradient_clipping_threshold: 1.0 +# Instead of updating the weights every step, you may effectively use a larger +# batch by accumulating the gradient over a set of steps. +gradient_accumulation_steps: 1 + # AdamW optimizer parameters # We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 opt_type: "adamw" # one of "adam_pax" or "adamw" diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py index ba69812bf..4a1456aa3 100644 --- a/MaxText/layers/pipeline.py +++ b/MaxText/layers/pipeline.py @@ -49,7 +49,7 @@ class Pipeline(nn.Module): def setup(self): self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism self.use_circ_storage = self.config.num_pipeline_repeats > 1 and self.config.num_pipeline_microbatches > self.num_stages - self.microbatch_size = self.config.global_batch_size_to_train_on // self.config.num_pipeline_microbatches + self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages self.microbatches_per_stage = microbatches_per_stage @@ -73,7 +73,7 @@ def init_states(self, inputs): # state_io (state input output) at first holds all of the input batches, but also will hold the outputs as the pipeline runs/finishes # state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed] state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) - # We shard the microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. + # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. state_io = nn.with_logical_constraint(state_io, ("activation_stage", None, "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules, mesh=self.mesh) # circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only needed @@ -353,18 +353,18 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp. Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape [global_batch], and internally this will be reshapped into microbatches. ''' - # Reshape inputs of [global_batch, ...] to [microbatches, microbatch_sizes, ...] - inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length, self.config.emb_dim)) + # Reshape inputs of [global_batch, ...] to [microbatches, pipeline_microbatch_sizes, ...] + inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module weights. if positions is not None: - positions = positions.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) + positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) position_idx = 0 else: example_position = None position_idx = None if segment_ids is not None: - segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length)) + segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) segment_idx = 0 else: @@ -414,11 +414,11 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp. stage_outputs = stage_outputs[0] # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output which has - # shape [microbatch_size, sequence, embed] + # shape [pipeline_microbatch_size, sequence, embed] if self.config.num_pipeline_repeats > 1: stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap - broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.global_batch_size_to_train_on // self.microbatch_size]) - return jnp.reshape(broadcasted_stage_outpus, [self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim]) + broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]) + return jnp.reshape(broadcasted_stage_outpus, [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim]) def run_iteration_scannable(model,loop_state, xs): # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we explicitly wrap @@ -468,6 +468,6 @@ def run_iteration_scannable(model,loop_state, xs): final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] - final_output = jnp.reshape(final_output, (self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) + final_output = jnp.reshape(final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) return final_output \ No newline at end of file diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 71673a475..0ab01cdf9 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -479,7 +479,7 @@ def init_initial_state(model, tx, config, is_training, key): Args: model, tx, config, is_training, key """ - input_shape = (config.global_batch_size_to_load, config.max_target_length) + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) model_vars = model.init( {"params": key, "dropout": key, "aqt": key}, jnp.ones(input_shape, dtype=jnp.int32), @@ -803,7 +803,7 @@ def get_kv_cache_annotations(model, config, rng, mesh): def init_kv_cache(model, config): input_shape = ( - config.global_batch_size_to_load, + config.micro_batch_size_to_train_on, config.max_prefill_predict_length, ) diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index cd6c5e7bc..6df998be0 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -94,7 +94,7 @@ def get_train_input_output_trees(func, input_args, input_kwargs): def calculate_tokens_training_per_device(config): """Calculate training Tokens per device""" - return config.max_target_length * config.per_device_batch_size + return config.max_target_length * config.per_device_batch_size * config.gradient_accumulation_steps def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops): """ @@ -171,6 +171,8 @@ def calculate_tflops_training_per_device(config, log=True): ) ) + learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps + attention_tflops = attention_tflops * config.gradient_accumulation_steps total_tflops = learnable_weight_tflops + attention_tflops if log: diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 17f75dc77..7ff4442c1 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -327,7 +327,7 @@ def user_init(raw_keys): raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"] raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"] - raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys) + raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"], raw_keys["micro_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) @@ -342,7 +342,7 @@ def user_init(raw_keys): if raw_keys['num_pipeline_microbatches'] == -1: raw_keys['num_pipeline_microbatches'] = num_stages assert raw_keys['num_pipeline_microbatches'] % num_stages == 0, f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})" - assert raw_keys['global_batch_size_to_train_on'] % raw_keys['num_pipeline_microbatches'] == 0, f"The global batch size ({raw_keys['global_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})" + assert raw_keys['micro_batch_size_to_train_on'] % raw_keys['num_pipeline_microbatches'] == 0, f"The batch size ({raw_keys['micro_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})" else: raw_keys["using_pipeline_parallelism"] = False @@ -476,17 +476,19 @@ def calculate_global_batch_sizes(raw_keys): if per_device_batch_size < 1.0: # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 if expansion_factor_real_data != -1: - global_batch_size_to_load = num_devices * expansion_factor_real_data + micro_batch_size_to_load = num_devices * expansion_factor_real_data else: - global_batch_size_to_load = num_devices + micro_batch_size_to_load = num_devices else: if expansion_factor_real_data != -1: - global_batch_size_to_load = int(num_devices * per_device_batch_size * expansion_factor_real_data) + micro_batch_size_to_load = int(num_devices * per_device_batch_size * expansion_factor_real_data) else: - global_batch_size_to_load = int(num_devices * per_device_batch_size) + micro_batch_size_to_load = int(num_devices * per_device_batch_size) - global_batch_size_to_train_on = int(num_devices * per_device_batch_size) - return global_batch_size_to_load, global_batch_size_to_train_on + micro_batch_size_to_train_on = int(num_devices * per_device_batch_size) + global_batch_size_to_load = int(micro_batch_size_to_load * raw_keys["gradient_accumulation_steps"]) + global_batch_size_to_train_on = int(micro_batch_size_to_train_on * raw_keys["gradient_accumulation_steps"]) + return global_batch_size_to_load, global_batch_size_to_train_on, micro_batch_size_to_train_on def get_num_target_devices(raw_keys): diff --git a/MaxText/tests/gradient_accumulation_test.py b/MaxText/tests/gradient_accumulation_test.py new file mode 100644 index 000000000..009bc1110 --- /dev/null +++ b/MaxText/tests/gradient_accumulation_test.py @@ -0,0 +1,92 @@ +""" +Copyright 2024 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# pylint: disable=missing-module-docstring, missing-function-docstring +import numpy as np +import json +import unittest +import pytest +import string +import random +from train import main as train_main + +def generate_random_string(length=10): + characters = string.ascii_letters # Include letters, digits, and punctuation + return ''.join(random.choice(characters) for _ in range(length)) + +class GradientAccumulationTest(unittest.TestCase): + + + @pytest.mark.tpu + def test_grad_accumulate_same_loss(self): + random_suffix = generate_random_string() + run_accumulate_metrics_file = f"/tmp/runner_grad_accumulate_{random_suffix}.txt" + print(f"{run_accumulate_metrics_file=}") + run_regular_metrics_file = f"/tmp/runner_regular_{random_suffix}.txt" + print(f"{run_regular_metrics_file=}") + shared_maxtext_args = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + r"dataset_path=gs://maxtext-dataset", + "gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off) + "enable_checkpointing=False", + "base_emb_dim=256", + "base_num_decoder_layers=4", + "tokenizer_path=../assets/tokenizer.llama2", + "steps=50", + ] + # Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10 + train_main(shared_maxtext_args + [ + "run_name=runner_grad_accumulate", + f"metrics_file={run_accumulate_metrics_file}", + "per_device_batch_size=1", + "gradient_accumulation_steps=10", + ]) + + #Run without gradient accumulation with per_device_batch=10 + train_main(shared_maxtext_args + [ + "run_name=runner_grad_accumulate_regular", + f"metrics_file={run_regular_metrics_file}", + "per_device_batch_size=10", + "gradient_accumulation_steps=1", + ]) + + # Assert losses roughly equal + with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\ + open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run: + accum_run_loss = json.loads(accum_run.readlines()[-1])["learning/loss"] + regular_run_loss = json.loads(regular_run.readlines()[-1])["learning/loss"] + print(f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}", flush=True) + print(f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}", flush=True) + # Not identical due to an epsilon addition in loss denominator. + np.testing.assert_allclose(accum_run_loss, regular_run_loss, rtol=0.01) + + # Assert grad norms roughly equal + with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\ + open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run: + accum_run_grad_norm= json.loads(accum_run.readlines()[-1])["learning/raw_grad_norm"] + regular_run_grad_norm = json.loads(regular_run.readlines()[-1])["learning/raw_grad_norm"] + print(f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}", flush=True) + print(f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}", flush=True) + # Not identical due to an epsilon addition in loss denominator. + np.testing.assert_allclose(accum_run_grad_norm, regular_run_grad_norm, rtol=0.01) + + # Assert per device tflops are the same (10x smaller microbatch size, but 10x more microbatches) + with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\ + open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run: + accum_device_tflops = json.loads(accum_run.readlines()[-1])["perf/per_device_tflops"] + regular_device_tflops = json.loads(regular_run.readlines()[-1])["perf/per_device_tflops"] + print(f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}", flush=True) + print(f"[Gradient Accumulation Test] per_device_tflops without gradient accumulation: {regular_device_tflops}", flush=True) + np.testing.assert_equal(accum_device_tflops, regular_device_tflops) diff --git a/MaxText/train.py b/MaxText/train.py index cc045e3f1..e85d580ad 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -77,6 +77,9 @@ def validate_train_config(config): if not config.base_output_directory.startswith("gs://"): max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." + if config.quantization=='fp8': + # pylint: disable=line-too-long + assert config.gradient_accumulation_steps == 1, "fp8 can't be used with gradient_accumulation_steps right now. Please use other quantization or set gradient_accumulation_steps to 1" def get_first_step(state): @@ -156,6 +159,11 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") writer.flush() +def clear_buffered_metrics(): + global _buffered_step + global _buffered_metrics + _buffered_step = None + _buffered_metrics = None def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None): """Wrapper for saving checkpoint""" @@ -224,7 +232,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): # decimate proportion of data when per_device_batch_size<1 if is_train: for k, v in data.items(): - data[k] = v[: config.global_batch_size_to_train_on, :] + data[k] = v[: config.micro_batch_size_to_train_on, :] logits, intermediate_outputs = model.apply( params, @@ -266,9 +274,39 @@ def train_step(model, config, state, data, dropout_rng): rng2: A new rng key that can be used in future calls. """ - train_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=True) - grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True) - (loss, aux), raw_grads = grad_fn(state.params) + if config.gradient_accumulation_steps > 1: + def accumulate_gradient(acc_grad_and_loss, data): + grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) + (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, state.params, is_train=True) + acc_grad_and_loss['loss'] += aux['total_loss'] + acc_grad_and_loss['grad'] = jax.tree_util.tree_map( + lambda x, y: x * aux['total_weights'] + y, + cur_batch_gradient, + acc_grad_and_loss['grad']) + acc_grad_and_loss['total_weights'] += aux['total_weights'] + return acc_grad_and_loss, aux + + def reshape_to_microbatch_accumulations(batch_arr): + ''' Reshape global batch to microbatches, assuming batch axis is leading.''' + microbatches = config.gradient_accumulation_steps + microbatch_shape = (microbatches, batch_arr.shape[0] // microbatches) + batch_arr.shape[1:] + return jnp.reshape(batch_arr, microbatch_shape) + + data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data) + init_grad = jax.tree_util.tree_map(jnp.zeros_like, state.params) + init_grad_and_loss = {'loss': 0.0, 'grad': init_grad, 'total_weights':0} + + grad_and_loss, aux = jax.lax.scan( + accumulate_gradient, + init_grad_and_loss, + data, + length = config.gradient_accumulation_steps) + loss = grad_and_loss['loss'] / grad_and_loss['total_weights'] + raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss['total_weights'], grad_and_loss['grad']) + aux = jax.tree_map(lambda x: jnp.sum(x, axis=0), aux) + else: + grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, state.params, is_train=True) intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] @@ -589,13 +627,15 @@ def train_loop(config, state=None): write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics max_utils.close_summary_writer(writer) record_goodput(recorder, config, job_end=True) + clear_buffered_metrics() return state def main(argv: Sequence[str]) -> None: jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" pyconfig.initialize(argv) max_utils.print_system_information() config = pyconfig.config