From 0f5a148a26db6fe24a9e668174d7087be989da96 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Fri, 15 Dec 2023 06:06:36 -0500 Subject: [PATCH 01/30] add rasr compatible feature extraction --- i6_models/primitives/feature_extraction.py | 91 +++++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index ead52dd5..5cdcc6ed 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -1,5 +1,11 @@ -__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"] +__all__ = [ + "LogMelFeatureExtractionV1", + "LogMelFeatureExtractionV1Config", + "RasrCompatibleLogMelFeatureExtractionV1", + "RasrCompatibleLogMelFeatureExtractionV1Config", +] +import math from dataclasses import dataclass from typing import Optional, Tuple @@ -111,3 +117,86 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: length = ((length - self.n_fft) // self.hop_length) + 1 return feature_data, length.int() + + +@dataclass +class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration): + """ + Attributes: + sample_rate: audio sample rate in Hz + win_size: window size in seconds + hop_size: window shift in seconds + f_min: minimum filter frequency in Hz + f_max: maximum filter frequency in Hz + min_amp: minimum amplitude for safe log + num_filters: number of mel windows + """ + + sample_rate: int + win_size: float + hop_size: float + f_min: int + f_max: int + min_amp: float + num_filters: int + + def __post_init__(self) -> None: + super().__post_init__() + assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate" + assert self.f_min >= 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive" + assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive" + assert self.num_filters > 0, "number of filters needs to be positive" + assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense" + + +class RasrCompatibleLogMelFeatureExtractionV1(nn.Module): + """ + Rasr-compatible log-mel feature extraction using log10. Does not use torchaudio. + + Using it wrapped with torch.no_grad() is recommended if no gradient is needed + """ + + def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): + super().__init__() + self.hop_length = int(cfg.hop_size * cfg.sample_rate) + self.min_amp = cfg.min_amp + self.win_length = int(cfg.win_size * cfg.sample_rate) + # smallest power if two which is greater than or equal to win_length + self.n_fft = 2 ** math.ceil(math.log2(self.win_length)) + + self.register_buffer( + "mel_basis", + torch.tensor( + filters.mel( + sr=cfg.sample_rate, + n_fft=cfg.n_fft, + n_mels=cfg.num_filters, + fmin=cfg.f_min, + fmax=cfg.f_max, + htk=True, + norm=None, + ), + ), + ) + self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + + def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param raw_audio: [B, T] + :param length in samples: [B] + :return features as [B,T,F] and length in frames [B] + """ + windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] + smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] + + # Compute power spectrum using torch.fft.rfftn + power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] + power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] + + melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T'] + log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) + feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] + + length = ((length - self.win_length) // self.hop_length) + 1 + + return feature_data, length.int() From e8c5dde2d738d26344c1497c99c16affe453b6b5 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Fri, 15 Dec 2023 06:17:29 -0500 Subject: [PATCH 02/30] remove f_min and f_max from config --- i6_models/primitives/feature_extraction.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 5cdcc6ed..02706574 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -126,8 +126,6 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration): sample_rate: audio sample rate in Hz win_size: window size in seconds hop_size: window shift in seconds - f_min: minimum filter frequency in Hz - f_max: maximum filter frequency in Hz min_amp: minimum amplitude for safe log num_filters: number of mel windows """ @@ -135,15 +133,11 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration): sample_rate: int win_size: float hop_size: float - f_min: int - f_max: int min_amp: float num_filters: int def __post_init__(self) -> None: super().__post_init__() - assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate" - assert self.f_min >= 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive" assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive" assert self.num_filters > 0, "number of filters needs to be positive" assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense" @@ -171,8 +165,8 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): sr=cfg.sample_rate, n_fft=cfg.n_fft, n_mels=cfg.num_filters, - fmin=cfg.f_min, - fmax=cfg.f_max, + fmin=0, + fmax=cfg.sample_rate // 2, htk=True, norm=None, ), From 5fa4bf3abba33659cc8731c2d0e5b2702486d787 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 3 Jan 2024 12:19:00 -0500 Subject: [PATCH 03/30] fix --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 02706574..34bbfde4 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -163,7 +163,7 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): torch.tensor( filters.mel( sr=cfg.sample_rate, - n_fft=cfg.n_fft, + n_fft=self.n_fft, n_mels=cfg.num_filters, fmin=0, fmax=cfg.sample_rate // 2, From 33dd0473d1fd2aba581ce3ae4cf0bbf108389a16 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Mon, 29 Jan 2024 12:40:16 -0500 Subject: [PATCH 04/30] add preemphasis, use amplitude instead of power spectrum, additive log offset --- i6_models/primitives/feature_extraction.py | 23 +++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 34bbfde4..0bbfc40c 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -146,8 +146,6 @@ def __post_init__(self) -> None: class RasrCompatibleLogMelFeatureExtractionV1(nn.Module): """ Rasr-compatible log-mel feature extraction using log10. Does not use torchaudio. - - Using it wrapped with torch.no_grad() is recommended if no gradient is needed """ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): @@ -180,17 +178,24 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] + # preemphasis + preemphasized = raw_audio.clone() + preemphasized[..., 1:] -= 1.0 * preemphasized[..., :-1] + + # zero pad for the last frame + padded = torch.cat([preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1))], dim=1) + + windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] - # Compute power spectrum using torch.fft.rfftn - power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] - power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] + # Compute amplitude spectrum using torch.fft.rfftn + amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1] + amplitude_spectrum = amplitude_spectrum.transpose(1, 2) # [B, F, T'] - melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T'] - log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) + melspec = torch.einsum("...ft,mf->...mt", amplitude_spectrum, self.mel_basis) # [B, F'=num_filters, T'] + log_melspec = torch.log10(melspec + self.min_amp) feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - length = ((length - self.win_length) // self.hop_length) + 1 + length = math.ceil((length - self.win_length) / self.hop_length) + 1 return feature_data, length.int() From 3071b44cad6be9455f95b6f3070202b363c066db Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Mon, 29 Jan 2024 12:42:12 -0500 Subject: [PATCH 05/30] small change --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 0bbfc40c..383112f3 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -188,7 +188,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] - # Compute amplitude spectrum using torch.fft.rfftn + # compute amplitude spectrum using torch.fft.rfftn amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1] amplitude_spectrum = amplitude_spectrum.transpose(1, 2) # [B, F, T'] From 687854ad109ae60dac28dbb6b2f54ec631888063 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Tue, 30 Jan 2024 06:12:21 -0500 Subject: [PATCH 06/30] make alpha a parameter --- i6_models/primitives/feature_extraction.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 383112f3..d6e4ad9d 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -128,6 +128,7 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration): hop_size: window shift in seconds min_amp: minimum amplitude for safe log num_filters: number of mel windows + alpha: preemphasis weight """ sample_rate: int @@ -135,6 +136,7 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration): hop_size: float min_amp: float num_filters: int + alpha: float = 1.0 def __post_init__(self) -> None: super().__post_init__() @@ -153,8 +155,10 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): self.hop_length = int(cfg.hop_size * cfg.sample_rate) self.min_amp = cfg.min_amp self.win_length = int(cfg.win_size * cfg.sample_rate) - # smallest power if two which is greater than or equal to win_length - self.n_fft = 2 ** math.ceil(math.log2(self.win_length)) + self.n_fft = 2 ** math.ceil( + math.log2(self.win_length) + ) # smallest power if two which is greater than or equal to win_length + self.alpha = cfg.alpha self.register_buffer( "mel_basis", @@ -178,9 +182,9 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - # preemphasis + # preemphasize preemphasized = raw_audio.clone() - preemphasized[..., 1:] -= 1.0 * preemphasized[..., :-1] + preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1] # zero pad for the last frame padded = torch.cat([preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1))], dim=1) From e7c850d0d38be8eecb57634c373a1643c9293974 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Thu, 1 Feb 2024 06:48:26 -0500 Subject: [PATCH 07/30] fix errors --- i6_models/primitives/feature_extraction.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index d6e4ad9d..f87d6ab9 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -187,7 +187,10 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1] # zero pad for the last frame - padded = torch.cat([preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1))], dim=1) + padded = torch.cat( + [preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1), device=preemphasized.device)], + dim=1, + ) windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] @@ -200,6 +203,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: log_melspec = torch.log10(melspec + self.min_amp) feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - length = math.ceil((length - self.win_length) / self.hop_length) + 1 + length = torch.ceil((length - self.win_length) / self.hop_length) + 1 return feature_data, length.int() From 482f5609361dde4124d9d3d9d7af669f59847a5f Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 13:58:03 +0100 Subject: [PATCH 08/30] fix window broadcasting --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index f87d6ab9..0e6274f6 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -193,7 +193,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: ) windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] - smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] + smoothed = windowed * self.window[None, None, :] # [B, T', W] # compute amplitude spectrum using torch.fft.rfftn amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1] From 9e0cde32ff60e378971168ad580109f17ed9ee78 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 14:14:04 +0100 Subject: [PATCH 09/30] test_rasr_compatible --- tests/test_feature_extraction.py | 63 +++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 9e7f241f..815f4a21 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -1,10 +1,12 @@ +import os import copy import numpy as np import torch +import unittest from librosa.feature import melspectrogram -from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1, LogMelFeatureExtractionV1Config +from i6_models.primitives.feature_extraction import * def test_logmel_librosa_compatibility(): @@ -65,3 +67,62 @@ def test_logmel_length(): assert torch.all(mel_center.size()[1] == length_center) mel_no_center, length_no_center = fe_no_center(audio, audio_length) assert torch.all(mel_no_center.size()[1] == length_no_center) + + +def test_rasr_compatible(): + try: + from i6_core.lib.rasr_cache import FileArchive + except ImportError: + raise unittest.SkipTest("i6_core not available") + try: + import soundfile + except ImportError: + raise unittest.SkipTest("soundfile not available") + if not os.path.exists("test_data/features.cache") or not os.path.exists("test_data/103-1240-0000.wav"): + raise unittest.SkipTest("test data not available") + + def _torch_repr(x: torch.Tensor) -> str: + try: + from lovely_tensors import lovely + except ImportError: + mean, std = x.mean(), x.std() + min_, max_ = x.min(), x.max() + return f"{x.shape} x∈[{min_}, {max_}] μ={mean} σ={std} {x.dtype}" + else: + return lovely(x) + + rasr_cache = FileArchive("test_data/features.cache", must_exists=True) + print(rasr_cache.file_list()) + print(rasr_cache.read("corpus/103-1240-0000/1.attribs", "str")) + time_, rasr_feat = rasr_cache.read("corpus/103-1240-0000/1", "feat") + assert len(time_) == len(rasr_feat) + print("RASR feature len:", len(rasr_feat), "frame 0 times:", time_[0], "frame 0 shape:", rasr_feat[0].shape) + rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32) + print("RASR feature shape:", rasr_feat.shape) + + cfg = RasrCompatibleLogMelFeatureExtractionV1Config( + sample_rate=16_000, + win_size=0.025, + hop_size=0.01, + min_amp=1.175494e-38, + num_filters=80, + ) + feat_extractor = RasrCompatibleLogMelFeatureExtractionV1(cfg) + + # int16 audio is in [2**15, 2**15-1]. + # This is how BlissToPcmHDFJob does it by default: + # https://github.com/rwth-i6/i6_core/blob/add09a8b640a2ba5928b815fa65f7504242be038/returnn/hdf.py#L207 + # This is also how our standard RASR flow handles it: + # https://github.com/rwth-i6/i6_models/pull/44#issuecomment-1938264642 + audio, sample_rate = soundfile.read(open("test_data/103-1240-0000.wav", "rb"), dtype="int16") + assert sample_rate == cfg.sample_rate + audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] + print("raw audio", _torch_repr(audio)) + + i6m_feat, _ = feat_extractor(audio.unsqueeze(0), torch.tensor([len(audio)])) + i6m_feat = i6m_feat.squeeze(0) + + print("i6_models:", _torch_repr(i6m_feat)) + print("RASR:", _torch_repr(rasr_feat)) + + torch.testing.assert_allclose(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5) From f45f01e36233c05bb9c34a16a0652d38f058715b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 17:22:07 +0100 Subject: [PATCH 10/30] test_rasr_compatible more --- tests/test_feature_extraction.py | 273 +++++++++++++++++++++++++++++-- 1 file changed, 256 insertions(+), 17 deletions(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 815f4a21..8e8f56c2 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -1,9 +1,14 @@ import os +import sys import copy import numpy as np import torch import unittest +import tempfile +import atexit +import textwrap +from typing import Optional from librosa.feature import melspectrogram from i6_models.primitives.feature_extraction import * @@ -78,23 +83,48 @@ def test_rasr_compatible(): import soundfile except ImportError: raise unittest.SkipTest("soundfile not available") - if not os.path.exists("test_data/features.cache") or not os.path.exists("test_data/103-1240-0000.wav"): - raise unittest.SkipTest("test data not available") - - def _torch_repr(x: torch.Tensor) -> str: - try: - from lovely_tensors import lovely - except ImportError: - mean, std = x.mean(), x.std() - min_, max_ = x.min(), x.max() - return f"{x.shape} x∈[{min_}, {max_}] μ={mean} σ={std} {x.dtype}" - else: - return lovely(x) - - rasr_cache = FileArchive("test_data/features.cache", must_exists=True) + rasr_feature_extractor_bin_path = ( + "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/" + "feature-extraction.linux-x86_64-standard" + ) + if not os.path.exists(rasr_feature_extractor_bin_path): + raise unittest.SkipTest("RASR feature-extraction binary not found") + + wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") + atexit.register(os.remove, wav_file_path) + generate_random_speech_like_audio_wav(wav_file_path) + rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow( + rasr_feature_extractor_bin_path, + wav_file_path, + textwrap.dedent( + f"""\ + + + + + + + + + + + + + + + + + + + """ + ), + ) + + rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) print(rasr_cache.file_list()) - print(rasr_cache.read("corpus/103-1240-0000/1.attribs", "str")) - time_, rasr_feat = rasr_cache.read("corpus/103-1240-0000/1", "feat") + print(rasr_cache.read("corpus/recording/1.attribs", "str")) + time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat") assert len(time_) == len(rasr_feat) print("RASR feature len:", len(rasr_feat), "frame 0 times:", time_[0], "frame 0 shape:", rasr_feat[0].shape) rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32) @@ -114,7 +144,7 @@ def _torch_repr(x: torch.Tensor) -> str: # https://github.com/rwth-i6/i6_core/blob/add09a8b640a2ba5928b815fa65f7504242be038/returnn/hdf.py#L207 # This is also how our standard RASR flow handles it: # https://github.com/rwth-i6/i6_models/pull/44#issuecomment-1938264642 - audio, sample_rate = soundfile.read(open("test_data/103-1240-0000.wav", "rb"), dtype="int16") + audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16") assert sample_rate == cfg.sample_rate audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] print("raw audio", _torch_repr(audio)) @@ -126,3 +156,212 @@ def _torch_repr(x: torch.Tensor) -> str: print("RASR:", _torch_repr(rasr_feat)) torch.testing.assert_allclose(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5) + + +def generate_rasr_feature_cache_from_wav_and_flow( + rasr_feature_extractor_bin_path: str, + wav_file_path: str, + flow_network_str: str, + *, + flow_input_name: str = "samples", + flow_output_name: str = "nonlinear", +) -> str: + import subprocess + + corpus_xml_path = tempfile.mktemp(suffix=".xml", prefix="tmp-rasr-corpus") + atexit.register(os.remove, corpus_xml_path) + with open(corpus_xml_path, "w") as f: + f.write( + textwrap.dedent( + f"""\ + + + + + + + + """ + ) + ) + + rasr_feature_cache_path = tempfile.mktemp(suffix=".cache", prefix="tmp-rasr-features") + atexit.register(os.remove, rasr_feature_cache_path) + rasr_flow_xml_path = tempfile.mktemp(suffix=".config", prefix="tmp-rasr-flow") + atexit.register(os.remove, rasr_flow_xml_path) + with open(rasr_flow_xml_path, "w") as f: + f.write( + textwrap.dedent( + f"""\ + + + + + + + + + + """ + ) + + textwrap.indent(flow_network_str, " ") + + textwrap.dedent( + f"""\ + + + + + """ + ) + ) + + rasr_config_path = tempfile.mktemp(suffix=".config", prefix="tmp-rasr-feature-extract") + atexit.register(os.remove, rasr_config_path) + with open(rasr_config_path, "w") as f: + f.write( + textwrap.dedent( + f"""\ + [*.corpus] + file = {corpus_xml_path} + + [*.feature-extraction] + file = {rasr_flow_xml_path} + """ + ) + ) + + subprocess.check_call([rasr_feature_extractor_bin_path, "--config", rasr_config_path]) + return rasr_feature_cache_path + + +def _get_wav_file_duration_sec(wav_file_path: str) -> float: + import wave + import contextlib + + with contextlib.closing(wave.open(wav_file_path, "r")) as f: + frames = f.getnframes() + rate = f.getframerate() + return frames / float(rate) + + +def generate_random_speech_like_audio_wav( + output_wav_file_path: str, + duration_sec: float = 5.0, + *, + samples_per_sec: int = 16_000, + sample_width_bytes: int = 2, # int16 + frequency: float = 150.0, + num_random_freqs_per_sec: int = 15, + amplitude: float = 0.3, + amplitude_frequency: Optional[float] = None, +): + import wave + + f = wave.open(output_wav_file_path, "wb") + f.setframerate(samples_per_sec) + f.setnchannels(1) + f.setsampwidth(sample_width_bytes) + + samples = generate_random_speech_like_audio( + batch_size=1, + num_frames=int(duration_sec * samples_per_sec), + samples_per_sec=samples_per_sec, + frequency=frequency, + num_random_freqs_per_sec=num_random_freqs_per_sec, + amplitude=amplitude, + amplitude_frequency=amplitude_frequency, + ) # [B,T] + samples = samples[0] # [T] + print("generated raw samples:", _torch_repr(samples)) + + samples_int = (samples * (2 ** (8 * sample_width_bytes - 1) - 1)).to( + {1: torch.int8, 2: torch.int16, 4: torch.int32}[sample_width_bytes] + ) + + f.writeframes(samples_int.numpy().tobytes()) + f.close() + + +def generate_random_speech_like_audio( + batch_size: int, + num_frames: int, + *, + samples_per_sec: int = 16_000, + frequency: float = 150.0, + num_random_freqs_per_sec: int = 15, + amplitude: float = 0.3, + amplitude_frequency: Optional[float] = None, +) -> torch.Tensor: + """ + generate audio + + Source: + https://github.com/albertz/playground/blob/master/create-random-speech-like-sound.py + + :return: shape [batch_size,num_frames] + """ + frame_idxs = torch.arange(num_frames, dtype=torch.int64) # [T] + + samples = _integrate_rnd_frequencies( + batch_size, + frame_idxs, + base_frequency=frequency, + samples_per_sec=samples_per_sec, + num_random_freqs_per_sec=num_random_freqs_per_sec, + ) # [T,B] + + if amplitude_frequency is None: + amplitude_frequency = frequency / 75.0 + amplitude_variations = _integrate_rnd_frequencies( + batch_size, + frame_idxs, + base_frequency=amplitude_frequency, + samples_per_sec=samples_per_sec, + num_random_freqs_per_sec=amplitude_frequency, + ) # [T,B] + + samples *= amplitude * (0.666 + 0.333 * amplitude_variations) + return samples.permute(1, 0) # [B,T] + + +def _integrate_rnd_frequencies( + batch_size: int, + frame_idxs: torch.Tensor, + *, + base_frequency: float, + samples_per_sec: int, + num_random_freqs_per_sec: float, +) -> torch.Tensor: + rnd_freqs = torch.empty( + size=(int(len(frame_idxs) * num_random_freqs_per_sec / samples_per_sec) + 1, batch_size), + dtype=torch.float32, + ) # [T',B] + torch.nn.init.trunc_normal_(rnd_freqs, a=-1.0, b=1.0) + rnd_freqs = (rnd_freqs * 0.5 + 1.0) * base_frequency # [T',B] + + freq_idx_f = (frame_idxs * num_random_freqs_per_sec) / samples_per_sec + freq_idx = freq_idx_f.to(torch.int64) + next_freq_idx = torch.clip(freq_idx + 1, 0, len(rnd_freqs) - 1) + frac = (freq_idx_f % 1)[:, None] # [T,1] + freq = rnd_freqs[freq_idx] * (1 - frac) + rnd_freqs[next_freq_idx] * frac # [T,B] + + ts = torch.cumsum(freq / samples_per_sec, dim=0) # [T,B] + return torch.sin(2 * torch.pi * ts) + + +def _torch_repr(x: torch.Tensor) -> str: + try: + from lovely_tensors import lovely + except ImportError: + mean, std = x.mean(), x.std() + min_, max_ = x.min(), x.max() + return f"{x.shape} x∈[{min_}, {max_}] μ={mean} σ={std} {x.dtype}" + else: + return lovely(x) + + +if __name__ == "__main__": + for arg in sys.argv[1:]: + print(f"*** {arg}()") + globals()[arg]() From 71f9e6a95693c3afdfe9f3c97dcca3915c9f5fcc Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 17:48:25 +0100 Subject: [PATCH 11/30] test_rasr_compatible_raw_audio_samples (passing) --- tests/test_feature_extraction.py | 49 +++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 8e8f56c2..205c0b3d 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -119,6 +119,7 @@ def test_rasr_compatible(): """ ), + flow_output_name="nonlinear", ) rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) @@ -155,7 +156,53 @@ def test_rasr_compatible(): print("i6_models:", _torch_repr(i6m_feat)) print("RASR:", _torch_repr(rasr_feat)) - torch.testing.assert_allclose(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5) + + +def test_rasr_compatible_raw_audio_samples(): + try: + from i6_core.lib.rasr_cache import FileArchive + except ImportError: + raise unittest.SkipTest("i6_core not available") + try: + import soundfile + except ImportError: + raise unittest.SkipTest("soundfile not available") + rasr_feature_extractor_bin_path = ( + "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/" + "feature-extraction.linux-x86_64-standard" + ) + if not os.path.exists(rasr_feature_extractor_bin_path): + raise unittest.SkipTest("RASR feature-extraction binary not found") + + wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") + atexit.register(os.remove, wav_file_path) + generate_random_speech_like_audio_wav(wav_file_path) + rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow( + rasr_feature_extractor_bin_path, + wav_file_path, + textwrap.dedent( + f"""\ + + + + + """ + ), + flow_output_name="convert", + ) + + rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) + time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat") + assert len(time_) == len(rasr_feat) + rasr_feat = torch.tensor(np.concatenate(rasr_feat, axis=0), dtype=torch.float32) + print("RASR:", _torch_repr(rasr_feat)) + + audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16") + audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] + print("raw audio", _torch_repr(audio)) + + torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30) def generate_rasr_feature_cache_from_wav_and_flow( From 259d7f332aeec0dd41366ce6d2d97304d6642032 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 17:48:41 +0100 Subject: [PATCH 12/30] test_rasr_compatible_preemphasis (failing) --- tests/test_feature_extraction.py | 49 ++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 205c0b3d..6332daf5 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -205,6 +205,55 @@ def test_rasr_compatible_raw_audio_samples(): torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30) +def test_rasr_compatible_preemphasis(): + try: + from i6_core.lib.rasr_cache import FileArchive + except ImportError: + raise unittest.SkipTest("i6_core not available") + try: + import soundfile + except ImportError: + raise unittest.SkipTest("soundfile not available") + rasr_feature_extractor_bin_path = ( + "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/" + "feature-extraction.linux-x86_64-standard" + ) + if not os.path.exists(rasr_feature_extractor_bin_path): + raise unittest.SkipTest("RASR feature-extraction binary not found") + + wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") + atexit.register(os.remove, wav_file_path) + generate_random_speech_like_audio_wav(wav_file_path) + rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow( + rasr_feature_extractor_bin_path, + wav_file_path, + textwrap.dedent( + f"""\ + + + + + + + """ + ), + flow_output_name="preemphasis", + ) + + rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) + time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat") + assert len(time_) == len(rasr_feat) + rasr_feat = torch.tensor(np.concatenate(rasr_feat, axis=0), dtype=torch.float32) + print("RASR:", _torch_repr(rasr_feat)) + + audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16") + audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] + audio[..., 1:] -= 1.0 * audio[..., :-1] + print("i6_models", _torch_repr(audio)) + + torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30) + + def generate_rasr_feature_cache_from_wav_and_flow( rasr_feature_extractor_bin_path: str, wav_file_path: str, From f41a60c9ccfa26422da29f91769fd9083efe2c41 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 17:54:21 +0100 Subject: [PATCH 13/30] fix preemphasize --- i6_models/primitives/feature_extraction.py | 1 + tests/test_feature_extraction.py | 1 + 2 files changed, 2 insertions(+) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 0e6274f6..ff0d790e 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -185,6 +185,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: # preemphasize preemphasized = raw_audio.clone() preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1] + preemphasized[..., 0] = 0.0 # zero pad for the last frame padded = torch.cat( diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 6332daf5..e4914c0c 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -249,6 +249,7 @@ def test_rasr_compatible_preemphasis(): audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16") audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] audio[..., 1:] -= 1.0 * audio[..., :-1] + audio[..., 0] = 0.0 print("i6_models", _torch_repr(audio)) torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30) From acefd99e4f93378f0ff6b5d108b9caf724085012 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 18:04:55 +0100 Subject: [PATCH 14/30] test_rasr_compatible_window (failing) --- tests/test_feature_extraction.py | 70 ++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index e4914c0c..3dea7d4f 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -255,6 +255,76 @@ def test_rasr_compatible_preemphasis(): torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30) +def test_rasr_compatible_window(): + try: + from i6_core.lib.rasr_cache import FileArchive + except ImportError: + raise unittest.SkipTest("i6_core not available") + try: + import soundfile + except ImportError: + raise unittest.SkipTest("soundfile not available") + rasr_feature_extractor_bin_path = ( + "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/" + "feature-extraction.linux-x86_64-standard" + ) + if not os.path.exists(rasr_feature_extractor_bin_path): + raise unittest.SkipTest("RASR feature-extraction binary not found") + + wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") + atexit.register(os.remove, wav_file_path) + generate_random_speech_like_audio_wav(wav_file_path) + rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow( + rasr_feature_extractor_bin_path, + wav_file_path, + textwrap.dedent( + f"""\ + + + + + + + + + """ + ), + flow_output_name="window", + ) + + rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) + time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat") + assert len(time_) == len(rasr_feat) + rasr_feat[-1] = np.pad(rasr_feat[-1], (0, rasr_feat[0].shape[0] - rasr_feat[-1].shape[0])) + rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32) + print("RASR:", _torch_repr(rasr_feat)) + + audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16") + audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] + + # preemphasize + audio[..., 1:] -= 1.0 * audio[..., :-1] + audio[..., 0] = 0.0 + + # windowing + win_size = 0.025 + hop_size = 0.01 + hop_length = int(hop_size * sample_rate) + win_length = int(win_size * sample_rate) + padded = torch.cat( # zero pad for the last frame + [audio, torch.zeros((hop_length - 1), device=audio.device)], + dim=0, + ) + + windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length] + window = torch.hann_window(win_length, periodic=False) + smoothed = windowed * window[None, :] # [T', W] + + print("i6_models", _torch_repr(smoothed)) + + torch.testing.assert_close(smoothed, rasr_feat, rtol=1e-30, atol=1e-30) + + def generate_rasr_feature_cache_from_wav_and_flow( rasr_feature_extractor_bin_path: str, wav_file_path: str, From ba47a7864c8c76c7c6b0be841f55f541c650cfb5 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 13 Feb 2024 23:52:45 +0100 Subject: [PATCH 15/30] testing custom hanning window implementations --- tests/test_feature_extraction.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 3dea7d4f..79bdff46 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -311,6 +311,37 @@ def test_rasr_compatible_window(): hop_size = 0.01 hop_length = int(hop_size * sample_rate) win_length = int(win_size * sample_rate) + + # https://github.com/rwth-i6/rasr/blob/master/src/Signal/WindowFunction.cc + # https://pytorch.org/docs/stable/generated/torch.hann_window.html + + def _my_hann_win_np(size: int) -> torch.Tensor: + M_PI = 3.14159265358979323846264338327950288 + M = size - 1 + n = np.arange(0, size // 2) # [size//2] + win = 0.5 - 0.5 * np.cos(2.0 * M_PI * n / M) # [size//2] + win = torch.tensor(win, dtype=torch.float32) + win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size] + return win + + def _my_hann_win(size: int) -> torch.Tensor: + M_PI = 3.14159265358979323846264338327950288 + M = size - 1 + n = torch.arange(0, size // 2, dtype=torch.float64) # [size//2] + win = 0.5 - 0.5 * torch.cos(2.0 * M_PI * n / M) # [size//2] + win = win.to(torch.float32) + win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size] + return win + + # manual + for i, t in enumerate(range(0, audio.shape[0], hop_length)): + x = audio[t : t + win_length] + x = x * torch.hann_window(x.shape[0], periodic=False, dtype=torch.float64).to(torch.float32) + # x = x * _my_hann_win(x.shape[0]) + print("x", i, ":", _torch_repr(x)) + print(" RASR:", _torch_repr(rasr_feat[i])) + torch.testing.assert_close(x, rasr_feat[i][: x.shape[0]], rtol=1e-30, atol=1e-30) + padded = torch.cat( # zero pad for the last frame [audio, torch.zeros((hop_length - 1), device=audio.device)], dim=0, From 44aba814698a2251e0f614538c7dc3a19c496c0e Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 00:15:22 +0100 Subject: [PATCH 16/30] cleanup, fix windowing (WIP) --- i6_models/primitives/feature_extraction.py | 4 +++- tests/test_feature_extraction.py | 26 ++++------------------ 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index ff0d790e..10e3fa7f 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -174,7 +174,9 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): ), ), ) - self.register_buffer("window", torch.hann_window(self.win_length, periodic=False)) + self.register_buffer( + "window", torch.hann_window(self.win_length, periodic=False, dtype=torch.float64).to(torch.float32) + ) def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: """ diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 79bdff46..9e5bd82b 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -315,32 +315,14 @@ def test_rasr_compatible_window(): # https://github.com/rwth-i6/rasr/blob/master/src/Signal/WindowFunction.cc # https://pytorch.org/docs/stable/generated/torch.hann_window.html - def _my_hann_win_np(size: int) -> torch.Tensor: - M_PI = 3.14159265358979323846264338327950288 - M = size - 1 - n = np.arange(0, size // 2) # [size//2] - win = 0.5 - 0.5 * np.cos(2.0 * M_PI * n / M) # [size//2] - win = torch.tensor(win, dtype=torch.float32) - win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size] - return win - - def _my_hann_win(size: int) -> torch.Tensor: - M_PI = 3.14159265358979323846264338327950288 - M = size - 1 - n = torch.arange(0, size // 2, dtype=torch.float64) # [size//2] - win = 0.5 - 0.5 * torch.cos(2.0 * M_PI * n / M) # [size//2] - win = win.to(torch.float32) - win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size] - return win - # manual for i, t in enumerate(range(0, audio.shape[0], hop_length)): x = audio[t : t + win_length] x = x * torch.hann_window(x.shape[0], periodic=False, dtype=torch.float64).to(torch.float32) - # x = x * _my_hann_win(x.shape[0]) - print("x", i, ":", _torch_repr(x)) - print(" RASR:", _torch_repr(rasr_feat[i])) torch.testing.assert_close(x, rasr_feat[i][: x.shape[0]], rtol=1e-30, atol=1e-30) + # once end was reached, stop + if x.shape[0] < win_length: + break padded = torch.cat( # zero pad for the last frame [audio, torch.zeros((hop_length - 1), device=audio.device)], @@ -348,7 +330,7 @@ def _my_hann_win(size: int) -> torch.Tensor: ) windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length] - window = torch.hann_window(win_length, periodic=False) + window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32) smoothed = windowed * window[None, :] # [T', W] print("i6_models", _torch_repr(smoothed)) From d84565047266048141c0aa20103a21236c6e7a4a Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 01:35:04 +0100 Subject: [PATCH 17/30] fix last Hanning window --- i6_models/primitives/feature_extraction.py | 27 +++++++++----- tests/test_feature_extraction.py | 41 ++++++++++++++++++---- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 10e3fa7f..f578e20b 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -181,22 +181,33 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: """ :param raw_audio: [B, T] - :param length in samples: [B] + :param length: in samples [B] :return features as [B,T,F] and length in frames [B] """ + assert raw_audio.shape[1] > 0 # also same for length + res_size = max(raw_audio.shape[1] - self.win_length + self.hop_length - 1, 0) // self.hop_length + 1 + res_length = ( + torch.maximum(length - self.win_length + self.hop_length - 1, torch.zeros_like(length)) // self.hop_length + + 1 + ) + # preemphasize preemphasized = raw_audio.clone() preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1] preemphasized[..., 0] = 0.0 # zero pad for the last frame - padded = torch.cat( - [preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1), device=preemphasized.device)], - dim=1, - ) + last_win_size = preemphasized.shape[1] - (res_size - 1) * self.hop_length + last_pad = self.win_length - last_win_size + padded = torch.nn.functional.pad(preemphasized, (0, last_pad)) windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] - smoothed = windowed * self.window[None, None, :] # [B, T', W] + smoothed = windowed[:, :-1] * self.window[None, None, :] # [B, T'-1, W] + + # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that. + last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32) + last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W] + smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W] # compute amplitude spectrum using torch.fft.rfftn amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1] @@ -206,6 +217,4 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: log_melspec = torch.log10(melspec + self.min_amp) feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - length = torch.ceil((length - self.win_length) / self.hop_length) + 1 - - return feature_data, length.int() + return feature_data, res_length diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 9e5bd82b..b98af129 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -295,7 +295,8 @@ def test_rasr_compatible_window(): rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat") assert len(time_) == len(rasr_feat) - rasr_feat[-1] = np.pad(rasr_feat[-1], (0, rasr_feat[0].shape[0] - rasr_feat[-1].shape[0])) + rasr_last_pad = rasr_feat[0].shape[0] - rasr_feat[-1].shape[0] + rasr_feat[-1] = np.pad(rasr_feat[-1], (0, rasr_last_pad)) rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32) print("RASR:", _torch_repr(rasr_feat)) @@ -312,6 +313,25 @@ def test_rasr_compatible_window(): hop_length = int(hop_size * sample_rate) win_length = int(win_size * sample_rate) + def _get_length(audio_len: int) -> int: + if audio_len == 0: + return 0 + return max(audio_len - win_length + hop_length - 1, 0) // hop_length + 1 + + def _get_length_naive(audio_len: int) -> int: + n = 0 + for t in range(0, audio_len, hop_length): + n += 1 + if audio_len <= t + win_length: + break + return n + + for t in range(10 * win_length): + assert _get_length(t) == _get_length_naive(t), f"t={t}, {_get_length(t)} != {_get_length_naive(t)}" + + res_len = _get_length(audio.shape[0]) + assert res_len == len(rasr_feat) + # https://github.com/rwth-i6/rasr/blob/master/src/Signal/WindowFunction.cc # https://pytorch.org/docs/stable/generated/torch.hann_window.html @@ -321,17 +341,24 @@ def test_rasr_compatible_window(): x = x * torch.hann_window(x.shape[0], periodic=False, dtype=torch.float64).to(torch.float32) torch.testing.assert_close(x, rasr_feat[i][: x.shape[0]], rtol=1e-30, atol=1e-30) # once end was reached, stop - if x.shape[0] < win_length: + if audio.shape[0] <= t + win_length: + assert win_length - x.shape[0] == rasr_last_pad, f"win {win_length}, cur {x.shape[0]}, pad {rasr_last_pad}" break - padded = torch.cat( # zero pad for the last frame - [audio, torch.zeros((hop_length - 1), device=audio.device)], - dim=0, - ) + last_win_size = audio.shape[0] - (res_len - 1) * hop_length + last_pad = win_length - last_win_size + assert last_pad == rasr_last_pad, f"last pad {last_pad}, RASR last pad {rasr_last_pad}" + padded = torch.nn.functional.pad(audio, (0, last_pad)) # zero pad for the last frame windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length] + assert len(windowed) == res_len window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32) - smoothed = windowed * window[None, :] # [T', W] + smoothed = windowed[:-1] * window[None, :] # [T'-1, W] + + # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that. + last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32) + last_win = torch.nn.functional.pad(last_win, (0, last_pad)) + smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0) print("i6_models", _torch_repr(smoothed)) From e2cda8b396f4766cfe0383181e10604e6542712c Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 01:39:33 +0100 Subject: [PATCH 18/30] fix device --- i6_models/primitives/feature_extraction.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index f578e20b..4a08e8f3 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -205,7 +205,9 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: smoothed = windowed[:, :-1] * self.window[None, None, :] # [B, T'-1, W] # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that. - last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32) + last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to( + self.window.device, torch.float32 + ) last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W] smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W] From 6925acc6a967b50035ec60d04b624b56b7bbe805 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 01:41:51 +0100 Subject: [PATCH 19/30] simplify --- i6_models/primitives/feature_extraction.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 4a08e8f3..de7a5163 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -213,10 +213,8 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: # compute amplitude spectrum using torch.fft.rfftn amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1] - amplitude_spectrum = amplitude_spectrum.transpose(1, 2) # [B, F, T'] - melspec = torch.einsum("...ft,mf->...mt", amplitude_spectrum, self.mel_basis) # [B, F'=num_filters, T'] + melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters] log_melspec = torch.log10(melspec + self.min_amp) - feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - return feature_data, res_length + return log_melspec, res_length From 848c88642cf024fc3306d083d17c934cbbee5316 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 01:58:46 +0100 Subject: [PATCH 20/30] test_rasr_compatible_fft (failing) --- tests/test_feature_extraction.py | 90 ++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index b98af129..282e52a0 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -1,6 +1,7 @@ import os import sys import copy +import math import numpy as np import torch import unittest @@ -365,6 +366,95 @@ def _get_length_naive(audio_len: int) -> int: torch.testing.assert_close(smoothed, rasr_feat, rtol=1e-30, atol=1e-30) +def test_rasr_compatible_fft(): + try: + from i6_core.lib.rasr_cache import FileArchive + except ImportError: + raise unittest.SkipTest("i6_core not available") + try: + import soundfile + except ImportError: + raise unittest.SkipTest("soundfile not available") + rasr_feature_extractor_bin_path = ( + "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/" + "feature-extraction.linux-x86_64-standard" + ) + if not os.path.exists(rasr_feature_extractor_bin_path): + raise unittest.SkipTest("RASR feature-extraction binary not found") + + wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") + atexit.register(os.remove, wav_file_path) + generate_random_speech_like_audio_wav(wav_file_path) + rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow( + rasr_feature_extractor_bin_path, + wav_file_path, + textwrap.dedent( + f"""\ + + + + + + + + + + + + + + + """ + ), + flow_output_name="amplitude-spectrum", + ) + + rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) + time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat") + assert len(time_) == len(rasr_feat) + rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32) + print("RASR:", _torch_repr(rasr_feat)) + + audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16") + audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] + + # preemphasize + audio[..., 1:] -= 1.0 * audio[..., :-1] + audio[..., 0] = 0.0 + + # windowing + win_size = 0.025 + hop_size = 0.01 + hop_length = int(hop_size * sample_rate) + win_length = int(win_size * sample_rate) + + res_len = max(audio.shape[0] - win_length + hop_length - 1, 0) // hop_length + 1 + assert res_len == len(rasr_feat) + + last_win_size = audio.shape[0] - (res_len - 1) * hop_length + last_pad = win_length - last_win_size + padded = torch.nn.functional.pad(audio, (0, last_pad)) # zero pad for the last frame + + windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length] + assert len(windowed) == res_len + window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32) + smoothed = windowed[:-1] * window[None, :] # [T'-1, W] + + # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that. + last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32) + last_win = torch.nn.functional.pad(last_win, (0, last_pad)) + smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0) + + n_fft = 2 ** math.ceil(math.log2(win_length)) + # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F] + # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2] + amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=n_fft)) # [B, T', F=n_fft//2+1] + + print("i6_models", _torch_repr(amplitude_spectrum)) + + torch.testing.assert_close(amplitude_spectrum, rasr_feat, rtol=1e-30, atol=1e-30) + + def generate_rasr_feature_cache_from_wav_and_flow( rasr_feature_extractor_bin_path: str, wav_file_path: str, From 990a9779c4acc3eeb44a145f7858ef63f2d36d68 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 11:39:14 +0100 Subject: [PATCH 21/30] FFT test more direct (still failing) --- tests/test_feature_extraction.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 282e52a0..5d7bc76e 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -402,11 +402,9 @@ def test_rasr_compatible_fft(): - - """ ), - flow_output_name="amplitude-spectrum", + flow_output_name="scaling", ) rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) @@ -446,13 +444,18 @@ def test_rasr_compatible_fft(): smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0) n_fft = 2 ** math.ceil(math.log2(win_length)) + print(f"win_length={win_length}, n_fft={n_fft}") + smoothed = smoothed.to(torch.float64) + # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F] # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2] - amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=n_fft)) # [B, T', F=n_fft//2+1] + fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1] + fft = torch.view_as_real(fft).flatten(-2) # # [B, T', F=(n_fft//2+1)*2] + fft = fft.to(torch.float32) - print("i6_models", _torch_repr(amplitude_spectrum)) + print("i6_models", _torch_repr(fft)) - torch.testing.assert_close(amplitude_spectrum, rasr_feat, rtol=1e-30, atol=1e-30) + torch.testing.assert_close(fft, rasr_feat, rtol=1e-30, atol=1e-30) def generate_rasr_feature_cache_from_wav_and_flow( From 6556b3c4f96789d9a26d64f0ae2b8dd9f2e4f49b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 14:39:04 +0100 Subject: [PATCH 22/30] tests deterministic --- tests/test_feature_extraction.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 5d7bc76e..fd6f9257 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -91,6 +91,7 @@ def test_rasr_compatible(): if not os.path.exists(rasr_feature_extractor_bin_path): raise unittest.SkipTest("RASR feature-extraction binary not found") + torch.manual_seed(42) wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") atexit.register(os.remove, wav_file_path) generate_random_speech_like_audio_wav(wav_file_path) @@ -176,6 +177,7 @@ def test_rasr_compatible_raw_audio_samples(): if not os.path.exists(rasr_feature_extractor_bin_path): raise unittest.SkipTest("RASR feature-extraction binary not found") + torch.manual_seed(42) wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") atexit.register(os.remove, wav_file_path) generate_random_speech_like_audio_wav(wav_file_path) @@ -222,6 +224,7 @@ def test_rasr_compatible_preemphasis(): if not os.path.exists(rasr_feature_extractor_bin_path): raise unittest.SkipTest("RASR feature-extraction binary not found") + torch.manual_seed(42) wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") atexit.register(os.remove, wav_file_path) generate_random_speech_like_audio_wav(wav_file_path) @@ -272,6 +275,7 @@ def test_rasr_compatible_window(): if not os.path.exists(rasr_feature_extractor_bin_path): raise unittest.SkipTest("RASR feature-extraction binary not found") + torch.manual_seed(42) wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") atexit.register(os.remove, wav_file_path) generate_random_speech_like_audio_wav(wav_file_path) @@ -382,6 +386,7 @@ def test_rasr_compatible_fft(): if not os.path.exists(rasr_feature_extractor_bin_path): raise unittest.SkipTest("RASR feature-extraction binary not found") + torch.manual_seed(42) wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") atexit.register(os.remove, wav_file_path) generate_random_speech_like_audio_wav(wav_file_path) From 79a043e69a881ba1d9d528d3d07fc6519c706abd Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 14:39:51 +0100 Subject: [PATCH 23/30] copy RASR C++ FFT code for testing --- tests/test_feature_extraction.py | 118 ++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 2 deletions(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index fd6f9257..1b4d6d06 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -454,8 +454,9 @@ def test_rasr_compatible_fft(): # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F] # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2] - fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1] - fft = torch.view_as_real(fft).flatten(-2) # # [B, T', F=(n_fft//2+1)*2] + # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1] + # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F=(n_fft//2+1)*2] + fft = my_fft(smoothed, n_fft=n_fft) fft = fft.to(torch.float32) print("i6_models", _torch_repr(fft)) @@ -463,6 +464,119 @@ def test_rasr_compatible_fft(): torch.testing.assert_close(fft, rasr_feat, rtol=1e-30, atol=1e-30) +def create_bit_reversal_reordering(size): + """ + Creates a bit reversal reordering tensor for a given size. + """ + # Initial setup + length = size // 2 + reordering = torch.arange(size) + + # Bit reversal reordering logic + j = 1 + for i in range(1, size, 2): + if j > i: + reordering[i - 1] = j - 1 + reordering[i] = j + m = length + while 2 <= m < j: + j -= m + m //= 2 + j += m + + return reordering + + +def bit_reversal_reordering(tensor: torch.Tensor) -> torch.Tensor: + """ + Reorders a given tensor according to bit reversal pattern. + """ + size = tensor.size(-1) + reordering = create_bit_reversal_reordering(size) + return tensor[..., reordering] + + +def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: + # https://github.com/rwth-i6/rasr/blob/master/src/Math/FastFourierTransform.cc#L95 + size = n_fft + d_pi = torch.tensor(6.28318530717959, dtype=torch.float64) + theta_base = d_pi + + tensor = torch.nn.functional.pad(tensor, (0, size - tensor.shape[-1])) + v = bit_reversal_reordering(tensor) + + cur_length = 2 + # estimate DFFT using Danielson and Lanczos formula + while cur_length < size: + # initialization of trigonometric recurrence + step = cur_length * 2 + theta = theta_base / cur_length + sin_h_theta = torch.sin(0.5 * theta) + wp_r = -2.0 * sin_h_theta * sin_h_theta + wp_i = torch.sin(theta) + w_r = 1.0 + w_i = 0.0 + for m in range(1, cur_length, 2): + for i in range(m, size, step): + # Danielson & Lanczos formula + j = i + cur_length + tempr = w_r * v[..., j - 1] - w_i * v[..., j] + tempi = w_r * v[..., j] + w_i * v[..., j - 1] + v[..., j - 1] = v[..., i - 1] - tempr + v[..., j] = v[..., i] - tempi + v[..., i - 1] += tempr + v[..., i] += tempi + w_temp_r, w_temp_i = w_r, w_i + w_r = w_temp_r * wp_r - w_temp_i * wp_i + w_r + w_i = w_temp_i * wp_r + w_temp_r * wp_i + w_i + cur_length = step + + pi = torch.tensor(3.141592653589793238, dtype=torch.float64) + size_d4 = size >> 2 + theta = pi / (size >> 1) + c = -0.5 + + sin_h_theta = torch.sin(0.5 * theta) + wp_r = -2.0 * sin_h_theta * sin_h_theta + wp_i = torch.sin(theta) + w_r = wp_r + 1 + w_i = wp_i + + for i in range(1, size_d4): + i1 = i + i + i2 = i1 + 1 + i3 = size - i1 + i4 = i3 + 1 + + # separate the two transforms + h1_r = 0.5 * (v[..., i1] + v[..., i3]) + h1_i = 0.5 * (v[..., i2] - v[..., i4]) + h2_r = -c * (v[..., i2] + v[..., i4]) + h2_i = c * (v[..., i1] - v[..., i3]) + + # Calculating the true transform of the original real data + v[..., i1] = h1_r + w_r * h2_r - w_i * h2_i + v[..., i2] = h1_i + w_r * h2_i + w_i * h2_r + v[..., i3] = h1_r - w_r * h2_r + w_i * h2_i + v[..., i4] = -h1_i + w_r * h2_i + w_i * h2_r + + # Updating the trigonometric recurrences + old_wr = w_r + w_r = w_r * wp_r - w_i * wp_i + w_r + w_i = w_i * wp_r + old_wr * wp_i + w_i + + h = v[..., 0].clone() + v[..., 0] = h + v[..., 1] + v[..., 1] = h - v[..., 1] + + # unpack logic + v = torch.nn.functional.pad(v, (0, 2)) + v[..., -2] = v[..., 1] + v[..., 1] = 0 + + return v + + def generate_rasr_feature_cache_from_wav_and_flow( rasr_feature_extractor_bin_path: str, wav_file_path: str, From 34765574728a2c7a0babf913f3d6e8ca6d78be77 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 15:06:52 +0100 Subject: [PATCH 24/30] FFT fixes --- tests/test_feature_extraction.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 1b4d6d06..27393aa6 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -493,7 +493,12 @@ def bit_reversal_reordering(tensor: torch.Tensor) -> torch.Tensor: """ size = tensor.size(-1) reordering = create_bit_reversal_reordering(size) - return tensor[..., reordering] + v = tensor.clone() + for i in range(size): + h = v[..., i].clone() + v[..., i] = v[..., reordering[i]] + v[..., reordering[i]] = h + return v def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: @@ -527,8 +532,8 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: v[..., i - 1] += tempr v[..., i] += tempi w_temp_r, w_temp_i = w_r, w_i - w_r = w_temp_r * wp_r - w_temp_i * wp_i + w_r - w_i = w_temp_i * wp_r + w_temp_r * wp_i + w_i + w_r = w_r * wp_r - w_temp_i * wp_i + w_r + w_i = w_i * wp_r + w_temp_r * wp_i + w_i cur_length = step pi = torch.tensor(3.141592653589793238, dtype=torch.float64) From 467f8284d3f92cb6aa54c9e0d6014b79e4f9452b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 15:15:11 +0100 Subject: [PATCH 25/30] FFT becomes more exact --- tests/test_feature_extraction.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 27393aa6..315a720b 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -450,14 +450,12 @@ def test_rasr_compatible_fft(): n_fft = 2 ** math.ceil(math.log2(win_length)) print(f"win_length={win_length}, n_fft={n_fft}") - smoothed = smoothed.to(torch.float64) # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F] # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2] # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1] # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F=(n_fft//2+1)*2] fft = my_fft(smoothed, n_fft=n_fft) - fft = fft.to(torch.float32) print("i6_models", _torch_repr(fft)) @@ -525,8 +523,10 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: for i in range(m, size, step): # Danielson & Lanczos formula j = i + cur_length - tempr = w_r * v[..., j - 1] - w_i * v[..., j] - tempi = w_r * v[..., j] + w_i * v[..., j - 1] + tempr = w_r * v[..., j - 1].to(torch.float64) - w_i * v[..., j].to(torch.float64) + tempi = w_r * v[..., j].to(torch.float64) + w_i * v[..., j - 1].to(torch.float64) + tempr = tempr.to(torch.float32) + tempi = tempi.to(torch.float32) v[..., j - 1] = v[..., i - 1] - tempr v[..., j] = v[..., i] - tempi v[..., i - 1] += tempr @@ -554,10 +554,10 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: i4 = i3 + 1 # separate the two transforms - h1_r = 0.5 * (v[..., i1] + v[..., i3]) - h1_i = 0.5 * (v[..., i2] - v[..., i4]) - h2_r = -c * (v[..., i2] + v[..., i4]) - h2_i = c * (v[..., i1] - v[..., i3]) + h1_r = 0.5 * (v[..., i1] + v[..., i3]).to(torch.float64) + h1_i = 0.5 * (v[..., i2] - v[..., i4]).to(torch.float64) + h2_r = (-c * (v[..., i2] + v[..., i4])).to(torch.float64) + h2_i = (c * (v[..., i1] - v[..., i3])).to(torch.float64) # Calculating the true transform of the original real data v[..., i1] = h1_r + w_r * h2_r - w_i * h2_i From 69cd90d9b75ecd664eee3775fd592a1e9c0af160 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 15:25:02 +0100 Subject: [PATCH 26/30] FFT cleanup --- tests/test_feature_extraction.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 315a720b..849f2c9e 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -517,8 +517,8 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: sin_h_theta = torch.sin(0.5 * theta) wp_r = -2.0 * sin_h_theta * sin_h_theta wp_i = torch.sin(theta) - w_r = 1.0 - w_i = 0.0 + w_r = torch.tensor(1.0, dtype=torch.float64) + w_i = torch.tensor(0.0, dtype=torch.float64) for m in range(1, cur_length, 2): for i in range(m, size, step): # Danielson & Lanczos formula @@ -539,7 +539,6 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: pi = torch.tensor(3.141592653589793238, dtype=torch.float64) size_d4 = size >> 2 theta = pi / (size >> 1) - c = -0.5 sin_h_theta = torch.sin(0.5 * theta) wp_r = -2.0 * sin_h_theta * sin_h_theta @@ -556,8 +555,8 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor: # separate the two transforms h1_r = 0.5 * (v[..., i1] + v[..., i3]).to(torch.float64) h1_i = 0.5 * (v[..., i2] - v[..., i4]).to(torch.float64) - h2_r = (-c * (v[..., i2] + v[..., i4])).to(torch.float64) - h2_i = (c * (v[..., i1] - v[..., i3])).to(torch.float64) + h2_r = 0.5 * (v[..., i2] + v[..., i4]).to(torch.float64) + h2_i = -0.5 * (v[..., i1] - v[..., i3]).to(torch.float64) # Calculating the true transform of the original real data v[..., i1] = h1_r + w_r * h2_r - w_i * h2_i From aca1eb3617dcb174e145359664edb3c31bf0ea7b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 14 Feb 2024 15:32:17 +0100 Subject: [PATCH 27/30] test_rasr_compatible_amplitude_spectrum (failing) --- tests/test_feature_extraction.py | 97 +++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 849f2c9e..56f26f46 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -459,7 +459,102 @@ def test_rasr_compatible_fft(): print("i6_models", _torch_repr(fft)) - torch.testing.assert_close(fft, rasr_feat, rtol=1e-30, atol=1e-30) + torch.testing.assert_close(fft, rasr_feat, rtol=1e-6, atol=1e-2) # values are huge, accept some errors...? + + +def test_rasr_compatible_amplitude_spectrum(): + try: + from i6_core.lib.rasr_cache import FileArchive + except ImportError: + raise unittest.SkipTest("i6_core not available") + try: + import soundfile + except ImportError: + raise unittest.SkipTest("soundfile not available") + rasr_feature_extractor_bin_path = ( + "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/" + "feature-extraction.linux-x86_64-standard" + ) + if not os.path.exists(rasr_feature_extractor_bin_path): + raise unittest.SkipTest("RASR feature-extraction binary not found") + + torch.manual_seed(42) + wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio") + atexit.register(os.remove, wav_file_path) + generate_random_speech_like_audio_wav(wav_file_path) + rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow( + rasr_feature_extractor_bin_path, + wav_file_path, + textwrap.dedent( + f"""\ + + + + + + + + + + + + + + + """ + ), + flow_output_name="amplitude-spectrum", + ) + + rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True) + time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat") + assert len(time_) == len(rasr_feat) + rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32) + print("RASR:", _torch_repr(rasr_feat)) + + audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16") + audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1] + + # preemphasize + audio[..., 1:] -= 1.0 * audio[..., :-1] + audio[..., 0] = 0.0 + + # windowing + win_size = 0.025 + hop_size = 0.01 + hop_length = int(hop_size * sample_rate) + win_length = int(win_size * sample_rate) + + res_len = max(audio.shape[0] - win_length + hop_length - 1, 0) // hop_length + 1 + assert res_len == len(rasr_feat) + + last_win_size = audio.shape[0] - (res_len - 1) * hop_length + last_pad = win_length - last_win_size + padded = torch.nn.functional.pad(audio, (0, last_pad)) # zero pad for the last frame + + windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length] + assert len(windowed) == res_len + window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32) + smoothed = windowed[:-1] * window[None, :] # [T'-1, W] + + # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that. + last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32) + last_win = torch.nn.functional.pad(last_win, (0, last_pad)) + smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0) + + n_fft = 2 ** math.ceil(math.log2(win_length)) + print(f"win_length={win_length}, n_fft={n_fft}") + + # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F] + # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2] + # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1] + # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F=(n_fft//2+1)*2] + fft = my_fft(smoothed, n_fft=n_fft) + amplitude_spectrum = torch.abs(torch.view_as_complex(fft.unflatten(-1, (-1, 2)))) + + print("i6_models", _torch_repr(amplitude_spectrum)) + + torch.testing.assert_close(amplitude_spectrum, rasr_feat, rtol=1e-30, atol=1e-30) def create_bit_reversal_reordering(size): From 3ead58feaad02bbed2cf32b98c6eaa98af6fa8bb Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Tue, 12 Mar 2024 09:46:07 -0400 Subject: [PATCH 28/30] add fft scaling, updata test and remove spaces --- i6_models/primitives/feature_extraction.py | 7 +++++-- tests/test_feature_extraction.py | 18 ++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index de7a5163..e1528e81 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -152,6 +152,8 @@ class RasrCompatibleLogMelFeatureExtractionV1(nn.Module): def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): super().__init__() + + self.sample_rate = int(cfg.sample_rate) self.hop_length = int(cfg.hop_size * cfg.sample_rate) self.min_amp = cfg.min_amp self.win_length = int(cfg.win_size * cfg.sample_rate) @@ -211,8 +213,9 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W] smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W] - # compute amplitude spectrum using torch.fft.rfftn - amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1] + # compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling + fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1] + amplitude_spectrum = torch.abs(fft) # [B, T', F=n_fft//2+1] melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters] log_melspec = torch.log10(melspec + self.min_amp) diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index 56f26f46..0bc7a4c2 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -110,15 +110,13 @@ def test_rasr_compatible(): - - - + - + """ ), flow_output_name="nonlinear", @@ -189,7 +187,7 @@ def test_rasr_compatible_raw_audio_samples(): - + """ ), flow_output_name="convert", @@ -236,7 +234,7 @@ def test_rasr_compatible_preemphasis(): - + """ @@ -287,7 +285,7 @@ def test_rasr_compatible_window(): - + @@ -398,7 +396,7 @@ def test_rasr_compatible_fft(): - + @@ -490,7 +488,7 @@ def test_rasr_compatible_amplitude_spectrum(): - + @@ -742,7 +740,7 @@ def generate_rasr_feature_cache_from_wav_and_flow( f"""\ [*.corpus] file = {corpus_xml_path} - + [*.feature-extraction] file = {rasr_flow_xml_path} """ From 1af10fc6fa4298638acbb1d48659b165e11f4ba2 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Tue, 12 Mar 2024 09:50:06 -0400 Subject: [PATCH 29/30] black --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index e1528e81..ac9022ab 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -214,7 +214,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W] # compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling - fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1] + fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1] amplitude_spectrum = torch.abs(fft) # [B, T', F=n_fft//2+1] melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters] From fff39f0d56300edc49bf3c67ab32e543bfb8260a Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Mon, 18 Mar 2024 08:42:03 -0400 Subject: [PATCH 30/30] adjust last window for different sequence lengths --- i6_models/primitives/feature_extraction.py | 27 +++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index ac9022ab..d791f776 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -198,20 +198,25 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1] preemphasized[..., 0] = 0.0 - # zero pad for the last frame - last_win_size = preemphasized.shape[1] - (res_size - 1) * self.hop_length - last_pad = self.win_length - last_win_size - padded = torch.nn.functional.pad(preemphasized, (0, last_pad)) + # zero pad for the last frame of each sequence in the batch + last_win_size = length - (res_length - 1) * self.hop_length # [B] + last_pad = self.win_length - last_win_size # [B] - windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] - smoothed = windowed[:, :-1] * self.window[None, None, :] # [B, T'-1, W] + # zero pad for the whole batch + last_pad_batch = self.win_length - (preemphasized.shape[1] - (res_size - 1) * self.hop_length) + padded = torch.nn.functional.pad(preemphasized, (0, last_pad_batch)) + + windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=self.win_length] + + smoothed = windowed * self.window[None, None, :] # [B, T', W] # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that. - last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to( - self.window.device, torch.float32 - ) - last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W] - smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W] + for i, (last_w_size, last_p, res_l) in enumerate(zip(last_win_size, last_pad, res_length)): + last_win = torch.hann_window(last_w_size, periodic=False, dtype=torch.float64).to( + self.window.device, torch.float32 + ) + last_win = torch.nn.functional.pad(last_win, (0, last_p)) # [W] + smoothed[i, res_l - 1] = windowed[i, res_l - 1] * last_win[None, :] # compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1]