Skip to content

Commit

Permalink
Merge pull request #26 from argmaxinc/arda/multilingual_eval
Browse files Browse the repository at this point in the history
Add multilingual dataset eval support
  • Loading branch information
atiorh authored Dec 5, 2024
2 parents b77094c + 4e3610e commit 6ce531c
Show file tree
Hide file tree
Showing 15 changed files with 514 additions and 64 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/public-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ jobs:
shell: bash -el {0}
run: |
TEST_DEV=cpu $(which python) tests/test_audio_encoder.py
- name: Evaluate Unit Test
shell: bash -el {0}
run: |
$(which python) tests/test_evaluate.py --dataset librispeech-debug --pipeline WhisperKit
- name: Folder Evaluate Unit Test
shell: bash -el {0}
run: |
$(which python) tests/test_evaluate.py --dataset common_voice_17_0-debug-zip --pipeline WhisperKit --language-subset en
- name: Lint
shell: bash -el {0}
run: |
Expand Down
11 changes: 11 additions & 0 deletions scripts/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def cli():
choices=("WhisperKit", "whisper.cpp", "WhisperMLX", "WhisperOpenAIAPI"),
required=True
)
parser.add_argument(
"--force-language",
action="store_true",
help="If specified, forces the language in each data sample (if available)"
)
parser.add_argument(
"--language-subset",
type=str,
default=None,
help="If specified, filters the dataset for the given language"
)

# Alias the CLI args to match the test scripts
args = parser.parse_args()
Expand Down
11 changes: 4 additions & 7 deletions tests/test_aihub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@

from whisperkit.android import models as android
from whisperkit.android import utils as aihub_utils
from whisperkit import audio_encoder as apple

from tests.test_audio_encoder import TEST_N_SAMPLES

TEST_VOCAB_SIZE = 51865
TEST_PSNR_THR = 40
Expand Down Expand Up @@ -48,13 +45,12 @@ def setUpClass(cls):
}
}
super().setUpClass()

@classmethod
def tearDownClass(cls):
cls.models = None
super().tearDownClass()


def test_torch2torch_correctness(self):
""" Test forward pass functionality and correctness of PyTorch models
"""
Expand All @@ -70,8 +66,9 @@ def test_torch2torch_correctness(self):
logger.info(f"torch2torch model={model_key} PSNR={psnr:.3g}")
else:
logger.info(
f"torch2torch correctness test skipped: Reference model does not exist for {model_key}")

"torch2torch correctness test skipped: "
f"Reference model does not exist for {model_key}"
)

