Skip to content

Commit

Permalink
Merge pull request #96 from Masao-Someki/dev_v2
Browse files Browse the repository at this point in the history
[WIP] upgrade to v2
  • Loading branch information
Masao-Someki authored Oct 22, 2023
2 parents 949064a + 9ee5d06 commit 2fa5b7f
Show file tree
Hide file tree
Showing 47 changed files with 503 additions and 319 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

![](https://circleci.com/gh/espnet/espnet_onnx.svg?style=shield)
![](https://img.shields.io/badge/licence-MIT-blue)
[![](https://img.shields.io/badge/pypi-0.1.11-brightgreen)](https://pypi.org/project/espnet-onnx/)
[![](https://img.shields.io/badge/pypi-0.2.0-brightgreen)](https://pypi.org/project/espnet-onnx/)

**ESPnet without PyTorch!**

Expand Down Expand Up @@ -322,6 +322,10 @@ ASR: [Supported architecture for ASR](./docs/ASRSupported.md)

TTS: [Supported architecture for TTS](./docs/TTSSupported.md)

## Developer's Guide

ASR: [Developer's Guide](./docs/DeveloperGuide.md)

## References

- [ESPNet: end-to-end speech processing toolkit](https://github.com/espnet/espnet)
Expand Down
3 changes: 3 additions & 0 deletions docs/ASRSupported.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
| FairseqHubertEncoder | x |
| FairseqHubertPretrainEncoder | x |
| LongformerEncoder | x |
| BranchformerEncoder ||
| E-BranchformerEncoder ||


**Decoder**

Expand Down
20 changes: 20 additions & 0 deletions docs/DeveloperGuide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Developer's Guide for Speech-to-Text Models

**Converting Your Own Model**

To begin, please check if your model can be successfully converted into ONNX format without any special treatment using `ASRModelExport`. It's important to note that if your model contains operations not supported by PyTorch, you may encounter errors. In such cases, follow these steps to successfully convert your model:

1. Create a new class that is ONNX-compatible, excluding any unsupported operations.

2. Integrate your newly created class into the `espnet_onnx.export.convert_map.yml` file. This file will help ESPnet-ONNX identify the conversion between incompatible and compatible classes. Here's an example of how to add your class to the YAML file:

```yaml
asr:
...

# Add your new class here
- from: <incompatible class>
to: <compatible class>
```
3. After adding your class to the `convert_map.yml` file, check if you can successfully convert your model into the ONNX format. ESPnet-ONNX will automatically identify the incompatible classes and replace them with the compatible ones, ensuring a seamless conversion process.
2 changes: 1 addition & 1 deletion espnet_onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from espnet_onnx.asr.asr_streaming import StreamingSpeech2Text
from espnet_onnx.tts.tts_model import Text2Speech

__version__ = "0.1.11"
__version__ = "0.2.0"
8 changes: 4 additions & 4 deletions espnet_onnx/asr/model/encoders/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def forward_encoder(self, feats, feat_length):
["encoder_out", "encoder_out_lens"], {"feats": feats}
)

if self.config.enc_type == "RNNEncoder":
encoder_out = mask_fill(
encoder_out, make_pad_mask(feat_length, encoder_out, 1), 0.0
)
# if self.config.enc_type == "RNNEncoder":
# encoder_out = mask_fill(
# encoder_out, make_pad_mask(feat_length, encoder_out, 1), 0.0
# )

return encoder_out, encoder_out_lens
8 changes: 7 additions & 1 deletion espnet_onnx/export/asr/export_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,17 @@


class ASRModelExport:
def __init__(self, cache_dir: Union[Path, str] = None):
def __init__(self, cache_dir: Union[Path, str] = None, convert_map: Union[str, Path] = None):
assert check_argument_types()
if cache_dir is None:
cache_dir = Path.home() / ".cache" / "espnet_onnx"

if convert_map is None:
convert_map = Path(os.path.dirname(__file__)).parent / "convert_map.yml"

self.cache_dir = Path(cache_dir)
self.convert_map = convert_map

# Use opset_version=12 to avoid optimization error.
# When using the original onnxruntime, 'axes' is moved to input from opset_version=13
# so optimized model will be invalid for onnxruntime<=1.14.1 (latest in 2023/05)
Expand Down Expand Up @@ -75,6 +80,7 @@ def export(
model.asr_model.frontend,
model.asr_model.preencoder,
self.export_config,
self.convert_map,
)
enc_out_size = enc_model.get_output_size()
self._export_encoder(enc_model, export_dir, verbose)
Expand Down
45 changes: 23 additions & 22 deletions espnet_onnx/export/asr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,18 @@
# decoder
from espnet2.asr.decoder.rnn_decoder import RNNDecoder as espnetRNNDecoder
from espnet2.asr.encoder.conformer_encoder import \
ConformerEncoder as espnetConformerEncoder
from espnet2.asr.encoder.contextual_block_conformer_encoder import \
ContextualBlockConformerEncoder as espnetContextualConformer
from espnet2.asr.encoder.contextual_block_transformer_encoder import \
ContextualBlockTransformerEncoder as espnetContextualTransformer
# encoder
from espnet2.asr.encoder.rnn_encoder import RNNEncoder as espnetRNNEncoder
from espnet2.asr.encoder.transformer_encoder import \
TransformerEncoder as espnetTransformerEncoder
from espnet2.asr.encoder.vgg_rnn_encoder import \
VGGRNNEncoder as espnetVGGRNNEncoder

from espnet_onnx.export.asr.models.ctc import CTC

from espnet_onnx.export.asr.models.layers.ctc import CTC
from espnet_onnx.export.asr.models.decoders.rnn import RNNDecoder
from espnet_onnx.export.asr.models.decoders.transducer import TransducerDecoder
from espnet_onnx.export.asr.models.decoders.xformer import XformerDecoder
from espnet_onnx.export.asr.models.encoders.conformer import ConformerEncoder
from espnet_onnx.export.asr.models.encoders.contextual_block_xformer import \
ContextualBlockXformerEncoder
from espnet_onnx.export.asr.models.encoders.rnn import RNNEncoder
from espnet_onnx.export.asr.models.encoders.transformer import \
TransformerEncoder
from espnet_onnx.export.asr.models.joint_network import JointNetwork
from espnet_onnx.export.asr.models.layers.joint_network import JointNetwork

try:
from espnet2.asr.transducer.transducer_decoder import \
Expand All @@ -46,17 +35,29 @@
TransformerLM


def get_encoder(model, frontend, preencoder, export_config):
if isinstance(model, espnetRNNEncoder) or isinstance(model, espnetVGGRNNEncoder):
return RNNEncoder(model, frontend, preencoder, **export_config)
elif isinstance(model, espnetContextualTransformer) or isinstance(
# conversion
from espnet_onnx.utils.export_function import (
replace_modules,
get_replace_modules
)
from espnet_onnx.export.asr.models.encoder_wrapper import DefaultEncoder

def get_encoder(model, frontend, preencoder, export_config, convert_map):
if isinstance(model, espnetContextualTransformer) or isinstance(
model, espnetContextualConformer
):
return ContextualBlockXformerEncoder(model, **export_config)
elif isinstance(model, espnetTransformerEncoder):
return TransformerEncoder(model, frontend, preencoder, **export_config)
elif isinstance(model, espnetConformerEncoder):
return ConformerEncoder(model, frontend, preencoder, **export_config)
else:
_model = replace_modules(
get_replace_modules(
convert_map,
"asr_optimization" if export_config.get("optimize", False) else "asr"
),
model,
preencoder=preencoder,
export_config=export_config
)
return DefaultEncoder(_model, frontend, **export_config)


def get_decoder(model, export_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from espnet.nets.pytorch_backend.rnn.attentions import NoAtt

from espnet_onnx.export.layers.attention import require_tanh
from espnet_onnx.export.asr.models.layers.attention import require_tanh
from espnet_onnx.utils.abs_model import AbsExportModel


Expand Down
4 changes: 2 additions & 2 deletions espnet_onnx/export/asr/models/decoders/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.nn.functional as F
from espnet.nets.pytorch_backend.rnn.attentions import NoAtt

from espnet_onnx.export.layers.attention import OnnxNoAtt, get_attention
from espnet_onnx.export.layers.predecoder import PreDecoder
from espnet_onnx.export.asr.models.layers.attention import OnnxNoAtt, get_attention
from espnet_onnx.export.asr.models.decoders.predecoder import PreDecoder
from espnet_onnx.utils.abs_model import AbsExportModel
from espnet_onnx.utils.function import make_pad_mask

Expand Down
2 changes: 1 addition & 1 deletion espnet_onnx/export/asr/models/decoders/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn

from espnet_onnx.export.asr.models.language_models.embed import Embedding
from espnet_onnx.export.asr.models.layers.embed import Embedding
from espnet_onnx.utils.abs_model import AbsExportModel


Expand Down
6 changes: 3 additions & 3 deletions espnet_onnx/export/asr/models/decoders/xformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from espnet.nets.pytorch_backend.transformer.attention import \
MultiHeadedAttention

from espnet_onnx.export.asr.models.decoder_layer import OnnxDecoderLayer
from espnet_onnx.export.asr.models.language_models.embed import Embedding
from espnet_onnx.export.asr.models.multihead_att import \
from espnet_onnx.export.asr.models.layers.decoder_layer import OnnxDecoderLayer
from espnet_onnx.export.asr.models.layers.embed import Embedding
from espnet_onnx.export.asr.models.layers.multihead_att import \
OnnxMultiHeadedAttention
from espnet_onnx.utils.abs_model import AbsExportModel
from espnet_onnx.utils.function import subsequent_mask
Expand Down
84 changes: 84 additions & 0 deletions espnet_onnx/export/asr/models/encoder_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import torch
import torch.nn as nn
from espnet_onnx.utils.abs_model import AbsExportModel
from espnet_onnx.export.asr.get_config import (
get_frontend_config,
get_norm_config,
)


class DefaultEncoder(nn.Module, AbsExportModel):
def __init__(self, model, frontend, feats_dim=80, **kwargs):
super().__init__()
self.model = model
self.model_name = 'default_encoder'
self.frontend = frontend
self.feats_dim = feats_dim
self.get_frontend(kwargs)
for k,v in kwargs.items():
setattr(self, k, v)

if self.is_optimizable():
self.num_heads = self.model.num_heads
self.hidden_size = self.model.hidden_size

def get_frontend(self, kwargs):
from espnet_onnx.export.asr.models import get_frontend_models

self.frontend_model = get_frontend_models(self.frontend, kwargs)
if self.frontend_model is not None:
self.submodel = []
self.submodel.append(self.frontend_model)
self.feats_dim = self.frontend_model.output_dim

def forward(self, feats):
feats_length = torch.ones(feats[:, :, 0].shape).sum(dim=-1).type(torch.long)
return self.model(feats, feats_length)

def get_output_size(self):
if 'RNNEncoder' in type(self.model).__module__:
# check RNN first
return self.model.model_output_size
elif 'espnet2' in type(self.model).__module__:
# default espnet model
return self.model.encoders[0].size
else:
# optimized espnet_onnx model
return self.model.model.encoders[0].size

def is_optimizable(self):
return 'espnet_onnx' in type(self.model).__module__ \
and 'rnn' not in type(self.model).__module__

def get_dummy_inputs(self):
feats = torch.randn(1, 100, self.feats_dim)
return feats

def get_input_names(self):
return ["feats"]

def get_output_names(self):
return ["encoder_out", "encoder_out_lens"]

def get_dynamic_axes(self):
return {"feats": {1: "feats_length"}, "encoder_out": {1: "enc_out_length"}}

def get_model_config(self, asr_model=None, path=None):
ret = {}
is_vggrnn = 'rnn' in type(self.model).__module__ and \
any(['OnnxVGG2l' in type(m).__name__ for m in asr_model.encoder.modules()])

ret.update(
enc_type='DefaultEncoder',
model_path=os.path.join(path, f"{self.model_name}.onnx"),
is_vggrnn=is_vggrnn,
frontend=get_frontend_config(
asr_model.frontend, self.frontend_model, path=path
),
do_normalize=asr_model.normalize is not None,
do_postencoder=asr_model.postencoder is not None,
)
if ret["do_normalize"]:
ret.update(normalize=get_norm_config(asr_model.normalize, path))
return ret
Loading

0 comments on commit 2fa5b7f

Please sign in to comment.