diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py index 89b3c439..de49ccb9 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py @@ -90,6 +90,14 @@ batch_sizes=[1, 64, 128], ) +NVIDIA1_3B_2G_TRAIN_BF16_JAX_CASES = utils.build_batch_benchmark_cases( + batch_models=model_definitions.NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES, + verify_parameters={ + "absolute_tolerance": 0.5, + }, + batch_sizes=[8], +) + ALL_BENCHMARKS = list( itertools.chain( T5_LARGE_FP32_JAX_512XI32_CASES.values(), @@ -103,4 +111,5 @@ RESNET50_FP16_JAX_3X224X224XF16_CASES.values(), RESNET50_BF16_JAX_3X224X224XBF16_CASES.values(), GPT2LMHEAD_FP32_JAX_512XI32_CASES.values(), + NVIDIA1_3B_2G_TRAIN_BF16_JAX_CASES.values(), )) diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py index 09fb8d76..d9c9ef69 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py @@ -295,6 +295,35 @@ template=GPT2LMHEAD_FP32_JAX_512XI32_BATCH_TEMPLATE, batch_sizes=[1, 64, 128]) +# DO_NOT_SUBMIT +NVIDIA1_3B_2G_GCS_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.13_1690046172/" +NVIDIA1_3B_2G_ARTIFACTS_DIR_URL_TEMPLATE = string.Template(NVIDIA1_3B_2G_GCS_DIR + + "${name}") +NVIDIA1_3B_2G_TRAIN_BF16_JAX_IMPL = def_types.ModelImplementation( + name="MODEL_NVIDIA1_3B_2G_TRAIN_BF16_JAX", + tags=["fp32", "transformer-decoder", "nvidia"], + framework_type=def_types.ModelFrameworkType.JAX, + module_path=f"{utils.MODELS_MODULE_PATH}.jax.pax.nvidia1_3b_2g_train_model", + source_info="PAX LLM", +) +NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCH_TEMPLATE = utils.ModelTemplate( + name=utils.BATCH_NAME("NVIDIA1_3B_2G_TRAIN_BF16_JAX"), + tags=[utils.BATCH_TAG], + model_impl=NVIDIA1_3B_2G_TRAIN_BF16_JAX_IMPL, + model_parameters={ + "batch_size": utils.BATCH_SIZE_PARAM, + "data_type": "fp32", + }, + artifacts_dir_url=NVIDIA1_3B_2G_ARTIFACTS_DIR_URL_TEMPLATE, + exported_model_types=[ + def_types.ModelArtifactType.STABLEHLO_MLIR, + def_types.ModelArtifactType.XLA_HLO_DUMP, + ], +) +NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES = utils.build_batch_models( + template=NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCH_TEMPLATE, + batch_sizes=[8]) + ALL_MODELS = list( itertools.chain( T5_LARGE_FP32_JAX_512XI32_BATCHES.values(), @@ -308,4 +337,5 @@ RESNET50_FP16_JAX_3X224X224XF16_BATCHES.values(), RESNET50_BF16_JAX_3X224X224XBF16_BATCHES.values(), GPT2LMHEAD_FP32_JAX_512XI32_BATCHES.values(), + NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES.values(), )) diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.py index 086d43d8..8bab78cd 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from absl import flags import argparse import jax import os @@ -120,7 +121,7 @@ def main(output_dir: pathlib.Path, filter: str, iree_opt_path: pathlib.Path): output_dir.mkdir(parents=True, exist_ok=True) for model in models: - # We need to generate artifacts in a separate proces each time in order for + # We need to generate artifacts in a separate process each time in order for # XLA to update the HLO dump directory. p = multiprocessing.Process(target=_generate_artifacts, args=(model, output_dir, iree_opt_path)) @@ -129,4 +130,6 @@ def main(output_dir: pathlib.Path, filter: str, iree_opt_path: pathlib.Path): if __name__ == "__main__": + # PAX requires absl's flags to be initialized. + flags.FLAGS(sys.argv[:1]) main(**vars(_parse_arguments())) diff --git a/common_benchmark_suite/openxla/benchmark/models/jax/pax/__init__.py b/common_benchmark_suite/openxla/benchmark/models/jax/pax/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/common_benchmark_suite/openxla/benchmark/models/jax/pax/nvidia1_3b_2g_train_model.py b/common_benchmark_suite/openxla/benchmark/models/jax/pax/nvidia1_3b_2g_train_model.py new file mode 100644 index 00000000..4e824674 --- /dev/null +++ b/common_benchmark_suite/openxla/benchmark/models/jax/pax/nvidia1_3b_2g_train_model.py @@ -0,0 +1,117 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from etils import epath +import tempfile +from typing import Any, Tuple + +import jax +import jax.numpy as jnp +from paxml import partitioning +from paxml import programs +from paxml import trainer_lib +from paxml.tasks.lm.params import nvidia +from praxis import base_layer +from praxis import py_utils + +from openxla.benchmark.models import model_interfaces + +instantiate = base_layer.instantiate +NestedMap = py_utils.NestedMap + + +class NVIDIA1_3B2g(nvidia.NVIDIA1_3B): + ICI_MESH_SHAPE = [2, 1, 1] + + +class NVIDIA1_3B2gTrain(model_interfaces.InferenceModel): + batch_size: int + + def __init__(self, batch_size: int): + self.batch_size = batch_size + + self.experiment_config = NVIDIA1_3B2g() + self.task = instantiate(self.experiment_config.task()) + self.partitioner = partitioning.create_partitioner(self.task) + prng_key = jax.random.PRNGKey(123) + + train_input_p = self.experiment_config.datasets()[0] + train_input_p = self.partitioner.preprocess_input_config(train_input_p) + self.train_input = instantiate(train_input_p) + + with tempfile.TemporaryDirectory() as d: + job_log_dir = epath.Path(d) + prng_key, setup_key = jax.random.split(prng_key) + self.partitioner.setup( + self.task, + setup_key, + train_inputs_shape_dtype=None, + train_input_pipeline=self.train_input, + job_log_dir=job_log_dir, + ) + + # Initialize the partitioned train state. + prng_key, state_key = jax.random.split(prng_key) + _, self.train_state, _ = self.partitioner.initialize_prng_key_and_train_state( + state_key, + train_state=None, + checkpoint_type=None, + ) + + prng_key, train_prng_seed, eval_prng_seed = jax.random.split(prng_key, 3) + self.train_program = programs.SingleTaskTrainProgram() + self.train_program.setup( + self.task, + self.train_input, + self.partitioner, + job_log_dir, + train_prng_seed, + eval_prng_seed, + init_step=0, + ) + self.partitioned_prng_key = self.partitioner.preprocess_prng_key(prng_key) + + def generate_default_inputs(self) -> NestedMap: + train_input_p = self.experiment_config.datasets()[0] + train_input_p = self.partitioner.preprocess_input_config(train_input_p) + train_input_p.input.batch_size = self.batch_size + train_input = instantiate(train_input_p) + train_batch = train_input.get_next() + train_batch = self.partitioner.preprocess_inputs( + train_input, + train_batch, + self.train_program.train_input_partition_spec(train_batch) + ) + return train_batch + + def preprocess(self, raw_input: Any) -> Any: + return raw_input + + def forward(self, inputs: NestedMap) -> Tuple[NestedMap]: + step, train_state, step_fn_output = self.train_program.train_step( + step=0, + state=self.train_state, + prng_key=self.partitioned_prng_key, + inputs=inputs, + static_args=trainer_lib.BaseStepFnStaticArgs( + unpadded_global_batch_size=self.batch_size) + ) + return (step_fn_output,) + + def postprocess(self, outputs: Any) -> Any: + return outputs + + +def create_model(batch_size: int = 1, + **_unused_params) -> NVIDIA1_3B2gTrain: + """Configure and create a NVIDIA1_3B model instance. + + Args: + batch_size: input batch size. + Returns: + A NVIDIA1_3B model. + """ + return NVIDIA1_3B2gTrain(batch_size=batch_size) diff --git a/common_benchmark_suite/openxla/benchmark/models/jax/pax/requirements.txt b/common_benchmark_suite/openxla/benchmark/models/jax/pax/requirements.txt new file mode 100644 index 00000000..a852edaa --- /dev/null +++ b/common_benchmark_suite/openxla/benchmark/models/jax/pax/requirements.txt @@ -0,0 +1,3 @@ +jax +praxis @ git+https://github.com/google/praxis +paxml @ git+https://github.com/google/paxml diff --git a/comparative_benchmark/jax_xla/benchmark_all.sh b/comparative_benchmark/jax_xla/benchmark_all.sh index b24ef026..0d2c48df 100755 --- a/comparative_benchmark/jax_xla/benchmark_all.sh +++ b/comparative_benchmark/jax_xla/benchmark_all.sh @@ -39,6 +39,7 @@ declare -a GPU_BENCHMARK_NAMES=( "models/T5_LARGE_FP32_JAX_.+" "models/T5_4CG_LARGE_FP32_JAX_.+" "models/GPT2LMHEAD_FP32_JAX_.+" + "models/NVIDIA1_3B_2G_TRAIN_BF16_JAX.+" ) declare -a CPU_BENCHMARK_NAMES=( diff --git a/comparative_benchmark/jax_xla/requirements.txt b/comparative_benchmark/jax_xla/requirements.txt index 905a4a6f..bd785996 100644 --- a/comparative_benchmark/jax_xla/requirements.txt +++ b/comparative_benchmark/jax_xla/requirements.txt @@ -2,3 +2,5 @@ flax jax transformers pillow +praxis @ git+https://github.com/google/praxis +paxml @ git+https://github.com/google/paxml diff --git a/comparative_benchmark/jax_xla/run_benchmarks.py b/comparative_benchmark/jax_xla/run_benchmarks.py index 9bd6a403..c0f81da4 100755 --- a/comparative_benchmark/jax_xla/run_benchmarks.py +++ b/comparative_benchmark/jax_xla/run_benchmarks.py @@ -6,6 +6,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from absl import flags import argparse import jax import numpy as np @@ -141,4 +142,6 @@ def main(**kwargs): if __name__ == "__main__": + # PAX requires absl's flags to be initialized. + flags.FLAGS(sys.argv[:1]) main(**vars(_parse_arguments()))