diff --git a/README.md b/README.md index dd6a7aa..31b5670 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # TI-ViT -The repository contains script for exporting PyTorch VIT model to ONNX format in the form that compatible with +The repository contains script for exporting PyTorch VIT model to ONNX format in the form that is compatible with [edgeai-tidl-tools](https://github.com/TexasInstruments/edgeai-tidl-tools) (version 8.6.0.5). ## Installation @@ -12,13 +12,24 @@ pip3 install git+https://github.com/ENOT-AutoDL/ti-vit.git@main ## Examples +### MLP blocks on TI DSP (maximum performance variant) + To export the model version with maximum performance, run the following command: ```commandline export-ti-vit -o npu-max-perf.onnx -t npu-max-perf ``` +This variant of model contains MLP blocks that can be run on TI DSP. GELU operation is approximated. + +### MLP blocks partially on TI DSP (minimal loss of accuracy) To export the model version with minimal loss of accuracy, run the following command: ```commandline export-ti-vit -o npu-max-acc.onnx -t npu-max-acc ``` +This variant of model contains MLP blocks that partially can be run on TI DSP. GELU operation is not approximated. + +## Compilation of the exported model +It is important to disable compilation of all nodes except nodes from MLP blocks ("Squeeze" node from MLP must be +disabled too). The list of operations for ["deny_list:layer_name"](https://github.com/TexasInstruments/edgeai-tidl-tools/blob/08_06_00_05/examples/osrt_python/README.md#options-to-enable-control-on-layer-level-delegation-to-ti-dsparm) +compiler option can be found in the file "output-onnx-dir/output-onnx-name.deny_list", that is generated with onnx file. diff --git a/pyproject.toml b/pyproject.toml index 854aa48..15113a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,8 @@ name = 'ti-vit' version = '0.0.1' dependencies = [ + 'onnx', + 'onnx-simplifier', 'torch==1.13.1', 'torchvision==0.14.1', ] diff --git a/src/ti_vit/export.py b/src/ti_vit/export.py index 04b79e9..a708876 100644 --- a/src/ti_vit/export.py +++ b/src/ti_vit/export.py @@ -1,4 +1,5 @@ import argparse +import json import logging import sys import warnings @@ -6,13 +7,17 @@ from typing import Optional from typing import Union +import onnx import torch +from onnxsim import simplify from torchvision.models import ViT_B_16_Weights from torchvision.models import vit_b_16 from ti_vit.model import TICompatibleVitOrtMaxAcc from ti_vit.model import TICompatibleVitOrtMaxPerf +_LOGGER = logging.getLogger(__name__) + def export( output_onnx_path: Union[str, Path], @@ -56,21 +61,41 @@ def export( device = next(model.parameters()).device dummy_data = torch.ones([1, 3, resolution, resolution], dtype=torch.float32, device=device) + output_onnx_path = Path(output_onnx_path) with warnings.catch_warnings(): warnings.simplefilter("ignore") # disable export warnings torch.onnx.export( model=model, - f=str(output_onnx_path), + f=str(output_onnx_path.resolve()), args=dummy_data, input_names=["input"], output_names=["output"], opset_version=9, ) + _LOGGER.info(f'model exported to onnx (path = "{output_onnx_path}")') + + onnx_model = onnx.load(output_onnx_path) + onnx_model, ok = simplify(onnx_model) + if not ok: + _LOGGER.error("onnx-simplifier step is failed") + else: + onnx.save_model(onnx_model, f=output_onnx_path) + _LOGGER.info("onnx simplified") + + if model_type != "cpu": + deny_list = [node.name for node in onnx_model.graph.node if "mlp" not in node.name or node.op_type == "Squeeze"] + deny_list_path = output_onnx_path.with_suffix(".deny_list") + with deny_list_path.open("wt") as deny_list_file: # pylint: disable=unspecified-encoding + json.dump(deny_list, fp=deny_list_file, indent=4) + _LOGGER.info(f'deny list created (path = "{output_onnx_path}")') def export_ti_compatible_vit() -> None: # pylint: disable=missing-function-docstring logger = logging.getLogger("ti_vit") - logger.addHandler(logging.StreamHandler(sys.stdout)) + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter(fmt="%(levelname)s: %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) logger.setLevel(logging.INFO) parser = argparse.ArgumentParser()