Skip to content

Commit

Permalink
dev(narugo): add whisper langid model
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Aug 29, 2024
1 parent 374aa15 commit bbdf092
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 0 deletions.
1 change: 1 addition & 0 deletions soundutils/langid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .whisper import whisper_langid, whisper_langid_score
101 changes: 101 additions & 0 deletions soundutils/langid/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import json
from functools import lru_cache
from typing import List, Dict, Tuple

import numpy as np
from hbutils.string import plural_word
from huggingface_hub import hf_hub_download

from ..data import SoundTyping, Sound
from ..preprocess.transformers import WhisperFeatureExtractor
from ..utils import open_onnx_model, softmax


@lru_cache()
def _preprocessor_config(repo_id: str):
with open(hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename='preprocessor_config.json'
), 'r') as f:
return json.load(f)


@lru_cache()
def _config(repo_id: str):
with open(hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename='config.json'
), 'r') as f:
return json.load(f)


@lru_cache()
def _open_model(repo_id: str):
return open_onnx_model(hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename='model.onnx'
))


_DEFAULT_MODEL = 'deepghs/whisper-medium-fleurs-lang-id-onnx'


@lru_cache()
def _feature_extractor(repo_id: str):
return WhisperFeatureExtractor(**_preprocessor_config(repo_id=repo_id))


def _audio_preprocess(sounds: List[SoundTyping], repo_id: str = _DEFAULT_MODEL, resample_rate: int = 16000):
datas = []
for sf in sounds:
sound = Sound.load(sf)
if sound.channels != 1:
raise ValueError(f'Only 1-channel audio is supported, '
f'{plural_word(sound.channels, "channel")} found in {sf}.')
sound = sound.resample(resample_rate)
data, sr = sound.to_numpy()
datas.append(data[0])

fr = _feature_extractor(repo_id=repo_id)
return fr(datas)['input_features']


def _raw_sound_langid(sound: SoundTyping, model_name: str = _DEFAULT_MODEL):
input_ = _audio_preprocess([sound], repo_id=model_name)
model = _open_model(repo_id=model_name)
input_names = [input.name for input in model.get_inputs()]
assert len(input_names) == 1, f'Non-unique input for model {model_name!r} - {input_names!r}.'
output_names = [output.name for output in model.get_outputs()]
assert len(output_names) == 1, f'Non-unique output for model {model_name!r} - {output_names!r}.'

output, = model.run(output_names, {
input_names[0]: input_
})
logits = output[0]
return softmax(logits)


def whisper_langid(sound: SoundTyping, model_name: str = _DEFAULT_MODEL) -> Tuple[str, float]:
scores = _raw_sound_langid(
sound=sound,
model_name=model_name,
)
idx = np.argmax(scores).item()
best_label = _config(repo_id=model_name)["id2label"][str(idx)]
best_score = scores[idx].item()
return best_label, best_score


def whisper_langid_score(sound: SoundTyping, model_name: str = _DEFAULT_MODEL) -> Dict[str, float]:
score = _raw_sound_langid(
sound=sound,
model_name=model_name,
)
retval = {}
for i, v in enumerate(score.tolist()):
label = _config(repo_id=model_name)["id2label"][str(i)]
retval[label] = v
return retval
2 changes: 2 additions & 0 deletions soundutils/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .enum import ExplicitEnum
from .np import softmax
from .onnx import open_onnx_model, get_onnx_provider
7 changes: 7 additions & 0 deletions soundutils/utils/np.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np


