Skip to content

Commit

Permalink
dev(narugo): add silero_langid function
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Aug 31, 2024
1 parent a2a0277 commit 95dd43b
Show file tree
Hide file tree
Showing 11 changed files with 639 additions and 0 deletions.
1 change: 1 addition & 0 deletions soundutils/langid/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .silero import silero_langid, silero_langid_score
from .whisper import whisper_langid, whisper_langid_score
66 changes: 66 additions & 0 deletions soundutils/langid/silero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import json
from functools import lru_cache
from typing import Tuple

import numpy as np
from huggingface_hub import hf_hub_download

from ..data import SoundTyping, Sound
from ..utils import open_onnx_model, softmax, vreplace

_REPO_ID = 'deepghs/silero-lang95-onnx'


@lru_cache()
def _lang_dict_95():
with open(hf_hub_download(repo_id=_REPO_ID, filename='lang_dict_95.json'), 'r') as f:
return json.load(f)


@lru_cache()
def _lang_group_dict_95():
with open(hf_hub_download(repo_id=_REPO_ID, filename='lang_group_dict_95.json'), 'r') as f:
return json.load(f)


@lru_cache()
def _open_model():
return open_onnx_model(hf_hub_download(repo_id=_REPO_ID, filename='lang_classifier_95.onnx'))


def _raw_langid(sound: SoundTyping, top_n: int = 5):
sound = Sound.load(sound).to_mono().resample(sample_rate=16000)
wav, sr = sound.to_numpy()

model = _open_model()
lang_logits, lang_group_logits = model.run(None, {'input': wav.astype(np.float32)})
softm = softmax(lang_logits, axis=-1)[0]
softm_group = softmax(lang_group_logits, axis=-1)[0]
srtd = np.argsort(softm)[::-1]
srtd_group = np.argsort(softm_group)[::-1]

lang_dict = _lang_dict_95()
lang_group_dict = _lang_group_dict_95()
scores = {}
group_scores = []
for i in range(top_n):
prob = softm[srtd[i]].item()
prob_group = softm_group[srtd_group[i]].item()
scores[lang_dict[str(srtd[i].item())]] = prob
group_scores.append((lang_group_dict[str(srtd_group[i].item())], prob_group))

return scores, group_scores


def silero_langid(sound: SoundTyping) -> Tuple[str, float]:
scores, group_scores = _raw_langid(sound=sound, top_n=1)
lang, score = list(scores.items())[0]
return lang, score


def silero_langid_score(sound: SoundTyping, fmt: str = 'scores', top_n: int = 5):
scores, group_scores = _raw_langid(sound=sound, top_n=top_n)
return vreplace(fmt, {
'scores': scores,
'group_scores': group_scores,
})
1 change: 1 addition & 0 deletions soundutils/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .conv import np_conv1d
from .enum import ExplicitEnum
from .format import vreplace
from .np import softmax
from .onnx import open_onnx_model, get_onnx_provider
26 changes: 26 additions & 0 deletions soundutils/utils/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
__all__ = [
'vreplace',
]


def vreplace(v, mapping):
"""
Replaces values in a data structure using a mapping dictionary.
:param v: The input data structure.
:type v: Any
:param mapping: A dictionary mapping values to replacement values.
:type mapping: Dict
:return: The modified data structure.
:rtype: Any
"""
if isinstance(v, (list, tuple)):
return type(v)([vreplace(vitem, mapping) for vitem in v])
elif isinstance(v, dict):
return type(v)({key: vreplace(value, mapping) for key, value in v.items()})
else:
try:
_ = hash(v)
except TypeError:
return v
else:
return mapping.get(v, v)
26 changes: 26 additions & 0 deletions test/langid/test_silero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os.path

import pytest

from soundutils.langid import silero_langid
from test.testings import get_testfile

_MAP = {
'en': ['en, English'],
'zh': ['zh, Chinese', 'zh-CN, Chinese'],
'jp': ['ja, Japanese'],
'kr': ['ko, Korean'],
}


@pytest.mark.unittest
class TestLangidSilero:
@pytest.mark.parametrize(['file', 'lang'], [
(os.path.join(lang, f'{lang}_{item}.wav'), lang)
for item in ['short', 'medium', 'long']
for lang in ['en', 'jp', 'zh', 'kr']
])
def test_silero_langid(self, file, lang):
label, score = silero_langid(get_testfile('assets', 'langs', file))
assert label in _MAP[lang]
assert isinstance(score, float)
52 changes: 52 additions & 0 deletions test/utils/test_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

