diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 3cae3f4c4..20e1468da 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -11,11 +11,11 @@ CLIPEncoderLayer as HfCLIPEncoderLayer, CLIPEncoder as HfCLIPEncoder, ) -from os import PathLike import torch from ...types.theta import Theta, Dataset, torch_module_to_theta from ...layers.configs import ClipTextConfig +from ...utils.typing import AnyPath from .clip import ClipTextModel from iree.turbine.aot import FxProgramsBuilder, export @@ -50,9 +50,14 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: return Dataset(properties=model.config.to_properties(), root_theta=model.theta) +def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: AnyPath): + dataset = clip_text_model_to_dataset(model) + dataset.save(output_path) + + def export_clip_text_model_dataset_from_hugging_face( - model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel], - output_path: Union[str, PathLike], + model_or_name_or_path: Union[AnyPath, transformers.CLIPTextModel], + output_path: AnyPath, dtype: Optional[torch.dtype] = None, ): if isinstance(model_or_name_or_path, transformers.CLIPTextModel): @@ -67,7 +72,7 @@ def export_clip_text_model_dataset_from_hugging_face( def export_clip_text_model_mlir( - model: Union[ClipTextModel, PathLike], + model: Union[ClipTextModel, AnyPath], batch_sizes: list[int], mlir_output_path: str, ): @@ -99,3 +104,17 @@ def _( output = export(fxb, import_symbolic_shape_expressions=True) output.save_mlir(mlir_output_path) + + +def export_clip_text_model_to_iree( + model: ClipTextModel, + batch_sizes: list[int], + mlir_output_path: AnyPath, + parameters_output_path: AnyPath, +): + export_clip_text_model_iree_parameters(model, parameters_output_path) + export_clip_text_model_mlir( + model=parameters_output_path, + batch_sizes=batch_sizes, + mlir_output_path=mlir_output_path, + ) diff --git a/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py new file mode 100644 index 000000000..b4b5ec77d --- /dev/null +++ b/sharktank/sharktank/models/clip/export_toy_text_model_iree_test_data.py @@ -0,0 +1,30 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from argparse import ArgumentParser +from typing import Optional + +from .testing import export_clip_toy_text_model_default_iree_test_data + + +def main(args: Optional[list[str]] = None): + parser = ArgumentParser( + description=( + "Export test data for toy-sized CLIP text model." + " This program MLIR, parameters sample input and expected output." + " Exports float32 and bfloat16 model variants." + " The expected output is always in float32 precision." + ) + ) + parser.add_argument( + "--output-path-prefix", type=str, default=f"clip_toy_text_model" + ) + args = parser.parse_args(args=args) + export_clip_toy_text_model_default_iree_test_data(args.output_path_prefix) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/models/clip/testing.py b/sharktank/sharktank/models/clip/testing.py index 87634c220..0ffa3b594 100644 --- a/sharktank/sharktank/models/clip/testing.py +++ b/sharktank/sharktank/models/clip/testing.py @@ -4,14 +4,161 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ...layers.configs.llm_configs import ClipTextConfig -from ...types.theta import Theta -from .export import hugging_face_clip_text_model_to_theta +import functools import torch +from os import PathLike +from typing import Union, Optional +from copy import copy +from iree.turbine.aot.params import ParameterArchiveBuilder + +from ...layers.configs.llm_configs import ClipTextConfig +from .clip import ClipTextModel +from ...types.theta import Theta, Dataset +from ...types.tensors import dtype_to_serialized_short_name +from ...utils.typing import AnyPath +from ...utils.io import save_tensor_as_irpa +from .export import ( + clip_text_model_to_dataset, + hugging_face_clip_text_model_to_theta, + export_clip_text_model_to_iree, +) +from ...transforms.dataset import set_float_dtype + + +def clip_toy_text_model_config(dtype: torch.dtype) -> ClipTextConfig: + num_attention_heads = 5 + vocab_size = 11 + return ClipTextConfig( + vocab_size=vocab_size, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + max_position_embeddings=17, + layer_norm_eps=1e-4, + num_hidden_layers=2, + bos_token_id=vocab_size - 2, + eos_token_id=vocab_size - 1, + dtype=dtype, + ) + + +def export_clip_toy_text_model_default_iree_test_data(output_path_prefix: str): + # We want to always export the same without interfering with RNG for the rest of + # the program. + rng_state = torch.get_rng_state() + torch.random.manual_seed(12345) + + reference_dtype = torch.float32 + target_dtypes = [torch.float32, torch.bfloat16] + target_iree_parameters_output_paths = [] + target_mlir_output_paths = [] + batch_size = 4 + for dtype in target_dtypes: + prefix = f"{output_path_prefix}_{dtype_to_serialized_short_name(dtype)}" + target_iree_parameters_output_paths.append(f"{prefix}_parameters.irpa") + target_mlir_output_paths.append(f"{prefix}.mlir") + call_prefix = f"{output_path_prefix}_forward_bs{batch_size}" + input_ids_output_path = f"{call_prefix}_arg0_input_ids.irpa" + expected_last_hidden_state_output_path = ( + f"{call_prefix}_expected_result0_last_hidden_state_" + f"{dtype_to_serialized_short_name(reference_dtype)}.irpa" + ) + export_clip_toy_text_model_iree_test_data( + reference_dtype=reference_dtype, + target_dtypes=target_dtypes, + batch_size=batch_size, + input_ids_output_path=input_ids_output_path, + expected_last_hidden_state_output_path=expected_last_hidden_state_output_path, + target_iree_parameters_output_paths=target_iree_parameters_output_paths, + target_mlir_output_paths=target_mlir_output_paths, + ) + + torch.set_rng_state(rng_state) + + +def export_clip_toy_text_model_iree_test_data( + reference_dtype: torch.dtype, + target_dtypes: list[torch.dtype], + batch_size: int, + target_iree_parameters_output_paths: list[AnyPath], + target_mlir_output_paths: list[AnyPath], + input_ids_output_path: AnyPath, + expected_last_hidden_state_output_path: AnyPath, +): + reference_config = clip_toy_text_model_config(reference_dtype) + input_ids = make_random_input_token_sequences( + batch_size=batch_size, config=reference_config + ) + reference_theta = make_clip_text_model_random_theta(reference_config) + reference_model = ClipTextModel(theta=reference_theta, config=reference_config) + for i, ( + target_dtype, + target_iree_parameters_output_path, + target_mlir_output_path, + ) in enumerate( + zip( + target_dtypes, + target_iree_parameters_output_paths, + target_mlir_output_paths, + strict=True, + ) + ): + export_clip_text_model_iree_test_data( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + target_iree_parameters_output_path=target_iree_parameters_output_path, + target_mlir_output_path=target_mlir_output_path, + input_ids_output_path=input_ids_output_path if i == 0 else None, + expected_last_hidden_state_output_path=expected_last_hidden_state_output_path + if i == 0 + else None, + ) + + +def export_clip_text_model_iree_test_data( + reference_model: ClipTextModel, + target_dtype: torch.dtype, + input_ids: torch.LongTensor, + target_mlir_output_path: AnyPath, + target_iree_parameters_output_path: AnyPath, + input_ids_output_path: Optional[AnyPath] = None, + expected_last_hidden_state_output_path: Optional[AnyPath] = None, +): + batch_size = input_ids.shape[0] + reference_dataset = clip_text_model_to_dataset(reference_model) + target_config = copy(reference_model.config) + target_config.dtype = target_dtype + target_dataset = Dataset( + root_theta=reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=torch.bfloat16) + ), + properties=target_config.to_properties(), + ) + target_model = ClipTextModel(theta=target_dataset.root_theta, config=target_config) + export_clip_text_model_to_iree( + target_model, + batch_sizes=[batch_size], + mlir_output_path=target_mlir_output_path, + parameters_output_path=target_iree_parameters_output_path, + ) + + if input_ids_output_path is not None: + save_tensor_as_irpa(input_ids, input_ids_output_path) + + if expected_last_hidden_state_output_path is None: + return + + expected_last_hidden_state = reference_model(input_ids=input_ids)[ + "last_hidden_state" + ] + save_tensor_as_irpa( + expected_last_hidden_state, expected_last_hidden_state_output_path + ) def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta: - from transformers import CLIPTextConfig as HfCLIPTextConfig from transformers import CLIPTextModel as HfCLIPTextModel hf_config = config.to_hugging_face_clip_text_model_config() diff --git a/sharktank/sharktank/utils/io.py b/sharktank/sharktank/utils/io.py index ac2480846..5f1f3905d 100644 --- a/sharktank/sharktank/utils/io.py +++ b/sharktank/sharktank/utils/io.py @@ -5,10 +5,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from pathlib import Path +import torch -from iree.turbine.aot import ( - ParameterArchiveBuilder, -) +from iree.turbine.aot import ParameterArchiveBuilder, ParameterArchive + +from .typing import AnyPath class ShardedArchiveBuilder(ParameterArchiveBuilder): @@ -49,3 +50,22 @@ def path_for_rank(path: Path, rank: int): /tmp/foobar.rank0.irpa """ return path.with_suffix(f".rank{rank}{path.suffix}") + + +def save_tensor_as_irpa(tensor: torch.Tensor, path: AnyPath): + """Save a single tensor into an IRPA file.""" + param_builder = ParameterArchiveBuilder() + param_builder.add_tensor("", tensor) + param_builder.save(path) + + +def load_irpa_as_tensor(tensor: torch.Tensor, path: AnyPath, **kwargs): + """Load a tensor form an IRPA file that holds only one tensor.""" + params = ParameterArchive(path, **kwargs) + items = params.items() + if len(items) != 1: + raise ValueError( + f'Too many items {len(items)} in IRPA file "{path}".' + " Only a single tensor was expected." + ) + return items[0][1].as_tensor() diff --git a/sharktank/sharktank/utils/typing.py b/sharktank/sharktank/utils/typing.py new file mode 100644 index 000000000..386a05032 --- /dev/null +++ b/sharktank/sharktank/utils/typing.py @@ -0,0 +1,10 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Union +from os import PathLike + +AnyPath = Union[str, PathLike] diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py index 99af4ba6f..40dbab985 100644 --- a/sharktank/tests/models/clip/clip_test.py +++ b/sharktank/tests/models/clip/clip_test.py @@ -59,6 +59,7 @@ from sharktank.models.clip.testing import ( make_random_input_token_sequences, make_clip_text_model_random_theta, + export_clip_text_model_iree_test_data, ) from sharktank.models.clip import ( ClipAttention, @@ -96,6 +97,11 @@ def testSmokeExportLargeF32FromHuggingFace(self): huggingface_repo_id, output_path ) + def testSmokeExportToyIreeTestData(self): + from sharktank.models.clip.export_toy_text_model_iree_test_data import main + + main([f"--output-path-prefix={self.path_prefix}clip_toy_text_model"]) + @with_clip_data def testCompareLargeIreeF32AgainstTorchEagerF32(self): self.runTestCompareIreeAgainstPretrainedTorchEager( @@ -147,30 +153,24 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}" ) - target_config = copy(reference_model.config) - target_config.dtype = target_dtype - reference_dataset = clip_text_model_to_dataset(reference_model) - target_dataset = Dataset( - root_theta=reference_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=target_config.dtype) - ), - properties=target_config.to_properties(), - ) - parameters_path = f"{target_model_path_prefix}.irpa" - if not self.caching or not os.path.exists(parameters_path): - target_dataset.save(parameters_path) - - dataset = Dataset.load(parameters_path) - target_config = ClipTextConfig.from_properties(dataset.properties) input_args = OrderedDict([("input_ids", input_ids)]) batch_size = input_ids.shape[0] - mlir_path = f"{target_model_path_prefix}.mlir" - if not self.caching or not os.path.exists(mlir_path): - export_clip_text_model_mlir( - parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + + if ( + not self.caching + or not os.path.exists(mlir_path) + or not os.path.exists(parameters_path) + ): + export_clip_text_model_iree_test_data( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + target_mlir_output_path=mlir_path, + target_iree_parameters_output_path=parameters_path, ) + iree_module_path = f"{target_model_path_prefix}.vmfb" if not self.caching or not os.path.exists(iree_module_path): iree.compiler.compile_file( @@ -211,11 +211,11 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens( for i in range(len(expected_outputs)) ] - actual_last_hidden_states = actual_outputs[0] - expected_last_hidden_states = expected_outputs[0] + actual_last_hidden_state = actual_outputs[0] + expected_last_hidden_state = expected_outputs[0] assert_text_encoder_state_close( - actual_last_hidden_states, expected_last_hidden_states, atol + actual_last_hidden_state, expected_last_hidden_state, atol ) def runTestCompareRandomModelIreeAgainstTorch(