Skip to content

Commit

Permalink
Apply suggestions from code review (docstrings)
Browse files Browse the repository at this point in the history
Co-authored-by: Arseny <[email protected]>
  • Loading branch information
ivkalgin and senysenyseny16 authored Feb 5, 2024
1 parent 30008a5 commit 6b2f6b8
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TI-ViT

The repository contains script for export pytorch VIT model to onnx format in form that compatible with
The repository contains script for exporting PyTorch VIT model to ONNX format in the form that compatible with
[edgeai-tidl-tools](https://github.com/TexasInstruments/edgeai-tidl-tools) (version 8.6.0.5).

## Installation
Expand Down
2 changes: 1 addition & 1 deletion src/ti_vit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def from_module(
attention_type: AttentionType = AttentionType.CONV_CONV,
) -> "TICompatibleAttention":
"""
Create TI compatible attention block from common Vit attention block.
Create TI compatible attention block from common ViT attention block.
Parameters
----------
Expand Down
6 changes: 3 additions & 3 deletions src/ti_vit/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def export(
Parameters
----------
output_onnx_path : Union[str, Path]
Path to the output onnx.
Path to the output ONNX file.
model_type : str
Type of the final model. Possible values are "npu-max-acc", "npu-max-perf" or "cpu".
checkpoint_path : Optional[Union[str, Path]] = None
Path to the pytorch model checkpoint. If value is None, then ViT_B_16 pretrained torchvision model is used.
Path to the PyTorch model checkpoint. If value is None, then ViT_B_16 pretrained torchvision model is used.
Default value is None.
resolution : int
Resolution of input image. Default value is 224.
Expand Down Expand Up @@ -89,7 +89,7 @@ def export_ti_compatible_vit() -> None: # pylint: disable=missing-function-docs
"--checkpoint",
type=str,
required=False,
help="Path to the Vit checkpoint (optional argument). By default we download the torchvision checkpoint "
help="Path to the ViT checkpoint (optional argument). By default torchvision checkpoint is downloaded."
"(VIT_B_16).",
default=None,
)
Expand Down
2 changes: 1 addition & 1 deletion src/ti_vit/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def from_module(
gelu_approx_type: GeluApproximationType = GeluApproximationType.NONE,
) -> "TICompatibleMLP":
"""
Create TI compatible MLP block from common Vit MLP block.
Create TI compatible MLP block from common ViT MLP block.
Parameters
----------
Expand Down
8 changes: 4 additions & 4 deletions src/ti_vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-f


class TICompatibleVitOrtMaxPerf(_TICompatibleVit):
"""TI compatible Vit model with maximal performance."""
"""TI compatible ViT model with maximum performance."""

def __init__(self, model: VisionTransformer, ignore_tidl_errors: bool = False):
"""
Parameters
----------
model : VisionTransformer
Source Vit model.
Source ViT model.
ignore_tidl_errors : bool
Experimental option.
"""
Expand Down Expand Up @@ -95,14 +95,14 @@ def _mlp_perf_block_cfg() -> _BlockCfg:


class TICompatibleVitOrtMaxAcc(_TICompatibleVit):
"""TI compatible Vit model with minimal accuracy drop."""
"""TI compatible ViT model with minimal accuracy drop."""

def __init__(self, model: VisionTransformer):
"""
Parameters
----------
model : VisionTransformer
Source Vit model.
Source ViT model.
"""
super().__init__(
model=model,
Expand Down

0 comments on commit 6b2f6b8

Please sign in to comment.