Skip to content

Commit

Permalink
Initial grad accumulate with scan
Browse files Browse the repository at this point in the history
Gradient accumulation

manually tested

globals are sus

Add grad accumulation test

Gradient accumulation config assert

Gradient accumulation config assert

pylint

fixed tests

global and micro batch size

Clean up with microbatches

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

Gradient accumulation

grad acc

grad acc
  • Loading branch information
gobbleturk committed Aug 16, 2024
1 parent 0a09954 commit 344b46c
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 26 deletions.
4 changes: 4 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ init_weights_seed: 0
# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0

# Instead of updating the weights every step, you may effectively use a larger
# batch by accumulating the gradient over a set of steps.
gradient_accumulation_steps: 1

# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
opt_type: "adamw" # one of "adam_pax" or "adamw"
Expand Down
20 changes: 10 additions & 10 deletions MaxText/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Pipeline(nn.Module):
def setup(self):
self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism
self.use_circ_storage = self.config.num_pipeline_repeats > 1 and self.config.num_pipeline_microbatches > self.num_stages
self.microbatch_size = self.config.global_batch_size_to_train_on // self.config.num_pipeline_microbatches
self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches
microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages
self.microbatches_per_stage = microbatches_per_stage

Expand All @@ -73,7 +73,7 @@ def init_states(self, inputs):
# state_io (state input output) at first holds all of the input batches, but also will hold the outputs as the pipeline runs/finishes
# state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed]
state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:])
# We shard the microbatch_size axis by data/fsdp, not num_microbatches since those are looped over.
# We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over.
state_io = nn.with_logical_constraint(state_io, ("activation_stage", None, "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules, mesh=self.mesh)

# circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only needed
Expand Down Expand Up @@ -353,18 +353,18 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp.
Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape [global_batch], and internally
this will be reshapped into microbatches.
'''
# Reshape inputs of [global_batch, ...] to [microbatches, microbatch_sizes, ...]
inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length, self.config.emb_dim))
# Reshape inputs of [global_batch, ...] to [microbatches, pipeline_microbatch_sizes, ...]
inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim))
example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module weights.
if positions is not None:
positions = positions.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length))
positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length))
example_position = jax.lax.broadcast(positions[0], [self.num_stages])
position_idx = 0
else:
example_position = None
position_idx = None
if segment_ids is not None:
segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.microbatch_size, self.config.max_target_length))
segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length))
example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages])
segment_idx = 0
else:
Expand Down Expand Up @@ -414,11 +414,11 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp.
stage_outputs = stage_outputs[0]

# We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output which has
# shape [microbatch_size, sequence, embed]
# shape [pipeline_microbatch_size, sequence, embed]
if self.config.num_pipeline_repeats > 1:
stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap
broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.global_batch_size_to_train_on // self.microbatch_size])
return jnp.reshape(broadcasted_stage_outpus, [self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim])
broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size])
return jnp.reshape(broadcasted_stage_outpus, [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim])

def run_iteration_scannable(model,loop_state, xs):
# flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we explicitly wrap
Expand Down Expand Up @@ -468,6 +468,6 @@ def run_iteration_scannable(model,loop_state, xs):
final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"])

# reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed]
final_output = jnp.reshape(final_output, (self.config.global_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim))
final_output = jnp.reshape(final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim))

return final_output
4 changes: 2 additions & 2 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def init_initial_state(model, tx, config, is_training, key):
Args: model, tx, config, is_training, key
"""
input_shape = (config.global_batch_size_to_load, config.max_target_length)
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
model_vars = model.init(
{"params": key, "dropout": key, "aqt": key},
jnp.ones(input_shape, dtype=jnp.int32),
Expand Down Expand Up @@ -803,7 +803,7 @@ def get_kv_cache_annotations(model, config, rng, mesh):

def init_kv_cache(model, config):
input_shape = (
config.global_batch_size_to_load,
config.micro_batch_size_to_train_on,
config.max_prefill_predict_length,
)

Expand Down
4 changes: 3 additions & 1 deletion MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_train_input_output_trees(func, input_args, input_kwargs):

def calculate_tokens_training_per_device(config):
"""Calculate training Tokens per device"""
return config.max_target_length * config.per_device_batch_size
return config.max_target_length * config.per_device_batch_size * config.gradient_accumulation_steps

def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops):
"""
Expand Down Expand Up @@ -171,6 +171,8 @@ def calculate_tflops_training_per_device(config, log=True):
)
)

learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps
attention_tflops = attention_tflops * config.gradient_accumulation_steps
total_tflops = learnable_weight_tflops + attention_tflops

if log:
Expand Down
18 changes: 10 additions & 8 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def user_init(raw_keys):
raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"]
raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"]

raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys)
raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"], raw_keys["micro_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys)
raw_keys["num_slices"] = get_num_slices(raw_keys)
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)

Expand All @@ -342,7 +342,7 @@ def user_init(raw_keys):
if raw_keys['num_pipeline_microbatches'] == -1:
raw_keys['num_pipeline_microbatches'] = num_stages
assert raw_keys['num_pipeline_microbatches'] % num_stages == 0, f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})"
assert raw_keys['global_batch_size_to_train_on'] % raw_keys['num_pipeline_microbatches'] == 0, f"The global batch size ({raw_keys['global_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})"
assert raw_keys['micro_batch_size_to_train_on'] % raw_keys['num_pipeline_microbatches'] == 0, f"The batch size ({raw_keys['micro_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})"
else:
raw_keys["using_pipeline_parallelism"] = False

Expand Down Expand Up @@ -476,17 +476,19 @@ def calculate_global_batch_sizes(raw_keys):
if per_device_batch_size < 1.0:
# For per_device_batch_size<1, we load the data as if per_device_batch_size=1
if expansion_factor_real_data != -1:
global_batch_size_to_load = num_devices * expansion_factor_real_data
micro_batch_size_to_load = num_devices * expansion_factor_real_data
else:
global_batch_size_to_load = num_devices
micro_batch_size_to_load = num_devices
else:
if expansion_factor_real_data != -1:
global_batch_size_to_load = int(num_devices * per_device_batch_size * expansion_factor_real_data)
micro_batch_size_to_load = int(num_devices * per_device_batch_size * expansion_factor_real_data)
else:
global_batch_size_to_load = int(num_devices * per_device_batch_size)
micro_batch_size_to_load = int(num_devices * per_device_batch_size)

global_batch_size_to_train_on = int(num_devices * per_device_batch_size)
return global_batch_size_to_load, global_batch_size_to_train_on
micro_batch_size_to_train_on = int(num_devices * per_device_batch_size)
global_batch_size_to_load = int(micro_batch_size_to_load * raw_keys["gradient_accumulation_steps"])
global_batch_size_to_train_on = int(micro_batch_size_to_train_on * raw_keys["gradient_accumulation_steps"])
return global_batch_size_to_load, global_batch_size_to_train_on, micro_batch_size_to_train_on


def get_num_target_devices(raw_keys):
Expand Down
92 changes: 92 additions & 0 deletions MaxText/tests/gradient_accumulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# pylint: disable=missing-module-docstring, missing-function-docstring
import numpy as np
import json
import unittest
import pytest
import string
import random
from train import main as train_main

def generate_random_string(length=10):
characters = string.ascii_letters # Include letters, digits, and punctuation
return ''.join(random.choice(characters) for _ in range(length))

class GradientAccumulationTest(unittest.TestCase):


@pytest.mark.tpu
def test_grad_accumulate_same_loss(self):
random_suffix = generate_random_string()
run_accumulate_metrics_file = f"/tmp/runner_grad_accumulate_{random_suffix}.txt"
print(f"{run_accumulate_metrics_file=}")
run_regular_metrics_file = f"/tmp/runner_regular_{random_suffix}.txt"
print(f"{run_regular_metrics_file=}")
shared_maxtext_args = [
None,
"configs/base.yml",
r"base_output_directory=gs://runner-maxtext-logs",
r"dataset_path=gs://maxtext-dataset",
"gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off)
"enable_checkpointing=False",
"base_emb_dim=256",
"base_num_decoder_layers=4",
"tokenizer_path=../assets/tokenizer.llama2",
"steps=50",
]
# Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10
train_main(shared_maxtext_args + [
"run_name=runner_grad_accumulate",
f"metrics_file={run_accumulate_metrics_file}",
"per_device_batch_size=1",
"gradient_accumulation_steps=10",
])

#Run without gradient accumulation with per_device_batch=10
train_main(shared_maxtext_args + [
"run_name=runner_grad_accumulate_regular",
f"metrics_file={run_regular_metrics_file}",
"per_device_batch_size=10",
"gradient_accumulation_steps=1",
])

# Assert losses roughly equal
with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\
open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run:
accum_run_loss = json.loads(accum_run.readlines()[-1])["learning/loss"]
regular_run_loss = json.loads(regular_run.readlines()[-1])["learning/loss"]
print(f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}", flush=True)
print(f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}", flush=True)
# Not identical due to an epsilon addition in loss denominator.
np.testing.assert_allclose(accum_run_loss, regular_run_loss, rtol=0.01)

# Assert grad norms roughly equal
with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\
open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run:
accum_run_grad_norm= json.loads(accum_run.readlines()[-1])["learning/raw_grad_norm"]
regular_run_grad_norm = json.loads(regular_run.readlines()[-1])["learning/raw_grad_norm"]
print(f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}", flush=True)
print(f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}", flush=True)
# Not identical due to an epsilon addition in loss denominator.
np.testing.assert_allclose(accum_run_grad_norm, regular_run_grad_norm, rtol=0.01)

# Assert per device tflops are the same (10x smaller microbatch size, but 10x more microbatches)
with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\
open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run:
accum_device_tflops = json.loads(accum_run.readlines()[-1])["perf/per_device_tflops"]
regular_device_tflops = json.loads(regular_run.readlines()[-1])["perf/per_device_tflops"]
print(f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}", flush=True)
print(f"[Gradient Accumulation Test] per_device_tflops without gradient accumulation: {regular_device_tflops}", flush=True)
np.testing.assert_equal(accum_device_tflops, regular_device_tflops)
50 changes: 45 additions & 5 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def validate_train_config(config):
if not config.base_output_directory.startswith("gs://"):
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system")
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer."
if config.quantization=='fp8':
# pylint: disable=line-too-long
assert config.gradient_accumulation_steps == 1, "fp8 can't be used with gradient_accumulation_steps right now. Please use other quantization or set gradient_accumulation_steps to 1"


def get_first_step(state):
Expand Down Expand Up @@ -156,6 +159,11 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
writer.flush()

def clear_buffered_metrics():
global _buffered_step
global _buffered_metrics
_buffered_step = None
_buffered_metrics = None

def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None):
"""Wrapper for saving checkpoint"""
Expand Down Expand Up @@ -224,7 +232,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
# decimate proportion of data when per_device_batch_size<1
if is_train:
for k, v in data.items():
data[k] = v[: config.global_batch_size_to_train_on, :]
data[k] = v[: config.micro_batch_size_to_train_on, :]

logits, intermediate_outputs = model.apply(
params,
Expand Down Expand Up @@ -266,9 +274,39 @@ def train_step(model, config, state, data, dropout_rng):
rng2: A new rng key that can be used in future calls.
"""
train_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=True)
grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
(loss, aux), raw_grads = grad_fn(state.params)
if config.gradient_accumulation_steps > 1:
def accumulate_gradient(acc_grad_and_loss, data):
grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True)
(_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, state.params, is_train=True)
acc_grad_and_loss['loss'] += aux['total_loss']
acc_grad_and_loss['grad'] = jax.tree_util.tree_map(
lambda x, y: x * aux['total_weights'] + y,
cur_batch_gradient,
acc_grad_and_loss['grad'])
acc_grad_and_loss['total_weights'] += aux['total_weights']
return acc_grad_and_loss, aux

def reshape_to_microbatch_accumulations(batch_arr):
''' Reshape global batch to microbatches, assuming batch axis is leading.'''
microbatches = config.gradient_accumulation_steps
microbatch_shape = (microbatches, batch_arr.shape[0] // microbatches) + batch_arr.shape[1:]
return jnp.reshape(batch_arr, microbatch_shape)

data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data)
init_grad = jax.tree_util.tree_map(jnp.zeros_like, state.params)
init_grad_and_loss = {'loss': 0.0, 'grad': init_grad, 'total_weights':0}

grad_and_loss, aux = jax.lax.scan(
accumulate_gradient,
init_grad_and_loss,
data,
length = config.gradient_accumulation_steps)
loss = grad_and_loss['loss'] / grad_and_loss['total_weights']
raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss['total_weights'], grad_and_loss['grad'])
aux = jax.tree_map(lambda x: jnp.sum(x, axis=0), aux)
else:
grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True)
(loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, state.params, is_train=True)
intermediate_outputs = aux["intermediate_outputs"]
total_weights = aux["total_weights"]

Expand Down Expand Up @@ -589,13 +627,15 @@ def train_loop(config, state=None):
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics
max_utils.close_summary_writer(writer)
record_goodput(recorder, config, job_end=True)
clear_buffered_metrics()
return state


def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
pyconfig.initialize(argv)
max_utils.print_system_information()
config = pyconfig.config
Expand Down

0 comments on commit 344b46c

Please sign in to comment.