diff --git a/README.md b/README.md index e7adda8..d572283 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ pip install pythaiasr ``` **For Wav2Vec2 with language model:** -if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model, you needs to install by the step. +if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model with language model, you needs to install by the step. ```sh pip install pythaiasr[lm] @@ -37,17 +37,19 @@ print(asr(file)) ### API ```python -asr(file: str, model: str = "airesearch/wav2vec2-large-xlsr-53-th") +asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) ``` - file: path of sound file - model: The ASR model +- lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model) +- device: device - return: thai text from ASR **Options for model** - *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model -- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model -- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model +- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) +- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) You can read about models from the list: diff --git a/pythaiasr/__init__.py b/pythaiasr/__init__.py index 67a4221..60c6f79 100644 --- a/pythaiasr/__init__.py +++ b/pythaiasr/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import torch -from transformers import AutoProcessor, AutoModelForCTC import torchaudio import numpy as np @@ -8,9 +7,10 @@ class ASR: - def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=None) -> None: + def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=False, device: str=None) -> None: """ - :param str model: The ASR model + :param str model: The ASR model name + :param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model) :param str device: device **Options for model** @@ -18,9 +18,22 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=Non * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model """ - self.processor = AutoProcessor.from_pretrained(model) self.model_name = model - self.model = AutoModelForCTC.from_pretrained(model) + self.support_model =[ + "airesearch/wav2vec2-large-xlsr-53-th", + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm", + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut" + ] + assert self.model_name in self.support_model + self.lm =lm + if not self.lm: + from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) + self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name) + else: + from transformers import AutoProcessor, AutoModelForCTC + self.processor = AutoProcessor.from_pretrained(self.model_name) + self.model = AutoModelForCTC.from_pretrained(self.model_name) if device!=None: self.device = torch.device(device) @@ -54,29 +67,33 @@ def __call__(self, file: str) -> str: pred_ids = torch.argmax(logits, dim=-1)[0] if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th": txt = self.processor.decode(pred_ids) - else: + elif self.lm: txt = self.processor.batch_decode(logits.detach().numpy()).text[0] + else: + txt = self.processor.decode(pred_ids) return txt _model_name = "airesearch/wav2vec2-large-xlsr-53-th" _model = None -def asr(file: str, model: str = _model_name) -> str: +def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -> str: """ :param str file: path of sound file - :param str model: The ASR model + :param str model: The ASR model name + :param bool lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model) + :param str device: device :return: thai text from ASR :rtype: str **Options for model** * *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model - * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model - * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model + * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) (+ language model) + * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) (+ language model) """ global _model, _model_name if model!=_model or _model == None: - _model = ASR(model) + _model = ASR(model, lm=lm, device=device) _model_name = model return _model(file=file) diff --git a/setup.py b/setup.py index 75b3857..bf3b948 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ def read(*paths): requirements = [ 'datasets', - 'transformers', + 'transformers<5.0', 'torchaudio', 'soundfile', 'torch', @@ -27,7 +27,7 @@ def read(*paths): setup( name='pythaiasr', - version='1.1.2', + version='1.2.0', packages=['pythaiasr'], url='https://github.com/pythainlp/pythaiasr', license='Apache Software License 2.0',