Skip to content

Commit

Permalink
logical axis rules explicit instead of rely on context
Browse files Browse the repository at this point in the history
  • Loading branch information
gobbleturk committed Aug 14, 2024
1 parent 9a1f389 commit 38ab9c4
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 111 deletions.
29 changes: 22 additions & 7 deletions MaxText/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import common_types
import functools
from typing import Any
from jax.experimental import shard_map
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P

sdclass Pipeline(nn.Module):
class Pipeline(nn.Module):
"""Module that implements pipelining across stages.
This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights.
Expand Down Expand Up @@ -324,13 +326,25 @@ def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positi
"x_times": self.num_stages}
)
return vmap_func


def rotate_right_shmap(self, arr):
# Use lax.slice to avoid generating a gather.
last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0)
except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0)
return jnp.concatenate([last, except_last], axis=0)
axis_names = nn.logical_to_mesh_axes(("activation_stage", "activation_batch", "activation_length", "activation_embed"), rules=self.config.logical_axis_rules)
print(f"{axis_names=}")
axis_names = P(*("stage", "data", "sequence", "tensor"))
print(f"{axis_names=}")
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=axis_names,
out_specs=axis_names,
check_rep=False,
)
def rotate_shmap(arr):
arr = jax.lax.ppermute(arr, 'stage', [(i, (i+1) % self.num_stages) for i in range(self.num_stages)])
return arr
#return arr
return rotate_shmap(arr)

def rotate_right(self, arr):
# Use lax.slice to avoid generating a gather.
Expand All @@ -347,7 +361,8 @@ def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, m
prev_outputs = loop_state["prev_outputs"]

# rotate prev outputs before doing the computation
new_shift = self.rotate_right(prev_outputs)
#new_shift = self.rotate_right(prev_outputs)
new_shift = self.rotate_right_shmap(prev_outputs)

microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration)

Expand Down
68 changes: 68 additions & 0 deletions MaxText/rotate_via_shamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax
import jax.ad_checkpoint
import numpy as np
from jax import numpy as jnp
from flax.core import meta
from flax import linen as nn
import common_types
import functools
from typing import Any
from jax.experimental import shard_map
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental import mesh_utils

NUM_STAGES=2
BATCH_SHARD=2
SEQUENCE_SHARD=1
EMBED_SHARD=1


BATCH=2
SEQUENCE=6
EMBED=3

AXIS_NAMES = ('stage','batch', 'sequence', 'embed')
ici_parallelism = [NUM_STAGES, BATCH_SHARD, SEQUENCE_SHARD ,EMBED_SHARD]
devices_array = mesh_utils.create_device_mesh(ici_parallelism)
global mesh
mesh = Mesh(devices_array, AXIS_NAMES)

def rotate_right_shmap(arr):
partition_spec = P(*AXIS_NAMES)
print(f"{partition_spec=}", flush=True)
@functools.partial(
shard_map.shard_map,
mesh=mesh,
in_specs=partition_spec,
out_specs=partition_spec,
check_rep=False,
)
def rotate_shmap(arr):
arr = jax.lax.ppermute(arr, 'stage', [(i, (i+1) % NUM_STAGES) for i in range(NUM_STAGES)])
return arr
return rotate_shmap(arr)

def rotate_right(arr):
# Use lax.slice to avoid generating a gather.
last = jax.lax.slice_in_dim(arr, NUM_STAGES - 1, NUM_STAGES, axis=0)
except_last = jax.lax.slice_in_dim(arr, 0, NUM_STAGES - 1, axis=0)
return jnp.concatenate([last, except_last], axis=0)

def create_random_arr():
shape = (NUM_STAGES, BATCH, SEQUENCE, EMBED)
total_elements = np.prod(shape) # Calculate the total number of elements
sequential_values = jnp.arange(1, total_elements + 1) # Create a 1D array with sequential values
return jnp.reshape(sequential_values, shape)

arr1 = create_random_arr()
arr2 = create_random_arr()

print(f"{jnp.linalg.norm(arr1)=}",flush=True)

rot_shmap = rotate_right_shmap(arr1)
rot_regular = rotate_right(arr2)
diff = rot_shmap - rot_regular

print(f"{jnp.linalg.norm(rot_shmap)=}",flush=True)
print(f"{jnp.linalg.norm(rot_regular)=}",flush=True)
print(f"{jnp.linalg.norm(diff)=}",flush=True)
210 changes: 106 additions & 104 deletions MaxText/tests/pipeline_parallelism_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,60 +131,61 @@ def regular_sequential_layers_dummy_loss(params, inputs, inputs_position, inputs

assert_same_output_and_grad(regular_sequential_layers_dummy_loss, pipeline_parallelism_dummy_loss, init_pipeline_params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets)

@pytest.mark.tpu
def test_circular_minimum_microbatches_same_output_and_grad(self):
# 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches
pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
enable_checkpointing=False,
run_name="circular_minimum_microbatches",
max_target_length=128,
base_emb_dim=28,
ici_pipeline_parallelism=4,
base_num_decoder_layers=8,
num_pipeline_microbatches=4,
per_device_batch_size=4
)
config = pyconfig.config
self.assert_pipeline_same_output_and_grad(config)

@pytest.mark.tpu
def test_circular_extra_microbatches_same_output_and_grad(self):
# 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches
pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
enable_checkpointing=False,
run_name="circular_extra_microbatches",
max_target_length=128,
base_emb_dim=28,
ici_pipeline_parallelism=4,
base_num_decoder_layers=8,
num_pipeline_microbatches=8,
per_device_batch_size=4
)
config = pyconfig.config
self.assert_pipeline_same_output_and_grad(config)

@pytest.mark.tpu
def test_non_circular_same_output_and_grad(self):
# 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches
pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
enable_checkpointing=False,
run_name="non_circular",
max_target_length=128,
base_emb_dim=28,
ici_pipeline_parallelism=4,
base_num_decoder_layers=4,
num_pipeline_microbatches=4,
per_device_batch_size=4
)
config = pyconfig.config
self.assert_pipeline_same_output_and_grad(config)
# @pytest.mark.tpu
# def test_circular_minimum_microbatches_same_output_and_grad(self):
# # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches
# pyconfig.initialize(
# [sys.argv[0], "configs/base.yml"],
# enable_checkpointing=False,
# run_name="circular_minimum_microbatches",
# max_target_length=128,
# base_emb_dim=28,
# ici_pipeline_parallelism=4,
# base_num_decoder_layers=8,
# num_pipeline_microbatches=4,
# per_device_batch_size=4
# )
# config = pyconfig.config
# self.assert_pipeline_same_output_and_grad(config)

# @pytest.mark.tpu
# def test_circular_extra_microbatches_same_output_and_grad(self):
# # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches
# pyconfig.initialize(
# [sys.argv[0], "configs/base.yml"],
# enable_checkpointing=False,
# run_name="circular_extra_microbatches",
# max_target_length=128,
# base_emb_dim=28,
# ici_pipeline_parallelism=4,
# base_num_decoder_layers=8,
# num_pipeline_microbatches=8,
# per_device_batch_size=4
# )
# config = pyconfig.config
# self.assert_pipeline_same_output_and_grad(config)

# @pytest.mark.tpu
# def test_non_circular_same_output_and_grad(self):
# # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches
# pyconfig.initialize(
# [sys.argv[0], "configs/base.yml"],
# enable_checkpointing=False,
# run_name="non_circular",
# max_target_length=128,
# base_emb_dim=28,
# ici_pipeline_parallelism=4,
# base_num_decoder_layers=4,
# num_pipeline_microbatches=4,
# per_device_batch_size=4
# )
# config = pyconfig.config
# self.assert_pipeline_same_output_and_grad(config)

@pytest.mark.tpu
def test_activation_forwarding_same_output_and_grad(self):
# 4 stages, activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches

pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
enable_checkpointing=False,
Expand All @@ -198,63 +199,64 @@ def test_activation_forwarding_same_output_and_grad(self):
pipeline_activation_forwarding=True
)
config = pyconfig.config
#with nn_partitioning.axis_rules(config.logical_axis_rules):
self.assert_pipeline_same_output_and_grad(config)

@pytest.mark.tpu
def test_full_train_circular(self):
# Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches
train_main([
None,
"configs/base.yml",
r"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_pipeline_parallelism_test",
r"dataset_path=gs://maxtext-dataset",
"base_emb_dim=28",
"base_num_query_heads=4",
"base_num_kv_heads=4",
"base_mlp_dim=32",
"base_num_decoder_layers=32",
"head_dim=128",
"per_device_batch_size=2",
"max_target_length=1024",
"vocab_size=32",
"dataset_type=synthetic",
"steps=3",
"enable_checkpointing=False",
"ici_pipeline_parallelism=4",
"num_layers_per_pipeline_stage=2",
"num_pipeline_microbatches=8",
"tokenizer_path=../assets/tokenizer.llama2",

])

@pytest.mark.tpu
def test_full_train_non_circular(self):
# Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches
train_main([
None,
"configs/base.yml",
r"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_pipeline_parallelism_test",
r"dataset_path=gs://maxtext-dataset",
"base_emb_dim=28",
"base_num_query_heads=4",
"base_num_kv_heads=4",
"base_mlp_dim=32",
"base_num_decoder_layers=32",
"head_dim=128",
"per_device_batch_size=2",
"max_target_length=1024",
"vocab_size=32",
"dataset_type=synthetic",
"steps=3",
"enable_checkpointing=False",
"ici_pipeline_parallelism=4",
"num_layers_per_pipeline_stage=8",
"num_pipeline_microbatches=8",
"tokenizer_path=../assets/tokenizer.llama2",
# @pytest.mark.tpu
# def test_full_train_circular(self):
# # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches
# train_main([
# None,
# "configs/base.yml",
# r"base_output_directory=gs://runner-maxtext-logs",
# "run_name=runner_pipeline_parallelism_test",
# r"dataset_path=gs://maxtext-dataset",
# "base_emb_dim=28",
# "base_num_query_heads=4",
# "base_num_kv_heads=4",
# "base_mlp_dim=32",
# "base_num_decoder_layers=32",
# "head_dim=128",
# "per_device_batch_size=2",
# "max_target_length=1024",
# "vocab_size=32",
# "dataset_type=synthetic",
# "steps=3",
# "enable_checkpointing=False",
# "ici_pipeline_parallelism=4",
# "num_layers_per_pipeline_stage=2",
# "num_pipeline_microbatches=8",
# "tokenizer_path=../assets/tokenizer.llama2",

# ])

# @pytest.mark.tpu
# def test_full_train_non_circular(self):
# # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches
# train_main([
# None,
# "configs/base.yml",
# r"base_output_directory=gs://runner-maxtext-logs",
# "run_name=runner_pipeline_parallelism_test",
# r"dataset_path=gs://maxtext-dataset",
# "base_emb_dim=28",
# "base_num_query_heads=4",
# "base_num_kv_heads=4",
# "base_mlp_dim=32",
# "base_num_decoder_layers=32",
# "head_dim=128",
# "per_device_batch_size=2",
# "max_target_length=1024",
# "vocab_size=32",
# "dataset_type=synthetic",
# "steps=3",
# "enable_checkpointing=False",
# "ici_pipeline_parallelism=4",
# "num_layers_per_pipeline_stage=8",
# "num_pipeline_microbatches=8",
# "tokenizer_path=../assets/tokenizer.llama2",

])
# ])

if __name__ == "__main__":
unittest.main()

0 comments on commit 38ab9c4

Please sign in to comment.