diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 1513f364a..8a443e6ca 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -18,6 +18,8 @@ from typing import Any, Optional import torch +from ...types.tensors import serialized_name_to_dtype, dtype_to_serialized_name + __all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"] @@ -287,9 +289,10 @@ class ClipTextConfig: output_attentions: bool = False output_hidden_states: bool = False use_return_dict: bool = True + dtype: torch.dtype = torch.float32 @staticmethod - def from_transformers_clip_text_config( + def from_hugging_face_clip_text_model_config( config: "transformers.CLIPTextConfig", ) -> "ClipTextConfig": return ClipTextConfig( @@ -308,7 +311,30 @@ def from_transformers_clip_text_config( output_attentions=config.output_attentions, output_hidden_states=config.output_hidden_states, use_return_dict=config.use_return_dict, + dtype=config.torch_dtype or torch.float32, ) - def as_properties(self) -> dict[str, Any]: - return asdict(self) + def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": + kwargs = self.to_properties() + kwargs["torch_dtype"] = kwargs["dtype"] + del kwargs["dtype"] + kwargs["return_dict"] = kwargs["use_return_dict"] + del kwargs["use_return_dict"] + from transformers import CLIPTextConfig + + return CLIPTextConfig(**kwargs) + + @staticmethod + def from_properties(properties: dict[str, Any]) -> "ClipTextConfig": + kwargs = dict(properties) + kwargs.pop("SHARK_DATASET_VERSION") + if "dtype" in kwargs and kwargs["dtype"] is not None: + kwargs["dtype"] = serialized_name_to_dtype(kwargs["dtype"]) + + return ClipTextConfig(**kwargs) + + def to_properties(self) -> dict[str, Any]: + res = asdict(self) + if self.dtype is not None: + res["dtype"] = dtype_to_serialized_name(self.dtype) + return res diff --git a/sharktank/sharktank/models/clip/clip.py b/sharktank/sharktank/models/clip/clip.py index 29734e9f1..0593c940b 100644 --- a/sharktank/sharktank/models/clip/clip.py +++ b/sharktank/sharktank/models/clip/clip.py @@ -21,10 +21,10 @@ ) from collections import OrderedDict -from ...layers import BaseLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer +from ...layers import ThetaLayer, LinearLayer, LayerNorm, TokenEmbeddingLayer from ... import ops from ...types.theta import Theta, Dataset -from ...types.tensors import DefaultPrimitiveTensor +from ...types.tensors import AnyTensor, DefaultPrimitiveTensor from ...layers.configs import ClipTextConfig from ...layers.activations import ACT2FN @@ -68,11 +68,11 @@ def forward( return embeddings -class ClipAttention(BaseLayer): +class ClipAttention(ThetaLayer): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads @@ -182,9 +182,9 @@ def forward( return attn_output, attn_weights_reshaped -class ClipMlp(BaseLayer): +class ClipMlp(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = LinearLayer(theta("fc1")) @@ -197,9 +197,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class ClipEncoderLayer(BaseLayer): +class ClipEncoderLayer(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.embed_dim = config.hidden_size self.self_attn = ClipAttention(theta=theta("self_attn"), config=config) self.layer_norm1 = LayerNorm( @@ -251,14 +251,14 @@ def forward( return outputs -class ClipEncoder(BaseLayer): +class ClipEncoder(ThetaLayer): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`ClipEncoderLayer`]. """ def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.layers = nn.ModuleList( [ @@ -356,9 +356,9 @@ def forward( ) -class ClipTextTransformer(nn.Module): +class ClipTextTransformer(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config embed_dim = config.hidden_size self.embeddings = ClipTextEmbeddings(theta=theta("embeddings"), config=config) @@ -475,9 +475,9 @@ def forward( ) -class ClipTextModel(BaseLayer): +class ClipTextModel(ThetaLayer): def __init__(self, theta: Theta, config: ClipTextConfig): - super().__init__() + super().__init__(theta) self.config = config self.text_model = ClipTextTransformer(theta=theta("text_model"), config=config) @@ -487,6 +487,25 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value + def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]: + input_ids = ( + torch.arange( + start=0, + end=batch_size * self.config.max_position_embeddings, + dtype=torch.long, + ) + % self.config.vocab_size + ) + input_ids = input_ids.reshape([batch_size, self.config.max_position_embeddings]) + return OrderedDict( + [ + ( + "input_ids", + input_ids, + ) + ] + ) + def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 83bda2cbe..3cae3f4c4 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -4,54 +4,98 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Union +from typing import Optional, Union import transformers from transformers.models.clip.modeling_clip import ( - CLIPAttention as TransformersCLIPAttention, - CLIPEncoderLayer as TransformersCLIPEncoderLayer, - CLIPEncoder as TransformersCLIPEncoder, + CLIPAttention as HfCLIPAttention, + CLIPEncoderLayer as HfCLIPEncoderLayer, + CLIPEncoder as HfCLIPEncoder, ) from os import PathLike import torch from ...types.theta import Theta, Dataset, torch_module_to_theta -from ...types.tensors import DefaultPrimitiveTensor from ...layers.configs import ClipTextConfig +from .clip import ClipTextModel +from iree.turbine.aot import FxProgramsBuilder, export -def transformers_clip_attention_to_theta(model: TransformersCLIPAttention) -> Theta: +def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta: return torch_module_to_theta(model) -def transformers_clip_encoder_layer_to_theta(model: TransformersCLIPEncoder) -> Theta: +def hugging_face_clip_encoder_layer_to_theta(model: HfCLIPEncoder) -> Theta: return torch_module_to_theta(model) -def transformers_clip_encoder_to_theta(model: TransformersCLIPEncoderLayer) -> Theta: +def hugging_face_clip_encoder_to_theta(model: HfCLIPEncoderLayer) -> Theta: return torch_module_to_theta(model) -def transformers_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta: +def hugging_face_clip_text_model_to_theta(model: transformers.CLIPTextModel) -> Theta: return torch_module_to_theta(model) -def transformers_clip_text_model_to_dataset( +def hugging_face_clip_text_model_to_dataset( model: transformers.CLIPTextModel, ) -> Dataset: - config = ClipTextConfig.from_transformers_clip_text_config(model.config) - properties = config.as_properties() - theta = transformers_clip_text_model_to_theta(model) + config = ClipTextConfig.from_hugging_face_clip_text_model_config(model.config) + properties = config.to_properties() + theta = hugging_face_clip_text_model_to_theta(model) theta.rename_tensors_to_paths() return Dataset(properties, theta) +def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: + return Dataset(properties=model.config.to_properties(), root_theta=model.theta) + + def export_clip_text_model_dataset_from_hugging_face( model_or_name_or_path: Union[str, PathLike, transformers.CLIPTextModel], output_path: Union[str, PathLike], + dtype: Optional[torch.dtype] = None, ): if isinstance(model_or_name_or_path, transformers.CLIPTextModel): + assert dtype is None model = model_or_name_or_path else: - model = transformers.CLIPTextModel.from_pretrained(model_or_name_or_path) - dataset = transformers_clip_text_model_to_dataset(model) + model = transformers.CLIPTextModel.from_pretrained( + model_or_name_or_path, torch_dtype=dtype + ) + dataset = hugging_face_clip_text_model_to_dataset(model) dataset.save(output_path) + + +def export_clip_text_model_mlir( + model: Union[ClipTextModel, PathLike], + batch_sizes: list[int], + mlir_output_path: str, +): + """ + Args: + model: either the torch module or path to GGUF/IRPA. + """ + if not isinstance(model, ClipTextModel): + dataset = Dataset.load(model) + config = ClipTextConfig.from_properties(dataset.properties) + model = ClipTextModel(theta=dataset.root_theta, config=config) + + fxb = FxProgramsBuilder(model) + + for batch_size in batch_sizes: + sample_inputs = model.sample_inputs(batch_size) + + @fxb.export_program( + name=f"forward_bs{batch_size}", + args=tuple(sample_inputs.values()), + dynamic_shapes=None, + strict=False, + ) + def _( + model, + input_ids, + ): + return model(input_ids) + + output = export(fxb, import_symbolic_shape_expressions=True) + output.save_mlir(mlir_output_path) diff --git a/sharktank/sharktank/models/clip/testing.py b/sharktank/sharktank/models/clip/testing.py new file mode 100644 index 000000000..87634c220 --- /dev/null +++ b/sharktank/sharktank/models/clip/testing.py @@ -0,0 +1,37 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...layers.configs.llm_configs import ClipTextConfig +from ...types.theta import Theta +from .export import hugging_face_clip_text_model_to_theta +import torch + + +def make_clip_text_model_random_theta(config: ClipTextConfig) -> Theta: + from transformers import CLIPTextConfig as HfCLIPTextConfig + from transformers import CLIPTextModel as HfCLIPTextModel + + hf_config = config.to_hugging_face_clip_text_model_config() + model = HfCLIPTextModel(hf_config) + return hugging_face_clip_text_model_to_theta(model) + + +def make_random_input_token_sequences( + batch_size: int, config: ClipTextConfig +) -> torch.LongTensor: + sequence_lens = torch.randint( + low=1, high=config.max_position_embeddings + 1, size=(batch_size,) + ) + sequences = torch.full( + size=(batch_size, config.max_position_embeddings), + fill_value=config.eos_token_id, + dtype=torch.long, + ) + for batch_idx, l in enumerate(sequence_lens): + sequences[batch_idx][0:l] = torch.randint( + low=0, high=config.vocab_size - 1, size=(l,), dtype=torch.long + ) + return sequences diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 143ede184..021925169 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -214,12 +214,14 @@ def rename_tensors_to_paths(self): def torch_module_to_theta(module: torch.nn.Module) -> Theta: - return Theta( + res = Theta( { name: DefaultPrimitiveTensor(data=param) for name, param in module.named_parameters() } ) + res.rename_tensors_to_paths() + return res def flat_to_nested_dict(flat: dict[str, Any]) -> dict[str, Any]: diff --git a/sharktank/sharktank/utils/math.py b/sharktank/sharktank/utils/math.py index 3723f67dd..639f559d2 100644 --- a/sharktank/sharktank/utils/math.py +++ b/sharktank/sharktank/utils/math.py @@ -19,7 +19,7 @@ def round_up_to_multiple_of(x: Number, multiple: Number) -> Number: def cosine_similarity( a: torch.Tensor, b: torch.Tensor, /, *, dim: Optional[Union[int, tuple[int]]] = None -) -> float: +) -> torch.Tensor: """Compute cosine similarity over dimensions dim. If dim is none computes over all dimensions.""" dot_product = torch.sum(a * b, dim=dim) diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 6c81acf9e..d3cf08fd6 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -18,6 +18,7 @@ import gc from ..types import * +from .math import cosine_similarity # Range of torch.rand() is [0,1) # Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values @@ -184,6 +185,36 @@ def assert_iterables_equal( ), f"Iterables not equal at index {i} for elements {v1} and {v2}" +def assert_text_encoder_state_close( + actual: torch.Tensor, expected: torch.Tensor, atol: float +): + """The cosine similarity has been suggested to compare encoder states. + + Dehua Peng, Zhipeng Gui, Huayi Wu - + Interpreting the Curse of Dimensionality from Distance Concentration and Manifold + Effect (2023) + + shows that cosine and all Minkowski distances suffer from the curse of + dimensionality. + The cosine similarity ignores the vector magnitudes. We can probably come up with a + better metric, but this is maybe good enough. + + The functions expects that the last dimension is the features per token. + It will compute the cosine similarity for each token. + """ + cosine_similarity_per_token = cosine_similarity( + actual, + expected, + dim=-1, + ) + torch.testing.assert_close( + cosine_similarity_per_token, + torch.ones_like(cosine_similarity_per_token), + atol=atol, + rtol=0, + ) + + SHARKTANK_TEST_SKIP_ENV_VAR = "SHARKTANK_TEST_SKIP" diff --git a/sharktank/tests/models/clip/clip_test.py b/sharktank/tests/models/clip/clip_test.py index 409999797..99af4ba6f 100644 --- a/sharktank/tests/models/clip/clip_test.py +++ b/sharktank/tests/models/clip/clip_test.py @@ -4,37 +4,61 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from collections import OrderedDict import functools +import iree.compiler +import os from parameterized import parameterized +from copy import copy import pytest import torch from torch.utils._pytree import tree_map from typing import Optional from unittest import TestCase import transformers -from transformers import CLIPTextModel as TransformersCLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel as HfCLIPTextModel, CLIPTokenizer from transformers.models.clip.modeling_clip import ( - CLIPAttention as TransformersCLIPAttention, - CLIPEncoderLayer as TransformersCLIPEncoderLayer, - CLIPEncoder as TransformersCLIPEncoder, + CLIPAttention as HfCLIPAttention, + CLIPEncoderLayer as HfCLIPEncoderLayer, + CLIPEncoder as HfCLIPEncoder, ) -from sharktank.types import DefaultPrimitiveTensor +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +from sharktank.types import ( + DefaultPrimitiveTensor, + dtype_to_serialized_short_name, + Dataset, +) from sharktank.transforms.dataset import set_float_dtype from sharktank.utils.hf_datasets import get_dataset -from sharktank.utils.math import cosine_similarity from sharktank.utils.testing import ( + assert_text_encoder_state_close, make_rand_torch, make_random_mask, TempDirTestBase, test_prompts, ) from sharktank.models.clip.export import ( + export_clip_text_model_mlir, export_clip_text_model_dataset_from_hugging_face, - transformers_clip_attention_to_theta, - transformers_clip_encoder_layer_to_theta, - transformers_clip_encoder_to_theta, - transformers_clip_text_model_to_theta, + hugging_face_clip_attention_to_theta, + hugging_face_clip_encoder_layer_to_theta, + hugging_face_clip_encoder_to_theta, + hugging_face_clip_text_model_to_dataset, + hugging_face_clip_text_model_to_theta, + clip_text_model_to_dataset, +) +from sharktank.models.clip.testing import ( + make_random_input_token_sequences, + make_clip_text_model_random_theta, ) from sharktank.models.clip import ( ClipAttention, @@ -48,21 +72,244 @@ with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')") -@pytest.mark.usefixtures("path_prefix") -class ClipExportTest(TempDirTestBase): +@pytest.mark.usefixtures("caching", "path_prefix") +class ClipTextIreeTest(TempDirTestBase): def setUp(self): super().setUp() + torch.random.manual_seed(12345) if self.path_prefix is None: self.path_prefix = f"{self._temp_dir}/" @with_clip_data def testSmokeExportLargeF32FromHuggingFace(self): - repo_id = "openai/clip-vit-large-patch14" + huggingface_repo_id = "openai/clip-vit-large-patch14" + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + get_dataset( + huggingface_repo_id, + ).download() + target_dtype_name = dtype_to_serialized_short_name(torch.float32) + target_model_path_prefix = f"{self.path_prefix}{huggingface_repo_id_as_path}_text_model_{target_dtype_name}" + output_path = f"{target_model_path_prefix}.irpa" + export_clip_text_model_dataset_from_hugging_face( + huggingface_repo_id, output_path + ) + + @with_clip_data + def testCompareLargeIreeF32AgainstTorchEagerF32(self): + self.runTestCompareIreeAgainstPretrainedTorchEager( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-5, + ) + + @with_clip_data + def testCompareLargeIreeBf16AgainstTorchEagerF32(self): + self.runTestCompareIreeAgainstPretrainedTorchEager( + "openai/clip-vit-large-patch14", + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + # The observed error is 1.43e-2. We leave a bit of margin. + atol=3e-3, + ) + + @with_clip_data + def testCompareToyModelIreeF32AgainstTorchEagerF32(self): + self.runTestCompareToyModelIreeAgainstTorch( + reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-5 + ) + + @with_clip_data + def testCompareToyModelIreeBf16AgainstTorchEagerF32(self): + self.runTestCompareToyModelIreeAgainstTorch( + reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1e-3 + ) + + @torch.no_grad() + def runTestCompareIreeAgainstTorchEagerWithInputTokens( + self, + reference_model: ClipTextModel, + target_dtype: torch.dtype, + input_ids: torch.LongTensor, + atol: float, + file_artifact_prefix_name: str, + ): + reference_dtype_name = dtype_to_serialized_short_name( + reference_model.config.dtype + ) + target_dtype_name = dtype_to_serialized_short_name(target_dtype) + reference_model_path_prefix = ( + f"{self.path_prefix}{file_artifact_prefix_name}_{reference_dtype_name}" + ) + target_model_path_prefix = ( + f"{self.path_prefix}{file_artifact_prefix_name}_{target_dtype_name}" + ) + + target_config = copy(reference_model.config) + target_config.dtype = target_dtype + reference_dataset = clip_text_model_to_dataset(reference_model) + target_dataset = Dataset( + root_theta=reference_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=target_config.dtype) + ), + properties=target_config.to_properties(), + ) + + parameters_path = f"{target_model_path_prefix}.irpa" + if not self.caching or not os.path.exists(parameters_path): + target_dataset.save(parameters_path) + + dataset = Dataset.load(parameters_path) + target_config = ClipTextConfig.from_properties(dataset.properties) + input_args = OrderedDict([("input_ids", input_ids)]) + batch_size = input_ids.shape[0] + + mlir_path = f"{target_model_path_prefix}.mlir" + if not self.caching or not os.path.exists(mlir_path): + export_clip_text_model_mlir( + parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + ) + iree_module_path = f"{target_model_path_prefix}.vmfb" + if not self.caching or not os.path.exists(iree_module_path): + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], + ) + + reference_result_dict = call_torch_module_function( + module=reference_model, + function_name="forward", + kwargs=input_args, + trace_path_prefix=f"{reference_model_path_prefix}_torch_", + ) + expected_outputs = flatten_for_iree_signature(reference_result_dict) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=parameters_path, + ) + iree_args = prepare_iree_module_function_args( + args=flatten_for_iree_signature(input_args), devices=iree_devices + ) + iree_result = iree_to_torch( + *run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name=f"forward_bs{batch_size}", + trace_path_prefix=f"{target_model_path_prefix}_iree_", + ) + ) + actual_outputs = [ + ops.to(iree_result[i], dtype=expected_outputs[i].dtype) + for i in range(len(expected_outputs)) + ] + + actual_last_hidden_states = actual_outputs[0] + expected_last_hidden_states = expected_outputs[0] + + assert_text_encoder_state_close( + actual_last_hidden_states, expected_last_hidden_states, atol + ) + + def runTestCompareRandomModelIreeAgainstTorch( + self, + reference_config: ClipTextConfig, + target_dtype: torch.dtype, + batch_size: int, + atol: float, + file_artifact_prefix_name: str, + ): + input_ids = make_random_input_token_sequences( + batch_size=batch_size, config=reference_config + ) + reference_theta = make_clip_text_model_random_theta(reference_config) + reference_model = ClipTextModel(theta=reference_theta, config=reference_config) + self.runTestCompareIreeAgainstTorchEagerWithInputTokens( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) + + def runTestCompareToyModelIreeAgainstTorch( + self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float + ): + batch_size = 4 + num_attention_heads = 5 + vocab_size = 11 + reference_config = ClipTextConfig( + vocab_size=vocab_size, + hidden_size=13 * num_attention_heads, + intermediate_size=7, + projection_dim=3, + num_attention_heads=num_attention_heads, + max_position_embeddings=17, + layer_norm_eps=1e-4, + num_hidden_layers=2, + bos_token_id=vocab_size - 2, + eos_token_id=vocab_size - 1, + dtype=reference_dtype, + ) + file_artifact_prefix_name = "clip_text_model_toy" + self.runTestCompareRandomModelIreeAgainstTorch( + reference_config=reference_config, + target_dtype=target_dtype, + batch_size=batch_size, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) + + def runTestCompareIreeAgainstPretrainedTorchEager( + self, + huggingface_repo_id: str, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + ): get_dataset( - repo_id, + huggingface_repo_id, ).download() - output_path = f"{self.path_prefix}{repo_id.replace('/', '--')}.irpa" - export_clip_text_model_dataset_from_hugging_face(repo_id, output_path) + + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + file_artifact_prefix_name = f"{huggingface_repo_id_as_path}_text_model" + + hf_model: HfCLIPTextModel = HfCLIPTextModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype + ) + reference_dataset = hugging_face_clip_text_model_to_dataset(hf_model) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + hf_model.config + ) + reference_model = ClipTextModel( + theta=reference_dataset.root_theta, config=config + ) + + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(huggingface_repo_id) + input_ids = tokenizer( + test_prompts, + truncation=True, + max_length=reference_model.config.max_position_embeddings, + padding="max_length", + return_tensors="pt", + )["input_ids"] + + self.runTestCompareIreeAgainstTorchEagerWithInputTokens( + reference_model=reference_model, + target_dtype=target_dtype, + input_ids=input_ids, + atol=atol, + file_artifact_prefix_name=file_artifact_prefix_name, + ) @pytest.mark.usefixtures("get_model_artifacts") @@ -70,7 +317,6 @@ class ClipTextEagerTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() def runTestCompareTorchEagerAgainstHuggingFace( self, @@ -86,16 +332,14 @@ def runTestCompareTorchEagerAgainstHuggingFace( huggingface_repo_id, ).download() - reference_model: TransformersCLIPTextModel = ( - TransformersCLIPTextModel.from_pretrained( - huggingface_repo_id, torch_dtype=reference_dtype - ) + reference_model: HfCLIPTextModel = HfCLIPTextModel.from_pretrained( + huggingface_repo_id, torch_dtype=reference_dtype ) - theta = transformers_clip_text_model_to_theta(reference_model) + theta = hugging_face_clip_text_model_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config( + config = ClipTextConfig.from_hugging_face_clip_text_model_config( reference_model.config ) model = ClipTextModel(theta, config) @@ -119,16 +363,10 @@ def runTestCompareTorchEagerAgainstHuggingFace( actual_outputs, ) - cosine_similarity_per_token = cosine_similarity( + assert_text_encoder_state_close( actual_outputs["last_hidden_state"], expected_outputs["last_hidden_state"], - dim=-1, - ) - torch.testing.assert_close( - cosine_similarity_per_token, - torch.ones_like(cosine_similarity_per_token), atol=atol, - rtol=0, ) @with_clip_data @@ -146,6 +384,7 @@ def testLargeCompareTorchEagerBf16AgainstHuggingFaceF32(self): "openai/clip-vit-large-patch14", reference_dtype=torch.float32, target_dtype=torch.bfloat16, + # The observed error is 3.66e-4. We leave a bit of margin. atol=1e-3, ) @@ -180,15 +419,17 @@ def testCompareEagerToySizedModelAgainstTransformers( bos_token_id=vocab_size - 2, eos_token_id=vocab_size - 1, ) - reference_model = TransformersCLIPTextModel( + reference_model = HfCLIPTextModel( reference_config, ) reference_model.eval() - theta = transformers_clip_text_model_to_theta(reference_model) + theta = hugging_face_clip_text_model_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipTextModel(theta, config) input_ids = torch.randint(low=0, high=vocab_size, size=[batch_size, tgt_len]) @@ -210,7 +451,6 @@ class ClipAttentionTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -241,15 +481,17 @@ def testCompareEagerToySizedModelAgainstTransformers( projection_dim=3, num_attention_heads=num_attention_heads, ) - reference_model = TransformersCLIPAttention( + reference_model = HfCLIPAttention( reference_config, ) reference_model.eval() - theta = transformers_clip_attention_to_theta(reference_model) + theta = hugging_face_clip_attention_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipAttention(theta, config) reference_hidden_states = make_rand_torch( @@ -292,7 +534,6 @@ class ClipEncoderLayerTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -321,15 +562,17 @@ def testCompareEagerToySizedModelAgainstTransformers( num_attention_heads=num_attention_heads, layer_norm_eps=1e-4, ) - reference_model = TransformersCLIPEncoderLayer( + reference_model = HfCLIPEncoderLayer( reference_config, ) reference_model.eval() - theta = transformers_clip_encoder_layer_to_theta(reference_model) + theta = hugging_face_clip_encoder_layer_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipEncoderLayer(theta, config) reference_hidden_states = make_rand_torch( @@ -372,7 +615,6 @@ class ClipEncoderTest(TestCase): def setUp(self): super().setUp() torch.random.manual_seed(12345) - torch.no_grad() @parameterized.expand( [ @@ -402,15 +644,17 @@ def testCompareEagerToySizedModelAgainstTransformers( layer_norm_eps=1e-4, num_hidden_layers=2, ) - reference_model = TransformersCLIPEncoder( + reference_model = HfCLIPEncoder( reference_config, ) reference_model.eval() - theta = transformers_clip_encoder_to_theta(reference_model) + theta = hugging_face_clip_encoder_to_theta(reference_model) theta.rename_tensors_to_paths() theta = theta.transform(functools.partial(set_float_dtype, dtype=target_dtype)) - config = ClipTextConfig.from_transformers_clip_text_config(reference_config) + config = ClipTextConfig.from_hugging_face_clip_text_model_config( + reference_config + ) model = ClipEncoder(theta, config) reference_inputs_embeds = make_rand_torch(