-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dev(narugo): add whisper langid model
- Loading branch information
1 parent
374aa15
commit bbdf092
Showing
8 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .whisper import whisper_langid, whisper_langid_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import json | ||
from functools import lru_cache | ||
from typing import List, Dict, Tuple | ||
|
||
import numpy as np | ||
from hbutils.string import plural_word | ||
from huggingface_hub import hf_hub_download | ||
|
||
from ..data import SoundTyping, Sound | ||
from ..preprocess.transformers import WhisperFeatureExtractor | ||
from ..utils import open_onnx_model, softmax | ||
|
||
|
||
@lru_cache() | ||
def _preprocessor_config(repo_id: str): | ||
with open(hf_hub_download( | ||
repo_id=repo_id, | ||
repo_type='model', | ||
filename='preprocessor_config.json' | ||
), 'r') as f: | ||
return json.load(f) | ||
|
||
|
||
@lru_cache() | ||
def _config(repo_id: str): | ||
with open(hf_hub_download( | ||
repo_id=repo_id, | ||
repo_type='model', | ||
filename='config.json' | ||
), 'r') as f: | ||
return json.load(f) | ||
|
||
|
||
@lru_cache() | ||
def _open_model(repo_id: str): | ||
return open_onnx_model(hf_hub_download( | ||
repo_id=repo_id, | ||
repo_type='model', | ||
filename='model.onnx' | ||
)) | ||
|
||
|
||
_DEFAULT_MODEL = 'deepghs/whisper-medium-fleurs-lang-id-onnx' | ||
|
||
|
||
@lru_cache() | ||
def _feature_extractor(repo_id: str): | ||
return WhisperFeatureExtractor(**_preprocessor_config(repo_id=repo_id)) | ||
|
||
|
||
def _audio_preprocess(sounds: List[SoundTyping], repo_id: str = _DEFAULT_MODEL, resample_rate: int = 16000): | ||
datas = [] | ||
for sf in sounds: | ||
sound = Sound.load(sf) | ||
if sound.channels != 1: | ||
raise ValueError(f'Only 1-channel audio is supported, ' | ||
f'{plural_word(sound.channels, "channel")} found in {sf}.') | ||
sound = sound.resample(resample_rate) | ||
data, sr = sound.to_numpy() | ||
datas.append(data[0]) | ||
|
||
fr = _feature_extractor(repo_id=repo_id) | ||
return fr(datas)['input_features'] | ||
|
||
|
||
def _raw_sound_langid(sound: SoundTyping, model_name: str = _DEFAULT_MODEL): | ||
input_ = _audio_preprocess([sound], repo_id=model_name) | ||
model = _open_model(repo_id=model_name) | ||
input_names = [input.name for input in model.get_inputs()] | ||
assert len(input_names) == 1, f'Non-unique input for model {model_name!r} - {input_names!r}.' | ||
output_names = [output.name for output in model.get_outputs()] | ||
assert len(output_names) == 1, f'Non-unique output for model {model_name!r} - {output_names!r}.' | ||
|
||
output, = model.run(output_names, { | ||
input_names[0]: input_ | ||
}) | ||
logits = output[0] | ||
return softmax(logits) | ||
|
||
|
||
def whisper_langid(sound: SoundTyping, model_name: str = _DEFAULT_MODEL) -> Tuple[str, float]: | ||
scores = _raw_sound_langid( | ||
sound=sound, | ||
model_name=model_name, | ||
) | ||
idx = np.argmax(scores).item() | ||
best_label = _config(repo_id=model_name)["id2label"][str(idx)] | ||
best_score = scores[idx].item() | ||
return best_label, best_score | ||
|
||
|
||
def whisper_langid_score(sound: SoundTyping, model_name: str = _DEFAULT_MODEL) -> Dict[str, float]: | ||
score = _raw_sound_langid( | ||
sound=sound, | ||
model_name=model_name, | ||
) | ||
retval = {} | ||
for i, v in enumerate(score.tolist()): | ||
label = _config(repo_id=model_name)["id2label"][str(i)] | ||
retval[label] = v | ||
return retval |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .enum import ExplicitEnum | ||
from .np import softmax | ||
from .onnx import open_onnx_model, get_onnx_provider |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import numpy as np | ||
|
||
|
||
def softmax(x, axis=-1): | ||
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) | ||
sum_exp_x = np.sum(exp_x, axis=axis, keepdims=True) | ||
return exp_x / sum_exp_x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Overview: | ||
Management of onnx models. | ||
""" | ||
import logging | ||
import os | ||
import shutil | ||
from typing import Optional | ||
|
||
from hbutils.system import pip_install | ||
|
||
__all__ = [ | ||
'get_onnx_provider', 'open_onnx_model' | ||
] | ||
|
||
|
||
def _ensure_onnxruntime(): | ||
try: | ||
import onnxruntime | ||
except (ImportError, ModuleNotFoundError): | ||
logging.warning('Onnx runtime not installed, preparing to install ...') | ||
if shutil.which('nvidia-smi'): | ||
logging.info('Installing onnxruntime-gpu ...') | ||
pip_install(['onnxruntime-gpu'], silent=True) | ||
else: | ||
logging.info('Installing onnxruntime (cpu) ...') | ||
pip_install(['onnxruntime'], silent=True) | ||
|
||
|
||
_ensure_onnxruntime() | ||
from onnxruntime import get_available_providers, get_all_providers, InferenceSession, SessionOptions, \ | ||
GraphOptimizationLevel | ||
|
||
alias = { | ||
'gpu': "CUDAExecutionProvider", | ||
"trt": "TensorrtExecutionProvider", | ||
} | ||
|
||
|
||
def get_onnx_provider(provider: Optional[str] = None): | ||
""" | ||
Overview: | ||
Get onnx provider. | ||
:param provider: The provider for ONNX runtime. ``None`` by default and will automatically detect | ||
if the ``CUDAExecutionProvider`` is available. If it is available, it will be used, | ||
otherwise the default ``CPUExecutionProvider`` will be used. | ||
:return: String of the provider. | ||
""" | ||
if not provider: | ||
if "CUDAExecutionProvider" in get_available_providers(): | ||
return "CUDAExecutionProvider" | ||
else: | ||
return "CPUExecutionProvider" | ||
elif provider.lower() in alias: | ||
return alias[provider.lower()] | ||
else: | ||
for p in get_all_providers(): | ||
if provider.lower() == p.lower() or f'{provider}ExecutionProvider'.lower() == p.lower(): | ||
return p | ||
|
||
raise ValueError(f'One of the {get_all_providers()!r} expected, ' | ||
f'but unsupported provider {provider!r} found.') | ||
|
||
|
||
def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession: | ||
options = SessionOptions() | ||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL | ||
if provider == "CPUExecutionProvider": | ||
options.intra_op_num_threads = os.cpu_count() | ||
|
||
providers = [provider] | ||
if use_cpu and "CPUExecutionProvider" not in providers: | ||
providers.append("CPUExecutionProvider") | ||
|
||
logging.info(f'Model {ckpt!r} loaded with provider {provider!r}') | ||
return InferenceSession(ckpt, options, providers=providers) | ||
|
||
|
||
def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession: | ||
""" | ||
Overview: | ||
Open an ONNX model and load its ONNX runtime. | ||
:param ckpt: ONNX model file. | ||
:param mode: Provider of the ONNX. Default is ``None`` which means the provider will be auto-detected, | ||
see :func:`get_onnx_provider` for more details. | ||
:return: A loaded ONNX runtime object. | ||
.. note:: | ||
When ``mode`` is set to ``None``, it will attempt to detect the environment variable ``ONNX_MODE``. | ||
This means you can decide which ONNX runtime to use by setting the environment variable. For example, | ||
on Linux, executing ``export ONNX_MODE=cpu`` will ignore any existing CUDA and force the model inference | ||
to run on CPU. | ||
""" | ||
return _open_onnx_model(ckpt, get_onnx_provider(mode or os.environ.get('ONNX_MODE', None))) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from soundutils.langid import whisper_langid | ||
from test.testings import get_testfile | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestLangidWhisper: | ||
@pytest.mark.parametrize(['lang', 'file'], [ | ||
('Mandarin Chinese', os.path.join('zh', 'zh_long.wav')), | ||
('Japanese', os.path.join('jp', 'jp_long.wav')), | ||
('Korean', os.path.join('kr', 'kr_long.wav')), | ||
('English', os.path.join('en', 'en_long.wav')), | ||
]) | ||
def test_whisper_langid(self, lang, file): | ||
alang, ascore = whisper_langid(get_testfile('assets', 'langs', file)) | ||
assert alang == lang |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from soundutils.utils import softmax | ||
|
||
|
||
@pytest.fixture | ||
def vector(): | ||
return np.array([1.0, 2.0, 3.0]) | ||
|
||
|
||
@pytest.fixture | ||
def matrix(): | ||
return np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) | ||
|
||
|
||
@pytest.fixture | ||
def high_precision_vector(): | ||
return np.array([1234567890.0, 1234567891.0, 1234567892.0]) | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestUtilsNp: | ||
def test_softmax_vector(self, vector): | ||
result = softmax(vector) | ||
expected = np.array([0.09003057, 0.24472847, 0.66524096]) | ||
np.testing.assert_almost_equal(result, expected, decimal=8) | ||
|
||
def test_softmax_matrix(self, matrix): | ||
result = softmax(matrix, axis=1) | ||
expected = np.array([ | ||
[0.09003057, 0.24472847, 0.66524096], | ||
[0.09003057, 0.24472847, 0.66524096] | ||
]) | ||
np.testing.assert_almost_equal(result, expected, decimal=8) | ||
|
||
def test_softmax_matrix_along_default_axis(self, matrix): | ||
result = softmax(matrix) | ||
expected = np.array([ | ||
[0.09003057, 0.24472847, 0.66524096], | ||
[0.09003057, 0.24472847, 0.66524096] | ||
]) | ||
np.testing.assert_almost_equal(result, expected, decimal=8) | ||
|
||
def test_softmax_high_precision(self, high_precision_vector): | ||
result = softmax(high_precision_vector) | ||
expected = np.array([0.09003057, 0.24472847, 0.66524096]) | ||
np.testing.assert_almost_equal(result, expected, decimal=8) |