Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lengths argument to pesq_batch #46

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ print(pesq(rate, ref, deg, 'nb'))
# Usage for `multiprocessing` feature

```python
def pesq_batch(fs, ref, deg, mode='wb', n_processor=None, on_error=PesqError.RAISE_EXCEPTION):
def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION, lengths=None):
"""
Running `pesq` using multiple processors
Running `pesq` using multiple processors
Args:
on_error:
ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal
Expand All @@ -76,6 +76,7 @@ def pesq_batch(fs, ref, deg, mode='wb', n_processor=None, on_error=PesqError.RAI
mode: 'wb' (wide-band) or 'nb' (narrow-band)
n_processor: cpu_count() (default) or number of processors (chosen by the user) or 0 (without multiprocessing)
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES
lengths: None or list of original length of audio signals before batching, length n_file
Returns:
pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO)
"""
Expand All @@ -86,6 +87,7 @@ When the `ref` is an 1-D numpy array and `deg` is a 2-D numpy array, the result

When the `ref` is a 2-D numpy array and `deg` is a 2-D numpy array, the result of `pesq_batch` is identical to the value of `[pesq(fs, ref[i,:], deg[i,:],**kwargs) for i in range(deg.shape[0])]`.

The optional `lengths` argument is useful if the individual audio files initially had different length and were padded to match their length before creating the batch. It ensures the padded samples do not affect the results.

# Correctness

Expand Down
13 changes: 10 additions & 3 deletions pesq/_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def pesq(fs, ref, deg, mode='wb', on_error=PesqError.RAISE_EXCEPTION):
return _pesq_inner(ref, deg, fs, mode, on_error)


def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION):
def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION, lengths=None):
"""
Running `pesq` using multiple processors
Args:
Expand All @@ -125,12 +125,15 @@ def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.R
mode: 'wb' (wide-band) or 'nb' (narrow-band)
n_processor: cpu_count() (default) or number of processors (chosen by the user) or 0 (without multiprocessing)
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES
lengths: None or list of original length of audio signals before batching, length n_file
Returns:
pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO)
"""
_check_fs_mode(mode, fs, USAGE_BATCH)
# check dimension
if len(ref.shape) == 1:
if lengths is not None:
raise ValueError("cannot provide lengths if ref is 1D")
if len(deg.shape) == 1 and ref.shape == deg.shape:
return [_pesq_inner(ref, deg, fs, mode, PesqError.RETURN_VALUES)]
elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]:
Expand All @@ -147,14 +150,18 @@ def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.R
raise ValueError("The shapes of `deg` is invalid!")
elif len(ref.shape) == 2:
if deg.shape == ref.shape:
if lengths is None:
lengths = [ref.shape[-1] for _ in range(ref.shape[0])]
elif len(lengths) != ref.shape[0]:
raise ValueError("len(lengths) does not match the batch size")
if n_processor <= 0:
pesq_score = [np.nan for i in range(deg.shape[0])]
for i in range(deg.shape[0]):
pesq_score[i] = _pesq_inner(ref[i, :], deg[i, :], fs, mode, on_error)
pesq_score[i] = _pesq_inner(ref[i, :lengths[i]], deg[i, :lengths[i]], fs, mode, on_error)
return pesq_score
else:
return _processor_mapping(_pesq_inner,
[(ref[i, :], deg[i, :], fs, mode, on_error) for i in range(deg.shape[0])],
[(ref[i, :lengths[i]], deg[i, :lengths[i]], fs, mode, on_error) for i in range(deg.shape[0])],
n_processor)
else:
raise ValueError("The shape of `deg` is invalid!")
Expand Down
44 changes: 44 additions & 0 deletions tests/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,50 @@ def test_pesq_batch():
assert np.allclose(np.array(scores), ideally), scores


def test_lengths():
data_dir = Path(__file__).parent.parent / 'audio'
ref_path = data_dir / 'speech.wav'
deg_path = data_dir / 'speech_bab_0dB.wav'

sample_rate, ref = scipy.io.wavfile.read(ref_path)
sample_rate, deg = scipy.io.wavfile.read(deg_path)

n_file = 10
# random lengths smaller than len(ref) = 49600
lengths = [43433, 40969, 30613, 38570, 45484, 10800, 28424, 22943, 30918, 30784]

# 1D - 1D -- raises error
with pytest.raises(ValueError):
pesq_batch(ref=ref, deg=deg, fs=sample_rate, mode='wb', lengths=lengths)

# 1D - 2D -- raises error
deg_2d = np.repeat(deg[np.newaxis, :], n_file, axis=0)
with pytest.raises(ValueError):
pesq_batch(ref=ref, deg=deg_2d, fs=sample_rate, mode='wb', lengths=lengths)

# 2D - 2D -- bad len(lenghts)
ref_2d = np.repeat(ref[np.newaxis, :], n_file, axis=0)
with pytest.raises(ValueError):
pesq_batch(ref=ref, deg=deg_2d, fs=sample_rate, mode='wb', lengths=lengths[:3])

# 2D - 2D
ideally = [
pesq(ref=ref_2d[i, :length], deg=deg_2d[i, :length], fs=sample_rate, mode='wb')
for i, length in enumerate(lengths)
]
scores = pesq_batch(ref=ref_2d, deg=deg_2d, fs=sample_rate, mode='wb', lengths=lengths)
assert np.allclose(scores, ideally), scores

# narrowband
ref_2d = np.repeat(ref[np.newaxis, :], n_file, axis=0)
ideally = [
pesq(ref=ref_2d[i, :length], deg=deg_2d[i, :length], fs=sample_rate, mode='nb')
for i, length in enumerate(lengths)
]
scores = pesq_batch(ref=ref_2d, deg=deg_2d, fs=sample_rate, mode='nb', lengths=lengths)
assert np.allclose(scores, ideally), scores


# def test_time_efficiency():
# data_dir = Path(__file__).parent.parent / 'audio'
# ref_path = data_dir / 'speech.wav'
Expand Down