Skip to content

Commit

Permalink
v0.0.1 (#1)
Browse files Browse the repository at this point in the history
v0.0.1
  • Loading branch information
ivkalgin authored Feb 5, 2024
1 parent 33df177 commit 07eb36e
Show file tree
Hide file tree
Showing 11 changed files with 793 additions and 1 deletion.
42 changes: 42 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Lint

on:
push:
branches:
- main
pull_request:

jobs:
lint-python:
name: Pylint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
python -m pip install -e .
- name: Analysing the code with pylint
run: |
pylint --output-format=colorized $(git ls-files '*.py')
lint-python-format:
name: Python format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.9"
- uses: psf/black@stable
with:
options: "--check --diff"
- uses: isort/isort-action@master
with:
configuration:
--check
--diff
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__pycache__/
*.egg-info/
*.egg

.idea*
55 changes: 55 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
repos:
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args:
[
"--force-single-line-imports",
"--ensure-newline-before-comments",
"--line-length=120",
]
- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
hooks:
- id: pyupgrade
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.3
hooks:
- id: docformatter
additional_dependencies: [tomli]
args:
[
"--in-place",
"--config",
"pyproject.toml",
]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
hooks:
- id: mdformat
additional_dependencies:
- mdformat-gfm
- mdformat-black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-yaml
- id: check-toml
- id: check-json
- id: check-ast
- id: fix-byte-order-marker
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: detect-private-key
- id: end-of-file-fixer
- id: detect-private-key
- id: no-commit-to-branch
args: ["-b=main"]
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,24 @@
# ti-vit
# TI-ViT

The repository contains script for exporting PyTorch VIT model to ONNX format in the form that compatible with
[edgeai-tidl-tools](https://github.com/TexasInstruments/edgeai-tidl-tools) (version 8.6.0.5).

## Installation

To install export script run the following command:
```commandline
pip3 install git+https://github.com/ENOT-AutoDL/ti-vit.git@main
```

## Examples

To export the model version with maximum performance, run the following command:
```commandline
export-ti-vit -o npu-max-perf.onnx -t npu-max-perf
```

To export the model version with minimal loss of accuracy, run the following command:
```commandline
export-ti-vit -o npu-max-acc.onnx -t npu-max-acc
```

62 changes: 62 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
[project]
name = 'ti-vit'
version = '0.0.1'
dependencies = [
'torch==1.13.1',
'torchvision==0.14.1',
]

[project.scripts]
export-ti-vit = "ti_vit.export:export_ti_compatible_vit"

[tool.black]
line-length = 120
target-version = ["py38", "py39"]
include = '\.pyi?$'

[tool.isort]
profile = "black"
line_length = 120
ensure_newline_before_comments = true
force_single_line = true

[tool.nbqa.mutate]
pyupgrade = 1

[tool.nbqa.addopts]
pyupgrade = ["--py38-plus"]

[tool.docformatter]
recursive = true
wrap-summaries = 0
wrap-descriptions = 0
blank = true
black = true
pre-summary-newline = true

[tool.pylint.format]
max-line-length = 120

[tool.pylint.design]
max-args = 12
max-locals = 30
max-attributes = 20
min-public-methods = 0

[tool.pylint.typecheck]
generated-members = ["torch.*"]

[tool.pylint.messages_control]
disable = [
"logging-fstring-interpolation",
"missing-module-docstring",
"unnecessary-pass",
]

[tool.pylint.BASIC]
good-names = ["B", "N", "C"]

[tool.pyright]
reportMissingImports = false
reportMissingTypeStubs = false
reportWildcardImportFromLibrary = false
2 changes: 2 additions & 0 deletions src/ti_vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ti_vit.model import TICompatibleVitOrtMaxAcc
from ti_vit.model import TICompatibleVitOrtMaxPerf
167 changes: 167 additions & 0 deletions src/ti_vit/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import typing
from enum import Enum
from typing import Tuple

import torch
from torch import nn

from ti_vit.common import copy_weights
from ti_vit.common import sync_device_and_mode


class AttentionType(Enum):
"""
Type of attention block.
- CONV_CONV - qkv projection and output projection is a convolution with 1x1 kernel
- CONV_LINEAR - qkv projection is a convolution with 1x1 kernel, output projection is linear
"""

CONV_CONV = "CONV_CONV"
CONV_LINEAR = "CONV_LINEAR"


class TICompatibleAttention(nn.Module):
"""TI compatible attention block."""

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attention_type: AttentionType = AttentionType.CONV_LINEAR,
):
"""
Parameters
----------
dim : int
Total dimension of the model.
num_heads : int
Number of parallel attention heads.
qkv_bias : bool
If True, adds a learnable bias to the qkv projection. Default value is False.
attention_type : AttentionType
Type of attention block (see ``AttentionType`` enum documentation).
"""
super().__init__()

if dim % num_heads != 0:
raise ValueError(f'"dim"={dim} should be divisible by "num_heads"={num_heads}')

self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5

if attention_type == AttentionType.CONV_CONV:
self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias)
self.out_proj = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(1, 1))
elif attention_type == AttentionType.CONV_LINEAR:
self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias)
self.out_proj = nn.Linear(in_features=dim, out_features=dim)
else:
raise ValueError(f'Got unknown attention_type "{attention_type}"')

