Skip to content

Commit

Permalink
Add dotprod microbenchmark artifacts.
Browse files Browse the repository at this point in the history
AUTO_UPLOAD=1 ./common_benchmark_suite/openxla/benchmark/comparative_suite/jax/scripts/generate_model_artifacts.sh DOT_PRODUCT_JAX_.+

Signed-off-by: mariecwhite <[email protected]>
  • Loading branch information
mariecwhite committed Jun 27, 2024
1 parent 7a39538 commit 4edd71b
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,30 @@
verify_parameters={"absolute_tolerance": 0.5},
)

DOT_PRODUCT_JAX_1X256X2048XI8I8_CASE = def_types.BenchmarkCase.build(
model=model_definitions.DOT_PRODUCT_JAX_1X256X2048XI8I8,
input_data=testdata.INPUT_DATA_MODEL_DEFAULT,
verify_parameters={"absolute_tolerance": 0.5},
)

DOT_PRODUCT_JAX_1X256X2048XI8I4_CASE = def_types.BenchmarkCase.build(
model=model_definitions.DOT_PRODUCT_JAX_1X256X2048XI8I4,
input_data=testdata.INPUT_DATA_MODEL_DEFAULT,
verify_parameters={"absolute_tolerance": 0.5},
)

DOT_PRODUCT_JAX_256X256X2048XI8I8_CASE = def_types.BenchmarkCase.build(
model=model_definitions.DOT_PRODUCT_JAX_256X256X2048XI8I8,
input_data=testdata.INPUT_DATA_MODEL_DEFAULT,
verify_parameters={"absolute_tolerance": 0.5},
)

DOT_PRODUCT_JAX_256X256X2048XI8I4_CASE = def_types.BenchmarkCase.build(
model=model_definitions.DOT_PRODUCT_JAX_256X256X2048XI8I4,
input_data=testdata.INPUT_DATA_MODEL_DEFAULT,
verify_parameters={"absolute_tolerance": 0.5},
)

ALL_BENCHMARKS = list(
itertools.chain(
T5_LARGE_FP32_JAX_512XI32_CASES.values(),
Expand Down Expand Up @@ -207,4 +231,8 @@
GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32_CASE,
GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32_CASE,
GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32_CASE,
DOT_PRODUCT_JAX_1X256X2048XI8I8_CASE,
DOT_PRODUCT_JAX_1X256X2048XI8I4_CASE,
DOT_PRODUCT_JAX_256X256X2048XI8I8_CASE,
DOT_PRODUCT_JAX_256X256X2048XI8I4_CASE,
]
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,139 @@
f"{PARENT_GCS_DIR}/GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32",
)

# Dotprod microbenchmarks.
DOT_PRODUCT_JAX_IMPL = def_types.ModelImplementation(
name="DOT_PRODUCT_JAX",
tags=["microbenchmark"],
framework_type=def_types.ModelFrameworkType.JAX,
module_path=f"{utils.MODELS_MODULE_PATH}.jax.dotprod.dot_product",
source_info="",
)

