Skip to content

Commit

Permalink
Feat: deny list and simplifier (#2)
Browse files Browse the repository at this point in the history
* added simplifier step
* added deny-list export step
* updated README
  • Loading branch information
ivkalgin authored Feb 6, 2024
1 parent 07eb36e commit 6d17ee1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TI-ViT

The repository contains script for exporting PyTorch VIT model to ONNX format in the form that compatible with
The repository contains script for exporting PyTorch VIT model to ONNX format in the form that is compatible with
[edgeai-tidl-tools](https://github.com/TexasInstruments/edgeai-tidl-tools) (version 8.6.0.5).

## Installation
Expand All @@ -12,13 +12,24 @@ pip3 install git+https://github.com/ENOT-AutoDL/ti-vit.git@main

## Examples

### MLP blocks on TI DSP (maximum performance variant)

To export the model version with maximum performance, run the following command:
```commandline
export-ti-vit -o npu-max-perf.onnx -t npu-max-perf
```
This variant of model contains MLP blocks that can be run on TI DSP. GELU operation is approximated.

### MLP blocks partially on TI DSP (minimal loss of accuracy)

To export the model version with minimal loss of accuracy, run the following command:
```commandline
export-ti-vit -o npu-max-acc.onnx -t npu-max-acc
```
This variant of model contains MLP blocks that partially can be run on TI DSP. GELU operation is not approximated.

## Compilation of the exported model

It is important to disable compilation of all nodes except nodes from MLP blocks ("Squeeze" node from MLP must be
disabled too). The list of operations for ["deny_list:layer_name"](https://github.com/TexasInstruments/edgeai-tidl-tools/blob/08_06_00_05/examples/osrt_python/README.md#options-to-enable-control-on-layer-level-delegation-to-ti-dsparm)
compiler option can be found in the file "output-onnx-dir/output-onnx-name.deny_list", that is generated with onnx file.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
name = 'ti-vit'
version = '0.0.1'
dependencies = [
'onnx',
'onnx-simplifier',
'torch==1.13.1',
'torchvision==0.14.1',
]
Expand Down
29 changes: 27 additions & 2 deletions src/ti_vit/export.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import argparse
import json
import logging
import sys
import warnings
from pathlib import Path
from typing import Optional
from typing import Union

import onnx
import torch
from onnxsim import simplify
from torchvision.models import ViT_B_16_Weights
from torchvision.models import vit_b_16

from ti_vit.model import TICompatibleVitOrtMaxAcc
from ti_vit.model import TICompatibleVitOrtMaxPerf

_LOGGER = logging.getLogger(__name__)


def export(
output_onnx_path: Union[str, Path],
Expand Down Expand Up @@ -56,21 +61,41 @@ def export(
device = next(model.parameters()).device
dummy_data = torch.ones([1, 3, resolution, resolution], dtype=torch.float32, device=device)

output_onnx_path = Path(output_onnx_path)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # disable export warnings
torch.onnx.export(
model=model,
f=str(output_onnx_path),
f=str(output_onnx_path.resolve()),
args=dummy_data,
input_names=["input"],
output_names=["output"],
opset_version=9,
)
_LOGGER.info(f'model exported to onnx (path = "{output_onnx_path}")')

onnx_model = onnx.load(output_onnx_path)
onnx_model, ok = simplify(onnx_model)
if not ok:
_LOGGER.error("onnx-simplifier step is failed")
else:
onnx.save_model(onnx_model, f=output_onnx_path)
_LOGGER.info("onnx simplified")

if model_type != "cpu":
deny_list = [node.name for node in onnx_model.graph.node if "mlp" not in node.name or node.op_type == "Squeeze"]
deny_list_path = output_onnx_path.with_suffix(".deny_list")
with deny_list_path.open("wt") as deny_list_file: # pylint: disable=unspecified-encoding
json.dump(deny_list, fp=deny_list_file, indent=4)
_LOGGER.info(f'deny list created (path = "{output_onnx_path}")')


def export_ti_compatible_vit() -> None: # pylint: disable=missing-function-docstring
logger = logging.getLogger("ti_vit")
logger.addHandler(logging.StreamHandler(sys.stdout))
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(fmt="%(levelname)s: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser()
Expand Down

0 comments on commit 6d17ee1

Please sign in to comment.