From b71e9069affbfb65645a229c14b579dd990b004b Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Mon, 9 Dec 2024 15:57:33 +0000 Subject: [PATCH] Add CLI script exporting CLIP Toy model IREE test data This is required to have an easy way of exporting test data that will be used in IREE to guard against regressions. E.g. ``` python -m sharktank.models.clip.export_toy_text_model_iree_test_data \ --output-path-prefix=clip_toy_text_model ``` Refactor some of the existing tests to reuse the new export logic. --- sharktank/sharktank/models/clip/export.py | 27 ++- .../export_toy_text_model_iree_test_data.py | 30 ++++ sharktank/sharktank/models/clip/testing.py | 155 +++++++++++++++++- sharktank/sharktank/utils/io.py | 26 ++- sharktank/sharktank/utils/typing.py | 10 ++ sharktank/tests/models/clip/clip_test.py | 44 ++--- 6 files changed, 259 insertions(+), 33 deletions(-) create mode 100644 sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py create mode 100644 sharktank/sharktank/utils/typing.py diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 3cae3f4c4..20e1468da 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -11,11 +11,11 @@ CLIPEncoderLayer as HfCLIPEncoderLayer, CLIPEncoder as HfCLIPEncoder, ) -from os import PathLike import torch from ...types.theta import Theta, Dataset, torch_module_to_theta from ...layers.configs import ClipTextConfig +from ...utils.typing import AnyPath from .clip import ClipTextModel from iree.turbine.aot import FxProgramsBuilder, export @@ -50,9 +50,14 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: return Dataset(properties=model.config.to_properties(), root_theta=model.theta) +def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: AnyPath): + dataset = clip_text_model_to_dataset(model) + dataset.save(output_path) + + def export_clip_text_model_dataset_from_hugging_face( - model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel], - output_path: Union[str, PathLike], + model_or_name_or_path: Union[AnyPath, transformers.CLIPTextModel], + output_path: AnyPath, dtype: Optional[torch.dtype] = None, ): if isinstance(model_or_name_or_path, transformers.CLIPTextModel): @@ -67,7 +72,7 @@ def export_clip_text_model_dataset_from_hugging_face( def export_clip_text_model_mlir( - model: Union[ClipTextModel, PathLike], + model: Union[ClipTextModel, AnyPath], batch_sizes: list[int], mlir_output_path: str, ): @@ -99,3 +104,17 @@ def _( output = export(fxb, import_symbolic_shape_expressions=True) output.save_mlir(mlir_output_path) + + +def export_clip_text_model_to_iree( + model: ClipTextModel, + batch_sizes: list[int], + mlir_output_path: AnyPath, + parameters_output_path: AnyPath, +): + export_clip_text_model_iree_parameters(model, parameters_output_path) + export_clip_text_model_mlir( + model=parameters_output_path, + batch_sizes=batch_sizes, + mlir_output_path=mlir_output_path, + ) diff --git a/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py new file mode 100644 index 000000000..b4b5ec77d --- /dev/null +++ b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py @@ -0,0 +1,30 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from argparse import ArgumentParser +from typing import Optional + +from .testing import export_clip_toy_text_model_default_iree_test_data + + +def main(args: Optional[list[str]] = None): + parser = ArgumentParser( + description=( + "Export test data for toy-sized CLIP text model." + " This program MLIR, parameters sample input and expected output." + " Exports float32 and bfloat16 model variants." + " The expected output is always in float32 precision." + ) + ) + parser.add_argument( + "--output-path-prefix", type=str, default=f"clip_toy_text_model" + ) + args = parser.parse_args(args=args) + export_clip_toy_text_model_default_iree_test_data(args.output_path_prefix) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/models/clip/testing.py b/sharktank/sharktank/models/clip/testing.py index 87634c220..0ffa3b594 100644 --- a/sharktank/sharktank/models/clip/testing.py +++ b/sharktank/sharktank/models/clip/testing.py @@ -4,14 +4,161 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ...layers.configs.llm_configs import ClipTextConfig -from ...types.theta import Theta -from .export import hugging_face_clip_text_model_to_theta +import functools import torch +from os import PathLike +from typing import Union, Optional +from copy import copy +from iree.turbine.aot.params import ParameterArchiveBuilder + +from ...layers.configs.llm_configs import ClipTextConfig +from .clip import ClipTextModel +from ...types.theta import Theta, Dataset +from ...types.tensors import dtype_to_serialized_short_name +from ...utils.typing import AnyPath +from ...utils.io import save_tensor_as_irpa +from .export import ( + clip_text_model_to_dataset, + hugging_face_clip_text_model_to_theta, + export_clip_text_model_to_iree, +) +from ...transforms.dataset import set_float_dtype + + +def clip_toy_text_model_config(dtype: torch.dtype) -> ClipTextConfig: + num_attention_heads = 5 + vocab_size = 11 + return ClipTextConfig( + vocab_size=vocab_size, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + max_position_embeddings=17, + layer_norm_eps=1e-4, + num_hidden_layers=2, + bos_token_id=vocab_size - 2, + eos_token_id=vocab_size - 1, + dtype=dtype, + ) + + +def export_clip_toy_text_model_default_iree_test_data(output_path_prefix: str): + # We want to always export the same without interfering with RNG for the rest of + # the program. + rng_state = torch.get_rng_state() + torch.random.manual_seed(12345) + + reference_dtype = torch.float32 + target_dtypes = [torch.float32, torch.bfloat16] + target_iree_parameters_output_paths = [] + target_mlir_output_paths = [] + batch_size = 4 + for dtype in target_dtypes: + prefix = f"{output_path_prefix}_{dtype_to_serialized_short_name(dtype)}" + target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa") + target_mlir_output_paths.append(f"{prefix}.mlir") + call_prefix = f"{output_path_prefix}_forward_bs{batch_size}" + input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa" + expected_last_hidden_state_output_path = ( + f"{call_prefix}_expected_result0_last_hidden_state_" + f"{dtype_to_serialized_short_name(reference_dtype)}.irpa" + ) + export_clip_toy_text_model_iree_test_data( + reference_dtype=reference_dtype, + target_dtypes=target_dtypes, + batch_size=batch_size, + input_ids_output_path=input_ids_output_path, + expected_last_hidden_state_output_path=expected_last_hidden_state_output_path, + target_iree_parameters_output_paths=target_iree_parameters_output_paths, + target_mlir_output_paths=target_mlir_output_paths, + ) + + torch.set_rng_state(rng_state) + + +def export_clip_toy_text_model_iree_test_data( + reference_dtype: torch.dtype, + target_dtypes: list[torch.dtype], + batch_size: int, + target_iree_parameters_output_paths: list[AnyPath], + target_mlir_output_paths: list[AnyPath], + input_ids_output_path: AnyPath, + expected_last_hidden_state_output_path: AnyPath, +): + reference_config = clip_toy_text_model_config(reference_dtype) + input_ids = make_random_input_token_sequences( + batch_size=batch_size, config=reference_config + ) + reference_theta = make_clip_text_model_random_theta(reference_config) + reference_model = ClipTextModel(theta=reference_theta, config=reference_config) + for i, ( + target_dtype, + target_iree_parameters_output_path, + target_mlir_output_path, + ) in enumerate( + zip( + target_dtypes, + target_iree_parameters_output_paths, + target_mlir_output_paths, + strict=True, + ) + ): + export_clip_text_model_iree_test_data( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + target_iree_parameters_output_path=target_iree_parameters_output_path, + target_mlir_output_path=target_mlir_output_path, + input_ids_output_path=input_ids_output_path if i == 0 else None, + expected_last_hidden_state_output_path=expected_last_hidden_state_output_path + if i == 0 + else None, + ) + + +def export_clip_text_model_iree_test_data( + reference_model: ClipTextModel, + target_dtype: torch.dtype, + input_ids: torch.LongTensor, + target_mlir_output_path: AnyPath, + target_iree_parameters_output_path: AnyPath, + input_ids_output_path: Optional[AnyPath] = None, + expected_last_hidden_state_output_path: Optional[AnyPath] = None, +): + batch_size = input_ids.shape[0] + reference_dataset = clip_text_model_to_dataset(reference_model) + target_config = copy(reference_model.config) + target_config.dtype = target_dtype + target_dataset = Dataset( + root_theta=reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=torch.bfloat16) + ), + properties=target_config.to_properties(), + ) + target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config) + export_clip_text_model_to_iree( + target_model, + batch_sizes=[batch_size], + mlir_output_path=target_mlir_output_path, + parameters_output_path=target_iree_parameters_output_path, + ) + + if input_ids_output_path is not None: + save_tensor_as_irpa(input_ids, input_ids_output_path) + + if expected_last_hidden_state_output_path is None: + return + + expected_last_hidden_state = reference_model(input_ids=input_ids)[ + "last_hidden_state" + ] + save_tensor_as_irpa( + expected_last_hidden_state, expected_last_hidden_state_output_path + ) def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta: - from transformers import CLIPTextConfig as HfCLIPTextConfig from transformers import CLIPTextModel as HfCLIPTextModel hf_config = config.to_hugging_face_clip_text_model_config() diff --git a/sharktank/sharktank/utils/io.py b/sharktank/sharktank/utils/io.py index ac2480846..5f1f3905d 100644 --- a/sharktank/sharktank/utils/io.py +++ b/sharktank/sharktank/utils/io.py @@ -5,10 +5,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from pathlib import Path +import torch -from iree.turbine.aot import ( - ParameterArchiveBuilder, -) +from iree.turbine.aot import ParameterArchiveBuilder, ParameterArchive + +from .typing import AnyPath class ShardedArchiveBuilder(ParameterArchiveBuilder): @@ -49,3 +50,22 @@ def path_for_rank(path: Path, rank: int): /tmp/foobar.rank0.irpa """ return path.with_suffix(f".rank{rank}{path.suffix}") + + +def save_tensor_as_irpa(tensor: torch.Tensor, path: AnyPath): + """Save a single tensor into an IRPA file.""" + param_builder = ParameterArchiveBuilder() + param_builder.add_tensor("", tensor) + param_builder.save(path) + + +def load_irpa_as_tensor(tensor: torch.Tensor, path: AnyPath, **kwargs): + """Load a tensor form an IRPA file that holds only one tensor.""" + params = ParameterArchive(path, **kwargs) + items = params.items() + if len(items) != 1: + raise ValueError( + f'Too many items {len(items)} in IRPA file "{path}".' + " Only a single tensor was expected." + ) + return items[0][1].as_tensor() diff --git a/sharktank/sharktank/utils/typing.py b/sharktank/sharktank/utils/typing.py new file mode 100644 index 000000000..386a05032 --- /dev/null +++ b/sharktank/sharktank/utils/typing.py @@ -0,0 +1,10 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Union +from os import PathLike + +AnyPath = Union[str, PathLike] diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py index 99af4ba6f..40dbab985 100644 --- a/sharktank/tests/models/clip/clip_test.py +++ b/sharktank/tests/models/clip/clip_test.py @@ -59,6 +59,7 @@ from sharktank.models.clip.testing import ( make_random_input_token_sequences, make_clip_text_model_random_theta, + export_clip_text_model_iree_test_data, ) from sharktank.models.clip import ( ClipAttention, @@ -96,6 +97,11 @@ def testSmokeExportLargeF32FromHuggingFace(self): huggingface_repo_id, output_path ) + def testSmokeExportToyIreeTestData(self): + from sharktank.models.clip.export_toy_text_model_iree_test_data import main + + main([f"--output-path-prefix={self.path_prefix}clip_toy_text_model"]) + @with_clip_data def testCompareLargeIreeF32AgainstTorchEagerF32(self): self.runTestCompareIreeAgainstPretrainedTorchEager( @@ -147,30 +153,24 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}" ) - target_config = copy(reference_model.config) - target_config.dtype = target_dtype - reference_dataset = clip_text_model_to_dataset(reference_model) - target_dataset = Dataset( - root_theta=reference_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=target_config.dtype) - ), - properties=target_config.to_properties(), - ) - parameters_path = f"{target_model_path_prefix}.irpa" - if not self.caching or not os.path.exists(parameters_path): - target_dataset.save(parameters_path) - - dataset = Dataset.load(parameters_path) - target_config = ClipTextConfig.from_properties(dataset.properties) input_args = OrderedDict([("input_ids", input_ids)]) batch_size = input_ids.shape[0] - mlir_path = f"{target_model_path_prefix}.mlir" - if not self.caching or not os.path.exists(mlir_path): - export_clip_text_model_mlir( - parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + + if ( + not self.caching + or not os.path.exists(mlir_path) + or not os.path.exists(parameters_path) + ): + export_clip_text_model_iree_test_data( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + target_mlir_output_path=mlir_path, + target_iree_parameters_output_path=parameters_path, ) + iree_module_path = f"{target_model_path_prefix}.vmfb" if not self.caching or not os.path.exists(iree_module_path): iree.compiler.compile_file( @@ -211,11 +211,11 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( for i in range(len(expected_outputs)) ] - actual_last_hidden_states = actual_outputs[0] - expected_last_hidden_states = expected_outputs[0] + actual_last_hidden_state = actual_outputs[0] + expected_last_hidden_state = expected_outputs[0] assert_text_encoder_state_close( - actual_last_hidden_states, expected_last_hidden_states, atol + actual_last_hidden_state, expected_last_hidden_state, atol ) def runTestCompareRandomModelIreeAgainstTorch(