DOT_PRODUCT_JAX_1X256X2048XI8I8 = def_types.Model(
name="DOT_PRODUCT_JAX_1X256X2048XI8I8",
tags=["i8i8i32"],
model_impl=DOT_PRODUCT_JAX_IMPL,
model_parameters={
"model_name": "dotprod",
"lhs_shape": (1, 256),
"lhs_type": "int8",
"rhs_shape": (256, 2048),
"rhs_type": "int8",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.LINALG_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
artifacts_dir_url=f"{PARENT_GCS_DIR}/DOT_PRODUCT_JAX_1X256X2048XI8I8",
)

DOT_PRODUCT_JAX_1X256X2048XI8I4 = def_types.Model(
name="DOT_PRODUCT_JAX_1X256X2048XI8I4",
tags=["i8i4i32"],
model_impl=DOT_PRODUCT_JAX_IMPL,
model_parameters={
"model_name": "dotprod",
"lhs_shape": (1, 256),
"lhs_type": "int8",
"rhs_shape": (256, 2048),
"rhs_type": "int4",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.LINALG_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
artifacts_dir_url=f"{PARENT_GCS_DIR}/DOT_PRODUCT_JAX_1X256X2048XI8I4",
)

DOT_PRODUCT_JAX_1X256X2048XF32F32 = def_types.Model(
name="DOT_PRODUCT_JAX_1X256X2048XF32F32",
tags=["f32f32f32"],
model_impl=DOT_PRODUCT_JAX_IMPL,
model_parameters={
"model_name": "dotprod",
"lhs_shape": (1, 256),
"lhs_type": "fp32",
"rhs_shape": (256, 2048),
"rhs_type": "fp32",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.LINALG_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
def_types.ModelArtifactType.TFLITE_FP32,
def_types.ModelArtifactType.TFLITE_FP32_STABLEHLO,
def_types.ModelArtifactType.TFLITE_FP16,
def_types.ModelArtifactType.TFLITE_DYNAMIC_RANGE_QUANT,
def_types.ModelArtifactType.TFLITE_INT8,
],
artifacts_dir_url=f"{PARENT_GCS_DIR}/DOT_PRODUCT_JAX_1X256X2048XF32F32",
)

DOT_PRODUCT_JAX_256X256X2048XI8I8 = def_types.Model(
name="DOT_PRODUCT_JAX_256X256X2048XI8I8",
tags=["i8i8i32"],
model_impl=DOT_PRODUCT_JAX_IMPL,
model_parameters={
"model_name": "dotprod",
"lhs_shape": (256, 256),
"lhs_type": "int8",
"rhs_shape": (256, 2048),
"rhs_type": "int8",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.LINALG_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
artifacts_dir_url=f"{PARENT_GCS_DIR}/DOT_PRODUCT_JAX_256X256X2048XI8I8",
)

DOT_PRODUCT_JAX_256X256X2048XI8I4 = def_types.Model(
name="DOT_PRODUCT_JAX_256X256X2048XI8I4",
tags=["i8i4i32"],
model_impl=DOT_PRODUCT_JAX_IMPL,
model_parameters={
"model_name": "dotprod",
"lhs_shape": (256, 256),
"lhs_type": "int8",
"rhs_shape": (256, 2048),
"rhs_type": "int4",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.LINALG_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
artifacts_dir_url=f"{PARENT_GCS_DIR}/DOT_PRODUCT_JAX_256X256X2048XI8I4",
)

DOT_PRODUCT_JAX_256X256X2048XF32F32 = def_types.Model(
name="DOT_PRODUCT_JAX_256X256X2048XF32F32",
tags=["f32f32f32"],
model_impl=DOT_PRODUCT_JAX_IMPL,
model_parameters={
"model_name": "dotprod",
"lhs_shape": (1, 256),
"lhs_type": "fp32",
"rhs_shape": (256, 2048),
"rhs_type": "fp32",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.LINALG_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
def_types.ModelArtifactType.TFLITE_FP32,
def_types.ModelArtifactType.TFLITE_FP32_STABLEHLO,
def_types.ModelArtifactType.TFLITE_FP16,
def_types.ModelArtifactType.TFLITE_DYNAMIC_RANGE_QUANT,
def_types.ModelArtifactType.TFLITE_INT8,
],
artifacts_dir_url=f"{PARENT_GCS_DIR}/DOT_PRODUCT_JAX_256X256X2048XF32F32",
)

ALL_MODELS = list(
itertools.chain(
# Models with different batch sizes.
Expand All @@ -638,10 +771,12 @@
SD_PIPELINE_FP16_JAX_64XI32_BATCHES.values(),
SD_PIPELINE_BF16_JAX_64XI32_BATCHES.values(),
)) + [
GPT2LMHEAD_PIPELINE_JAX_1X4XI32,
T5_SMALL_FP32_JAX_1X128XI32,
GPT2LMHEAD_PIPELINE_JAX_1X4XI32, T5_SMALL_FP32_JAX_1X128XI32,
VIT_CLASSIFICATION_JAX_3X224X224XF32,
GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32,
GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32,
GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32,
DOT_PRODUCT_JAX_1X256X2048XI8I8, DOT_PRODUCT_JAX_1X256X2048XI8I4,
DOT_PRODUCT_JAX_256X256X2048XI8I8, DOT_PRODUCT_JAX_256X256X2048XI8I4,
DOT_PRODUCT_JAX_1X256X2048XF32F32, DOT_PRODUCT_JAX_256X256X2048XF32F32
]
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def _generate_mlir(jit_function: Any, jit_inputs: Any, model_dir: pathlib.Path,
iree_ir_tool: Optional[pathlib.Path]):
iree_ir_tool: Optional[pathlib.Path]) -> pathlib.Path:
mlir = jit_function.lower(*jit_inputs).compiler_ir(dialect="stablehlo")
mlir_path = model_dir / "stablehlo.mlir"
print(f"Saving mlir to {mlir_path}")
Expand All @@ -48,6 +48,30 @@ def _generate_mlir(jit_function: Any, jit_inputs: Any, model_dir: pathlib.Path,
check=True,
)
mlir_path.unlink()
return binary_mlir_path

return mlir_path


def _generate_linalg_mlir(stablehlo_mlir_path: pathlib.Path,
iree_compile_path: pathlib.Path,
iree_ir_tool: Optional[pathlib.Path]):
linalg_mlir_path = stablehlo_mlir_path.parent / "linalg.mlir"
subprocess.run(
f"{iree_compile_path} {stablehlo_mlir_path} --compile-to=preprocessing > {linalg_mlir_path}",
shell=True,
check=True)

if iree_ir_tool:
binary_mlir_path = stablehlo_mlir_path.parent / "linalg.mlirbc"
subprocess.run(
[
iree_ir_tool, "cp", "--emit-bytecode", linalg_mlir_path, "-o",
binary_mlir_path
],
check=True,
)
linalg_mlir_path.unlink()


def _generate_tf_function(model_obj: Any, inputs: Any):
Expand Down Expand Up @@ -207,6 +231,7 @@ def _generate_tflite(model_obj: Any, inputs: Any, model_dir: pathlib.Path,


def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path,
iree_compile_path: pathlib.Path,
iree_ir_tool: Optional[pathlib.Path],
auto_upload: bool):
model_dir = save_dir / model.name
Expand Down Expand Up @@ -240,10 +265,14 @@ def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path,
os.unsetenv("XLA_FLAGS")

if def_types.ModelArtifactType.STABLEHLO_MLIR in model.exported_model_types:
_generate_mlir(jit_function=jit_function,
jit_inputs=jit_inputs,
model_dir=model_dir,
iree_ir_tool=iree_ir_tool)
stablehlo_mlir_path = _generate_mlir(jit_function=jit_function,
jit_inputs=jit_inputs,
model_dir=model_dir,
iree_ir_tool=iree_ir_tool)

if def_types.ModelArtifactType.LINALG_MLIR in model.exported_model_types:
_generate_linalg_mlir(stablehlo_mlir_path, iree_compile_path,
iree_ir_tool)

if def_types.ModelArtifactType.TFLITE_FP32 in model.exported_model_types:
_generate_tflite(model_obj=model_obj,
Expand Down Expand Up @@ -279,6 +308,12 @@ def _parse_arguments() -> argparse.Namespace:
nargs="+",
default=[".*"],
help="The regex patterns to filter model names.")
parser.add_argument(
"--iree-compile-path",
"--iree_compile_path",
type=pathlib.Path,
default=None,
help="Path to `iree-compile`. Used to generate linalg mlir.")
parser.add_argument("--iree-ir-tool",
"--iree_ir_tool",
type=pathlib.Path,
Expand All @@ -295,7 +330,8 @@ def _parse_arguments() -> argparse.Namespace:


def main(output_dir: pathlib.Path, filters: List[str],
iree_ir_tool: pathlib.Path, auto_upload: bool):
iree_compile_path: pathlib.Path, iree_ir_tool: pathlib.Path,
auto_upload: bool):
combined_filters = "|".join(f"({name_filter})" for name_filter in filters)
name_pattern = re.compile(f"^{combined_filters}$")
models = [
Expand All @@ -317,8 +353,8 @@ def main(output_dir: pathlib.Path, filters: List[str],
# of HLO dumps here - otherwise multiple processes will dump to the same HLO
# directory.
p = multiprocessing.Process(target=_generate_artifacts,
args=(model, output_dir, iree_ir_tool,
auto_upload))
args=(model, output_dir, iree_compile_path,
iree_ir_tool, auto_upload))
p.start()
p.join()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pip list > "${VERSION_DIR}/models_version_info.txt"

declare -a args=(
-o "${VERSION_DIR}"
--iree_compile_path="$(which iree-compile)"
--iree_ir_tool="$(which iree-ir-tool)"
)

Expand Down
Empty file.
Loading

0 comments on commit 4edd71b

Please sign in to comment.