Skip to content

Commit

Permalink
Add LmCloudSpmd2B training configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenix-meadowlark committed Aug 10, 2023
1 parent 217aca2 commit ad9508e
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@
batch_sizes=[1, 64, 128],
)

NVIDIA1_3B_2G_TRAIN_BF16_JAX_CASES = utils.build_batch_benchmark_cases(
batch_models=model_definitions.NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES,
verify_parameters={
"absolute_tolerance": 0.5,
},
batch_sizes=[8],
)

ALL_BENCHMARKS = list(
itertools.chain(
T5_LARGE_FP32_JAX_512XI32_CASES.values(),
Expand All @@ -103,4 +111,5 @@
RESNET50_FP16_JAX_3X224X224XF16_CASES.values(),
RESNET50_BF16_JAX_3X224X224XBF16_CASES.values(),
GPT2LMHEAD_FP32_JAX_512XI32_CASES.values(),
NVIDIA1_3B_2G_TRAIN_BF16_JAX_CASES.values(),
))
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,35 @@
template=GPT2LMHEAD_FP32_JAX_512XI32_BATCH_TEMPLATE,
batch_sizes=[1, 64, 128])

# DO_NOT_SUBMIT
NVIDIA1_3B_2G_GCS_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.13_1690046172/"
NVIDIA1_3B_2G_ARTIFACTS_DIR_URL_TEMPLATE = string.Template(NVIDIA1_3B_2G_GCS_DIR +
"${name}")
NVIDIA1_3B_2G_TRAIN_BF16_JAX_IMPL = def_types.ModelImplementation(
name="MODEL_NVIDIA1_3B_2G_TRAIN_BF16_JAX",
tags=["fp32", "transformer-decoder", "nvidia"],
framework_type=def_types.ModelFrameworkType.JAX,
module_path=f"{utils.MODELS_MODULE_PATH}.jax.pax.nvidia1_3b_2g_train_model",
source_info="PAX LLM",
)
NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCH_TEMPLATE = utils.ModelTemplate(
name=utils.BATCH_NAME("NVIDIA1_3B_2G_TRAIN_BF16_JAX"),
tags=[utils.BATCH_TAG],
model_impl=NVIDIA1_3B_2G_TRAIN_BF16_JAX_IMPL,
model_parameters={
"batch_size": utils.BATCH_SIZE_PARAM,
"data_type": "fp32",
},
artifacts_dir_url=NVIDIA1_3B_2G_ARTIFACTS_DIR_URL_TEMPLATE,
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
)
NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES = utils.build_batch_models(
template=NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCH_TEMPLATE,
batch_sizes=[8])

ALL_MODELS = list(
itertools.chain(
T5_LARGE_FP32_JAX_512XI32_BATCHES.values(),
Expand All @@ -308,4 +337,5 @@
RESNET50_FP16_JAX_3X224X224XF16_BATCHES.values(),
RESNET50_BF16_JAX_3X224X224XBF16_BATCHES.values(),
GPT2LMHEAD_FP32_JAX_512XI32_BATCHES.values(),
NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES.values(),
))
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from absl import flags
import argparse
import jax
import os
Expand Down Expand Up @@ -120,7 +121,7 @@ def main(output_dir: pathlib.Path, filter: str, iree_opt_path: pathlib.Path):

output_dir.mkdir(parents=True, exist_ok=True)
for model in models:
# We need to generate artifacts in a separate proces each time in order for
# We need to generate artifacts in a separate process each time in order for
# XLA to update the HLO dump directory.
p = multiprocessing.Process(target=_generate_artifacts,
args=(model, output_dir, iree_opt_path))
Expand All @@ -129,4 +130,6 @@ def main(output_dir: pathlib.Path, filter: str, iree_opt_path: pathlib.Path):


if __name__ == "__main__":
# PAX requires absl's flags to be initialized.
flags.FLAGS(sys.argv[:1])
main(**vars(_parse_arguments()))
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2023 The OpenXLA Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from etils import epath
import tempfile
from typing import Any, Tuple

import jax
import jax.numpy as jnp
from paxml import partitioning
from paxml import programs
from paxml import trainer_lib
from paxml.tasks.lm.params import nvidia
from praxis import base_layer
from praxis import py_utils

from openxla.benchmark.models import model_interfaces

