Skip to content

Commit

Permalink
Add Gemma 2B IT models
Browse files Browse the repository at this point in the history
Signed-off-by: mariecwhite <[email protected]>
  • Loading branch information
mariecwhite committed Jun 26, 2024
1 parent e97a45a commit 7a39538
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 7 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ please read "User's Guide".

Pipelines may include more than one model or control flow.

| Framework | Pipeline | Data Type | Variations |
|------------|-------------------|------------------|----------------------------------------------|
| JAX | T5-Small | FP32, FP16, BF16 | Token generation sizes: 16, 32, 64, 128, 256 |
| JAX | Stable Diffusion | FP32, FP16, BF16 | Input sequence 64 tokens |
| JAX | GPT-2 with LMHead | FP32 | Generates 200 tokens |
| Tensorflow | GPT-2 with LMHead | FP32 | Generates 200 tokens |
| GGML | GPT-2 with LMHead | FP32, FP16 | Generates 200 tokens |
| Framework | Pipeline | Data Type | Variations |
|------------|-------------------|------------------|------------------------------------------------|
| JAX | Gemma-2B-IT | FP32, FP16, BF16 | Input sequence 1024 tokens. Max new tokens 256 |
| JAX | T5-Small | FP32, FP16, BF16 | Token generation sizes: 16, 32, 64, 128, 256 |
| JAX | Stable Diffusion | FP32, FP16, BF16 | Input sequence 64 tokens |
| JAX | GPT-2 with LMHead | FP32 | Generates 200 tokens |
| Tensorflow | GPT-2 with LMHead | FP32 | Generates 200 tokens |
| GGML | GPT-2 with LMHead | FP32, FP16 | Generates 200 tokens |

## Dashboards

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,24 @@
batch_sizes=[1, 8],
)

GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32_CASE = def_types.BenchmarkCase.build(
model=model_definitions.GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32,
input_data=testdata.INPUT_DATA_MODEL_DEFAULT,
verify_parameters={"absolute_tolerance": 0.5},
)

GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32_CASE = def_types.BenchmarkCase.build(
model=model_definitions.GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32,
input_data=testdata.INPUT_DATA_MODEL_DEFAULT,
verify_parameters={"absolute_tolerance": 0.5},
)

GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32_CASE = def_types.BenchmarkCase.build(
model=model_definitions.GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32,
input_data=testdata.INPUT_DATA_MODEL_DEFAULT,
verify_parameters={"absolute_tolerance": 0.5},
)

ALL_BENCHMARKS = list(
itertools.chain(
T5_LARGE_FP32_JAX_512XI32_CASES.values(),
Expand Down Expand Up @@ -186,4 +204,7 @@
GPT2LMHEAD_PIPELINE_JAX_1X4XI32_CASE,
T5_SMALL_FP32_JAX_1X128XI32_CASE,
GPT2LMHEAD_PIPELINE_JAX_1X4XI32_CASE,
GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32_CASE,
GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32_CASE,
GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32_CASE,
]
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,74 @@
artifacts_dir_url=f"{PARENT_GCS_DIR}/VIT_CLASSIFICATION_JAX_3X224X224XF32",
)

# Gemma models.
# Model implementation from https://huggingface.co/docs/transformers/model_doc/gemma#transformers.FlaxGemmaForCausalLM.
GEMMA_PIPELINE_JAX_IMPL = def_types.ModelImplementation(
name="GEMMA_PIPELINE_JAX",
tags=["gemma", "pipeline"],
framework_type=def_types.ModelFrameworkType.JAX,
module_path=f"{utils.MODELS_MODULE_PATH}.jax.gemma.gemma_pipeline",
source_info=
"https://huggingface.co/docs/transformers/model_doc/gemma#transformers.FlaxGemmaForCausalLM",
)

GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32 = def_types.Model(
name="GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32",
tags=["fp32", "batch-1", "seqlen-1024"],
model_impl=GEMMA_PIPELINE_JAX_IMPL,
model_parameters={
"batch_size": 1,
"data_type": "fp32",
"seq_len": 1024,
"max_new_tokens": 256,
"model_name": "google/gemma-2b-it",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
artifacts_dir_url=
f"{PARENT_GCS_DIR}/GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32",
)

GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32 = def_types.Model(
name="GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32",
tags=["bf16", "batch-1", "seqlen-1024"],
model_impl=GEMMA_PIPELINE_JAX_IMPL,
model_parameters={
"batch_size": 1,
"data_type": "bf16",
"seq_len": 1024,
"max_new_tokens": 256,
"model_name": "google/gemma-2b-it",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
artifacts_dir_url=
f"{PARENT_GCS_DIR}/GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32",
)

GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32 = def_types.Model(
name="GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32",
tags=["fp16", "batch-1", "seqlen-1024"],
model_impl=GEMMA_PIPELINE_JAX_IMPL,
model_parameters={
"batch_size": 1,
"data_type": "fp16",
"seq_len": 1024,
"max_new_tokens": 256,
"model_name": "google/gemma-2b-it",
},
exported_model_types=[
def_types.ModelArtifactType.STABLEHLO_MLIR,
def_types.ModelArtifactType.XLA_HLO_DUMP,
],
artifacts_dir_url=
f"{PARENT_GCS_DIR}/GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32",
)

ALL_MODELS = list(
itertools.chain(
# Models with different batch sizes.
Expand Down Expand Up @@ -573,4 +641,7 @@
GPT2LMHEAD_PIPELINE_JAX_1X4XI32,
T5_SMALL_FP32_JAX_1X128XI32,
VIT_CLASSIFICATION_JAX_3X224X224XF32,
GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32,
GEMMA2BIT_GREEDY_BF16_JAX_1X1024XI32_256XI32,
GEMMA2BIT_GREEDY_FP16_JAX_1X1024XI32_256XI32,
]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2023 The OpenXLA Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import jax.numpy as jnp

from transformers import AutoTokenizer, GemmaTokenizer, FlaxPreTrainedModel, FlaxGemmaForCausalLM, GenerationConfig
from typing import Any, List, Tuple

from openxla.benchmark.models.jax import jax_model_interface


class GemmaPipeline(jax_model_interface.JaxInferenceModel):
"""See https://huggingface.co/docs/transformers/model_doc/gemma for more information."""

batch_size: int
seq_len: int
model: FlaxGemmaForCausalLM
params: dict[str, Any]
model_name: str
tokenizer: GemmaTokenizer
tokenization_kwargs: dict[str, Any]

def __init__(
self,
batch_size: int,
dtype: Any,
seq_len: int,
max_new_tokens: int,
model_name: str,
):
self.model, self.params = FlaxGemmaForCausalLM.from_pretrained(
model_name, revision="flax", _do_init=False)

if dtype == jnp.float32:
self.params = self.model.to_fp32(self.params)
elif dtype == jnp.float16:
self.params = self.model.to_fp16(self.params)
elif dtype == jnp.bfloat16:
self.params = self.model.to_bf16(self.params)
else:
raise ValueError(f"Unsupported data type '{dtype}'.")

self.model_name = model_name
self.batch_size = batch_size
self.seq_len = seq_len
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length=self.seq_len,
padding_side="left",
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenization_kwargs = {
"return_tensors": "jax",
}

self.generation_config = GenerationConfig.from_pretrained(
model_name,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True)

def generate_default_inputs(self) -> str:
return "Once upon a time"

def preprocess(self, input_text: str) -> Tuple[Any,]:
batch_input_text = [input_text] * self.batch_size
inputs = self.tokenizer(text=batch_input_text, **self.tokenization_kwargs)
return (inputs["input_ids"],)

def forward(self, input_text: Any) -> Any:
output = self.model.generate(input_text,
params=self.params,
generation_config=self.generation_config)
print(f"output: {output}")

def postprocess(self, output: Any) -> List[str]:
return self.tokenizer.batch_decode(output, skip_special_tokens=True)

def apply(self, input_text: Any) -> Any:
raise Exception("Not implemented.")


DTYPE_MAP = {
"fp32": jnp.float32,
"fp16": jnp.float16,
"bf16": jnp.bfloat16,
}


def create_model(batch_size: int = 1,
data_type: str = "fp32",
seq_len: int = 1024,
max_new_tokens: int = 256,
model_name: str = "google/gemma-2b-it",
**_unused_params) -> GemmaPipeline:
"""Configure and create a JAX Gemma pipeline.
Args:
batch_size: input batch size.
seq_len: input sequence length.
max_new_tokens: the maximum number of new tokens to generate.
model_name: The name of the Gemma variant to use. Supported variants include:
google/gemma-2b-it, google/gemma-7b-it.
Returns:
A JAX GemmaPipeline.
"""
return GemmaPipeline(batch_size=batch_size,
dtype=DTYPE_MAP[data_type],
seq_len=seq_len,
max_new_tokens=max_new_tokens,
model_name=model_name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
flax
jax
torch
transformers

0 comments on commit 7a39538

Please sign in to comment.