Skip to content

Commit

Permalink
Add Timm support in ORTModelForImageClassification (#1578)
Browse files Browse the repository at this point in the history
* add timm support

* add timm inf

* updated file

* added tests

* formatted

* added _export method

* improved code

* add deprecation comment

* update tests

* update tests

* fix test erros

* update test list

* update test req

* update test req

* added abstract method
  • Loading branch information
mht-sharma authored Jan 8, 2024
1 parent 5017d06 commit 8ac9763
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 24 deletions.
16 changes: 6 additions & 10 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from transformers import AutoConfig, PretrainedConfig, is_tf_available, is_torch_available
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging

from ..utils import CONFIG_NAME
from ..utils.import_utils import is_onnx_available


Expand Down Expand Up @@ -1554,7 +1555,7 @@ def infer_task_from_model(
@classmethod
def infer_library_from_model(
cls,
model_name_or_path: str,
model_name_or_path: Union[str, Path],
subfolder: str = "",
revision: Optional[str] = None,
cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE,
Expand Down Expand Up @@ -1597,15 +1598,10 @@ def infer_library_from_model(

if "model_index.json" in all_files:
library_name = "diffusers"
elif "config.json" in all_files:
config_path = full_model_path / "config.json"

if not full_model_path.is_dir():
config_path = huggingface_hub.hf_hub_download(
model_name_or_path, "config.json", subfolder=subfolder, revision=revision
)

model_config = PretrainedConfig.from_json_file(config_path)
elif CONFIG_NAME in all_files:
model_config = PretrainedConfig.from_pretrained(
model_name_or_path, subfolder=subfolder, revision=revision
)

if hasattr(model_config, "pretrained_cfg") or hasattr(model_config, "architecture"):
library_name = "timm"
Expand Down
44 changes: 36 additions & 8 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from typing import TYPE_CHECKING, Optional, Union

from huggingface_hub import HfApi, HfFolder
from transformers import AutoConfig, add_start_docstrings
from transformers import AutoConfig, PretrainedConfig, add_start_docstrings

from .exporters import TasksManager
from .utils import CONFIG_NAME


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers import PreTrainedModel, TFPreTrainedModel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,7 +81,7 @@ class OptimizedModel(PreTrainedModel):
base_model_prefix = "optimized_model"
config_name = CONFIG_NAME

def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "PretrainedConfig"):
def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: PretrainedConfig):
super().__init__()
self.model = model
self.config = config
Expand Down Expand Up @@ -224,7 +225,7 @@ def _load_config(
force_download: bool = False,
subfolder: str = "",
trust_remote_code: bool = False,
) -> "PretrainedConfig":
) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=config_name_or_path,
Expand Down Expand Up @@ -257,7 +258,7 @@ def _load_config(
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
Expand All @@ -273,7 +274,7 @@ def _from_pretrained(
def _from_transformers(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
Expand All @@ -285,7 +286,28 @@ def _from_transformers(
) -> "OptimizedModel":
"""Overwrite this method in subclass to define how to load your model from vanilla transformers model"""
raise NotImplementedError(
"Overwrite this method in subclass to define how to load your model from vanilla transformers model"
"`_from_transformers` method will be deprecated in a future release. Please override `_export` instead"
"to define how to load your model from vanilla transformers model"
)

@classmethod
@abstractmethod
def _export(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
trust_remote_code: bool = False,
**kwargs,
) -> "OptimizedModel":
"""Overwrite this method in subclass to define how to load your model from vanilla hugging face model"""
raise NotImplementedError(
"Overwrite this method in subclass to define how to load your model from vanilla hugging face model"
)

@classmethod
Expand All @@ -298,7 +320,7 @@ def from_pretrained(
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
subfolder: str = "",
config: Optional["PretrainedConfig"] = None,
config: Optional[PretrainedConfig] = None,
local_files_only: bool = False,
trust_remote_code: bool = False,
revision: Optional[str] = None,
Expand All @@ -325,6 +347,11 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(model_id, subfolder, revision)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)):
Expand Down Expand Up @@ -369,6 +396,7 @@ def from_pretrained(
trust_remote_code = False

from_pretrained_method = cls._from_transformers if export else cls._from_pretrained

return from_pretrained_method(
model_id=model_id,
config=config,
Expand Down
36 changes: 36 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,42 @@ def _from_transformers(
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
) -> "ORTModel":
"""The method will be deprecated in future releases."""
return cls._export(
model_id=model_id,
config=config,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
use_auth_token=use_auth_token,
subfolder=subfolder,
local_files_only=local_files_only,
trust_remote_code=trust_remote_code,
provider=provider,
session_options=session_options,
provider_options=provider_options,
use_io_binding=use_io_binding,
task=task,
)

@classmethod
def _export(
cls,
model_id: str,
config: "PretrainedConfig",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
trust_remote_code: bool = False,
provider: str = "CPUExecutionProvider",
session_options: Optional[ort.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
) -> "ORTModel":
if task is None:
task = cls._auto_model_to_task(cls.auto_model_class)
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"torchaudio",
"einops",
"invisible-watermark",
"timm",
"scikit-learn",
]

QUALITY_REQUIRE = ["black~=23.1", "ruff==0.1.5"]
Expand Down
6 changes: 0 additions & 6 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,6 @@
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"regnet": "facebook/regnet-y-040",
"resnet": "microsoft/resnet-50",
"resnext26ts": "timm/resnext26ts.ra2_in1k",
"resnext50-32x4d": "timm/resnext50_32x4d.tv2_in1k",
"resnext50d-32x4d": "timm/resnext50d_32x4d.bt_in1k",
"resnext101-32x4d": "timm/resnext101_32x4d.gluon_in1k",
"resnext101-32x8d": "timm/resnext101_32x8d.tv_in1k",
"resnext101-64x4d": "timm/resnext101_64x4d.c1_in1k",
"roberta": "roberta-base",
"roformer": "junnyu/roformer_chinese_base",
"sam": "facebook/sam-vit-base",
Expand Down
66 changes: 66 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import onnxruntime
import pytest
import requests
import timm
import torch
from huggingface_hub.constants import default_cache_path
from parameterized import parameterized
Expand Down Expand Up @@ -2714,16 +2715,81 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin):
"vit",
]

TIMM_SUPPORTED_ARCHITECTURES = ["default-timm-config"]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
ORTMODEL_CLASS = ORTModelForImageClassification
TASK = "image-classification"

def _get_model_ids(self, model_arch):
model_ids = MODEL_NAMES[model_arch]
if isinstance(model_ids, dict):
model_ids = list(model_ids.keys())
else:
model_ids = [model_ids]
return model_ids

def _get_onnx_model_dir(self, model_id, model_arch, test_name):
onnx_model_dir = self.onnx_model_dirs[test_name]
if isinstance(MODEL_NAMES[model_arch], dict):
onnx_model_dir = onnx_model_dir[model_id]

return onnx_model_dir

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForImageClassification.from_pretrained(MODEL_NAMES["t5"], export=True)

self.assertIn("only supports the tasks", str(context.exception))

@parameterized.expand(TIMM_SUPPORTED_ARCHITECTURES)
@pytest.mark.run_slow
@pytest.mark.timm_test
@slow
def test_compare_to_timm(self, model_arch):
model_args = {"test_name": model_arch, "model_arch": model_arch}

self._setup(model_args)

model_ids = self._get_model_ids(model_arch)
for model_id in model_ids:
onnx_model = ORTModelForImageClassification.from_pretrained(
self._get_onnx_model_dir(model_id, model_arch, model_arch)
)

self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession)
self.assertIsInstance(onnx_model.config, PretrainedConfig)

set_seed(SEED)
timm_model = timm.create_model(model_id, pretrained=True)
timm_model = timm_model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(timm_model)
transforms = timm.data.create_transform(**data_config, is_training=False)

url = (
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
)
image = Image.open(requests.get(url, stream=True).raw)
inputs = transforms(image).unsqueeze(0)

with torch.no_grad():
timm_outputs = timm_model(inputs)

for input_type in ["pt", "np"]:
if input_type == "np":
inputs = inputs.cpu().detach().numpy()
onnx_outputs = onnx_model(inputs)

self.assertIn("logits", onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])

# compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), timm_outputs, atol=1e-4))

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_args = {"test_name": model_arch, "model_arch": model_arch}
Expand Down
37 changes: 37 additions & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,43 @@
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"default-timm-config": {
"timm/inception_v3.tf_adv_in1k": ["image-classification"],
"timm/tf_efficientnet_b0.in1k": ["image-classification"],
"timm/resnetv2_50x1_bit.goog_distilled_in1k": ["image-classification"],
"timm/cspdarknet53.ra_in1k": ["image-classification"],
"timm/cspresnet50.ra_in1k": ["image-classification"],
"timm/cspresnext50.ra_in1k": ["image-classification"],
"timm/densenet121.ra_in1k": ["image-classification"],
"timm/dla102.in1k": ["image-classification"],
"timm/dpn107.mx_in1k": ["image-classification"],
"timm/ecaresnet101d.miil_in1k": ["image-classification"],
"timm/efficientnet_b1_pruned.in1k": ["image-classification"],
"timm/inception_resnet_v2.tf_ens_adv_in1k": ["image-classification"],
"timm/fbnetc_100.rmsp_in1k": ["image-classification"],
"timm/xception41.tf_in1k": ["image-classification"],
"timm/senet154.gluon_in1k": ["image-classification"],
"timm/seresnext26d_32x4d.bt_in1k": ["image-classification"],
"timm/hrnet_w18.ms_aug_in1k": ["image-classification"],
"timm/inception_v3.gluon_in1k": ["image-classification"],
"timm/inception_v4.tf_in1k": ["image-classification"],
"timm/mixnet_s.ft_in1k": ["image-classification"],
"timm/mnasnet_100.rmsp_in1k": ["image-classification"],
"timm/mobilenetv2_100.ra_in1k": ["image-classification"],
"timm/mobilenetv3_small_050.lamb_in1k": ["image-classification"],
"timm/nasnetalarge.tf_in1k": ["image-classification"],
"timm/tf_efficientnet_b0.ns_jft_in1k": ["image-classification"],
"timm/pnasnet5large.tf_in1k": ["image-classification"],
"timm/regnetx_002.pycls_in1k": ["image-classification"],
"timm/regnety_002.pycls_in1k": ["image-classification"],
"timm/res2net101_26w_4s.in1k": ["image-classification"],
"timm/res2next50.in1k": ["image-classification"],
"timm/resnest101e.in1k": ["image-classification"],
"timm/spnasnet_100.rmsp_in1k": ["image-classification"],
"timm/resnet18.fb_swsl_ig1b_ft_in1k": ["image-classification"],
"timm/wide_resnet101_2.tv_in1k": ["image-classification"],
"timm/tresnet_l.miil_in1k": ["image-classification"],
},
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
"detr": "hf-internal-testing/tiny-random-detr",
Expand Down

0 comments on commit 8ac9763

Please sign in to comment.