From 07eb36e2cfcbedf2f4731dc8640f8b00b78d4dd6 Mon Sep 17 00:00:00 2001 From: ivkalgin Date: Mon, 5 Feb 2024 23:32:08 +0100 Subject: [PATCH] v0.0.1 (#1) v0.0.1 --- .github/workflows/lint.yml | 42 ++++++++++ .gitignore | 5 ++ .pre-commit-config.yaml | 55 ++++++++++++ README.md | 25 +++++- pyproject.toml | 62 ++++++++++++++ src/ti_vit/__init__.py | 2 + src/ti_vit/attention.py | 167 +++++++++++++++++++++++++++++++++++++ src/ti_vit/common.py | 45 ++++++++++ src/ti_vit/export.py | 111 ++++++++++++++++++++++++ src/ti_vit/mlp.py | 156 ++++++++++++++++++++++++++++++++++ src/ti_vit/model.py | 124 +++++++++++++++++++++++++++ 11 files changed, 793 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/lint.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml create mode 100644 src/ti_vit/__init__.py create mode 100644 src/ti_vit/attention.py create mode 100644 src/ti_vit/common.py create mode 100644 src/ti_vit/export.py create mode 100644 src/ti_vit/mlp.py create mode 100644 src/ti_vit/model.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..9bb1ca9 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,42 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: + +jobs: + lint-python: + name: Pylint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + python -m pip install -e . + - name: Analysing the code with pylint + run: | + pylint --output-format=colorized $(git ls-files '*.py') + + lint-python-format: + name: Python format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + - uses: psf/black@stable + with: + options: "--check --diff" + - uses: isort/isort-action@master + with: + configuration: + --check + --diff diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2b6e1f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +*.egg-info/ +*.egg + +.idea* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..30a0ae1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,55 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black +- repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: + [ + "--force-single-line-imports", + "--ensure-newline-before-comments", + "--line-length=120", + ] +- repo: https://github.com/asottile/pyupgrade + rev: v3.8.0 + hooks: + - id: pyupgrade +- repo: https://github.com/PyCQA/docformatter + rev: v1.7.3 + hooks: + - id: docformatter + additional_dependencies: [tomli] + args: + [ + "--in-place", + "--config", + "pyproject.toml", + ] +- repo: https://github.com/executablebooks/mdformat + rev: 0.7.16 + hooks: + - id: mdformat + additional_dependencies: + - mdformat-gfm + - mdformat-black +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: check-toml + - id: check-json + - id: check-ast + - id: fix-byte-order-marker + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: detect-private-key + - id: end-of-file-fixer + - id: detect-private-key + - id: no-commit-to-branch + args: ["-b=main"] diff --git a/README.md b/README.md index 7cef31e..dd6a7aa 100644 --- a/README.md +++ b/README.md @@ -1 +1,24 @@ -# ti-vit \ No newline at end of file +# TI-ViT + +The repository contains script for exporting PyTorch VIT model to ONNX format in the form that compatible with +[edgeai-tidl-tools](https://github.com/TexasInstruments/edgeai-tidl-tools) (version 8.6.0.5). + +## Installation + +To install export script run the following command: +```commandline +pip3 install git+https://github.com/ENOT-AutoDL/ti-vit.git@main +``` + +## Examples + +To export the model version with maximum performance, run the following command: +```commandline +export-ti-vit -o npu-max-perf.onnx -t npu-max-perf +``` + +To export the model version with minimal loss of accuracy, run the following command: +```commandline +export-ti-vit -o npu-max-acc.onnx -t npu-max-acc +``` + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..854aa48 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,62 @@ +[project] +name = 'ti-vit' +version = '0.0.1' +dependencies = [ + 'torch==1.13.1', + 'torchvision==0.14.1', +] + +[project.scripts] +export-ti-vit = "ti_vit.export:export_ti_compatible_vit" + +[tool.black] +line-length = 120 +target-version = ["py38", "py39"] +include = '\.pyi?$' + +[tool.isort] +profile = "black" +line_length = 120 +ensure_newline_before_comments = true +force_single_line = true + +[tool.nbqa.mutate] +pyupgrade = 1 + +[tool.nbqa.addopts] +pyupgrade = ["--py38-plus"] + +[tool.docformatter] +recursive = true +wrap-summaries = 0 +wrap-descriptions = 0 +blank = true +black = true +pre-summary-newline = true + +[tool.pylint.format] +max-line-length = 120 + +[tool.pylint.design] +max-args = 12 +max-locals = 30 +max-attributes = 20 +min-public-methods = 0 + +[tool.pylint.typecheck] +generated-members = ["torch.*"] + +[tool.pylint.messages_control] +disable = [ + "logging-fstring-interpolation", + "missing-module-docstring", + "unnecessary-pass", +] + +[tool.pylint.BASIC] +good-names = ["B", "N", "C"] + +[tool.pyright] +reportMissingImports = false +reportMissingTypeStubs = false +reportWildcardImportFromLibrary = false diff --git a/src/ti_vit/__init__.py b/src/ti_vit/__init__.py new file mode 100644 index 0000000..b32ecd8 --- /dev/null +++ b/src/ti_vit/__init__.py @@ -0,0 +1,2 @@ +from ti_vit.model import TICompatibleVitOrtMaxAcc +from ti_vit.model import TICompatibleVitOrtMaxPerf diff --git a/src/ti_vit/attention.py b/src/ti_vit/attention.py new file mode 100644 index 0000000..d41b71b --- /dev/null +++ b/src/ti_vit/attention.py @@ -0,0 +1,167 @@ +import typing +from enum import Enum +from typing import Tuple + +import torch +from torch import nn + +from ti_vit.common import copy_weights +from ti_vit.common import sync_device_and_mode + + +class AttentionType(Enum): + """ + Type of attention block. + + - CONV_CONV - qkv projection and output projection is a convolution with 1x1 kernel + - CONV_LINEAR - qkv projection is a convolution with 1x1 kernel, output projection is linear + + """ + + CONV_CONV = "CONV_CONV" + CONV_LINEAR = "CONV_LINEAR" + + +class TICompatibleAttention(nn.Module): + """TI compatible attention block.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attention_type: AttentionType = AttentionType.CONV_LINEAR, + ): + """ + Parameters + ---------- + dim : int + Total dimension of the model. + num_heads : int + Number of parallel attention heads. + qkv_bias : bool + If True, adds a learnable bias to the qkv projection. Default value is False. + attention_type : AttentionType + Type of attention block (see ``AttentionType`` enum documentation). + """ + super().__init__() + + if dim % num_heads != 0: + raise ValueError(f'"dim"={dim} should be divisible by "num_heads"={num_heads}') + + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + if attention_type == AttentionType.CONV_CONV: + self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias) + self.out_proj = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(1, 1)) + elif attention_type == AttentionType.CONV_LINEAR: + self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias) + self.out_proj = nn.Linear(in_features=dim, out_features=dim) + else: + raise ValueError(f'Got unknown attention_type "{attention_type}"') + + self._attention_type = attention_type + + def forward( # pylint: disable=missing-function-docstring + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + need_weights: bool = True, + ) -> Tuple[torch.Tensor, None]: + del key, value + + assert not need_weights + + x = query + B, N, C = x.shape + + # (B, N, C) -> (B, N, C, 1) -> (B, C, N, 1) + x = x.unsqueeze(3).permute(0, 2, 1, 3) + + qkv = self.qkv_proj(x) + qkv = qkv.reshape(B, 3, C, N) + q, k, v = qkv.split(1, dim=1) + + # (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H) + q = q.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2) + # (B, 1, C, N) -> (B, H, C//H, N) + k = k.reshape(B, self.num_heads, C // self.num_heads, N) + # (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H) + v = v.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2) + + attn = (q @ k) * self.scale + attn = attn.softmax(dim=-1) + + x = attn @ v + + if self._attention_type == AttentionType.CONV_CONV: + # (B, H, N, C//H) -> (B, H, C//H, N) -> (B, C, N, 1) + x = x.permute(0, 1, 3, 2).reshape(B, C, N, 1) + x = self.out_proj(x) + x = x.permute(0, 2, 1, 3) + x = x.squeeze(3) + else: + # (B, H, N, C//H) -> (B, N, H, C//H) -> (B, N, C) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + x = self.out_proj(x) + + return x, None + + @classmethod + def from_module( + cls, + vit_attn: nn.Module, + attention_type: AttentionType = AttentionType.CONV_CONV, + ) -> "TICompatibleAttention": + """ + Create TI compatible attention block from common ViT attention block. + + Parameters + ---------- + vit_attn : nn.Module + Source block. + attention_type : AttentionType + Attention type (see ``AttentionType`` enum documentation). + + Returns + ------- + TICompatibleAttention + Instance of ``TICompatibleAttention`` with appropriate weights, device and training mode. + + """ + if hasattr(vit_attn, "qkv"): + qkv_proj = typing.cast(nn.Linear, vit_attn.qkv) + out_proj = typing.cast(nn.Linear, vit_attn.proj) + else: + in_proj_weight = typing.cast(nn.Parameter, vit_attn.in_proj_weight) + out_features, in_features = in_proj_weight.shape + qkv_proj = nn.Linear( + in_features=in_features, + out_features=out_features, + bias=hasattr(vit_attn, "in_proj_bias"), + device=in_proj_weight.device, + dtype=in_proj_weight.dtype, + ) + qkv_proj.weight = in_proj_weight + qkv_proj.bias = vit_attn.in_proj_bias # pyright: ignore[reportAttributeAccessIssue] + + out_proj = typing.cast(nn.Linear, vit_attn.out_proj) + + ti_compatible_attn = cls( + dim=qkv_proj.in_features, + num_heads=typing.cast(int, vit_attn.num_heads), + qkv_bias=qkv_proj.bias is not None, + attention_type=attention_type, + ) + sync_device_and_mode(src=vit_attn, dst=ti_compatible_attn) + + copy_weights(src=qkv_proj, dst=ti_compatible_attn.qkv_proj) + copy_weights(src=out_proj, dst=ti_compatible_attn.out_proj) + + if hasattr(vit_attn, "scale"): + ti_compatible_attn.scale = vit_attn.scale + + return ti_compatible_attn diff --git a/src/ti_vit/common.py b/src/ti_vit/common.py new file mode 100644 index 0000000..0579e2b --- /dev/null +++ b/src/ti_vit/common.py @@ -0,0 +1,45 @@ +from typing import Union + +import torch +from torch import nn + + +def copy_weights(src: nn.Linear, dst: Union[nn.Linear, nn.Conv2d]) -> None: + """ + Update weights and bias parameters of the destination module with values from the source module. + + Parameters + ---------- + src : nn.Linear + The source module. + dst : Union[nn.Linear, nn.Conv2d] + The destination module. + + """ + with torch.no_grad(): + if isinstance(dst, nn.Linear): + dst.weight.copy_(src.weight) + elif isinstance(dst, nn.Conv2d): + dst.weight.copy_(src.weight.unsqueeze(-1).unsqueeze(-1)) + else: + raise TypeError(f"dst must be nn.Linear or nn.Conv2d (type(dst)={type(dst)})") + + if src.bias is not None: + dst.bias.copy_(src.bias) # pyright: ignore[reportOptionalMemberAccess] + + +def sync_device_and_mode(src: nn.Module, dst: nn.Module) -> None: + """ + Update device and training mode parameters of the destination module with values from the source module. + + Parameters + ---------- + src : nn.Module + The source module. + dst : nn.Module + The destination module. + + """ + device = next(src.parameters()).device + dst.to(device=device) + dst.train(mode=src.training) diff --git a/src/ti_vit/export.py b/src/ti_vit/export.py new file mode 100644 index 0000000..04b79e9 --- /dev/null +++ b/src/ti_vit/export.py @@ -0,0 +1,111 @@ +import argparse +import logging +import sys +import warnings +from pathlib import Path +from typing import Optional +from typing import Union + +import torch +from torchvision.models import ViT_B_16_Weights +from torchvision.models import vit_b_16 + +from ti_vit.model import TICompatibleVitOrtMaxAcc +from ti_vit.model import TICompatibleVitOrtMaxPerf + + +def export( + output_onnx_path: Union[str, Path], + model_type: str, + checkpoint_path: Optional[Union[str, Path]] = None, + resolution: int = 224, +) -> None: + """ + Parameters + ---------- + output_onnx_path : Union[str, Path] + Path to the output ONNX file. + model_type : str + Type of the final model. Possible values are "npu-max-acc", "npu-max-perf" or "cpu". + checkpoint_path : Optional[Union[str, Path]] = None + Path to the PyTorch model checkpoint. If value is None, then ViT_B_16 pretrained torchvision model is used. + Default value is None. + resolution : int + Resolution of input image. Default value is 224. + """ + if checkpoint_path is None: + model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT, progress=True) + else: + checkpoint = torch.load(str(checkpoint_path)) + model = checkpoint["model_ckpt"] + + model.cpu().eval() + + try: + transform_model_func = { + "cpu": lambda model: model, + "npu-max-acc": TICompatibleVitOrtMaxAcc, + "npu-max-perf": lambda model: TICompatibleVitOrtMaxPerf(model=model, ignore_tidl_errors=False), + "npu-max-perf-experimental": lambda model: TICompatibleVitOrtMaxPerf(model=model, ignore_tidl_errors=True), + }[model_type] + except KeyError as exc: + raise ValueError(f"Got unknown transformation type ('{model_type}')") from exc + + model = transform_model_func(model) + + device = next(model.parameters()).device + dummy_data = torch.ones([1, 3, resolution, resolution], dtype=torch.float32, device=device) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # disable export warnings + torch.onnx.export( + model=model, + f=str(output_onnx_path), + args=dummy_data, + input_names=["input"], + output_names=["output"], + opset_version=9, + ) + + +def export_ti_compatible_vit() -> None: # pylint: disable=missing-function-docstring + logger = logging.getLogger("ti_vit") + logger.addHandler(logging.StreamHandler(sys.stdout)) + logger.setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("-o", "--output-onnx", type=str, required=True, help="Path to the output onnx.") + parser.add_argument( + "-t", + "--model-type", + type=str, + required=False, + default="npu-max-perf", + help='Type of the final model (optional argument). Possible values are "npu-max-acc", "npu-max-perf", or "cpu".' + ' Default value is "npu-max-perf".', + ) + parser.add_argument( + "-c", + "--checkpoint", + type=str, + required=False, + help="Path to the ViT checkpoint (optional argument). By default torchvision checkpoint is downloaded." + "(VIT_B_16).", + default=None, + ) + parser.add_argument( + "-r", + "--resolution", + type=int, + required=False, + default=224, + help="Resolution of input images (optional argument). Default value is 224.", + ) + args = parser.parse_args() + + export( + checkpoint_path=args.checkpoint, + output_onnx_path=args.output_onnx, + model_type=args.model_type, + resolution=args.resolution, + ) diff --git a/src/ti_vit/mlp.py b/src/ti_vit/mlp.py new file mode 100644 index 0000000..37f8184 --- /dev/null +++ b/src/ti_vit/mlp.py @@ -0,0 +1,156 @@ +from enum import Enum + +import torch +from torch import nn +from torchvision.models.vision_transformer import MLPBlock + +from ti_vit.common import copy_weights +from ti_vit.common import sync_device_and_mode + + +class MLPType(Enum): + """ + Type of MLP block. + + - CONV_CONV - MLP block assembled as ``convolution_with_kernel_1x1 + activation + convolution_with_kernel_1x1`` + - LINEAR_CONV - MLP block assembled as ``linear + activation + convolution_with_kernel_1x1`` + + """ + + CONV_CONV = "CONV_CONV" + LINEAR_CONV = "LINEAR_CONV" + + +class GeluApproximationType(Enum): + """ + GELU approximation type. + + - NONE - disable approximation + - SIGMOID - approximate as ``x * sigmoid(1.702 * x)`` + - TANH - approximate as ``0.5 * x * (tanh(0.7978845834732056 * (x + 0.044715 * x * x * x)) + 1.0)`` + + """ + + NONE = "NONE" + SIGMOID = "SIGMOID" + TANH = "TANH" + + +class TICompatibleMLP(nn.Module): + """TI compatible MLP block.""" + + def __init__( + self, + dims: int, + hidden_dims: int, + mlp_type: MLPType = MLPType.CONV_CONV, + gelu_approx_type: GeluApproximationType = GeluApproximationType.NONE, + ): + """ + dims : int + Number of channels of the input. + hidden_dims : int + Number of channels of the expanded tensor. + mlp_type : MLPType + MLP type (see ``MLPType`` enum documentation). + gelu_approx_type : GeluApproximationType + GELU approximation type (see ``GeluApproximationType`` enum documentation). + """ + super().__init__() + + try: + self.gelu = { + GeluApproximationType.NONE: nn.GELU(), + GeluApproximationType.SIGMOID: self._gelu_approx_sigmoid, + GeluApproximationType.TANH: self._gelu_approx_tanh, + }[gelu_approx_type] + self._gelu_approx_type = gelu_approx_type + except KeyError as exc: + raise ValueError(f'Got unknown type of gelu approximation "{gelu_approx_type}"') from exc + + if mlp_type == MLPType.CONV_CONV: + self.expand = nn.Conv2d(in_channels=dims, out_channels=hidden_dims, kernel_size=(1, 1)) + self.shrink = nn.Conv2d(in_channels=hidden_dims, out_channels=dims, kernel_size=(1, 1)) + elif mlp_type == MLPType.LINEAR_CONV: + self.expand = nn.Linear(in_features=dims, out_features=hidden_dims) + self.shrink = nn.Conv2d(in_channels=hidden_dims, out_channels=dims, kernel_size=(1, 1)) + else: + raise ValueError(f'Got unknown mlp_type "{mlp_type}"') + + self._mlp_type = mlp_type + + @staticmethod + def _gelu_approx_tanh(x: torch.Tensor) -> torch.Tensor: + # This is default torch approximation (0.5 * x * (tanh(0.7978845834732056 * (x + 0.044715 * x * x * x)) + 1.0)), + # where tanh replaced by (2.0 * nn.functional.sigmoid(2.0 * x) - 1.0) + return x * torch.sigmoid(1.5957691669464111 * (x + 0.044715 * x * x * x)) + + @staticmethod + def _gelu_approx_sigmoid(x: torch.Tensor) -> torch.Tensor: + # simplified torch approximation + return x * torch.sigmoid(1.702 * x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + if self._mlp_type == MLPType.CONV_CONV: + x = x.unsqueeze(3).permute(0, 2, 1, 3) + x = self.expand(x) + x = self.gelu(x) + else: + x = self.expand(x) + if self._gelu_approx_type == GeluApproximationType.NONE: + x = self.gelu(x) + x = x.unsqueeze(3).permute(0, 2, 1, 3) + else: + x = x.unsqueeze(3).permute(0, 2, 1, 3) + x = self.gelu(x) + + x = self.shrink(x) + x = x.permute(0, 2, 1, 3).squeeze(3) + + return x + + @classmethod + def from_module( + cls, + vit_mlp: MLPBlock, + mlp_type: MLPType = MLPType.CONV_CONV, + gelu_approx_type: GeluApproximationType = GeluApproximationType.NONE, + ) -> "TICompatibleMLP": + """ + Create TI compatible MLP block from common ViT MLP block. + + Parameters + ---------- + vit_mlp : MLPBlock + Source block. + mlp_type : MLPType + MLP type (see ``MLPType`` enum documentation). + gelu_approx_type : GeluApproximationType + GELU approximation type (see ``GeluApproximationType`` enum documentation). + + Returns + ------- + TICompatibleMLP + Instance of ``TICompatibleMLP`` with appropriate weights, device and training mode. + + """ + expand, shrink = vit_mlp[0], vit_mlp[3] + if not isinstance(expand, nn.Linear) or not isinstance(shrink, nn.Linear): + raise ValueError('Got unknown type of vit_mlp. Cannot find "Linear" layers.') + if not isinstance(vit_mlp[1], nn.GELU): + raise ValueError('Got unknown type of vit_mlp. Cannot find "GELU" layer.') + if not isinstance(vit_mlp[2], nn.Dropout) or not isinstance(vit_mlp[4], nn.Dropout): + raise ValueError('Got unknown type of vit_mlp. Cannot find "dropout" layers.') + + ti_compatible_mlp = cls( + dims=expand.in_features, + hidden_dims=expand.out_features, + mlp_type=mlp_type, + gelu_approx_type=gelu_approx_type, + ) + sync_device_and_mode(src=vit_mlp, dst=ti_compatible_mlp) + + copy_weights(src=expand, dst=ti_compatible_mlp.expand) + copy_weights(src=shrink, dst=ti_compatible_mlp.shrink) + + return ti_compatible_mlp diff --git a/src/ti_vit/model.py b/src/ti_vit/model.py new file mode 100644 index 0000000..01399e9 --- /dev/null +++ b/src/ti_vit/model.py @@ -0,0 +1,124 @@ +import logging +import typing +from typing import Any +from typing import Dict +from typing import NamedTuple +from typing import Optional + +import torch +from torch import nn +from torchvision.models.vision_transformer import EncoderBlock +from torchvision.models.vision_transformer import VisionTransformer + +from ti_vit.attention import AttentionType +from ti_vit.attention import TICompatibleAttention +from ti_vit.mlp import GeluApproximationType +from ti_vit.mlp import MLPType +from ti_vit.mlp import TICompatibleMLP + +_LOGGER = logging.getLogger(__name__) + + +class _BlockCfg(NamedTuple): + attention_cfg: Optional[Dict[str, Any]] + mlp_cfg: Optional[Dict[str, Any]] + + +class _TICompatibleVit(nn.Module): + def __init__(self, model: VisionTransformer, cfg: Dict[int, _BlockCfg]): + super().__init__() + + self._model = model + + attn_counter, mlp_counter = 0, 0 + for block_index, block_cfg in cfg.items(): + block: EncoderBlock = typing.cast(EncoderBlock, model.encoder.layers[block_index]) + + if block_cfg.attention_cfg is not None: + self_attention = TICompatibleAttention.from_module(block.self_attention, **block_cfg.attention_cfg) + setattr(block, "self_attention", self_attention) + _LOGGER.debug( + f"REPLACE {type(block.self_attention)} => {type(self_attention)} " + f"(BLOCK={block_index}, CFG={block_cfg.attention_cfg})" + ) + attn_counter += 1 + + if block_cfg.mlp_cfg is not None: + mlp = TICompatibleMLP.from_module(block.mlp, **block_cfg.mlp_cfg) + setattr(block, "mlp", mlp) + _LOGGER.debug( + f"REPLACE {type(block.mlp)} => {type(mlp)} " f"(BLOCK={block_index}, CFG={block_cfg.mlp_cfg})" + ) + mlp_counter += 1 + + _LOGGER.info(f"{attn_counter} attentions replaced") + _LOGGER.info(f"{mlp_counter} MLPs replaced") + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + return self._model(x) + + +class TICompatibleVitOrtMaxPerf(_TICompatibleVit): + """TI compatible ViT model with maximum performance.""" + + def __init__(self, model: VisionTransformer, ignore_tidl_errors: bool = False): + """ + Parameters + ---------- + model : VisionTransformer + Source ViT model. + ignore_tidl_errors : bool + Experimental option. + """ + if ignore_tidl_errors: + cfg = {i: self._mlp_perf_block_cfg() if i < 8 else self._attn_mlp_perf_block_cfg() for i in range(12)} + else: + cfg = {i: self._mlp_perf_block_cfg() for i in range(12)} + + super().__init__(model=model, cfg=cfg) + + @staticmethod + def _attn_mlp_perf_block_cfg() -> _BlockCfg: + return _BlockCfg( + attention_cfg={ + "attention_type": AttentionType.CONV_LINEAR, + }, + mlp_cfg={"mlp_type": MLPType.CONV_CONV, "gelu_approx_type": GeluApproximationType.TANH}, + ) + + @staticmethod + def _mlp_perf_block_cfg() -> _BlockCfg: + return _BlockCfg( + attention_cfg=None, + mlp_cfg={"mlp_type": MLPType.CONV_CONV, "gelu_approx_type": GeluApproximationType.TANH}, + ) + + +class TICompatibleVitOrtMaxAcc(_TICompatibleVit): + """TI compatible ViT model with minimal accuracy drop.""" + + def __init__(self, model: VisionTransformer): + """ + Parameters + ---------- + model : VisionTransformer + Source ViT model. + """ + super().__init__( + model=model, + cfg={i: self._mlp_lc_block_cfg() if i < 8 else self._mlp_cc_block_cfg() for i in range(12)}, + ) + + @staticmethod + def _mlp_lc_block_cfg() -> _BlockCfg: + return _BlockCfg( + attention_cfg=None, + mlp_cfg={"mlp_type": MLPType.LINEAR_CONV, "gelu_approx_type": GeluApproximationType.NONE}, + ) + + @staticmethod + def _mlp_cc_block_cfg() -> _BlockCfg: + return _BlockCfg( + attention_cfg=None, + mlp_cfg={"mlp_type": MLPType.CONV_CONV, "gelu_approx_type": GeluApproximationType.NONE}, + )