diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py index 0f7327a7..5c048aea 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/model_definitions.py @@ -37,6 +37,7 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, ], ) T5_LARGE_FP32_TF_512XI32_BATCHES = utils.build_batch_models( @@ -69,6 +70,7 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, ], ) BERT_LARGE_FP32_TF_384XI32_BATCHES = utils.build_batch_models( @@ -100,6 +102,8 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, + def_types.ModelArtifactType.TFLITE_INT8, ], ) RESNET50_FP32_TF_224X224X3XF32_BATCHES = utils.build_batch_models( @@ -130,6 +134,7 @@ def_types.ModelArtifactType.STABLEHLO_MLIR, def_types.ModelArtifactType.XLA_HLO_DUMP, def_types.ModelArtifactType.TF_SAVEDMODEL_V2, + def_types.ModelArtifactType.TFLITE_FP32, ], ) diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py index d17e1df0..69776f5a 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/tf/scripts/generate_model_artifacts.py @@ -9,6 +9,7 @@ import pathlib import re import multiprocessing +import numpy as np import shutil import sys import tarfile @@ -61,6 +62,47 @@ def _generate_mlir(model_dir: pathlib.Path, saved_model_dir: pathlib.Path): write_bytecode(str(mlir_path), result) +def _generate_tflite(inputs: Tuple[Any, ...], model_dir: pathlib.Path, + saved_model_dir: pathlib.Path): + converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir)) + + # Generate fp32 model. + try: + tflite_model = converter.convert() + tflite_model_path = model_dir.joinpath("model_fp32.tflite") + with open(tflite_model_path, 'wb') as f: + f.write(tflite_model) + except Exception as e: + print(f"Failed to generate int8 TFLite model. Exception: {e}") + + # Generate int8 model. + try: + + def representative_examples(): + for _ in range(2): + random_inputs = [] + for input in inputs: + random_inputs.append( + np.random.uniform(low=input.dtype.min, + high=input.dtype.max, + size=input.shape).astype( + input.dtype.as_numpy_dtype)) + yield random_inputs + + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.TFLITE_BUILTINS_INT8 + ] + converter.representative_dataset = representative_examples + converter.inference_type = tf.int8 + tflite_model_int8 = converter.convert() + tflite_model_int8_path = model_dir.joinpath("model_int8.tflite") + with open(tflite_model_int8_path, 'wb') as f: + f.write(tflite_model_int8) + except Exception as e: + print(f"Failed to generate int8 TFLite model. Exception: {e}") + + def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path, auto_upload: bool): model_dir = save_dir.joinpath(model.name) @@ -87,6 +129,7 @@ def _generate_artifacts(model: def_types.Model, save_dir: pathlib.Path, saved_model_dir = _generate_saved_model(inputs, model_obj, model_dir) _generate_mlir(model_dir, saved_model_dir) + _generate_tflite(inputs, model_dir, saved_model_dir) with tarfile.open(model_dir.joinpath("tf-model.tgz"), "w:gz") as tar: tar.add(f"{saved_model_dir}/", arcname="") diff --git a/common_benchmark_suite/openxla/benchmark/def_types.py b/common_benchmark_suite/openxla/benchmark/def_types.py index 666d2564..25d45c2b 100644 --- a/common_benchmark_suite/openxla/benchmark/def_types.py +++ b/common_benchmark_suite/openxla/benchmark/def_types.py @@ -15,6 +15,7 @@ class ModelFrameworkType(Enum): """Type of framework a model is implemented in.""" TF_V1 = "tensorflow_v1" TF_V2 = "tensorflow_v2" + TFLITE = "tflite" PYTORCH = "pytorch" JAX = "jax" GGML = "ggml" @@ -42,6 +43,8 @@ class ModelArtifactType(Enum): """Type of derived model artifact.""" TF_SAVEDMODEL_V1 = "tf_savedmodel_v1" TF_SAVEDMODEL_V2 = "tf_savedmodel_v2" + TFLITE_FP32 = "tflite_fp32" + TFLITE_INT8 = "tflite_int8" XLA_HLO_DUMP = "xla_hlo_dump" STABLEHLO_MLIR = "stablehlo_mlir" LINALG_MLIR = "linalg_mlir"