instantiate = base_layer.instantiate
NestedMap = py_utils.NestedMap


class NVIDIA1_3B2g(nvidia.NVIDIA1_3B):
ICI_MESH_SHAPE = [2, 1, 1]


class NVIDIA1_3B2gTrain(model_interfaces.InferenceModel):
batch_size: int

def __init__(self, batch_size: int):
self.batch_size = batch_size

self.experiment_config = NVIDIA1_3B2g()
self.task = instantiate(self.experiment_config.task())
self.partitioner = partitioning.create_partitioner(self.task)
prng_key = jax.random.PRNGKey(123)

train_input_p = self.experiment_config.datasets()[0]
train_input_p = self.partitioner.preprocess_input_config(train_input_p)
self.train_input = instantiate(train_input_p)

with tempfile.TemporaryDirectory() as d:
job_log_dir = epath.Path(d)
prng_key, setup_key = jax.random.split(prng_key)
self.partitioner.setup(
self.task,
setup_key,
train_inputs_shape_dtype=None,
train_input_pipeline=self.train_input,
job_log_dir=job_log_dir,
)

# Initialize the partitioned train state.
prng_key, state_key = jax.random.split(prng_key)
_, self.train_state, _ = self.partitioner.initialize_prng_key_and_train_state(
state_key,
train_state=None,
checkpoint_type=None,
)

prng_key, train_prng_seed, eval_prng_seed = jax.random.split(prng_key, 3)
self.train_program = programs.SingleTaskTrainProgram()
self.train_program.setup(
self.task,
self.train_input,
self.partitioner,
job_log_dir,
train_prng_seed,
eval_prng_seed,
init_step=0,
)
self.partitioned_prng_key = self.partitioner.preprocess_prng_key(prng_key)

def generate_default_inputs(self) -> NestedMap:
train_input_p = self.experiment_config.datasets()[0]
train_input_p = self.partitioner.preprocess_input_config(train_input_p)
train_input_p.input.batch_size = self.batch_size
train_input = instantiate(train_input_p)
train_batch = train_input.get_next()
train_batch = self.partitioner.preprocess_inputs(
train_input,
train_batch,
self.train_program.train_input_partition_spec(train_batch)
)
return train_batch

def preprocess(self, raw_input: Any) -> Any:
return raw_input

def forward(self, inputs: NestedMap) -> Tuple[NestedMap]:
step, train_state, step_fn_output = self.train_program.train_step(
step=0,
state=self.train_state,
prng_key=self.partitioned_prng_key,
inputs=inputs,
static_args=trainer_lib.BaseStepFnStaticArgs(
unpadded_global_batch_size=self.batch_size)
)
return (step_fn_output,)

def postprocess(self, outputs: Any) -> Any:
return outputs


def create_model(batch_size: int = 1,
**_unused_params) -> NVIDIA1_3B2gTrain:
"""Configure and create a NVIDIA1_3B model instance.
Args:
batch_size: input batch size.
Returns:
A NVIDIA1_3B model.
"""
return NVIDIA1_3B2gTrain(batch_size=batch_size)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
jax
praxis @ git+https://github.com/google/praxis
paxml @ git+https://github.com/google/paxml
1 change: 1 addition & 0 deletions comparative_benchmark/jax_xla/benchmark_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ declare -a GPU_BENCHMARK_NAMES=(
"models/T5_LARGE_FP32_JAX_.+"
"models/T5_4CG_LARGE_FP32_JAX_.+"
"models/GPT2LMHEAD_FP32_JAX_.+"
"models/NVIDIA1_3B_2G_TRAIN_BF16_JAX.+"
)

declare -a CPU_BENCHMARK_NAMES=(
Expand Down
2 changes: 2 additions & 0 deletions comparative_benchmark/jax_xla/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ flax
jax
transformers
pillow
praxis @ git+https://github.com/google/praxis
paxml @ git+https://github.com/google/paxml
3 changes: 3 additions & 0 deletions comparative_benchmark/jax_xla/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from absl import flags
import argparse
import jax
import numpy as np
Expand Down Expand Up @@ -141,4 +142,6 @@ def main(**kwargs):


if __name__ == "__main__":
# PAX requires absl's flags to be initialized.
flags.FLAGS(sys.argv[:1])
main(**vars(_parse_arguments()))

0 comments on commit ad9508e

Please sign in to comment.