-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
v0.0.1
- Loading branch information
Showing
11 changed files
with
793 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
name: Lint | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
jobs: | ||
lint-python: | ||
name: Pylint | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.9" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install pylint | ||
python -m pip install -e . | ||
- name: Analysing the code with pylint | ||
run: | | ||
pylint --output-format=colorized $(git ls-files '*.py') | ||
lint-python-format: | ||
name: Python format | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.9" | ||
- uses: psf/black@stable | ||
with: | ||
options: "--check --diff" | ||
- uses: isort/isort-action@master | ||
with: | ||
configuration: | ||
--check | ||
--diff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
__pycache__/ | ||
*.egg-info/ | ||
*.egg | ||
|
||
.idea* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 23.3.0 | ||
hooks: | ||
- id: black | ||
- repo: https://github.com/PyCQA/isort | ||
rev: 5.12.0 | ||
hooks: | ||
- id: isort | ||
args: | ||
[ | ||
"--force-single-line-imports", | ||
"--ensure-newline-before-comments", | ||
"--line-length=120", | ||
] | ||
- repo: https://github.com/asottile/pyupgrade | ||
rev: v3.8.0 | ||
hooks: | ||
- id: pyupgrade | ||
- repo: https://github.com/PyCQA/docformatter | ||
rev: v1.7.3 | ||
hooks: | ||
- id: docformatter | ||
additional_dependencies: [tomli] | ||
args: | ||
[ | ||
"--in-place", | ||
"--config", | ||
"pyproject.toml", | ||
] | ||
- repo: https://github.com/executablebooks/mdformat | ||
rev: 0.7.16 | ||
hooks: | ||
- id: mdformat | ||
additional_dependencies: | ||
- mdformat-gfm | ||
- mdformat-black | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.4.0 | ||
hooks: | ||
- id: check-yaml | ||
- id: check-toml | ||
- id: check-json | ||
- id: check-ast | ||
- id: fix-byte-order-marker | ||
- id: end-of-file-fixer | ||
- id: trailing-whitespace | ||
- id: check-added-large-files | ||
- id: check-case-conflict | ||
- id: check-merge-conflict | ||
- id: detect-private-key | ||
- id: end-of-file-fixer | ||
- id: detect-private-key | ||
- id: no-commit-to-branch | ||
args: ["-b=main"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,24 @@ | ||
# ti-vit | ||
# TI-ViT | ||
|
||
The repository contains script for exporting PyTorch VIT model to ONNX format in the form that compatible with | ||
[edgeai-tidl-tools](https://github.com/TexasInstruments/edgeai-tidl-tools) (version 8.6.0.5). | ||
|
||
## Installation | ||
|
||
To install export script run the following command: | ||
```commandline | ||
pip3 install git+https://github.com/ENOT-AutoDL/ti-vit.git@main | ||
``` | ||
|
||
## Examples | ||
|
||
To export the model version with maximum performance, run the following command: | ||
```commandline | ||
export-ti-vit -o npu-max-perf.onnx -t npu-max-perf | ||
``` | ||
|
||
To export the model version with minimal loss of accuracy, run the following command: | ||
```commandline | ||
export-ti-vit -o npu-max-acc.onnx -t npu-max-acc | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
[project] | ||
name = 'ti-vit' | ||
version = '0.0.1' | ||
dependencies = [ | ||
'torch==1.13.1', | ||
'torchvision==0.14.1', | ||
] | ||
|
||
[project.scripts] | ||
export-ti-vit = "ti_vit.export:export_ti_compatible_vit" | ||
|
||
[tool.black] | ||
line-length = 120 | ||
target-version = ["py38", "py39"] | ||
include = '\.pyi?$' | ||
|
||
[tool.isort] | ||
profile = "black" | ||
line_length = 120 | ||
ensure_newline_before_comments = true | ||
force_single_line = true | ||
|
||
[tool.nbqa.mutate] | ||
pyupgrade = 1 | ||
|
||
[tool.nbqa.addopts] | ||
pyupgrade = ["--py38-plus"] | ||
|
||
[tool.docformatter] | ||
recursive = true | ||
wrap-summaries = 0 | ||
wrap-descriptions = 0 | ||
blank = true | ||
black = true | ||
pre-summary-newline = true | ||
|
||
[tool.pylint.format] | ||
max-line-length = 120 | ||
|
||
[tool.pylint.design] | ||
max-args = 12 | ||
max-locals = 30 | ||
max-attributes = 20 | ||
min-public-methods = 0 | ||
|
||
[tool.pylint.typecheck] | ||
generated-members = ["torch.*"] | ||
|
||
[tool.pylint.messages_control] | ||
disable = [ | ||
"logging-fstring-interpolation", | ||
"missing-module-docstring", | ||
"unnecessary-pass", | ||
] | ||
|
||
[tool.pylint.BASIC] | ||
good-names = ["B", "N", "C"] | ||
|
||
[tool.pyright] | ||
reportMissingImports = false | ||
reportMissingTypeStubs = false | ||
reportWildcardImportFromLibrary = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from ti_vit.model import TICompatibleVitOrtMaxAcc | ||
from ti_vit.model import TICompatibleVitOrtMaxPerf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import typing | ||
from enum import Enum | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from ti_vit.common import copy_weights | ||
from ti_vit.common import sync_device_and_mode | ||
|
||
|
||
class AttentionType(Enum): | ||
""" | ||
Type of attention block. | ||
- CONV_CONV - qkv projection and output projection is a convolution with 1x1 kernel | ||
- CONV_LINEAR - qkv projection is a convolution with 1x1 kernel, output projection is linear | ||
""" | ||
|
||
CONV_CONV = "CONV_CONV" | ||
CONV_LINEAR = "CONV_LINEAR" | ||
|
||
|
||
class TICompatibleAttention(nn.Module): | ||
"""TI compatible attention block.""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
num_heads: int = 8, | ||
qkv_bias: bool = False, | ||
attention_type: AttentionType = AttentionType.CONV_LINEAR, | ||
): | ||
""" | ||
Parameters | ||
---------- | ||
dim : int | ||
Total dimension of the model. | ||
num_heads : int | ||
Number of parallel attention heads. | ||
qkv_bias : bool | ||
If True, adds a learnable bias to the qkv projection. Default value is False. | ||
attention_type : AttentionType | ||
Type of attention block (see ``AttentionType`` enum documentation). | ||
""" | ||
super().__init__() | ||
|
||
if dim % num_heads != 0: | ||
raise ValueError(f'"dim"={dim} should be divisible by "num_heads"={num_heads}') | ||
|
||
self.num_heads = num_heads | ||
head_dim = dim // num_heads | ||
self.scale = head_dim**-0.5 | ||
|
||
if attention_type == AttentionType.CONV_CONV: | ||
self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias) | ||
self.out_proj = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(1, 1)) | ||
elif attention_type == AttentionType.CONV_LINEAR: | ||
self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias) | ||
self.out_proj = nn.Linear(in_features=dim, out_features=dim) | ||
else: | ||
raise ValueError(f'Got unknown attention_type "{attention_type}"') | ||
|
||
self._attention_type = attention_type | ||
|
||
def forward( # pylint: disable=missing-function-docstring | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
need_weights: bool = True, | ||
) -> Tuple[torch.Tensor, None]: | ||
del key, value | ||
|
||
assert not need_weights | ||
|
||
x = query | ||
B, N, C = x.shape | ||
|
||
# (B, N, C) -> (B, N, C, 1) -> (B, C, N, 1) | ||
x = x.unsqueeze(3).permute(0, 2, 1, 3) | ||
|
||
qkv = self.qkv_proj(x) | ||
qkv = qkv.reshape(B, 3, C, N) | ||
q, k, v = qkv.split(1, dim=1) | ||
|
||
# (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H) | ||
q = q.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2) | ||
# (B, 1, C, N) -> (B, H, C//H, N) | ||
k = k.reshape(B, self.num_heads, C // self.num_heads, N) | ||
# (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H) | ||
v = v.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2) | ||
|
||
attn = (q @ k) * self.scale | ||
attn = attn.softmax(dim=-1) | ||
|
||
x = attn @ v | ||
|
||
if self._attention_type == AttentionType.CONV_CONV: | ||
# (B, H, N, C//H) -> (B, H, C//H, N) -> (B, C, N, 1) | ||
x = x.permute(0, 1, 3, 2).reshape(B, C, N, 1) | ||
x = self.out_proj(x) | ||
x = x.permute(0, 2, 1, 3) | ||
x = x.squeeze(3) | ||
else: | ||
# (B, H, N, C//H) -> (B, N, H, C//H) -> (B, N, C) | ||
x = x.permute(0, 2, 1, 3).reshape(B, N, C) | ||
x = self.out_proj(x) | ||
|
||
return x, None | ||
|
||
@classmethod | ||
def from_module( | ||
cls, | ||
vit_attn: nn.Module, | ||
attention_type: AttentionType = AttentionType.CONV_CONV, | ||
) -> "TICompatibleAttention": | ||
""" | ||
Create TI compatible attention block from common ViT attention block. | ||
Parameters | ||
---------- | ||
vit_attn : nn.Module | ||
Source block. | ||
attention_type : AttentionType | ||
Attention type (see ``AttentionType`` enum documentation). | ||
Returns | ||
------- | ||
TICompatibleAttention | ||
Instance of ``TICompatibleAttention`` with appropriate weights, device and training mode. | ||
""" | ||
if hasattr(vit_attn, "qkv"): | ||
qkv_proj = typing.cast(nn.Linear, vit_attn.qkv) | ||
out_proj = typing.cast(nn.Linear, vit_attn.proj) | ||
else: | ||
in_proj_weight = typing.cast(nn.Parameter, vit_attn.in_proj_weight) | ||
out_features, in_features = in_proj_weight.shape | ||
qkv_proj = nn.Linear( | ||
in_features=in_features, | ||
out_features=out_features, | ||
bias=hasattr(vit_attn, "in_proj_bias"), | ||
device=in_proj_weight.device, | ||
dtype=in_proj_weight.dtype, | ||
) | ||
qkv_proj.weight = in_proj_weight | ||
qkv_proj.bias = vit_attn.in_proj_bias # pyright: ignore[reportAttributeAccessIssue] | ||
|
||
out_proj = typing.cast(nn.Linear, vit_attn.out_proj) | ||
|
||
ti_compatible_attn = cls( | ||
dim=qkv_proj.in_features, | ||
num_heads=typing.cast(int, vit_attn.num_heads), | ||
qkv_bias=qkv_proj.bias is not None, | ||
attention_type=attention_type, | ||
) | ||
sync_device_and_mode(src=vit_attn, dst=ti_compatible_attn) | ||
|
||
copy_weights(src=qkv_proj, dst=ti_compatible_attn.qkv_proj) | ||
copy_weights(src=out_proj, dst=ti_compatible_attn.out_proj) | ||
|
||
if hasattr(vit_attn, "scale"): | ||
ti_compatible_attn.scale = vit_attn.scale | ||
|
||
return ti_compatible_attn |
Oops, something went wrong.