Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NVIDIA1_3B training configuration #118

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@
batch_sizes=[1, 64, 128],
)

NVIDIA1_3B_2G_TRAIN_BF16_JAX_CASES = utils.build_batch_benchmark_cases(
batch_models=model_definitions.NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES,
verify_parameters={
"absolute_tolerance": 0.5,
},
batch_sizes=[8],
)

ALL_BENCHMARKS = list(
itertools.chain(
T5_LARGE_FP32_JAX_512XI32_CASES.values(),
Expand All @@ -103,4 +111,5 @@
RESNET50_FP16_JAX_3X224X224XF16_CASES.values(),
RESNET50_BF16_JAX_3X224X224XBF16_CASES.values(),
GPT2LMHEAD_FP32_JAX_512XI32_CASES.values(),
NVIDIA1_3B_2G_TRAIN_BF16_JAX_CASES.values(),
))
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,35 @@
template=GPT2LMHEAD_FP32_JAX_512XI32_BATCH_TEMPLATE,
batch_sizes=[1, 64, 128])

# DO_NOT_SUBMIT
NVIDIA1_3B_2G_GCS_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.13_1690046172/"
NVIDIA1_3B_2G_ARTIFACTS_DIR_URL_TEMPLATE = string.Template(NVIDIA1_3B_2G_GCS_DIR +
"${name}")
NVIDIA1_3B_2G_TRAIN_BF16_JAX_IMPL = def_types.ModelImplementation(
name="MODEL_NVIDIA1_3B_2G_TRAIN_BF16_JAX",
tags=["fp32", "transformer-decoder", "nvidia"],
framework_type=def_types.ModelFrameworkType.JAX,
module_path=f"{utils.MODELS_MODULE_PATH}.jax.pax.nvidia1_3b_2g_train_model",
source_info="PAX LLM",
)
NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCH_TEMPLATE = utils.ModelTemplate(
name=utils.BATCH_NAME("NVIDIA1_3B_2G_TRAIN_BF16_JAX"),
tags=[utils.BATCH_TAG],
model_impl=NVIDIA1_3B_2G_TRAIN_BF16_JAX_IMPL,
model_parameters={
"batch_size": utils.BATCH_SIZE_PARAM,
"data_type": "fp32",
},
artifacts_dir_url=NVIDIA1_3B_2G_ARTIFACTS_DIR_URL_TEMPLATE,
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
)
NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES = utils.build_batch_models(
template=NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCH_TEMPLATE,
batch_sizes=[8])

ALL_MODELS = list(
itertools.chain(
T5_LARGE_FP32_JAX_512XI32_BATCHES.values(),
Expand All @@ -308,4 +337,5 @@
RESNET50_FP16_JAX_3X224X224XF16_BATCHES.values(),
RESNET50_BF16_JAX_3X224X224XBF16_BATCHES.values(),
GPT2LMHEAD_FP32_JAX_512XI32_BATCHES.values(),
NVIDIA1_3B_2G_TRAIN_BF16_JAX_BATCHES.values(),
))
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from absl import flags
import argparse
import jax
import os
Expand Down Expand Up @@ -120,7 +121,7 @@ def main(output_dir: pathlib.Path, filter: str, iree_opt_path: pathlib.Path):

output_dir.mkdir(parents=True, exist_ok=True)
for model in models:
# We need to generate artifacts in a separate proces each time in order for
# We need to generate artifacts in a separate process each time in order for
# XLA to update the HLO dump directory.
p = multiprocessing.Process(target=_generate_artifacts,
args=(model, output_dir, iree_opt_path))
Expand All @@ -129,4 +130,6 @@ def main(output_dir: pathlib.Path, filter: str, iree_opt_path: pathlib.Path):


if __name__ == "__main__":
# PAX requires absl's flags to be initialized.
flags.FLAGS(sys.argv[:1])
main(**vars(_parse_arguments()))
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2023 The OpenXLA Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from etils import epath
import tempfile
from typing import Any, Tuple

import jax
import jax.numpy as jnp
from paxml import partitioning
from paxml import programs
from paxml import trainer_lib
from paxml.tasks.lm.params import nvidia
from praxis import base_layer
from praxis import py_utils

from openxla.benchmark.models import model_interfaces