def softmax(x, axis=-1):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
sum_exp_x = np.sum(exp_x, axis=axis, keepdims=True)
return exp_x / sum_exp_x
96 changes: 96 additions & 0 deletions soundutils/utils/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Overview:
Management of onnx models.
"""
import logging
import os
import shutil
from typing import Optional

from hbutils.system import pip_install

__all__ = [
'get_onnx_provider', 'open_onnx_model'
]


def _ensure_onnxruntime():
try:
import onnxruntime
except (ImportError, ModuleNotFoundError):
logging.warning('Onnx runtime not installed, preparing to install ...')
if shutil.which('nvidia-smi'):
logging.info('Installing onnxruntime-gpu ...')
pip_install(['onnxruntime-gpu'], silent=True)
else:
logging.info('Installing onnxruntime (cpu) ...')
pip_install(['onnxruntime'], silent=True)


_ensure_onnxruntime()
from onnxruntime import get_available_providers, get_all_providers, InferenceSession, SessionOptions, \
GraphOptimizationLevel

alias = {
'gpu': "CUDAExecutionProvider",
"trt": "TensorrtExecutionProvider",
}


def get_onnx_provider(provider: Optional[str] = None):
"""
Overview:
Get onnx provider.
:param provider: The provider for ONNX runtime. ``None`` by default and will automatically detect
if the ``CUDAExecutionProvider`` is available. If it is available, it will be used,
otherwise the default ``CPUExecutionProvider`` will be used.
:return: String of the provider.
"""
if not provider:
if "CUDAExecutionProvider" in get_available_providers():
return "CUDAExecutionProvider"
else:
return "CPUExecutionProvider"
elif provider.lower() in alias:
return alias[provider.lower()]
else:
for p in get_all_providers():
if provider.lower() == p.lower() or f'{provider}ExecutionProvider'.lower() == p.lower():
return p

raise ValueError(f'One of the {get_all_providers()!r} expected, '
f'but unsupported provider {provider!r} found.')


def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession:
options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
if provider == "CPUExecutionProvider":
options.intra_op_num_threads = os.cpu_count()

providers = [provider]
if use_cpu and "CPUExecutionProvider" not in providers:
providers.append("CPUExecutionProvider")

logging.info(f'Model {ckpt!r} loaded with provider {provider!r}')
return InferenceSession(ckpt, options, providers=providers)


def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession:
"""
Overview:
Open an ONNX model and load its ONNX runtime.
:param ckpt: ONNX model file.
:param mode: Provider of the ONNX. Default is ``None`` which means the provider will be auto-detected,
see :func:`get_onnx_provider` for more details.
:return: A loaded ONNX runtime object.
.. note::
When ``mode`` is set to ``None``, it will attempt to detect the environment variable ``ONNX_MODE``.
This means you can decide which ONNX runtime to use by setting the environment variable. For example,
on Linux, executing ``export ONNX_MODE=cpu`` will ignore any existing CUDA and force the model inference
to run on CPU.
"""
return _open_onnx_model(ckpt, get_onnx_provider(mode or os.environ.get('ONNX_MODE', None)))
Empty file added test/langid/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions test/langid/test_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

import pytest

from soundutils.langid import whisper_langid
from test.testings import get_testfile


@pytest.mark.unittest
class TestLangidWhisper:
@pytest.mark.parametrize(['lang', 'file'], [
('Mandarin Chinese', os.path.join('zh', 'zh_long.wav')),
('Japanese', os.path.join('jp', 'jp_long.wav')),
('Korean', os.path.join('kr', 'kr_long.wav')),
('English', os.path.join('en', 'en_long.wav')),
])
def test_whisper_langid(self, lang, file):
alang, ascore = whisper_langid(get_testfile('assets', 'langs', file))
assert alang == lang
48 changes: 48 additions & 0 deletions test/utils/test_np.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import pytest

from soundutils.utils import softmax


@pytest.fixture
def vector():
return np.array([1.0, 2.0, 3.0])


@pytest.fixture
def matrix():
return np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])


@pytest.fixture
def high_precision_vector():
return np.array([1234567890.0, 1234567891.0, 1234567892.0])


@pytest.mark.unittest
class TestUtilsNp:
def test_softmax_vector(self, vector):
result = softmax(vector)
expected = np.array([0.09003057, 0.24472847, 0.66524096])
np.testing.assert_almost_equal(result, expected, decimal=8)

def test_softmax_matrix(self, matrix):
result = softmax(matrix, axis=1)
expected = np.array([
[0.09003057, 0.24472847, 0.66524096],
[0.09003057, 0.24472847, 0.66524096]
])
np.testing.assert_almost_equal(result, expected, decimal=8)

def test_softmax_matrix_along_default_axis(self, matrix):
result = softmax(matrix)
expected = np.array([
[0.09003057, 0.24472847, 0.66524096],
[0.09003057, 0.24472847, 0.66524096]
])
np.testing.assert_almost_equal(result, expected, decimal=8)

def test_softmax_high_precision(self, high_precision_vector):
result = softmax(high_precision_vector)
expected = np.array([0.09003057, 0.24472847, 0.66524096])
np.testing.assert_almost_equal(result, expected, decimal=8)

0 comments on commit bbdf092

Please sign in to comment.