Skip to content

Commit

Permalink
Merge pull request #10 from wannaphong/add-wav2vec2-mode
Browse files Browse the repository at this point in the history
PyThaiASR v1.2.0
  • Loading branch information
wannaphong authored Oct 15, 2022
2 parents 4071489 + c0ab28c commit de40473
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pip install pythaiasr
```

**For Wav2Vec2 with language model:**
if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model, you needs to install by the step.
if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model with language model, you needs to install by the step.

```sh
pip install pythaiasr[lm]
Expand All @@ -37,17 +37,19 @@ print(asr(file))
### API

```python
asr(file: str, model: str = "airesearch/wav2vec2-large-xlsr-53-th")
asr(file: str, model: str = _model_name, lm: bool=False, device: str=None)
```

- file: path of sound file
- model: The ASR model
- lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model)
- device: device
- return: thai text from ASR

**Options for model**
- *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer)
- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer)

You can read about models from the list:

Expand Down
39 changes: 28 additions & 11 deletions pythaiasr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
# -*- coding: utf-8 -*-
import torch
from transformers import AutoProcessor, AutoModelForCTC
import torchaudio
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ASR:
def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=None) -> None:
def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=False, device: str=None) -> None:
"""
:param str model: The ASR model
:param str model: The ASR model name
:param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model)
:param str device: device
**Options for model**
* *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
"""
self.processor = AutoProcessor.from_pretrained(model)
self.model_name = model
self.model = AutoModelForCTC.from_pretrained(model)
self.support_model =[
"airesearch/wav2vec2-large-xlsr-53-th",
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm",
"wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut"
]
assert self.model_name in self.support_model
self.lm =lm
if not self.lm:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
else:
from transformers import AutoProcessor, AutoModelForCTC
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model = AutoModelForCTC.from_pretrained(self.model_name)
if device!=None:
self.device = torch.device(device)

Expand Down Expand Up @@ -54,29 +67,33 @@ def __call__(self, file: str) -> str:
pred_ids = torch.argmax(logits, dim=-1)[0]
if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th":
txt = self.processor.decode(pred_ids)
else:
elif self.lm:
txt = self.processor.batch_decode(logits.detach().numpy()).text[0]
else:
txt = self.processor.decode(pred_ids)
return txt

_model_name = "airesearch/wav2vec2-large-xlsr-53-th"
_model = None


def asr(file: str, model: str = _model_name) -> str:
def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -> str:
"""
:param str file: path of sound file
:param str model: The ASR model
:param str model: The ASR model name
:param bool lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model)
:param str device: device
:return: thai text from ASR
:rtype: str
**Options for model**
* *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) (+ language model)
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) (+ language model)
"""
global _model, _model_name
if model!=_model or _model == None:
_model = ASR(model)
_model = ASR(model, lm=lm, device=device)
_model_name = model

return _model(file=file)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def read(*paths):

requirements = [
'datasets',
'transformers',
'transformers<5.0',
'torchaudio',
'soundfile',
'torch',
Expand All @@ -27,7 +27,7 @@ def read(*paths):

setup(
name='pythaiasr',
version='1.1.2',
version='1.2.0',
packages=['pythaiasr'],
url='https://github.com/pythainlp/pythaiasr',
license='Apache Software License 2.0',
Expand Down

0 comments on commit de40473

Please sign in to comment.