self._attention_type = attention_type

def forward( # pylint: disable=missing-function-docstring
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
need_weights: bool = True,
) -> Tuple[torch.Tensor, None]:
del key, value

assert not need_weights

x = query
B, N, C = x.shape

# (B, N, C) -> (B, N, C, 1) -> (B, C, N, 1)
x = x.unsqueeze(3).permute(0, 2, 1, 3)

qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, 3, C, N)
q, k, v = qkv.split(1, dim=1)

# (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H)
q = q.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2)
# (B, 1, C, N) -> (B, H, C//H, N)
k = k.reshape(B, self.num_heads, C // self.num_heads, N)
# (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H)
v = v.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2)

attn = (q @ k) * self.scale
attn = attn.softmax(dim=-1)

x = attn @ v

if self._attention_type == AttentionType.CONV_CONV:
# (B, H, N, C//H) -> (B, H, C//H, N) -> (B, C, N, 1)
x = x.permute(0, 1, 3, 2).reshape(B, C, N, 1)
x = self.out_proj(x)
x = x.permute(0, 2, 1, 3)
x = x.squeeze(3)
else:
# (B, H, N, C//H) -> (B, N, H, C//H) -> (B, N, C)
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
x = self.out_proj(x)

return x, None

@classmethod
def from_module(
cls,
vit_attn: nn.Module,
attention_type: AttentionType = AttentionType.CONV_CONV,
) -> "TICompatibleAttention":
"""
Create TI compatible attention block from common ViT attention block.
Parameters
----------
vit_attn : nn.Module
Source block.
attention_type : AttentionType
Attention type (see ``AttentionType`` enum documentation).
Returns
-------
TICompatibleAttention
Instance of ``TICompatibleAttention`` with appropriate weights, device and training mode.
"""
if hasattr(vit_attn, "qkv"):
qkv_proj = typing.cast(nn.Linear, vit_attn.qkv)
out_proj = typing.cast(nn.Linear, vit_attn.proj)
else:
in_proj_weight = typing.cast(nn.Parameter, vit_attn.in_proj_weight)
out_features, in_features = in_proj_weight.shape
qkv_proj = nn.Linear(
in_features=in_features,
out_features=out_features,
bias=hasattr(vit_attn, "in_proj_bias"),
device=in_proj_weight.device,
dtype=in_proj_weight.dtype,
)
qkv_proj.weight = in_proj_weight
qkv_proj.bias = vit_attn.in_proj_bias # pyright: ignore[reportAttributeAccessIssue]

out_proj = typing.cast(nn.Linear, vit_attn.out_proj)

ti_compatible_attn = cls(
dim=qkv_proj.in_features,
num_heads=typing.cast(int, vit_attn.num_heads),
qkv_bias=qkv_proj.bias is not None,
attention_type=attention_type,
)
sync_device_and_mode(src=vit_attn, dst=ti_compatible_attn)

copy_weights(src=qkv_proj, dst=ti_compatible_attn.qkv_proj)
copy_weights(src=out_proj, dst=ti_compatible_attn.out_proj)

if hasattr(vit_attn, "scale"):
ti_compatible_attn.scale = vit_attn.scale

return ti_compatible_attn
Loading

0 comments on commit 07eb36e

Please sign in to comment.