-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dev(narugo): add silero_langid function
- Loading branch information
1 parent
a2a0277
commit 95dd43b
Showing
11 changed files
with
639 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .silero import silero_langid, silero_langid_score | ||
from .whisper import whisper_langid, whisper_langid_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import json | ||
from functools import lru_cache | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
from huggingface_hub import hf_hub_download | ||
|
||
from ..data import SoundTyping, Sound | ||
from ..utils import open_onnx_model, softmax, vreplace | ||
|
||
_REPO_ID = 'deepghs/silero-lang95-onnx' | ||
|
||
|
||
@lru_cache() | ||
def _lang_dict_95(): | ||
with open(hf_hub_download(repo_id=_REPO_ID, filename='lang_dict_95.json'), 'r') as f: | ||
return json.load(f) | ||
|
||
|
||
@lru_cache() | ||
def _lang_group_dict_95(): | ||
with open(hf_hub_download(repo_id=_REPO_ID, filename='lang_group_dict_95.json'), 'r') as f: | ||
return json.load(f) | ||
|
||
|
||
@lru_cache() | ||
def _open_model(): | ||
return open_onnx_model(hf_hub_download(repo_id=_REPO_ID, filename='lang_classifier_95.onnx')) | ||
|
||
|
||
def _raw_langid(sound: SoundTyping, top_n: int = 5): | ||
sound = Sound.load(sound).to_mono().resample(sample_rate=16000) | ||
wav, sr = sound.to_numpy() | ||
|
||
model = _open_model() | ||
lang_logits, lang_group_logits = model.run(None, {'input': wav.astype(np.float32)}) | ||
softm = softmax(lang_logits, axis=-1)[0] | ||
softm_group = softmax(lang_group_logits, axis=-1)[0] | ||
srtd = np.argsort(softm)[::-1] | ||
srtd_group = np.argsort(softm_group)[::-1] | ||
|
||
lang_dict = _lang_dict_95() | ||
lang_group_dict = _lang_group_dict_95() | ||
scores = {} | ||
group_scores = [] | ||
for i in range(top_n): | ||
prob = softm[srtd[i]].item() | ||
prob_group = softm_group[srtd_group[i]].item() | ||
scores[lang_dict[str(srtd[i].item())]] = prob | ||
group_scores.append((lang_group_dict[str(srtd_group[i].item())], prob_group)) | ||
|
||
return scores, group_scores | ||
|
||
|
||
def silero_langid(sound: SoundTyping) -> Tuple[str, float]: | ||
scores, group_scores = _raw_langid(sound=sound, top_n=1) | ||
lang, score = list(scores.items())[0] | ||
return lang, score | ||
|
||
|
||
def silero_langid_score(sound: SoundTyping, fmt: str = 'scores', top_n: int = 5): | ||
scores, group_scores = _raw_langid(sound=sound, top_n=top_n) | ||
return vreplace(fmt, { | ||
'scores': scores, | ||
'group_scores': group_scores, | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .conv import np_conv1d | ||
from .enum import ExplicitEnum | ||
from .format import vreplace | ||
from .np import softmax | ||
from .onnx import open_onnx_model, get_onnx_provider |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
__all__ = [ | ||
'vreplace', | ||
] | ||
|
||
|
||
def vreplace(v, mapping): | ||
""" | ||
Replaces values in a data structure using a mapping dictionary. | ||
:param v: The input data structure. | ||
:type v: Any | ||
:param mapping: A dictionary mapping values to replacement values. | ||
:type mapping: Dict | ||
:return: The modified data structure. | ||
:rtype: Any | ||
""" | ||
if isinstance(v, (list, tuple)): | ||
return type(v)([vreplace(vitem, mapping) for vitem in v]) | ||
elif isinstance(v, dict): | ||
return type(v)({key: vreplace(value, mapping) for key, value in v.items()}) | ||
else: | ||
try: | ||
_ = hash(v) | ||
except TypeError: | ||
return v | ||
else: | ||
return mapping.get(v, v) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os.path | ||
|
||
import pytest | ||
|
||
from soundutils.langid import silero_langid | ||
from test.testings import get_testfile | ||
|
||
_MAP = { | ||
'en': ['en, English'], | ||
'zh': ['zh, Chinese', 'zh-CN, Chinese'], | ||
'jp': ['ja, Japanese'], | ||
'kr': ['ko, Korean'], | ||
} | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestLangidSilero: | ||
@pytest.mark.parametrize(['file', 'lang'], [ | ||
(os.path.join(lang, f'{lang}_{item}.wav'), lang) | ||
for item in ['short', 'medium', 'long'] | ||
for lang in ['en', 'jp', 'zh', 'kr'] | ||
]) | ||
def test_silero_langid(self, file, lang): | ||
label, score = silero_langid(get_testfile('assets', 'langs', file)) | ||
assert label in _MAP[lang] | ||
assert isinstance(score, float) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import pytest | ||
|
||
from soundutils.utils import vreplace | ||
|
||
|
||
@pytest.fixture | ||
def simple_mapping(): | ||
return {'a': 'apple', 'b': 'banana'} | ||
|
||
|
||
@pytest.fixture | ||
def complex_mapping(): | ||
return {1: 'one', 2: 'two', 3: 'three'} | ||
|
||
|
||
@pytest.fixture | ||
def nested_data_structure(): | ||
return [1, {'a': 2, 'b': [3, 4]}] | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestVReplace: | ||
def test_basic_replacement(self, simple_mapping): | ||
assert vreplace('a', simple_mapping) == 'apple' | ||
assert vreplace('b', simple_mapping) == 'banana' | ||
|
||
def test_no_replacement(self, simple_mapping): | ||
assert vreplace('c', simple_mapping) == 'c' | ||
|
||
def test_list_replacement(self, simple_mapping): | ||
assert vreplace(['a', 'b', 'c'], simple_mapping) == ['apple', 'banana', 'c'] | ||
|
||
def test_tuple_replacement(self, simple_mapping): | ||
assert vreplace(('a', 'b', 'c'), simple_mapping) == ('apple', 'banana', 'c') | ||
|
||
def test_dict_replacement_1(self, complex_mapping): | ||
input_dict = {'a': 1, 'b': 2} | ||
expected_dict = {'a': 'one', 'b': 'two'} | ||
assert vreplace(input_dict, complex_mapping) == expected_dict | ||
|
||
def test_dict_replacement_2(self, complex_mapping): | ||
input_dict = {1: 'a', 2: 'b'} | ||
expected_dict = {1: 'a', 2: 'b'} | ||
assert vreplace(input_dict, complex_mapping) == expected_dict | ||
|
||
def test_nested_structure_replacement(self, complex_mapping, nested_data_structure): | ||
expected_structure = ['one', {'a': 'two', 'b': ['three', 4]}] | ||
assert vreplace(nested_data_structure, complex_mapping) == expected_structure | ||
|
||
def test_unhashable_type(self): | ||
unhashable = [1, 2, {3}] | ||
assert vreplace(unhashable, {1: 'one'}) == ['one', 2, {3}] |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import json | ||
from pprint import pprint | ||
|
||
import numpy as np | ||
import onnxruntime | ||
from huggingface_hub import hf_hub_download | ||
|
||
from soundutils.data import Sound | ||
from soundutils.utils import softmax | ||
from test.testings import get_testfile | ||
|
||
languages = ['ru', 'en', 'de', 'es'] | ||
|
||
|
||
class Validator(): | ||
def __init__(self, path): | ||
self.model = onnxruntime.InferenceSession(path) | ||
|
||
def __call__(self, inputs: np.ndarray): | ||
ort_inputs = {'input': inputs} | ||
outs = self.model.run(None, ort_inputs) | ||
return outs | ||
|
||
|
||
def read_audio(path: str, | ||
sampling_rate: int = 16000): | ||
sound = Sound.load(path).resample(sampling_rate) | ||
data, sr = sound.to_numpy() | ||
return data[0] | ||
|
||
|
||
def get_language_and_group(wav: np.ndarray, model, lang_dict: dict, lang_group_dict: dict, top_n: int = 5): | ||
wav = wav.astype(np.float32)[None, ...] | ||
lang_logits, lang_group_logits = model(wav) | ||
|
||
softm = softmax(lang_logits, axis=-1)[0] | ||
softm_group = softmax(lang_group_logits, axis=-1)[0] | ||
|
||
srtd = np.argsort(softm)[::-1] | ||
srtd_group = np.argsort(softm_group)[::-1] | ||
|
||
outs = {} | ||
outs_group = [] | ||
for i in range(top_n): | ||
prob = softm[srtd[i]].item() | ||
prob_group = softm_group[srtd_group[i]].item() | ||
outs[lang_dict[str(srtd[i].item())]] = prob | ||
outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group)) | ||
|
||
return outs, outs_group | ||
|
||
|
||
if __name__ == '__main__': | ||
lang = 'jp' | ||
wav = read_audio(get_testfile('assets', 'langs', lang, f'{lang}_medium.wav'), sampling_rate=16000) | ||
|
||
repo_id = 'deepghs/silero-lang95-onnx' | ||
with open(hf_hub_download(repo_id=repo_id, filename='lang_dict_95.json'), 'r') as f: | ||
lang_dict = json.load(f) | ||
with open(hf_hub_download(repo_id=repo_id, filename='lang_group_dict_95.json'), 'r') as f: | ||
lang_group_dict = json.load(f) | ||
model = Validator(path=hf_hub_download(repo_id=repo_id, filename='lang_classifier_95.onnx')) | ||
|
||
languages, language_groups = get_language_and_group(wav, model, lang_dict, lang_group_dict, top_n=5) | ||
|
||
pprint(languages) | ||
for gp, score in language_groups: | ||
print(f'Language group: {gp!r} with prob {score:.3f}') |
Empty file.
Oops, something went wrong.