Skip to content

Commit

Permalink
add multilingual dataset eval support
Browse files Browse the repository at this point in the history
  • Loading branch information
arda-argmax committed Dec 3, 2024
1 parent 3e60c43 commit 19f13b0
Show file tree
Hide file tree
Showing 9 changed files with 440 additions and 34 deletions.
11 changes: 11 additions & 0 deletions scripts/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def cli():
choices=("WhisperKit", "whisper.cpp", "WhisperMLX", "WhisperOpenAIAPI"),
required=True
)
parser.add_argument(
"--force-language",
action="store_true",
help="If specified, forces the language in each data sample (if available)"
)
parser.add_argument(
"--language-subset",
type=str,
default=None,
help="If specified, filters the dataset for the given language"
)

# Alias the CLI args to match the test scripts
args = parser.parse_args()
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
"soundfile",
"librosa",
"datasets",
"evaluate"
"evaluate",
"transliterate",
"openai",
"mlx-whisper",
],
"android": [
"qai-hub",
Expand Down
15 changes: 12 additions & 3 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from whisperkit._constants import EVALS_REPO_ID, MODEL_REPO_ID
from whisperkit.evaluate.datasets import EVAL_DATASETS
from whisperkit.evaluate.evaluate import evaluate
import whisperkit.evaluate.evaluate
from whisperkit.pipelines import get_pipeline_cls
from whisperkit.test_utils import BenchmarkContext

Expand All @@ -36,6 +37,7 @@
TEST_UPLOAD_RESULTS = os.getenv("TEST_UPLOAD_RESULTS", None) or False
TEST_QOI_REFERENCE = os.getenv("TEST_QOI_REFERENCE", None) or None # TODO
AVG_WER_SANITY_CHECK_THR = 0.5
LANGUAGE_SUBSET = None


class TestWhisperPipelineEvaluate(unittest.TestCase):
Expand Down Expand Up @@ -66,7 +68,8 @@ def setUpClass(cls) -> None:
dataset_name=TEST_DATASET_NAME,
num_samples=TEST_NUM_SAMPLES,
cache_dir=TEST_CACHE_DIR,
num_proc=TEST_NUM_PROC),
num_proc=TEST_NUM_PROC,
language_subset=LANGUAGE_SUBSET),
"metadata": {
"num_samples": TEST_NUM_SAMPLES,
"num_proc": TEST_NUM_PROC,
Expand Down Expand Up @@ -96,7 +99,9 @@ def setUpClass(cls) -> None:
results_dir = os.path.join(
TEST_PIPELINE,
TEST_MODEL_VERSION.replace("/", "_"),
TEST_DATASET_NAME
TEST_DATASET_NAME,
"forced" if whisperkit.evaluate.evaluate.FORCE_LANGUAGE else "",
LANGUAGE_SUBSET if LANGUAGE_SUBSET else ""
)
results_fname = datetime.datetime.now().astimezone(
).strftime("%Y-%m-%d_%H:%M:%S_GMT%z") + ".json"
Expand All @@ -122,7 +127,7 @@ def test_evaluate(self):
def main(args):
global TEST_DATASET_NAME, TEST_PIPELINE, TEST_NUM_SAMPLES, TEST_CACHE_DIR, \
TEST_MODEL_VERSION, TEST_CODE_COMMIT_HASH, TEST_MODEL_COMMIT_HASH, \
TEST_NUM_PROC, TEST_UPLOAD_RESULTS, TEST_QOI_REFERENCE
TEST_NUM_PROC, TEST_UPLOAD_RESULTS, TEST_QOI_REFERENCE, LANGUAGE_SUBSET
TEST_DATASET_NAME = args.dataset
TEST_PIPELINE = args.pipeline
TEST_NUM_SAMPLES = args.num_samples
Expand All @@ -132,6 +137,10 @@ def main(args):
TEST_MODEL_COMMIT_HASH = args.model_commit_hash
TEST_NUM_PROC = args.num_proc
TEST_UPLOAD_RESULTS = args.upload_results
LANGUAGE_SUBSET = args.language_subset

# Force language option
whisperkit.evaluate.evaluate.FORCE_LANGUAGE = args.force_language

with argmaxtools_test_utils._get_test_cache_dir(
args.persistent_cache_dir
Expand Down
108 changes: 107 additions & 1 deletion whisperkit/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
EVAL_DATASETS = [
"earnings22", "librispeech", "librispeech-200",
"earnings22-debug", "librispeech-debug",
"earnings22-12hours"
"earnings22-12hours",
"common_voice_17_0-debug-zip",
"common_voice_17_0-argmax_subset-400"
]
CUSTOM_EVAL_DATASET = os.getenv("EVAL_DATASET", None)
if CUSTOM_EVAL_DATASET is not None:
Expand All @@ -26,3 +28,107 @@
OPENAI_API_MAX_FILE_SIZE = 25e6 # bytes
OPENAI_API_COMPRESSED_UPLOAD_BIT_RATE = "12k" # kbps
TEST_DATA_REPO = "argmaxinc/whisperkit-test-data"

# Supported Languages
SUPPORTED_LANGUAGES = [
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"yue",
"zh",
]
8 changes: 8 additions & 0 deletions whisperkit/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from argmaxtools.utils import get_logger

logger = get_logger(__name__)

try:
import evaluate
except ModuleNotFoundError as e:
raise ModuleNotFoundError("`evaluate` not found. Please install evals extras via: `pip install -e '.[evals]'`" ) from e
34 changes: 31 additions & 3 deletions whisperkit/evaluate/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@
from argmaxtools.utils import get_logger
from huggingface_hub import snapshot_download

from whisperkit._constants import DATASET_REPO_OWNER, EVAL_DATASETS
from whisperkit._constants import DATASET_REPO_OWNER, EVAL_DATASETS, SUPPORTED_LANGUAGES
from whisperkit.evaluate.normalize_en import EnglishTextNormalizer

logger = get_logger(__name__)

text_normalizer = EnglishTextNormalizer()


def get_dataset(dataset_name, cache_dir, max_num_samples=-1):
def get_dataset(dataset_name, cache_dir, max_num_samples=-1, language_subset=None):
if dataset_name not in EVAL_DATASETS:
raise ValueError(f"Dataset not yet registered: {dataset_name}")

if language_subset is not None:
assert language_subset in SUPPORTED_LANGUAGES, f"Unsupported language: {language_subset}"
logger.info(f"Filtering dataset for language: {language_subset}")

logger.info(f"""\n
=======================================================
Downloading and preprocessing '{dataset_name}' dataset
Expand All @@ -41,6 +45,21 @@ def get_dataset(dataset_name, cache_dir, max_num_samples=-1):
local_dir_use_symlinks=True
)

# Unzip if necessary
zip_files = [f for f in os.listdir(cache_dir) if f.endswith('.zip')]
if len(zip_files) > 0:
logger.info(f"Unzipping {len(zip_files)} files")
for zip_file in zip_files:
zip_path = os.path.join(cache_dir, zip_file)
os.system(f"unzip -q -o {zip_path} -d {cache_dir}")
os.remove(zip_path)

has_folders = False
for path in os.listdir(cache_dir):
if os.path.isdir(os.path.join(cache_dir, path)) and not path.startswith("."):
has_folders = True
break

audio_paths = _get_audio_paths(cache_dir)
audio_paths = {path.split("/")[-1]: path for path in audio_paths}

Expand Down Expand Up @@ -73,14 +92,23 @@ def preprocess_fn(batch):
if key in batch:
break
batch["original_text"] = batch[key]
batch["norm_text"] = text_normalizer(batch[key])
if not isinstance(batch[key], str):
logger.warning(f"non-string text dectected: {batch[key]} | Class: {type(batch[key])}")
logger.warning(f"Conversion to string: {str(batch[key])}")
batch["norm_text"] = text_normalizer(str(batch[key]))

# Remove invalid samples
drop = batch["norm_text"].strip() == "ignore time segment in scoring"
drop = drop or batch["norm_text"].strip() == ""

# Filter by language
if language_subset is not None:
drop = drop or batch.get("language", None) != language_subset

if drop:
return None
if has_folders:
batch["norm_folder"] = "/".join(batch["norm_path"].split("/")[:-1])
return batch

original_num_samples = len(dataset)
Expand Down
Loading

0 comments on commit 19f13b0

Please sign in to comment.