instantiate = base_layer.instantiate
NestedMap = py_utils.NestedMap


class NVIDIA1_3B2g(nvidia.NVIDIA1_3B):
ICI_MESH_SHAPE = [2, 1, 1]


class NVIDIA1_3B2gTrain(model_interfaces.InferenceModel):
batch_size: int

def __init__(self, batch_size: int):
self.batch_size = batch_size

self.experiment_config = NVIDIA1_3B2g()
self.task = instantiate(self.experiment_config.task())
self.partitioner = partitioning.create_partitioner(self.task)
prng_key = jax.random.PRNGKey(123)

train_input_p = self.experiment_config.datasets()[0]
train_input_p = self.partitioner.preprocess_input_config(train_input_p)
self.train_input = instantiate(train_input_p)

with tempfile.TemporaryDirectory() as d:
job_log_dir = epath.Path(d)
prng_key, setup_key = jax.random.split(prng_key)
self.partitioner.setup(
self.task,
setup_key,
train_inputs_shape_dtype=None,
train_input_pipeline=self.train_input,
job_log_dir=job_log_dir,
)

# Initialize the partitioned train state.
prng_key, state_key = jax.random.split(prng_key)
_, self.train_state, _ = self.partitioner.initialize_prng_key_and_train_state(
state_key,
train_state=None,
checkpoint_type=None,
)

prng_key, train_prng_seed, eval_prng_seed = jax.random.split(prng_key, 3)
self.train_program = programs.SingleTaskTrainProgram()
self.train_program.setup(
self.task,
self.train_input,
self.partitioner,
job_log_dir,
train_prng_seed,
eval_prng_seed,
init_step=0,
)
self.partitioned_prng_key = self.partitioner.preprocess_prng_key(prng_key)

def generate_default_inputs(self) -> NestedMap:
train_input_p = self.experiment_config.datasets()[0]
train_input_p = self.partitioner.preprocess_input_config(train_input_p)
train_input_p.input.batch_size = self.batch_size
train_input = instantiate(train_input_p)
train_batch = train_input.get_next()
train_batch = self.partitioner.preprocess_inputs(
train_input,
train_batch,
self.train_program.train_input_partition_spec(train_batch)
)
return train_batch

def preprocess(self, raw_input: Any) -> Any:
return raw_input

def forward(self, inputs: NestedMap) -> Tuple[NestedMap]:
step, train_state, step_fn_output = self.train_program.train_step(
step=0,
state=self.train_state,
prng_key=self.partitioned_prng_key,
inputs=inputs,
static_args=trainer_lib.BaseStepFnStaticArgs(
unpadded_global_batch_size=self.batch_size)
)
return (step_fn_output,)

def postprocess(self, outputs: Any) -> Any:
return outputs


def create_model(batch_size: int = 1,
**_unused_params) -> NVIDIA1_3B2gTrain:
"""Configure and create a NVIDIA1_3B model instance.

Args:
batch_size: input batch size.
Returns:
A NVIDIA1_3B model.
"""
return NVIDIA1_3B2gTrain(batch_size=batch_size)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
jax
praxis @ git+https://github.com/google/praxis
paxml @ git+https://github.com/google/paxml
1 change: 1 addition & 0 deletions comparative_benchmark/jax_xla/benchmark_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ declare -a GPU_BENCHMARK_NAMES=(
"models/T5_LARGE_FP32_JAX_.+"
"models/T5_4CG_LARGE_FP32_JAX_.+"
"models/GPT2LMHEAD_FP32_JAX_.+"
"models/NVIDIA1_3B_2G_TRAIN_BF16_JAX.+"
)

declare -a CPU_BENCHMARK_NAMES=(
Expand Down
2 changes: 2 additions & 0 deletions comparative_benchmark/jax_xla/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ flax
jax
transformers
pillow
praxis @ git+https://github.com/google/praxis
paxml @ git+https://github.com/google/paxml
3 changes: 3 additions & 0 deletions comparative_benchmark/jax_xla/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from absl import flags
import argparse
import jax
import numpy as np
Expand Down Expand Up @@ -141,4 +142,6 @@ def main(**kwargs):


if __name__ == "__main__":
# PAX requires absl's flags to be initialized.
flags.FLAGS(sys.argv[:1])
main(**vars(_parse_arguments()))