From 20c876efff6f85e170f499bd2148d3471deea8df Mon Sep 17 00:00:00 2001 From: Wannaphong Phatthiyaphaibun Date: Sun, 16 Oct 2022 01:34:46 +0700 Subject: [PATCH 1/6] Update __init__.py --- pythaiasr/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pythaiasr/__init__.py b/pythaiasr/__init__.py index 67a4221..3e04584 100644 --- a/pythaiasr/__init__.py +++ b/pythaiasr/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import torch -from transformers import AutoProcessor, AutoModelForCTC import torchaudio import numpy as np @@ -8,9 +7,10 @@ class ASR: - def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=None) -> None: + def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm=False, device=None) -> None: """ - :param str model: The ASR model + :param str model: The ASR model name + :param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model) :param str device: device **Options for model** @@ -18,9 +18,21 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=Non * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model """ - self.processor = AutoProcessor.from_pretrained(model) self.model_name = model - self.model = AutoModelForCTC.from_pretrained(model) + self.support_model =[ + "airesearch/wav2vec2-large-xlsr-53-th", + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm", + "wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut" + ] + assert self.model_name in self.support_model + if not lm: + from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) + self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name) + else: + from transformers import AutoProcessor, AutoModelForCTC + self.processor = AutoProcessor.from_pretrained(self.model_name) + self.model = AutoModelForCTC.from_pretrained(self.model_name) if device!=None: self.device = torch.device(device) From 2c761ffb5418d0e2891c0e03ca410bcc76797cf3 Mon Sep 17 00:00:00 2001 From: Wannaphong Phatthiyaphaibun Date: Sun, 16 Oct 2022 01:45:58 +0700 Subject: [PATCH 2/6] Update __init__.py --- pythaiasr/__init__.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pythaiasr/__init__.py b/pythaiasr/__init__.py index 3e04584..0c477ed 100644 --- a/pythaiasr/__init__.py +++ b/pythaiasr/__init__.py @@ -7,7 +7,7 @@ class ASR: - def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm=False, device=None) -> None: + def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=False, device: str=None) -> None: """ :param str model: The ASR model name :param bool lm: Use language model (default is False and except *airesearch/wav2vec2-large-xlsr-53-th* model) @@ -74,21 +74,23 @@ def __call__(self, file: str) -> str: _model = None -def asr(file: str, model: str = _model_name) -> str: +def asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) -> str: """ :param str file: path of sound file - :param str model: The ASR model + :param str model: The ASR model name + :param bool lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model) + :param str device: device :return: thai text from ASR :rtype: str **Options for model** * *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model - * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model - * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model + * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) (+ language model) + * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) (+ language model) """ global _model, _model_name if model!=_model or _model == None: - _model = ASR(model) + _model = ASR(model, lm=lm, device=device) _model_name = model return _model(file=file) From c2bc0102a67e28780a5be7bf14959fc0d885f3b1 Mon Sep 17 00:00:00 2001 From: Wannaphong Phatthiyaphaibun Date: Sun, 16 Oct 2022 01:46:15 +0700 Subject: [PATCH 3/6] Update README.md --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e7adda8..d572283 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ pip install pythaiasr ``` **For Wav2Vec2 with language model:** -if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model, you needs to install by the step. +if you want to use wannaphong/wav2vec2-large-xlsr-53-th-cv8-* model with language model, you needs to install by the step. ```sh pip install pythaiasr[lm] @@ -37,17 +37,19 @@ print(asr(file)) ### API ```python -asr(file: str, model: str = "airesearch/wav2vec2-large-xlsr-53-th") +asr(file: str, model: str = _model_name, lm: bool=False, device: str=None) ``` - file: path of sound file - model: The ASR model +- lm: Use language model (except *airesearch/wav2vec2-large-xlsr-53-th* model) +- device: device - return: thai text from ASR **Options for model** - *airesearch/wav2vec2-large-xlsr-53-th* (default) - AI RESEARCH - PyThaiNLP model -- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) + language model -- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model +- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-newmm* - Thai Wav2Vec2 with CommonVoice V8 (newmm tokenizer) +- *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) You can read about models from the list: From fa22d7dc17777ed02f58ba4a6599b993aa390fad Mon Sep 17 00:00:00 2001 From: Wannaphong Phatthiyaphaibun Date: Sun, 16 Oct 2022 01:48:43 +0700 Subject: [PATCH 4/6] Update __init__.py --- pythaiasr/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pythaiasr/__init__.py b/pythaiasr/__init__.py index 0c477ed..dbf8888 100644 --- a/pythaiasr/__init__.py +++ b/pythaiasr/__init__.py @@ -25,7 +25,8 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", lm: bool=F "wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut" ] assert self.model_name in self.support_model - if not lm: + self.lm =lm + if not self.lm: from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name) @@ -64,7 +65,7 @@ def __call__(self, file: str) -> str: input_dict = self.processor(a["input_values"][0], return_tensors="pt", padding=True) logits = self.model(input_dict.input_values).logits pred_ids = torch.argmax(logits, dim=-1)[0] - if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th": + if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th" or self.lm: txt = self.processor.decode(pred_ids) else: txt = self.processor.batch_decode(logits.detach().numpy()).text[0] From 74e27b14902ec9e6956c2136ea20b863473b47c1 Mon Sep 17 00:00:00 2001 From: Wannaphong Phatthiyaphaibun Date: Sun, 16 Oct 2022 01:52:51 +0700 Subject: [PATCH 5/6] Update __init__.py --- pythaiasr/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pythaiasr/__init__.py b/pythaiasr/__init__.py index dbf8888..60c6f79 100644 --- a/pythaiasr/__init__.py +++ b/pythaiasr/__init__.py @@ -65,10 +65,12 @@ def __call__(self, file: str) -> str: input_dict = self.processor(a["input_values"][0], return_tensors="pt", padding=True) logits = self.model(input_dict.input_values).logits pred_ids = torch.argmax(logits, dim=-1)[0] - if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th" or self.lm: + if self.model_name == "airesearch/wav2vec2-large-xlsr-53-th": txt = self.processor.decode(pred_ids) - else: + elif self.lm: txt = self.processor.batch_decode(logits.detach().numpy()).text[0] + else: + txt = self.processor.decode(pred_ids) return txt _model_name = "airesearch/wav2vec2-large-xlsr-53-th" From c0ab28c5a9c9d05d5728527dbb5148e40ba63a9a Mon Sep 17 00:00:00 2001 From: Wannaphong Phatthiyaphaibun Date: Sun, 16 Oct 2022 01:53:21 +0700 Subject: [PATCH 6/6] Update setup.py --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 75b3857..bf3b948 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ def read(*paths): requirements = [ 'datasets', - 'transformers', + 'transformers<5.0', 'torchaudio', 'soundfile', 'torch', @@ -27,7 +27,7 @@ def read(*paths): setup( name='pythaiasr', - version='1.1.2', + version='1.2.0', packages=['pythaiasr'], url='https://github.com/pythainlp/pythaiasr', license='Apache Software License 2.0',