def test_torch2aihub_performance_and_correctness(self):
""" Test AI Hub compilation and inference job results against local PyTorch test results
Expand Down
1 change: 1 addition & 0 deletions tests/test_audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TEST_PSNR_THR = 35

argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = 0.95
argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = True

# WhisperMelSpectrogram constants
# TEST_N_MELS = [80, 128]
Expand Down
26 changes: 21 additions & 5 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import datetime
import json
import os
import pprint
import subprocess
import unittest

Expand All @@ -17,6 +16,7 @@
from whisperkit._constants import EVALS_REPO_ID, MODEL_REPO_ID
from whisperkit.evaluate.datasets import EVAL_DATASETS
from whisperkit.evaluate.evaluate import evaluate
import whisperkit.evaluate.evaluate
from whisperkit.pipelines import get_pipeline_cls
from whisperkit.test_utils import BenchmarkContext

Expand All @@ -36,6 +36,7 @@
TEST_UPLOAD_RESULTS = os.getenv("TEST_UPLOAD_RESULTS", None) or False
TEST_QOI_REFERENCE = os.getenv("TEST_QOI_REFERENCE", None) or None # TODO
AVG_WER_SANITY_CHECK_THR = 0.5
LANGUAGE_SUBSET = None


class TestWhisperPipelineEvaluate(unittest.TestCase):
Expand All @@ -60,21 +61,28 @@ def setUpClass(cls) -> None:
shell=True
).stdout.decode('utf-8').strip()[:7]

inference_context_spec_dict = None
try:
inference_context_spec_dict = cls.inference_context.spec_dict()
except Exception as e:
logger.warning(f"Inference context spec dict failed: {e}")

cls.results = {
"results": evaluate(
cls.pipeline,
dataset_name=TEST_DATASET_NAME,
num_samples=TEST_NUM_SAMPLES,
cache_dir=TEST_CACHE_DIR,
num_proc=TEST_NUM_PROC),
num_proc=TEST_NUM_PROC,
language_subset=LANGUAGE_SUBSET),
"metadata": {
"num_samples": TEST_NUM_SAMPLES,
"num_proc": TEST_NUM_PROC,
"pipeline": TEST_PIPELINE,
"dataset_name": TEST_DATASET_NAME,
"model_version": TEST_MODEL_VERSION,
"whisperkittools_commit_hash": wkt_commit_hash,
"inference_context": cls.inference_context.spec_dict(),
"inference_context": inference_context_spec_dict,
"model_repo_id": MODEL_REPO_ID
}
}
Expand All @@ -96,7 +104,9 @@ def setUpClass(cls) -> None:
results_dir = os.path.join(
TEST_PIPELINE,
TEST_MODEL_VERSION.replace("/", "_"),
TEST_DATASET_NAME
TEST_DATASET_NAME,
"forced" if whisperkit.evaluate.evaluate.FORCE_LANGUAGE else "",
LANGUAGE_SUBSET if LANGUAGE_SUBSET else ""
)
results_fname = datetime.datetime.now().astimezone(
).strftime("%Y-%m-%d_%H:%M:%S_GMT%z") + ".json"
Expand All @@ -122,7 +132,7 @@ def test_evaluate(self):
def main(args):
global TEST_DATASET_NAME, TEST_PIPELINE, TEST_NUM_SAMPLES, TEST_CACHE_DIR, \
TEST_MODEL_VERSION, TEST_CODE_COMMIT_HASH, TEST_MODEL_COMMIT_HASH, \
TEST_NUM_PROC, TEST_UPLOAD_RESULTS, TEST_QOI_REFERENCE
TEST_NUM_PROC, TEST_UPLOAD_RESULTS, TEST_QOI_REFERENCE, LANGUAGE_SUBSET
TEST_DATASET_NAME = args.dataset
TEST_PIPELINE = args.pipeline
TEST_NUM_SAMPLES = args.num_samples
Expand All @@ -132,6 +142,10 @@ def main(args):
TEST_MODEL_COMMIT_HASH = args.model_commit_hash
TEST_NUM_PROC = args.num_proc
TEST_UPLOAD_RESULTS = args.upload_results
LANGUAGE_SUBSET = args.language_subset

# Force language option
whisperkit.evaluate.evaluate.FORCE_LANGUAGE = args.force_language

with argmaxtools_test_utils._get_test_cache_dir(
args.persistent_cache_dir
Expand Down Expand Up @@ -169,6 +183,8 @@ def main(args):
parser.add_argument("--model-commit-hash", type=str, default=None)
parser.add_argument("--num-proc", type=int, default=1)
parser.add_argument("--upload-results", action="store_true")
parser.add_argument("--language-subset", type=str, default=None)
parser.add_argument("--force-language", action="store_true")
args = parser.parse_args()

main(args)
2 changes: 2 additions & 0 deletions tests/test_text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
TEST_PSNR_THR = 35
TEST_CACHE_DIR = os.getenv("TEST_CACHE_DIR", None) or "/tmp"

argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = True

# WhisperDecoderContextPrefill constants
TEST_PREFILL_CONSISTENCY_PSNR_THR = 20
TEST_BATCH = 16
Expand Down
108 changes: 107 additions & 1 deletion whisperkit/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
EVAL_DATASETS = [
"earnings22", "librispeech", "librispeech-200",
"earnings22-debug", "librispeech-debug",
"earnings22-12hours"
"earnings22-12hours",
"common_voice_17_0-debug-zip",
"common_voice_17_0-argmax_subset-400"
]
CUSTOM_EVAL_DATASET = os.getenv("EVAL_DATASET", None)
if CUSTOM_EVAL_DATASET is not None:
Expand All @@ -26,3 +28,107 @@
OPENAI_API_MAX_FILE_SIZE = 25e6 # bytes
OPENAI_API_COMPRESSED_UPLOAD_BIT_RATE = "12k" # kbps
TEST_DATA_REPO = "argmaxinc/whisperkit-test-data"

# Supported Languages
SUPPORTED_LANGUAGES = [
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"yue",
"zh",
]
10 changes: 5 additions & 5 deletions whisperkit/android/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, filter_length=1024, hop_length=512, win_length=None, window='
np.imag(fourier_basis[:self.cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])

assert(filter_length >= self.win_length)
assert (filter_length >= self.win_length)
fft_window = get_window(window, self.win_length, fftbins=True)
fft_window = pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()
Expand Down Expand Up @@ -91,14 +91,14 @@ def __init__(self, n_mels=80, n_fft=400, hop_length=160):
)

self.stft = DecomposedSTFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.n_fft,
window='hann'
)

def forward(self, audio: tt.WhisperMelSpectrogramInputType) -> tt.WhisperMelSpectrogramOutputType:

transformed = self.stft(audio)
magnitudes = transformed[..., :-1]
mel_spec = self.mel_filters @ magnitudes
Expand All @@ -118,7 +118,7 @@ class WhisperDecoderPostProc(nn.Module):
def forward(self, logits):
TOKEN_TIMESTAMP_BEGIN = 50363
TOKEN_NO_SPEECH = 50361

# logprobs = F.log_softmax(logits, dim=0)
logprobs = torch.log(F.softmax(logits, dim=0))
timestamp_logprob = torch.logsumexp(logprobs[TOKEN_TIMESTAMP_BEGIN:], dim=0)
Expand Down
Loading

0 comments on commit 6ce531c

Please sign in to comment.