Skip to content

Commit

Permalink
add wpe_v8 to torch, i.e., a loopy implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
boeddeker committed Aug 28, 2024
1 parent e4aa9b2 commit 338a7df
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions nara_wpe/torch_wpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,85 @@ def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu
X = Y - torch.matmul(hermite(G), Y_tilde)

return X


def wpe_v8(
Y,
taps=10,
delay=3,
iterations=3,
psd_context=0,
statistics_mode='full',
inplace=False
):
"""
Loopy Multiple Input Multiple Output Weighted Prediction Error [1, 2] implementation
For numpy it is often the fastest Numpy implementation, torch has to be
profiled.
It loops over the independent axes. This reduces the memory footprint.
Args:
Y: Complex valued STFT signal with shape (..., D, T).
taps: Filter order
delay: Delay as a guard interval, such that X does not become zero.
iterations:
psd_context: Defines the number of elements in the time window
to improve the power estimation. Total number of elements will
be (psd_context + 1 + psd_context).
statistics_mode: Either 'full' or 'valid'.
'full': Pad the observation with zeros on the left for the
estimation of the correlation matrix and vector.
'valid': Only calculate correlation matrix and vector on valid
slices of the observation.
inplace: Whether to change Y inplace. Has only advantages, when Y has
independent axes, because the core WPE algorithm does not support
an inplace modification of the observation.
This option may be relevant, when Y is so large, that you do not
want to double the memory consumption (i.e. save Y and the
dereverberated signal in the memory).
Returns:
Estimated signal with the same shape as Y
[1] "Generalization of multi-channel linear prediction methods for blind MIMO
impulse response shortening", Yoshioka, Takuya and Nakatani, Tomohiro, 2012
[2] NARA-WPE: A Python package for weighted prediction error dereverberation in
Numpy and Tensorflow for online and offline processing, Drude, Lukas and
Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold, 2018
"""
ndim = Y.ndim
if ndim == 2:
out = wpe_v6(
Y,
taps=taps,
delay=delay,
iterations=iterations,
psd_context=psd_context,
statistics_mode=statistics_mode
)
if inplace:
Y[...] = out
return out
elif ndim >= 3:
if inplace:
out = Y
else:
out = torch.empty_like(Y)

for index in np.ndindex(Y.shape[:-2]):
out[index] = wpe_v6(
Y=Y[index],
taps=taps,
delay=delay,
iterations=iterations,
psd_context=psd_context,
statistics_mode=statistics_mode,
)
return out
else:
raise NotImplementedError(
'Input shape has to be (..., D, T) and not {}.'.format(Y.shape)
)

0 comments on commit 338a7df

Please sign in to comment.