From d82ec42bae0aef034afac3cf149fa6f65a1fcd96 Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Mon, 19 Aug 2024 23:59:07 +0000 Subject: [PATCH 1/2] Integrate Badput monitoring with MaxText --- MaxText/train.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/MaxText/train.py b/MaxText/train.py index 56a2d3181..5e5885467 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -416,14 +416,15 @@ def create_goodput_recorder(config): return None -def record_goodput(recorder, config, step=None, job_start=False, job_end=False): +def record_goodput( + recorder, + config, + record_func, + *args, + ): + """Record data for Goodput and Badput computation.""" if recorder and config.enable_goodput_recording: - if job_start and step is None: - recorder.record_job_start_time() - if job_end and step is None: - recorder.record_job_end_time() - if step is not None: - recorder.record_step_start_time(step) + record_func(*args) def check_example_batch(config, example_batch): if config.max_checkify: @@ -511,7 +512,11 @@ def setup_train_loop(config): data_iterator: state: the initialized train state """ + recorder = create_goodput_recorder(config) + record_goodput(recorder, config, recorder.record_tpu_init_start_time if recorder else None) init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config) + record_goodput(recorder, config, recorder.record_tpu_init_end_time if recorder else None) + record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None) data_iterator, eval_data_iterator = create_data_iterator(config, mesh) state, state_mesh_annotations, data_iterator = max_utils.setup_training_state( @@ -521,7 +526,7 @@ def setup_train_loop(config): if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance=0.02) - + record_goodput(recorder, config, recorder.record_training_preparation_end_time if recorder else None) return ( init_rng, writer, @@ -546,7 +551,7 @@ def train_loop(config, state=None): """ # Create a GoodputRecorder to log information recorder = create_goodput_recorder(config) - record_goodput(recorder, config, job_start=True) + record_goodput(recorder, config, recorder.record_job_start_time if recorder else None) ( init_rng, @@ -634,10 +639,12 @@ def train_loop(config, state=None): prof.activate() with jax.profiler.StepTraceAnnotation("train", step_num=step): + record_goodput(recorder, config, recorder.record_data_loading_start_time if recorder else None) example_batch = load_next_batch(data_iterator, example_batch, config) + record_goodput(recorder, config, recorder.record_data_loading_end_time if recorder else None) check_example_batch(config, example_batch=example_batch) nextrng = jax.jit(jax.random.fold_in)(init_rng, step) - record_goodput(recorder, config, step=step) + record_goodput(recorder, config, recorder.record_step_start_time if recorder else None, step) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state, metrics = p_train_step(state, example_batch, nextrng) @@ -693,7 +700,7 @@ def train_loop(config, state=None): checkpoint_manager.wait_until_finished() write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics max_utils.close_summary_writer(writer) - record_goodput(recorder, config, job_end=True) + record_goodput(recorder, config, recorder.record_job_end_time if recorder else None) clear_buffered_metrics() return state @@ -719,7 +726,8 @@ def main(argv: Sequence[str]) -> None: logger_name=logger_name, tensorboard_dir=config.tensorboard_dir, upload_interval=config.goodput_upload_interval_seconds, - monitoring_enabled=True + monitoring_enabled=True, + include_badput_breakdown=True, ) goodput_monitor.start_goodput_uploader() max_logging.log("Started Goodput upload to Tensorboard in the background!") From b8bbe21c8bbdfd2fca88d1def1f0e6b928b964e5 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Thu, 19 Sep 2024 03:58:13 +0000 Subject: [PATCH 2/2] Fix lint errors Fix pyink and pylint errors so that code_style.sh passes. --- MaxText/accelerator_to_spec_map.py | 2 +- MaxText/checkpointing.py | 57 +- MaxText/convert_gemma2_chkpt.py | 157 +++--- MaxText/convert_gemma_chkpt.py | 39 +- MaxText/decode.py | 4 +- MaxText/generate_param_only_checkpoint.py | 6 +- MaxText/inference_microbenchmark.py | 96 ++-- MaxText/inference_microbenchmark_sweep.py | 118 ++-- .../input_pipeline/_grain_data_processing.py | 54 +- MaxText/input_pipeline/_hf_data_processing.py | 95 ++-- .../input_pipeline/_input_pipeline_utils.py | 64 ++- .../input_pipeline/_tfds_data_processing.py | 65 ++- .../_tfds_data_processing_c4_mlperf.py | 52 +- .../input_pipeline_interface.py | 8 +- MaxText/kernels/ragged_attention.py | 238 ++++---- MaxText/layers/attentions.py | 186 ++++--- MaxText/layers/embeddings.py | 4 +- MaxText/layers/gpt3.py | 3 +- MaxText/layers/linears.py | 225 ++++---- MaxText/layers/mistral.py | 8 +- MaxText/layers/models.py | 93 ++-- MaxText/layers/pipeline.py | 519 ++++++++++-------- MaxText/layers/quantizations.py | 118 ++-- MaxText/layers/simple_layer.py | 26 +- MaxText/llama_mistral_mixtral_orbax_to_hf.py | 161 +++--- MaxText/llama_or_mistral_ckpt.py | 111 ++-- MaxText/max_utils.py | 184 ++----- MaxText/maxengine.py | 48 +- MaxText/maxengine_server.py | 8 +- MaxText/maxtext_utils.py | 38 +- MaxText/profiler.py | 13 +- MaxText/pyconfig.py | 132 +++-- .../scratch_code/golden_gemma-2b_export.ipynb | 106 ++-- .../golden_gemma2-27b_export-flax.ipynb | 97 ++-- .../golden_gemma2-2b_export-flax.ipynb | 97 ++-- .../golden_gemma2-2b_export.ipynb | 63 ++- .../golden_gemma2-9b_export-flax.ipynb | 97 ++-- .../golden_gemma2-9b_export.ipynb | 61 +- .../scratch_code/golden_llama2-70b_export.py | 73 ++- .../golden_llama2-7b_export.ipynb | 63 +-- .../scratch_code/golden_llama3-70b_export.py | 73 ++- .../golden_llama3-8b_export.ipynb | 59 +- .../golden_mixtral-8x22b_export.ipynb | 59 +- .../golden_mixtral-8x7b_export.ipynb | 61 +- MaxText/standalone_dataloader.py | 2 +- MaxText/tests/aot_hlo_identical_test.py | 171 +++--- MaxText/tests/attention_test.py | 81 +-- MaxText/tests/forward_pass_logit_checker.py | 41 +- MaxText/tests/gradient_accumulation_test.py | 114 ++-- MaxText/tests/grain_data_processing_test.py | 4 +- MaxText/tests/hf_data_processing_test.py | 4 +- .../inference_microbenchmark_smoke_test.py | 24 +- MaxText/tests/kernels_test.py | 54 +- MaxText/tests/llama_test.py | 14 +- MaxText/tests/maxtext_utils_test.py | 104 ++-- MaxText/tests/moe_test.py | 135 +++-- MaxText/tests/pipeline_parallelism_test.py | 289 +++++----- MaxText/tests/profiler_test.py | 1 - MaxText/tests/pyconfig_test.py | 85 ++- MaxText/tests/simple_decoder_layer_test.py | 50 +- MaxText/tests/standalone_dl_ckpt_test.py | 90 +-- MaxText/tests/tfds_data_processing_test.py | 6 +- MaxText/tests/tokenizer_test.py | 21 +- MaxText/tests/train_compile_test.py | 262 +++++---- MaxText/tests/train_gpu_smoke_test.py | 20 +- MaxText/tests/train_int8_smoke_test.py | 42 +- MaxText/tests/train_smoke_test.py | 40 +- MaxText/tokenizer.py | 75 ++- MaxText/train.py | 164 +++--- MaxText/train_compile.py | 6 +- MaxText/train_tokenizer.py | 16 +- 71 files changed, 3032 insertions(+), 2724 deletions(-) diff --git a/MaxText/accelerator_to_spec_map.py b/MaxText/accelerator_to_spec_map.py index 5dfe7dff2..7d56182f4 100644 --- a/MaxText/accelerator_to_spec_map.py +++ b/MaxText/accelerator_to_spec_map.py @@ -155,7 +155,7 @@ class SystemCharacteristics: # across hosts will occur over DCN. This makes the "slice" topology of A3 fixed to a single host. # To use AoT compilation with multihost, the `compile_topology_num_slices` flag should be # specified to the number of hosts. - "a3": SystemCharacteristics("gpu", None, None, None, 8, None) + "a3": SystemCharacteristics("gpu", None, None, None, 8, None), } diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index f4203aa9e..9ce8464a3 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -32,9 +32,7 @@ CheckpointManagerOptions = ocp.CheckpointManagerOptions PyTreeCheckpointHandler = ocp.PyTreeCheckpointHandler LocalCheckpointOptions = emergency_checkpoint_manager.LocalCheckpointOptions -PersistentCheckpointOptions = ( - emergency_checkpoint_manager.PersistentCheckpointOptions -) +PersistentCheckpointOptions = emergency_checkpoint_manager.PersistentCheckpointOptions abstract_logger = ocp.logging.abstract_logger cloud_logger = ocp.logging.cloud_logger @@ -76,7 +74,7 @@ def create_orbax_checkpoint_manager( save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, ), - logger=orbax_logger + logger=orbax_logger, ) max_logging.log("Checkpoint manager created!") return mngr @@ -96,12 +94,8 @@ def create_orbax_emergency_checkpoint_manager( max_logging.log("Creating emergency checkpoint manager...") options = emergency_checkpoint_manager.CheckpointManagerOptions( - local=LocalCheckpointOptions( - save_interval_steps=local_save_interval_steps - ), - persistent=PersistentCheckpointOptions( - save_interval_steps=persistent_save_interval_steps - ), + local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps), + persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps), ) emergency_mngr = emergency_checkpoint_manager.CheckpointManager( local_checkpoint_dir, @@ -191,16 +185,13 @@ def map_to_pspec(data): replica_axis_index = 0 replica_devices = _replica_devices(mesh.devices, replica_axis_index) replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) - single_replica_sharding = jax.sharding.NamedSharding( - replica_mesh, pspec) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) array_handler = ocp.type_handlers.SingleReplicaArrayHandler( replica_axis_index=0, broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit ) - ocp.type_handlers.register_type_handler( - jax.Array, array_handler, override=True - ) + ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) return ocp.type_handlers.SingleReplicaArrayRestoreArgs( sharding=jax.sharding.NamedSharding(mesh, pspec), @@ -218,9 +209,7 @@ def map_to_pspec(data): return ( checkpoint_manager.restore( latest_step, - args=ocp.args.PyTreeRestore( - item=abstract_unboxed_pre_state, restore_args=restore_args - ), + args=ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args), ), None, ) @@ -234,9 +223,7 @@ def map_to_pspec(data): item=abstract_unboxed_pre_state, restore_args=restore_args, ), - iter=grain.PyGrainCheckpointRestore( - data_iterator.local_iterator - ), + iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator), ), ), None, @@ -282,9 +269,7 @@ def setup_checkpoint_logger(config) -> composite_logger.CompositeLogger | None: max_logging.log("Setting up checkpoint logger...") if config.enable_checkpoint_cloud_logger: logger_name = f"checkpoint_{config.run_name}" - options = cloud_logger.CloudLoggerOptions( - job_name=config.run_name, logger_name=logger_name - ) + options = cloud_logger.CloudLoggerOptions(job_name=config.run_name, logger_name=logger_name) orbax_cloud_logger = cloud_logger.CloudLogger(options=options) max_logging.log("Successfully set up checkpoint cloud logger.") @@ -294,9 +279,7 @@ def setup_checkpoint_logger(config) -> composite_logger.CompositeLogger | None: orbax_logger = None if orbax_cloud_logger is not None and orbax_standard_logger is not None: - orbax_logger = composite_logger.CompositeLogger( - orbax_cloud_logger, orbax_standard_logger - ) + orbax_logger = composite_logger.CompositeLogger(orbax_cloud_logger, orbax_standard_logger) max_logging.log("Successfully set up checkpoint composite logger.") return orbax_logger @@ -312,15 +295,10 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params): # Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste # memory, we instead specify here that we are just restoring the params field of the checkpoint # (which itself may be a dictionary containing a key named 'params'). - restore_args = ocp.checkpoint_utils.construct_restore_args( - abstract_unboxed_params - ) + restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params) restored = ckptr.restore( - ckpt, - item={"params": abstract_unboxed_params}, - transforms={}, - restore_args={"params": restore_args} - ) + ckpt, item={"params": abstract_unboxed_params}, transforms={}, restore_args={"params": restore_args} + ) return restored["params"] @@ -328,11 +306,6 @@ def save_params_to_path(checkpoint_dir, params): """Save decode params in checkpoint at specified path.""" assert checkpoint_dir, "checkpoint_dir is not defined." orbax_checkpointer = ocp.PyTreeCheckpointer() - save_args = orbax_utils.save_args_from_target({"params":params}) - orbax_checkpointer.save( - checkpoint_dir, - {"params":params}, - save_args=save_args, - force=True - ) + save_args = orbax_utils.save_args_from_target({"params": params}) + orbax_checkpointer.save(checkpoint_dir, {"params": params}, save_args=save_args, force=True) print(f"Quantized params checkpoint saved at: {checkpoint_dir}") diff --git a/MaxText/convert_gemma2_chkpt.py b/MaxText/convert_gemma2_chkpt.py index e53a37dbf..33ddf73e2 100644 --- a/MaxText/convert_gemma2_chkpt.py +++ b/MaxText/convert_gemma2_chkpt.py @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + # pylint: disable=line-too-long """ Convert orbax Gemma checkpoint to MaxText compatible checkpoint. @@ -36,6 +37,7 @@ Params = dict[str, Any] + def nest_params(params: Params) -> Params: """Nests params as a dict of dicts rather than a flat dict.""" nested_params = {} @@ -75,7 +77,7 @@ def main(raw_args=None) -> None: if args.model_size in ("2b", "9b"): query_pre_attn_scalar = head_dim**-0.5 elif args.model_size in ("27b"): - query_pre_attn_scalar = (embed_dim // num_heads)**-0.5 + query_pre_attn_scalar = (embed_dim // num_heads) ** -0.5 transpose_gating_einsum = True if args.model_size in ("2b"): @@ -89,44 +91,49 @@ def main(raw_args=None) -> None: }, "token_embedder": {"embedding": params["transformer"]["embedder"]["input_embedding"] * jnp.sqrt(embed_dim)}, } - self_attention_local = dict({ - "query": {"kernel": []}, - "key": {"kernel": []}, - "value": {"kernel": []}, - "out": {"kernel": []}, - }) - self_attention_global = dict({ - "query": {"kernel": []}, - "key": {"kernel": []}, - "value": {"kernel": []}, - "out": {"kernel": []}, - }) - - layer_weight = dict({ - "mlp_local": { - "wi_0": {"kernel": []}, - "wi_1": {"kernel": []}, - "wo": {"kernel": []}, - }, - "mlp_global": { - "wi_0": {"kernel": []}, - "wi_1": {"kernel": []}, - "wo": {"kernel": []}, - }, - "pre_self_attention_norm_local": {"scale": []}, - "pre_ffw_norm_local": {"scale": []}, - "post_self_attention_norm_local": {"scale": []}, - "post_ffw_norm_local": {"scale": []}, - "pre_self_attention_norm_global": {"scale": []}, - "pre_ffw_norm_global": {"scale": []}, - "post_self_attention_norm_global": {"scale": []}, - "post_ffw_norm_global": {"scale": []}, - }) + self_attention_local = dict( + { + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, + } + ) + self_attention_global = dict( + { + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, + } + ) + layer_weight = dict( + { + "mlp_local": { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, + }, + "mlp_global": { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, + }, + "pre_self_attention_norm_local": {"scale": []}, + "pre_ffw_norm_local": {"scale": []}, + "post_self_attention_norm_local": {"scale": []}, + "post_ffw_norm_local": {"scale": []}, + "pre_self_attention_norm_global": {"scale": []}, + "pre_ffw_norm_global": {"scale": []}, + "post_self_attention_norm_global": {"scale": []}, + "post_ffw_norm_global": {"scale": []}, + } + ) for layer_idx in range(0, num_layers, 2): in_layer_name_local = "layer_" + str(layer_idx) - in_layer_name_global = "layer_" + str(layer_idx+1) + in_layer_name_global = "layer_" + str(layer_idx + 1) ######################## layer local attention ######################## self_attention_local["query"]["kernel"].append( @@ -142,23 +149,35 @@ def main(raw_args=None) -> None: # mlp if transpose_gating_einsum: - layer_weight["mlp_local"]["wi_0"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0])) - layer_weight["mlp_local"]["wi_1"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1])) + layer_weight["mlp_local"]["wi_0"]["kernel"].append( + np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0]) + ) + layer_weight["mlp_local"]["wi_1"]["kernel"].append( + np.transpose(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1]) + ) else: - layer_weight["mlp_local"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0]) - layer_weight["mlp_local"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1]) + layer_weight["mlp_local"]["wi_0"]["kernel"].append( + params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][0] + ) + layer_weight["mlp_local"]["wi_1"]["kernel"].append( + params["transformer"][in_layer_name_local]["mlp"]["gating_einsum"]["w"][1] + ) layer_weight["mlp_local"]["wo"]["kernel"].append(params["transformer"][in_layer_name_local]["mlp"]["linear"]["w"]) layer_weight["pre_self_attention_norm_local"]["scale"].append( params["transformer"][in_layer_name_local]["pre_attention_norm"]["scale"] + 1 ) - layer_weight["pre_ffw_norm_local"]["scale"].append(params["transformer"][in_layer_name_local]["pre_ffw_norm"]["scale"] + 1) + layer_weight["pre_ffw_norm_local"]["scale"].append( + params["transformer"][in_layer_name_local]["pre_ffw_norm"]["scale"] + 1 + ) layer_weight["post_self_attention_norm_local"]["scale"].append( - params["transformer"][in_layer_name_local]["post_attention_norm"]["scale"] + 1 + params["transformer"][in_layer_name_local]["post_attention_norm"]["scale"] + 1 + ) + layer_weight["post_ffw_norm_local"]["scale"].append( + params["transformer"][in_layer_name_local]["post_ffw_norm"]["scale"] + 1 ) - layer_weight["post_ffw_norm_local"]["scale"].append(params["transformer"][in_layer_name_local]["post_ffw_norm"]["scale"] + 1) ######################## layer global attention ######################## @@ -171,27 +190,41 @@ def main(raw_args=None) -> None: self_attention_global["value"]["kernel"].append( params["transformer"][in_layer_name_global]["attn"]["kv_einsum"]["w"][1].transpose((1, 0, 2)) ) - self_attention_global["out"]["kernel"].append(params["transformer"][in_layer_name_global]["attn"]["attn_vec_einsum"]["w"]) + self_attention_global["out"]["kernel"].append( + params["transformer"][in_layer_name_global]["attn"]["attn_vec_einsum"]["w"] + ) # mlp if transpose_gating_einsum: - layer_weight["mlp_global"]["wi_0"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0])) - layer_weight["mlp_global"]["wi_1"]["kernel"].append(np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1])) + layer_weight["mlp_global"]["wi_0"]["kernel"].append( + np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0]) + ) + layer_weight["mlp_global"]["wi_1"]["kernel"].append( + np.transpose(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1]) + ) else: - layer_weight["mlp_global"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0]) - layer_weight["mlp_global"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1]) + layer_weight["mlp_global"]["wi_0"]["kernel"].append( + params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][0] + ) + layer_weight["mlp_global"]["wi_1"]["kernel"].append( + params["transformer"][in_layer_name_global]["mlp"]["gating_einsum"]["w"][1] + ) layer_weight["mlp_global"]["wo"]["kernel"].append(params["transformer"][in_layer_name_global]["mlp"]["linear"]["w"]) layer_weight["pre_self_attention_norm_global"]["scale"].append( params["transformer"][in_layer_name_global]["pre_attention_norm"]["scale"] + 1 ) - layer_weight["pre_ffw_norm_global"]["scale"].append(params["transformer"][in_layer_name_global]["pre_ffw_norm"]["scale"] + 1) + layer_weight["pre_ffw_norm_global"]["scale"].append( + params["transformer"][in_layer_name_global]["pre_ffw_norm"]["scale"] + 1 + ) layer_weight["post_self_attention_norm_global"]["scale"].append( - params["transformer"][in_layer_name_global]["post_attention_norm"]["scale"] + 1 + params["transformer"][in_layer_name_global]["post_attention_norm"]["scale"] + 1 + ) + layer_weight["post_ffw_norm_global"]["scale"].append( + params["transformer"][in_layer_name_global]["post_ffw_norm"]["scale"] + 1 ) - layer_weight["post_ffw_norm_global"]["scale"].append(params["transformer"][in_layer_name_global]["post_ffw_norm"]["scale"] + 1) self_attention_local["query"]["kernel"] = np.array(self_attention_local["query"]["kernel"]).transpose((1, 0, 2, 3)) self_attention_local["key"]["kernel"] = np.array(self_attention_local["key"]["kernel"]).transpose((1, 0, 2, 3)) @@ -211,22 +244,22 @@ def main(raw_args=None) -> None: layer_weight["mlp_global"]["wi_1"]["kernel"] = np.array(layer_weight["mlp_global"]["wi_1"]["kernel"]).transpose((1, 0, 2)) layer_weight["mlp_global"]["wo"]["kernel"] = np.array(layer_weight["mlp_global"]["wo"]["kernel"]).transpose((1, 0, 2)) - layer_weight["pre_self_attention_norm_local"]["scale"] = np.array(layer_weight["pre_self_attention_norm_local"]["scale"]).transpose( - (1, 0) - ) + layer_weight["pre_self_attention_norm_local"]["scale"] = np.array( + layer_weight["pre_self_attention_norm_local"]["scale"] + ).transpose((1, 0)) layer_weight["pre_ffw_norm_local"]["scale"] = np.array(layer_weight["pre_ffw_norm_local"]["scale"]).transpose((1, 0)) - layer_weight["post_self_attention_norm_local"]["scale"] = np.array(layer_weight["post_self_attention_norm_local"]["scale"]).transpose( - (1, 0) - ) + layer_weight["post_self_attention_norm_local"]["scale"] = np.array( + layer_weight["post_self_attention_norm_local"]["scale"] + ).transpose((1, 0)) layer_weight["post_ffw_norm_local"]["scale"] = np.array(layer_weight["post_ffw_norm_local"]["scale"]).transpose((1, 0)) - layer_weight["pre_self_attention_norm_global"]["scale"] = np.array(layer_weight["pre_self_attention_norm_global"]["scale"]).transpose( - (1, 0) - ) + layer_weight["pre_self_attention_norm_global"]["scale"] = np.array( + layer_weight["pre_self_attention_norm_global"]["scale"] + ).transpose((1, 0)) layer_weight["pre_ffw_norm_global"]["scale"] = np.array(layer_weight["pre_ffw_norm_global"]["scale"]).transpose((1, 0)) - layer_weight["post_self_attention_norm_global"]["scale"] = np.array(layer_weight["post_self_attention_norm_global"]["scale"]).transpose( - (1, 0) - ) + layer_weight["post_self_attention_norm_global"]["scale"] = np.array( + layer_weight["post_self_attention_norm_global"]["scale"] + ).transpose((1, 0)) layer_weight["post_ffw_norm_global"]["scale"] = np.array(layer_weight["post_ffw_norm_global"]["scale"]).transpose((1, 0)) layer_weight["self_attention_local"] = copy.deepcopy(self_attention_local) diff --git a/MaxText/convert_gemma_chkpt.py b/MaxText/convert_gemma_chkpt.py index 8bf000b87..38881ac43 100644 --- a/MaxText/convert_gemma_chkpt.py +++ b/MaxText/convert_gemma_chkpt.py @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + # pylint: disable=line-too-long """ Convert orbax Gemma checkpoint to MaxText compatible checkpoint. @@ -78,27 +79,31 @@ def main(raw_args=None) -> None: }, "token_embedder": {"embedding": params["transformer"]["embedder"]["input_embedding"] * jnp.sqrt(embed_dim)}, } - self_attention = dict({ - "query": {"kernel": []}, - "key": {"kernel": []}, - "value": {"kernel": []}, - "out": {"kernel": []}, - }) - - layer_weight = dict({ - "mlp": { - "wi_0": {"kernel": []}, - "wi_1": {"kernel": []}, - "wo": {"kernel": []}, - }, - "pre_self_attention_norm": {"scale": []}, - "pre_ffw_norm": {"scale": []}, - }) + self_attention = dict( + { + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, + } + ) + + layer_weight = dict( + { + "mlp": { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, + }, + "pre_self_attention_norm": {"scale": []}, + "pre_ffw_norm": {"scale": []}, + } + ) for layer_idx in range(num_layers): in_layer_name = "layer_" + str(layer_idx) # attention block - if args.model_size in ("2b","9b"): # MQA + if args.model_size in ("2b", "9b"): # MQA self_attention["query"]["kernel"].append( params["transformer"][in_layer_name]["attn"]["q_einsum"]["w"].transpose((1, 0, 2)) * head_dim**-0.5 ) diff --git a/MaxText/decode.py b/MaxText/decode.py index add37e746..344d3fb97 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -31,9 +31,7 @@ def main(config): text = config.prompt metadata = engine.get_tokenizer() tokenizer_model = engine.build_tokenizer(metadata) - tokens, true_length = tokenizer_model.encode( - text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length] - ) + tokens, true_length = tokenizer_model.encode(text, is_bos=True, prefill_lengths=[config.max_prefill_predict_length]) assert true_length <= config.max_prefill_predict_length, "can't take too many tokens" assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index 8969430b9..80c97361d 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -56,7 +56,8 @@ def new_pspec(x): new_per_layer_state_annotation = jax.tree_util.tree_map(new_pspec, training_state_annotations_layers) new_per_layer_state_sharding = jax.tree_util.tree_map( - lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) + lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation + ) for i in range(config.num_decoder_layers): @@ -92,7 +93,8 @@ def _save_decode_checkpoint(config, state, checkpoint_manager): """Generate checkpoint for decode from the training_state.""" with jax.spmd_mode("allow_all"): decode_state = max_utils.init_decode_state( - None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params)) + None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params) + ) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, 0, decode_state): max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}") diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index 6276a907b..171bd7af9 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -46,9 +46,7 @@ def prefill_benchmark_loop(engine, params, tokens, true_length, iters): return (end - start).total_seconds() -def prefill_benchmark( - config, engine, params, tokens, true_length, num_model_params, iters -): +def prefill_benchmark(config, engine, params, tokens, true_length, num_model_params, iters): """Handles warmup, running prefill benchmark, and printing results.""" for _ in range(_WARMUP_ITERS): prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) @@ -59,7 +57,7 @@ def prefill_benchmark( time_in_s = prefill_benchmark_loop(engine, params, tokens, true_length, iters) prefill_average_ms = 1000 * time_in_s / iters prefill_tflops_per_device, _, _ = maxtext_utils.calculate_prefill_tflops_per_device(num_model_params, tokens.size, config) - tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0 + tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0 print( f"\tPrefill step average time: {prefill_average_ms:.3f} ms\n" f"\tPrefill total TFLOPs/device: {prefill_tflops_per_device:.3f}\n" @@ -75,7 +73,7 @@ def prefill_benchmark( def prefill_insert_benchmark_loop( config, engine, decode_state, params, total_slots, tokens, true_length, iters, profile_name - ): +): """Inner loop for benchmarking prefill and insert step.""" prof = profiler.Profiler(config, profile_name) prof.activate() @@ -90,9 +88,7 @@ def prefill_insert_benchmark_loop( return (end - start).total_seconds(), decode_state -def prefill_insert_benchmark( - config, engine, decode_state, params, total_slots, tokens, true_length, iters - ): +def prefill_insert_benchmark(config, engine, decode_state, params, total_slots, tokens, true_length, iters): """Handles warmup, running insert benchmark, and printing results.""" for i in range(_WARMUP_ITERS): @@ -103,14 +99,11 @@ def prefill_insert_benchmark( print(f"Prefill and insert benchmark results for length {tokens.size}:\n") time_in_s, decode_state = prefill_insert_benchmark_loop( - config, engine, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}") - prefill_insert_average_ms = time_in_s / iters * 1000.0 - print( - f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n" + config, engine, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}" ) - result_dict = { - "time_in_ms": prefill_insert_average_ms - } + prefill_insert_average_ms = time_in_s / iters * 1000.0 + print(f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n") + result_dict = {"time_in_ms": prefill_insert_average_ms} return result_dict, decode_state @@ -173,7 +166,7 @@ def collate_results(config, results, model_size, cache_size, num_model_params, i return results -def flatten_dict(dictionary, prefix='', sep='_'): +def flatten_dict(dictionary, prefix="", sep="_"): results = [] for k, v in dictionary.items(): new_key = str(prefix) + sep + str(k) if prefix else k @@ -187,7 +180,7 @@ def flatten_dict(dictionary, prefix='', sep='_'): def write_results(results, filename, flatten_microbenchmark_results): """Write the results microbenchmark results to a json file.""" if flatten_microbenchmark_results: - results['flattened_results'] = flatten_dict(results) + results["flattened_results"] = flatten_dict(results) if filename != "": with open(filename, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) @@ -219,20 +212,20 @@ def summarize_prefill_result(engine, params, tokens, true_length): print(f"Prefill result of length {tokens.size}:\n") prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) jax.block_until_ready(prefill_result) - num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = ( - max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True) + num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = max_utils.summarize_pytree_data( + prefill_result["logits"], name="Prefill Logits", raw=True ) - num_prefill_cache_params, total_prefill_cache_size, avg_prefill_cache_param_size = ( - max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache") + num_prefill_cache_params, total_prefill_cache_size, avg_prefill_cache_param_size = max_utils.summarize_pytree_data( + prefill_result["cache"], name="Prefill Cache" ) del prefill_result return { - "num_logits_params": num_prefill_logits_params, - "total_logits_size": total_prefill_logits_size, - "avg_logits_param_size": avg_prefill_logits_param_size, - "num_cache_params": num_prefill_cache_params, - "total_cache_size": total_prefill_cache_size, - "avg_cache_param_size": avg_prefill_cache_param_size, + "num_logits_params": num_prefill_logits_params, + "total_logits_size": total_prefill_logits_size, + "avg_logits_param_size": avg_prefill_logits_param_size, + "num_cache_params": num_prefill_cache_params, + "total_cache_size": total_prefill_cache_size, + "avg_cache_param_size": avg_prefill_cache_param_size, } @@ -262,54 +255,55 @@ def main(config, inference_metadata: Optional[Dict[str, Any]] = None): for prefill_length in prefill_lengths: prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad( - text, vocab, is_bos=True, prefill_lengths=[prefill_length] + text, vocab, is_bos=True, prefill_lengths=[prefill_length] ) benchmark_results["prefill-result-sizes"][prefill_length] = summarize_prefill_result( - engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] + engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] ) for prefill_length in prefill_lengths: benchmark_results["prefill"][prefill_length] = prefill_benchmark( - config, - engine, - params, - prefill_tokens[prefill_length], - prefill_true_lengths[prefill_length], - num_model_params, - benchmark_loop_iters + config, + engine, + params, + prefill_tokens[prefill_length], + prefill_true_lengths[prefill_length], + num_model_params, + benchmark_loop_iters, ) prefill_insert_time, decode_state = prefill_insert_benchmark( - config, - engine, - decode_state, - params, - engine.max_concurrent_decodes, - prefill_tokens[prefill_length], - prefill_true_lengths[prefill_length], - benchmark_loop_iters + config, + engine, + decode_state, + params, + engine.max_concurrent_decodes, + prefill_tokens[prefill_length], + prefill_true_lengths[prefill_length], + benchmark_loop_iters, ) benchmark_results["insert"][prefill_length] = {} benchmark_results["insert"][prefill_length]["time_in_ms"] = ( - prefill_insert_time["time_in_ms"] - benchmark_results["prefill"][prefill_length]["time_in_ms"] + prefill_insert_time["time_in_ms"] - benchmark_results["prefill"][prefill_length]["time_in_ms"] ) if "generate" in stages_to_benchmark: benchmark_results["autoregressive"], decode_state = ar_benchmark( - config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters) + config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters + ) results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) print_results_for_analyze(results) if inference_metadata: flatten_microbenchmark_results = pyconfig.string_to_bool( - inference_metadata.get('flatten_microbenchmark_results', 'false') + inference_metadata.get("flatten_microbenchmark_results", "false") ) else: - flatten_microbenchmark_results = 'false' + flatten_microbenchmark_results = "false" results = write_results( - results, - filename=config.inference_microbenchmark_log_file_path, - flatten_microbenchmark_results=flatten_microbenchmark_results + results, + filename=config.inference_microbenchmark_log_file_path, + flatten_microbenchmark_results=flatten_microbenchmark_results, ) return results diff --git a/MaxText/inference_microbenchmark_sweep.py b/MaxText/inference_microbenchmark_sweep.py index 8f7ddda5b..b641e3fd4 100644 --- a/MaxText/inference_microbenchmark_sweep.py +++ b/MaxText/inference_microbenchmark_sweep.py @@ -42,99 +42,101 @@ def main(): config = pyconfig.config base_run_name = config.run_name - with open(config.inference_metadata_file, encoding='utf-8') as json_file: + with open(config.inference_metadata_file, encoding="utf-8") as json_file: inference_metadata = json.load(json_file) print(f"inference_metadata: {inference_metadata}") - two_axis_order_product_id_list = inference_metadata['two_axis_order_product_id_list'].split(':') - prefill_cache_axis_order_list = inference_metadata['prefill_cache_axis_order_list'].split(':') - ar_cache_axis_order_list = inference_metadata['ar_cache_axis_order_list'].split(':') + two_axis_order_product_id_list = inference_metadata["two_axis_order_product_id_list"].split(":") + prefill_cache_axis_order_list = inference_metadata["prefill_cache_axis_order_list"].split(":") + ar_cache_axis_order_list = inference_metadata["ar_cache_axis_order_list"].split(":") start_two_axis_order_product_id = two_axis_order_product_id_list[0] end_two_axis_order_product_id = two_axis_order_product_id_list[-1] results = [] for ( - two_axis_order_product_id, - prefill_cache_axis_order, - ar_cache_axis_order, + two_axis_order_product_id, + prefill_cache_axis_order, + ar_cache_axis_order, ) in zip( - two_axis_order_product_id_list, - prefill_cache_axis_order_list, - ar_cache_axis_order_list, + two_axis_order_product_id_list, + prefill_cache_axis_order_list, + ar_cache_axis_order_list, ): print(f"two_axis_order_product_id {two_axis_order_product_id}") print(f"prefill_cache_axis_order {prefill_cache_axis_order}") print(f"ar_cache_axis_order {ar_cache_axis_order}") - run_tag = ( - f"{two_axis_order_product_id}-{prefill_cache_axis_order.replace(',','')}-{ar_cache_axis_order.replace(',','')}" - ) + run_tag = f"{two_axis_order_product_id}-{prefill_cache_axis_order.replace(',','')}-{ar_cache_axis_order.replace(',','')}" run_name = f"{base_run_name}/{run_tag}" tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "") - pyconfig._config.keys['prefill_cache_axis_order'] = prefill_cache_axis_order # pylint: disable=protected-access - pyconfig._config.keys['ar_cache_axis_order'] = ar_cache_axis_order # pylint: disable=protected-access - pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access - pyconfig._config.keys['run_name'] = run_name # pylint: disable=protected-access + pyconfig._config.keys["prefill_cache_axis_order"] = prefill_cache_axis_order # pylint: disable=protected-access + pyconfig._config.keys["ar_cache_axis_order"] = ar_cache_axis_order # pylint: disable=protected-access + pyconfig._config.keys["tensorboard_dir"] = tensorboard_dir # pylint: disable=protected-access + pyconfig._config.keys["run_name"] = run_name # pylint: disable=protected-access # Prepare metadata (dimensions) json for XLML dimensions_json = { - "base_output_directory": config.base_output_directory, - "model_name": config.model_name, - "tokenizer": config.tokenizer_path, - "weight_dtype": config.weight_dtype, - "inference_microbenchmark_prefill_lengths": f"{config.inference_microbenchmark_prefill_lengths}", - "inference_microbenchmark_stages": config.inference_microbenchmark_stages, - "inference_microbenchmark_loop_iters": f"{config.inference_microbenchmark_loop_iters}", - "max_prefill_predict_length": f"{config.max_prefill_predict_length}", - "max_target_length": f"{config.max_target_length}", - "per_device_batch_size": f"{config.per_device_batch_size}", - "ici_fsdp_parallelism": f"{config.ici_fsdp_parallelism}", - "ici_autoregressive_parallelism": f"{config.ici_autoregressive_parallelism}", - "ici_tensor_parallelism": f"{config.ici_tensor_parallelism}", - "profiler": f"{config.profiler}", - "scan_layers": f"{config.scan_layers}", - "quantization": config.quantization, - "quantize_kvcache": f"{config.quantize_kvcache}", - "attention": config.attention, - "two_axis_order_product_id": f"{two_axis_order_product_id}", - "prefill_cache_axis_order": f"{prefill_cache_axis_order}", - "ar_cache_axis_order": f"{ar_cache_axis_order}", - "compute_axis_order": f"{config.compute_axis_order}", - "reshape_q": f"{config.reshape_q}", - "kv_quant_axis": f"{config.kv_quant_axis}", - "run_name": f"{run_name}", - "run_tag": f"{run_tag}", - "config_json_string": json.dumps( - pyconfig._config.keys, # pylint: disable=protected-access - default=lambda x: f"<>" - ) + "base_output_directory": config.base_output_directory, + "model_name": config.model_name, + "tokenizer": config.tokenizer_path, + "weight_dtype": config.weight_dtype, + "inference_microbenchmark_prefill_lengths": f"{config.inference_microbenchmark_prefill_lengths}", + "inference_microbenchmark_stages": config.inference_microbenchmark_stages, + "inference_microbenchmark_loop_iters": f"{config.inference_microbenchmark_loop_iters}", + "max_prefill_predict_length": f"{config.max_prefill_predict_length}", + "max_target_length": f"{config.max_target_length}", + "per_device_batch_size": f"{config.per_device_batch_size}", + "ici_fsdp_parallelism": f"{config.ici_fsdp_parallelism}", + "ici_autoregressive_parallelism": f"{config.ici_autoregressive_parallelism}", + "ici_tensor_parallelism": f"{config.ici_tensor_parallelism}", + "profiler": f"{config.profiler}", + "scan_layers": f"{config.scan_layers}", + "quantization": config.quantization, + "quantize_kvcache": f"{config.quantize_kvcache}", + "attention": config.attention, + "two_axis_order_product_id": f"{two_axis_order_product_id}", + "prefill_cache_axis_order": f"{prefill_cache_axis_order}", + "ar_cache_axis_order": f"{ar_cache_axis_order}", + "compute_axis_order": f"{config.compute_axis_order}", + "reshape_q": f"{config.reshape_q}", + "kv_quant_axis": f"{config.kv_quant_axis}", + "run_name": f"{run_name}", + "run_tag": f"{run_tag}", + "config_json_string": json.dumps( + pyconfig._config.keys, # pylint: disable=protected-access + default=lambda x: f"<>", + ), } dimensions_json = { - **dimensions_json, - **inference_metadata, + **dimensions_json, + **inference_metadata, } try: microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata) - metrics = microbenchmark_results['flattened_results'] + metrics = microbenchmark_results["flattened_results"] metrics = {k.lower(): v for k, v in metrics.items()} - dimensions_json['oom'] = 'False' - print(f"Completed run {two_axis_order_product_id} out of: " - f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}") + dimensions_json["oom"] = "False" + print( + f"Completed run {two_axis_order_product_id} out of: " + f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}" + ) except xla_extension.XlaRuntimeError: # OOM metrics = {} - dimensions_json['oom'] = 'True' - print(f"Failed at run {two_axis_order_product_id} out of: " - f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}") + dimensions_json["oom"] = "True" + print( + f"Failed at run {two_axis_order_product_id} out of: " + f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}" + ) - final = {'metrics': metrics, 'dimensions': dimensions_json} + final = {"metrics": metrics, "dimensions": dimensions_json} print(f"Result: {final}") results.append(final) print(f"All results {results}") - path = 'inference_microbenchmark_sweep_results.jsonl' + path = "inference_microbenchmark_sweep_results.jsonl" with jsonlines.open(path, mode="w") as writer: writer.write_all(results) diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index 0171fd765..69efc3a11 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -63,13 +63,16 @@ def preprocessing_pipeline( operations.append(_input_pipeline_utils.NormalizeFeatures(data_column, tokenize)) if tokenize: - operations.append(_grain_tokenizer.TokenizeAndTrim(["inputs", "targets"], max_target_length, tokenizer_path, add_bos, add_eos)) + operations.append( + _grain_tokenizer.TokenizeAndTrim(["inputs", "targets"], max_target_length, tokenizer_path, add_bos, add_eos) + ) # Pack and Batch examples. if packing: operations.append( grain.experimental.PackAndBatchOperation( - batch_size=global_batch_size // jax.process_count(), length_struct={"inputs": max_target_length, "targets": max_target_length} + batch_size=global_batch_size // jax.process_count(), + length_struct={"inputs": max_target_length, "targets": max_target_length}, ) ) operations.append(_input_pipeline_utils.ReformatPacking()) @@ -103,6 +106,7 @@ def preprocessing_pipeline( # Return multi-host jax.Array prep iterator return multihost_gen + def make_grain_iterator( config: ml_collections.ConfigDict, global_mesh, @@ -111,26 +115,7 @@ def make_grain_iterator( """Load, preprocess dataset and return iterators""" train_ds = get_datasets(config.grain_train_files) train_iter = preprocessing_pipeline( - dataset=train_ds, - tokenizer_path=config.tokenizer_path, - global_batch_size=config.global_batch_size_to_load, - global_mesh=global_mesh, - max_target_length=config.max_target_length, - grain_worker_count=config.grain_worker_count, - dataloading_host_index=process_indices.index(jax.process_index()), - dataloading_host_count=len(process_indices), - data_column=config.train_data_column, - shuffle=config.enable_data_shuffling, - data_shuffle_seed=config.data_shuffle_seed, - tokenize=config.tokenize_train_data, - add_bos=config.add_bos, - add_eos=config.add_eos, - ) - - if config.eval_interval > 0: - eval_ds = get_datasets(config.grain_eval_files) - eval_iter = preprocessing_pipeline( - dataset=eval_ds, + dataset=train_ds, tokenizer_path=config.tokenizer_path, global_batch_size=config.global_batch_size_to_load, global_mesh=global_mesh, @@ -138,12 +123,31 @@ def make_grain_iterator( grain_worker_count=config.grain_worker_count, dataloading_host_index=process_indices.index(jax.process_index()), dataloading_host_count=len(process_indices), - data_column=config.eval_data_column, - shuffle=False, + data_column=config.train_data_column, + shuffle=config.enable_data_shuffling, data_shuffle_seed=config.data_shuffle_seed, - tokenize=config.tokenize_eval_data, + tokenize=config.tokenize_train_data, add_bos=config.add_bos, add_eos=config.add_eos, + ) + + if config.eval_interval > 0: + eval_ds = get_datasets(config.grain_eval_files) + eval_iter = preprocessing_pipeline( + dataset=eval_ds, + tokenizer_path=config.tokenizer_path, + global_batch_size=config.global_batch_size_to_load, + global_mesh=global_mesh, + max_target_length=config.max_target_length, + grain_worker_count=config.grain_worker_count, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + data_column=config.eval_data_column, + shuffle=False, + data_shuffle_seed=config.data_shuffle_seed, + tokenize=config.tokenize_eval_data, + add_bos=config.add_bos, + add_eos=config.add_eos, ) else: eval_iter = None diff --git a/MaxText/input_pipeline/_hf_data_processing.py b/MaxText/input_pipeline/_hf_data_processing.py index e3bffd704..76ef7ce08 100644 --- a/MaxText/input_pipeline/_hf_data_processing.py +++ b/MaxText/input_pipeline/_hf_data_processing.py @@ -73,13 +73,15 @@ def preprocessing_pipeline( else: dataset = dataset.select_columns([data_column_name]) - dataset = _input_pipeline_utils.HFDataSource(dataset, - dataloading_host_index, - dataloading_host_count, - num_threads, - generate_padding_example, - max_target_length, - data_column_name) + dataset = _input_pipeline_utils.HFDataSource( + dataset, + dataloading_host_index, + dataloading_host_count, + num_threads, + generate_padding_example, + max_target_length, + data_column_name, + ) operations = [] operations.append(_input_pipeline_utils.HFNormalizeFeatures(data_column_name)) @@ -125,11 +127,12 @@ def preprocessing_pipeline( # Return multi-host jax.Array prep iterator return multihost_gen + def make_hf_iterator( config: ml_collections.ConfigDict, global_mesh, process_indices, - ): +): """Load, preprocess dataset and return iterators""" train_ds = datasets.load_dataset( config.hf_path, @@ -140,31 +143,31 @@ def make_hf_iterator( token=config.hf_access_token, ) train_iter = preprocessing_pipeline( - dataloading_host_index=process_indices.index(jax.process_index()), - dataloading_host_count=len(process_indices), - global_mesh=global_mesh, - dataset=train_ds, - data_column_name=config.train_data_column, - tokenize=config.tokenize_train_data, - tokenizer_path=config.tokenizer_path, - hf_access_token=config.hf_access_token, - global_batch_size=config.global_batch_size_to_load, - max_target_length=config.max_target_length, - shuffle=config.enable_data_shuffling, - data_shuffle_seed=config.data_shuffle_seed, - add_bos=config.add_bos, - add_eos=config.add_eos, - generate_padding_example=True, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=global_mesh, + dataset=train_ds, + data_column_name=config.train_data_column, + tokenize=config.tokenize_train_data, + tokenizer_path=config.tokenizer_path, + hf_access_token=config.hf_access_token, + global_batch_size=config.global_batch_size_to_load, + max_target_length=config.max_target_length, + shuffle=config.enable_data_shuffling, + data_shuffle_seed=config.data_shuffle_seed, + add_bos=config.add_bos, + add_eos=config.add_eos, + generate_padding_example=True, ) if config.eval_interval > 0: eval_ds = datasets.load_dataset( - config.hf_path, - data_dir=config.hf_data_dir, - data_files=config.hf_eval_files, - split=config.hf_eval_split, - streaming=True, - token=config.hf_access_token, + config.hf_path, + data_dir=config.hf_data_dir, + data_files=config.hf_eval_files, + split=config.hf_eval_split, + streaming=True, + token=config.hf_access_token, ) if config.eval_per_device_batch_size > 0: eval_batch_size = config.eval_per_device_batch_size * global_mesh.size @@ -172,25 +175,25 @@ def make_hf_iterator( eval_batch_size = config.global_batch_size_to_load if config.eval_steps > 0: - eval_generate_padding_example=True + eval_generate_padding_example = True else: - eval_generate_padding_example=False + eval_generate_padding_example = False eval_iter = preprocessing_pipeline( - dataloading_host_index=process_indices.index(jax.process_index()), - dataloading_host_count=len(process_indices), - global_mesh=global_mesh, - dataset=eval_ds, - data_column_name=config.eval_data_column, - tokenize=config.tokenize_eval_data, - tokenizer_path=config.tokenizer_path, - hf_access_token=config.hf_access_token, - global_batch_size=eval_batch_size, - max_target_length=config.max_target_length, - shuffle=False, - data_shuffle_seed=config.data_shuffle_seed, - add_bos=config.add_bos, - add_eos=config.add_eos, - generate_padding_example=eval_generate_padding_example, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=global_mesh, + dataset=eval_ds, + data_column_name=config.eval_data_column, + tokenize=config.tokenize_eval_data, + tokenizer_path=config.tokenizer_path, + hf_access_token=config.hf_access_token, + global_batch_size=eval_batch_size, + max_target_length=config.max_target_length, + shuffle=False, + data_shuffle_seed=config.data_shuffle_seed, + add_bos=config.add_bos, + add_eos=config.add_eos, + generate_padding_example=eval_generate_padding_example, ) else: eval_iter = None diff --git a/MaxText/input_pipeline/_input_pipeline_utils.py b/MaxText/input_pipeline/_input_pipeline_utils.py index 4b34c802e..ddaec7e3c 100644 --- a/MaxText/input_pipeline/_input_pipeline_utils.py +++ b/MaxText/input_pipeline/_input_pipeline_utils.py @@ -33,26 +33,32 @@ ########## Functions used by TFDS pipeline + def normalize_features(x, column_name): return {"inputs": x[column_name], "targets": x[column_name]} + def get_tokenizer(tokenizer_path, add_bos, add_eos): # Load tokenizer tokenizer_model = tokenizer.build_tokenizer(tokenizer_path, add_bos, add_eos) return tokenizer_model + def truncate_to_max_allowable_length(x, max_length): x["inputs"] = x["inputs"][:max_length] x["targets"] = x["targets"][:max_length] return x + def shift_data_by_truncation(x): x["inputs"] = x["inputs"][:-1] x["targets"] = x["targets"][1:] return x + ########## Functions used by HF pipeline + def tokenization(example, hf_tokenizer, max_length, column_name): """Tokenize a HuggingFace dataset""" return hf_tokenizer(example[column_name], truncation=True, max_length=max_length) @@ -61,8 +67,10 @@ def tokenization(example, hf_tokenizer, max_length, column_name): @dataclasses.dataclass class HFNormalizeFeatures(grain.MapTransform): """Normalize feature keys for HuggingFace input""" + def __init__(self, column_name): self.column_name = column_name + def map(self, features): return { "inputs": np.asarray(features[self.column_name], dtype=np.int32), @@ -73,15 +81,16 @@ def map(self, features): class HFDataSource(grain.RandomAccessDataSource): """A class that makes HuggingFace IterableDataset a grain datasource without random access support""" - def __init__(self, - dataset: datasets.IterableDataset, - dataloading_host_index: int, - dataloading_host_count: int, - num_threads: int, - generate_padding_example: bool, - max_target_length: int, - data_column_name: str - ): + def __init__( + self, + dataset: datasets.IterableDataset, + dataloading_host_index: int, + dataloading_host_count: int, + num_threads: int, + generate_padding_example: bool, + max_target_length: int, + data_column_name: str, + ): self.dataset = dataset self.num_threads = num_threads self.dataloading_host_count = dataloading_host_count @@ -94,14 +103,15 @@ def __init__(self, self.dataset_shards = [dataloading_host_index * self.num_threads + i for i in range(self.num_threads)] self.datasets = [split_dataset_by_node(dataset, world_size=self.n_shards, rank=x) for x in self.dataset_shards] self.data_iters = [] - self.out_of_data =False + self.out_of_data = False def _check_shard_count(self): if self.n_shards < (self.dataloading_host_count * self.num_threads): - warnings.warn(f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, " - "smaller than number of host loading data. This is known to lead to inefficient dataloading. " - "see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice" - ) + warnings.warn( + f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, " + "smaller than number of host loading data. This is known to lead to inefficient dataloading. " + "see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice" + ) self.n_shards = self.dataloading_host_count * self.num_threads def _update_shard(self, idx): @@ -113,11 +123,14 @@ def _update_shard(self, idx): self.datasets[idx] = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.dataset_shards[idx]) self.data_iters[idx] = iter(self.datasets[idx]) else: - max_logging.log(f"Run out of shards on host {self.dataloading_host_index}, shard {self.dataset_shards[idx]} is not available") + max_logging.log( + f"Run out of shards on host {self.dataloading_host_index}, shard {self.dataset_shards[idx]} is not available" + ) self.out_of_data = True if self.generate_padding_example: - max_logging.log(f"Host {self.dataloading_host_index} will start generating all-0 padding examples until step number is met.") - + max_logging.log( + f"Host {self.dataloading_host_index} will start generating all-0 padding examples until step number is met." + ) def __len__(self): """Return length of the HF dataset. Since HuggingFace IterableDataset does not have length, @@ -143,11 +156,14 @@ def __getitem__(self, index): except StopIteration: self._update_shard(idx) + ########## Functions used by Grain pipeline + @dataclasses.dataclass class ParseFeatures(grain.MapTransform): """Parse serialized example""" + def __init__(self, data_column, tokenize): self.data_column = data_column if tokenize: @@ -157,23 +173,28 @@ def __init__(self, data_column, tokenize): def map(self, features): def _parse(example): - parsed = tf.io.parse_example(example, { - self.data_column: tf.io.FixedLenSequenceFeature([], dtype=self.dtype, allow_missing=True) - }) + parsed = tf.io.parse_example( + example, {self.data_column: tf.io.FixedLenSequenceFeature([], dtype=self.dtype, allow_missing=True)} + ) return parsed return _parse(features) + @dataclasses.dataclass class NormalizeFeatures(grain.MapTransform): """Normalize text feature keys.""" + def __init__(self, column_name, tokenize): self.column_name = column_name self.tokenize = tokenize def map(self, features): if self.tokenize: - return {"inputs": features[self.column_name].numpy()[0].decode(), "targets": features[self.column_name].numpy()[0].decode()} + return { + "inputs": features[self.column_name].numpy()[0].decode(), + "targets": features[self.column_name].numpy()[0].decode(), + } else: return {"inputs": features[self.column_name].numpy(), "targets": features[self.column_name].numpy()} @@ -252,4 +273,3 @@ def __init__(self, axis=1): def map(self, data): return shift_and_refine(data, axis=self.axis) - diff --git a/MaxText/input_pipeline/_tfds_data_processing.py b/MaxText/input_pipeline/_tfds_data_processing.py index 46e5db05d..5fdd6c16c 100644 --- a/MaxText/input_pipeline/_tfds_data_processing.py +++ b/MaxText/input_pipeline/_tfds_data_processing.py @@ -31,6 +31,7 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE + def get_datasets( dataset_name, data_split, @@ -54,10 +55,11 @@ def get_datasets( ) ds = ds_builder.as_dataset(split=data_split, read_config=read_config, shuffle_files=shuffle_files) else: - warnings.warn(f"WARNING: Inefficient dataloading. Your {dataset_name} contains {ds_builder.info.splits[data_split].num_shards} shards, " - f"smaller than {dataloading_host_count=}. This is known to lead to inefficient dataloading." - "see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice" - ) + warnings.warn( + f"WARNING: Inefficient dataloading. Your {dataset_name} contains {ds_builder.info.splits[data_split].num_shards} shards, " + f"smaller than {dataloading_host_count=}. This is known to lead to inefficient dataloading." + "see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice" + ) ds = ds_builder.as_dataset(split=data_split, read_config=read_config, shuffle_files=shuffle_files) ds = ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) @@ -93,7 +95,10 @@ def preprocessing_pipeline( if max_target_length > 0: # We can take upto max_length+1 because there would be truncation by 1 token # for both inputs and targets - dataset = dataset.map(lambda x: _input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + 1), num_parallel_calls=AUTOTUNE) + dataset = dataset.map( + lambda x: _input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + 1), + num_parallel_calls=AUTOTUNE, + ) # Shuffle and repeat. if shuffle: @@ -103,7 +108,9 @@ def preprocessing_pipeline( # Shift inputs for teacher-forced training if shift: - dataset = dataset.map(_input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) + dataset = dataset.map( + _input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True + ) # Perform greedy sequence packing and batching assert global_batch_size % global_mesh.size == 0, "Batch size should be divisible number of global devices." @@ -135,35 +142,35 @@ def make_tfds_iterator( ): """load dataset, preprocess and return iterators""" train_ds = get_datasets( - dataset_name=config.dataset_name, - data_split='train', - shuffle_files=config.enable_data_shuffling, - shuffle_seed=config.data_shuffle_seed, - dataloading_host_index=process_indices.index(jax.process_index()), - dataloading_host_count=len(process_indices), + dataset_name=config.dataset_name, + data_split="train", + shuffle_files=config.enable_data_shuffling, + shuffle_seed=config.data_shuffle_seed, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), ) train_iter = preprocessing_pipeline( - dataset=train_ds, - tokenizer_path=config.tokenizer_path, - global_batch_size=config.global_batch_size_to_load, - global_mesh=global_mesh, - max_target_length=config.max_target_length, - data_column_name=config.train_data_column, - shuffle=config.enable_data_shuffling, - data_shuffle_seed=config.data_shuffle_seed, - tokenize=config.tokenize_train_data, - add_bos=config.add_bos, - add_eos=config.add_eos, + dataset=train_ds, + tokenizer_path=config.tokenizer_path, + global_batch_size=config.global_batch_size_to_load, + global_mesh=global_mesh, + max_target_length=config.max_target_length, + data_column_name=config.train_data_column, + shuffle=config.enable_data_shuffling, + data_shuffle_seed=config.data_shuffle_seed, + tokenize=config.tokenize_train_data, + add_bos=config.add_bos, + add_eos=config.add_eos, ) if config.eval_interval > 0: eval_ds = get_datasets( - dataset_name=config.eval_dataset_name, - data_split=config.eval_split, - shuffle_files=False, - shuffle_seed=config.data_shuffle_seed, - dataloading_host_index=process_indices.index(jax.process_index()), - dataloading_host_count=len(process_indices), + dataset_name=config.eval_dataset_name, + data_split=config.eval_split, + shuffle_files=False, + shuffle_seed=config.data_shuffle_seed, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), ) if config.eval_per_device_batch_size > 0: diff --git a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index 9ce2df0bd..d238800cd 100644 --- a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -221,11 +221,11 @@ def get_dataset( if shard_in_read: # shard dataset in reading read_config = tfds.ReadConfig( - shuffle_seed = data_shuffle_seed, - input_context = tf.distribute.InputContext( - input_pipeline_id=dataloading_host_index, - num_input_pipelines=dataloading_host_count, - ), + shuffle_seed=data_shuffle_seed, + input_context=tf.distribute.InputContext( + input_pipeline_id=dataloading_host_index, + num_input_pipelines=dataloading_host_count, + ), ) ds_builder = tfds.builder(dataset_name) ds_builder.download_and_prepare() @@ -260,7 +260,9 @@ def preprocess_train_dataset( data_shuffle_seed: int, ) -> tf.data.Dataset: """Preprocess the training dataset.""" - train_ds = train_ds.map(lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE) + train_ds = train_ds.map( + lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)), num_parallel_calls=AUTOTUNE + ) train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096) train_ds = split_tokens_to_targets_length(train_ds, max_target_length) @@ -299,20 +301,20 @@ def preprocess_eval_dataset( def make_c4_mlperf_train_iterator( - config: ml_collections.ConfigDict, - global_mesh, - add_bos, - add_eos, - process_indices, + config: ml_collections.ConfigDict, + global_mesh, + add_bos, + add_eos, + process_indices, ): """Make train iterator of customized C4 dataset for mlperf gpt3 training.""" train_ds = get_dataset( - dataset_name=config.dataset_name, - split="train2", - dataloading_host_index=process_indices.index(jax.process_index()), - dataloading_host_count=len(process_indices), - enable_data_shuffling=config.enable_data_shuffling, - data_shuffle_seed=config.data_shuffle_seed, + dataset_name=config.dataset_name, + split="train2", + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + enable_data_shuffling=config.enable_data_shuffling, + data_shuffle_seed=config.data_shuffle_seed, ) train_ds = rekey(train_ds, {"inputs": None, "targets": "text"}) @@ -330,17 +332,17 @@ def make_c4_mlperf_train_iterator( def make_c4_mlperf_eval_iterator( - config: ml_collections.ConfigDict, - global_mesh, - process_indices, + config: ml_collections.ConfigDict, + global_mesh, + process_indices, ): """Make eval iterator of customized C4 dataset for mlperf gpt3 training.""" eval_ds = get_dataset( - dataset_name=config.eval_dataset_name, - split="validation_tokenized_5662seqs", - dataloading_host_index=process_indices.index(jax.process_index()), - dataloading_host_count=len(process_indices), - enable_data_shuffling=False, + dataset_name=config.eval_dataset_name, + split="validation_tokenized_5662seqs", + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + enable_data_shuffling=False, ) # note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length # mainly to avoid eval sequences change depending on the number of hosts diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 67b7dfb12..b499cc839 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -144,7 +144,9 @@ def make_c4_mlperf_iterator(config, mesh): assert len(process_indices) == jax.process_count() // config.expansion_factor_real_data print("Overwrite both add_bos and add_eos to False") if jax.process_index() in process_indices: - train_iterator = make_c4_mlperf_train_iterator(config, mesh, add_bos=False, add_eos=False, process_indices=process_indices) + train_iterator = make_c4_mlperf_train_iterator( + config, mesh, add_bos=False, add_eos=False, process_indices=process_indices + ) else: train_iterator = BadSyntheticDataIterator(config, mesh) @@ -153,7 +155,9 @@ def make_c4_mlperf_iterator(config, mesh): else: effective_eval_per_device_batch_size = config.per_device_batch_size - assert effective_eval_per_device_batch_size >= 1.0, f"{effective_eval_per_device_batch_size=} is less than 1, which is not supported." + assert ( + effective_eval_per_device_batch_size >= 1.0 + ), f"{effective_eval_per_device_batch_size=} is less than 1, which is not supported." # Use all processes for evaluation until split is handled independently eval_process_indices = list(range(jax.process_count())) eval_iterator = make_c4_mlperf_eval_iterator(config, mesh, eval_process_indices) diff --git a/MaxText/kernels/ragged_attention.py b/MaxText/kernels/ragged_attention.py index 21cab43d8..3a5778145 100644 --- a/MaxText/kernels/ragged_attention.py +++ b/MaxText/kernels/ragged_attention.py @@ -35,10 +35,10 @@ @functools.partial(jax.jit, static_argnames=["mask_value"]) def reference_mqa( - q: jax.Array, - k: jax.Array, - v: jax.Array, - lengths: jax.Array, + q: jax.Array, + k: jax.Array, + v: jax.Array, + lengths: jax.Array, *, mask_value: float = DEFAULT_MASK_VALUE, ) -> tuple[jax.Array, jax.Array, jax.Array]: @@ -57,22 +57,18 @@ def reference_mqa( max logit ([batch_size, num_heads]) and softmax denominator ([batch_size, num_heads]). """ - logits = jnp.einsum( - "bhd,btd->bht", q.astype(jnp.float32), k.astype(jnp.float32) - ) - mask = jnp.arange(k.shape[1])[None] < lengths[:, None] + logits = jnp.einsum("bhd,btd->bht", q.astype(jnp.float32), k.astype(jnp.float32)) + mask = jnp.arange(k.shape[1])[None] < lengths[:, None] - logits = logits + jnp.where(mask, 0.0, mask_value)[:, None] - logits_max = logits.max(axis=-1) + logits = logits + jnp.where(mask, 0.0, mask_value)[:, None] + logits_max = logits.max(axis=-1) - unnormalized = jnp.exp(logits - logits_max[..., None]) - denominator = unnormalized.sum(axis=-1) - o = ( - jnp.einsum("bht,btd->bhd", unnormalized.astype(v.dtype), v) - / denominator[..., None] - ) + unnormalized = jnp.exp(logits - logits_max[..., None]) + denominator = unnormalized.sum(axis=-1) + o = jnp.einsum("bht,btd->bhd", unnormalized.astype(v.dtype), v) / denominator[..., None] return o, logits_max[..., None], denominator[..., None] + @jax.jit def reference_mha( q: jax.Array, @@ -100,11 +96,9 @@ def reference_mha( q = jnp.swapaxes(q, 1, 2) k = jnp.swapaxes(k, 1, 2) v = jnp.swapaxes(v, 1, 2) - return jax.vmap(functools.partial( - reference_mqa, - mask_value=mask_value), - in_axes=(1, 1, 1, None), - out_axes=2)(q, k, v, lengths) + return jax.vmap(functools.partial(reference_mqa, mask_value=mask_value), in_axes=(1, 1, 1, None), out_axes=2)( + q, k, v, lengths + ) @functools.partial(jax.jit, static_argnames=["mask_value"]) @@ -137,23 +131,19 @@ def reference_gqa( q = q.reshape(batch_size, num_heads_kv, num_heads_q // num_heads_kv, head_dim) - logits = jnp.einsum( - "bhgd,bhtd->bhgt", q.astype(jnp.float32), k.astype(jnp.float32) - ) + logits = jnp.einsum("bhgd,bhtd->bhgt", q.astype(jnp.float32), k.astype(jnp.float32)) mask = jnp.arange(seq_len)[None] < lengths[:, None] logits = logits + jnp.where(mask, 0.0, mask_value)[:, None, None, :] logits_max = logits.max(axis=-1) unnormalized = jnp.exp(logits - logits_max[..., None]) denominator = unnormalized.sum(axis=-1) - o = ( - jnp.einsum("bhgt,bhtd->bhgd", unnormalized.astype(v.dtype), v) - / denominator[..., None] - ) + o = jnp.einsum("bhgt,bhtd->bhgd", unnormalized.astype(v.dtype), v) / denominator[..., None] logits_max = logits_max.reshape(batch_size, 1, num_heads_q, 1) denominator = denominator.reshape(batch_size, 1, num_heads_q, 1) o = o.reshape(batch_size, 1, num_heads_q, head_dim) return o, logits_max, denominator + def ragged_flash_attention_kernel( lengths_ref, q_ref, @@ -184,10 +174,8 @@ def run(): v = v_ref[...].astype(jnp.float32) m_prev, l_prev = m_ref[...], l_ref[...] - qk = lax.dot_general( - q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 - ) - + qk = lax.dot_general(q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32) + mask = i * block_size + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length qk = qk + jnp.where(mask, 0.0, mask_value) m_curr = qk.max(axis=-1) @@ -204,9 +192,7 @@ def run(): l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) m_ref[...], l_ref[...] = m_next, l_next_safe - o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe - ).astype(o_ref.dtype) + o_ref[...] = ((l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe).astype(o_ref.dtype) def ragged_mqa( @@ -214,7 +200,7 @@ def ragged_mqa( k: jax.Array, v: jax.Array, lengths: jax.Array, - *, + *, block_size: int = 256, mask_value: float = DEFAULT_MASK_VALUE, cost_estimate: pltpu.CostEstimate | None = None, @@ -235,10 +221,10 @@ def ragged_mqa( max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]). """ - batch_size, num_heads, head_dim = q.shape + batch_size, num_heads, head_dim = q.shape assert lengths.shape == (batch_size,) assert lengths.dtype == jnp.int32 - seq_len = k.shape[1] + seq_len = k.shape[1] def compute_ragged_block_indices(b, i, lengths_ref): length = lengths_ref[b] @@ -258,34 +244,22 @@ def compute_ragged_block_indices(b, i, lengths_ref): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec( - (None, num_heads, head_dim), - lambda b, i, _: (b, 0, 0)), - pl.BlockSpec( - (None, block_size, head_dim), - compute_ragged_block_indices), - pl.BlockSpec( - (None, block_size, head_dim), - compute_ragged_block_indices), + pl.BlockSpec((None, num_heads, head_dim), lambda b, i, _: (b, 0, 0)), + pl.BlockSpec((None, block_size, head_dim), compute_ragged_block_indices), + pl.BlockSpec((None, block_size, head_dim), compute_ragged_block_indices), ], out_specs=[ - pl.BlockSpec( - (None, num_heads, head_dim), - lambda b, i, _: (b, 0, 0)), - pl.BlockSpec( - (None, num_heads, head_dim), - lambda b, i, _: (b, 0, 0)), - pl.BlockSpec( - (None, num_heads, head_dim), - lambda b, i, _: (b, 0, 0)), + pl.BlockSpec((None, num_heads, head_dim), lambda b, i, _: (b, 0, 0)), + pl.BlockSpec((None, num_heads, head_dim), lambda b, i, _: (b, 0, 0)), + pl.BlockSpec((None, num_heads, head_dim), lambda b, i, _: (b, 0, 0)), ], grid=(batch_size, seq_len // block_size), ), compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary"), - cost_estimate=cost_estimate, - ) + mosaic=dict( + dimension_semantics=("parallel", "arbitrary"), + cost_estimate=cost_estimate, + ) ), out_shape=[ jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), @@ -297,20 +271,20 @@ def compute_ragged_block_indices(b, i, lengths_ref): @functools.partial( - jax.jit, - static_argnames=[ - "block_size", - "mask_value", - ], + jax.jit, + static_argnames=[ + "block_size", + "mask_value", + ], ) def ragged_mha( - query: jax.Array, - key: jax.Array, - value: jax.Array, - lengths: jax.Array, - *, - block_size: int = 256, - mask_value: float = DEFAULT_MASK_VALUE, + query: jax.Array, + key: jax.Array, + value: jax.Array, + lengths: jax.Array, + *, + block_size: int = 256, + mask_value: float = DEFAULT_MASK_VALUE, ) -> tuple[jax.Array, jax.Array, jax.Array]: """Ragged multi head attention. @@ -329,56 +303,56 @@ def ragged_mha( num_heads, 1]). """ cost_analysis = ( - reference_mha.lower( - query, - key, - value, - lengths, - mask_value=mask_value, - ) - .compile() - .cost_analysis()[0] + reference_mha.lower( + query, + key, + value, + lengths, + mask_value=mask_value, + ) + .compile() + .cost_analysis()[0] ) cost_estimate = pltpu.CostEstimate( - flops=int(cost_analysis["flops"]), - transcendentals=int(cost_analysis["transcendentals"]), - bytes_accessed=int(cost_analysis["bytes accessed"]), + flops=int(cost_analysis["flops"]), + transcendentals=int(cost_analysis["transcendentals"]), + bytes_accessed=int(cost_analysis["bytes accessed"]), ) query = jnp.swapaxes(query, 1, 2) key = jnp.swapaxes(key, 1, 2) value = jnp.swapaxes(value, 1, 2) - o, m, l = jax.vmap( - functools.partial( - ragged_mqa, - block_size=block_size, - mask_value=mask_value, - cost_estimate=cost_estimate, - ), - in_axes=(1, 1, 1, None), - out_axes=2, + o, m, l = jax.vmap( + functools.partial( + ragged_mqa, + block_size=block_size, + mask_value=mask_value, + cost_estimate=cost_estimate, + ), + in_axes=(1, 1, 1, None), + out_axes=2, )(query, key, value, lengths) m = jnp.expand_dims(m, axis=-1) l = jnp.expand_dims(l, axis=-1) - o = o * l + o = o * l return o, m, l @functools.partial( - jax.jit, - static_argnames=[ - "block_size", - "mask_value", - ], + jax.jit, + static_argnames=[ + "block_size", + "mask_value", + ], ) def ragged_gqa( - query: jax.Array, - key: jax.Array, - value: jax.Array, - lengths: jax.Array, - *, - block_size: int = 256, - mask_value: float = DEFAULT_MASK_VALUE, + query: jax.Array, + key: jax.Array, + value: jax.Array, + lengths: jax.Array, + *, + block_size: int = 256, + mask_value: float = DEFAULT_MASK_VALUE, ) -> tuple[jax.Array, jax.Array, jax.Array]: """Ragged group query attention. @@ -397,40 +371,40 @@ def ragged_gqa( num_heads, 1]). """ cost_analysis = ( - reference_gqa.lower( - jnp.squeeze(query), - jnp.swapaxes(key, 1, 2), - jnp.swapaxes(value, 1, 2), - lengths, - mask_value=mask_value, - ) - .compile() - .cost_analysis()[0] + reference_gqa.lower( + jnp.squeeze(query), + jnp.swapaxes(key, 1, 2), + jnp.swapaxes(value, 1, 2), + lengths, + mask_value=mask_value, + ) + .compile() + .cost_analysis()[0] ) cost_estimate = pltpu.CostEstimate( - flops=int(cost_analysis["flops"]), - transcendentals=int(cost_analysis["transcendentals"]), - bytes_accessed=int(cost_analysis["bytes accessed"]), + flops=int(cost_analysis["flops"]), + transcendentals=int(cost_analysis["transcendentals"]), + bytes_accessed=int(cost_analysis["bytes accessed"]), ) batch_size, _, num_heads_q, head_dim = query.shape _, _, num_heads_kv, _ = key.shape - + query = query.reshape(batch_size, num_heads_kv, num_heads_q // num_heads_kv, head_dim) # (b, n_kv, n_q // n_kv, d) - key = jnp.swapaxes(key, 1, 2) # (b, n_kv, s, d) - value = jnp.swapaxes(value, 1, 2) # (b, n_kv, s, d) - o, m, l = jax.vmap( - functools.partial( - ragged_mqa, - block_size=block_size, - mask_value=mask_value, - cost_estimate=cost_estimate, - ), - in_axes=(1, 1, 1, None), - out_axes=1, + key = jnp.swapaxes(key, 1, 2) # (b, n_kv, s, d) + value = jnp.swapaxes(value, 1, 2) # (b, n_kv, s, d) + o, m, l = jax.vmap( + functools.partial( + ragged_mqa, + block_size=block_size, + mask_value=mask_value, + cost_estimate=cost_estimate, + ), + in_axes=(1, 1, 1, None), + out_axes=1, )(query, key, value, lengths) m = jnp.reshape(m, (batch_size, 1, num_heads_q, 1)) l = jnp.reshape(l, (batch_size, 1, num_heads_q, 1)) o = jnp.reshape(o, (batch_size, 1, num_heads_q, head_dim)) - o = o * l - return o, m, l \ No newline at end of file + o = o * l + return o, m, l diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index cb75afbaf..4ea8cfb3e 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -89,7 +89,7 @@ class AttentionType(enum.Enum): def validate_compute_axis_order(s: AxisIdxes) -> None: - valid_compute_axis_order = ((0,1,2,3), (0,2,1,3)) + valid_compute_axis_order = ((0, 1, 2, 3), (0, 2, 1, 3)) if s not in valid_compute_axis_order: # currently supported compute_axis_order raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order) @@ -148,7 +148,7 @@ class AttentionOp(nn.Module): use_ragged_attention: bool = False ragged_block_size: int = 256 - def check_attention_inputs(self, query: Array, key: Array| KVTensor, value: Array| KVTensor) -> None: + def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None: """Check attention inputs.""" assert key.ndim == value.ndim, "k, v must have same rank." @@ -189,19 +189,24 @@ def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, if self.attention_type == AttentionType.LOCAL_SLIDING and output_mask is not None: if self.sliding_window_size is None: - raise ValueError( - 'Sliding_window_size must be set if Local Sliding attention type' - ) + raise ValueError("Sliding_window_size must be set if Local Sliding attention type") all_ones = jnp.ones_like(output_mask) - sliding_mask = jnp.triu( - all_ones, -1 * self.sliding_window_size + 1 - ) * jnp.tril(all_ones, self.sliding_window_size - 1) + sliding_mask = jnp.triu(all_ones, -1 * self.sliding_window_size + 1) * jnp.tril(all_ones, self.sliding_window_size - 1) output_mask = sliding_mask * output_mask return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None - def apply_attention(self, query: Array, key: Array | KVTensor, value: Array | KVTensor, decoder_segment_ids: Array | None, lengths: Array | None, model_mode: str, use_ragged_attention: bool = False): + def apply_attention( + self, + query: Array, + key: Array | KVTensor, + value: Array | KVTensor, + decoder_segment_ids: Array | None, + lengths: Array | None, + model_mode: str, + use_ragged_attention: bool = False, + ): self.check_attention_inputs(query, key, value) length = query.shape[-3] if use_ragged_attention and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: @@ -241,13 +246,15 @@ def apply_attention(self, query: Array, key: Array | KVTensor, value: Array | KV else: raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") - - def ragged_attention(self, query: Array, key: Array | KVTensor, value: Array | KVTensor, lengths: Array, block_size: int) -> tuple[Array, Array, Array]: + def ragged_attention( + self, query: Array, key: Array | KVTensor, value: Array | KVTensor, lengths: Array, block_size: int + ) -> tuple[Array, Array, Array]: """Ragged Attention.""" if isinstance(query, KVTensor) or isinstance(query, KVTensor): raise TypeError("Ragged attention does not currently support quantized tensors.") b = nn.logical_to_mesh_axes(self.ragged_lengths_names) bsnd = nn.logical_to_mesh_axes(self.cache_logical_axis_names) + @functools.partial( shard_map, mesh=self.mesh, @@ -269,7 +276,14 @@ def wrap_ragged_attention(query, key, value, lengths, block_size): return wrap_ragged_attention(query, key, value, lengths, block_size) - def tpu_flash_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None) -> Array: + def tpu_flash_attention( + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + attn_logits_soft_cap: float | None = None, + ) -> Array: """TPU Flash Attention.""" # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) @@ -306,11 +320,9 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): block_q=min(global_block_q, query.shape[2]), block_kv_compute=min(global_block_q, key.shape[2]), block_kv=min(global_block_q, key.shape[2]), - block_q_dkv=min(global_block_q_dkv, query.shape[2]), block_kv_dkv=min(global_block_q_dkv, key.shape[2]), block_kv_dkv_compute=min(global_block_q_dkv, query.shape[2]), - block_q_dq=min(global_block_q_dq, query.shape[2]), block_kv_dq=min(global_block_q_dq, query.shape[2]), ) @@ -320,9 +332,7 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): # Apply local masking if local sliding attention is enabled. if self.attention_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: - raise ValueError( - 'Sliding_window_size must be set if Local Sliding attention type' - ) + raise ValueError("Sliding_window_size must be set if Local Sliding attention type") mask &= splash_attention_mask.LocalMask( shape=(query.shape[2], query.shape[2]), window_size=(self.sliding_window_size, self.sliding_window_size), @@ -330,9 +340,13 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): ) # Create multi-head mask - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes, attn_logits_soft_cap=attn_logits_soft_cap, + mask=multi_head_mask, + head_shards=1, + q_seq_shards=1, + block_sizes=block_sizes, + attn_logits_soft_cap=attn_logits_soft_cap, ) return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids) @@ -381,7 +395,9 @@ def cudnn_flash_attention( ) return dpa_layer(query, key, value, mask=attn_mask) - def compute_local_attention(self, attn_weights: Array, value: Array | KVTensor, q_seq_len: int, model_mode: str) -> tuple[Array, Array, Array]: + def compute_local_attention( + self, attn_weights: Array, value: Array | KVTensor, q_seq_len: int, model_mode: str + ) -> tuple[Array, Array, Array]: """Computes the attention of a local subset of the kv cache. Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py @@ -410,17 +426,17 @@ def compute_local_attention(self, attn_weights: Array, value: Array | KVTensor, local_out = self.wv_product(local_exps, value, model_mode) if self.reshape_q and q_seq_len == 1: - local_max = local_max[:,0:1,:,:] - local_sum = local_sum[:,0:1,:,:] - local_out = local_out[:,0:1,:,:] + local_max = local_max[:, 0:1, :, :] + local_sum = local_sum[:, 0:1, :, :] + local_out = local_out[:, 0:1, :, :] return local_out, local_max, local_sum def apply_attention_dot( self, query: Array, - key: Array| KVTensor, - value: Array| KVTensor, + key: Array | KVTensor, + value: Array | KVTensor, decoder_segment_ids: Array | None, model_mode: str = common_types.MODEL_MODE_TRAIN, ): @@ -448,7 +464,7 @@ def apply_attention_dot( attn_weights = apply_mask_to_logits(attn_weights, attn_mask) return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode) - def qk_product(self, query: Array, key: Array| KVTensor, q_seq_len: int, model_mode: str) -> Array: + def qk_product(self, query: Array, key: Array | KVTensor, q_seq_len: int, model_mode: str) -> Array: """Query-Key product. Args: @@ -473,12 +489,12 @@ def qk_product(self, query: Array, key: Array| KVTensor, q_seq_len: int, model_m b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads - if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): + if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0, 1, 2, 3): query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) if self.reshape_q and q_seq_len == 1: query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d)) result = einsum("btkgd,bskd->bkgts", query, key) - elif self.compute_axis_order == (0,2,1,3): + elif self.compute_axis_order == (0, 2, 1, 3): query = jnp.transpose(query, axes=self.compute_axis_order) key = jax.tree.map(lambda x: jnp.transpose(x, axes=self.compute_axis_order), key) query = jnp.reshape(query, (b, n_kv, n // n_kv, t, d)) @@ -510,11 +526,11 @@ def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: s einsum = jnp.einsum if self.kv_quant: einsum = self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(value) - if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): + if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0, 1, 2, 3): out = einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) - elif self.compute_axis_order == (0,2,1,3): + elif self.compute_axis_order == (0, 2, 1, 3): value = jax.tree.map(lambda x: jnp.transpose(x, axes=self.compute_axis_order), value) out = einsum("bkgts,bksd->bkgtd", attn_weights, value) b, n_kv, g, t, d = out.shape @@ -539,7 +555,6 @@ def _get_cache_scale_logical_shape(self, batch, heads): return (batch, self.max_prefill_predict_length, 1, 1) raise f"Invalid config for kv_quant_axis:{self.kv_quant.axis_cfg}" - def _get_prefill_cache_vars(self, batch, heads, kv_head_size): dtype = self._get_cached_kv_dtype(self.dtype) @@ -642,8 +657,8 @@ def _get_ar_cache_vars(self, batch, heads, kv_head_size): cached_lengths_var = self.variable( "cache", "cached_ar_lengths", - nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, )), - (cache_logical_shape[0], ), + nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH,)), + (cache_logical_shape[0],), jnp.int32, ) @@ -670,8 +685,7 @@ def _get_ar_cache_vars(self, batch, heads, kv_head_size): cached_key_scale_var = None cached_value_scale_var = None - cache_index_var = self.variable( - "cache", "cache_ar_index", nn.with_logical_partitioning(jnp.zeros, ()), (1,), jnp.int32) + cache_index_var = self.variable("cache", "cache_ar_index", nn.with_logical_partitioning(jnp.zeros, ()), (1,), jnp.int32) key_vars = (cached_key_var, cached_key_scale_var) value_vars = (cached_value_var, cached_value_scale_var) return key_vars, value_vars, cached_segment_id_var, cache_index_var, cached_lengths_var @@ -697,19 +711,20 @@ def kv_cache_prefill( batch, _, heads, kv_head_size = key.shape assert key.dtype == value.dtype, "Key and Value Dtypes should match." - cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars(batch, heads, kv_head_size) + cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars( + batch, heads, kv_head_size + ) _ = self._get_ar_cache_vars(batch, heads, kv_head_size) # initialize it now key_shaped_for_cache = jnp.transpose(key, self.prefill_cache_axis_order) value_shaped_for_cache = jnp.transpose(value, self.prefill_cache_axis_order) if self.kv_quant: - prefill_key_axis_names = self.transpose_tuple( - self.cache_logical_axis_names, self.prefill_cache_axis_order) - key_shaped_for_cache, key_scale_shaped_for_cache = self.kv_quant.quantize( - key_shaped_for_cache, prefill_key_axis_names) + prefill_key_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.prefill_cache_axis_order) + key_shaped_for_cache, key_scale_shaped_for_cache = self.kv_quant.quantize(key_shaped_for_cache, prefill_key_axis_names) value_shaped_for_cache, value_scale_shaped_for_cache = self.kv_quant.quantize( - value_shaped_for_cache, prefill_key_axis_names) + value_shaped_for_cache, prefill_key_axis_names + ) cached_prefill_key_vars[1].value = key_scale_shaped_for_cache cached_prefill_value_vars[1].value = value_scale_shaped_for_cache @@ -755,10 +770,11 @@ def update_ar_key_value( ar_cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.ar_cache_axis_order) if self.kv_quant: one_token_key_shaped_for_cache, one_token_key_scale_shaped_for_cache = self.kv_quant.quantize( - one_token_key_shaped_for_cache, ar_cache_axis_names) + one_token_key_shaped_for_cache, ar_cache_axis_names + ) one_token_value_shaped_for_cache, one_token_value_scale_shaped_for_cache = self.kv_quant.quantize( - one_token_value_shaped_for_cache, ar_cache_axis_names) - + one_token_value_shaped_for_cache, ar_cache_axis_names + ) ar_cache_update_idx = jnp.squeeze(one_hot_indices) ar_cache_sequence_axis = ar_cache_update_axis = ar_cache_axis_names.index(CACHE_SEQUENCE) @@ -781,27 +797,37 @@ def value_body(i, val): new_token_locations[ar_cache_batch_axis] = i return val.at[tuple(cache_locations)].set(one_token_value_shaped_for_cache[tuple(new_token_locations)]) - cached_key_var.value = jax.lax.fori_loop(0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key_var.value, unroll=8) - cached_value_var.value = jax.lax.fori_loop(0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value_var.value, unroll=8) + cached_key_var.value = jax.lax.fori_loop( + 0, one_token_key_shaped_for_cache.shape[0], key_body, cached_key_var.value, unroll=8 + ) + cached_value_var.value = jax.lax.fori_loop( + 0, one_token_value_shaped_for_cache.shape[0], value_body, cached_value_var.value, unroll=8 + ) else: one_hot_indices = one_hot_indices.astype(int) cached_key_var.value = jax.lax.dynamic_update_index_in_dim( - cached_key_var.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis) + cached_key_var.value, one_token_key_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis + ) cached_value_var.value = jax.lax.dynamic_update_index_in_dim( - cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis) + cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis + ) cached_key_var.value = nn.with_logical_constraint(cached_key_var.value, ar_cache_axis_names) cached_value_var.value = nn.with_logical_constraint(cached_value_var.value, ar_cache_axis_names) - if self.kv_quant: ar_cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order) ar_cache_scale_update_axis = ar_cache_scale_axis_names.index(CACHE_SCALE_SEQUENCE) cached_key_scale_var.value = jax.lax.dynamic_update_index_in_dim( - cached_key_scale_var.value, one_token_key_scale_shaped_for_cache, ar_cache_update_idx, ar_cache_scale_update_axis) + cached_key_scale_var.value, one_token_key_scale_shaped_for_cache, ar_cache_update_idx, ar_cache_scale_update_axis + ) cached_value_scale_var.value = jax.lax.dynamic_update_index_in_dim( - cached_value_scale_var.value, one_token_value_scale_shaped_for_cache, ar_cache_update_idx, ar_cache_scale_update_axis) + cached_value_scale_var.value, + one_token_value_scale_shaped_for_cache, + ar_cache_update_idx, + ar_cache_scale_update_axis, + ) return @@ -816,12 +842,7 @@ def get_cached_values(self, cache_vars, target_dtype, cache_axis_order) -> jax.A elif dtype == jnp.int4: scale_value /= quantizations.MAX_INT4 - cache_value = KVTensor( - qvalue=cache_value, - scale=[scale_value], - scale_t=None, - dequant_dtype=target_dtype - ) + cache_value = KVTensor(qvalue=cache_value, scale=[scale_value], scale_t=None, dequant_dtype=target_dtype) cache_value_in_logical_shape = jax.tree.map(lambda x: self.reverse_transepose(x, cache_axis_order), cache_value) return cache_value_in_logical_shape @@ -851,18 +872,32 @@ def kv_cache_autoregressive( if not is_initialized: raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") - cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, cache_ar_index_var, cache_ar_lengths_var = self._get_ar_cache_vars(batch, heads, kv_head_size) + cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, cache_ar_index_var, cache_ar_lengths_var = ( + self._get_ar_cache_vars(batch, heads, kv_head_size) + ) - self.update_ar_key_value(key, value, cached_ar_key_vars, cached_ar_value_vars, cache_ar_index_var.value, cache_ar_lengths_var.value, use_ragged_attention) + self.update_ar_key_value( + key, + value, + cached_ar_key_vars, + cached_ar_value_vars, + cache_ar_index_var.value, + cache_ar_lengths_var.value, + use_ragged_attention, + ) active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim( cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1 ) - cache_ar_index_var.value = jnp.mod(cache_ar_index_var.value + 1, self.max_target_length - self.max_prefill_predict_length) + cache_ar_index_var.value = jnp.mod( + cache_ar_index_var.value + 1, self.max_target_length - self.max_prefill_predict_length + ) cache_ar_lengths_var.value = cache_ar_lengths_var.value.at[:].add(1) # The below retrieves the existing prefill cache variables, not creating new ones - cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars(batch, heads, kv_head_size) + cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars( + batch, heads, kv_head_size + ) cached_prefill = ( self.get_cached_values(cached_prefill_key_vars, key.dtype, self.prefill_cache_axis_order), @@ -874,11 +909,13 @@ def kv_cache_autoregressive( self.get_cached_values(cached_ar_key_vars, key.dtype, self.ar_cache_axis_order), self.get_cached_values(cached_ar_value_vars, value.dtype, self.ar_cache_axis_order), cached_ar_segment_id_var.value, - cache_ar_lengths_var.value + cache_ar_lengths_var.value, ) return cached_prefill, cached_ar - def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str, use_ragged_attention: bool = False) -> tuple: + def kv_cache( + self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str, use_ragged_attention: bool = False + ) -> tuple: """KV cache takes the current state and updates the state accordingly. The key and value have dimension [b, s, n_kv, d], @@ -933,7 +970,9 @@ def normalize_attention(self, local_outs, local_maxes, local_sums): @nn.compact def __call__(self, query, key, value, decoder_segment_ids, model_mode): - prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention) + prefill_kv_cache, ar_kv_cache = self.kv_cache( + key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention + ) prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( query=query, @@ -1029,7 +1068,6 @@ class Attention(nn.Module): compute_axis_order: AxisIdxes = (0, 1, 2, 3) reshape_q: bool = False - def query_projection(self, inputs_q: Array) -> Array: """Query projection.""" @@ -1121,8 +1159,12 @@ def out_projection(self, output_dim: int, out: Array) -> Array: def key_rotary(self, key: Array, inputs_positions: Array): """Apply Rotary Embedding to key.""" - key = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, max_timescale = self.config.rope_max_timescale, - embedding_dims=self.head_dim, name="key_rotary")(inputs=key, position=inputs_positions) + key = RotaryEmbedding( + min_timescale=self.config.rope_min_timescale, + max_timescale=self.config.rope_max_timescale, + embedding_dims=self.head_dim, + name="key_rotary", + )(inputs=key, position=inputs_positions) return key @nn.compact @@ -1167,8 +1209,12 @@ def __call__( value = self.kv_projection(inputs_kv, proj_name="value") # apply ROPE - query = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, max_timescale = self.config.rope_max_timescale, - embedding_dims=self.head_dim, name="query_rotary")(inputs=query, position=inputs_positions) + query = RotaryEmbedding( + min_timescale=self.config.rope_min_timescale, + max_timescale=self.config.rope_max_timescale, + embedding_dims=self.head_dim, + name="query_rotary", + )(inputs=query, position=inputs_positions) key = self.key_rotary(key, inputs_positions) # annotate with sharding constraint. @@ -1181,7 +1227,7 @@ def __call__( assert not self.config.quantize_kvcache or self.kv_quant attention_op = AttentionOp( - config = self.config, + config=self.config, mesh=self.mesh, attention_kernel=self.attention_kernel, max_target_length=self.max_target_length, diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 1be1ce7a4..aee132419 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -82,7 +82,9 @@ def __call__(self, inputs: Array) -> Array: output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) else: output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = nn.with_logical_constraint(output, ("activation_embed_and_logits_batch", "activation_length", "activation_embed")) + output = nn.with_logical_constraint( + output, ("activation_embed_and_logits_batch", "activation_length", "activation_embed") + ) return output def attend(self, query: Array) -> Array: diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py index ea645c249..3d4204c54 100644 --- a/MaxText/layers/gpt3.py +++ b/MaxText/layers/gpt3.py @@ -153,7 +153,6 @@ class Gpt3MultiHeadAttention(nn.Module): value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) - def qkv_projection(self, inputs: Array, proj_name: str): """Fused QKV projection""" @@ -313,7 +312,7 @@ def __call__( fused_qkv=cfg.fused_qkv, use_bias=True, quant=self.quant, - kv_quant=quantizations.configure_kv_quant(cfg) + kv_quant=quantizations.configure_kv_quant(cfg), ) attention_lnx = attention_layer( diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 496436a3e..4d9c72b10 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -49,6 +49,7 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization + def _convert_to_activation_function(fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: """Convert a string to an activation function.""" if fn_or_string == "linear": @@ -296,41 +297,41 @@ class MoeBlock(nn.Module): quant: Optional[Quant] = None # The first axes is expert - wi_kernel_axes = ('exp', 'embed_no_exp', 'mlp') - wo_kernel_axes = ('exp', 'mlp', 'embed_no_exp') + wi_kernel_axes = ("exp", "embed_no_exp", "mlp") + wo_kernel_axes = ("exp", "mlp", "embed_no_exp") def generate_kernels(self, num_experts, emb_dim, mlp_dim): kernel_in_axis = np.arange(1) kernel_out_axis = np.arange(1, 2) - kernel_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + kernel_init = nd_dense_init(1.0, "fan_in", "truncated_normal") w0_kernel = self.param( - 'wi_0', + "wi_0", nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes), (num_experts, emb_dim, mlp_dim), self.weight_dtype, kernel_in_axis, kernel_out_axis, - ) + ) w0_kernel = jnp.asarray(w0_kernel, self.dtype) w1_kernel = self.param( - 'wi_1', + "wi_1", nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes), (num_experts, emb_dim, mlp_dim), self.weight_dtype, kernel_in_axis, kernel_out_axis, - ) + ) w1_kernel = jnp.asarray(w1_kernel, self.dtype) wo_kernel = self.param( - 'wo', + "wo", nn.with_logical_partitioning(kernel_init, self.wo_kernel_axes), (num_experts, mlp_dim, emb_dim), self.weight_dtype, kernel_in_axis, kernel_out_axis, - ) + ) wo_kernel = jnp.asarray(wo_kernel, self.dtype) return w0_kernel, w1_kernel, wo_kernel @@ -356,32 +357,38 @@ def unpermute(self, intermediate, sorted_selected_experts, weights): unsort_intermediate = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0) reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok)) tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism - reshaped_intermediate = jnp.reshape(unsort_intermediate, (-1, self.num_experts_per_tok, self.config.emb_dim // tensor_parallelism)) + reshaped_intermediate = jnp.reshape( + unsort_intermediate, (-1, self.num_experts_per_tok, self.config.emb_dim // tensor_parallelism) + ) with jax.named_scope("weight_sum"): matmul_precision = lax.Precision(self.config.matmul_precision) - output = jnp.einsum("BKE,BK -> BE", reshaped_intermediate.astype(jnp.float32), reshaped_weights.astype(jnp.float32), precision=matmul_precision) + output = jnp.einsum( + "BKE,BK -> BE", + reshaped_intermediate.astype(jnp.float32), + reshaped_weights.astype(jnp.float32), + precision=matmul_precision, + ) return output.reshape(-1, self.config.max_target_length, self.config.emb_dim // tensor_parallelism).astype(self.dtype) def megablox(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): tile_size = (512, 1024, 1024) + def gmm(inputs, kernel, group_sizes): hs_shape = inputs.shape # pad length is the 1st dimension of tiling size in gmm call pad_length = 512 if hs_shape[0] % pad_length: pad_length = pad_length - hs_shape[0] % pad_length - inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)]) + inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0, 0, 0)]) inputs = inputs.astype(self.dtype) kernel = kernel.astype(self.dtype) - output = mblx.gmm(lhs=inputs, - rhs=kernel, - group_sizes=group_sizes, - preferred_element_type=jnp.bfloat16, - tiling=tile_size) + output = mblx.gmm( + lhs=inputs, rhs=kernel, group_sizes=group_sizes, preferred_element_type=jnp.bfloat16, tiling=tile_size + ) if hs_shape[0] % pad_length: - output = output[:hs_shape[0]] + output = output[: hs_shape[0]] return output # Currently, we only support data and tensor parallelism with Megablox. @@ -391,12 +398,12 @@ def gmm(inputs, kernel, group_sizes): shard_map.shard_map, mesh=self.mesh, in_specs=( - (nn.logical_to_mesh_axes(("activation_batch", None, None))), - (nn.logical_to_mesh_axes(("activation_batch", None, None))), - (nn.logical_to_mesh_axes((None, None, "mlp"))), - (nn.logical_to_mesh_axes((None, None, "mlp"))), - (nn.logical_to_mesh_axes((None, "mlp", None))), - ), + (nn.logical_to_mesh_axes(("activation_batch", None, None))), + (nn.logical_to_mesh_axes(("activation_batch", None, None))), + (nn.logical_to_mesh_axes((None, None, "mlp"))), + (nn.logical_to_mesh_axes((None, None, "mlp"))), + (nn.logical_to_mesh_axes((None, "mlp", None))), + ), out_specs=(nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))), check_rep=False, ) @@ -409,11 +416,10 @@ def wrapper(x, logits, w0, w1, wo): intermediate_output = gmm(intermediate_layer, wo, group_sizes) tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism if tensor_parallelism > 1: - intermediate_output = jax.lax.psum_scatter(intermediate_output, 'tensor', scatter_dimension=1, tiled=True) - output = self.unpermute(intermediate_output, - sorted_selected_experts, - weights) + intermediate_output = jax.lax.psum_scatter(intermediate_output, "tensor", scatter_dimension=1, tiled=True) + output = self.unpermute(intermediate_output, sorted_selected_experts, weights) return output, None + return wrapper(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) def reshape_and_update_weights(self, weights, indices): @@ -425,46 +431,54 @@ def reshape_and_update_weights(self, weights, indices): return update_weights def generate_masks(self, top_k_indices, softmax_probs): - # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor - batch_size, seq_len, _ = top_k_indices.shape - tokens_per_batch = seq_len * self.num_experts_per_tok - expert_capacity_per_batch = int((tokens_per_batch / self.num_experts) * self.config.capacity_factor) - max_logging.log(f"Applying potential token dropping with a batch expert_capacity of {expert_capacity_per_batch}") - - # calculate expert mask and drop tokens if needed - # shape of output expert mask: (batch, sequence, num_experts_per_tok) - # - # A small example: - # give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to expert [0, 1] & [1, 3], - # then expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]], - # after cumsum, expert_token_count becomes [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]], - # if we set expert_capacity=1, - # trunc_expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], - # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of updated_expert_mask is [[[1, 1],[0, 1]]]. - expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) - expert_mask_fused = jnp.reshape(expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts)) - expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None)) - expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) - expert_token_count = jnp.reshape(expert_token_count_fused, ((batch_size, seq_len, self.num_experts_per_tok, self.num_experts))) - expert_token_count = nn.with_logical_constraint(expert_token_count, ("activation_batch", "activation_length", None, None)) - trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) - combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) - - # reshape & update weights - softmax_probs *= combined_expert_mask - - # calculate token position in expert capacity dimension - expert_token_position_fused = expert_mask_fused * expert_token_count_fused - expert_token_position = jnp.reshape(expert_token_position_fused, (batch_size, seq_len, self.num_experts_per_tok, self.num_experts)) - combined_expert_token_position = jnp.sum(expert_token_position, axis=2) * combined_expert_mask - expert_token_position_in_capacity = jax.nn.one_hot(combined_expert_token_position, num_classes=expert_capacity_per_batch+1, dtype=jnp.int32) - - # shape of combine_mask is (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1), - # and cut 0-dimension which is always 0 - combine_mask = (softmax_probs[..., None] * expert_token_position_in_capacity) - combine_mask = combine_mask[..., 1:] - dispatch_mask = combine_mask.astype(bool) - return dispatch_mask, combine_mask + # calculate expert_capacity = (tokens_per_batch / num_experts) * capacity_factor + batch_size, seq_len, _ = top_k_indices.shape + tokens_per_batch = seq_len * self.num_experts_per_tok + expert_capacity_per_batch = int((tokens_per_batch / self.num_experts) * self.config.capacity_factor) + max_logging.log(f"Applying potential token dropping with a batch expert_capacity of {expert_capacity_per_batch}") + + # calculate expert mask and drop tokens if needed + # shape of output expert mask: (batch, sequence, num_experts_per_tok) + # + # A small example: + # give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to expert [0, 1] & [1, 3], + # then expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]], + # after cumsum, expert_token_count becomes [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]], + # if we set expert_capacity=1, + # trunc_expert_mask becomes [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]], + # so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of updated_expert_mask is [[[1, 1],[0, 1]]]. + expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32) + expert_mask_fused = jnp.reshape(expert_mask, (batch_size, seq_len * self.num_experts_per_tok, self.num_experts)) + expert_mask_fused = nn.with_logical_constraint(expert_mask_fused, ("activation_batch", None, None)) + expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1) + expert_token_count = jnp.reshape( + expert_token_count_fused, ((batch_size, seq_len, self.num_experts_per_tok, self.num_experts)) + ) + expert_token_count = nn.with_logical_constraint( + expert_token_count, ("activation_batch", "activation_length", None, None) + ) + trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) + combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) + + # reshape & update weights + softmax_probs *= combined_expert_mask + + # calculate token position in expert capacity dimension + expert_token_position_fused = expert_mask_fused * expert_token_count_fused + expert_token_position = jnp.reshape( + expert_token_position_fused, (batch_size, seq_len, self.num_experts_per_tok, self.num_experts) + ) + combined_expert_token_position = jnp.sum(expert_token_position, axis=2) * combined_expert_mask + expert_token_position_in_capacity = jax.nn.one_hot( + combined_expert_token_position, num_classes=expert_capacity_per_batch + 1, dtype=jnp.int32 + ) + + # shape of combine_mask is (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1), + # and cut 0-dimension which is always 0 + combine_mask = softmax_probs[..., None] * expert_token_position_in_capacity + combine_mask = combine_mask[..., 1:] + dispatch_mask = combine_mask.astype(bool) + return dispatch_mask, combine_mask # See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. def load_balance_loss(self, top_k_indices, logits): @@ -474,7 +488,7 @@ def load_balance_loss(self, top_k_indices, logits): density = jnp.mean(summed_expert_mask, axis=1) # get fraction of probability allocated to each expert density_prob = jnp.mean(logits, axis=1) - loss = jnp.mean(density * density_prob) * (self.num_experts ** 2) * self.config.load_balance_loss_weight + loss = jnp.mean(density * density_prob) * (self.num_experts**2) * self.config.load_balance_loss_weight return loss def get_einsum(self, rhs_mesh_axes: Tuple[Optional[str], ...] = ()): @@ -500,43 +514,69 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): loss = self.load_balance_loss(top_k_indices, softmax_probs) inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) with jax.named_scope("dispatch"): - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision) - dispatch = nn.with_logical_constraint(dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)( + "BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision + ) + dispatch = nn.with_logical_constraint( + dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed") + ) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, None) w0_kernel = nn.with_logical_constraint(w0_kernel, w0_kernel_axes) - layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision).astype(jnp.float32) - layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp")) + layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( + "EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision + ).astype(jnp.float32) + layer_w0 = nn.with_logical_constraint( + layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") + ) with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, None) w1_kernel = nn.with_logical_constraint(w1_kernel, w1_kernel_axes) - layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision).astype(jnp.float32) - layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_exp", "activation_batch_no_exp",None, "activation_mlp")) + layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( + "EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision + ).astype(jnp.float32) + layer_w1 = nn.with_logical_constraint( + layer_w1, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") + ) layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) with jax.named_scope("wo"): wo_kernel_axes = ("exp", None, None) wo_kernel = nn.with_logical_constraint(wo_kernel, wo_kernel_axes) - intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)("EBCH,EHM -> EBCM", layer_multiply, wo_kernel, precision=matmul_precision) - intermediate_layer = nn.with_logical_constraint(intermediate_layer, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)( + "EBCH,EHM -> EBCM", layer_multiply, wo_kernel, precision=matmul_precision + ) + intermediate_layer = nn.with_logical_constraint( + intermediate_layer, ("activation_exp", "activation_batch_no_exp", None, "activation_embed") + ) with jax.named_scope("combine"): # Matmul & element wise operation - output = self.get_einsum(rhs_mesh_axes=mask_axes)("EBCM,BSEC -> BSM", intermediate_layer, combine_mask, precision=matmul_precision) + output = self.get_einsum(rhs_mesh_axes=mask_axes)( + "EBCM,BSEC -> BSM", intermediate_layer, combine_mask, precision=matmul_precision + ) return output, loss else: weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) with jax.named_scope("wi_0"): - layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)("BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision).astype(jnp.float32) + layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( + "BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision + ).astype(jnp.float32) with jax.named_scope("wi_1"): - layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)("BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision).astype(jnp.float32) + layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( + "BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision + ).astype(jnp.float32) layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) with jax.named_scope("wo"): - intermediate_layer = self.get_einsum(rhs_mesh_axes=self.wo_kernel_axes)("BSEH,EHM -> BSEM", layer_multiply, wo_kernel, precision=matmul_precision) + intermediate_layer = self.get_einsum(rhs_mesh_axes=self.wo_kernel_axes)( + "BSEH,EHM -> BSEM", layer_multiply, wo_kernel, precision=matmul_precision + ) with jax.named_scope("w_sum"): weights_axis = ("activation_batch", "activation_length", "activation_exp") - output = self.get_einsum(rhs_mesh_axes=weights_axis)("BSEM,BSE -> BSM", intermediate_layer.astype(jnp.float32), weights.astype(jnp.float32)).astype(self.dtype) + output = self.get_einsum(rhs_mesh_axes=weights_axis)( + "BSEM,BSE -> BSM", intermediate_layer.astype(jnp.float32), weights.astype(jnp.float32) + ).astype(self.dtype) return output, None @nn.compact @@ -544,18 +584,17 @@ def __call__(self, inputs): cfg = self.config inputs = inputs.astype(cfg.dtype) gate_logits = DenseGeneral( - self.num_experts, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - name="gate", - matmul_precision=self.config.matmul_precision)(inputs) + self.num_experts, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + name="gate", + matmul_precision=self.config.matmul_precision, + )(inputs) - w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, - cfg.emb_dim, - cfg.mlp_dim) + w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, cfg.mlp_dim) if cfg.megablox: max_logging.log("Running MoE megablox implementation.") diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 8ac08787b..9f89ec06e 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -129,15 +129,13 @@ def __call__( num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, mesh=mesh, - kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), - kernel_axes=('embed', 'mlp'), + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", "mlp"), dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, quant=self.quant, )(hidden_states) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) else: mlp_lnx = linears.MlpBlock( intermediate_dim=cfg.mlp_dim, diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 82cc47294..7739dcb06 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -146,8 +146,10 @@ def __call__( return layer_output, None if cfg.scan_layers else layer_output + class SequentialBlockDecoderLayers(nn.Module): """Sequential unscanned series of decoder layers.""" + decoder_layer: Any num_decoder_layers: int config: Config @@ -158,14 +160,15 @@ class SequentialBlockDecoderLayers(nn.Module): def __call__(self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic, model_mode) -> jnp.ndarray: for lyr in range(self.num_decoder_layers): inputs = self.decoder_layer(config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant)( - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) return inputs + class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" @@ -224,26 +227,26 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mes params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) cache_spec = 0 scan_fn = nn.scan( - decoder_layer, - variable_axes={ - "params": params_spec, - "cache": cache_spec, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, - split_rngs={ - "params": True, - "dropout": cfg.enable_dropout, - }, - in_axes=( - nn.broadcast, - nn.broadcast, - nn.broadcast, - nn.broadcast, - ), - length=length, - metadata_params={nn.PARTITION_NAME: metdata_axis_name}, + decoder_layer, + variable_axes={ + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={ + "params": True, + "dropout": cfg.enable_dropout, + }, + in_axes=( + nn.broadcast, + nn.broadcast, + nn.broadcast, + nn.broadcast, + ), + length=length, + metadata_params={nn.PARTITION_NAME: metdata_axis_name}, ) return scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant) @@ -338,20 +341,28 @@ def __call__( static_argnums=(-1, -2, -3, -4, -5), ) if cfg.using_pipeline_parallelism: - if cfg.num_layers_per_pipeline_stage == 1: - stage_module = BlockLayer(config=cfg, mesh=mesh, quant=self.quant) - elif cfg.scan_layers: - stage_module = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_layers_per_pipeline_stage, "layers_per_stage", mesh) - elif not cfg.scan_layers: - stage_module=SequentialBlockDecoderLayers(decoder_layer=RemattedBlockLayer, num_decoder_layers=cfg.num_layers_per_pipeline_stage, config=cfg, mesh=mesh,quant=self.quant) - - y = pipeline.Pipeline(config=cfg, mesh=mesh, layers=stage_module, remat_policy=policy)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, + if cfg.num_layers_per_pipeline_stage == 1: + stage_module = BlockLayer(config=cfg, mesh=mesh, quant=self.quant) + elif cfg.scan_layers: + stage_module = self.scan_decoder_layers( + cfg, RemattedBlockLayer, cfg.num_layers_per_pipeline_stage, "layers_per_stage", mesh + ) + elif not cfg.scan_layers: + stage_module = SequentialBlockDecoderLayers( + decoder_layer=RemattedBlockLayer, + num_decoder_layers=cfg.num_layers_per_pipeline_stage, + config=cfg, + mesh=mesh, + quant=self.quant, ) + + y = pipeline.Pipeline(config=cfg, mesh=mesh, layers=stage_module, remat_policy=policy)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) else: if cfg.scan_layers: y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_decoder_layers, "layers", mesh)( @@ -401,7 +412,9 @@ def __call__( )( y ) # We do not quantize the logits matmul. - logits = nn.with_logical_constraint(logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")) + logits = nn.with_logical_constraint( + logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") + ) logits = logits.astype(jnp.float32) return logits diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py index 62aea086c..f5601ef9f 100644 --- a/MaxText/layers/pipeline.py +++ b/MaxText/layers/pipeline.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -''' Pipeline layer wrapping a decoder layer(s). Supports circular pipelining ''' +""" Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """ import jax import jax.ad_checkpoint @@ -24,9 +24,10 @@ import functools from typing import Any + class Pipeline(nn.Module): """Module that implements pipelining across stages. - + This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. This will produce a pipeline pattern if the stage dimension is sharded. @@ -40,9 +41,9 @@ class Pipeline(nn.Module): mesh: The device mesh of the system. remat_policy: Remat policy to use for the loop iterations """ - + config: common_types.Config - layers: nn.Module # The name of this property (layers) is reflected in the state pytree and thus also checkpoints. + layers: nn.Module # The name of this property (layers) is reflected in the state pytree and thus also checkpoints. mesh: common_types.Mesh remat_policy: Any = None @@ -55,7 +56,10 @@ def setup(self): self.use_circ_storage = self.need_circ_storage() def need_circ_storage(self): - return self.config.num_pipeline_repeats > 1 and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay + return ( + self.config.num_pipeline_repeats > 1 + and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay + ) def iterations_to_complete_first_microbatch_one_repeat(self): # Return the number of iterations it takes for microbatch 0 to finish a repeat @@ -63,30 +67,43 @@ def iterations_to_complete_first_microbatch_one_repeat(self): def iterations_to_complete_first_microbatch(self): # Return the number of iterations it takes for microbatch 0 to finish the last stage of the last repeat - return self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + self.iterations_to_complete_first_microbatch_one_repeat() + return ( + self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + + self.iterations_to_complete_first_microbatch_one_repeat() + ) def init_states(self, inputs): - '''Initialize components of state: state_io, shift, circular_storage and circular_storage_mover - Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] - - Returns a dictionary with properties - shift: zeros shape [num_stages, micro_size, sequence, embed] - prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None - state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] - circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None - circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None - loop_iteration: scalar set initially to 0. - ''' + """Initialize components of state: state_io, shift, circular_storage and circular_storage_mover + Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] + + Returns a dictionary with properties + shift: zeros shape [num_stages, micro_size, sequence, embed] + prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None + state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] + circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None + circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None + loop_iteration: scalar set initially to 0. + """ # Shift is used to rotate the output of each pipeline into the input of the next # shift has shape [num_stages, micro_size, sequence, embed] shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - shift = nn.with_logical_constraint(shift, ("activation_stage", "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules,mesh=self.mesh) + shift = nn.with_logical_constraint( + shift, + ("activation_stage", "activation_batch", "activation_length", "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) # Prev outputs has the same shape of the output (and shift) if self.config.pipeline_delay_activation_forwarding: prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - prev_outputs = nn.with_logical_constraint(prev_outputs, ("activation_stage", "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules,mesh=self.mesh) + prev_outputs = nn.with_logical_constraint( + prev_outputs, + ("activation_stage", "activation_batch", "activation_length", "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) else: prev_outputs = None @@ -94,7 +111,12 @@ def init_states(self, inputs): # state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed] state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. - state_io = nn.with_logical_constraint(state_io, ("activation_stage", None, "activation_batch", "activation_length", "activation_embed"),rules=self.config.logical_axis_rules, mesh=self.mesh) + state_io = nn.with_logical_constraint( + state_io, + ("activation_stage", None, "activation_batch", "activation_length", "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) # circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only needed # when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without additional storage. @@ -104,45 +126,45 @@ def init_states(self, inputs): # fine as long as there is some amount of additional sharding axes, e.g. FSDP, TP, DP (e.g. there are many devices that shard stage 0) # We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101). if self.use_circ_storage: - circ_storage = jnp.zeros((self.num_stages,) + inputs.shape , dtype=inputs.dtype) + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) else: - circ_storage = None + circ_storage = None # circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration of delay # circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed] if self.use_circ_storage: - circ_storage_mover = shift + circ_storage_mover = shift else: - circ_storage_mover = None + circ_storage_mover = None init_loop_state = { - "state_io": state_io, - "shift": shift, - "circ_storage": circ_storage, - "circ_storage_mover": circ_storage_mover, - "loop_iteration": 0, - "prev_outputs": prev_outputs + "state_io": state_io, + "shift": shift, + "circ_storage": circ_storage, + "circ_storage_mover": circ_storage_mover, + "loop_iteration": 0, + "prev_outputs": prev_outputs, } return init_loop_state def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): - ''' + """ Construct stages_in: the global array that is operated on for this iteration, shape same as shift=[stages, micro_size, sequence, embed] This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from state_io or an old one from circ_storage - ''' + """ # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage - state_io_slice = state_io[:,state_io_batch_idx] + state_io_slice = state_io[:, state_io_batch_idx] if self.use_circ_storage: - # Setup potential input from circ_storage, which also has a rotating index for microbatch, size of num_microbatches - circ_storage_batch_idx = loop_iteration % self.config.num_pipeline_microbatches - circular_stage_in = circ_storage[:,circ_storage_batch_idx] + # Setup potential input from circ_storage, which also has a rotating index for microbatch, size of num_microbatches + circ_storage_batch_idx = loop_iteration % self.config.num_pipeline_microbatches + circular_stage_in = circ_storage[:, circ_storage_batch_idx] else: - # The last stage immediately flows into the first stage, use this rotated shift instead of circular storage - circular_stage_in = shift - + # The last stage immediately flows into the first stage, use this rotated shift instead of circular storage + circular_stage_in = shift + # For early loop iterations we grab a new input for stage 0 from the state_io. Once each microbatch has left state_io # we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. from circ_storage). first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) @@ -151,15 +173,19 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # However these bubble computation results remain in the shift buffer (do not make it back to state_io) and are thus discarded / not returned. # The final returned output is stored in the state_io, which has the appropriate total size of num_microbatches. The state_io will not contain bubble results # at the end of the last iteration. - def select_state_or_input(first_stage_in, shift): - # Selects input for stage 0, shift for other stages - return jnp.where(jax.lax.broadcasted_iota('int32', shift.shape, 0) == 0, first_stage_in, shift) + # Selects input for stage 0, shift for other stages + return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) - stages_in = nn.with_logical_constraint(stages_in, ("activation_stage", "activation_batch", "activation_length", "activation_embed"), rules=self.config.logical_axis_rules, mesh=self.mesh) + stages_in = nn.with_logical_constraint( + stages_in, + ("activation_stage", "activation_batch", "activation_length", "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) return stages_in def shard_dim_by_stages(self, x, dim: int): @@ -172,9 +198,9 @@ def shard_dim_by_stages(self, x, dim: int): return jax.lax.with_sharding_constraint(x, sharding) def get_microbatch_and_repeat_ids(self, loop_iteration): - '''Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular''' + """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and non-circular""" # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is one behind due to bubble, etc for other stages - microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) + microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids @@ -192,13 +218,16 @@ def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stage The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights removed. """ + def _gather_one(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) gathered_weights_stage_dim = 0 repeat_ids = self.shard_dim_by_stages(repeat_ids, 0) weights = self.shard_dim_by_stages(weights, stages_dim_in_weights) - stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)(weights, repeat_ids) + stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( + weights, repeat_ids + ) stage_weights = self.shard_dim_by_stages(stage_weights, gathered_weights_stage_dim) return stage_weights @@ -217,37 +246,39 @@ def vmap_gather(self, xs, ids, ids_dim): The per-stage gathered values. The shape is xs.shape but with ids_dim size replaced with [num_stages]. """ + def _gather_one(x, i): - return jnp.squeeze( - jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) ids = self.shard_dim_by_stages(ids, 0) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) return self.shard_dim_by_stages(outs, 0) - def get_new_loop_state(self,output, loop_state): - ''' - Update the various buffers given the output of the most recent iteration - * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) - * Pushing inputs up from top of state_io into first stage of shift - * Pulling outputs up from last stage of shift into bottom of state_io - * shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to right/down - * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage - * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration - * prev_outputs: is set to the current output - ''' - - old_state_io = loop_state['state_io'] + def get_new_loop_state(self, output, loop_state): + """ + Update the various buffers given the output of the most recent iteration + * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) + * Pushing inputs up from top of state_io into first stage of shift + * Pulling outputs up from last stage of shift into bottom of state_io + * shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to right/down + * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage + * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration + * prev_outputs: is set to the current output + """ + + old_state_io = loop_state["state_io"] old_circ_storage = loop_state["circ_storage"] old_circ_storage_mover = loop_state["circ_storage_mover"] loop_iteration = loop_state["loop_iteration"] old_prev_outputs = loop_state["prev_outputs"] + # Shift becomes a rotated-right version of the previous output def _rotate_right(output_in): # Use lax.slice to avoid generating a gather. last = jax.lax.slice_in_dim(output_in, self.num_stages - 1, self.num_stages, axis=0) except_last = jax.lax.slice_in_dim(output_in, 0, self.num_stages - 1, axis=0) return jnp.concatenate([last, except_last], axis=0) + if self.config.pipeline_delay_activation_forwarding: new_shift = _rotate_right(old_prev_outputs) new_prev_outputs = output @@ -259,48 +290,53 @@ def _rotate_right(output_in): # Insert the circ_storage_mover into new_circ_storage at a microbatch-rotating index. # circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped compute/async transfers def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): - rotated = _rotate_right(circ_storage_mover_in) - rotated = jnp.expand_dims(rotated, 1) - # We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0 - offset = (loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1) % self.config.num_pipeline_microbatches # Note extra -1 b/c grabbing from the previous output - using circ_storage_mover before it is updated - return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) + rotated = _rotate_right(circ_storage_mover_in) + rotated = jnp.expand_dims(rotated, 1) + # We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0 + offset = ( + loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1 + ) % self.config.num_pipeline_microbatches # Note extra -1 b/c grabbing from the previous output - using circ_storage_mover before it is updated + return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) + new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage) new_circ_storage_mover = output else: - new_circ_storage = None - new_circ_storage_mover = None + new_circ_storage = None + new_circ_storage_mover = None # Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the last stage output stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] + def _update_state_io(state_in, stream_slice, output): - # Shift the current slice to the left, then fill the last stage with the final output. - padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) - stream_slice = jax.lax.slice_in_dim( - jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) - stream_slice = jnp.where( - jax.lax.broadcasted_iota('int32', stream_slice.shape, 0) == self.num_stages - 1, output, - stream_slice) - stream_slice = jnp.expand_dims(stream_slice, 1) - return jax.lax.dynamic_update_slice_in_dim( - state_in, stream_slice, stream_buf_idx, axis=1) + # Shift the current slice to the left, then fill the last stage with the final output. + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) + stream_slice = jnp.expand_dims(stream_slice, 1) + return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) + new_state = _update_state_io(old_state_io, stream_slice, output) - + new_loop_state = { - "state_io": new_state, - "shift": new_shift, - "circ_storage": new_circ_storage, - "circ_storage_mover": new_circ_storage_mover, - "loop_iteration": loop_iteration + 1, - "prev_outputs": new_prev_outputs + "state_io": new_state, + "shift": new_shift, + "circ_storage": new_circ_storage, + "circ_storage_mover": new_circ_storage_mover, + "loop_iteration": loop_iteration + 1, + "prev_outputs": new_prev_outputs, } return new_loop_state - + def permute_output_micro_per_stage_dim(self, output): # The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to state_io - it will land on a different index of state_io depending on the number of iterations. microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage - permutation = (np.arange(self.microbatches_per_stage) + microbatch_0_idx) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear in idx 1, etc - output = output[:,permutation] + permutation = ( + np.arange(self.microbatches_per_stage) + microbatch_0_idx + ) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear in idx 1, etc + output = output[:, permutation] return output def get_main_vmap_func(self): @@ -309,88 +345,114 @@ def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positi return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) vmap_func = nn.vmap( - func_to_vmap, - in_axes=(0, 0, 0, None, None), - spmd_axis_name='stage', - variable_axes={'params': 0}, - split_rngs={'params': self.is_initializing()}, - metadata_params={ - nn.PARTITION_NAME: "layers", - 'sub_weight_split_dims_mapping': (None), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages} + func_to_vmap, + in_axes=(0, 0, 0, None, None), + spmd_axis_name="stage", + variable_axes={"params": 0}, + split_rngs={"params": self.is_initializing()}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, ) return vmap_func def run_one_iteration(self, loop_state, positions, segment_ids, deterministic, model_mode, decoder_layer_instance): - '''Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.''' - state_io = loop_state['state_io'] - shift = loop_state["shift"] - circ_storage = loop_state["circ_storage"] - loop_iteration = loop_state["loop_iteration"] - - microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) - - stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) - # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire buffer. - stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, 'iteration_input') - stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None - stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None - - vmap_func = self.get_main_vmap_func() - - if self.config.num_pipeline_repeats > 1: - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - - def prepare_vars_for_main_vmap(weights): - def gather_weights_for_stages_in(weights): - return jax.tree.map( - functools.partial( - self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1), - weights) - circular_metadata_params={ - nn.PARTITION_NAME: "circular_repeats", - 'sub_weight_split_dims_mapping': (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - 'optimizer_dims_mapping': None, - } - weights = meta.remove_axis(weights, 0, circular_metadata_params) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular entry per stage. - weights = gather_weights_for_stages_in(weights) - return weights - - vmap_func = nn.map_variables( - vmap_func, - mapped_collections=["params", "non_trainable", "summaries", "intermediates"], - mutable=True, - trans_in_fn=prepare_vars_for_main_vmap, - ) + """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.""" + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire buffer. + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None - stages_output = vmap_func(decoder_layer_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) - if self.config.scan_layers: - stages_output = stages_output[0] + vmap_func = self.get_main_vmap_func() - new_state = self.get_new_loop_state(stages_output, loop_state) - return new_state + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def prepare_vars_for_main_vmap(weights): + def gather_weights_for_stages_in(weights): + return jax.tree.map( + functools.partial( + self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 + ), + weights, + ) + + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + weights = meta.remove_axis( + weights, 0, circular_metadata_params + ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular entry per stage. + weights = gather_weights_for_stages_in(weights) + return weights + + vmap_func = nn.map_variables( + vmap_func, + mapped_collections=["params", "non_trainable", "summaries", "intermediates"], + mutable=True, + trans_in_fn=prepare_vars_for_main_vmap, + ) + + stages_output = vmap_func( + decoder_layer_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode + ) + if self.config.scan_layers: + stages_output = stages_output[0] + + new_state = self.get_new_loop_state(stages_output, loop_state) + return new_state @nn.compact - def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp.ndarray, deterministic: bool, model_mode=common_types.MODEL_MODE_TRAIN) -> jnp.ndarray: - ''' The main method that maps the series of decoder layer inputs to final layer outputs. + def __call__( + self, + inputs: jnp.ndarray, + segment_ids: jnp.ndarray, + positions: jnp.ndarray, + deterministic: bool, + model_mode=common_types.MODEL_MODE_TRAIN, + ) -> jnp.ndarray: + """The main method that maps the series of decoder layer inputs to final layer outputs. Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape [global_batch], and internally this will be reshapped into microbatches. - ''' + """ # Reshape inputs of [global_batch, ...] to [microbatches, pipeline_microbatch_sizes, ...] - inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) - example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module weights. + inputs = inputs.reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ) + ) + example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module weights. if positions is not None: - positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + positions = positions.reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) position_idx = 0 else: example_position = None position_idx = None if segment_ids is not None: - segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + segment_ids = segment_ids.reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) segment_idx = 0 else: @@ -398,105 +460,120 @@ def __call__(self, inputs: jnp.ndarray, segment_ids: jnp.ndarray, positions:jnp. segment_idx = None loop_state = self.init_states(inputs) - + # Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) compute to perform # Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages * repeats / num_stages = num_micro * repeats # However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes # num_micro * repeat iterations for the last microbatch to start the final repeat, then an additional num_stages - 1 to finish the final repeat. # Thus the total iterations is num_micro * repeat + num_stages - 1, and we may consider the num_stages - 1 as bubble. - # The bubble doubles when we use forwarding delay. - bubble_iterations = self.forwarding_delay * (self.num_stages - 1) + # The bubble doubles when we use forwarding delay. + bubble_iterations = self.forwarding_delay * (self.num_stages - 1) real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats total_iterations = real_iterations + bubble_iterations - if self.is_initializing(): - vmap_func = self.get_main_vmap_func() + if self.is_initializing(): + vmap_func = self.get_main_vmap_func() + + if self.config.num_pipeline_repeats > 1: + # To shard the weights on initialization for the circular pipeline we create weights of + # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis. + # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization. + vmap_func = nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={ + "params": 0, + "non_trainable": 0, + "hyper_params": 0, + }, + split_rngs={"params": True}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + }, + ) - if self.config.num_pipeline_repeats > 1: - # To shard the weights on initialization for the circular pipeline we create weights of - # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis. - # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization. - vmap_func= nn.vmap( - vmap_func, - in_axes=(0, segment_idx, position_idx, None, None), - variable_axes={ - 'params': 0, - "non_trainable": 0, - "hyper_params": 0, - }, - split_rngs={'params': True}, - metadata_params={ - nn.PARTITION_NAME: "circular_repeats", - 'sub_weight_split_dims_mapping': (None,), - "is_initializing": True, - "x_times": self.config.num_pipeline_repeats, - 'optimizer_dims_mapping': None, - } + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = ( + jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) + if example_segmentation is not None + else None + ) + example_position = ( + jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) if example_position is not None else None ) + # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for the full total_iterations. + stage_outputs = vmap_func( + self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode + ) + if self.config.scan_layers: + stage_outputs = stage_outputs[0] + + # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output which has + # shape [pipeline_microbatch_size, sequence, embed] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap + broadcasted_stage_outpus = jax.lax.broadcast( + stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] + ) + return jnp.reshape( + broadcasted_stage_outpus, + [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], + ) + + def run_iteration_scannable(model, loop_state, xs): + # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we explicitly wrap + # the run_one_iteration in this method - the first argument model (i.e. self) is a nn.module instance. + return model.run_one_iteration(loop_state, positions, segment_ids, deterministic, model_mode, model.layers), None - example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) - example_segmentation = jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) if example_segmentation is not None else None - example_position = jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) if example_position is not None else None - # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for the full total_iterations. - stage_outputs = vmap_func(self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode) - if self.config.scan_layers: - stage_outputs = stage_outputs[0] - - # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output which has - # shape [pipeline_microbatch_size, sequence, embed] - if self.config.num_pipeline_repeats > 1: - stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap - broadcasted_stage_outpus = jax.lax.broadcast(stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size]) - return jnp.reshape(broadcasted_stage_outpus, [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim]) - - def run_iteration_scannable(model,loop_state, xs): - # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we explicitly wrap - # the run_one_iteration in this method - the first argument model (i.e. self) is a nn.module instance. - return model.run_one_iteration(loop_state, positions, segment_ids, deterministic, model_mode, model.layers), None if self.remat_policy is not None: remat_policy = jax.checkpoint_policies.save_from_both_policies( - self.remat_policy, - jax.checkpoint_policies.save_only_these_names('iteration_input') + self.remat_policy, jax.checkpoint_policies.save_only_these_names("iteration_input") ) else: - remat_policy = jax.checkpoint_policies.save_only_these_names('iteration_input') + remat_policy = jax.checkpoint_policies.save_only_these_names("iteration_input") run_one_iteration_rematted = nn.remat( - run_iteration_scannable, - prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan - policy=remat_policy + run_iteration_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan + policy=remat_policy, ) # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. if self.config.scan_pipeline_iterations: variable_carry = [] - variable_broadcast = ["params"] # All loop iterations need the weights for the full pipeline. + variable_broadcast = ["params"] # All loop iterations need the weights for the full pipeline. if self.is_mutable_collection("non_trainable"): variable_carry.append("non_trainable") else: variable_broadcast.append("non_trainable") run_all_iterations_scanned = nn.scan( - run_one_iteration_rematted, - variable_axes={ - "summaries": 0, - "aux_loss": 0, - "intermediates": 0, - "hyper_params": 0, - }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, - # Dropout/aqt keys will be split for each iteration. - split_rngs={"random": True}, - length=total_iterations, - ) + run_one_iteration_rematted, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=total_iterations, + ) loop_state, _ = run_all_iterations_scanned(self, loop_state, None) else: - for loop_iteration in range(total_iterations): - loop_state, _ = run_one_iteration_rematted(self, loop_state, None) + for loop_iteration in range(total_iterations): + loop_state, _ = run_one_iteration_rematted(self, loop_state, None) # The final output is located in the input/output array, however the output microbatches may be permuted relative to the input final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] - final_output = jnp.reshape(final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) - - return final_output \ No newline at end of file + final_output = jnp.reshape( + final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim) + ) + + return final_output diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 5b5fcf050..b1325df5d 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -62,19 +62,15 @@ def _tiling_fn(lhs, rhs, dimension_numbers, tile_size): ) for lhs_idx, rhs_idx in zip(lhs_ca, rhs_ca): - ret.lhs.contraction_axes.append( - tiled_dot_general.AxisTiling(axis=lhs_idx, tile_size=tile_size, tile_count=None) - ) - ret.rhs.contraction_axes.append( - tiled_dot_general.AxisTiling( - axis=rhs_idx, tile_size=tile_size, tile_count=None - ) - ) + ret.lhs.contraction_axes.append(tiled_dot_general.AxisTiling(axis=lhs_idx, tile_size=tile_size, tile_count=None)) + ret.rhs.contraction_axes.append(tiled_dot_general.AxisTiling(axis=rhs_idx, tile_size=tile_size, tile_count=None)) return ret -def _rhs_axis_metadata_wrapper(x: jnp.ndarray, tile_map, no_sharding_axis: Sequence[int], mesh_axes: Tuple[str, ...], is_tiled: bool): +def _rhs_axis_metadata_wrapper( + x: jnp.ndarray, tile_map, no_sharding_axis: Sequence[int], mesh_axes: Tuple[str, ...], is_tiled: bool +): mesh_axes = list(mesh_axes) if is_tiled: # tile_map is a mapping between original rank and a list of new, tiled rank. @@ -103,16 +99,16 @@ class AqtQuantization: def _get_mixed_precision_cfg(self): quant_dg = None - is_tiled=False - tiling_fn=None - module_path = '/'.join(nn.module._context.module_stack[-1].path) + is_tiled = False + tiling_fn = None + module_path = "/".join(nn.module._context.module_stack[-1].path) for layer_name_re, layer_quant_dg in self.quant_dg.items(): if re.fullmatch(layer_name_re, module_path): quant_dg, tile_size = layer_quant_dg if quant_dg is None: - quant_dg, tile_size = self.quant_dg['default'] + quant_dg, tile_size = self.quant_dg["default"] if tile_size != -1: - is_tiled=True + is_tiled = True tiling_fn = functools.partial(_tiling_fn, tile_size=tile_size) return quant_dg, is_tiled, tiling_fn @@ -126,9 +122,8 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): if isinstance(self.quant_dg, dict): quant_dg, is_tiled, tiling_fn = self._get_mixed_precision_cfg() else: - quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None - rhs_axis_metadata_wrapper=self._get_rhs_axis_metadata_wrapper( - mesh_axes, is_tiled) + quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None + rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(mesh_axes, is_tiled) aqt_dg_cls = functools.partial( aqt_flax.AqtDotGeneral, quant_dg, @@ -137,14 +132,13 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, rhs_axis_metadata_wrapper=rhs_axis_metadata_wrapper, use_legacy_freezer=False, - tiling_fn=tiling_fn + tiling_fn=tiling_fn, ) return aqt_dg_cls def einsum(self, mesh_axes: Tuple[str, ...] = ()): """Returns einsum configured with aqt params.""" - rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper( - mesh_axes) + rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(mesh_axes) aqt_einsum = functools.partial( aqt_flax.AqtEinsum( cfg=self.quant_dg, @@ -168,6 +162,7 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns dot_general configured with aqt params.""" return nn.Fp8DotGeneralOp + def _get_int8_quant_config(config): drhs_bits = None drhs_accumulator_dtype = None @@ -175,20 +170,18 @@ def _get_int8_quant_config(config): if config.quantization_local_shard_count != 0: drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 - drhs_local_aqt = aqt_config.LocalAqt( - contraction_axis_shard_count=config.quantization_local_shard_count - ) + drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count) return aqt_config.config_v3( - fwd_bits=8, - dlhs_bits=8, - drhs_bits=drhs_bits, - rng_type="jax.uniform", - dlhs_local_aqt=None, - drhs_local_aqt=drhs_local_aqt, - fwd_accumulator_dtype=jnp.int32, - dlhs_accumulator_dtype=jnp.int32, - drhs_accumulator_dtype=drhs_accumulator_dtype, - ) + fwd_bits=8, + dlhs_bits=8, + drhs_bits=drhs_bits, + rng_type="jax.uniform", + dlhs_local_aqt=None, + drhs_local_aqt=drhs_local_aqt, + fwd_accumulator_dtype=jnp.int32, + dlhs_accumulator_dtype=jnp.int32, + drhs_accumulator_dtype=drhs_accumulator_dtype, + ) def _get_weight_only_quant_config(lhs_bits=None, rhs_bits=None): @@ -207,8 +200,7 @@ def _get_mixed_precision_quant_config(config, config_file): scale = layer_quantization_config.get("scale", 1.0) aqt_dg = aqt_config.dot_general_make(lhs_bits=None, rhs_bits=rhs_num_bits) if scale < 1.0: - aqt_dg.fwd.dg_quantizer.rhs.calibration = functools.partial( - calibration.AbsMaxCalibration, scale=scale) + aqt_dg.fwd.dg_quantizer.rhs.calibration = functools.partial(calibration.AbsMaxCalibration, scale=scale) ret_config[layer_name_re] = [aqt_dg, tile_size] return ret_config @@ -290,14 +282,16 @@ def remove_quantized_params(params, aqt_vars): tree_flat[i] = v return tree_unflatten(tree_struct, tree_flat) + def configure_kv_quant(config): return None if not config.quantize_kvcache else KVQuant(config) + class KVQuant: axis_cfg = "" dtype = None - def __init__(self, config:Config): + def __init__(self, config: Config): assert config.quantize_kvcache self.axis_cfg = config.kv_quant_axis self.dtype = self._get_dtype(config.kv_quant_dtype) @@ -313,15 +307,12 @@ def _get_max_axis(self, axis_names: AxisNames): if self.axis_cfg == "dkv": return axis_names.index(CACHE_KV) if self.axis_cfg == "heads_and_dkv": - return ( - axis_names.index(CACHE_HEADS), - axis_names.index(CACHE_KV) - ) + return (axis_names.index(CACHE_HEADS), axis_names.index(CACHE_KV)) raise ValueError(f"Invalid KV quant axis cfg: {self.axis_cfg}") def quantize(self, kv: Array, axis_names: AxisNames): """Quantize key/values stored in kvcache.""" - assert self.axis_cfg, 'KV quant axis cannot be None' + assert self.axis_cfg, "KV quant axis cannot be None" max_axis = self._get_max_axis(axis_names) scale = jnp.max(jnp.abs(kv), axis=max_axis, keepdims=True) if self.dtype == jnp.int8: @@ -332,42 +323,35 @@ def quantize(self, kv: Array, axis_names: AxisNames): return value, scale raise ValueError(f"Invalid KV quant dtype:{self.dtype}.") - def einsum_fn_with_rhs_qtensor( - self, - kv: Array| aqt_tensor.QTensor, - rhs_dequant_mode=None, - rhs_calibration_mode=None - ): + def einsum_fn_with_rhs_qtensor(self, kv: Array | aqt_tensor.QTensor, rhs_dequant_mode=None, rhs_calibration_mode=None): # Assumes kv is already quantized. einsum = jnp.einsum if isinstance(kv, aqt_tensor.QTensor): num_bits = 4 if kv.qvalue.dtype == jnp.int4 else 8 kv_cfg = aqt_config.dot_general_make( - lhs_bits=None, - rhs_bits=num_bits, - bwd_bits=None, - use_fwd_quant=False, - ) + lhs_bits=None, + rhs_bits=num_bits, + bwd_bits=None, + use_fwd_quant=False, + ) if rhs_dequant_mode: - aqt_config.set_fwd_dequant_mode( - kv_cfg, rhs_dequant_mode=rhs_dequant_mode - ) + aqt_config.set_fwd_dequant_mode(kv_cfg, rhs_dequant_mode=rhs_dequant_mode) if rhs_calibration_mode: aqt_config.set_fwd_calibration_mode( - kv_cfg, - rhs_calibration_mode=rhs_calibration_mode, - ) - einsum = aqt_flax.AqtEinsum( - rhs_quant_mode=aqt_flax.QuantMode.TRAIN, - lhs_freeze_mode=aqt_flax.FreezerMode.NONE, - rhs_freeze_mode=aqt_flax.FreezerMode.NONE, - cfg=kv_cfg + kv_cfg, + rhs_calibration_mode=rhs_calibration_mode, ) + einsum = aqt_flax.AqtEinsum( + rhs_quant_mode=aqt_flax.QuantMode.TRAIN, + lhs_freeze_mode=aqt_flax.FreezerMode.NONE, + rhs_freeze_mode=aqt_flax.FreezerMode.NONE, + cfg=kv_cfg, + ) return einsum def einsum_fn_with_rhs_qtensor_and_dequant(self, value): return self.einsum_fn_with_rhs_qtensor( - value, - rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT, - rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS - ) + value, + rhs_dequant_mode=aqt_config.DequantMode.OTHER_INPUT, + rhs_calibration_mode=aqt_config.CalibrationMode.REMAINING_AXIS, + ) diff --git a/MaxText/layers/simple_layer.py b/MaxText/layers/simple_layer.py index f52e76a56..a46ba9861 100644 --- a/MaxText/layers/simple_layer.py +++ b/MaxText/layers/simple_layer.py @@ -22,17 +22,19 @@ # pytype: disable=attribute-error + class SimpleDecoderLayer(nn.Module): - """ Decoder layer consisting of a single [embed, embed] weight matrix """ + """Decoder layer consisting of a single [embed, embed] weight matrix""" + config: common_types.Config mesh: Mesh quant: Optional[quantizations.AqtQuantization] = None def setup(self): self.weight_mat = self.param( - 'weights', - nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - (self.config.emb_dim, self.config.emb_dim) + "weights", + nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + (self.config.emb_dim, self.config.emb_dim), ) def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode): @@ -41,22 +43,24 @@ def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, else: return inputs @ self.weight_mat.astype(inputs.dtype) + class SimpleMlpDecoderLayer(nn.Module): - """ Decoder layer consisting of [embed,mlp] followed by an [mlp,embed] matmul. """ + """Decoder layer consisting of [embed,mlp] followed by an [mlp,embed] matmul.""" + config: common_types.Config mesh: Mesh quant: Optional[quantizations.AqtQuantization] = None def setup(self): self.ff_1 = self.param( - 'ff_1', - nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - (self.config.emb_dim, self.config.mlp_dim) + "ff_1", + nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + (self.config.emb_dim, self.config.mlp_dim), ) self.ff_2 = self.param( - 'ff_2', - nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - (self.config.mlp_dim, self.config.emb_dim) + "ff_2", + nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + (self.config.mlp_dim, self.config.emb_dim), ) def __call__(self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode): diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index 8e39ec180..ec126bd4a 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -1,4 +1,3 @@ - """ Copyright 2023 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); @@ -56,12 +55,12 @@ def unpermute_from_match_maxtext_rope(arr): split_size = arr.shape[-1] // 2 # Assuming half for evens, half for odds evens = arr[..., :split_size] odds = arr[..., split_size:] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) -def reverse_scale(arr,scale): +def reverse_scale(arr, scale): """ - MaxText has the scaling factor included into the weights, + MaxText has the scaling factor included into the weights, we reverse it when writing out the HuggingFace checkpoint """ return arr * np.sqrt(scale) @@ -114,102 +113,128 @@ def convert_state_to_hf(training_state, model_size): # Load the model specific parameters model_params = llama_or_mistral_ckpt.MODEL_PARAMS_DICT[model_size] - base_num_decoder_layers = model_params['num_layers'] - base_num_query_heads = model_params['num_heads'] - head_dim = model_params['dims_per_head'] - base_num_kv_heads = model_params['num_kv_heads'] - num_experts = model_params['num_experts'] if 'num_experts' in model_params else None + base_num_decoder_layers = model_params["num_layers"] + base_num_query_heads = model_params["num_heads"] + head_dim = model_params["dims_per_head"] + base_num_kv_heads = model_params["num_kv_heads"] + num_experts = model_params["num_experts"] if "num_experts" in model_params else None hf_model_params = {} # Port the embedding weights - hf_model_params["model.embed_tokens.weight"] = torch.tensor(np.asarray( - training_state.params['params']['token_embedder']['embedding']), - dtype=torch.float16) + hf_model_params["model.embed_tokens.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["token_embedder"]["embedding"]), dtype=torch.float16 + ) - for layer_int in tqdm(range(base_num_decoder_layers),desc="Porting parameters layerwise"): + for layer_int in tqdm(range(base_num_decoder_layers), desc="Porting parameters layerwise"): print(f"Converting weights for layer {layer_int}") # Attention layers - hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor(np.asarray( - unpermute_from_match_maxtext_rope( - reverse_scale( - training_state.params['params']["decoder"]["layers"]["self_attention"]["query"]["kernel"][:, layer_int, :, :] - ,head_dim + hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor( + np.asarray( + unpermute_from_match_maxtext_rope( + reverse_scale( + training_state.params["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"][ + :, layer_int, :, : + ], + head_dim, + ) ) - ).reshape(base_num_query_heads * head_dim,base_num_query_heads * head_dim).T), - dtype=torch.float16 + .reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) + .T + ), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor(np.asarray( - unpermute_from_match_maxtext_rope( - training_state.params['params']["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :] - ).reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor( + np.asarray( + unpermute_from_match_maxtext_rope( + training_state.params["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :] + ) + .reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim) + .T + ), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["layers"]["self_attention"]["value"]["kernel"][:, layer_int, :, :] - .reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor( + np.asarray( + training_state.params["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"][:, layer_int, :, :] + .reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim) + .T + ), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.self_attn.o_proj.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["layers"]["self_attention"]["out"]["kernel"][:, layer_int, :, :] - .reshape(base_num_query_heads * head_dim,base_num_query_heads * head_dim).T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.self_attn.o_proj.weight"] = torch.tensor( + np.asarray( + training_state.params["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"][:, layer_int, :, :] + .reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) + .T + ), + dtype=torch.float16, ) # MLP Layers if num_experts is None: - hf_model_params[f"model.layers.{layer_int}.mlp.gate_proj.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["layers"]["mlp"]["wi_0"]["kernel"][:, layer_int, :].T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.mlp.gate_proj.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wi_0"]["kernel"][:, layer_int, :].T), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.mlp.up_proj.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["layers"]["mlp"]["wi_1"]["kernel"][:, layer_int, :].T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.mlp.up_proj.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wi_1"]["kernel"][:, layer_int, :].T), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.mlp.down_proj.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["layers"]["mlp"]["wo"]["kernel"][:, layer_int, :].T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.mlp.down_proj.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wo"]["kernel"][:, layer_int, :].T), + dtype=torch.float16, ) else: - hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.gate.weight"] = torch.tensor(np.asarray( - training_state.params['params']['decoder']['layers']['MoeBlock_0']['gate']['kernel'][:,layer_int,:].T - ), dtype=torch.float16) + hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.gate.weight"] = torch.tensor( + np.asarray( + training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["gate"]["kernel"][:, layer_int, :].T + ), + dtype=torch.float16, + ) for k in range(num_experts): - hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w1.weight"] = torch.tensor(np.asarray( - training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_0'][k, layer_int, :, :].T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w1.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wi_0"][k, layer_int, :, :].T), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w2.weight"] = torch.tensor(np.asarray( - training_state.params['params']['decoder']['layers']['MoeBlock_0']['wo'][k, layer_int, :, :].T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w2.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wo"][k, layer_int, :, :].T), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w3.weight"] = torch.tensor(np.asarray( - training_state.params['params']['decoder']['layers']['MoeBlock_0']['wi_1'][k, layer_int, :, :].T), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w3.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wi_1"][k, layer_int, :, :].T), + dtype=torch.float16, ) # Pre/post attention layer norm - hf_model_params[f"model.layers.{layer_int}.input_layernorm.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"][:, layer_int] - .reshape(base_num_query_heads * head_dim)), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.input_layernorm.weight"] = torch.tensor( + np.asarray( + training_state.params["params"]["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"][ + :, layer_int + ].reshape(base_num_query_heads * head_dim) + ), + dtype=torch.float16, ) - hf_model_params[f"model.layers.{layer_int}.post_attention_layernorm.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"][:, layer_int] - .reshape(base_num_query_heads * head_dim)), - dtype=torch.float16 + hf_model_params[f"model.layers.{layer_int}.post_attention_layernorm.weight"] = torch.tensor( + np.asarray( + training_state.params["params"]["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"][ + :, layer_int + ].reshape(base_num_query_heads * head_dim) + ), + dtype=torch.float16, ) # LM head and layernorm - hf_model_params["lm_head.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["logits_dense"]["kernel"].T), - dtype=torch.float16 + hf_model_params["lm_head.weight"] = torch.tensor( + np.asarray(training_state.params["params"]["decoder"]["logits_dense"]["kernel"].T), dtype=torch.float16 ) - hf_model_params["model.norm.weight"] = torch.tensor(np.asarray( - training_state.params['params']["decoder"]["decoder_norm"]["scale"].reshape(base_num_query_heads * head_dim)), - dtype=torch.float16 + hf_model_params["model.norm.weight"] = torch.tensor( + np.asarray( + training_state.params["params"]["decoder"]["decoder_norm"]["scale"].reshape(base_num_query_heads * head_dim) + ), + dtype=torch.float16, ) return hf_model_params diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index 61f4c7f23..356529934 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -122,36 +122,37 @@ SIMULATED_CPU_DEVICES_COUNT = 16 + def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict: + # pylint: disable=line-too-long return { - "tok_embeddings.weight": "model.embed_tokens.weight", - "norm.weight": "model.norm.weight", - "output.weight": "lm_head.weight", - # MOE model - f"layers.{layer_idx}.attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight", - f"layers.{layer_idx}.ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight", - f"layers.{layer_idx}.attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight", - f"layers.{layer_idx}.attention.wk.weight": f"model.layers.{layer_idx}.self_attn.k_proj.weight", - f"layers.{layer_idx}.attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight", - f"layers.{layer_idx}.attention.wo.weight": f"model.layers.{layer_idx}.self_attn.o_proj.weight", - f"layers.{layer_idx}.feed_forward.gate.weight": f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", - f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight": - f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight", - f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight": - f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight", - f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w3.weight": - f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight", - # dense model - f"layers.{layer_idx}.feed_forward.w1.weight": f"model.layers.{layer_idx}.mlp.gate_proj.weight", - f"layers.{layer_idx}.feed_forward.w2.weight": f"model.layers.{layer_idx}.mlp.down_proj.weight", - f"layers.{layer_idx}.feed_forward.w3.weight": f"model.layers.{layer_idx}.mlp.up_proj.weight", + "tok_embeddings.weight": "model.embed_tokens.weight", + "norm.weight": "model.norm.weight", + "output.weight": "lm_head.weight", + # MOE model + f"layers.{layer_idx}.attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight", + f"layers.{layer_idx}.ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight", + f"layers.{layer_idx}.attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight", + f"layers.{layer_idx}.attention.wk.weight": f"model.layers.{layer_idx}.self_attn.k_proj.weight", + f"layers.{layer_idx}.attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight", + f"layers.{layer_idx}.attention.wo.weight": f"model.layers.{layer_idx}.self_attn.o_proj.weight", + f"layers.{layer_idx}.feed_forward.gate.weight": f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", + f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight": f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight", + f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight": f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight", + f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w3.weight": f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight", + # dense model + f"layers.{layer_idx}.feed_forward.w1.weight": f"model.layers.{layer_idx}.mlp.gate_proj.weight", + f"layers.{layer_idx}.feed_forward.w2.weight": f"model.layers.{layer_idx}.mlp.down_proj.weight", + f"layers.{layer_idx}.feed_forward.w3.weight": f"model.layers.{layer_idx}.mlp.up_proj.weight", } + @dataclass class _HFNamespaceMapper: - """A class to dynamically map Mistral/Llama weight names to Huggingface weights + """A class to dynamically map Mistral/Llama weight names to Huggingface weights if the checkpoint is from HF. """ + collection: dict delimiter: str = "." @@ -193,7 +194,7 @@ def convert_to_jax_weights(base_model_path, model_size): vocab_size = model_params["vocab"] num_experts = model_params["num_experts"] if "num_experts" in model_params else None mem_info = psutil.Process() - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) max_logging.log(f"Loading the base model from {base_model_path}") # Skip any hidden files for checkpoints @@ -207,7 +208,7 @@ def convert_to_jax_weights(base_model_path, model_size): # map weight names if they use HuggingFace instead of PyTorch convention chkpt_vars = [_HFNamespaceMapper(var) for var in chkpt_vars] - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # initialize the data structure for storing jax_weights layer_key = "MoeBlock_0" if num_experts else "mlp" @@ -230,29 +231,27 @@ def convert_to_jax_weights(base_model_path, model_size): decoder_norm_scale = chkpt_vars[0]["norm.weight"].type(torch.float16).numpy() jax_weights["decoder"]["decoder_norm"]["scale"] = decoder_norm_scale - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # logits dense ################################################# max_logging.log("Processing logits dense") - logits_dense = np.concatenate([var["output.weight"].type( - torch.float16).numpy() for var in chkpt_vars], - axis=0).transpose()[:, :vocab_size] + logits_dense = np.concatenate( + [var["output.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0 + ).transpose()[:, :vocab_size] jax_weights["decoder"]["logits_dense"]["kernel"] = logits_dense - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # token embedding ############################################## max_logging.log("Processing token embeddings") - if model_size[:6] == 'llama3': - token_embedder = np.concatenate( - [var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0 - ) + if model_size[:6] == "llama3": + token_embedder = np.concatenate([var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0) else: token_embedder = np.concatenate( - [var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=1 + [var["tok_embeddings.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=1 )[:vocab_size, :] jax_weights["token_embedder"]["embedding"] = token_embedder - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # self attention ############################################### max_logging.log("Processing self attention") @@ -293,10 +292,10 @@ def convert_to_jax_weights(base_model_path, model_size): self_attention["value"]["kernel"] = np.zeros(stack_shape + wv.shape, dtype=np.float16) self_attention["out"]["kernel"] = np.zeros(stack_shape + w_post.shape, dtype=np.float16) - self_attention["query"]["kernel"][layer_idx, ...] = wq # pylint: disable=E1137 - self_attention["key"]["kernel"][layer_idx, ...] = wk # pylint: disable=E1137 - self_attention["value"]["kernel"][layer_idx, ...] = wv # pylint: disable=E1137 - self_attention["out"]["kernel"][layer_idx, ...] = w_post # pylint: disable=E1137 + self_attention["query"]["kernel"][layer_idx, ...] = wq # pylint: disable=E1137 + self_attention["key"]["kernel"][layer_idx, ...] = wk # pylint: disable=E1137 + self_attention["value"]["kernel"][layer_idx, ...] = wv # pylint: disable=E1137 + self_attention["out"]["kernel"][layer_idx, ...] = w_post # pylint: disable=E1137 self_attention["query"]["kernel"] = np.transpose(self_attention["query"]["kernel"], axes=(1, 0, 2, 3)) self_attention["key"]["kernel"] = np.transpose(self_attention["key"]["kernel"], axes=(1, 0, 2, 3)) @@ -309,12 +308,11 @@ def convert_to_jax_weights(base_model_path, model_size): self_attention["query"]["kernel"] = self_attention["query"]["kernel"] / np.sqrt(head_dim) jax_weights["decoder"]["layers"]["self_attention"] = self_attention - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # layer weight pre and post self attention norm ################ max_logging.log("Processing pre and post self attention norms") - layer_weight = {"pre_self_attention_layer_norm": {"scale": None}, - "post_self_attention_layer_norm": {"scale": None}} + layer_weight = {"pre_self_attention_layer_norm": {"scale": None}, "post_self_attention_layer_norm": {"scale": None}} # self attention layer norm and swap the layer index for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False): @@ -323,11 +321,13 @@ def convert_to_jax_weights(base_model_path, model_size): if layer_weight["pre_self_attention_layer_norm"]["scale"] is None: stack_shape = (base_num_decoder_layers,) layer_weight["pre_self_attention_layer_norm"]["scale"] = np.zeros( - stack_shape + pre_self_attention_layernorm.shape, dtype=np.float16) + stack_shape + pre_self_attention_layernorm.shape, dtype=np.float16 + ) layer_weight["post_self_attention_layer_norm"]["scale"] = np.zeros( - stack_shape + post_self_attention_layernorm.shape, dtype=np.float16) - layer_weight["pre_self_attention_layer_norm"]["scale"][layer_idx, ...] = pre_self_attention_layernorm # pylint: disable=E1137 - layer_weight["post_self_attention_layer_norm"]["scale"][layer_idx, ...] = post_self_attention_layernorm # pylint: disable=E1137 + stack_shape + post_self_attention_layernorm.shape, dtype=np.float16 + ) + layer_weight["pre_self_attention_layer_norm"]["scale"][layer_idx, ...] = pre_self_attention_layernorm # pylint: disable=E1137 + layer_weight["post_self_attention_layer_norm"]["scale"][layer_idx, ...] = post_self_attention_layernorm # pylint: disable=E1137 layer_weight["pre_self_attention_layer_norm"]["scale"] = np.transpose( layer_weight["pre_self_attention_layer_norm"]["scale"], axes=(1, 0) @@ -338,7 +338,7 @@ def convert_to_jax_weights(base_model_path, model_size): jax_weights["decoder"]["layers"]["pre_self_attention_layer_norm"] = layer_weight["pre_self_attention_layer_norm"] jax_weights["decoder"]["layers"]["post_self_attention_layer_norm"] = layer_weight["post_self_attention_layer_norm"] - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # layer weights ################################################ max_logging.log("Processing layer weights") @@ -418,8 +418,7 @@ def convert_to_jax_weights(base_model_path, model_size): layer_weight["mlp"]["wi_1"]["kernel"][ei, li, ...] = wi_1 layer_weight["mlp"]["wo"]["kernel"][ei, li, ...] = wo gc.collect() - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) - + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) if num_experts is None: # swap the layer index @@ -435,15 +434,14 @@ def convert_to_jax_weights(base_model_path, model_size): jax_weights["decoder"]["layers"]["MoeBlock_0"]["wi_0"] = layer_weight["mlp"]["wi_0"]["kernel"] jax_weights["decoder"]["layers"]["MoeBlock_0"]["wi_1"] = layer_weight["mlp"]["wi_1"]["kernel"] jax_weights["decoder"]["layers"]["MoeBlock_0"]["wo"] = layer_weight["mlp"]["wo"]["kernel"] - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) del chkpt_vars gc.collect() - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) return jax_weights - def save_jax_weights_to_checkpoint(maxtext_model_path, jax_weights): """ Function to save jax_weights ready for MaxText to a parameters checkpoint @@ -455,7 +453,7 @@ def save_jax_weights_to_checkpoint(maxtext_model_path, jax_weights): """Save maxtext parameter checkpoint.""" mem_info = psutil.Process() - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) gc.collect() mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis") s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis")) # shards first axis @@ -481,7 +479,7 @@ def checkpoint_device_put(arr): jax_weights_new.append(checkpoint_device_put(jax_weight)) del jax_weight gc.collect() - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) jax_weights = tree.unflatten(jax_weights_struct, jax_weights_new) @@ -499,7 +497,7 @@ def checkpoint_device_put(arr): step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore ) - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024 ** 3)) + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new): max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}") @@ -520,5 +518,4 @@ def checkpoint_device_put(arr): os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}" - save_jax_weights_to_checkpoint(args.maxtext_model_path, - convert_to_jax_weights(args.base_model_path, args.model_size)) + save_jax_weights_to_checkpoint(args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size)) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 93641fdb4..045ed5d36 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -58,11 +58,7 @@ def finder(x): def l2norm_pytree(x): """L2 norm of a pytree of arrays.""" - return jnp.sqrt( - jax.tree_util.tree_reduce( - lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0 - ) - ) + return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0)) def calculate_num_params_from_pytree(params): @@ -74,16 +70,13 @@ def calculate_num_params_from_pytree(params): def calculate_total_params_per_chip(params): """Calculate total paramsper chip.""" + def calculate_leaf_params_per_chip(arr): shard = arr.addressable_shards[0] return np.prod(shard.data.shape) - params_sizes_per_chip = jax.tree_util.tree_map( - calculate_leaf_params_per_chip, params - ) - total_parameters_per_chip = jax.tree_util.tree_reduce( - lambda x, y: x + y, params_sizes_per_chip - ) + params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_per_chip, params) + total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip) return total_parameters_per_chip @@ -101,11 +94,7 @@ def summarize_size_from_pytree(params): def initialize_summary_writer(config): summary_writer_path = os.path.join(config.tensorboard_dir, config.run_name) - return ( - writer.SummaryWriter(summary_writer_path) - if jax.process_index() == 0 - else None - ) + return writer.SummaryWriter(summary_writer_path) if jax.process_index() == 0 else None def close_summary_writer(summary_writer): @@ -180,9 +169,7 @@ def write_config_raw_keys_for_gcs(raw_keys): yaml.dump(raw_keys_dict, config_for_gcs) config_for_gcs.close() - gcs_filename = os.path.join( - raw_keys["base_output_directory"], raw_keys["run_name"], filename - ) + gcs_filename = os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], filename) max_logging.log(f"Moving file {filename} to GCS...") upload_blob(gcs_filename, filename) max_logging.log(f"File {filename} moved successfully!") @@ -216,15 +203,11 @@ def maybe_initialize_jax_distributed_system(raw_keys): # Don't initialize jax distributed with AOT compilation return if is_gpu_backend(raw_keys): - max_logging.log( - "Attempting to initialize the jax distributed system for GPU backend..." - ) + max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") elif is_cpu_backend(raw_keys): - max_logging.log( - "Attempting to initialize the jax distributed system for CPU backend..." - ) + max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") initialize_jax_for_cpu() max_logging.log("Jax distributed system initialized on CPUs!") elif ( @@ -234,7 +217,7 @@ def maybe_initialize_jax_distributed_system(raw_keys): and not raw_keys["enable_single_controller"] ) or raw_keys["hardware"] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") - if not raw_keys['enable_emergency_checkpoint']: + if not raw_keys["enable_emergency_checkpoint"]: jax.distributed.initialize() else: initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys) @@ -257,9 +240,7 @@ def initialize_jax_for_gpu(): def initialize_jax_for_cpu(): """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" coordinator_ip_address = get_coordinator_ip_address() - coordinator_address = ( - coordinator_ip_address + ":1234" - ) # JAX coordinator port used in XPK + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK # Env variables to be set in XPK or otherwise job_index = int(os.environ.get("JOB_INDEX")) job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) @@ -283,12 +264,16 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): process_id, coordinator_address = _retrieve_jax_init_info(raw_keys) if process_id != "" and coordinator_address != "": - max_logging.log(f"Using {process_id} as the process_id and {coordinator_address} as the" - " coordinator_address to initialize JAX distributed runtime...") + max_logging.log( + f"Using {process_id} as the process_id and {coordinator_address} as the" + " coordinator_address to initialize JAX distributed runtime..." + ) jax.distributed.initialize(coordinator_address=coordinator_address, process_id=int(process_id)) else: - max_logging.log("Initializing JAX distributed runtime without args when emergency checkpointing is" - " enabled. This should not happen and your workload may have unexpected behavior.") + max_logging.log( + "Initializing JAX distributed runtime without args when emergency checkpointing is" + " enabled. This should not happen and your workload may have unexpected behavior." + ) jax.distributed.initialize() ocp.multihost.utils.initialize_runtime_to_distributed_ids() @@ -305,11 +290,12 @@ def _retrieve_jax_init_info(raw_keys): # "repair" time is longer. for i in range(900): if local_jax_init_info_file.exists(): - return local_jax_init_info_file.read_text().split('\n')[:2] + return local_jax_init_info_file.read_text().split("\n")[:2] max_logging.log(f"Unable to locate {JAX_INIT_INFO_FILE} after {i} seconds, sleeping for 1 second before retrying...") time.sleep(1) - max_logging.log(f"Unable to locate {JAX_INIT_INFO_FILE} after 900 seconds," - "returning empty process id and coordinator address.") + max_logging.log( + f"Unable to locate {JAX_INIT_INFO_FILE} after 900 seconds," "returning empty process id and coordinator address." + ) return "", "" @@ -346,9 +332,7 @@ def get_coordinator_ip_address(): return coordinator_ip_address -def fill_unspecified_mesh_axes( - parallelism_vals, target_product, parallelism_type -): +def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type): """Evaluates unspecified DCN/ICI parallelism values""" if -1 in parallelism_vals: assert ( @@ -406,20 +390,12 @@ def create_device_mesh(config, devices=None): ] # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes( - ici_parallelism, num_devices_per_slice, "ICI" - ) + ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") - allow_split_physical_axes = ( - config.allow_split_physical_axes - if config.allow_split_physical_axes - else False - ) + allow_split_physical_axes = config.allow_split_physical_axes if config.allow_split_physical_axes else False if multi_slice_env: - dcn_parallelism = fill_unspecified_mesh_axes( - dcn_parallelism, num_slices, "DCN" - ) + dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") mesh = mesh_utils.create_hybrid_device_mesh( ici_parallelism, dcn_parallelism, @@ -435,9 +411,9 @@ def create_device_mesh(config, devices=None): ) else: mesh = mesh_utils.create_device_mesh( - ici_parallelism, - devices, - ) + ici_parallelism, + devices, + ) max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") @@ -454,9 +430,7 @@ def unbox_logicallypartioned(boxed_pytree): a pytree where all all LogicallyPartitioned leaves have been unboxed. """ return jax.tree_util.tree_map( - lambda x: x.unbox() - if isinstance(x, flax.linen.spmd.LogicallyPartitioned) - else x, + lambda x: x.unbox() if isinstance(x, flax.linen.spmd.LogicallyPartitioned) else x, boxed_pytree, is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), ) @@ -493,7 +467,6 @@ def init_initial_state(model, tx, config, is_training, key): return init_decode_state(model.apply, model_vars) - def setup_decode_state(model, config, rng, mesh, checkpoint_manager): """Setup decode state by loading params from a checkpoint. Args: @@ -509,32 +482,21 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager): """ if not config.load_parameters_path: # generate random params - max_logging.log( - "No decode checkpoint specified - generating random weights." - ) - state, state_mesh_annotations, _ = setup_initial_state( - model, None, None, config, rng, mesh, checkpoint_manager, False - ) + max_logging.log("No decode checkpoint specified - generating random weights.") + state, state_mesh_annotations, _ = setup_initial_state(model, None, None, config, rng, mesh, checkpoint_manager, False) else: # Load params from checkpoint max_logging.log(f"Loading decode params from {config.load_parameters_path}") - unboxed_abstract_state, state_mesh_annotations, _ = ( - get_abstract_state(model, None, config, rng, mesh, False) - ) + unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(model, None, config, rng, mesh, False) with nn_partitioning.axis_rules(config.logical_axis_rules): - params = checkpointing.load_params_from_path( - config.load_parameters_path, - unboxed_abstract_state.params - ) + params = checkpointing.load_params_from_path(config.load_parameters_path, unboxed_abstract_state.params) state = init_decode_state(None, params) state = unbox_logicallypartioned(state) return state, state_mesh_annotations -def setup_training_state( - model, data_iterator, tx, config, rng, mesh, checkpoint_manager -): +def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): is_training = True return setup_initial_state( model, @@ -575,8 +537,8 @@ def setup_initial_state( state_mesh_annotations: the mesh annotations for the train state """ - unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = ( - get_abstract_state(model, tx, config, rng, mesh, is_training) + unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( + model, tx, config, rng, mesh, is_training ) # Initialization @@ -599,9 +561,7 @@ def setup_initial_state( data_iterator.local_iterator = restored["iter"] state = restored["items"] else: - init_state_partial = functools.partial( - init_initial_state, model, tx, config, is_training - ) + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) state = jax.jit( init_state_partial, in_shardings=None, @@ -640,15 +600,11 @@ def schedule(step): lr = config.learning_rate cos_final_lr = lr * config.cosine_learning_rate_final_fraction - warmup_steps = int( - config.learning_rate_schedule_steps * config.warmup_steps_fraction - ) + warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction) cos_steps = config.learning_rate_schedule_steps - warmup_steps constant_zero_steps = config.steps - config.learning_rate_schedule_steps - warmup_schedule = optax.linear_schedule( - init_value=0.0, end_value=lr, transition_steps=warmup_steps - ) + warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) constant_schedule = optax.constant_schedule(0.0) @@ -668,9 +624,7 @@ def schedule(step): # Cross entropy implementation is taken from original T5X codebase: # https://github.com/google-research/t5x/blob/ace831eea1e2742b4299cd1a9af7e4f302038351/t5x/losses.py#L25-L101 @jax.custom_vjp -def cross_entropy_with_logits( - logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float -) -> Tuple[jnp.ndarray, jnp.ndarray]: +def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: """Computes cross entropy loss with stable custom gradient. Computes a stabilized-gradient version of: -jnp.sum(targets * nn.log_softmax(logits), axis=-1) @@ -699,9 +653,7 @@ def cross_entropy_with_logits( return loss, total_z_loss -def _cross_entropy_with_logits_fwd( - logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0 -) -> Tuple[ +def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0) -> Tuple[ Tuple[jnp.ndarray, jnp.ndarray], Tuple[ jnp.ndarray, @@ -751,10 +703,7 @@ def _cross_entropy_with_logits_bwd( g = g[0] # Ignore z_loss component as that is only used for logging. logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res # z-loss term adds the (2 * z_loss * log_z) factor. - deriv = ( - jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - - targets - ) + deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets g_logits = jnp.expand_dims(g, axis=-1) * deriv g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax return ( @@ -764,33 +713,23 @@ def _cross_entropy_with_logits_bwd( ) # sets z-loss coeff gradient to 0 -cross_entropy_with_logits.defvjp( - _cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd -) +cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) def get_abstract_state(model, tx, config, rng, mesh, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" - init_state_partial = functools.partial( - init_initial_state, model, tx, config, is_training - ) + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial, rng) state_logical_annotations = nn.get_partition_spec(abstract_state) - state_mesh_shardings = nn.logical_to_mesh_sharding( - state_logical_annotations, mesh, config.logical_axis_rules - ) + state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) - abstract_sharded_state = jax.jit( - init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings - ).eval_shape(rng) + abstract_sharded_state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings).eval_shape(rng) - unboxed_abstract_sharded_state = unbox_logicallypartioned( - abstract_sharded_state - ) + unboxed_abstract_sharded_state = unbox_logicallypartioned(abstract_sharded_state) # Initialization with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) @@ -841,14 +780,10 @@ def print_model_vars(print_str, model_vars): def get_project(): """Get project""" - completed_command = subprocess.run( - ["gcloud", "config", "get", "project"], check=True, capture_output=True - ) + completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) project_outputs = completed_command.stdout.decode().strip().split("\n") if len(project_outputs) < 1 or project_outputs[-1] == "": - max_logging.log( - "You must specify config.vertex_tensorboard_project or set 'gcloud config set project '" - ) + max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project '") return None return project_outputs[-1] @@ -864,9 +799,7 @@ def delete_leaf(leaf): def summarize_pytree_data(params, name="Params", raw=False): """Generate basic metrics of a given Pytree.""" - num_params, total_param_size, avg_param_size = summarize_size_from_pytree( - params - ) + num_params, total_param_size, avg_param_size = summarize_size_from_pytree(params) if not raw: num_params_in_billions = num_params / 1e9 total_param_size_in_gb = total_param_size / 1e9 @@ -887,27 +820,28 @@ def summarize_pytree_data(params, name="Params", raw=False): def save_quantized_checkpoint_if_configured(config, params): - assert config.quantization, 'quantization must be configured' + assert config.quantization, "quantization must be configured" if config.save_quantized_params_path: checkpointing.save_params_to_path(config.save_quantized_params_path, params) else: "Skipping saving quantized checkpoint as save_quantized_params_path is null." -def print_mem_stats(label:str): - print(f'\nMemstats: {label}:') +def print_mem_stats(label: str): + print(f"\nMemstats: {label}:") try: for d in jax.local_devices(): stats = d.memory_stats() - used = round(stats['bytes_in_use']/2**30, 2) - limit = round(stats['bytes_limit']/2**30, 2) + used = round(stats["bytes_in_use"] / 2**30, 2) + limit = round(stats["bytes_limit"] / 2**30, 2) print(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") except (RuntimeError, KeyError): print("\tMemstats unavailable.") + def print_system_information(): - """ Print system information of the current environment. - Note that this will initialize the JAX backend. """ + """Print system information of the current environment. + Note that this will initialize the JAX backend.""" max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") max_logging.log(f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}") diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index f4e0a1d4d..b4a4c0b8e 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -92,30 +92,32 @@ def load_params(self, *args, **kwargs) -> Params: ) self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh) self.kv_cache_shardings = jax.tree_util.tree_map( - lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) + lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations + ) if self.model.quant and not self.config.checkpoint_is_quantized: params = self.quantize_params(state) else: params = state.params - max_utils.print_mem_stats('After load_params') + max_utils.print_mem_stats("After load_params") return params def quantize_params(self, state): """Forward pass to quantize decode params.""" self.model.quant.quant_mode = quantizations.get_quant_mode("convert") + @jax.jit def model_apply(_p, _rng): return self.model.apply( - _p | {"aqt": {}}, - jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), - jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), - decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32), - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={"params": _rng}, - mutable=True, - ) + _p | {"aqt": {}}, + jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": _rng}, + mutable=True, + ) _, new_vars = model_apply(state.params, self.rng) # Remove param values which have corresponding qtensors in aqt to save memory. @@ -123,8 +125,8 @@ def model_apply(_p, _rng): params["aqt"] = new_vars["aqt"] params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"]) self.abstract_params = jax.tree_util.tree_map( - lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params - ) + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params + ) max_utils.save_quantized_checkpoint_if_configured(self.config, params) self.model.quant.quant_mode = quantizations.get_quant_mode("serve") return params @@ -192,9 +194,7 @@ def prefill( all_valid = jnp.ones(first_generated_token.shape, dtype=jnp.int8) result = engine_api.ResultTokens( - data=jnp.concatenate( - (first_generated_token, all_valid, generated_tokens), axis=1 - ), + data=jnp.concatenate((first_generated_token, all_valid, generated_tokens), axis=1), # Tokens are shape [batch, speculations], so when we concatenate # tokens, validity and length along their index 1 dimension then they # occupy 0:speculations. @@ -211,7 +211,7 @@ def prefill( "cache": new_vars["cache"], "next_pos": next_pos, "generated_tokens": generated_tokens, - "tokens": first_generated_token + "tokens": first_generated_token, }, result @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,)) @@ -246,9 +246,7 @@ def generate(self, params: Params, decode_state: DecodeState) -> Tuple[DecodeSta all_valid = jnp.ones(new_token.shape, dtype=jnp.int8) result = engine_api.ResultTokens( - data=jnp.concatenate( - (new_token, all_valid, decode_state["generated_tokens"]), axis=1 - ), + data=jnp.concatenate((new_token, all_valid, decode_state["generated_tokens"]), axis=1), # Tokens are shape [batch, speculations], so when we concatenate # tokens, validity and length along their index 1 dimension then they # occupy 0:speculations. @@ -265,7 +263,7 @@ def generate(self, params: Params, decode_state: DecodeState) -> Tuple[DecodeSta "cache": new_cache, "next_pos": decode_state["next_pos"] + 1, "generated_tokens": decode_state["generated_tokens"] + 1, - "tokens": new_token + "tokens": new_token, }, result @functools.partial( @@ -334,9 +332,7 @@ def copy(path, partial_cache, full_cache, annotations): inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( decode_state["generated_tokens"], unboxed_prefix["generated_tokens"], slot, 0 ) - inserted_tokens = jax.lax.dynamic_update_index_in_dim( - decode_state["tokens"], unboxed_prefix["tokens"], slot, 0 - ) + inserted_tokens = jax.lax.dynamic_update_index_in_dim(decode_state["tokens"], unboxed_prefix["tokens"], slot, 0) inserted_logits = jax.lax.with_sharding_constraint(inserted_logits, self.replicated_sharding) inserted_generated_tokens = jax.lax.with_sharding_constraint(inserted_generated_tokens, self.replicated_sharding) @@ -349,7 +345,7 @@ def copy(path, partial_cache, full_cache, annotations): "cache": inserted_cache, "next_pos": inserted_next_pos, "generated_tokens": inserted_generated_tokens, - "tokens": inserted_tokens + "tokens": inserted_tokens, } def get_prefix_destination_sharding(self) -> Any: @@ -394,7 +390,7 @@ def init(abstract_params): "cache": cache["cache"], "next_pos": next_pos, "generated_tokens": generated_tokens, - "tokens": tokens + "tokens": tokens, } with nn_partitioning.axis_rules(self.config.logical_axis_rules): diff --git a/MaxText/maxengine_server.py b/MaxText/maxengine_server.py index 741af8e74..6bc3791b2 100644 --- a/MaxText/maxengine_server.py +++ b/MaxText/maxengine_server.py @@ -38,11 +38,9 @@ def main(config): devices = server_lib.get_devices() server_config = maxengine_config.get_server_config("MaxtextInterleavedServer", config) - metrics_server_config : config_lib.MetricsServerConfig | None = None + metrics_server_config: config_lib.MetricsServerConfig | None = None if config.prometheus_port != 0: - metrics_server_config = config_lib.MetricsServerConfig( - port=config.prometheus_port - ) + metrics_server_config = config_lib.MetricsServerConfig(port=config.prometheus_port) # We separate credential from run so that we can unit test it with # local credentials. @@ -55,7 +53,7 @@ def main(config): metrics_server_config=metrics_server_config, enable_jax_profiler=config.enable_jax_profiler if config.enable_jax_profiler else False, jax_profiler_port=config.jax_profiler_port if config.jax_profiler_port else 9999, - enable_model_warmup=config.enable_model_warmup if config.enable_model_warmup else False + enable_model_warmup=config.enable_model_warmup if config.enable_model_warmup else False, ) jetstream_server.wait_for_termination() diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 475c72b09..7e93ad068 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -30,6 +30,7 @@ OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" + def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config): """Get the shardings (both state and data) for train_step""" functional_train = get_functional_train_step(train_step, model, config) @@ -92,10 +93,12 @@ def get_train_input_output_trees(func, input_args, input_kwargs): p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree) return p_train_step + def calculate_tokens_training_per_device(config): """Calculate training Tokens per device""" return config.max_target_length * config.per_device_batch_size * config.gradient_accumulation_steps + def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops): """ Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder @@ -106,12 +109,14 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim + # local attention - 4 * config.per_device_batch_size * config.max_target_length * min(config.sliding_window_size, config.max_target_length) - * config.num_query_heads * config.head_dim - ) - attention_tflops = ( - attention_flops * config.num_decoder_layers * 3 / 10**12 + 4 + * config.per_device_batch_size + * config.max_target_length + * min(config.sliding_window_size, config.max_target_length) + * config.num_query_heads + * config.head_dim ) + attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12 # multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer learnable_weight_tflops = ( @@ -120,6 +125,7 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo return attention_tflops, learnable_weight_tflops + def calculate_tflops_training_per_device(config, log=True): """Calculate training TFLOP""" ffn1_flops = ( @@ -146,9 +152,7 @@ def calculate_tflops_training_per_device(config, log=True): * (config.num_query_heads + 2 * config.num_kv_heads) * config.head_dim ) - attention_flops = ( - 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim - ) + attention_flops = 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim projection_flops = ( 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_query_heads * config.head_dim ) @@ -159,16 +163,12 @@ def calculate_tflops_training_per_device(config, log=True): ((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 ) # megatron tflops calculation does not account for causality in attention - attention_tflops = ( - attention_flops * config.num_decoder_layers * 3 / 10**12 - ) + attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12 # override for gemma2 decoder tflop calculation - if config.decoder_block == 'gemma2': - attention_tflops, learnable_weight_tflops = ( - calculate_gemma2_tflops_training_per_device( - config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops - ) + if config.decoder_block == "gemma2": + attention_tflops, learnable_weight_tflops = calculate_gemma2_tflops_training_per_device( + config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops ) learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps @@ -243,6 +243,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.02): f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " "of total parameters." ) + def apply_gradient_clipping(raw_grads, state, clipping_threshold): """Applies gradient clipping to raw gradients, with special handing for FLAX fp8 stats. @@ -259,13 +260,14 @@ def apply_gradient_clipping(raw_grads, state, clipping_threshold): # Scales + Amax History for Delayed Tensor Scaling SHOULD NOT be clipped or affect clipping fp8_stats = raw_grads.pop(OVERWRITE_WITH_GRADIENT) grads, _ = gradient_clip_transformation.update(raw_grads, state, None) - grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands - raw_grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands + grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands + raw_grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands else: grads, _ = gradient_clip_transformation.update(raw_grads, state, None) return grads + def get_nested_value(dictionary, nested_key, default=None): """ Retrieves a value from a nested key in a dictionary. diff --git a/MaxText/profiler.py b/MaxText/profiler.py index faee511c5..e5df875dd 100644 --- a/MaxText/profiler.py +++ b/MaxText/profiler.py @@ -23,6 +23,7 @@ import jax + class Profiler: """Activate/deactivate a profiler based on the 'profiler' config""" @@ -40,10 +41,9 @@ def activate(self): return if self.mode == "nsys": try: - self.libcudart = cdll.LoadLibrary('libcudart.so') - except Exception as e: # pylint: disable=broad-except - max_logging.log(f"WARNING: Failed to load library for nsys: {e}\n" - "profiler has no effect") + self.libcudart = cdll.LoadLibrary("libcudart.so") + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"WARNING: Failed to load library for nsys: {e}\n" "profiler has no effect") return self.libcudart.cudaProfilerStart() elif self.mode == "xplane": @@ -58,10 +58,9 @@ def deactivate(self): if self.libcudart is not None: self.libcudart.cudaProfilerStop() else: - max_logging.log("WARNING: library for nsys was not loaded \n" - "profiler has no effect") + max_logging.log("WARNING: library for nsys was not loaded \n" "profiler has no effect") return # Popen() instead of run() for non-blocking behavior - subprocess.Popen(["gsutil", "cp", "*nsys-rep", self.output_path]) # pylint: disable=consider-using-with + subprocess.Popen(["gsutil", "cp", "*nsys-rep", self.output_path]) # pylint: disable=consider-using-with elif self.mode == "xplane": jax.profiler.stop_trace() diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 8ea48e7e9..2ff2926e1 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -58,6 +58,7 @@ def validate_compute_axis_order(s: str) -> None: if s not in valid_compute_axis_order: # currently supported compute_axis_order raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order) + def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None: valid_kv_quant_axis = ("", "dkv", "heads_and_dkv") if s not in valid_kv_quant_axis: # currently supported kv_quant_axis @@ -65,11 +66,13 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None: if quantize_kvcache and s == "": raise ValueError("kv_quant_axis can not be '' when quantize_kvcache is True") + def validate_attention_kernel(s: str) -> None: valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te") if s not in valid_attention_kernels: # currently supported attention raise ValueError("Invalid attention kernel was passed. Valid options ", valid_attention_kernels) + def validate_attention_type(s: str) -> None: valid_attention_types = (attention_type.value for attention_type in AttentionType) if s not in valid_attention_types: # currently supported attention @@ -96,8 +99,12 @@ def validate_keys(keys): keys["load_parameters_path"] == "" or keys["load_full_state_path"] == "" ), "At most one of `load_parameters_path` or `load_full_state_path` should be set" if keys["enable_emergency_checkpoint"]: - assert keys["local_checkpoint_directory"] != "", "A local checkpoint directory must be specified when using emergency checkpoint" - assert keys["local_checkpoint_period"] > 0, "A positive local checkpoint period must be specified when using emergency checkpoint" + assert ( + keys["local_checkpoint_directory"] != "" + ), "A local checkpoint directory must be specified when using emergency checkpoint" + assert ( + keys["local_checkpoint_period"] > 0 + ), "A positive local checkpoint period must be specified when using emergency checkpoint" else: max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period") if keys["num_experts"] > 1: @@ -111,30 +118,29 @@ def validate_data_input(keys): f"dataset_type set to hf, will use {keys['hf_path']=}, {keys['hf_data_dir']=} and {keys['hf_train_files']=} to read data" ) assert keys["hf_path"] != "", "hf_path can't be empty when dataset_type=hf" - if not keys['hf_train_files']: - keys['hf_train_files'] = None - if not keys['hf_eval_files']: - keys['hf_eval_files'] = None - if keys['hf_eval_files']: - keys['hf_eval_split'] = 'train' - if keys['eval_interval'] > 0: - assert keys['hf_eval_split'], "Please specify hf_eval_split or set eval_interval to <=0." + if not keys["hf_train_files"]: + keys["hf_train_files"] = None + if not keys["hf_eval_files"]: + keys["hf_eval_files"] = None + if keys["hf_eval_files"]: + keys["hf_eval_split"] = "train" + if keys["eval_interval"] > 0: + assert keys["hf_eval_split"], "Please specify hf_eval_split or set eval_interval to <=0." elif keys["dataset_type"] == "grain": max_logging.log( f"dataset_type set to grain, will use {keys['grain_train_files']=} as data files, and {keys['grain_worker_count']} workers" ) - assert keys['grain_train_files'] != "", "grain_train_files can't be empty when dataset_type=grain" - if keys['eval_interval'] > 0: - assert keys['grain_eval_files'], "Please specify grain_eval_files or set eval_interval to <=0." + assert keys["grain_train_files"] != "", "grain_train_files can't be empty when dataset_type=grain" + if keys["eval_interval"] > 0: + assert keys["grain_eval_files"], "Please specify grain_eval_files or set eval_interval to <=0." elif keys["dataset_type"] == "tfds": - max_logging.log( - f"dataset_type set to tfds, will use {keys['dataset_path']=} and {keys['dataset_name']=}" - ) - assert keys['dataset_name'] != "", "dataset_name can't be empty when dataset_type=tfds" - if keys['eval_interval'] > 0: + max_logging.log(f"dataset_type set to tfds, will use {keys['dataset_path']=} and {keys['dataset_name']=}") + assert keys["dataset_name"] != "", "dataset_name can't be empty when dataset_type=tfds" + if keys["eval_interval"] > 0: assert keys["eval_split"], "Please specify eval_split or set eval_interval to <=0." + def validate_model_name(s: str) -> bool: """Validate provided model name.""" # currently supported models @@ -329,27 +335,44 @@ def user_init(raw_keys): raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"] raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"] - raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"], raw_keys["micro_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys) + ( + raw_keys["global_batch_size_to_load"], + raw_keys["global_batch_size_to_train_on"], + raw_keys["micro_batch_size_to_train_on"], + ) = calculate_global_batch_sizes(raw_keys) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) if using_pipeline_parallelism(raw_keys): raw_keys["using_pipeline_parallelism"] = True - num_stages = int(raw_keys['ici_pipeline_parallelism'] * raw_keys['dcn_pipeline_parallelism']) - if raw_keys['num_pipeline_repeats'] == -1: - num_pipeline_repeats, remainder = divmod(raw_keys['num_decoder_layers'], num_stages * raw_keys['num_layers_per_pipeline_stage']) - assert not remainder, f"The number of layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) times the number of stages ({num_stages}) must divide the number of decoder layers ({raw_keys['num_decoder_layers']}) " - raw_keys['num_pipeline_repeats'] = num_pipeline_repeats - assert num_stages * raw_keys['num_pipeline_repeats'] * raw_keys['num_layers_per_pipeline_stage'] == raw_keys['num_decoder_layers'], f"The product of pipeline stages ({num_stages}), repeats ({raw_keys['num_pipeline_repeats']}), and layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) must be equal to the number of layers ({raw_keys['num_decoder_layers']})" - if raw_keys['num_pipeline_microbatches'] == -1: - if raw_keys['pipeline_delay_activation_forwarding']: - raw_keys['num_pipeline_microbatches'] = 2 * num_stages + num_stages = int(raw_keys["ici_pipeline_parallelism"] * raw_keys["dcn_pipeline_parallelism"]) + if raw_keys["num_pipeline_repeats"] == -1: + num_pipeline_repeats, remainder = divmod( + raw_keys["num_decoder_layers"], num_stages * raw_keys["num_layers_per_pipeline_stage"] + ) + assert ( + not remainder + ), f"The number of layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) times the number of stages ({num_stages}) must divide the number of decoder layers ({raw_keys['num_decoder_layers']}) " + raw_keys["num_pipeline_repeats"] = num_pipeline_repeats + assert ( + num_stages * raw_keys["num_pipeline_repeats"] * raw_keys["num_layers_per_pipeline_stage"] + == raw_keys["num_decoder_layers"] + ), f"The product of pipeline stages ({num_stages}), repeats ({raw_keys['num_pipeline_repeats']}), and layers per stage ({raw_keys['num_layers_per_pipeline_stage']}) must be equal to the number of layers ({raw_keys['num_decoder_layers']})" + if raw_keys["num_pipeline_microbatches"] == -1: + if raw_keys["pipeline_delay_activation_forwarding"]: + raw_keys["num_pipeline_microbatches"] = 2 * num_stages else: - raw_keys['num_pipeline_microbatches'] = num_stages - assert raw_keys['num_pipeline_microbatches'] % num_stages == 0, f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})" - assert raw_keys['micro_batch_size_to_train_on'] % raw_keys['num_pipeline_microbatches'] == 0, f"The batch size ({raw_keys['micro_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})" + raw_keys["num_pipeline_microbatches"] = num_stages + assert ( + raw_keys["num_pipeline_microbatches"] % num_stages == 0 + ), f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})" + assert ( + raw_keys["micro_batch_size_to_train_on"] % raw_keys["num_pipeline_microbatches"] == 0 + ), f"The batch size ({raw_keys['micro_batch_size_to_train_on']}) must be divisible by the number of microbatches ({raw_keys['num_pipeline_microbatches']})" if raw_keys["pipeline_delay_activation_forwarding"]: - assert raw_keys['num_pipeline_microbatches'] >= 2 * num_stages, f"Delayed activation forwarding requires at least 2 * num_stages microbatches, but {num_stages} stages are used with {raw_keys['num_pipeline_microbatches']} microbatches" + assert ( + raw_keys["num_pipeline_microbatches"] >= 2 * num_stages + ), f"Delayed activation forwarding requires at least 2 * num_stages microbatches, but {num_stages} stages are used with {raw_keys['num_pipeline_microbatches']} microbatches" else: raw_keys["using_pipeline_parallelism"] = False @@ -403,14 +426,17 @@ def update_model_vars(base_config_path, raw_keys, config_name: str): raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name) return updated_keys + def validate_megablox_parallelism(raw_keys): - if raw_keys["megablox"] and (using_sequence_parallelism(raw_keys) or - using_pipeline_parallelism(raw_keys) or - using_expert_parallelism(raw_keys)): + if raw_keys["megablox"] and ( + using_sequence_parallelism(raw_keys) or using_pipeline_parallelism(raw_keys) or using_expert_parallelism(raw_keys) + ): raise ValueError("Currently we only support Megablox with data and tensor parallelism.") tensor_parallelism = raw_keys["ici_tensor_parallelism"] * raw_keys["dcn_tensor_parallelism"] if raw_keys["megablox"] and using_tensor_parallelism(raw_keys) and (raw_keys["emb_dim"] % tensor_parallelism): - raise ValueError(f"The embedding dimension {raw_keys['emb_dim']} is not divisible by tensor parallelism setting {tensor_parallelism}.") + raise ValueError( + f"The embedding dimension {raw_keys['emb_dim']} is not divisible by tensor parallelism setting {tensor_parallelism}." + ) def create_new_logical_axis_rules(old_logical_axis_rules, new_logical_axis_rules): @@ -422,26 +448,29 @@ def create_new_logical_axis_rules(old_logical_axis_rules, new_logical_axis_rules continue replacements.append((logical_axis, mesh_axes)) new_logical_axis.add(logical_axis) - old_logical_rules_filtered = [(old_logical_axis, _lists_to_tuples(old_mesh_axes)) for old_logical_axis, old_mesh_axes - in old_logical_axis_rules if old_logical_axis not in new_logical_axis] + old_logical_rules_filtered = [ + (old_logical_axis, _lists_to_tuples(old_mesh_axes)) + for old_logical_axis, old_mesh_axes in old_logical_axis_rules + if old_logical_axis not in new_logical_axis + ] return old_logical_rules_filtered + replacements def update_model_keys(raw_keys, model_keys, key): - """Update `key` value in `raw_keys` from the value in `model_keys`. """ + """Update `key` value in `raw_keys` from the value in `model_keys`.""" assert key in model_keys and key in raw_keys - if key == 'logical_axis_rules': + if key == "logical_axis_rules": raw_keys[key] = create_new_logical_axis_rules( - old_logical_axis_rules=raw_keys[key], - new_logical_axis_rules=model_keys[key]) + old_logical_axis_rules=raw_keys[key], new_logical_axis_rules=model_keys[key] + ) return raw_keys[key] = model_keys[key] + def validate_and_update_keys(raw_keys, model_keys, config_name: str): """Validate and update model specific config keys""" max_logging.log("Updating following parameters in config\n") - for k in model_keys: max_logging.log(f"{k}: {model_keys[k]}") if k not in raw_keys: @@ -509,8 +538,8 @@ def get_num_target_devices(raw_keys): def get_num_slices(raw_keys): - """ Calculate num_slices based on number of devices. """ - if raw_keys['hardware'] == 'cpu': + """Calculate num_slices based on number of devices.""" + if raw_keys["hardware"] == "cpu": max_logging.log(" Setting num_slices=1 for CPU hardware type") return 1 if int(raw_keys["compile_topology_num_slices"]) > 0: @@ -529,17 +558,22 @@ def get_quantization_local_shard_count(raw_keys): else: return raw_keys["quantization_local_shard_count"] + def using_pipeline_parallelism(raw_keys) -> bool: - return int(raw_keys['ici_pipeline_parallelism']) > 1 or int(raw_keys['dcn_pipeline_parallelism']) > 1 + return int(raw_keys["ici_pipeline_parallelism"]) > 1 or int(raw_keys["dcn_pipeline_parallelism"]) > 1 + def using_tensor_parallelism(raw_keys) -> bool: - return int(raw_keys['ici_tensor_parallelism']) > 1 or int(raw_keys['dcn_tensor_parallelism']) > 1 + return int(raw_keys["ici_tensor_parallelism"]) > 1 or int(raw_keys["dcn_tensor_parallelism"]) > 1 + def using_sequence_parallelism(raw_keys) -> bool: - return int(raw_keys['ici_sequence_parallelism']) > 1 or int(raw_keys['dcn_sequence_parallelism']) > 1 + return int(raw_keys["ici_sequence_parallelism"]) > 1 or int(raw_keys["dcn_sequence_parallelism"]) > 1 + def using_expert_parallelism(raw_keys) -> bool: - return int(raw_keys['ici_expert_parallelism']) > 1 or int(raw_keys['dcn_expert_parallelism']) > 1 + return int(raw_keys["ici_expert_parallelism"]) > 1 or int(raw_keys["dcn_expert_parallelism"]) > 1 + class HyperParameters: # pylint: disable=missing-class-docstring diff --git a/MaxText/scratch_code/golden_gemma-2b_export.ipynb b/MaxText/scratch_code/golden_gemma-2b_export.ipynb index e35f0094d..8b03a161e 100644 --- a/MaxText/scratch_code/golden_gemma-2b_export.ipynb +++ b/MaxText/scratch_code/golden_gemma-2b_export.ipynb @@ -43,11 +43,11 @@ "source": [ "import os\n", "\n", - "VARIANT = '2b' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", + "VARIANT = \"2b\" # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", "\n", "\n", - "ckpt_path = '/local/path/gemma-2b-flax/2b/'\n", - "vocab_path = '/local/path/gemma-2b-flax/tokenizer.model'" + "ckpt_path = \"/local/path/gemma-2b-flax/2b/\"\n", + "vocab_path = \"/local/path/gemma-2b-flax/tokenizer.model\"" ] }, { @@ -59,6 +59,7 @@ "source": [ "# Load parameters\n", "from gemma import params as params_lib\n", + "\n", "params = params_lib.load_and_format_params(ckpt_path)" ] }, @@ -70,6 +71,7 @@ "outputs": [], "source": [ "import sentencepiece as spm\n", + "\n", "vocab = spm.SentencePieceProcessor()\n", "vocab.Load(vocab_path)" ] @@ -89,8 +91,7 @@ "from gemma import transformer as transformer_lib\n", "\n", "config_2b = transformer_lib.TransformerConfig.from_params(\n", - " params,\n", - " cache_size=30 # Number of time steps in the transformer's cache\n", + " params, cache_size=30 # Number of time steps in the transformer's cache\n", ")\n", "model_2b = transformer_lib.Transformer(config=config_2b)" ] @@ -107,7 +108,7 @@ "sampler = sampler_lib.Sampler(\n", " transformer=model_2b,\n", " vocab=vocab,\n", - " params=params['transformer'],\n", + " params=params[\"transformer\"],\n", ")" ] }, @@ -118,17 +119,17 @@ "metadata": {}, "outputs": [], "source": [ - "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", + "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n", "\n", "out_data = sampler(\n", " input_strings=prompt_texts,\n", " total_generation_steps=6, # number of steps performed when generating\n", - " )\n", + ")\n", "\n", "for input_string, out_string in zip(prompt_texts, out_data.text):\n", " print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")\n", " print()\n", - " print(10*'#')" + " print(10 * \"#\")" ] }, { @@ -140,16 +141,18 @@ "source": [ "import jax\n", "\n", - "def get_attention_mask_and_positions(example: jax.Array,\n", - " pad_id : int,\n", - " )-> tuple[jax.Array, jax.Array]:\n", - " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", "\n", - " pad_mask = example != pad_id\n", + "def get_attention_mask_and_positions(\n", + " example: jax.Array,\n", + " pad_id: int,\n", + ") -> tuple[jax.Array, jax.Array]:\n", + " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", "\n", - " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", - " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", - " return current_token_position, attention_mask" + " pad_mask = example != pad_id\n", + "\n", + " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", + " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", + " return current_token_position, attention_mask" ] }, { @@ -166,42 +169,40 @@ "\n", "params = params_lib.load_and_format_params(ckpt_path)\n", "\n", - "output_path = \"golden_data_gemma-2b.jsonl\" \n", + "output_path = \"golden_data_gemma-2b.jsonl\"\n", "all_data_to_save = []\n", "\n", "for prompt_index in range(len(prompt_texts)):\n", - " prompt_text = prompt_texts[prompt_index]\n", - " one_sample_input = np.array([2]+vocab.encode(prompt_text))\n", - " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", - " pad_id = vocab.pad_id\n", - " get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " # Build the position and attention mask vectors.\n", - " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", - "\n", - "\n", - "\n", - " # Foward pass on the input data.\n", - " # No attention cache is needed here.\n", - "\n", - " logits, _ = model_2b.apply(\n", - " # params,\n", - " {'params': params['transformer']},\n", - " expanded_one_sample_input,\n", - " positions,\n", - " None, # Attention cache is None.\n", - " attention_mask,\n", - " )\n", - " print(f\"{logits=}\")\n", - " \n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_texts[prompt_index], \n", - " \"completion\": out_data.text[prompt_index], \n", - " \"tokens\": [2]+vocab.encode(prompt_texts[prompt_index]), \n", - " \"logits\": logits[0].tolist() #remove the batch dim and then tolist() for json serialization\n", - " } \n", - " all_data_to_save.append(data_to_save)\n" + " prompt_text = prompt_texts[prompt_index]\n", + " one_sample_input = np.array([2] + vocab.encode(prompt_text))\n", + " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", + " pad_id = vocab.pad_id\n", + " get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " # Build the position and attention mask vectors.\n", + " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", + "\n", + " # Foward pass on the input data.\n", + " # No attention cache is needed here.\n", + "\n", + " logits, _ = model_2b.apply(\n", + " # params,\n", + " {\"params\": params[\"transformer\"]},\n", + " expanded_one_sample_input,\n", + " positions,\n", + " None, # Attention cache is None.\n", + " attention_mask,\n", + " )\n", + " print(f\"{logits=}\")\n", + "\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_texts[prompt_index],\n", + " \"completion\": out_data.text[prompt_index],\n", + " \"tokens\": [2] + vocab.encode(prompt_texts[prompt_index]),\n", + " \"logits\": logits[0].tolist(), # remove the batch dim and then tolist() for json serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)" ] }, { @@ -211,9 +212,8 @@ "metadata": {}, "outputs": [], "source": [ - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", - "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", "\n", "print(f\"Data saved to {output_path}\")" diff --git a/MaxText/scratch_code/golden_gemma2-27b_export-flax.ipynb b/MaxText/scratch_code/golden_gemma2-27b_export-flax.ipynb index b90645298..359644bc7 100644 --- a/MaxText/scratch_code/golden_gemma2-27b_export-flax.ipynb +++ b/MaxText/scratch_code/golden_gemma2-27b_export-flax.ipynb @@ -31,11 +31,11 @@ "source": [ "import os\n", "\n", - "VARIANT = '27b' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", + "VARIANT = \"27b\" # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", "\n", "\n", - "ckpt_path = '/home/zhaoyuec/data/gemma2/gemma2-27b/ckpt/'\n", - "vocab_path = '/home/zhaoyuec/data/gemma2/gemma2-27b/tokenizer.model'" + "ckpt_path = \"/home/zhaoyuec/data/gemma2/gemma2-27b/ckpt/\"\n", + "vocab_path = \"/home/zhaoyuec/data/gemma2/gemma2-27b/tokenizer.model\"" ] }, { @@ -47,6 +47,7 @@ "source": [ "# Load parameters\n", "from gemma import params as params_lib\n", + "\n", "params = params_lib.load_and_format_params(ckpt_path)" ] }, @@ -69,6 +70,7 @@ ], "source": [ "import sentencepiece as spm\n", + "\n", "vocab = spm.SentencePieceProcessor()\n", "vocab.Load(vocab_path)" ] @@ -96,8 +98,7 @@ "from gemma import transformer as transformer_lib\n", "\n", "config_27b = transformer_lib.TransformerConfig.from_params(\n", - " params,\n", - " cache_size=30 # Number of time steps in the transformer's cache\n", + " params, cache_size=30 # Number of time steps in the transformer's cache\n", ")\n", "model_27b = transformer_lib.Transformer(config=config_27b)" ] @@ -114,7 +115,7 @@ "sampler = sampler_lib.Sampler(\n", " transformer=model_27b,\n", " vocab=vocab,\n", - " params=params['transformer'],\n", + " params=params[\"transformer\"],\n", ")" ] }, @@ -125,7 +126,7 @@ "metadata": {}, "outputs": [], "source": [ - "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", + "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n", "# prompt_texts = [\"I love to\"]\n", "\n", "# out_data = sampler(\n", @@ -148,16 +149,18 @@ "source": [ "import jax\n", "\n", - "def get_attention_mask_and_positions(example: jax.Array,\n", - " pad_id : int,\n", - " )-> tuple[jax.Array, jax.Array]:\n", - " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", "\n", - " pad_mask = example != pad_id\n", + "def get_attention_mask_and_positions(\n", + " example: jax.Array,\n", + " pad_id: int,\n", + ") -> tuple[jax.Array, jax.Array]:\n", + " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", + "\n", + " pad_mask = example != pad_id\n", "\n", - " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", - " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", - " return current_token_position, attention_mask" + " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", + " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", + " return current_token_position, attention_mask" ] }, { @@ -388,41 +391,40 @@ "\n", "params = params_lib.load_and_format_params(ckpt_path)\n", "\n", - "output_path = \"golden_data_gemma2-27b.jsonl\" \n", + "output_path = \"golden_data_gemma2-27b.jsonl\"\n", "all_data_to_save = []\n", "\n", "for prompt_index in range(len(prompt_texts)):\n", - " prompt_text = prompt_texts[prompt_index]\n", - " one_sample_input = np.array([2]+vocab.encode(prompt_text))\n", - " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", - " pad_id = vocab.pad_id\n", - " get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " # Build the position and attention mask vectors.\n", - " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", + " prompt_text = prompt_texts[prompt_index]\n", + " one_sample_input = np.array([2] + vocab.encode(prompt_text))\n", + " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", + " pad_id = vocab.pad_id\n", + " get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " # Build the position and attention mask vectors.\n", + " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", "\n", + " # Foward pass on the input data.\n", + " # No attention cache is needed here.\n", "\n", - " # Foward pass on the input data.\n", - " # No attention cache is needed here.\n", - "\n", - " logits, _ = model_27b.apply(\n", - " # params,\n", - " {'params': params['transformer']},\n", - " expanded_one_sample_input,\n", - " positions,\n", - " None, # Attention cache is None.\n", - " attention_mask,\n", - " )\n", - " print(f\"{logits=}\")\n", - " print(logits.shape)\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_texts[prompt_index], \n", - " # \"completion\": out_data.text[prompt_index], \n", - " \"tokens\": [2]+vocab.encode(prompt_texts[prompt_index]), \n", - " \"logits\": logits[0].tolist() #remove the batch dim and then tolist() for json serialization\n", - " } \n", - " all_data_to_save.append(data_to_save)\n" + " logits, _ = model_27b.apply(\n", + " # params,\n", + " {\"params\": params[\"transformer\"]},\n", + " expanded_one_sample_input,\n", + " positions,\n", + " None, # Attention cache is None.\n", + " attention_mask,\n", + " )\n", + " print(f\"{logits=}\")\n", + " print(logits.shape)\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_texts[prompt_index],\n", + " # \"completion\": out_data.text[prompt_index],\n", + " \"tokens\": [2] + vocab.encode(prompt_texts[prompt_index]),\n", + " \"logits\": logits[0].tolist(), # remove the batch dim and then tolist() for json serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)" ] }, { @@ -440,9 +442,8 @@ } ], "source": [ - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", - "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", "\n", "print(f\"Data saved to {output_path}\")" diff --git a/MaxText/scratch_code/golden_gemma2-2b_export-flax.ipynb b/MaxText/scratch_code/golden_gemma2-2b_export-flax.ipynb index df7c6c516..f666a309c 100644 --- a/MaxText/scratch_code/golden_gemma2-2b_export-flax.ipynb +++ b/MaxText/scratch_code/golden_gemma2-2b_export-flax.ipynb @@ -31,11 +31,11 @@ "source": [ "import os\n", "\n", - "VARIANT = '2b' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", + "VARIANT = \"2b\" # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", "\n", "\n", - "ckpt_path = '/home/zhaoyuec/workdir/gemma2-2b/ckpt/'\n", - "vocab_path = '/home/zhaoyuec/workdir/gemma2-2b/tokenizer.model'" + "ckpt_path = \"/home/zhaoyuec/workdir/gemma2-2b/ckpt/\"\n", + "vocab_path = \"/home/zhaoyuec/workdir/gemma2-2b/tokenizer.model\"" ] }, { @@ -47,6 +47,7 @@ "source": [ "# Load parameters\n", "from gemma import params as params_lib\n", + "\n", "params = params_lib.load_and_format_params(ckpt_path)" ] }, @@ -69,6 +70,7 @@ ], "source": [ "import sentencepiece as spm\n", + "\n", "vocab = spm.SentencePieceProcessor()\n", "vocab.Load(vocab_path)" ] @@ -96,8 +98,7 @@ "from gemma import transformer as transformer_lib\n", "\n", "config_2b = transformer_lib.TransformerConfig.from_params(\n", - " params,\n", - " cache_size=30 # Number of time steps in the transformer's cache\n", + " params, cache_size=30 # Number of time steps in the transformer's cache\n", ")\n", "model_2b = transformer_lib.Transformer(config=config_2b)" ] @@ -114,7 +115,7 @@ "sampler = sampler_lib.Sampler(\n", " transformer=model_2b,\n", " vocab=vocab,\n", - " params=params['transformer'],\n", + " params=params[\"transformer\"],\n", ")" ] }, @@ -125,7 +126,7 @@ "metadata": {}, "outputs": [], "source": [ - "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", + "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n", "# prompt_texts = [\"I love to\"]\n", "\n", "# out_data = sampler(\n", @@ -148,16 +149,18 @@ "source": [ "import jax\n", "\n", - "def get_attention_mask_and_positions(example: jax.Array,\n", - " pad_id : int,\n", - " )-> tuple[jax.Array, jax.Array]:\n", - " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", "\n", - " pad_mask = example != pad_id\n", + "def get_attention_mask_and_positions(\n", + " example: jax.Array,\n", + " pad_id: int,\n", + ") -> tuple[jax.Array, jax.Array]:\n", + " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", + "\n", + " pad_mask = example != pad_id\n", "\n", - " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", - " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", - " return current_token_position, attention_mask" + " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", + " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", + " return current_token_position, attention_mask" ] }, { @@ -328,41 +331,40 @@ "\n", "params = params_lib.load_and_format_params(ckpt_path)\n", "\n", - "output_path = \"golden_data_gemma2-2b.jsonl\" \n", + "output_path = \"golden_data_gemma2-2b.jsonl\"\n", "all_data_to_save = []\n", "\n", "for prompt_index in range(len(prompt_texts)):\n", - " prompt_text = prompt_texts[prompt_index]\n", - " one_sample_input = np.array([2]+vocab.encode(prompt_text))\n", - " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", - " pad_id = vocab.pad_id\n", - " get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " # Build the position and attention mask vectors.\n", - " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", + " prompt_text = prompt_texts[prompt_index]\n", + " one_sample_input = np.array([2] + vocab.encode(prompt_text))\n", + " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", + " pad_id = vocab.pad_id\n", + " get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " # Build the position and attention mask vectors.\n", + " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", "\n", + " # Foward pass on the input data.\n", + " # No attention cache is needed here.\n", "\n", - " # Foward pass on the input data.\n", - " # No attention cache is needed here.\n", - "\n", - " logits, _ = model_2b.apply(\n", - " # params,\n", - " {'params': params['transformer']},\n", - " expanded_one_sample_input,\n", - " positions,\n", - " None, # Attention cache is None.\n", - " attention_mask,\n", - " )\n", - " print(f\"{logits=}\")\n", - " print(logits.shape)\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_texts[prompt_index], \n", - " # \"completion\": out_data.text[prompt_index], \n", - " \"tokens\": [2]+vocab.encode(prompt_texts[prompt_index]), \n", - " \"logits\": logits[0].tolist() #remove the batch dim and then tolist() for json serialization\n", - " } \n", - " all_data_to_save.append(data_to_save)\n" + " logits, _ = model_2b.apply(\n", + " # params,\n", + " {\"params\": params[\"transformer\"]},\n", + " expanded_one_sample_input,\n", + " positions,\n", + " None, # Attention cache is None.\n", + " attention_mask,\n", + " )\n", + " print(f\"{logits=}\")\n", + " print(logits.shape)\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_texts[prompt_index],\n", + " # \"completion\": out_data.text[prompt_index],\n", + " \"tokens\": [2] + vocab.encode(prompt_texts[prompt_index]),\n", + " \"logits\": logits[0].tolist(), # remove the batch dim and then tolist() for json serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)" ] }, { @@ -380,9 +382,8 @@ } ], "source": [ - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", - "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", "\n", "print(f\"Data saved to {output_path}\")" diff --git a/MaxText/scratch_code/golden_gemma2-2b_export.ipynb b/MaxText/scratch_code/golden_gemma2-2b_export.ipynb index b20233c10..20ce6f195 100644 --- a/MaxText/scratch_code/golden_gemma2-2b_export.ipynb +++ b/MaxText/scratch_code/golden_gemma2-2b_export.ipynb @@ -84,7 +84,7 @@ "source": [ "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", "!pip3 install tokenizers -U\n", - "!pip3 install transformers -U\n" + "!pip3 install transformers -U" ] }, { @@ -103,8 +103,8 @@ } ], "source": [ - "import torch \n", - "from transformers import AutoTokenizer, AutoModelForCausalLM \n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import jsonlines" ] }, @@ -124,15 +124,15 @@ } ], "source": [ - "# Load the tokenizer and model from Hugging Face \n", - " \n", + "# Load the tokenizer and model from Hugging Face\n", + "\n", "model_id = \"google/gemma-2-2b\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.float32,\n", - ")\n" + ")" ] }, { @@ -161,43 +161,42 @@ } ], "source": [ - "# Save to disk \n", - "output_path = \"golden_data_gemma2-2b.jsonl\" \n", - " \n", - " \n", - "# Your prompt text \n", - "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", + "# Save to disk\n", + "output_path = \"golden_data_gemma2-2b.jsonl\"\n", + "\n", + "\n", + "# Your prompt text\n", + "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n", "all_data_to_save = []\n", "\n", "\n", "for prompt_text in prompt_texts:\n", - " # Encode the prompt text \n", - " input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", + " # Encode the prompt text\n", + " input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n", "\n", - " # Get the logits for the prompt + completion \n", - " with torch.no_grad(): \n", - " outputs = model(input_ids) \n", - " logits = outputs.logits \n", + " # Get the logits for the prompt + completion\n", + " with torch.no_grad():\n", + " outputs = model(input_ids)\n", + " logits = outputs.logits\n", "\n", - " # Convert logits to fp32 \n", - " logits = logits.cpu().numpy().astype('float32') \n", + " # Convert logits to fp32\n", + " logits = logits.cpu().numpy().astype(\"float32\")\n", "\n", - " print(logits.shape)\n", + " print(logits.shape)\n", "\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_text, \n", - " \"tokens\": input_ids.tolist()[0], \n", - " \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n", - " } \n", - " all_data_to_save.append(data_to_save)\n", - " \n", - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_text,\n", + " \"tokens\": input_ids.tolist()[0],\n", + " \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)\n", "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", "\n", - "print(f\"Data saved to {output_path}\") " + "print(f\"Data saved to {output_path}\")" ] } ], diff --git a/MaxText/scratch_code/golden_gemma2-9b_export-flax.ipynb b/MaxText/scratch_code/golden_gemma2-9b_export-flax.ipynb index c8c4d4198..671cd2f22 100644 --- a/MaxText/scratch_code/golden_gemma2-9b_export-flax.ipynb +++ b/MaxText/scratch_code/golden_gemma2-9b_export-flax.ipynb @@ -31,11 +31,11 @@ "source": [ "import os\n", "\n", - "VARIANT = '9b' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", + "VARIANT = \"9b\" # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", "\n", "\n", - "ckpt_path = '/home/zhaoyuec/data/gemma2/gemma2-9b/ckpt/'\n", - "vocab_path = '/home/zhaoyuec/data/gemma2/gemma2-9b/tokenizer.model'" + "ckpt_path = \"/home/zhaoyuec/data/gemma2/gemma2-9b/ckpt/\"\n", + "vocab_path = \"/home/zhaoyuec/data/gemma2/gemma2-9b/tokenizer.model\"" ] }, { @@ -47,6 +47,7 @@ "source": [ "# Load parameters\n", "from gemma import params as params_lib\n", + "\n", "params = params_lib.load_and_format_params(ckpt_path)" ] }, @@ -69,6 +70,7 @@ ], "source": [ "import sentencepiece as spm\n", + "\n", "vocab = spm.SentencePieceProcessor()\n", "vocab.Load(vocab_path)" ] @@ -96,8 +98,7 @@ "from gemma import transformer as transformer_lib\n", "\n", "config_9b = transformer_lib.TransformerConfig.from_params(\n", - " params,\n", - " cache_size=30 # Number of time steps in the transformer's cache\n", + " params, cache_size=30 # Number of time steps in the transformer's cache\n", ")\n", "model_9b = transformer_lib.Transformer(config=config_9b)" ] @@ -114,7 +115,7 @@ "sampler = sampler_lib.Sampler(\n", " transformer=model_9b,\n", " vocab=vocab,\n", - " params=params['transformer'],\n", + " params=params[\"transformer\"],\n", ")" ] }, @@ -125,7 +126,7 @@ "metadata": {}, "outputs": [], "source": [ - "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", + "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n", "# prompt_texts = [\"I love to\"]\n", "\n", "# out_data = sampler(\n", @@ -148,16 +149,18 @@ "source": [ "import jax\n", "\n", - "def get_attention_mask_and_positions(example: jax.Array,\n", - " pad_id : int,\n", - " )-> tuple[jax.Array, jax.Array]:\n", - " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", "\n", - " pad_mask = example != pad_id\n", + "def get_attention_mask_and_positions(\n", + " example: jax.Array,\n", + " pad_id: int,\n", + ") -> tuple[jax.Array, jax.Array]:\n", + " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", + "\n", + " pad_mask = example != pad_id\n", "\n", - " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", - " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", - " return current_token_position, attention_mask" + " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", + " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", + " return current_token_position, attention_mask" ] }, { @@ -376,41 +379,40 @@ "\n", "params = params_lib.load_and_format_params(ckpt_path)\n", "\n", - "output_path = \"golden_data_gemma2-9b.jsonl\" \n", + "output_path = \"golden_data_gemma2-9b.jsonl\"\n", "all_data_to_save = []\n", "\n", "for prompt_index in range(len(prompt_texts)):\n", - " prompt_text = prompt_texts[prompt_index]\n", - " one_sample_input = np.array([2]+vocab.encode(prompt_text))\n", - " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", - " pad_id = vocab.pad_id\n", - " get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " # Build the position and attention mask vectors.\n", - " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", - " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", + " prompt_text = prompt_texts[prompt_index]\n", + " one_sample_input = np.array([2] + vocab.encode(prompt_text))\n", + " expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n", + " pad_id = vocab.pad_id\n", + " get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " # Build the position and attention mask vectors.\n", + " positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n", + " print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n", "\n", + " # Foward pass on the input data.\n", + " # No attention cache is needed here.\n", "\n", - " # Foward pass on the input data.\n", - " # No attention cache is needed here.\n", - "\n", - " logits, _ = model_9b.apply(\n", - " # params,\n", - " {'params': params['transformer']},\n", - " expanded_one_sample_input,\n", - " positions,\n", - " None, # Attention cache is None.\n", - " attention_mask,\n", - " )\n", - " print(f\"{logits=}\")\n", - " print(logits.shape)\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_texts[prompt_index], \n", - " # \"completion\": out_data.text[prompt_index], \n", - " \"tokens\": [2]+vocab.encode(prompt_texts[prompt_index]), \n", - " \"logits\": logits[0].tolist() #remove the batch dim and then tolist() for json serialization\n", - " } \n", - " all_data_to_save.append(data_to_save)\n" + " logits, _ = model_9b.apply(\n", + " # params,\n", + " {\"params\": params[\"transformer\"]},\n", + " expanded_one_sample_input,\n", + " positions,\n", + " None, # Attention cache is None.\n", + " attention_mask,\n", + " )\n", + " print(f\"{logits=}\")\n", + " print(logits.shape)\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_texts[prompt_index],\n", + " # \"completion\": out_data.text[prompt_index],\n", + " \"tokens\": [2] + vocab.encode(prompt_texts[prompt_index]),\n", + " \"logits\": logits[0].tolist(), # remove the batch dim and then tolist() for json serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)" ] }, { @@ -428,9 +430,8 @@ } ], "source": [ - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", - "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", "\n", "print(f\"Data saved to {output_path}\")" diff --git a/MaxText/scratch_code/golden_gemma2-9b_export.ipynb b/MaxText/scratch_code/golden_gemma2-9b_export.ipynb index 3f69664da..9ef1bd593 100644 --- a/MaxText/scratch_code/golden_gemma2-9b_export.ipynb +++ b/MaxText/scratch_code/golden_gemma2-9b_export.ipynb @@ -84,7 +84,7 @@ "source": [ "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", "!pip3 install tokenizers -U\n", - "!pip3 install transformers -U\n" + "!pip3 install transformers -U" ] }, { @@ -103,8 +103,8 @@ } ], "source": [ - "import torch \n", - "from transformers import AutoTokenizer, AutoModelForCausalLM \n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import jsonlines" ] }, @@ -124,15 +124,15 @@ } ], "source": [ - "# Load the tokenizer and model from Hugging Face \n", - " \n", + "# Load the tokenizer and model from Hugging Face\n", + "\n", "model_id = \"google/gemma-2-9b\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.float32,\n", - ")\n" + ")" ] }, { @@ -158,41 +158,40 @@ } ], "source": [ - "# Save to disk \n", - "output_path = \"golden_data_gemma2-9b.jsonl\" \n", - " \n", - " \n", - "# Your prompt text \n", - "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", + "# Save to disk\n", + "output_path = \"golden_data_gemma2-9b.jsonl\"\n", + "\n", + "\n", + "# Your prompt text\n", + "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n", "all_data_to_save = []\n", "\n", "\n", "for prompt_text in prompt_texts:\n", - " # Encode the prompt text \n", - " input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", + " # Encode the prompt text\n", + " input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n", "\n", - " # Get the logits for the prompt + completion \n", - " with torch.no_grad(): \n", - " outputs = model(input_ids) \n", - " logits = outputs.logits \n", + " # Get the logits for the prompt + completion\n", + " with torch.no_grad():\n", + " outputs = model(input_ids)\n", + " logits = outputs.logits\n", "\n", - " # Convert logits to fp32 \n", - " logits = logits.cpu().numpy().astype('float32') \n", + " # Convert logits to fp32\n", + " logits = logits.cpu().numpy().astype(\"float32\")\n", "\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_text, \n", - " \"tokens\": input_ids.tolist()[0], \n", - " \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n", - " } \n", - " all_data_to_save.append(data_to_save)\n", - " \n", - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_text,\n", + " \"tokens\": input_ids.tolist()[0],\n", + " \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)\n", "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", "\n", - "print(f\"Data saved to {output_path}\") " + "print(f\"Data saved to {output_path}\")" ] } ], diff --git a/MaxText/scratch_code/golden_llama2-70b_export.py b/MaxText/scratch_code/golden_llama2-70b_export.py index 4f58296ac..c51fdcf92 100644 --- a/MaxText/scratch_code/golden_llama2-70b_export.py +++ b/MaxText/scratch_code/golden_llama2-70b_export.py @@ -14,13 +14,13 @@ limitations under the License. """ -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM import jsonlines from google.cloud import storage -# Load the tokenizer and model from Hugging Face - +# Load the tokenizer and model from Hugging Face + model_id = "meta-llama/Llama-2-70b-hf" tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -29,47 +29,46 @@ torch_dtype=torch.float32, ) - -# Your prompt text + +# Your prompt text prompt_texts = ["I love to", "Today is a", "What is the"] all_data_to_save = [] -output_path = 'golden_data_llama2-70b.jsonl' +output_path = "golden_data_llama2-70b.jsonl" for prompt_text in prompt_texts: - # Encode the prompt text - input_ids = tokenizer.encode(prompt_text, return_tensors='pt') - - # Get the logits for the prompt + completion - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits - - # Convert logits to fp32 - logits = logits.cpu().numpy().astype('float32') - - # Prepare data to be saved - data_to_save = { - "prompt": prompt_text, - "tokens": input_ids.tolist()[0], - "logits": logits.tolist()[0] # Convert numpy array to list for JSON serialization - } - all_data_to_save.append(data_to_save) - -with jsonlines.open(output_path,'w') as f: - f.write_all(all_data_to_save) + # Encode the prompt text + input_ids = tokenizer.encode(prompt_text, return_tensors="pt") + + # Get the logits for the prompt + completion + with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits + + # Convert logits to fp32 + logits = logits.cpu().numpy().astype("float32") + + # Prepare data to be saved + data_to_save = { + "prompt": prompt_text, + "tokens": input_ids.tolist()[0], + "logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization + } + all_data_to_save.append(data_to_save) + +with jsonlines.open(output_path, "w") as f: + f.write_all(all_data_to_save) + def upload_blob(bucket_name, source_file_name, destination_blob_name): - """Uploads a file to the bucket.""" - storage_client = storage.Client() - bucket = storage_client.get_bucket(bucket_name) - blob = bucket.blob(destination_blob_name) + """Uploads a file to the bucket.""" + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + blob = bucket.blob(destination_blob_name) - blob.upload_from_filename(source_file_name) + blob.upload_from_filename(source_file_name) -upload_blob('maxtext-llama', output_path, 'llama2-70b/golden-logits/' + output_path) -print('File {} uploaded to {}.'.format( - output_path, - 'llama2-70b/golden-logits/' + output_path)) +upload_blob("maxtext-llama", output_path, "llama2-70b/golden-logits/" + output_path) +print("File {} uploaded to {}.".format(output_path, "llama2-70b/golden-logits/" + output_path)) diff --git a/MaxText/scratch_code/golden_llama2-7b_export.ipynb b/MaxText/scratch_code/golden_llama2-7b_export.ipynb index 34242b252..d520d7b49 100644 --- a/MaxText/scratch_code/golden_llama2-7b_export.ipynb +++ b/MaxText/scratch_code/golden_llama2-7b_export.ipynb @@ -9,7 +9,7 @@ "source": [ "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", "!pip3 install tokenizers -U\n", - "!pip3 install transformers -U\n" + "!pip3 install transformers -U" ] }, { @@ -19,8 +19,8 @@ "metadata": {}, "outputs": [], "source": [ - "import torch \n", - "from transformers import AutoTokenizer, AutoModelForCausalLM \n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import jsonlines" ] }, @@ -31,15 +31,15 @@ "metadata": {}, "outputs": [], "source": [ - "# Load the tokenizer and model from Hugging Face \n", - " \n", + "# Load the tokenizer and model from Hugging Face\n", + "\n", "model_id = \"meta-llama/Llama-2-7b-hf\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.float32,\n", - ")\n" + ")" ] }, { @@ -57,43 +57,40 @@ "metadata": {}, "outputs": [], "source": [ - "# Save to disk \n", - "output_path = \"golden_data_llama2-7b.jsonl\" \n", - " \n", - " \n", - "# Your prompt text \n", - "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\" ]\n", + "# Save to disk\n", + "output_path = \"golden_data_llama2-7b.jsonl\"\n", + "\n", + "\n", + "# Your prompt text\n", + "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n", "all_data_to_save = []\n", "\n", "\n", "for prompt_text in prompt_texts:\n", - " # Encode the prompt text \n", - " input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", - "\n", - " # Get the logits for the prompt + completion \n", - " with torch.no_grad(): \n", - " outputs = model(input_ids) \n", - " logits = outputs.logits \n", + " # Encode the prompt text\n", + " input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n", "\n", - " # Convert logits to fp32 \n", - " logits = logits.cpu().numpy().astype('float32') \n", + " # Get the logits for the prompt + completion\n", + " with torch.no_grad():\n", + " outputs = model(input_ids)\n", + " logits = outputs.logits\n", "\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_text, \n", - " \"tokens\": input_ids.tolist()[0], \n", - " \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n", - " } \n", - " all_data_to_save.append(data_to_save)\n", - " \n", - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", + " # Convert logits to fp32\n", + " logits = logits.cpu().numpy().astype(\"float32\")\n", "\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_text,\n", + " \"tokens\": input_ids.tolist()[0],\n", + " \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)\n", "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", - "print(f\"Data saved to {output_path}\") \n", "\n", - " \n" + "print(f\"Data saved to {output_path}\")" ] } ], diff --git a/MaxText/scratch_code/golden_llama3-70b_export.py b/MaxText/scratch_code/golden_llama3-70b_export.py index 05851f923..c314e384e 100644 --- a/MaxText/scratch_code/golden_llama3-70b_export.py +++ b/MaxText/scratch_code/golden_llama3-70b_export.py @@ -14,13 +14,13 @@ limitations under the License. """ -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM import jsonlines from google.cloud import storage -# Load the tokenizer and model from Hugging Face - +# Load the tokenizer and model from Hugging Face + model_id = "meta-llama/Meta-Llama-3-70B" tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -29,47 +29,46 @@ torch_dtype=torch.float32, ) - -# Your prompt text + +# Your prompt text prompt_texts = ["I love to"] all_data_to_save = [] -output_path = 'golden_data_llama3-70b.jsonl' +output_path = "golden_data_llama3-70b.jsonl" for prompt_text in prompt_texts: - # Encode the prompt text - input_ids = tokenizer.encode(prompt_text, return_tensors='pt') - - # Get the logits for the prompt + completion - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits - - # Convert logits to fp32 - logits = logits.cpu().numpy().astype('float32') - - # Prepare data to be saved - data_to_save = { - "prompt": prompt_text, - "tokens": input_ids.tolist()[0], - "logits": logits.tolist()[0] # Convert numpy array to list for JSON serialization - } - all_data_to_save.append(data_to_save) - -with jsonlines.open(output_path,'w') as f: - f.write_all(all_data_to_save) + # Encode the prompt text + input_ids = tokenizer.encode(prompt_text, return_tensors="pt") + + # Get the logits for the prompt + completion + with torch.no_grad(): + outputs = model(input_ids) + logits = outputs.logits + + # Convert logits to fp32 + logits = logits.cpu().numpy().astype("float32") + + # Prepare data to be saved + data_to_save = { + "prompt": prompt_text, + "tokens": input_ids.tolist()[0], + "logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization + } + all_data_to_save.append(data_to_save) + +with jsonlines.open(output_path, "w") as f: + f.write_all(all_data_to_save) + def upload_blob(bucket_name, source_file_name, destination_blob_name): - """Uploads a file to the bucket.""" - storage_client = storage.Client() - bucket = storage_client.get_bucket(bucket_name) - blob = bucket.blob(destination_blob_name) + """Uploads a file to the bucket.""" + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + blob = bucket.blob(destination_blob_name) - blob.upload_from_filename(source_file_name) + blob.upload_from_filename(source_file_name) -upload_blob('maxtext-llama', output_path, 'llama3-70b/golden-logits/' + output_path) -print('File {} uploaded to {}.'.format( - output_path, - 'llama3-70b/golden-logits/' + output_path)) +upload_blob("maxtext-llama", output_path, "llama3-70b/golden-logits/" + output_path) +print("File {} uploaded to {}.".format(output_path, "llama3-70b/golden-logits/" + output_path)) diff --git a/MaxText/scratch_code/golden_llama3-8b_export.ipynb b/MaxText/scratch_code/golden_llama3-8b_export.ipynb index b41ce44b0..9868cf114 100644 --- a/MaxText/scratch_code/golden_llama3-8b_export.ipynb +++ b/MaxText/scratch_code/golden_llama3-8b_export.ipynb @@ -8,7 +8,7 @@ "source": [ "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", "!pip3 install tokenizers -U\n", - "!pip3 install transformers -U\n" + "!pip3 install transformers -U" ] }, { @@ -17,8 +17,8 @@ "metadata": {}, "outputs": [], "source": [ - "import torch \n", - "from transformers import AutoTokenizer, AutoModelForCausalLM \n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import jsonlines" ] }, @@ -28,15 +28,15 @@ "metadata": {}, "outputs": [], "source": [ - "# Load the tokenizer and model from Hugging Face \n", - " \n", + "# Load the tokenizer and model from Hugging Face\n", + "\n", "model_id = \"meta-llama/Meta-Llama-3-8B\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.float32,\n", - ")\n" + ")" ] }, { @@ -45,41 +45,40 @@ "metadata": {}, "outputs": [], "source": [ - "# Save to disk \n", - "output_path = \"golden_data_llama3-8b.jsonl\" \n", - " \n", - " \n", - "# Your prompt text \n", + "# Save to disk\n", + "output_path = \"golden_data_llama3-8b.jsonl\"\n", + "\n", + "\n", + "# Your prompt text\n", "prompt_texts = [\"I love to\"]\n", "all_data_to_save = []\n", "\n", "\n", "for prompt_text in prompt_texts:\n", - " # Encode the prompt text \n", - " input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", + " # Encode the prompt text\n", + " input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n", "\n", - " # Get the logits for the prompt + completion \n", - " with torch.no_grad(): \n", - " outputs = model(input_ids)\n", - " logits = outputs.logits \n", + " # Get the logits for the prompt + completion\n", + " with torch.no_grad():\n", + " outputs = model(input_ids)\n", + " logits = outputs.logits\n", "\n", - " # Convert logits to fp32 \n", - " logits = logits.cpu().numpy().astype('float32') \n", + " # Convert logits to fp32\n", + " logits = logits.cpu().numpy().astype(\"float32\")\n", "\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_text, \n", - " \"tokens\": input_ids.tolist()[0], \n", - " \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n", - " } \n", - " all_data_to_save.append(data_to_save)\n", - " \n", - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_text,\n", + " \"tokens\": input_ids.tolist()[0],\n", + " \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)\n", "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", "\n", - "print(f\"Data saved to {output_path}\") \n" + "print(f\"Data saved to {output_path}\")" ] }, { diff --git a/MaxText/scratch_code/golden_mixtral-8x22b_export.ipynb b/MaxText/scratch_code/golden_mixtral-8x22b_export.ipynb index 1b615d3be..a025e8d96 100644 --- a/MaxText/scratch_code/golden_mixtral-8x22b_export.ipynb +++ b/MaxText/scratch_code/golden_mixtral-8x22b_export.ipynb @@ -28,8 +28,8 @@ } ], "source": [ - "import torch \n", - "from transformers import AutoTokenizer, AutoModelForCausalLM \n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import jsonlines" ] }, @@ -48,15 +48,15 @@ } ], "source": [ - "# Load the tokenizer and model from Hugging Face \n", - " \n", + "# Load the tokenizer and model from Hugging Face\n", + "\n", "model_id = \"mistralai/Mixtral-8x22B-Instruct-v0.1\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.float16,\n", - ")\n" + ")" ] }, { @@ -89,43 +89,40 @@ } ], "source": [ - "# Save to disk \n", - "output_path = \"golden_data_mixtral-8x22b.jsonl\" \n", - " \n", - " \n", - "# Your prompt text \n", + "# Save to disk\n", + "output_path = \"golden_data_mixtral-8x22b.jsonl\"\n", + "\n", + "\n", + "# Your prompt text\n", "prompt_texts = [\"[INST] I love to [/INST]\", \"[INST] Today is a [/INST]\", \"[INST] What is the [/INST]\"]\n", "all_data_to_save = []\n", "\n", "\n", "for prompt_text in prompt_texts:\n", - " # Encode the prompt text \n", - " input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", - "\n", - " # Get the logits for the prompt + completion \n", - " with torch.no_grad(): \n", - " outputs = model(input_ids) \n", - " logits = outputs.logits \n", + " # Encode the prompt text\n", + " input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n", "\n", - " # Convert logits to fp32 \n", - " logits = logits.cpu().numpy().astype('float32') \n", + " # Get the logits for the prompt + completion\n", + " with torch.no_grad():\n", + " outputs = model(input_ids)\n", + " logits = outputs.logits\n", "\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_text, \n", - " \"tokens\": input_ids.tolist()[0], \n", - " \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n", - " } \n", - " all_data_to_save.append(data_to_save)\n", - " \n", - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", + " # Convert logits to fp32\n", + " logits = logits.cpu().numpy().astype(\"float32\")\n", "\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_text,\n", + " \"tokens\": input_ids.tolist()[0],\n", + " \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)\n", "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", - "print(f\"Data saved to {output_path}\") \n", "\n", - " \n" + "print(f\"Data saved to {output_path}\")" ] }, { diff --git a/MaxText/scratch_code/golden_mixtral-8x7b_export.ipynb b/MaxText/scratch_code/golden_mixtral-8x7b_export.ipynb index 449404ce5..ad0cca877 100644 --- a/MaxText/scratch_code/golden_mixtral-8x7b_export.ipynb +++ b/MaxText/scratch_code/golden_mixtral-8x7b_export.ipynb @@ -9,7 +9,7 @@ "source": [ "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", "!pip3 install tokenizers -U\n", - "!pip3 install transformers -U\n" + "!pip3 install transformers -U" ] }, { @@ -19,8 +19,8 @@ "metadata": {}, "outputs": [], "source": [ - "import torch \n", - "from transformers import AutoTokenizer, AutoModelForCausalLM \n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import jsonlines" ] }, @@ -31,15 +31,15 @@ "metadata": {}, "outputs": [], "source": [ - "# Load the tokenizer and model from Hugging Face \n", - " \n", + "# Load the tokenizer and model from Hugging Face\n", + "\n", "model_id = \"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.float32,\n", - ")\n" + ")" ] }, { @@ -57,43 +57,40 @@ "metadata": {}, "outputs": [], "source": [ - "# Save to disk \n", - "output_path = \"golden_data_mixtral-8x7b-2.jsonl\" \n", - " \n", - " \n", - "# Your prompt text \n", + "# Save to disk\n", + "output_path = \"golden_data_mixtral-8x7b-2.jsonl\"\n", + "\n", + "\n", + "# Your prompt text\n", "prompt_texts = [\"[INST] I love to [/INST]\", \"[INST] Today is a [/INST]\", \"[INST] What is the [/INST]\"]\n", "all_data_to_save = []\n", "\n", "\n", "for prompt_text in prompt_texts:\n", - " # Encode the prompt text \n", - " input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", - "\n", - " # Get the logits for the prompt + completion \n", - " with torch.no_grad(): \n", - " outputs = model(input_ids) \n", - " logits = outputs.logits \n", + " # Encode the prompt text\n", + " input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n", "\n", - " # Convert logits to fp32 \n", - " logits = logits.cpu().numpy().astype('float32') \n", + " # Get the logits for the prompt + completion\n", + " with torch.no_grad():\n", + " outputs = model(input_ids)\n", + " logits = outputs.logits\n", "\n", - " # Prepare data to be saved \n", - " data_to_save = { \n", - " \"prompt\": prompt_text, \n", - " \"tokens\": input_ids.tolist()[0], \n", - " \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n", - " } \n", - " all_data_to_save.append(data_to_save)\n", - " \n", - "with jsonlines.open(output_path,'w') as f: \n", - " f.write_all(all_data_to_save)\n", + " # Convert logits to fp32\n", + " logits = logits.cpu().numpy().astype(\"float32\")\n", "\n", + " # Prepare data to be saved\n", + " data_to_save = {\n", + " \"prompt\": prompt_text,\n", + " \"tokens\": input_ids.tolist()[0],\n", + " \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n", + " }\n", + " all_data_to_save.append(data_to_save)\n", "\n", + "with jsonlines.open(output_path, \"w\") as f:\n", + " f.write_all(all_data_to_save)\n", "\n", - "print(f\"Data saved to {output_path}\") \n", "\n", - " \n" + "print(f\"Data saved to {output_path}\")" ] } ], diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index f90f77c37..5e7e3447f 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -66,7 +66,7 @@ def main(argv: Sequence[str]) -> None: max_logging.log(f"Found {jax.device_count()} devices.") max_logging.log(f"Found {jax.process_count()} processes.") max_logging.log(f"Found {jax.devices()} devices.") - if config.dataset_type in ('tfds', 'c4_mlperf'): + if config.dataset_type in ("tfds", "c4_mlperf"): os.environ["TFDS_DATA_DIR"] = config.dataset_path data_load_loop(config) diff --git a/MaxText/tests/aot_hlo_identical_test.py b/MaxText/tests/aot_hlo_identical_test.py index d4f47310d..e742dacf9 100644 --- a/MaxText/tests/aot_hlo_identical_test.py +++ b/MaxText/tests/aot_hlo_identical_test.py @@ -21,88 +21,89 @@ class AotHloIdenticalTest(unittest.TestCase): - def run_compile_and_real(self, script_path, aot_dump_dir, real_dump_dir, extra_config_args): - """Executes a Bash script and returns the completed process object.""" - if extra_config_args is not None: - cmd = ["bash", script_path, aot_dump_dir, real_dump_dir, extra_config_args] - else: - cmd = ["bash", script_path, aot_dump_dir, real_dump_dir] - try: - result = subprocess.run( - cmd, # Command to run the script - check=True, # Raise an exception if the script fails - stdout=sys.stdout, # Stream to stdout - stderr=sys.stdout, # Stream to stdout - text=True # Decode output and error as text - ) - return result - except subprocess.CalledProcessError as e: - print(f"Error running script: {e.returncode}") - print(f"Output: {e.stdout}") - print(f"Error: {e.stderr}") - - def find_file_by_substring(self, directory, substring): - for filename in os.listdir(directory): - if substring in filename: - return os.path.join(directory,filename) - raise FileNotFoundError(f"Could not find a file in directory {directory} with substring {substring}") - - def delete_dir(self, dir): - if os.path.exists(dir): - shutil.rmtree(dir) - - def check_large_files_equal(self, file_path1, file_path2): - """Asserts that two potentially large text files have identical content.""" - - hasher1 = hashlib.sha256() - hasher2 = hashlib.sha256() - - with open(file_path1, "rb") as f1, open(file_path2, "rb") as f2: - # Read files in chunks for memory efficiency - while True: - chunk1 = f1.read(8192) # 8 KB chunks - chunk2 = f2.read(8192) - - if not chunk1 and not chunk2: # Reached the end of both files - break - hasher1.update(chunk1) - hasher2.update(chunk2) - - # Handle potential empty files - if not hasher1.digest() or not hasher2.digest(): - # One or both files are empty - return False - - if hasher1.hexdigest() != hasher2.hexdigest(): - # Files have different contents - return False - return True - - def assert_compile_and_real_match_hlo(self, test_name, extra_config_args): - hlo_filename_substring="jit_train_step.after_optimizations_after_buffer_assignment.txt" - compile_dump_dir=f"/tmp/compile_test_xla_dump/{test_name}/aot/" - train_dump_dir=f"/tmp/compile_test_xla_dump/{test_name}/real/" - self.delete_dir(compile_dump_dir) # Ensure directories empty before use - self.delete_dir(train_dump_dir) - - self.run_compile_and_real("tests/aot_hlo_identical_script.sh", compile_dump_dir, train_dump_dir, extra_config_args) - - compile_hlo_file = self.find_file_by_substring(compile_dump_dir, hlo_filename_substring) - train_hlo_file = self.find_file_by_substring(train_dump_dir, hlo_filename_substring) - print(f"AOT compiled HLO file for test {test_name}: {compile_hlo_file}", flush=True) - print(f"Real runs HLO file for test {test_name}: {train_hlo_file}", flush=True) - - files_equal = self.check_large_files_equal(compile_hlo_file, train_hlo_file) - self.delete_dir(compile_dump_dir) # Cleanup directories after use - self.delete_dir(train_dump_dir) - assert files_equal, f"AOT Compiled and real HLO files are not identical for test {test_name}!" - print("AOT Compiled and train HLO files are identical for test {test_name}!") - - # TODO (mattdavidow) - @pytest.mark.skip(reason="Issue w/ kernels_test. Error: The TPU is already in use by process...") - def test_default_hlo_match(self): - self.assert_compile_and_real_match_hlo("default_run", None) - - @pytest.mark.skip(reason="Issue w/ kernels_test. Error: The TPU is already in use by process...") - def test_int8_hlo_match(self): - self.assert_compile_and_real_match_hlo("int8", "quantization=int8") + + def run_compile_and_real(self, script_path, aot_dump_dir, real_dump_dir, extra_config_args): + """Executes a Bash script and returns the completed process object.""" + if extra_config_args is not None: + cmd = ["bash", script_path, aot_dump_dir, real_dump_dir, extra_config_args] + else: + cmd = ["bash", script_path, aot_dump_dir, real_dump_dir] + try: + result = subprocess.run( + cmd, # Command to run the script + check=True, # Raise an exception if the script fails + stdout=sys.stdout, # Stream to stdout + stderr=sys.stdout, # Stream to stdout + text=True, # Decode output and error as text + ) + return result + except subprocess.CalledProcessError as e: + print(f"Error running script: {e.returncode}") + print(f"Output: {e.stdout}") + print(f"Error: {e.stderr}") + + def find_file_by_substring(self, directory, substring): + for filename in os.listdir(directory): + if substring in filename: + return os.path.join(directory, filename) + raise FileNotFoundError(f"Could not find a file in directory {directory} with substring {substring}") + + def delete_dir(self, dir): + if os.path.exists(dir): + shutil.rmtree(dir) + + def check_large_files_equal(self, file_path1, file_path2): + """Asserts that two potentially large text files have identical content.""" + + hasher1 = hashlib.sha256() + hasher2 = hashlib.sha256() + + with open(file_path1, "rb") as f1, open(file_path2, "rb") as f2: + # Read files in chunks for memory efficiency + while True: + chunk1 = f1.read(8192) # 8 KB chunks + chunk2 = f2.read(8192) + + if not chunk1 and not chunk2: # Reached the end of both files + break + hasher1.update(chunk1) + hasher2.update(chunk2) + + # Handle potential empty files + if not hasher1.digest() or not hasher2.digest(): + # One or both files are empty + return False + + if hasher1.hexdigest() != hasher2.hexdigest(): + # Files have different contents + return False + return True + + def assert_compile_and_real_match_hlo(self, test_name, extra_config_args): + hlo_filename_substring = "jit_train_step.after_optimizations_after_buffer_assignment.txt" + compile_dump_dir = f"/tmp/compile_test_xla_dump/{test_name}/aot/" + train_dump_dir = f"/tmp/compile_test_xla_dump/{test_name}/real/" + self.delete_dir(compile_dump_dir) # Ensure directories empty before use + self.delete_dir(train_dump_dir) + + self.run_compile_and_real("tests/aot_hlo_identical_script.sh", compile_dump_dir, train_dump_dir, extra_config_args) + + compile_hlo_file = self.find_file_by_substring(compile_dump_dir, hlo_filename_substring) + train_hlo_file = self.find_file_by_substring(train_dump_dir, hlo_filename_substring) + print(f"AOT compiled HLO file for test {test_name}: {compile_hlo_file}", flush=True) + print(f"Real runs HLO file for test {test_name}: {train_hlo_file}", flush=True) + + files_equal = self.check_large_files_equal(compile_hlo_file, train_hlo_file) + self.delete_dir(compile_dump_dir) # Cleanup directories after use + self.delete_dir(train_dump_dir) + assert files_equal, f"AOT Compiled and real HLO files are not identical for test {test_name}!" + print("AOT Compiled and train HLO files are identical for test {test_name}!") + + # TODO (mattdavidow) + @pytest.mark.skip(reason="Issue w/ kernels_test. Error: The TPU is already in use by process...") + def test_default_hlo_match(self): + self.assert_compile_and_real_match_hlo("default_run", None) + + @pytest.mark.skip(reason="Issue w/ kernels_test. Error: The TPU is already in use by process...") + def test_int8_hlo_match(self): + self.assert_compile_and_real_match_hlo("int8", "quantization=int8") diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index 1447f672e..94c2af633 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -77,7 +77,7 @@ def setUp(self): dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, name="self_attention", - attention_type=self.attention_type + attention_type=self.attention_type, ) self._attention_as_mha_generic_variable = self._attention_as_mha_generic.init( @@ -263,18 +263,15 @@ def tpu_kernel_attention_helper(self, num_kv_heads): def test_dot_product_cache_axis_order(self): all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))] for axis_order in random.choices(all_axis_orders, k=4): - self.dot_product_attention_helper( - prefill_cache_axis_order=axis_order, - ar_cache_axis_order=axis_order - ) + self.dot_product_attention_helper(prefill_cache_axis_order=axis_order, ar_cache_axis_order=axis_order) print(f"passed test for {axis_order=}") def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): - for compute_axis_order in [(0,1,2,3), (0,2,1,3)]: + for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: self._dot_product_attention( - prefill_cache_axis_order, - ar_cache_axis_order, - compute_axis_order=compute_axis_order, + prefill_cache_axis_order, + ar_cache_axis_order, + compute_axis_order=compute_axis_order, ) print(f"passed subtest for {compute_axis_order=}") @@ -357,7 +354,7 @@ def _dot_product_attention( lnx_idx = lnx[:, idx : idx + 1, :] decoder_positions_idx = decoder_positions[:, idx : idx + 1] - + attention_w_layout_variable.update(attention_w_layout_output_cache) attention_w_layout_idx, attention_w_layout_output_cache = attention_w_layout.apply( attention_w_layout_variable, @@ -372,13 +369,15 @@ def _dot_product_attention( attention_w_layout_full_this_idx = attention_w_layout_full[:, idx : idx + 1, :] self.assertTrue(attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape) - self.assertTrue(jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False) + ) @pytest.mark.tpu def test_dot_product_reshape_q(self): - for compute_axis_order in [(0,1,2,3), (0,2,1,3)]: + for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: self._dot_product_attention_reshape_q( - compute_axis_order=compute_axis_order, + compute_axis_order=compute_axis_order, ) print(f"test passed for compute_axis_order: {compute_axis_order}") @@ -480,7 +479,9 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): mutable=["cache"], ) self.assertTrue( - jax.numpy.allclose(attention_wo_reshape_q_full[:, :prefill_length, :], attention_wo_reshape_q_prefill, equal_nan=False) + jax.numpy.allclose( + attention_wo_reshape_q_full[:, :prefill_length, :], attention_wo_reshape_q_prefill, equal_nan=False + ) ) attention_w_reshape_q_prefill, attention_w_reshape_q_output_cache = attention_w_reshape_q.apply( @@ -498,18 +499,20 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): jax.numpy.allclose(attention_w_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_prefill, equal_nan=False) ) + self.assertTrue(jax.numpy.allclose(attention_wo_reshape_q_prefill, attention_w_reshape_q_prefill, equal_nan=False)) self.assertTrue( - jax.numpy.allclose(attention_wo_reshape_q_prefill, attention_w_reshape_q_prefill, equal_nan=False) - ) - self.assertTrue( - jax.numpy.allclose(attention_wo_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_full[:, :prefill_length, :], equal_nan=False) + jax.numpy.allclose( + attention_wo_reshape_q_full[:, :prefill_length, :], + attention_w_reshape_q_full[:, :prefill_length, :], + equal_nan=False, + ) ) for idx in range(prefill_length, decode_total_length): lnx_idx = lnx[:, idx : idx + 1, :] decoder_positions_idx = decoder_positions[:, idx : idx + 1] - + attention_wo_reshape_q_variable.update(attention_wo_reshape_q_output_cache) attention_wo_reshape_q_idx, attention_wo_reshape_q_output_cache = attention_wo_reshape_q.apply( attention_wo_reshape_q_variable, @@ -524,7 +527,11 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): attention_wo_reshape_q_full_this_idx = attention_wo_reshape_q_full[:, idx : idx + 1, :] self.assertTrue(attention_wo_reshape_q_full_this_idx.shape == attention_wo_reshape_q_idx.shape) - self.assertTrue(jax.numpy.allclose(attention_wo_reshape_q_full_this_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_full_this_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False + ) + ) attention_w_reshape_q_variable.update(attention_w_reshape_q_output_cache) attention_w_reshape_q_idx, attention_w_reshape_q_output_cache = attention_w_reshape_q.apply( @@ -540,9 +547,15 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): attention_w_reshape_q_full_this_idx = attention_w_reshape_q_full[:, idx : idx + 1, :] self.assertTrue(attention_w_reshape_q_full_this_idx.shape == attention_w_reshape_q_idx.shape) - self.assertTrue(jax.numpy.allclose(attention_w_reshape_q_full_this_idx, attention_w_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + attention_w_reshape_q_full_this_idx, attention_w_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False + ) + ) - self.assertTrue(jax.numpy.allclose(attention_w_reshape_q_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose(attention_w_reshape_q_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False) + ) def test_sliding_window_attention(self): """Test sliding window attention""" @@ -581,14 +594,16 @@ def test_sliding_window_attention(self): attention_type=attentions.AttentionType.LOCAL_SLIDING, sliding_window_size=8, ) - + # Use freeze to fix the parameters to facilitate the comparison of sliding and global attention. - attn_variable = freeze(sliding_attn.init( - {"params": self.rng, "aqt": self.rng}, - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length)), - )) + attn_variable = freeze( + sliding_attn.init( + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), + ) + ) global_attn_output = global_attn.apply( attn_variable, @@ -614,7 +629,9 @@ def test_sliding_window_attention(self): # Test if sliding window attention is different from global attention self.assertFalse( - jax.numpy.allclose(sliding_window_output.astype(jnp.bfloat16), global_attn_output.astype(jnp.bfloat16), rtol=1e-04, atol=1e-04) + jax.numpy.allclose( + sliding_window_output.astype(jnp.bfloat16), global_attn_output.astype(jnp.bfloat16), rtol=1e-04, atol=1e-04 + ) ) # Attention with sliding window of size max_target_length @@ -648,7 +665,9 @@ def test_sliding_window_attention(self): # Test if sliding window attention with max_target_length size is the same as global attention self.assertTrue( - jax.numpy.allclose(sliding_window_output.astype(jnp.bfloat16), global_attn_output.astype(jnp.bfloat16), rtol=1e-04, atol=1e-04) + jax.numpy.allclose( + sliding_window_output.astype(jnp.bfloat16), global_attn_output.astype(jnp.bfloat16), rtol=1e-04, atol=1e-04 + ) ) diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index 80d71fdcb..eb5c64a3a 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -38,6 +38,7 @@ sys.path.append(maxtext_parent_dir) import max_logging + max_logging.log(f"Added parent directory = {maxtext_parent_dir}") import common_types @@ -50,33 +51,31 @@ def get_data(golden_data, golden_data_index, config): - """ Get the golden data for the test indexed at golden_data_index""" + """Get the golden data for the test indexed at golden_data_index""" max_logging.log(f"Comparing forward pass for golden data index = {golden_data_index} ") max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}") s = (config.global_batch_size_to_train_on, config.max_target_length) - ids = np.asarray(golden_data[golden_data_index]['tokens'], dtype=np.int32) + ids = np.asarray(golden_data[golden_data_index]["tokens"], dtype=np.int32) - logits = np.asarray(golden_data[golden_data_index]['logits'], dtype=np.float32) + logits = np.asarray(golden_data[golden_data_index]["logits"], dtype=np.float32) max_logging.log(f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={ids}, logits.shape = {logits.shape}") - decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR decoder_positions = jnp.stack( [jnp.arange(config.max_target_length, dtype=jnp.int32) for _ in range(config.global_batch_size_to_train_on)] ) - ids = jnp.stack( - [ids for _ in range(config.global_batch_size_to_train_on)] - ) + ids = jnp.stack([ids for _ in range(config.global_batch_size_to_train_on)]) max_logging.log(f"ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}") return ids, decoder_segment_ids, decoder_positions, logits + def main(config, test_args): """Test the Whole Model of model_name""" - #initialize the model with weights from reference ckpt + # initialize the model with weights from reference ckpt ( init_rng, _, @@ -88,15 +87,14 @@ def main(config, test_args): _, _, state, - ) = train.setup_train_loop(config) + ) = train.setup_train_loop(config) - input_golden_data_path = "MaxText/test_assets/golden_data_"+config.model_name+".jsonl" - with jsonlines.open(input_golden_data_path, 'r') as f: + input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl" + with jsonlines.open(input_golden_data_path, "r") as f: golden_data = list(f) - for golden_data_index in range(len(golden_data)): - ids, decoder_segment_ids, decoder_positions, golden_logits = get_data(golden_data,golden_data_index,config) + ids, decoder_segment_ids, decoder_positions, golden_logits = get_data(golden_data, golden_data_index, config) full_train_logits = model.apply( state.params, @@ -110,7 +108,9 @@ def main(config, test_args): max_logging.log(f"{golden_logits[0]=}") max_logging.log(f"{full_train_logits[0, 0, :]=}") token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0] - max_logging.log(f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}") + max_logging.log( + f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}" + ) model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1) golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1) @@ -123,15 +123,20 @@ def main(config, test_args): if test_args.max_kl_div is not None: max_logging.log("Checking KL Divergence between train distribution and golden distribution") - assert jax.numpy.all(kl_div < test_args.max_kl_div), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" + assert jax.numpy.all( + kl_div < test_args.max_kl_div + ), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" else: max_logging.log("Checking Numerical Differences between train logits and golden logits") assert jax.numpy.allclose( - full_train_logits[0, :token_size, :], golden_logits[:token_size, :], rtol=float(test_args.rtol), atol=float(test_args.atol), equal_nan=False + full_train_logits[0, :token_size, :], + golden_logits[:token_size, :], + rtol=float(test_args.rtol), + atol=float(test_args.atol), + equal_nan=False, ), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." - if __name__ == "__main__": jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" @@ -149,6 +154,6 @@ def main(config, test_args): for arg in to_remove_args: model_args = [s for s in model_args if not s.startswith(arg)] - pyconfig.initialize(model_args) + pyconfig.initialize(model_args) cfg = pyconfig.config main(cfg, test_args) diff --git a/MaxText/tests/gradient_accumulation_test.py b/MaxText/tests/gradient_accumulation_test.py index 009bc1110..e2730fe97 100644 --- a/MaxText/tests/gradient_accumulation_test.py +++ b/MaxText/tests/gradient_accumulation_test.py @@ -20,13 +20,14 @@ import random from train import main as train_main + def generate_random_string(length=10): - characters = string.ascii_letters # Include letters, digits, and punctuation - return ''.join(random.choice(characters) for _ in range(length)) + characters = string.ascii_letters # Include letters, digits, and punctuation + return "".join(random.choice(characters) for _ in range(length)) + -class GradientAccumulationTest(unittest.TestCase): +class GradientAccumulationTest(unittest.TestCase): - @pytest.mark.tpu def test_grad_accumulate_same_loss(self): random_suffix = generate_random_string() @@ -35,58 +36,73 @@ def test_grad_accumulate_same_loss(self): run_regular_metrics_file = f"/tmp/runner_regular_{random_suffix}.txt" print(f"{run_regular_metrics_file=}") shared_maxtext_args = [ - None, + None, "configs/base.yml", - r"base_output_directory=gs://runner-maxtext-logs", - r"dataset_path=gs://maxtext-dataset", - "gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off) - "enable_checkpointing=False", - "base_emb_dim=256", - "base_num_decoder_layers=4", - "tokenizer_path=../assets/tokenizer.llama2", - "steps=50", + r"base_output_directory=gs://runner-maxtext-logs", + r"dataset_path=gs://maxtext-dataset", + "gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off) + "enable_checkpointing=False", + "base_emb_dim=256", + "base_num_decoder_layers=4", + "tokenizer_path=../assets/tokenizer.llama2", + "steps=50", ] # Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10 - train_main(shared_maxtext_args + [ - "run_name=runner_grad_accumulate", - f"metrics_file={run_accumulate_metrics_file}", - "per_device_batch_size=1", - "gradient_accumulation_steps=10", - ]) + train_main( + shared_maxtext_args + + [ + "run_name=runner_grad_accumulate", + f"metrics_file={run_accumulate_metrics_file}", + "per_device_batch_size=1", + "gradient_accumulation_steps=10", + ] + ) - #Run without gradient accumulation with per_device_batch=10 - train_main(shared_maxtext_args + [ - "run_name=runner_grad_accumulate_regular", - f"metrics_file={run_regular_metrics_file}", - "per_device_batch_size=10", - "gradient_accumulation_steps=1", - ]) + # Run without gradient accumulation with per_device_batch=10 + train_main( + shared_maxtext_args + + [ + "run_name=runner_grad_accumulate_regular", + f"metrics_file={run_regular_metrics_file}", + "per_device_batch_size=10", + "gradient_accumulation_steps=1", + ] + ) # Assert losses roughly equal - with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\ - open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run: - accum_run_loss = json.loads(accum_run.readlines()[-1])["learning/loss"] - regular_run_loss = json.loads(regular_run.readlines()[-1])["learning/loss"] - print(f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}", flush=True) - print(f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}", flush=True) - # Not identical due to an epsilon addition in loss denominator. - np.testing.assert_allclose(accum_run_loss, regular_run_loss, rtol=0.01) + with ( + open(run_accumulate_metrics_file, "r", encoding="utf8") as accum_run, + open(run_regular_metrics_file, "r", encoding="utf8") as regular_run, + ): + accum_run_loss = json.loads(accum_run.readlines()[-1])["learning/loss"] + regular_run_loss = json.loads(regular_run.readlines()[-1])["learning/loss"] + print(f"[Gradient Accumulation Test] Loss with gradient accumulation: {accum_run_loss}", flush=True) + print(f"[Gradient Accumulation Test] Loss without gradient accumulation: {regular_run_loss}", flush=True) + # Not identical due to an epsilon addition in loss denominator. + np.testing.assert_allclose(accum_run_loss, regular_run_loss, rtol=0.01) # Assert grad norms roughly equal - with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\ - open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run: - accum_run_grad_norm= json.loads(accum_run.readlines()[-1])["learning/raw_grad_norm"] - regular_run_grad_norm = json.loads(regular_run.readlines()[-1])["learning/raw_grad_norm"] - print(f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}", flush=True) - print(f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}", flush=True) - # Not identical due to an epsilon addition in loss denominator. - np.testing.assert_allclose(accum_run_grad_norm, regular_run_grad_norm, rtol=0.01) + with ( + open(run_accumulate_metrics_file, "r", encoding="utf8") as accum_run, + open(run_regular_metrics_file, "r", encoding="utf8") as regular_run, + ): + accum_run_grad_norm = json.loads(accum_run.readlines()[-1])["learning/raw_grad_norm"] + regular_run_grad_norm = json.loads(regular_run.readlines()[-1])["learning/raw_grad_norm"] + print(f"[Gradient Accumulation Test] Grad norm with gradient accumulation: {accum_run_grad_norm}", flush=True) + print(f"[Gradient Accumulation Test] Grad norm without gradient accumulation: {regular_run_grad_norm}", flush=True) + # Not identical due to an epsilon addition in loss denominator. + np.testing.assert_allclose(accum_run_grad_norm, regular_run_grad_norm, rtol=0.01) # Assert per device tflops are the same (10x smaller microbatch size, but 10x more microbatches) - with open(run_accumulate_metrics_file, 'r', encoding='utf8') as accum_run,\ - open(run_regular_metrics_file, 'r', encoding='utf8') as regular_run: - accum_device_tflops = json.loads(accum_run.readlines()[-1])["perf/per_device_tflops"] - regular_device_tflops = json.loads(regular_run.readlines()[-1])["perf/per_device_tflops"] - print(f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}", flush=True) - print(f"[Gradient Accumulation Test] per_device_tflops without gradient accumulation: {regular_device_tflops}", flush=True) - np.testing.assert_equal(accum_device_tflops, regular_device_tflops) + with ( + open(run_accumulate_metrics_file, "r", encoding="utf8") as accum_run, + open(run_regular_metrics_file, "r", encoding="utf8") as regular_run, + ): + accum_device_tflops = json.loads(accum_run.readlines()[-1])["perf/per_device_tflops"] + regular_device_tflops = json.loads(regular_run.readlines()[-1])["perf/per_device_tflops"] + print(f"[Gradient Accumulation Test] per_device_tflops with gradient accumulation: {accum_device_tflops}", flush=True) + print( + f"[Gradient Accumulation Test] per_device_tflops without gradient accumulation: {regular_device_tflops}", + flush=True, + ) + np.testing.assert_equal(accum_device_tflops, regular_device_tflops) diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 4e41cce03..f85aaa126 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -14,7 +14,6 @@ limitations under the License. """ - import subprocess import sys import jax @@ -61,8 +60,7 @@ def setUp(self): self.train_iter = self._get_train_iterator() def _get_train_iterator(self): - train_iter, _ = _grain_data_processing.make_grain_iterator( - self.config, self.mesh, self.process_indices) + train_iter, _ = _grain_data_processing.make_grain_iterator(self.config, self.mesh, self.process_indices) return train_iter def test_train_ds(self): diff --git a/MaxText/tests/hf_data_processing_test.py b/MaxText/tests/hf_data_processing_test.py index f2bb3ee8c..941d10a65 100644 --- a/MaxText/tests/hf_data_processing_test.py +++ b/MaxText/tests/hf_data_processing_test.py @@ -51,8 +51,7 @@ def setUp(self): self.train_iter = self._get_train_iterator() def _get_train_iterator(self): - train_iter, _ = _hf_data_processing.make_hf_iterator( - self.config, self.mesh, self.process_indices) + train_iter, _ = _hf_data_processing.make_hf_iterator(self.config, self.mesh, self.process_indices) return train_iter def test_train_ds(self): @@ -95,5 +94,6 @@ def get_first_batch(iterator): self.assertTrue((train_batch1["inputs"] == train_batch2["inputs"]).all()) self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all()) + if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index b305562c2..c28de3dcc 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -25,17 +25,19 @@ class Inference_Microbenchmark(unittest.TestCase): @pytest.mark.tpu def test(self): - pyconfig.initialize([ - None, - "configs/tpu_smoke_test.yml", - "tokenizer_path=../assets/tokenizer.llama2", - "ici_autoregressive_parallelism=-1", - "ici_fsdp_parallelism=1", - "max_prefill_predict_length=1024", - "max_target_length=2048", - "scan_layers=false", - "weight_dtype=bfloat16", - ]) + pyconfig.initialize( + [ + None, + "configs/tpu_smoke_test.yml", + "tokenizer_path=../assets/tokenizer.llama2", + "ici_autoregressive_parallelism=-1", + "ici_fsdp_parallelism=1", + "max_prefill_predict_length=1024", + "max_target_length=2048", + "scan_layers=false", + "weight_dtype=bfloat16", + ] + ) inference_microbenchmark_main(pyconfig.config) diff --git a/MaxText/tests/kernels_test.py b/MaxText/tests/kernels_test.py index e1aac2f94..5ec2d1c17 100644 --- a/MaxText/tests/kernels_test.py +++ b/MaxText/tests/kernels_test.py @@ -26,6 +26,7 @@ class RaggedAttentionTest(unittest.TestCase): """Tests for ragged attention kernel.""" + batch_size = 4 num_kv_heads = 8 num_query_heads = 32 @@ -37,7 +38,6 @@ class RaggedAttentionTest(unittest.TestCase): key = jax.random.key(0) k1, k2, k3 = jax.random.split(key, 3) - @pytest.mark.tpu def test_ragged_mqa(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.head_dim), dtype=self.dtype) @@ -47,36 +47,62 @@ def test_ragged_mqa(self): ragged_out, ragged_max, ragged_denom = ragged_mqa(q, k, v, lengths) reference_out, reference_max, reference_denom = reference_mqa(q, k, v, lengths) - self.assertTrue(jnp.max(abs(ragged_out - reference_out)) < 1e-1, msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1") - self.assertTrue(jnp.average(abs(ragged_out - reference_out)) < 1e-2, msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2") - + self.assertTrue( + jnp.max(abs(ragged_out - reference_out)) < 1e-1, + msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1", + ) + self.assertTrue( + jnp.average(abs(ragged_out - reference_out)) < 1e-2, + msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2", + ) @pytest.mark.tpu def test_ragged_mha(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype) - k = jax.random.normal(self.k2, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype) - v = jax.random.normal(self.k3, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype) + k = jax.random.normal( + self.k2, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype + ) + v = jax.random.normal( + self.k3, (self.batch_size, self.max_target_length, self.num_query_heads, self.head_dim), dtype=self.dtype + ) lengths = jnp.array(np.random.randint(1, self.max_target_length, self.batch_size), dtype=jnp.int32) ragged_out, ragged_max, ragged_denom = ragged_mha(q, k, v, lengths) ragged_out = ragged_out / ragged_denom reference_out, reference_max, reference_denom = reference_mha(q, k, v, lengths) - self.assertTrue(jnp.max(abs(ragged_out - reference_out)) < 1e-1, msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1") - self.assertTrue(jnp.average(abs(ragged_out - reference_out)) < 1e-2, msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2") - + self.assertTrue( + jnp.max(abs(ragged_out - reference_out)) < 1e-1, + msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1", + ) + self.assertTrue( + jnp.average(abs(ragged_out - reference_out)) < 1e-2, + msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2", + ) @pytest.mark.tpu def test_ragged_gqa(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype) - k = jax.random.normal(self.k2, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype) - v = jax.random.normal(self.k3, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype) + k = jax.random.normal( + self.k2, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype + ) + v = jax.random.normal( + self.k3, (self.batch_size, self.max_target_length, self.num_kv_heads, self.head_dim), dtype=self.dtype + ) lengths = jnp.array(np.random.randint(1, self.max_target_length, self.batch_size), dtype=jnp.int32) ragged_out, ragged_max, ragged_denom = ragged_gqa(q, k, v, lengths) ragged_out = ragged_out / ragged_denom - reference_out, reference_max, reference_denom = reference_gqa(jnp.squeeze(q), jnp.swapaxes(k, 1, 2), jnp.swapaxes(v, 1, 2), lengths) - self.assertTrue(jnp.max(abs(ragged_out - reference_out)) < 1e-1, msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1") - self.assertTrue(jnp.average(abs(ragged_out - reference_out)) < 1e-2, msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2") + reference_out, reference_max, reference_denom = reference_gqa( + jnp.squeeze(q), jnp.swapaxes(k, 1, 2), jnp.swapaxes(v, 1, 2), lengths + ) + self.assertTrue( + jnp.max(abs(ragged_out - reference_out)) < 1e-1, + msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1", + ) + self.assertTrue( + jnp.average(abs(ragged_out - reference_out)) < 1e-2, + msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2", + ) if __name__ == "__main__": diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index a958a8a13..03d09a593 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -94,10 +94,12 @@ def test_rope(self): position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings from MaxText implementation - query_proj = embeddings.RotaryEmbedding(min_timescale=1, max_timescale = 10_000, embedding_dims=dim_per_head)( + query_proj = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head)( permute_to_match_maxtext_rope(x_q), position=position ) - key_proj = embeddings.RotaryEmbedding(min_timescale=1, max_timescale = 10_000, embedding_dims=dim_per_head)(permute_to_match_maxtext_rope(x_k), position=position) + key_proj = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head)( + permute_to_match_maxtext_rope(x_k), position=position + ) # Compare results self.assertTrue( @@ -118,12 +120,16 @@ def test_scaling_rope(self): position = jnp.arange(seq_len, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings and then scale - query_proj_1 = embeddings.RotaryEmbedding(min_timescale=1, max_timescale = 10_000, embedding_dims=dim_per_head)(x_q, position=position) + query_proj_1 = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head)( + x_q, position=position + ) query_proj_1 = query_proj_1 * (dim_per_head**-0.5) # scale first and then apply RoPE query_proj_2 = x_q * (dim_per_head**-0.5) - query_proj_2 = embeddings.RotaryEmbedding(min_timescale=1, max_timescale = 10_000, embedding_dims=dim_per_head)(query_proj_2, position=position) + query_proj_2 = embeddings.RotaryEmbedding(min_timescale=1, max_timescale=10_000, embedding_dims=dim_per_head)( + query_proj_2, position=position + ) self.assertTrue(jax.numpy.allclose(query_proj_2, query_proj_1, rtol=1e-01, atol=1e-04, equal_nan=False)) diff --git a/MaxText/tests/maxtext_utils_test.py b/MaxText/tests/maxtext_utils_test.py index 5b15a4827..9362a5326 100644 --- a/MaxText/tests/maxtext_utils_test.py +++ b/MaxText/tests/maxtext_utils_test.py @@ -20,58 +20,66 @@ import maxtext_utils + class TestGradientClipping(unittest.TestCase): - def test_grad_clipping_with_no_fp8_stats(self): - raw_grads = {"params": jnp.array([3.0, -4.0]), "wi_0": jnp.array([5.0, -6.0])} - clipped_grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, 1.0) - for param_name in raw_grads.keys(): - # The grads should all be clipped and not equal to what they were before - self.assertFalse(jnp.array_equal(raw_grads[param_name], clipped_grads[param_name])) - def test_fp8_stats_not_clipped_but_others_are(self): - raw_grads = {"params": {"wi_0":jnp.array([5.0, -6.0]), "wi_1":jnp.array([7.0, -8.0])}} - # Create the reference for how the params would be clipped if no fp8 stats were present - expected_clipped_grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, 1.0) + def test_grad_clipping_with_no_fp8_stats(self): + raw_grads = {"params": jnp.array([3.0, -4.0]), "wi_0": jnp.array([5.0, -6.0])} + clipped_grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, 1.0) + for param_name in raw_grads.keys(): + # The grads should all be clipped and not equal to what they were before + self.assertFalse(jnp.array_equal(raw_grads[param_name], clipped_grads[param_name])) + + def test_fp8_stats_not_clipped_but_others_are(self): + raw_grads = {"params": {"wi_0": jnp.array([5.0, -6.0]), "wi_1": jnp.array([7.0, -8.0])}} + # Create the reference for how the params would be clipped if no fp8 stats were present + expected_clipped_grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, 1.0) + + raw_grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = { + "amax_history_wi_0": jnp.array([3.0, -4.0]), + "scale_wi_0": jnp.array([13.2, -4.4]), + } + clipped_grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, 1.0) - raw_grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = {"amax_history_wi_0": jnp.array([3.0, -4.0]), "scale_wi_0": jnp.array([13.2, -4.4])} - clipped_grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, 1.0) + # Check all non-fp8 parameters have been clipped in a manner as if the fp8 stats were not present at all + for param_name in raw_grads["params"].keys(): + self.assertTrue(jnp.array_equal(expected_clipped_grads["params"][param_name], clipped_grads["params"][param_name])) - # Check all non-fp8 parameters have been clipped in a manner as if the fp8 stats were not present at all - for param_name in raw_grads['params'].keys(): - self.assertTrue(jnp.array_equal(expected_clipped_grads['params'][param_name], clipped_grads['params'][param_name])) + # Then check all fp8 parameters were not clipped at all + for param_name, raw_value in raw_grads[maxtext_utils.OVERWRITE_WITH_GRADIENT].items(): + self.assertTrue(jnp.array_equal(raw_value, clipped_grads[maxtext_utils.OVERWRITE_WITH_GRADIENT][param_name])) - # Then check all fp8 parameters were not clipped at all - for param_name, raw_value in raw_grads[maxtext_utils.OVERWRITE_WITH_GRADIENT].items(): - self.assertTrue(jnp.array_equal(raw_value, clipped_grads[maxtext_utils.OVERWRITE_WITH_GRADIENT][param_name])) class TestNestedValueRetrieval(unittest.TestCase): - def setUp(self): - self.test_dict = { - "level1": { - "level2": { - "key": 0.1, - } - }, - "empty_level": {} - } - - def test_valid_nested_key(self): - nested_key = ("level1", "level2", "key") - expected_value = 0.1 - result = maxtext_utils.get_nested_value(self.test_dict, nested_key, 0.0) - self.assertEqual(result, expected_value) - - def test_invalid_nested_key(self): - nested_key = ("level1", "nonexistent", "key") - expected_value = 0.0 - result = maxtext_utils.get_nested_value(self.test_dict, nested_key, 0.0) - self.assertEqual(result, expected_value) - - def test_empty_level(self): - nested_key = ("empty_level", "key") - expected_value = None - result = maxtext_utils.get_nested_value(self.test_dict, nested_key) - self.assertEqual(result, expected_value) - -if __name__ == '__main__': - unittest.main() + + def setUp(self): + self.test_dict = { + "level1": { + "level2": { + "key": 0.1, + } + }, + "empty_level": {}, + } + + def test_valid_nested_key(self): + nested_key = ("level1", "level2", "key") + expected_value = 0.1 + result = maxtext_utils.get_nested_value(self.test_dict, nested_key, 0.0) + self.assertEqual(result, expected_value) + + def test_invalid_nested_key(self): + nested_key = ("level1", "nonexistent", "key") + expected_value = 0.0 + result = maxtext_utils.get_nested_value(self.test_dict, nested_key, 0.0) + self.assertEqual(result, expected_value) + + def test_empty_level(self): + nested_key = ("empty_level", "key") + expected_value = None + result = maxtext_utils.get_nested_value(self.test_dict, nested_key) + self.assertEqual(result, expected_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index 2c4a08610..4e8b522ed 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -29,52 +29,68 @@ class TokenDroppingTest(unittest.TestCase): def setUp(self): super().setUp() pyconfig.initialize( - [None, 'configs/base.yml'], - run_name='moe_test', - enable_checkpointing=False, - model_name='mixtral-8x7b', - dtype='bfloat16', - megablox=False, - max_target_length=4, - per_device_batch_size=1, - capacity_factor=2, + [None, "configs/base.yml"], + run_name="moe_test", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + megablox=False, + max_target_length=4, + per_device_batch_size=1, + capacity_factor=2, ) self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(42) devices_array = max_utils.create_device_mesh(self.cfg) self.model = linears.MoeBlock( - config=self.cfg, - num_experts=self.cfg.num_experts, - num_experts_per_tok=self.cfg.num_experts_per_tok, - mesh=Mesh(devices_array, self.cfg.mesh_axes), - kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), - kernel_axes=('embed', 'mlp'), - dtype=self.cfg.dtype, - ) + config=self.cfg, + num_experts=self.cfg.num_experts, + num_experts_per_tok=self.cfg.num_experts_per_tok, + mesh=Mesh(devices_array, self.cfg.mesh_axes), + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", "mlp"), + dtype=self.cfg.dtype, + ) def test_generate_masks(self): # expert_capacity = (tokens_per_batch / num_experts) * capacity_factor # expert_capacity_in_batch = (4 * 2 / 8) * 2 = 2 - top_k_indices = jnp.array([[[0, 5],[0, 4],[1, 0],[3, 5]], - [[1, 2],[4, 1],[5, 0],[7, 1]], - [[6, 2],[2, 3],[4, 2],[1, 2]], - [[4, 1],[0, 7],[5, 0],[4, 7]]]) - softmax_probs = jnp.array([[[0.20, 0, 0, 0, 0, 0.80, 0, 0], - [0.68, 0, 0, 0, 0.32, 0, 0, 0], - [0.22, 0.78, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0.32, 0, 0.68, 0, 0]], - [[0, 0.26, 0.74, 0, 0, 0, 0, 0], - [0, 0.79, 0, 0, 0.21, 0, 0, 0], - [0.89, 0, 0, 0, 0, 0.11, 0, 0], - [0, 0.11, 0, 0, 0, 0, 0, 0.89]], - [[0, 0, 0.26, 0, 0, 0, 0.74, 0], - [0, 0, 0.88, 0.12, 0, 0, 0, 0], - [0, 0, 0.17, 0, 0.83, 0, 0, 0], - [0, 0.35, 0.65, 0, 0, 0, 0, 0]], - [[0, 0.47, 0, 0, 0.53, 0, 0, 0], - [0.36, 0, 0, 0, 0, 0, 0, 0.64], - [0.15, 0, 0, 0, 0, 0.85, 0, 0], - [0, 0, 0, 0, 0.18, 0, 0, 0.82]]]) + top_k_indices = jnp.array( + [ + [[0, 5], [0, 4], [1, 0], [3, 5]], + [[1, 2], [4, 1], [5, 0], [7, 1]], + [[6, 2], [2, 3], [4, 2], [1, 2]], + [[4, 1], [0, 7], [5, 0], [4, 7]], + ] + ) + softmax_probs = jnp.array( + [ + [ + [0.20, 0, 0, 0, 0, 0.80, 0, 0], + [0.68, 0, 0, 0, 0.32, 0, 0, 0], + [0.22, 0.78, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0.32, 0, 0.68, 0, 0], + ], + [ + [0, 0.26, 0.74, 0, 0, 0, 0, 0], + [0, 0.79, 0, 0, 0.21, 0, 0, 0], + [0.89, 0, 0, 0, 0, 0.11, 0, 0], + [0, 0.11, 0, 0, 0, 0, 0, 0.89], + ], + [ + [0, 0, 0.26, 0, 0, 0, 0.74, 0], + [0, 0, 0.88, 0.12, 0, 0, 0, 0], + [0, 0, 0.17, 0, 0.83, 0, 0, 0], + [0, 0.35, 0.65, 0, 0, 0, 0, 0], + ], + [ + [0, 0.47, 0, 0, 0.53, 0, 0, 0], + [0.36, 0, 0, 0, 0, 0, 0, 0.64], + [0.15, 0, 0, 0, 0, 0.85, 0, 0], + [0, 0, 0, 0, 0.18, 0, 0, 0.82], + ], + ] + ) # As expert_capacity_in_batch=2, so updated softmax_probs become (4 tokens were dropped): # softmax_probs = jnp.array([[[0.20, 0, 0, 0, 0, 0.80, 0, 0], @@ -95,22 +111,35 @@ def test_generate_masks(self): # [0, 0, 0, 0, 0.18, 0, 0, 0.82]]]) # shape of dispatch_mask & combine_mask: (batch_size, seq_len, num_experts, expert_capacity_per_batch) - expected_combine_mask = jnp.array([[[[0.2,0],[0,0],[0,0],[0,0],[0,0],[0.8,0],[0,0],[0,0]], - [[0,0.68],[0,0],[0,0],[0,0],[0.32,0],[0,0],[0,0],[0,0]], - [[0,0],[0.78,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0]], - [[0,0],[0,0],[0,0],[0.32,0],[0,0],[0,0.68],[0,0],[0,0]]], - [[[0,0],[0.26,0],[0.74,0],[0,0],[0,0],[0,0],[0,0],[0,0]], - [[0,0],[0,0.79],[0,0],[0,0],[0.21,0],[0,0],[0,0],[0,0]], - [[0.89,0],[0,0],[0,0],[0,0],[0,0],[0.11,0],[0,0],[0,0]], - [[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0.89,0]]], - [[[0,0],[0,0],[0.26,0],[0,0],[0,0],[0,0],[0.74,0],[0,0]], - [[0,0],[0,0],[0,0.88],[0.12,0],[0,0],[0,0],[0,0],[0,0]], - [[0,0],[0,0],[0,0],[0,0],[0.83,0],[0,0],[0,0],[0,0]], - [[0,0],[0.35,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0]]], - [[[0,0],[0.47,0],[0,0],[0,0],[0.53,0],[0,0],[0,0],[0,0]], - [[0.36,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0,0],[0.64,0]], - [[0,0.15],[0,0],[0,0],[0,0],[0,0],[0.85,0],[0,0],[0,0]], - [[0,0],[0,0],[0,0],[0,0],[0,0.18],[0,0],[0,0],[0,0.82]]]], dtype=jnp.float32) + expected_combine_mask = jnp.array( + [ + [ + [[0.2, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.8, 0], [0, 0], [0, 0]], + [[0, 0.68], [0, 0], [0, 0], [0, 0], [0.32, 0], [0, 0], [0, 0], [0, 0]], + [[0, 0], [0.78, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0], [0.32, 0], [0, 0], [0, 0.68], [0, 0], [0, 0]], + ], + [ + [[0, 0], [0.26, 0], [0.74, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0.79], [0, 0], [0, 0], [0.21, 0], [0, 0], [0, 0], [0, 0]], + [[0.89, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.11, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.89, 0]], + ], + [ + [[0, 0], [0, 0], [0.26, 0], [0, 0], [0, 0], [0, 0], [0.74, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0.88], [0.12, 0], [0, 0], [0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0], [0, 0], [0.83, 0], [0, 0], [0, 0], [0, 0]], + [[0, 0], [0.35, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], + ], + [ + [[0, 0], [0.47, 0], [0, 0], [0, 0], [0.53, 0], [0, 0], [0, 0], [0, 0]], + [[0.36, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.64, 0]], + [[0, 0.15], [0, 0], [0, 0], [0, 0], [0, 0], [0.85, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0.18], [0, 0], [0, 0], [0, 0.82]], + ], + ], + dtype=jnp.float32, + ) expected_dispatch_mask = expected_combine_mask.astype(bool) actual_dispatch_mask, actual_combine_mask = self.model.generate_masks(top_k_indices, softmax_probs) diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 731b98d61..299cfe3fb 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -40,7 +40,6 @@ from train import main as train_main - def assert_same_output_and_grad(f1, f2, *inputs): f1_value, f1_grad = jax.value_and_grad(f1)(*inputs) f2_value, f2_grad = jax.value_and_grad(f2)(*inputs) @@ -49,6 +48,7 @@ def pytree_ravel(pytree): ravelled_tree = jax.tree.map(jnp.ravel, pytree) ravelled_leaves, _ = jax.tree_util.tree_flatten(ravelled_tree) return jnp.concatenate(ravelled_leaves) + f1_grad = pytree_ravel(f1_grad) f2_grad = pytree_ravel(f2_grad) @@ -63,78 +63,97 @@ def assert_pipeline_same_output_and_grad(self, config): mesh = Mesh(devices_array, config.mesh_axes) def get_inputs(batch_size, sequence, features): - '''Get random inputs, and random dummy targets - Returns - inputs: [batch_size, sequence, features] - targets: [batch_size, sequence, features] - positions: [batch_size, sequence] - segmentations: [batch_size, segmentation] - ''' - input_shape = [batch_size, sequence, features] - inputs = jax.random.normal(jax.random.PRNGKey(2), input_shape, dtype=jnp.float32) - - # dummy targets same shape as inputs to use for a dummy loss function to check gradient correctness - dummy_targets = jax.random.normal(jax.random.PRNGKey(3),input_shape, dtype=jnp.float32) - - inputs_position = jnp.array([jnp.arange(sequence, dtype=jnp.int32) for _ in range(batch_size)], dtype=jnp.int32) - inputs_segmentation = jnp.ones((batch_size, sequence), dtype=jnp.int32) - return inputs, dummy_targets, inputs_position, inputs_segmentation - - inputs, dummy_targets, inputs_position, inputs_segmentation = get_inputs(config.global_batch_size_to_train_on, config.max_target_length, config.emb_dim) + """Get random inputs, and random dummy targets + Returns + inputs: [batch_size, sequence, features] + targets: [batch_size, sequence, features] + positions: [batch_size, sequence] + segmentations: [batch_size, segmentation] + """ + input_shape = [batch_size, sequence, features] + inputs = jax.random.normal(jax.random.PRNGKey(2), input_shape, dtype=jnp.float32) + + # dummy targets same shape as inputs to use for a dummy loss function to check gradient correctness + dummy_targets = jax.random.normal(jax.random.PRNGKey(3), input_shape, dtype=jnp.float32) + + inputs_position = jnp.array([jnp.arange(sequence, dtype=jnp.int32) for _ in range(batch_size)], dtype=jnp.int32) + inputs_segmentation = jnp.ones((batch_size, sequence), dtype=jnp.int32) + return inputs, dummy_targets, inputs_position, inputs_segmentation + + inputs, dummy_targets, inputs_position, inputs_segmentation = get_inputs( + config.global_batch_size_to_train_on, config.max_target_length, config.emb_dim + ) deterministic = True model_mode = common_types.MODEL_MODE_TRAIN - # We use a simpler single matmul decoder layer for fast compilation in these tests. + # We use a simpler single matmul decoder layer for fast compilation in these tests. single_pipeline_stage = simple_layer.SimpleDecoderLayer(config=config, mesh=mesh) - my_pipeline = pipeline.Pipeline( - config=config, - layers=single_pipeline_stage, - mesh=mesh + my_pipeline = pipeline.Pipeline(config=config, layers=single_pipeline_stage, mesh=mesh) + init_pipeline_params = my_pipeline.init( + jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode ) - init_pipeline_params = my_pipeline.init(jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode) # Create a dummy scalar loss function so we may take the gradient wrt weights - def pipeline_parallelism_dummy_loss(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets): - outputs = my_pipeline.apply(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) - loss = jnp.linalg.norm(outputs - dummy_targets) - return loss - - def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): - def get_cur_layer_params(params, layer_idx): - def get_cur_layer_params_arr(leaf): - # Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers] - if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage == 1: - new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] - leaf = jnp.reshape(leaf, new_shape) # [repeat, stage] -> [layers] - elif config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: - new_shape = (leaf.shape[0] * leaf.shape[1] * leaf.shape[2],) + leaf.shape[3:] - leaf = jnp.reshape(leaf, new_shape) # [repeat, stage, layers_per_stage] -> [layers] - elif config.num_pipeline_repeats == 1 and config.num_layers_per_pipeline_stage > 1: - new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] - leaf = jnp.reshape(leaf, new_shape) # [stage, layers_per_stage] -> [layers] - return leaf[layer_idx] - return jax.tree.map(get_cur_layer_params_arr, params) - - reg_layer_activations = inputs - for layer in range(config.num_decoder_layers): - cur_layer_params = get_cur_layer_params(params, layer) - cur_layer_params['params'] = cur_layer_params['params']['layers'] - if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: - cur_layer_params['params'] = meta.remove_axis(cur_layer_params['params'], 0, {nn.PARTITION_NAME:"circular_repeats"}) - cur_layer_params['params'] = meta.remove_axis(cur_layer_params['params'], 0, {nn.PARTITION_NAME:"layers"}) - reg_layer_activations, _ = single_pipeline_stage.apply(cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode) - return reg_layer_activations - - def regular_sequential_layers_dummy_loss(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets): - outputs = regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) - loss = jnp.linalg.norm(outputs - dummy_targets) - return loss - - assert_same_output_and_grad(regular_sequential_layers_dummy_loss, pipeline_parallelism_dummy_loss, init_pipeline_params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets) + def pipeline_parallelism_dummy_loss( + params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets + ): + outputs = my_pipeline.apply(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) + loss = jnp.linalg.norm(outputs - dummy_targets) + return loss + + def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): + def get_cur_layer_params(params, layer_idx): + def get_cur_layer_params_arr(leaf): + # Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage == 1: + new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage] -> [layers] + elif config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + new_shape = (leaf.shape[0] * leaf.shape[1] * leaf.shape[2],) + leaf.shape[3:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage, layers_per_stage] -> [layers] + elif config.num_pipeline_repeats == 1 and config.num_layers_per_pipeline_stage > 1: + new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] + leaf = jnp.reshape(leaf, new_shape) # [stage, layers_per_stage] -> [layers] + return leaf[layer_idx] + + return jax.tree.map(get_cur_layer_params_arr, params) + + reg_layer_activations = inputs + for layer in range(config.num_decoder_layers): + cur_layer_params = get_cur_layer_params(params, layer) + cur_layer_params["params"] = cur_layer_params["params"]["layers"] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + cur_layer_params["params"] = meta.remove_axis( + cur_layer_params["params"], 0, {nn.PARTITION_NAME: "circular_repeats"} + ) + cur_layer_params["params"] = meta.remove_axis(cur_layer_params["params"], 0, {nn.PARTITION_NAME: "layers"}) + reg_layer_activations, _ = single_pipeline_stage.apply( + cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode + ) + return reg_layer_activations + + def regular_sequential_layers_dummy_loss( + params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets + ): + outputs = regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) + loss = jnp.linalg.norm(outputs - dummy_targets) + return loss + + assert_same_output_and_grad( + regular_sequential_layers_dummy_loss, + pipeline_parallelism_dummy_loss, + init_pipeline_params, + inputs, + inputs_segmentation, + inputs_position, + deterministic, + model_mode, + dummy_targets, + ) @pytest.mark.tpu def test_circular_minimum_microbatches_same_output_and_grad(self): - # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches - pyconfig.initialize( + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches + pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="circular_minimum_microbatches", @@ -143,15 +162,15 @@ def test_circular_minimum_microbatches_same_output_and_grad(self): ici_pipeline_parallelism=4, base_num_decoder_layers=8, num_pipeline_microbatches=4, - per_device_batch_size=4 - ) - config = pyconfig.config - self.assert_pipeline_same_output_and_grad(config) + per_device_batch_size=4, + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu def test_circular_extra_microbatches_same_output_and_grad(self): - # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches - pyconfig.initialize( + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="circular_extra_microbatches", @@ -160,15 +179,15 @@ def test_circular_extra_microbatches_same_output_and_grad(self): ici_pipeline_parallelism=4, base_num_decoder_layers=8, num_pipeline_microbatches=8, - per_device_batch_size=4 - ) - config = pyconfig.config - self.assert_pipeline_same_output_and_grad(config) + per_device_batch_size=4, + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu def test_non_circular_same_output_and_grad(self): - # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches - pyconfig.initialize( + # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches + pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="non_circular", @@ -177,42 +196,44 @@ def test_non_circular_same_output_and_grad(self): ici_pipeline_parallelism=4, base_num_decoder_layers=4, num_pipeline_microbatches=4, - per_device_batch_size=4 - ) - config = pyconfig.config - self.assert_pipeline_same_output_and_grad(config) + per_device_batch_size=4, + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu def test_full_train_circular(self): # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches - train_main([ - None, - "configs/base.yml", - r"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_pipeline_parallelism_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=28", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=32", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "vocab_size=32", - "dataset_type=synthetic", - "steps=3", - "enable_checkpointing=False", - "ici_pipeline_parallelism=4", - "num_layers_per_pipeline_stage=2", - "num_pipeline_microbatches=8", - "tokenizer_path=../assets/tokenizer.llama2", - ]) + train_main( + [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=2", + "num_pipeline_microbatches=8", + "tokenizer_path=../assets/tokenizer.llama2", + ] + ) @pytest.mark.tpu def test_delay_activation_forwarding_same_output_and_grad(self): - # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches - pyconfig.initialize( + # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="activation_forwarding", @@ -222,38 +243,40 @@ def test_delay_activation_forwarding_same_output_and_grad(self): base_num_decoder_layers=8, num_pipeline_microbatches=8, per_device_batch_size=4, - pipeline_delay_activation_forwarding=True - ) - config = pyconfig.config - self.assert_pipeline_same_output_and_grad(config) + pipeline_delay_activation_forwarding=True, + ) + config = pyconfig.config + self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu def test_full_train_non_circular(self): # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches - train_main([ - None, - "configs/base.yml", - r"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_pipeline_parallelism_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=28", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=32", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "vocab_size=32", - "dataset_type=synthetic", - "steps=3", - "enable_checkpointing=False", - "ici_pipeline_parallelism=4", - "num_layers_per_pipeline_stage=8", - "num_pipeline_microbatches=8", - "tokenizer_path=../assets/tokenizer.llama2", - - ]) + train_main( + [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=8", + "num_pipeline_microbatches=8", + "tokenizer_path=../assets/tokenizer.llama2", + ] + ) + if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/profiler_test.py b/MaxText/tests/profiler_test.py index c4ef6b59e..04ddeff5c 100644 --- a/MaxText/tests/profiler_test.py +++ b/MaxText/tests/profiler_test.py @@ -24,7 +24,6 @@ class TpuJAXTest(unittest.TestCase): - """Test for profile collected with JAX.""" def _get_session_snapshot(self): diff --git a/MaxText/tests/pyconfig_test.py b/MaxText/tests/pyconfig_test.py index 710dc281d..bae6600dd 100644 --- a/MaxText/tests/pyconfig_test.py +++ b/MaxText/tests/pyconfig_test.py @@ -11,78 +11,55 @@ limitations under the License. """ - import unittest import pyconfig + class PyconfigTest(unittest.TestCase): """Tests for pyconfig.py""" def test_basic_override(self): - raw_keys = { - 'megablox': None, - 'foo': ['bar', 'baz'] - } - model_keys = { - 'foo': ['x', 'y'] - } + raw_keys = {"megablox": None, "foo": ["bar", "baz"]} + model_keys = {"foo": ["x", "y"]} - pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config') + pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name="config") - self.assertEqual(raw_keys, { - 'megablox': None, - 'foo': ['x', 'y'] - }) + self.assertEqual(raw_keys, {"megablox": None, "foo": ["x", "y"]}) def test_logical_axis_override(self): raw_keys = { - 'megablox': None, - 'foo': ['bar', 'baz'], - 'logical_axis_rules': [ - ['activation', ['data', 'fsdp']], - ['norm', 'tensor'] - ] - } - model_keys = { - 'logical_axis_rules': [ - ['activation', ['data', 'fsdp_transpose']], - ['norm', 'fsdp'] - ] + "megablox": None, + "foo": ["bar", "baz"], + "logical_axis_rules": [["activation", ["data", "fsdp"]], ["norm", "tensor"]], } + model_keys = {"logical_axis_rules": [["activation", ["data", "fsdp_transpose"]], ["norm", "fsdp"]]} - pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config') + pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name="config") - self.assertEqual(raw_keys, { - 'megablox': None, - 'foo': ['bar', 'baz'], - 'logical_axis_rules': [ - ('activation', ['data', 'fsdp_transpose']), - ('norm', 'fsdp') - ] - }) + self.assertEqual( + raw_keys, + { + "megablox": None, + "foo": ["bar", "baz"], + "logical_axis_rules": [("activation", ["data", "fsdp_transpose"]), ("norm", "fsdp")], + }, + ) def test_logical_axis_partial_override(self): raw_keys = { - 'megablox': None, - 'foo': ['bar', 'baz'], - 'logical_axis_rules': [ - ['activation', ['data', 'fsdp']], - ['norm', 'tensor'] - ] - } - model_keys = { - 'logical_axis_rules': [ - ['norm', 'fsdp'] - ] + "megablox": None, + "foo": ["bar", "baz"], + "logical_axis_rules": [["activation", ["data", "fsdp"]], ["norm", "tensor"]], } + model_keys = {"logical_axis_rules": [["norm", "fsdp"]]} - pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name='config') + pyconfig.validate_and_update_keys(raw_keys, model_keys, config_name="config") - self.assertEqual(raw_keys, { - 'megablox': None, - 'foo': ['bar', 'baz'], - 'logical_axis_rules': [ - ('activation', ('data', 'fsdp')), - ('norm', 'fsdp') - ] - }) + self.assertEqual( + raw_keys, + { + "megablox": None, + "foo": ["bar", "baz"], + "logical_axis_rules": [("activation", ("data", "fsdp")), ("norm", "fsdp")], + }, + ) diff --git a/MaxText/tests/simple_decoder_layer_test.py b/MaxText/tests/simple_decoder_layer_test.py index 253071de4..afa2d0aeb 100644 --- a/MaxText/tests/simple_decoder_layer_test.py +++ b/MaxText/tests/simple_decoder_layer_test.py @@ -17,33 +17,39 @@ class SimpleDecoderLayerTest(unittest.TestCase): + @pytest.mark.tpu def test_simple_decoder_layer(self): - train_main([ - None, - "configs/base.yml", - r"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_simple_decoder_layer_test", - r"dataset_path=gs://maxtext-dataset", - "decoder_block=simple", - "enable_checkpointing=False", - "tokenizer_path=../assets/tokenizer.llama2", - "steps=3" - ]) + train_main( + [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_simple_decoder_layer_test", + r"dataset_path=gs://maxtext-dataset", + "decoder_block=simple", + "enable_checkpointing=False", + "tokenizer_path=../assets/tokenizer.llama2", + "steps=3", + ] + ) @pytest.mark.tpu def test_mlp_decoder_layer(self): - train_main([ - None, - "configs/base.yml", - r"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_simple_decoder_layer_test", - r"dataset_path=gs://maxtext-dataset", - "decoder_block=simple_mlp", - "enable_checkpointing=False", - "tokenizer_path=../assets/tokenizer.llama2", - "steps=3" - ]) + train_main( + [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_simple_decoder_layer_test", + r"dataset_path=gs://maxtext-dataset", + "decoder_block=simple_mlp", + "enable_checkpointing=False", + "tokenizer_path=../assets/tokenizer.llama2", + "steps=3", + ] + ) + if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index d9befd1e8..1bd774946 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -37,54 +37,60 @@ def _get_random_test_name(self, test_name): @pytest.mark.tpu def test_standalone_dataloader(self): random_run_name = self._get_random_test_name("standalone_dataloader") - sdl_main(( - None, - "configs/base.yml", - "run_name=" + random_run_name, - "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", - "steps=100", - "enable_checkpointing=false", - "tokenizer_path=../assets/tokenizer.llama2", - )) # need to pass relative path to tokenizer + sdl_main( + ( + None, + "configs/base.yml", + "run_name=" + random_run_name, + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "steps=100", + "enable_checkpointing=false", + "tokenizer_path=../assets/tokenizer.llama2", + ) + ) # need to pass relative path to tokenizer @pytest.mark.tpu def test_standalone_checkpointer(self): random_run_name = self._get_random_test_name("standalone_checkpointer") # checkpoint at 50 - sckpt_main(( - None, - "configs/base.yml", - f"run_name={random_run_name}", - "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", - "base_emb_dim=128", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=128", - "base_num_decoder_layers=2", - "steps=60", - "enable_checkpointing=True", - "checkpoint_period=50", - "async_checkpointing=False", - )) + sckpt_main( + ( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=60", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + ) + ) # restore at 50 and checkpoint at 100 - sckpt_main(( - None, - "configs/base.yml", - f"run_name={random_run_name}", - "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", - "base_emb_dim=128", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=128", - "base_num_decoder_layers=2", - "steps=110", - "enable_checkpointing=True", - "checkpoint_period=50", - "async_checkpointing=False", - )) + sckpt_main( + ( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=110", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + ) + ) if __name__ == "__main__": diff --git a/MaxText/tests/tfds_data_processing_test.py b/MaxText/tests/tfds_data_processing_test.py index 74e6a585e..7958cc588 100644 --- a/MaxText/tests/tfds_data_processing_test.py +++ b/MaxText/tests/tfds_data_processing_test.py @@ -65,14 +65,12 @@ def _get_datasets(self): input_pipeline_id=jax.process_index(), num_input_pipelines=jax.process_count(), ) - ds = ds_builder.as_dataset(split="train", read_config=self.read_config, - shuffle_files=self.config.enable_data_shuffling) + ds = ds_builder.as_dataset(split="train", read_config=self.read_config, shuffle_files=self.config.enable_data_shuffling) return ds def _get_train_iterator(self): - train_iter, eval_iter = _tfds_data_processing.make_tfds_iterator( - self.config, self.mesh, self.process_indices) + train_iter, eval_iter = _tfds_data_processing.make_tfds_iterator(self.config, self.mesh, self.process_indices) return train_iter, eval_iter def test_train_ds(self): diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index 8d2ec7e90..c5222f0de 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -60,16 +60,14 @@ def tearDownClass(cls): @pytest.mark.skip(reason="mohitkhatwani@ will fix this") @pytest.mark.tpu def test_tokenize(self): - text = 'This is a test' - self.assertTrue(np.array_equal(self.source_tokenizer.encode(text).numpy(), - self.test_tokenizer.encode(text).numpy())) + text = "This is a test" + self.assertTrue(np.array_equal(self.source_tokenizer.encode(text).numpy(), self.test_tokenizer.encode(text).numpy())) @pytest.mark.tpu def test_detokenize(self): tokens = [66, 12, 10, 698] - self.assertEqual(np.asarray(self.source_tokenizer.decode(tokens)), - np.asarray(self.test_tokenizer.decode(tokens))) - + self.assertEqual(np.asarray(self.source_tokenizer.decode(tokens)), np.asarray(self.test_tokenizer.decode(tokens))) + class TikTokenTest(unittest.TestCase): """Tests for train_tokenizer.py""" @@ -78,7 +76,9 @@ class TikTokenTest(unittest.TestCase): def setUpClass(cls): dataset_name = "c4/en:3.0.1" dataset_path = "gs://maxtext-dataset" - cls.source_tokenizer = _input_pipeline_utils.get_tokenizer("../assets/tokenizer_llama3.tiktoken", add_bos=False, add_eos=False) + cls.source_tokenizer = _input_pipeline_utils.get_tokenizer( + "../assets/tokenizer_llama3.tiktoken", add_bos=False, add_eos=False + ) os.environ["TFDS_DATA_DIR"] = dataset_path read_config = tfds.ReadConfig( shuffle_seed=0, @@ -88,16 +88,15 @@ def setUpClass(cls): @pytest.mark.tpu def test_tokenize(self): - text = 'This is a test' + text = "This is a test" tokens = [2028, 374, 264, 1296] self.assertTrue(np.array_equal(self.source_tokenizer.encode(text), tokens)) @pytest.mark.tpu def test_detokenize(self): tokens = [2028, 374, 264, 1296] - text = 'This is a test' - self.assertEqual(np.asarray(self.source_tokenizer.decode(tokens)), - np.asarray(text)) + text = "This is a test" + self.assertEqual(np.asarray(self.source_tokenizer.decode(tokens)), np.asarray(text)) if __name__ == "__main__": diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index d3db2c3b0..17fcf5bc7 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -27,158 +27,176 @@ class TrainCompile(unittest.TestCase): @pytest.mark.tpu def test_save_compiled_v4(self): compiled_trainstep_file = "/tmp/test_compiled_v4.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v4-8", - "compile_topology_num_slices=1", - "base_emb_dim=256", - "base_mlp_dim=256", - "base_num_decoder_layers=2", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v4-8", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + ) + ) @pytest.mark.tpu def test_save_compiled_v5e(self): compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-16", - "compile_topology_num_slices=1", - "base_emb_dim=256", - "base_mlp_dim=256", - "base_num_decoder_layers=2", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + ) + ) # TODO (b/366200617) : This tests fails in AOT, but config works fine on real hardware @pytest.mark.skip(reason="Issue w/ kernels_test. Error: The TPU is already in use by process...") def test_minimal_offloaded_v5e(self): compiled_trainstep_file = "/tmp/test_compiled_v5e_offload.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", - "compile_topology_num_slices=1", - "per_device_batch_size=1", - "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", - "max_target_length=2048", - "fused_qkv=true", - "fused_mlp=true", - "remat_policy=minimal_offloaded", - "use_iota_embed=true", - "global_parameter_scale=128", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=minimal_offloaded", + "use_iota_embed=true", + "global_parameter_scale=128", + ) + ) @pytest.mark.tpu def test_save_compiled_v5p_two_slices(self): compiled_trainstep_file = "/tmp/test_compiled_v5p_two_slices.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5p-8", - "compile_topology_num_slices=2", - "base_emb_dim=256", - "base_mlp_dim=256", - "base_num_decoder_layers=2", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + ) + ) @pytest.mark.tpu def test_sequence_parallelism(self): compiled_trainstep_file = "/tmp/test_compiled.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", - "use_iota_embed=true", - "compile_topology_num_slices=1", - "ici_sequence_parallelism=16", - "global_parameter_scale=32", - "per_device_batch_size=0.0625", - "max_target_length=65536", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "ici_sequence_parallelism=16", + "global_parameter_scale=32", + "per_device_batch_size=0.0625", + "max_target_length=65536", + ) + ) @pytest.mark.tpu def test_remat_save_dot_except_mlpwi(self): compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlpwi.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", - "compile_topology_num_slices=1", - "per_device_batch_size=0.125", - "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", - "max_target_length=2048", - "fused_qkv=true", - "fused_mlp=true", - "remat_policy=save_dot_except_mlpwi", - "use_iota_embed=true", - "global_parameter_scale=128", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.125", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlpwi", + "use_iota_embed=true", + "global_parameter_scale=128", + ) + ) @pytest.mark.tpu def test_remat_save_dot_except_mlp(self): compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlp.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", - "compile_topology_num_slices=1", - "per_device_batch_size=0.25", - "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", - "max_target_length=2048", - "fused_qkv=true", - "fused_mlp=true", - "remat_policy=save_dot_except_mlp", - "use_iota_embed=true", - "global_parameter_scale=128", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.25", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlp", + "use_iota_embed=true", + "global_parameter_scale=128", + ) + ) @pytest.mark.tpu def test_remat_save_qkv_proj(self): compiled_trainstep_file = "/tmp/test_remat_save_qkv_proj.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", - "compile_topology_num_slices=1", - "per_device_batch_size=0.375", - "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", - "max_target_length=2048", - "fused_qkv=true", - "fused_mlp=true", - "remat_policy=save_qkv_proj", - "use_iota_embed=true", - "global_parameter_scale=128", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.375", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_qkv_proj", + "use_iota_embed=true", + "global_parameter_scale=128", + ) + ) @pytest.mark.tpu def test_remat_full(self): compiled_trainstep_file = "/tmp/test_remat_full.pickle" - train_compile_main(( - None, - "configs/base.yml", - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", - "compile_topology_num_slices=1", - "per_device_batch_size=1", - "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", - "max_target_length=2048", - "fused_qkv=true", - "fused_mlp=true", - "remat_policy=full", - "use_iota_embed=true", - "global_parameter_scale=128", - )) + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=full", + "use_iota_embed=true", + "global_parameter_scale=128", + ) + ) diff --git a/MaxText/tests/train_gpu_smoke_test.py b/MaxText/tests/train_gpu_smoke_test.py index 5af635c4a..dde3c00dc 100644 --- a/MaxText/tests/train_gpu_smoke_test.py +++ b/MaxText/tests/train_gpu_smoke_test.py @@ -25,15 +25,17 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([ - None, - "third_party/py/maxtext/configs/gpu_smoke_test.yml", - f"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "enable_checkpointing=False", - "tokenizer_path=../assets/tokenizer.llama2", - ]) + train_main( + [ + None, + "third_party/py/maxtext/configs/gpu_smoke_test.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "enable_checkpointing=False", + "tokenizer_path=../assets/tokenizer.llama2", + ] + ) if __name__ == "__main__": diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 387e9ecd9..854e3a5ec 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -26,26 +26,28 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([ - None, - "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=8", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "dataset_type=synthetic", - "steps=10", - "enable_checkpointing=False", - "quantization=int8", - "tokenizer_path=../assets/tokenizer.llama2", - ]) + train_main( + [ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + "quantization=int8", + "tokenizer_path=../assets/tokenizer.llama2", + ] + ) if __name__ == "__main__": diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index f2d8ef10d..474025564 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -26,25 +26,27 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([ - None, - "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", - "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=8", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "dataset_type=synthetic", - "steps=10", - "enable_checkpointing=False", - "tokenizer_path=../assets/tokenizer.llama2", - ]) + train_main( + [ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + "tokenizer_path=../assets/tokenizer.llama2", + ] + ) if __name__ == "__main__": diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index e6f389569..98f3b186a 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -37,7 +37,9 @@ class TikTokenTokenizer: num_reserved_special_tokens = 256 - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # pylint: disable=line-too-long + pat_str = ( + r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # pylint: disable=line-too-long + ) def __init__(self, model_path: str, add_bos: bool, add_eos: bool): """ @@ -50,28 +52,23 @@ def __init__(self, model_path: str, add_bos: bool, add_eos: bool): mergeable_ranks = load_tiktoken_bpe(model_path) num_base_tokens = len(mergeable_ranks) special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [ - f"<|reserved_special_token_{i}|>" - for i in range(5, self.num_reserved_special_tokens - 5) - ] - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5)] + self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)} self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, ) self.eos = add_eos self.bos = add_bos @@ -86,9 +83,7 @@ def __init__(self, model_path: str, add_bos: bool, add_eos: bool): self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"], } - max_logging.log( - f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" - ) + max_logging.log(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") def encode( self, @@ -130,20 +125,20 @@ def encode( MAX_NO_WHITESPACES_CHARS = 25_000 substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) ) t: List[int] = [] for substr in substrs: t.extend( - self.model.encode( - substr, - allowed_special=set(allowed_special), - disallowed_special=disallowed_special, - ) + self.model.encode( + substr, + allowed_special=set(allowed_special), + disallowed_special=disallowed_special, + ) ) if self.bos: t.insert(0, self.bos_id) @@ -165,9 +160,7 @@ def decode(self, t) -> str: return self.model.decode(t) @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ): + def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int): """ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` consecutive whitespaces or consecutive non-whitespaces. @@ -195,6 +188,7 @@ class SentencePieceTokenizer: """ Tokenizing and encoding/decoding text using the Sentencepiece tokenizer. """ + def __init__(self, model_path: str, add_bos: bool, add_eos: bool): max_logging.log(f"Tokenizer path: {model_path}") with tf.io.gfile.GFile(model_path, "rb") as model_fp: @@ -207,6 +201,7 @@ def encode(self, s: str) -> List[int]: def decode(self, t: Sequence[int]) -> str: return self.sp_tokenizer.detokenize(t) + def build_tokenizer(tokenizer_path, add_bos, add_eos): """Loads the tokenizer at `tokenizer_path`""" max_logging.log(f"Tokenizer path: {tokenizer_path}") @@ -218,12 +213,14 @@ def build_tokenizer(tokenizer_path, add_bos, add_eos): def TokenizeOp(tokenizer, features: Features, data_keys: Iterable[str] = ("inputs", "targets")) -> Features: """Op for tokenization""" + def _process_string(string_tensor): # Extract string value and decode it if necessary - string_value = string_tensor.numpy().decode('utf-8') + string_value = string_tensor.numpy().decode("utf-8") # encode and extract the tokenized integers modified_string = tokenizer.encode(string_value) return [modified_string] + for k in data_keys: if isinstance(tokenizer, TikTokenTokenizer): features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0] diff --git a/MaxText/train.py b/MaxText/train.py index 5e5885467..d1a1c53fb 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -66,7 +66,7 @@ Transformer = models.Transformer EPS = 1e-8 -_CHUNK_BYTE_SIZE = 2 * 1024 **3 +_CHUNK_BYTE_SIZE = 2 * 1024**3 def validate_train_config(config): @@ -78,9 +78,11 @@ def validate_train_config(config): if not config.base_output_directory.startswith("gs://"): max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." - if config.quantization=='fp8': + if config.quantization == "fp8": # pylint: disable=line-too-long - assert config.gradient_accumulation_steps == 1, "fp8 can't be used with gradient_accumulation_steps right now. Please use other quantization or set gradient_accumulation_steps to 1" + assert ( + config.gradient_accumulation_steps == 1 + ), "fp8 can't be used with gradient_accumulation_steps right now. Please use other quantization or set gradient_accumulation_steps to 1" def get_first_step(state): @@ -139,13 +141,15 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step max_utils.write_metrics_locally(metrics_to_write, steps_to_write, config, local_metrics_file, is_training) if config.gcs_metrics and jax.process_index() == 0: - running_gcs_metrics = max_utils.write_metrics_for_gcs(metrics_to_write, steps_to_write, config, - running_gcs_metrics, is_training) + running_gcs_metrics = max_utils.write_metrics_for_gcs( + metrics_to_write, steps_to_write, config, running_gcs_metrics, is_training + ) if is_training: _buffered_step = step _buffered_metrics = metrics + def write_metrics_to_tensorboard(writer, metrics, step, config, is_training=True): """Writes metrics to tensorboard""" with jax.spmd_mode("allow_all"): @@ -170,12 +174,14 @@ def write_metrics_to_tensorboard(writer, metrics, step, config, is_training=True max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") writer.flush() + def clear_buffered_metrics(): global _buffered_step global _buffered_metrics _buffered_step = None _buffered_metrics = None + def save_checkpoint( checkpoint_manager, step, @@ -187,8 +193,7 @@ def save_checkpoint( """Wrapper for saving checkpoint.""" if config and config.enable_checkpointing: if (step % config.checkpoint_period == 0) or ( - config.enable_emergency_checkpoint - and step % config.local_checkpoint_period == 0 + config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0 ): blocking_until_ready_start = time.time() max_logging.log(f"Waiting for step {step} to finish before checkpoint...") @@ -201,16 +206,13 @@ def save_checkpoint( ) # specify chunk_byte_size to force orbax to control maximum file size in checkpoint - save_args = jax.tree.map( - lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=_CHUNK_BYTE_SIZE), state - ) + save_args = jax.tree.map(lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=_CHUNK_BYTE_SIZE), state) if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager): return checkpoint_manager.save( - step, args=orbax.checkpoint.args.PyTreeSave( - item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE - ) - ) + step, + args=orbax.checkpoint.args.PyTreeSave(item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE), + ) if dataset_type == "grain": return checkpoint_manager.save( @@ -224,11 +226,12 @@ def save_checkpoint( ) else: return checkpoint_manager.save( - step, args=orbax.checkpoint.args.Composite( + step, + args=orbax.checkpoint.args.Composite( items=orbax.checkpoint.args.PyTreeSave( item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE ) - ) + ), ) @@ -329,36 +332,36 @@ def train_step(model, config, state, data, dropout_rng): """ if config.gradient_accumulation_steps > 1: + def accumulate_gradient(acc_grad_and_loss, data): grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, state.params, is_train=True) - acc_grad_and_loss['loss'] += aux['total_loss'] - acc_grad_and_loss['moe_lb_loss'] += aux['moe_lb_loss'] - acc_grad_and_loss['grad'] = jax.tree_util.tree_map( - lambda x, y: x * aux['total_weights'] + y, - cur_batch_gradient, - acc_grad_and_loss['grad']) - acc_grad_and_loss['total_weights'] += aux['total_weights'] + acc_grad_and_loss["loss"] += aux["total_loss"] + acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] + acc_grad_and_loss["grad"] = jax.tree_util.tree_map( + lambda x, y: x * aux["total_weights"] + y, cur_batch_gradient, acc_grad_and_loss["grad"] + ) + acc_grad_and_loss["total_weights"] += aux["total_weights"] return acc_grad_and_loss, aux def reshape_to_microbatch_accumulations(batch_arr): - ''' Reshape global batch to microbatches, assuming batch axis is leading.''' + """Reshape global batch to microbatches, assuming batch axis is leading.""" microbatches = config.gradient_accumulation_steps - microbatch_shape = (microbatches, batch_arr.shape[0] // microbatches) + batch_arr.shape[1:] + microbatch_shape = (microbatches, batch_arr.shape[0] // microbatches) + batch_arr.shape[1:] return jnp.reshape(batch_arr, microbatch_shape) data = jax.tree_util.tree_map(reshape_to_microbatch_accumulations, data) init_grad = jax.tree_util.tree_map(jnp.zeros_like, state.params) - init_grad_and_loss = {'loss': 0.0, 'grad': init_grad, 'total_weights':0, 'moe_lb_loss':0.0} + init_grad_and_loss = {"loss": 0.0, "grad": init_grad, "total_weights": 0, "moe_lb_loss": 0.0} grad_and_loss, aux = jax.lax.scan( - accumulate_gradient, - init_grad_and_loss, - data, - length = config.gradient_accumulation_steps) - loss = (grad_and_loss['loss'] / grad_and_loss['total_weights'] - + grad_and_loss['moe_lb_loss'] / config.gradient_accumulation_steps) - raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss['total_weights'], grad_and_loss['grad']) + accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps + ) + loss = ( + grad_and_loss["loss"] / grad_and_loss["total_weights"] + + grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps + ) + raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], grad_and_loss["grad"]) aux = jax.tree_map(lambda x: jnp.sum(x, axis=0), aux) else: grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) @@ -398,11 +401,12 @@ def eval_step(model, config, state, data, dropout_rng): total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] metrics = { - "scalar": {"evaluation/loss": loss, - "evaluation/total_loss": total_loss, - "evaluation/total_weights": total_weights, - "evaluation/moe_lb_loss": moe_lb_loss}, - + "scalar": { + "evaluation/loss": loss, + "evaluation/total_loss": total_loss, + "evaluation/total_weights": total_weights, + "evaluation/moe_lb_loss": moe_lb_loss, + }, } return metrics @@ -417,24 +421,24 @@ def create_goodput_recorder(config): def record_goodput( - recorder, - config, - record_func, - *args, - ): + recorder, + config, + record_func, + *args, +): """Record data for Goodput and Badput computation.""" if recorder and config.enable_goodput_recording: record_func(*args) + def check_example_batch(config, example_batch): if config.max_checkify: - jittable_f = checkify.checkify( - lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!") - ) + jittable_f = checkify.checkify(lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!")) # Check if inputs in batch contains bad synthetic data. - err, _ = jax.jit(jittable_f)(example_batch['inputs'][: config.global_batch_size_to_train_on, :]) + err, _ = jax.jit(jittable_f)(example_batch["inputs"][: config.global_batch_size_to_train_on, :]) err.throw() + def setup_mesh_and_model(config): """Set up the mesh and the model for training @@ -466,19 +470,15 @@ def setup_mesh_and_model(config): tx = optimizers.get_optimizer(config, learning_rate_schedule) logger = checkpointing.setup_checkpoint_logger(config) if config.enable_emergency_checkpoint: - abstract_state, _, _ = max_utils.get_abstract_state( - model, tx, config, init_rng, mesh, is_training=True - ) - checkpoint_manager = ( - checkpointing.create_orbax_emergency_checkpoint_manager( - config.local_checkpoint_directory, - config.checkpoint_dir, - mesh, - abstract_state, - config.local_checkpoint_period, - config.checkpoint_period, - logger, - ) + abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( + config.local_checkpoint_directory, + config.checkpoint_dir, + mesh, + abstract_state, + config.local_checkpoint_period, + config.checkpoint_period, + logger, ) else: checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( @@ -649,7 +649,9 @@ def train_loop(config, state=None): state, metrics = p_train_step(state, example_batch, nextrng) new_time = datetime.datetime.now() - record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens) + record_scalar_metrics( + metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens + ) last_step_completion = new_time if checkpoint_manager is not None: @@ -666,12 +668,12 @@ def train_loop(config, state=None): if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: assert eval_data_iterator cumulative_eval_metrics = { - "scalar": { - "eval/total_loss": 0.0, - "eval/total_weights": 0.0, - "eval/avg_loss": 0.0, - "eval/moe_lb_loss": 0.0, - } + "scalar": { + "eval/total_loss": 0.0, + "eval/total_weights": 0.0, + "eval/avg_loss": 0.0, + "eval/moe_lb_loss": 0.0, + } } eval_step_count = 0 for eval_batch in eval_data_iterator: @@ -684,10 +686,18 @@ def train_loop(config, state=None): cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"]) max_logging.log(f"Completed eval step {eval_step_count}") eval_step_count += 1 - eval_loss = cumulative_eval_metrics["scalar"]["eval/total_loss"] / (cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS) + cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count + eval_loss = ( + cumulative_eval_metrics["scalar"]["eval/total_loss"] + / (cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS) + + cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count + ) cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss - write_metrics(writer, local_metrics_file, running_gcs_metrics, cumulative_eval_metrics, step, config, is_training=False) - max_logging.log(f"average loss after {step=}: {eval_step_count=}, {eval_loss=}, total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}") + write_metrics( + writer, local_metrics_file, running_gcs_metrics, cumulative_eval_metrics, step, config, is_training=False + ) + max_logging.log( + f"average loss after {step=}: {eval_step_count=}, {eval_loss=}, total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}" + ) if eval_loss <= config.target_eval_loss: max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") prof.deactivate() @@ -720,14 +730,14 @@ def main(argv: Sequence[str]) -> None: vertex_tensorboard_manager.configure_vertex_tensorboard(config) if config.monitor_goodput and jax.process_index() == 0: - logger_name = f'goodput_{config.run_name}' + logger_name = f"goodput_{config.run_name}" goodput_monitor = monitoring.GoodputMonitor( - job_name=config.run_name, - logger_name=logger_name, - tensorboard_dir=config.tensorboard_dir, - upload_interval=config.goodput_upload_interval_seconds, - monitoring_enabled=True, - include_badput_breakdown=True, + job_name=config.run_name, + logger_name=logger_name, + tensorboard_dir=config.tensorboard_dir, + upload_interval=config.goodput_upload_interval_seconds, + monitoring_enabled=True, + include_badput_breakdown=True, ) goodput_monitor.start_goodput_uploader() max_logging.log("Started Goodput upload to Tensorboard in the background!") diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index dbeda4a0a..5c6ab82d0 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -56,11 +56,11 @@ def validate_config(config): def get_topology_mesh(config): """Get the target hardware devices, and create configured mesh with them""" target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) - if target_hardware.platform == 'gpu': + if target_hardware.platform == "gpu": # Disable sharded autotuning. This is an optimization to distribute # autotuning across the fleet, but can cause hangs with AoT compilation. - os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + ' --xla_gpu_shard_autotuning=false' - jax.config.update('mock_num_gpu_processes', config.compile_topology_num_slices) + os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" + jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices) topology_devices = jax.devices() else: topology_devices = get_topology_desc( diff --git a/MaxText/train_tokenizer.py b/MaxText/train_tokenizer.py index 0a440a388..cb0eab5cf 100644 --- a/MaxText/train_tokenizer.py +++ b/MaxText/train_tokenizer.py @@ -89,13 +89,15 @@ def _train_sentencepiece( fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) with tempfile.NamedTemporaryFile(delete=False, prefix="/tmp/sp_tmp") as model_fp: pass # we just want a prefix'd tmp-filename - argstr = " ".join([ - f"--input={fname}", - f"--vocab_size={vocab_size}", - f"--character_coverage={character_coverage}", - f"--model_prefix={model_fp.name}", - f"--model_type={model_type}", - ]) + argstr = " ".join( + [ + f"--input={fname}", + f"--vocab_size={vocab_size}", + f"--character_coverage={character_coverage}", + f"--model_prefix={model_fp.name}", + f"--model_type={model_type}", + ] + ) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address