from soundutils.utils import vreplace


@pytest.fixture
def simple_mapping():
return {'a': 'apple', 'b': 'banana'}


@pytest.fixture
def complex_mapping():
return {1: 'one', 2: 'two', 3: 'three'}


@pytest.fixture
def nested_data_structure():
return [1, {'a': 2, 'b': [3, 4]}]


@pytest.mark.unittest
class TestVReplace:
def test_basic_replacement(self, simple_mapping):
assert vreplace('a', simple_mapping) == 'apple'
assert vreplace('b', simple_mapping) == 'banana'

def test_no_replacement(self, simple_mapping):
assert vreplace('c', simple_mapping) == 'c'

def test_list_replacement(self, simple_mapping):
assert vreplace(['a', 'b', 'c'], simple_mapping) == ['apple', 'banana', 'c']

def test_tuple_replacement(self, simple_mapping):
assert vreplace(('a', 'b', 'c'), simple_mapping) == ('apple', 'banana', 'c')

def test_dict_replacement_1(self, complex_mapping):
input_dict = {'a': 1, 'b': 2}
expected_dict = {'a': 'one', 'b': 'two'}
assert vreplace(input_dict, complex_mapping) == expected_dict

def test_dict_replacement_2(self, complex_mapping):
input_dict = {1: 'a', 2: 'b'}
expected_dict = {1: 'a', 2: 'b'}
assert vreplace(input_dict, complex_mapping) == expected_dict

def test_nested_structure_replacement(self, complex_mapping, nested_data_structure):
expected_structure = ['one', {'a': 'two', 'b': ['three', 4]}]
assert vreplace(nested_data_structure, complex_mapping) == expected_structure

def test_unhashable_type(self):
unhashable = [1, 2, {3}]
assert vreplace(unhashable, {1: 'one'}) == ['one', 2, {3}]
Empty file added zoo/silero/__init__.py
Empty file.
Empty file added zoo/silero/lang95/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions zoo/silero/lang95/source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json
from pprint import pprint

import numpy as np
import onnxruntime
from huggingface_hub import hf_hub_download

from soundutils.data import Sound
from soundutils.utils import softmax
from test.testings import get_testfile

languages = ['ru', 'en', 'de', 'es']


class Validator():
def __init__(self, path):
self.model = onnxruntime.InferenceSession(path)

def __call__(self, inputs: np.ndarray):
ort_inputs = {'input': inputs}
outs = self.model.run(None, ort_inputs)
return outs


def read_audio(path: str,
sampling_rate: int = 16000):
sound = Sound.load(path).resample(sampling_rate)
data, sr = sound.to_numpy()
return data[0]


def get_language_and_group(wav: np.ndarray, model, lang_dict: dict, lang_group_dict: dict, top_n: int = 5):
wav = wav.astype(np.float32)[None, ...]
lang_logits, lang_group_logits = model(wav)

softm = softmax(lang_logits, axis=-1)[0]
softm_group = softmax(lang_group_logits, axis=-1)[0]

srtd = np.argsort(softm)[::-1]
srtd_group = np.argsort(softm_group)[::-1]

outs = {}
outs_group = []
for i in range(top_n):
prob = softm[srtd[i]].item()
prob_group = softm_group[srtd_group[i]].item()
outs[lang_dict[str(srtd[i].item())]] = prob
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))

return outs, outs_group


if __name__ == '__main__':
lang = 'jp'
wav = read_audio(get_testfile('assets', 'langs', lang, f'{lang}_medium.wav'), sampling_rate=16000)

repo_id = 'deepghs/silero-lang95-onnx'
with open(hf_hub_download(repo_id=repo_id, filename='lang_dict_95.json'), 'r') as f:
lang_dict = json.load(f)
with open(hf_hub_download(repo_id=repo_id, filename='lang_group_dict_95.json'), 'r') as f:
lang_group_dict = json.load(f)
model = Validator(path=hf_hub_download(repo_id=repo_id, filename='lang_classifier_95.onnx'))

languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=5)

pprint(languages)
for gp, score in language_groups:
print(f'Language group: {gp!r} with prob {score:.3f}')
Empty file added zoo/silero/vad/__init__.py
Empty file.
Loading

0 comments on commit 95dd43b

Please sign in to comment.