From 1c6c18965abdba0e54054c4d3e498007467be2b0 Mon Sep 17 00:00:00 2001 From: philgzl Date: Fri, 14 Apr 2023 01:21:34 +0200 Subject: [PATCH] Add lengths argument --- README.md | 6 ++++-- pesq/_pesq.py | 13 ++++++++++--- tests/test_pesq.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 980f73c..db772cf 100644 --- a/README.md +++ b/README.md @@ -65,9 +65,9 @@ print(pesq(rate, ref, deg, 'nb')) # Usage for `multiprocessing` feature ```python -def pesq_batch(fs, ref, deg, mode='wb', n_processor=None, on_error=PesqError.RAISE_EXCEPTION): +def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION, lengths=None): """ - Running `pesq` using multiple processors + Running `pesq` using multiple processors Args: on_error: ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal @@ -76,6 +76,7 @@ def pesq_batch(fs, ref, deg, mode='wb', n_processor=None, on_error=PesqError.RAI mode: 'wb' (wide-band) or 'nb' (narrow-band) n_processor: cpu_count() (default) or number of processors (chosen by the user) or 0 (without multiprocessing) on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES + lengths: None or list of original length of audio signals before batching, length n_file Returns: pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO) """ @@ -86,6 +87,7 @@ When the `ref` is an 1-D numpy array and `deg` is a 2-D numpy array, the result When the `ref` is a 2-D numpy array and `deg` is a 2-D numpy array, the result of `pesq_batch` is identical to the value of `[pesq(fs, ref[i,:], deg[i,:],**kwargs) for i in range(deg.shape[0])]`. +The optional `lengths` argument is useful if the individual audio files initially had different length and were padded to match their length before creating the batch. It ensures the padded samples do not affect the results. # Correctness diff --git a/pesq/_pesq.py b/pesq/_pesq.py index f9eb27a..e808ffb 100644 --- a/pesq/_pesq.py +++ b/pesq/_pesq.py @@ -114,7 +114,7 @@ def pesq(fs, ref, deg, mode='wb', on_error=PesqError.RAISE_EXCEPTION): return _pesq_inner(ref, deg, fs, mode, on_error) -def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION): +def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION, lengths=None): """ Running `pesq` using multiple processors Args: @@ -125,12 +125,15 @@ def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.R mode: 'wb' (wide-band) or 'nb' (narrow-band) n_processor: cpu_count() (default) or number of processors (chosen by the user) or 0 (without multiprocessing) on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES + lengths: None or list of original length of audio signals before batching, length n_file Returns: pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO) """ _check_fs_mode(mode, fs, USAGE_BATCH) # check dimension if len(ref.shape) == 1: + if lengths is not None: + raise ValueError("cannot provide lengths if ref is 1D") if len(deg.shape) == 1 and ref.shape == deg.shape: return [_pesq_inner(ref, deg, fs, mode, PesqError.RETURN_VALUES)] elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]: @@ -147,14 +150,18 @@ def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.R raise ValueError("The shapes of `deg` is invalid!") elif len(ref.shape) == 2: if deg.shape == ref.shape: + if lengths is None: + lengths = [ref.shape[-1] for _ in range(ref.shape[0])] + elif len(lengths) != ref.shape[0]: + raise ValueError("len(lengths) does not match the batch size") if n_processor <= 0: pesq_score = [np.nan for i in range(deg.shape[0])] for i in range(deg.shape[0]): - pesq_score[i] = _pesq_inner(ref[i, :], deg[i, :], fs, mode, on_error) + pesq_score[i] = _pesq_inner(ref[i, :lengths[i]], deg[i, :lengths[i]], fs, mode, on_error) return pesq_score else: return _processor_mapping(_pesq_inner, - [(ref[i, :], deg[i, :], fs, mode, on_error) for i in range(deg.shape[0])], + [(ref[i, :lengths[i]], deg[i, :lengths[i]], fs, mode, on_error) for i in range(deg.shape[0])], n_processor) else: raise ValueError("The shape of `deg` is invalid!") diff --git a/tests/test_pesq.py b/tests/test_pesq.py index e76d87e..eda737d 100755 --- a/tests/test_pesq.py +++ b/tests/test_pesq.py @@ -95,6 +95,50 @@ def test_pesq_batch(): assert np.allclose(np.array(scores), ideally), scores +def test_lengths(): + data_dir = Path(__file__).parent.parent / 'audio' + ref_path = data_dir / 'speech.wav' + deg_path = data_dir / 'speech_bab_0dB.wav' + + sample_rate, ref = scipy.io.wavfile.read(ref_path) + sample_rate, deg = scipy.io.wavfile.read(deg_path) + + n_file = 10 + # random lengths smaller than len(ref) = 49600 + lengths = [43433, 40969, 30613, 38570, 45484, 10800, 28424, 22943, 30918, 30784] + + # 1D - 1D -- raises error + with pytest.raises(ValueError): + pesq_batch(ref=ref, deg=deg, fs=sample_rate, mode='wb', lengths=lengths) + + # 1D - 2D -- raises error + deg_2d = np.repeat(deg[np.newaxis, :], n_file, axis=0) + with pytest.raises(ValueError): + pesq_batch(ref=ref, deg=deg_2d, fs=sample_rate, mode='wb', lengths=lengths) + + # 2D - 2D -- bad len(lenghts) + ref_2d = np.repeat(ref[np.newaxis, :], n_file, axis=0) + with pytest.raises(ValueError): + pesq_batch(ref=ref, deg=deg_2d, fs=sample_rate, mode='wb', lengths=lengths[:3]) + + # 2D - 2D + ideally = [ + pesq(ref=ref_2d[i, :length], deg=deg_2d[i, :length], fs=sample_rate, mode='wb') + for i, length in enumerate(lengths) + ] + scores = pesq_batch(ref=ref_2d, deg=deg_2d, fs=sample_rate, mode='wb', lengths=lengths) + assert np.allclose(scores, ideally), scores + + # narrowband + ref_2d = np.repeat(ref[np.newaxis, :], n_file, axis=0) + ideally = [ + pesq(ref=ref_2d[i, :length], deg=deg_2d[i, :length], fs=sample_rate, mode='nb') + for i, length in enumerate(lengths) + ] + scores = pesq_batch(ref=ref_2d, deg=deg_2d, fs=sample_rate, mode='nb', lengths=lengths) + assert np.allclose(scores, ideally), scores + + # def test_time_efficiency(): # data_dir = Path(__file__).parent.parent / 'audio' # ref_path = data_dir / 'speech.wav'