diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py index 078db5400..5d8d8dbe0 100644 --- a/MaxText/layers/pipeline.py +++ b/MaxText/layers/pipeline.py @@ -23,8 +23,10 @@ import common_types import functools from typing import Any +from jax.experimental import shard_map +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P -sdclass Pipeline(nn.Module): +class Pipeline(nn.Module): """Module that implements pipelining across stages. This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. @@ -324,13 +326,25 @@ def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positi "x_times": self.num_stages} ) return vmap_func - + def rotate_right_shmap(self, arr): - # Use lax.slice to avoid generating a gather. - last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) - except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) - return jnp.concatenate([last, except_last], axis=0) + axis_names = nn.logical_to_mesh_axes(("activation_stage", "activation_batch", "activation_length", "activation_embed"), rules=self.config.logical_axis_rules) + print(f"{axis_names=}") + axis_names = P(*("stage", "data", "sequence", "tensor")) + print(f"{axis_names=}") + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=axis_names, + out_specs=axis_names, + check_rep=False, + ) + def rotate_shmap(arr): + arr = jax.lax.ppermute(arr, 'stage', [(i, (i+1) % self.num_stages) for i in range(self.num_stages)]) + return arr + #return arr + return rotate_shmap(arr) def rotate_right(self, arr): # Use lax.slice to avoid generating a gather. @@ -347,7 +361,8 @@ def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, m prev_outputs = loop_state["prev_outputs"] # rotate prev outputs before doing the computation - new_shift = self.rotate_right(prev_outputs) + #new_shift = self.rotate_right(prev_outputs) + new_shift = self.rotate_right_shmap(prev_outputs) microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) diff --git a/MaxText/rotate_via_shamp.py b/MaxText/rotate_via_shamp.py new file mode 100644 index 000000000..0bc165822 --- /dev/null +++ b/MaxText/rotate_via_shamp.py @@ -0,0 +1,68 @@ +import jax +import jax.ad_checkpoint +import numpy as np +from jax import numpy as jnp +from flax.core import meta +from flax import linen as nn +import common_types +import functools +from typing import Any +from jax.experimental import shard_map +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils + +NUM_STAGES=2 +BATCH_SHARD=2 +SEQUENCE_SHARD=1 +EMBED_SHARD=1 + + +BATCH=2 +SEQUENCE=6 +EMBED=3 + +AXIS_NAMES = ('stage','batch', 'sequence', 'embed') +ici_parallelism = [NUM_STAGES, BATCH_SHARD, SEQUENCE_SHARD ,EMBED_SHARD] +devices_array = mesh_utils.create_device_mesh(ici_parallelism) +global mesh +mesh = Mesh(devices_array, AXIS_NAMES) + +def rotate_right_shmap(arr): + partition_spec = P(*AXIS_NAMES) + print(f"{partition_spec=}", flush=True) + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=partition_spec, + out_specs=partition_spec, + check_rep=False, + ) + def rotate_shmap(arr): + arr = jax.lax.ppermute(arr, 'stage', [(i, (i+1) % NUM_STAGES) for i in range(NUM_STAGES)]) + return arr + return rotate_shmap(arr) + +def rotate_right(arr): + # Use lax.slice to avoid generating a gather. + last = jax.lax.slice_in_dim(arr, NUM_STAGES - 1, NUM_STAGES, axis=0) + except_last = jax.lax.slice_in_dim(arr, 0, NUM_STAGES - 1, axis=0) + return jnp.concatenate([last, except_last], axis=0) + +def create_random_arr(): + shape = (NUM_STAGES, BATCH, SEQUENCE, EMBED) + total_elements = np.prod(shape) # Calculate the total number of elements + sequential_values = jnp.arange(1, total_elements + 1) # Create a 1D array with sequential values + return jnp.reshape(sequential_values, shape) + +arr1 = create_random_arr() +arr2 = create_random_arr() + +print(f"{jnp.linalg.norm(arr1)=}",flush=True) + +rot_shmap = rotate_right_shmap(arr1) +rot_regular = rotate_right(arr2) +diff = rot_shmap - rot_regular + +print(f"{jnp.linalg.norm(rot_shmap)=}",flush=True) +print(f"{jnp.linalg.norm(rot_regular)=}",flush=True) +print(f"{jnp.linalg.norm(diff)=}",flush=True) \ No newline at end of file diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 8bedf3dbc..6500129ca 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -131,60 +131,61 @@ def regular_sequential_layers_dummy_loss(params, inputs, inputs_position, inputs assert_same_output_and_grad(regular_sequential_layers_dummy_loss, pipeline_parallelism_dummy_loss, init_pipeline_params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets) - @pytest.mark.tpu - def test_circular_minimum_microbatches_same_output_and_grad(self): - # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches - pyconfig.initialize( - [sys.argv[0], "configs/base.yml"], - enable_checkpointing=False, - run_name="circular_minimum_microbatches", - max_target_length=128, - base_emb_dim=28, - ici_pipeline_parallelism=4, - base_num_decoder_layers=8, - num_pipeline_microbatches=4, - per_device_batch_size=4 - ) - config = pyconfig.config - self.assert_pipeline_same_output_and_grad(config) - - @pytest.mark.tpu - def test_circular_extra_microbatches_same_output_and_grad(self): - # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches - pyconfig.initialize( - [sys.argv[0], "configs/base.yml"], - enable_checkpointing=False, - run_name="circular_extra_microbatches", - max_target_length=128, - base_emb_dim=28, - ici_pipeline_parallelism=4, - base_num_decoder_layers=8, - num_pipeline_microbatches=8, - per_device_batch_size=4 - ) - config = pyconfig.config - self.assert_pipeline_same_output_and_grad(config) - - @pytest.mark.tpu - def test_non_circular_same_output_and_grad(self): - # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches - pyconfig.initialize( - [sys.argv[0], "configs/base.yml"], - enable_checkpointing=False, - run_name="non_circular", - max_target_length=128, - base_emb_dim=28, - ici_pipeline_parallelism=4, - base_num_decoder_layers=4, - num_pipeline_microbatches=4, - per_device_batch_size=4 - ) - config = pyconfig.config - self.assert_pipeline_same_output_and_grad(config) + # @pytest.mark.tpu + # def test_circular_minimum_microbatches_same_output_and_grad(self): + # # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches + # pyconfig.initialize( + # [sys.argv[0], "configs/base.yml"], + # enable_checkpointing=False, + # run_name="circular_minimum_microbatches", + # max_target_length=128, + # base_emb_dim=28, + # ici_pipeline_parallelism=4, + # base_num_decoder_layers=8, + # num_pipeline_microbatches=4, + # per_device_batch_size=4 + # ) + # config = pyconfig.config + # self.assert_pipeline_same_output_and_grad(config) + + # @pytest.mark.tpu + # def test_circular_extra_microbatches_same_output_and_grad(self): + # # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + # pyconfig.initialize( + # [sys.argv[0], "configs/base.yml"], + # enable_checkpointing=False, + # run_name="circular_extra_microbatches", + # max_target_length=128, + # base_emb_dim=28, + # ici_pipeline_parallelism=4, + # base_num_decoder_layers=8, + # num_pipeline_microbatches=8, + # per_device_batch_size=4 + # ) + # config = pyconfig.config + # self.assert_pipeline_same_output_and_grad(config) + + # @pytest.mark.tpu + # def test_non_circular_same_output_and_grad(self): + # # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches + # pyconfig.initialize( + # [sys.argv[0], "configs/base.yml"], + # enable_checkpointing=False, + # run_name="non_circular", + # max_target_length=128, + # base_emb_dim=28, + # ici_pipeline_parallelism=4, + # base_num_decoder_layers=4, + # num_pipeline_microbatches=4, + # per_device_batch_size=4 + # ) + # config = pyconfig.config + # self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu def test_activation_forwarding_same_output_and_grad(self): # 4 stages, activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, @@ -198,63 +199,64 @@ def test_activation_forwarding_same_output_and_grad(self): pipeline_activation_forwarding=True ) config = pyconfig.config + #with nn_partitioning.axis_rules(config.logical_axis_rules): self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu - def test_full_train_circular(self): - # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches - train_main([ - None, - "configs/base.yml", - r"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_pipeline_parallelism_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=28", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=32", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "vocab_size=32", - "dataset_type=synthetic", - "steps=3", - "enable_checkpointing=False", - "ici_pipeline_parallelism=4", - "num_layers_per_pipeline_stage=2", - "num_pipeline_microbatches=8", - "tokenizer_path=../assets/tokenizer.llama2", - - ]) - - @pytest.mark.tpu - def test_full_train_non_circular(self): - # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches - train_main([ - None, - "configs/base.yml", - r"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_pipeline_parallelism_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=28", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=32", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "vocab_size=32", - "dataset_type=synthetic", - "steps=3", - "enable_checkpointing=False", - "ici_pipeline_parallelism=4", - "num_layers_per_pipeline_stage=8", - "num_pipeline_microbatches=8", - "tokenizer_path=../assets/tokenizer.llama2", + # @pytest.mark.tpu + # def test_full_train_circular(self): + # # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches + # train_main([ + # None, + # "configs/base.yml", + # r"base_output_directory=gs://runner-maxtext-logs", + # "run_name=runner_pipeline_parallelism_test", + # r"dataset_path=gs://maxtext-dataset", + # "base_emb_dim=28", + # "base_num_query_heads=4", + # "base_num_kv_heads=4", + # "base_mlp_dim=32", + # "base_num_decoder_layers=32", + # "head_dim=128", + # "per_device_batch_size=2", + # "max_target_length=1024", + # "vocab_size=32", + # "dataset_type=synthetic", + # "steps=3", + # "enable_checkpointing=False", + # "ici_pipeline_parallelism=4", + # "num_layers_per_pipeline_stage=2", + # "num_pipeline_microbatches=8", + # "tokenizer_path=../assets/tokenizer.llama2", + + # ]) + + # @pytest.mark.tpu + # def test_full_train_non_circular(self): + # # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches + # train_main([ + # None, + # "configs/base.yml", + # r"base_output_directory=gs://runner-maxtext-logs", + # "run_name=runner_pipeline_parallelism_test", + # r"dataset_path=gs://maxtext-dataset", + # "base_emb_dim=28", + # "base_num_query_heads=4", + # "base_num_kv_heads=4", + # "base_mlp_dim=32", + # "base_num_decoder_layers=32", + # "head_dim=128", + # "per_device_batch_size=2", + # "max_target_length=1024", + # "vocab_size=32", + # "dataset_type=synthetic", + # "steps=3", + # "enable_checkpointing=False", + # "ici_pipeline_parallelism=4", + # "num_layers_per_pipeline_stage=8", + # "num_pipeline_microbatches=8", + # "tokenizer_path=../assets/tokenizer.llama2", - ]) + # ]) if __name__ == "__main__": unittest.main()