diff --git a/README.md b/README.md index 8baae3ce..89b426dd 100644 --- a/README.md +++ b/README.md @@ -124,13 +124,14 @@ please read "User's Guide". Pipelines may include more than one model or control flow. -| Framework | Pipeline | Data Type | Variations | -|------------|-------------------|------------------|----------------------------------------------| -| JAX | T5-Small | FP32, FP16, BF16 | Token generation sizes: 16, 32, 64, 128, 256 | -| JAX | Stable Diffusion | FP32, FP16, BF16 | Input sequence 64 tokens | -| JAX | GPT-2 with LMHead | FP32 | Generates 200 tokens | -| Tensorflow | GPT-2 with LMHead | FP32 | Generates 200 tokens | -| GGML | GPT-2 with LMHead | FP32, FP16 | Generates 200 tokens | +| Framework | Pipeline | Data Type | Variations | +|------------|-------------------|------------------|------------------------------------------------| +| JAX | Gemma-2B-IT | FP32, FP16, BF16 | Input sequence 1024 tokens. Max new tokens 256 | +| JAX | T5-Small | FP32, FP16, BF16 | Token generation sizes: 16, 32, 64, 128, 256 | +| JAX | Stable Diffusion | FP32, FP16, BF16 | Input sequence 64 tokens | +| JAX | GPT-2 with LMHead | FP32 | Generates 200 tokens | +| Tensorflow | GPT-2 with LMHead | FP32 | Generates 200 tokens | +| GGML | GPT-2 with LMHead | FP32, FP16 | Generates 200 tokens | ## Dashboards diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py index cf3be2e2..6c37ed19 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/benchmark_definitions.py @@ -159,6 +159,24 @@ batch_sizes=[1, 8], ) +GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32_CASE = def_types.BenchmarkCase.build( + model=model_definitions.GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32, + input_data=testdata.INPUT_DATA_MODEL_DEFAULT, + verify_parameters={"absolute_tolerance": 0.5}, +) + +GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32_CASE = def_types.BenchmarkCase.build( + model=model_definitions.GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32, + input_data=testdata.INPUT_DATA_MODEL_DEFAULT, + verify_parameters={"absolute_tolerance": 0.5}, +) + +GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32_CASE = def_types.BenchmarkCase.build( + model=model_definitions.GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32, + input_data=testdata.INPUT_DATA_MODEL_DEFAULT, + verify_parameters={"absolute_tolerance": 0.5}, +) + ALL_BENCHMARKS = list( itertools.chain( T5_LARGE_FP32_JAX_512XI32_CASES.values(), @@ -186,4 +204,7 @@ GPT2LMHEAD_PIPELINE_JAX_1X4XI32_CASE, T5_SMALL_FP32_JAX_1X128XI32_CASE, GPT2LMHEAD_PIPELINE_JAX_1X4XI32_CASE, + GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32_CASE, + GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32_CASE, + GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32_CASE, ] diff --git a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py index e9e81b7b..d88bcac6 100644 --- a/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py +++ b/common_benchmark_suite/openxla/benchmark/comparative_suite/jax/model_definitions.py @@ -545,6 +545,74 @@ artifacts_dir_url=f"{PARENT_GCS_DIR}/VIT_CLASSIFICATION_JAX_3X224X224XF32", ) +# Gemma models. +# Model implementation from https://huggingface.co/docs/transformers/model_doc/gemma#transformers.FlaxGemmaForCausalLM. +GEMMA_PIPELINE_JAX_IMPL = def_types.ModelImplementation( + name="GEMMA_PIPELINE_JAX", + tags=["gemma", "pipeline"], + framework_type=def_types.ModelFrameworkType.JAX, + module_path=f"{utils.MODELS_MODULE_PATH}.jax.gemma.gemma_pipeline", + source_info= + "https://huggingface.co/docs/transformers/model_doc/gemma#transformers.FlaxGemmaForCausalLM", +) + +GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32 = def_types.Model( + name="GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32", + tags=["fp32", "batch-1", "seqlen-1024"], + model_impl=GEMMA_PIPELINE_JAX_IMPL, + model_parameters={ + "batch_size": 1, + "data_type": "fp32", + "seq_len": 1024, + "max_new_tokens": 256, + "model_name": "google/gemma-2b-it", + }, + exported_model_types=[ + def_types.ModelArtifactType.STABLEHLO_MLIR, + def_types.ModelArtifactType.XLA_HLO_DUMP, + ], + artifacts_dir_url= + f"{PARENT_GCS_DIR}/GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32", +) + +GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32 = def_types.Model( + name="GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32", + tags=["bf16", "batch-1", "seqlen-1024"], + model_impl=GEMMA_PIPELINE_JAX_IMPL, + model_parameters={ + "batch_size": 1, + "data_type": "bf16", + "seq_len": 1024, + "max_new_tokens": 256, + "model_name": "google/gemma-2b-it", + }, + exported_model_types=[ + def_types.ModelArtifactType.STABLEHLO_MLIR, + def_types.ModelArtifactType.XLA_HLO_DUMP, + ], + artifacts_dir_url= + f"{PARENT_GCS_DIR}/GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32", +) + +GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32 = def_types.Model( + name="GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32", + tags=["fp16", "batch-1", "seqlen-1024"], + model_impl=GEMMA_PIPELINE_JAX_IMPL, + model_parameters={ + "batch_size": 1, + "data_type": "fp16", + "seq_len": 1024, + "max_new_tokens": 256, + "model_name": "google/gemma-2b-it", + }, + exported_model_types=[ + def_types.ModelArtifactType.STABLEHLO_MLIR, + def_types.ModelArtifactType.XLA_HLO_DUMP, + ], + artifacts_dir_url= + f"{PARENT_GCS_DIR}/GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32", +) + ALL_MODELS = list( itertools.chain( # Models with different batch sizes. @@ -573,4 +641,7 @@ GPT2LMHEAD_PIPELINE_JAX_1X4XI32, T5_SMALL_FP32_JAX_1X128XI32, VIT_CLASSIFICATION_JAX_3X224X224XF32, + GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32, + GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32, + GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32, ] diff --git a/common_benchmark_suite/openxla/benchmark/models/jax/gemma/__init__.py b/common_benchmark_suite/openxla/benchmark/models/jax/gemma/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/common_benchmark_suite/openxla/benchmark/models/jax/gemma/gemma_pipeline.py b/common_benchmark_suite/openxla/benchmark/models/jax/gemma/gemma_pipeline.py new file mode 100644 index 00000000..2890e1ea --- /dev/null +++ b/common_benchmark_suite/openxla/benchmark/models/jax/gemma/gemma_pipeline.py @@ -0,0 +1,113 @@ +# Copyright 2023 The OpenXLA Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import jax.numpy as jnp + +from transformers import AutoTokenizer, GemmaTokenizer, FlaxPreTrainedModel, FlaxGemmaForCausalLM, GenerationConfig +from typing import Any, List, Tuple + +from openxla.benchmark.models.jax import jax_model_interface + + +class GemmaPipeline(jax_model_interface.JaxInferenceModel): + """See https://huggingface.co/docs/transformers/model_doc/gemma for more information.""" + + batch_size: int + seq_len: int + model: FlaxGemmaForCausalLM + params: dict[str, Any] + model_name: str + tokenizer: GemmaTokenizer + tokenization_kwargs: dict[str, Any] + + def __init__( + self, + batch_size: int, + dtype: Any, + seq_len: int, + max_new_tokens: int, + model_name: str, + ): + self.model, self.params = FlaxGemmaForCausalLM.from_pretrained( + model_name, revision="flax", _do_init=False) + + if dtype == jnp.float32: + self.params = self.model.to_fp32(self.params) + elif dtype == jnp.float16: + self.params = self.model.to_fp16(self.params) + elif dtype == jnp.bfloat16: + self.params = self.model.to_bf16(self.params) + else: + raise ValueError(f"Unsupported data type '{dtype}'.") + + self.model_name = model_name + self.batch_size = batch_size + self.seq_len = seq_len + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + model_max_length=self.seq_len, + padding_side="left", + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenization_kwargs = { + "return_tensors": "jax", + } + + self.generation_config = GenerationConfig.from_pretrained( + model_name, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True) + + def generate_default_inputs(self) -> str: + return "Once upon a time" + + def preprocess(self, input_text: str) -> Tuple[Any,]: + batch_input_text = [input_text] * self.batch_size + inputs = self.tokenizer(text=batch_input_text, **self.tokenization_kwargs) + return (inputs["input_ids"],) + + def forward(self, input_text: Any) -> Any: + output = self.model.generate(input_text, + params=self.params, + generation_config=self.generation_config) + print(f"output: {output}") + + def postprocess(self, output: Any) -> List[str]: + return self.tokenizer.batch_decode(output, skip_special_tokens=True) + + def apply(self, input_text: Any) -> Any: + raise Exception("Not implemented.") + + +DTYPE_MAP = { + "fp32": jnp.float32, + "fp16": jnp.float16, + "bf16": jnp.bfloat16, +} + + +def create_model(batch_size: int = 1, + data_type: str = "fp32", + seq_len: int = 1024, + max_new_tokens: int = 256, + model_name: str = "google/gemma-2b-it", + **_unused_params) -> GemmaPipeline: + """Configure and create a JAX Gemma pipeline. + Args: + batch_size: input batch size. + seq_len: input sequence length. + max_new_tokens: the maximum number of new tokens to generate. + model_name: The name of the Gemma variant to use. Supported variants include: + google/gemma-2b-it, google/gemma-7b-it. + Returns: + A JAX GemmaPipeline. + """ + return GemmaPipeline(batch_size=batch_size, + dtype=DTYPE_MAP[data_type], + seq_len=seq_len, + max_new_tokens=max_new_tokens, + model_name=model_name) diff --git a/common_benchmark_suite/openxla/benchmark/models/jax/gemma/requirements.txt b/common_benchmark_suite/openxla/benchmark/models/jax/gemma/requirements.txt new file mode 100644 index 00000000..9f588849 --- /dev/null +++ b/common_benchmark_suite/openxla/benchmark/models/jax/gemma/requirements.txt @@ -0,0 +1,4 @@ +flax +jax +torch +transformers