diff --git a/nara_wpe/torch_wpe.py b/nara_wpe/torch_wpe.py index fa9291d..b113d52 100644 --- a/nara_wpe/torch_wpe.py +++ b/nara_wpe/torch_wpe.py @@ -225,3 +225,85 @@ def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu X = Y - torch.matmul(hermite(G), Y_tilde) return X + + +def wpe_v8( + Y, + taps=10, + delay=3, + iterations=3, + psd_context=0, + statistics_mode='full', + inplace=False +): + """ + Loopy Multiple Input Multiple Output Weighted Prediction Error [1, 2] implementation + + For numpy it is often the fastest Numpy implementation, torch has to be + profiled. + It loops over the independent axes. This reduces the memory footprint. + + Args: + Y: Complex valued STFT signal with shape (..., D, T). + taps: Filter order + delay: Delay as a guard interval, such that X does not become zero. + iterations: + psd_context: Defines the number of elements in the time window + to improve the power estimation. Total number of elements will + be (psd_context + 1 + psd_context). + statistics_mode: Either 'full' or 'valid'. + 'full': Pad the observation with zeros on the left for the + estimation of the correlation matrix and vector. + 'valid': Only calculate correlation matrix and vector on valid + slices of the observation. + inplace: Whether to change Y inplace. Has only advantages, when Y has + independent axes, because the core WPE algorithm does not support + an inplace modification of the observation. + This option may be relevant, when Y is so large, that you do not + want to double the memory consumption (i.e. save Y and the + dereverberated signal in the memory). + + Returns: + Estimated signal with the same shape as Y + + [1] "Generalization of multi-channel linear prediction methods for blind MIMO + impulse response shortening", Yoshioka, Takuya and Nakatani, Tomohiro, 2012 + [2] NARA-WPE: A Python package for weighted prediction error dereverberation in + Numpy and Tensorflow for online and offline processing, Drude, Lukas and + Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold, 2018 + + """ + ndim = Y.ndim + if ndim == 2: + out = wpe_v6( + Y, + taps=taps, + delay=delay, + iterations=iterations, + psd_context=psd_context, + statistics_mode=statistics_mode + ) + if inplace: + Y[...] = out + return out + elif ndim >= 3: + if inplace: + out = Y + else: + out = torch.empty_like(Y) + + for index in np.ndindex(Y.shape[:-2]): + out[index] = wpe_v6( + Y=Y[index], + taps=taps, + delay=delay, + iterations=iterations, + psd_context=psd_context, + statistics_mode=statistics_mode, + ) + return out + else: + raise NotImplementedError( + 'Input shape has to be (..., D, T) and not {}.'.format(Y.shape) + ) +