From 0f5a148a26db6fe24a9e668174d7087be989da96 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Fri, 15 Dec 2023 06:06:36 -0500
Subject: [PATCH 01/30] add rasr compatible feature extraction
---
i6_models/primitives/feature_extraction.py | 91 +++++++++++++++++++++-
1 file changed, 90 insertions(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index ead52dd5..5cdcc6ed 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -1,5 +1,11 @@
-__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
+__all__ = [
+ "LogMelFeatureExtractionV1",
+ "LogMelFeatureExtractionV1Config",
+ "RasrCompatibleLogMelFeatureExtractionV1",
+ "RasrCompatibleLogMelFeatureExtractionV1Config",
+]
+import math
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -111,3 +117,86 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
length = ((length - self.n_fft) // self.hop_length) + 1
return feature_data, length.int()
+
+
+@dataclass
+class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration):
+ """
+ Attributes:
+ sample_rate: audio sample rate in Hz
+ win_size: window size in seconds
+ hop_size: window shift in seconds
+ f_min: minimum filter frequency in Hz
+ f_max: maximum filter frequency in Hz
+ min_amp: minimum amplitude for safe log
+ num_filters: number of mel windows
+ """
+
+ sample_rate: int
+ win_size: float
+ hop_size: float
+ f_min: int
+ f_max: int
+ min_amp: float
+ num_filters: int
+
+ def __post_init__(self) -> None:
+ super().__post_init__()
+ assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate"
+ assert self.f_min >= 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive"
+ assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive"
+ assert self.num_filters > 0, "number of filters needs to be positive"
+ assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense"
+
+
+class RasrCompatibleLogMelFeatureExtractionV1(nn.Module):
+ """
+ Rasr-compatible log-mel feature extraction using log10. Does not use torchaudio.
+
+ Using it wrapped with torch.no_grad() is recommended if no gradient is needed
+ """
+
+ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
+ super().__init__()
+ self.hop_length = int(cfg.hop_size * cfg.sample_rate)
+ self.min_amp = cfg.min_amp
+ self.win_length = int(cfg.win_size * cfg.sample_rate)
+ # smallest power if two which is greater than or equal to win_length
+ self.n_fft = 2 ** math.ceil(math.log2(self.win_length))
+
+ self.register_buffer(
+ "mel_basis",
+ torch.tensor(
+ filters.mel(
+ sr=cfg.sample_rate,
+ n_fft=cfg.n_fft,
+ n_mels=cfg.num_filters,
+ fmin=cfg.f_min,
+ fmax=cfg.f_max,
+ htk=True,
+ norm=None,
+ ),
+ ),
+ )
+ self.register_buffer("window", torch.hann_window(self.win_length, periodic=False))
+
+ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ :param raw_audio: [B, T]
+ :param length in samples: [B]
+ :return features as [B,T,F] and length in frames [B]
+ """
+ windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
+ smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
+
+ # Compute power spectrum using torch.fft.rfftn
+ power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
+ power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
+
+ melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
+ log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
+ feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
+
+ length = ((length - self.win_length) // self.hop_length) + 1
+
+ return feature_data, length.int()
From e8c5dde2d738d26344c1497c99c16affe453b6b5 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Fri, 15 Dec 2023 06:17:29 -0500
Subject: [PATCH 02/30] remove f_min and f_max from config
---
i6_models/primitives/feature_extraction.py | 10 ++--------
1 file changed, 2 insertions(+), 8 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 5cdcc6ed..02706574 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -126,8 +126,6 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration):
sample_rate: audio sample rate in Hz
win_size: window size in seconds
hop_size: window shift in seconds
- f_min: minimum filter frequency in Hz
- f_max: maximum filter frequency in Hz
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
"""
@@ -135,15 +133,11 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration):
sample_rate: int
win_size: float
hop_size: float
- f_min: int
- f_max: int
min_amp: float
num_filters: int
def __post_init__(self) -> None:
super().__post_init__()
- assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate"
- assert self.f_min >= 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive"
assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive"
assert self.num_filters > 0, "number of filters needs to be positive"
assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense"
@@ -171,8 +165,8 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
sr=cfg.sample_rate,
n_fft=cfg.n_fft,
n_mels=cfg.num_filters,
- fmin=cfg.f_min,
- fmax=cfg.f_max,
+ fmin=0,
+ fmax=cfg.sample_rate // 2,
htk=True,
norm=None,
),
From 5fa4bf3abba33659cc8731c2d0e5b2702486d787 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 3 Jan 2024 12:19:00 -0500
Subject: [PATCH 03/30] fix
---
i6_models/primitives/feature_extraction.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 02706574..34bbfde4 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -163,7 +163,7 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
torch.tensor(
filters.mel(
sr=cfg.sample_rate,
- n_fft=cfg.n_fft,
+ n_fft=self.n_fft,
n_mels=cfg.num_filters,
fmin=0,
fmax=cfg.sample_rate // 2,
From 33dd0473d1fd2aba581ce3ae4cf0bbf108389a16 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Mon, 29 Jan 2024 12:40:16 -0500
Subject: [PATCH 04/30] add preemphasis, use amplitude instead of power
spectrum, additive log offset
---
i6_models/primitives/feature_extraction.py | 23 +++++++++++++---------
1 file changed, 14 insertions(+), 9 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 34bbfde4..0bbfc40c 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -146,8 +146,6 @@ def __post_init__(self) -> None:
class RasrCompatibleLogMelFeatureExtractionV1(nn.Module):
"""
Rasr-compatible log-mel feature extraction using log10. Does not use torchaudio.
-
- Using it wrapped with torch.no_grad() is recommended if no gradient is needed
"""
def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
@@ -180,17 +178,24 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
- windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
+ # preemphasis
+ preemphasized = raw_audio.clone()
+ preemphasized[..., 1:] -= 1.0 * preemphasized[..., :-1]
+
+ # zero pad for the last frame
+ padded = torch.cat([preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1))], dim=1)
+
+ windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
- # Compute power spectrum using torch.fft.rfftn
- power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
- power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
+ # Compute amplitude spectrum using torch.fft.rfftn
+ amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1]
+ amplitude_spectrum = amplitude_spectrum.transpose(1, 2) # [B, F, T']
- melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
- log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
+ melspec = torch.einsum("...ft,mf->...mt", amplitude_spectrum, self.mel_basis) # [B, F'=num_filters, T']
+ log_melspec = torch.log10(melspec + self.min_amp)
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
- length = ((length - self.win_length) // self.hop_length) + 1
+ length = math.ceil((length - self.win_length) / self.hop_length) + 1
return feature_data, length.int()
From 3071b44cad6be9455f95b6f3070202b363c066db Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Mon, 29 Jan 2024 12:42:12 -0500
Subject: [PATCH 05/30] small change
---
i6_models/primitives/feature_extraction.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 0bbfc40c..383112f3 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -188,7 +188,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
- # Compute amplitude spectrum using torch.fft.rfftn
+ # compute amplitude spectrum using torch.fft.rfftn
amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1]
amplitude_spectrum = amplitude_spectrum.transpose(1, 2) # [B, F, T']
From 687854ad109ae60dac28dbb6b2f54ec631888063 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Tue, 30 Jan 2024 06:12:21 -0500
Subject: [PATCH 06/30] make alpha a parameter
---
i6_models/primitives/feature_extraction.py | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 383112f3..d6e4ad9d 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -128,6 +128,7 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration):
hop_size: window shift in seconds
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
+ alpha: preemphasis weight
"""
sample_rate: int
@@ -135,6 +136,7 @@ class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration):
hop_size: float
min_amp: float
num_filters: int
+ alpha: float = 1.0
def __post_init__(self) -> None:
super().__post_init__()
@@ -153,8 +155,10 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
self.hop_length = int(cfg.hop_size * cfg.sample_rate)
self.min_amp = cfg.min_amp
self.win_length = int(cfg.win_size * cfg.sample_rate)
- # smallest power if two which is greater than or equal to win_length
- self.n_fft = 2 ** math.ceil(math.log2(self.win_length))
+ self.n_fft = 2 ** math.ceil(
+ math.log2(self.win_length)
+ ) # smallest power if two which is greater than or equal to win_length
+ self.alpha = cfg.alpha
self.register_buffer(
"mel_basis",
@@ -178,9 +182,9 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
- # preemphasis
+ # preemphasize
preemphasized = raw_audio.clone()
- preemphasized[..., 1:] -= 1.0 * preemphasized[..., :-1]
+ preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1]
# zero pad for the last frame
padded = torch.cat([preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1))], dim=1)
From e7c850d0d38be8eecb57634c373a1643c9293974 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Thu, 1 Feb 2024 06:48:26 -0500
Subject: [PATCH 07/30] fix errors
---
i6_models/primitives/feature_extraction.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index d6e4ad9d..f87d6ab9 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -187,7 +187,10 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1]
# zero pad for the last frame
- padded = torch.cat([preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1))], dim=1)
+ padded = torch.cat(
+ [preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1), device=preemphasized.device)],
+ dim=1,
+ )
windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
@@ -200,6 +203,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
log_melspec = torch.log10(melspec + self.min_amp)
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
- length = math.ceil((length - self.win_length) / self.hop_length) + 1
+ length = torch.ceil((length - self.win_length) / self.hop_length) + 1
return feature_data, length.int()
From 482f5609361dde4124d9d3d9d7af669f59847a5f Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 13:58:03 +0100
Subject: [PATCH 08/30] fix window broadcasting
---
i6_models/primitives/feature_extraction.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index f87d6ab9..0e6274f6 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -193,7 +193,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
)
windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
- smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
+ smoothed = windowed * self.window[None, None, :] # [B, T', W]
# compute amplitude spectrum using torch.fft.rfftn
amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1]
From 9e0cde32ff60e378971168ad580109f17ed9ee78 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 14:14:04 +0100
Subject: [PATCH 09/30] test_rasr_compatible
---
tests/test_feature_extraction.py | 63 +++++++++++++++++++++++++++++++-
1 file changed, 62 insertions(+), 1 deletion(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 9e7f241f..815f4a21 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -1,10 +1,12 @@
+import os
import copy
import numpy as np
import torch
+import unittest
from librosa.feature import melspectrogram
-from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1, LogMelFeatureExtractionV1Config
+from i6_models.primitives.feature_extraction import *
def test_logmel_librosa_compatibility():
@@ -65,3 +67,62 @@ def test_logmel_length():
assert torch.all(mel_center.size()[1] == length_center)
mel_no_center, length_no_center = fe_no_center(audio, audio_length)
assert torch.all(mel_no_center.size()[1] == length_no_center)
+
+
+def test_rasr_compatible():
+ try:
+ from i6_core.lib.rasr_cache import FileArchive
+ except ImportError:
+ raise unittest.SkipTest("i6_core not available")
+ try:
+ import soundfile
+ except ImportError:
+ raise unittest.SkipTest("soundfile not available")
+ if not os.path.exists("test_data/features.cache") or not os.path.exists("test_data/103-1240-0000.wav"):
+ raise unittest.SkipTest("test data not available")
+
+ def _torch_repr(x: torch.Tensor) -> str:
+ try:
+ from lovely_tensors import lovely
+ except ImportError:
+ mean, std = x.mean(), x.std()
+ min_, max_ = x.min(), x.max()
+ return f"{x.shape} x∈[{min_}, {max_}] μ={mean} σ={std} {x.dtype}"
+ else:
+ return lovely(x)
+
+ rasr_cache = FileArchive("test_data/features.cache", must_exists=True)
+ print(rasr_cache.file_list())
+ print(rasr_cache.read("corpus/103-1240-0000/1.attribs", "str"))
+ time_, rasr_feat = rasr_cache.read("corpus/103-1240-0000/1", "feat")
+ assert len(time_) == len(rasr_feat)
+ print("RASR feature len:", len(rasr_feat), "frame 0 times:", time_[0], "frame 0 shape:", rasr_feat[0].shape)
+ rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32)
+ print("RASR feature shape:", rasr_feat.shape)
+
+ cfg = RasrCompatibleLogMelFeatureExtractionV1Config(
+ sample_rate=16_000,
+ win_size=0.025,
+ hop_size=0.01,
+ min_amp=1.175494e-38,
+ num_filters=80,
+ )
+ feat_extractor = RasrCompatibleLogMelFeatureExtractionV1(cfg)
+
+ # int16 audio is in [2**15, 2**15-1].
+ # This is how BlissToPcmHDFJob does it by default:
+ # https://github.com/rwth-i6/i6_core/blob/add09a8b640a2ba5928b815fa65f7504242be038/returnn/hdf.py#L207
+ # This is also how our standard RASR flow handles it:
+ # https://github.com/rwth-i6/i6_models/pull/44#issuecomment-1938264642
+ audio, sample_rate = soundfile.read(open("test_data/103-1240-0000.wav", "rb"), dtype="int16")
+ assert sample_rate == cfg.sample_rate
+ audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
+ print("raw audio", _torch_repr(audio))
+
+ i6m_feat, _ = feat_extractor(audio.unsqueeze(0), torch.tensor([len(audio)]))
+ i6m_feat = i6m_feat.squeeze(0)
+
+ print("i6_models:", _torch_repr(i6m_feat))
+ print("RASR:", _torch_repr(rasr_feat))
+
+ torch.testing.assert_allclose(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5)
From f45f01e36233c05bb9c34a16a0652d38f058715b Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 17:22:07 +0100
Subject: [PATCH 10/30] test_rasr_compatible more
---
tests/test_feature_extraction.py | 273 +++++++++++++++++++++++++++++--
1 file changed, 256 insertions(+), 17 deletions(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 815f4a21..8e8f56c2 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -1,9 +1,14 @@
import os
+import sys
import copy
import numpy as np
import torch
import unittest
+import tempfile
+import atexit
+import textwrap
+from typing import Optional
from librosa.feature import melspectrogram
from i6_models.primitives.feature_extraction import *
@@ -78,23 +83,48 @@ def test_rasr_compatible():
import soundfile
except ImportError:
raise unittest.SkipTest("soundfile not available")
- if not os.path.exists("test_data/features.cache") or not os.path.exists("test_data/103-1240-0000.wav"):
- raise unittest.SkipTest("test data not available")
-
- def _torch_repr(x: torch.Tensor) -> str:
- try:
- from lovely_tensors import lovely
- except ImportError:
- mean, std = x.mean(), x.std()
- min_, max_ = x.min(), x.max()
- return f"{x.shape} x∈[{min_}, {max_}] μ={mean} σ={std} {x.dtype}"
- else:
- return lovely(x)
-
- rasr_cache = FileArchive("test_data/features.cache", must_exists=True)
+ rasr_feature_extractor_bin_path = (
+ "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/"
+ "feature-extraction.linux-x86_64-standard"
+ )
+ if not os.path.exists(rasr_feature_extractor_bin_path):
+ raise unittest.SkipTest("RASR feature-extraction binary not found")
+
+ wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
+ atexit.register(os.remove, wav_file_path)
+ generate_random_speech_like_audio_wav(wav_file_path)
+ rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow(
+ rasr_feature_extractor_bin_path,
+ wav_file_path,
+ textwrap.dedent(
+ f"""\
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ ),
+ )
+
+ rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
print(rasr_cache.file_list())
- print(rasr_cache.read("corpus/103-1240-0000/1.attribs", "str"))
- time_, rasr_feat = rasr_cache.read("corpus/103-1240-0000/1", "feat")
+ print(rasr_cache.read("corpus/recording/1.attribs", "str"))
+ time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat")
assert len(time_) == len(rasr_feat)
print("RASR feature len:", len(rasr_feat), "frame 0 times:", time_[0], "frame 0 shape:", rasr_feat[0].shape)
rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32)
@@ -114,7 +144,7 @@ def _torch_repr(x: torch.Tensor) -> str:
# https://github.com/rwth-i6/i6_core/blob/add09a8b640a2ba5928b815fa65f7504242be038/returnn/hdf.py#L207
# This is also how our standard RASR flow handles it:
# https://github.com/rwth-i6/i6_models/pull/44#issuecomment-1938264642
- audio, sample_rate = soundfile.read(open("test_data/103-1240-0000.wav", "rb"), dtype="int16")
+ audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16")
assert sample_rate == cfg.sample_rate
audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
print("raw audio", _torch_repr(audio))
@@ -126,3 +156,212 @@ def _torch_repr(x: torch.Tensor) -> str:
print("RASR:", _torch_repr(rasr_feat))
torch.testing.assert_allclose(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5)
+
+
+def generate_rasr_feature_cache_from_wav_and_flow(
+ rasr_feature_extractor_bin_path: str,
+ wav_file_path: str,
+ flow_network_str: str,
+ *,
+ flow_input_name: str = "samples",
+ flow_output_name: str = "nonlinear",
+) -> str:
+ import subprocess
+
+ corpus_xml_path = tempfile.mktemp(suffix=".xml", prefix="tmp-rasr-corpus")
+ atexit.register(os.remove, corpus_xml_path)
+ with open(corpus_xml_path, "w") as f:
+ f.write(
+ textwrap.dedent(
+ f"""\
+
+
+
+
+
+
+
+ """
+ )
+ )
+
+ rasr_feature_cache_path = tempfile.mktemp(suffix=".cache", prefix="tmp-rasr-features")
+ atexit.register(os.remove, rasr_feature_cache_path)
+ rasr_flow_xml_path = tempfile.mktemp(suffix=".config", prefix="tmp-rasr-flow")
+ atexit.register(os.remove, rasr_flow_xml_path)
+ with open(rasr_flow_xml_path, "w") as f:
+ f.write(
+ textwrap.dedent(
+ f"""\
+
+
+
+
+
+
+
+
+
+ """
+ )
+ + textwrap.indent(flow_network_str, " ")
+ + textwrap.dedent(
+ f"""\
+
+
+
+
+ """
+ )
+ )
+
+ rasr_config_path = tempfile.mktemp(suffix=".config", prefix="tmp-rasr-feature-extract")
+ atexit.register(os.remove, rasr_config_path)
+ with open(rasr_config_path, "w") as f:
+ f.write(
+ textwrap.dedent(
+ f"""\
+ [*.corpus]
+ file = {corpus_xml_path}
+
+ [*.feature-extraction]
+ file = {rasr_flow_xml_path}
+ """
+ )
+ )
+
+ subprocess.check_call([rasr_feature_extractor_bin_path, "--config", rasr_config_path])
+ return rasr_feature_cache_path
+
+
+def _get_wav_file_duration_sec(wav_file_path: str) -> float:
+ import wave
+ import contextlib
+
+ with contextlib.closing(wave.open(wav_file_path, "r")) as f:
+ frames = f.getnframes()
+ rate = f.getframerate()
+ return frames / float(rate)
+
+
+def generate_random_speech_like_audio_wav(
+ output_wav_file_path: str,
+ duration_sec: float = 5.0,
+ *,
+ samples_per_sec: int = 16_000,
+ sample_width_bytes: int = 2, # int16
+ frequency: float = 150.0,
+ num_random_freqs_per_sec: int = 15,
+ amplitude: float = 0.3,
+ amplitude_frequency: Optional[float] = None,
+):
+ import wave
+
+ f = wave.open(output_wav_file_path, "wb")
+ f.setframerate(samples_per_sec)
+ f.setnchannels(1)
+ f.setsampwidth(sample_width_bytes)
+
+ samples = generate_random_speech_like_audio(
+ batch_size=1,
+ num_frames=int(duration_sec * samples_per_sec),
+ samples_per_sec=samples_per_sec,
+ frequency=frequency,
+ num_random_freqs_per_sec=num_random_freqs_per_sec,
+ amplitude=amplitude,
+ amplitude_frequency=amplitude_frequency,
+ ) # [B,T]
+ samples = samples[0] # [T]
+ print("generated raw samples:", _torch_repr(samples))
+
+ samples_int = (samples * (2 ** (8 * sample_width_bytes - 1) - 1)).to(
+ {1: torch.int8, 2: torch.int16, 4: torch.int32}[sample_width_bytes]
+ )
+
+ f.writeframes(samples_int.numpy().tobytes())
+ f.close()
+
+
+def generate_random_speech_like_audio(
+ batch_size: int,
+ num_frames: int,
+ *,
+ samples_per_sec: int = 16_000,
+ frequency: float = 150.0,
+ num_random_freqs_per_sec: int = 15,
+ amplitude: float = 0.3,
+ amplitude_frequency: Optional[float] = None,
+) -> torch.Tensor:
+ """
+ generate audio
+
+ Source:
+ https://github.com/albertz/playground/blob/master/create-random-speech-like-sound.py
+
+ :return: shape [batch_size,num_frames]
+ """
+ frame_idxs = torch.arange(num_frames, dtype=torch.int64) # [T]
+
+ samples = _integrate_rnd_frequencies(
+ batch_size,
+ frame_idxs,
+ base_frequency=frequency,
+ samples_per_sec=samples_per_sec,
+ num_random_freqs_per_sec=num_random_freqs_per_sec,
+ ) # [T,B]
+
+ if amplitude_frequency is None:
+ amplitude_frequency = frequency / 75.0
+ amplitude_variations = _integrate_rnd_frequencies(
+ batch_size,
+ frame_idxs,
+ base_frequency=amplitude_frequency,
+ samples_per_sec=samples_per_sec,
+ num_random_freqs_per_sec=amplitude_frequency,
+ ) # [T,B]
+
+ samples *= amplitude * (0.666 + 0.333 * amplitude_variations)
+ return samples.permute(1, 0) # [B,T]
+
+
+def _integrate_rnd_frequencies(
+ batch_size: int,
+ frame_idxs: torch.Tensor,
+ *,
+ base_frequency: float,
+ samples_per_sec: int,
+ num_random_freqs_per_sec: float,
+) -> torch.Tensor:
+ rnd_freqs = torch.empty(
+ size=(int(len(frame_idxs) * num_random_freqs_per_sec / samples_per_sec) + 1, batch_size),
+ dtype=torch.float32,
+ ) # [T',B]
+ torch.nn.init.trunc_normal_(rnd_freqs, a=-1.0, b=1.0)
+ rnd_freqs = (rnd_freqs * 0.5 + 1.0) * base_frequency # [T',B]
+
+ freq_idx_f = (frame_idxs * num_random_freqs_per_sec) / samples_per_sec
+ freq_idx = freq_idx_f.to(torch.int64)
+ next_freq_idx = torch.clip(freq_idx + 1, 0, len(rnd_freqs) - 1)
+ frac = (freq_idx_f % 1)[:, None] # [T,1]
+ freq = rnd_freqs[freq_idx] * (1 - frac) + rnd_freqs[next_freq_idx] * frac # [T,B]
+
+ ts = torch.cumsum(freq / samples_per_sec, dim=0) # [T,B]
+ return torch.sin(2 * torch.pi * ts)
+
+
+def _torch_repr(x: torch.Tensor) -> str:
+ try:
+ from lovely_tensors import lovely
+ except ImportError:
+ mean, std = x.mean(), x.std()
+ min_, max_ = x.min(), x.max()
+ return f"{x.shape} x∈[{min_}, {max_}] μ={mean} σ={std} {x.dtype}"
+ else:
+ return lovely(x)
+
+
+if __name__ == "__main__":
+ for arg in sys.argv[1:]:
+ print(f"*** {arg}()")
+ globals()[arg]()
From 71f9e6a95693c3afdfe9f3c97dcca3915c9f5fcc Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 17:48:25 +0100
Subject: [PATCH 11/30] test_rasr_compatible_raw_audio_samples (passing)
---
tests/test_feature_extraction.py | 49 +++++++++++++++++++++++++++++++-
1 file changed, 48 insertions(+), 1 deletion(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 8e8f56c2..205c0b3d 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -119,6 +119,7 @@ def test_rasr_compatible():
"""
),
+ flow_output_name="nonlinear",
)
rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
@@ -155,7 +156,53 @@ def test_rasr_compatible():
print("i6_models:", _torch_repr(i6m_feat))
print("RASR:", _torch_repr(rasr_feat))
- torch.testing.assert_allclose(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5)
+ torch.testing.assert_close(i6m_feat, rasr_feat, rtol=1e-5, atol=1e-5)
+
+
+def test_rasr_compatible_raw_audio_samples():
+ try:
+ from i6_core.lib.rasr_cache import FileArchive
+ except ImportError:
+ raise unittest.SkipTest("i6_core not available")
+ try:
+ import soundfile
+ except ImportError:
+ raise unittest.SkipTest("soundfile not available")
+ rasr_feature_extractor_bin_path = (
+ "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/"
+ "feature-extraction.linux-x86_64-standard"
+ )
+ if not os.path.exists(rasr_feature_extractor_bin_path):
+ raise unittest.SkipTest("RASR feature-extraction binary not found")
+
+ wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
+ atexit.register(os.remove, wav_file_path)
+ generate_random_speech_like_audio_wav(wav_file_path)
+ rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow(
+ rasr_feature_extractor_bin_path,
+ wav_file_path,
+ textwrap.dedent(
+ f"""\
+
+
+
+
+ """
+ ),
+ flow_output_name="convert",
+ )
+
+ rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
+ time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat")
+ assert len(time_) == len(rasr_feat)
+ rasr_feat = torch.tensor(np.concatenate(rasr_feat, axis=0), dtype=torch.float32)
+ print("RASR:", _torch_repr(rasr_feat))
+
+ audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16")
+ audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
+ print("raw audio", _torch_repr(audio))
+
+ torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30)
def generate_rasr_feature_cache_from_wav_and_flow(
From 259d7f332aeec0dd41366ce6d2d97304d6642032 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 17:48:41 +0100
Subject: [PATCH 12/30] test_rasr_compatible_preemphasis (failing)
---
tests/test_feature_extraction.py | 49 ++++++++++++++++++++++++++++++++
1 file changed, 49 insertions(+)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 205c0b3d..6332daf5 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -205,6 +205,55 @@ def test_rasr_compatible_raw_audio_samples():
torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30)
+def test_rasr_compatible_preemphasis():
+ try:
+ from i6_core.lib.rasr_cache import FileArchive
+ except ImportError:
+ raise unittest.SkipTest("i6_core not available")
+ try:
+ import soundfile
+ except ImportError:
+ raise unittest.SkipTest("soundfile not available")
+ rasr_feature_extractor_bin_path = (
+ "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/"
+ "feature-extraction.linux-x86_64-standard"
+ )
+ if not os.path.exists(rasr_feature_extractor_bin_path):
+ raise unittest.SkipTest("RASR feature-extraction binary not found")
+
+ wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
+ atexit.register(os.remove, wav_file_path)
+ generate_random_speech_like_audio_wav(wav_file_path)
+ rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow(
+ rasr_feature_extractor_bin_path,
+ wav_file_path,
+ textwrap.dedent(
+ f"""\
+
+
+
+
+
+
+ """
+ ),
+ flow_output_name="preemphasis",
+ )
+
+ rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
+ time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat")
+ assert len(time_) == len(rasr_feat)
+ rasr_feat = torch.tensor(np.concatenate(rasr_feat, axis=0), dtype=torch.float32)
+ print("RASR:", _torch_repr(rasr_feat))
+
+ audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16")
+ audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
+ audio[..., 1:] -= 1.0 * audio[..., :-1]
+ print("i6_models", _torch_repr(audio))
+
+ torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30)
+
+
def generate_rasr_feature_cache_from_wav_and_flow(
rasr_feature_extractor_bin_path: str,
wav_file_path: str,
From f41a60c9ccfa26422da29f91769fd9083efe2c41 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 17:54:21 +0100
Subject: [PATCH 13/30] fix preemphasize
---
i6_models/primitives/feature_extraction.py | 1 +
tests/test_feature_extraction.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 0e6274f6..ff0d790e 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -185,6 +185,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
# preemphasize
preemphasized = raw_audio.clone()
preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1]
+ preemphasized[..., 0] = 0.0
# zero pad for the last frame
padded = torch.cat(
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 6332daf5..e4914c0c 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -249,6 +249,7 @@ def test_rasr_compatible_preemphasis():
audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16")
audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
audio[..., 1:] -= 1.0 * audio[..., :-1]
+ audio[..., 0] = 0.0
print("i6_models", _torch_repr(audio))
torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30)
From acefd99e4f93378f0ff6b5d108b9caf724085012 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 18:04:55 +0100
Subject: [PATCH 14/30] test_rasr_compatible_window (failing)
---
tests/test_feature_extraction.py | 70 ++++++++++++++++++++++++++++++++
1 file changed, 70 insertions(+)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index e4914c0c..3dea7d4f 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -255,6 +255,76 @@ def test_rasr_compatible_preemphasis():
torch.testing.assert_close(audio, rasr_feat, rtol=1e-30, atol=1e-30)
+def test_rasr_compatible_window():
+ try:
+ from i6_core.lib.rasr_cache import FileArchive
+ except ImportError:
+ raise unittest.SkipTest("i6_core not available")
+ try:
+ import soundfile
+ except ImportError:
+ raise unittest.SkipTest("soundfile not available")
+ rasr_feature_extractor_bin_path = (
+ "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/"
+ "feature-extraction.linux-x86_64-standard"
+ )
+ if not os.path.exists(rasr_feature_extractor_bin_path):
+ raise unittest.SkipTest("RASR feature-extraction binary not found")
+
+ wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
+ atexit.register(os.remove, wav_file_path)
+ generate_random_speech_like_audio_wav(wav_file_path)
+ rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow(
+ rasr_feature_extractor_bin_path,
+ wav_file_path,
+ textwrap.dedent(
+ f"""\
+
+
+
+
+
+
+
+
+ """
+ ),
+ flow_output_name="window",
+ )
+
+ rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
+ time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat")
+ assert len(time_) == len(rasr_feat)
+ rasr_feat[-1] = np.pad(rasr_feat[-1], (0, rasr_feat[0].shape[0] - rasr_feat[-1].shape[0]))
+ rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32)
+ print("RASR:", _torch_repr(rasr_feat))
+
+ audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16")
+ audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
+
+ # preemphasize
+ audio[..., 1:] -= 1.0 * audio[..., :-1]
+ audio[..., 0] = 0.0
+
+ # windowing
+ win_size = 0.025
+ hop_size = 0.01
+ hop_length = int(hop_size * sample_rate)
+ win_length = int(win_size * sample_rate)
+ padded = torch.cat( # zero pad for the last frame
+ [audio, torch.zeros((hop_length - 1), device=audio.device)],
+ dim=0,
+ )
+
+ windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length]
+ window = torch.hann_window(win_length, periodic=False)
+ smoothed = windowed * window[None, :] # [T', W]
+
+ print("i6_models", _torch_repr(smoothed))
+
+ torch.testing.assert_close(smoothed, rasr_feat, rtol=1e-30, atol=1e-30)
+
+
def generate_rasr_feature_cache_from_wav_and_flow(
rasr_feature_extractor_bin_path: str,
wav_file_path: str,
From ba47a7864c8c76c7c6b0be841f55f541c650cfb5 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Tue, 13 Feb 2024 23:52:45 +0100
Subject: [PATCH 15/30] testing custom hanning window implementations
---
tests/test_feature_extraction.py | 31 +++++++++++++++++++++++++++++++
1 file changed, 31 insertions(+)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 3dea7d4f..79bdff46 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -311,6 +311,37 @@ def test_rasr_compatible_window():
hop_size = 0.01
hop_length = int(hop_size * sample_rate)
win_length = int(win_size * sample_rate)
+
+ # https://github.com/rwth-i6/rasr/blob/master/src/Signal/WindowFunction.cc
+ # https://pytorch.org/docs/stable/generated/torch.hann_window.html
+
+ def _my_hann_win_np(size: int) -> torch.Tensor:
+ M_PI = 3.14159265358979323846264338327950288
+ M = size - 1
+ n = np.arange(0, size // 2) # [size//2]
+ win = 0.5 - 0.5 * np.cos(2.0 * M_PI * n / M) # [size//2]
+ win = torch.tensor(win, dtype=torch.float32)
+ win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size]
+ return win
+
+ def _my_hann_win(size: int) -> torch.Tensor:
+ M_PI = 3.14159265358979323846264338327950288
+ M = size - 1
+ n = torch.arange(0, size // 2, dtype=torch.float64) # [size//2]
+ win = 0.5 - 0.5 * torch.cos(2.0 * M_PI * n / M) # [size//2]
+ win = win.to(torch.float32)
+ win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size]
+ return win
+
+ # manual
+ for i, t in enumerate(range(0, audio.shape[0], hop_length)):
+ x = audio[t : t + win_length]
+ x = x * torch.hann_window(x.shape[0], periodic=False, dtype=torch.float64).to(torch.float32)
+ # x = x * _my_hann_win(x.shape[0])
+ print("x", i, ":", _torch_repr(x))
+ print(" RASR:", _torch_repr(rasr_feat[i]))
+ torch.testing.assert_close(x, rasr_feat[i][: x.shape[0]], rtol=1e-30, atol=1e-30)
+
padded = torch.cat( # zero pad for the last frame
[audio, torch.zeros((hop_length - 1), device=audio.device)],
dim=0,
From 44aba814698a2251e0f614538c7dc3a19c496c0e Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 00:15:22 +0100
Subject: [PATCH 16/30] cleanup, fix windowing (WIP)
---
i6_models/primitives/feature_extraction.py | 4 +++-
tests/test_feature_extraction.py | 26 ++++------------------
2 files changed, 7 insertions(+), 23 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index ff0d790e..10e3fa7f 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -174,7 +174,9 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
),
),
)
- self.register_buffer("window", torch.hann_window(self.win_length, periodic=False))
+ self.register_buffer(
+ "window", torch.hann_window(self.win_length, periodic=False, dtype=torch.float64).to(torch.float32)
+ )
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 79bdff46..9e5bd82b 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -315,32 +315,14 @@ def test_rasr_compatible_window():
# https://github.com/rwth-i6/rasr/blob/master/src/Signal/WindowFunction.cc
# https://pytorch.org/docs/stable/generated/torch.hann_window.html
- def _my_hann_win_np(size: int) -> torch.Tensor:
- M_PI = 3.14159265358979323846264338327950288
- M = size - 1
- n = np.arange(0, size // 2) # [size//2]
- win = 0.5 - 0.5 * np.cos(2.0 * M_PI * n / M) # [size//2]
- win = torch.tensor(win, dtype=torch.float32)
- win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size]
- return win
-
- def _my_hann_win(size: int) -> torch.Tensor:
- M_PI = 3.14159265358979323846264338327950288
- M = size - 1
- n = torch.arange(0, size // 2, dtype=torch.float64) # [size//2]
- win = 0.5 - 0.5 * torch.cos(2.0 * M_PI * n / M) # [size//2]
- win = win.to(torch.float32)
- win = torch.concatenate([win, win.flip(0)[size % 2 :]]) # [size]
- return win
-
# manual
for i, t in enumerate(range(0, audio.shape[0], hop_length)):
x = audio[t : t + win_length]
x = x * torch.hann_window(x.shape[0], periodic=False, dtype=torch.float64).to(torch.float32)
- # x = x * _my_hann_win(x.shape[0])
- print("x", i, ":", _torch_repr(x))
- print(" RASR:", _torch_repr(rasr_feat[i]))
torch.testing.assert_close(x, rasr_feat[i][: x.shape[0]], rtol=1e-30, atol=1e-30)
+ # once end was reached, stop
+ if x.shape[0] < win_length:
+ break
padded = torch.cat( # zero pad for the last frame
[audio, torch.zeros((hop_length - 1), device=audio.device)],
@@ -348,7 +330,7 @@ def _my_hann_win(size: int) -> torch.Tensor:
)
windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length]
- window = torch.hann_window(win_length, periodic=False)
+ window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32)
smoothed = windowed * window[None, :] # [T', W]
print("i6_models", _torch_repr(smoothed))
From d84565047266048141c0aa20103a21236c6e7a4a Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 01:35:04 +0100
Subject: [PATCH 17/30] fix last Hanning window
---
i6_models/primitives/feature_extraction.py | 27 +++++++++-----
tests/test_feature_extraction.py | 41 ++++++++++++++++++----
2 files changed, 52 insertions(+), 16 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 10e3fa7f..f578e20b 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -181,22 +181,33 @@ def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param raw_audio: [B, T]
- :param length in samples: [B]
+ :param length: in samples [B]
:return features as [B,T,F] and length in frames [B]
"""
+ assert raw_audio.shape[1] > 0 # also same for length
+ res_size = max(raw_audio.shape[1] - self.win_length + self.hop_length - 1, 0) // self.hop_length + 1
+ res_length = (
+ torch.maximum(length - self.win_length + self.hop_length - 1, torch.zeros_like(length)) // self.hop_length
+ + 1
+ )
+
# preemphasize
preemphasized = raw_audio.clone()
preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1]
preemphasized[..., 0] = 0.0
# zero pad for the last frame
- padded = torch.cat(
- [preemphasized, torch.zeros(preemphasized.shape[0], (self.hop_length - 1), device=preemphasized.device)],
- dim=1,
- )
+ last_win_size = preemphasized.shape[1] - (res_size - 1) * self.hop_length
+ last_pad = self.win_length - last_win_size
+ padded = torch.nn.functional.pad(preemphasized, (0, last_pad))
windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
- smoothed = windowed * self.window[None, None, :] # [B, T', W]
+ smoothed = windowed[:, :-1] * self.window[None, None, :] # [B, T'-1, W]
+
+ # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
+ last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32)
+ last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W]
+ smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W]
# compute amplitude spectrum using torch.fft.rfftn
amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1]
@@ -206,6 +217,4 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
log_melspec = torch.log10(melspec + self.min_amp)
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
- length = torch.ceil((length - self.win_length) / self.hop_length) + 1
-
- return feature_data, length.int()
+ return feature_data, res_length
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 9e5bd82b..b98af129 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -295,7 +295,8 @@ def test_rasr_compatible_window():
rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat")
assert len(time_) == len(rasr_feat)
- rasr_feat[-1] = np.pad(rasr_feat[-1], (0, rasr_feat[0].shape[0] - rasr_feat[-1].shape[0]))
+ rasr_last_pad = rasr_feat[0].shape[0] - rasr_feat[-1].shape[0]
+ rasr_feat[-1] = np.pad(rasr_feat[-1], (0, rasr_last_pad))
rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32)
print("RASR:", _torch_repr(rasr_feat))
@@ -312,6 +313,25 @@ def test_rasr_compatible_window():
hop_length = int(hop_size * sample_rate)
win_length = int(win_size * sample_rate)
+ def _get_length(audio_len: int) -> int:
+ if audio_len == 0:
+ return 0
+ return max(audio_len - win_length + hop_length - 1, 0) // hop_length + 1
+
+ def _get_length_naive(audio_len: int) -> int:
+ n = 0
+ for t in range(0, audio_len, hop_length):
+ n += 1
+ if audio_len <= t + win_length:
+ break
+ return n
+
+ for t in range(10 * win_length):
+ assert _get_length(t) == _get_length_naive(t), f"t={t}, {_get_length(t)} != {_get_length_naive(t)}"
+
+ res_len = _get_length(audio.shape[0])
+ assert res_len == len(rasr_feat)
+
# https://github.com/rwth-i6/rasr/blob/master/src/Signal/WindowFunction.cc
# https://pytorch.org/docs/stable/generated/torch.hann_window.html
@@ -321,17 +341,24 @@ def test_rasr_compatible_window():
x = x * torch.hann_window(x.shape[0], periodic=False, dtype=torch.float64).to(torch.float32)
torch.testing.assert_close(x, rasr_feat[i][: x.shape[0]], rtol=1e-30, atol=1e-30)
# once end was reached, stop
- if x.shape[0] < win_length:
+ if audio.shape[0] <= t + win_length:
+ assert win_length - x.shape[0] == rasr_last_pad, f"win {win_length}, cur {x.shape[0]}, pad {rasr_last_pad}"
break
- padded = torch.cat( # zero pad for the last frame
- [audio, torch.zeros((hop_length - 1), device=audio.device)],
- dim=0,
- )
+ last_win_size = audio.shape[0] - (res_len - 1) * hop_length
+ last_pad = win_length - last_win_size
+ assert last_pad == rasr_last_pad, f"last pad {last_pad}, RASR last pad {rasr_last_pad}"
+ padded = torch.nn.functional.pad(audio, (0, last_pad)) # zero pad for the last frame
windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length]
+ assert len(windowed) == res_len
window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32)
- smoothed = windowed * window[None, :] # [T', W]
+ smoothed = windowed[:-1] * window[None, :] # [T'-1, W]
+
+ # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
+ last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32)
+ last_win = torch.nn.functional.pad(last_win, (0, last_pad))
+ smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0)
print("i6_models", _torch_repr(smoothed))
From e2cda8b396f4766cfe0383181e10604e6542712c Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 01:39:33 +0100
Subject: [PATCH 18/30] fix device
---
i6_models/primitives/feature_extraction.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index f578e20b..4a08e8f3 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -205,7 +205,9 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
smoothed = windowed[:, :-1] * self.window[None, None, :] # [B, T'-1, W]
# The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
- last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32)
+ last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(
+ self.window.device, torch.float32
+ )
last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W]
smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W]
From 6925acc6a967b50035ec60d04b624b56b7bbe805 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 01:41:51 +0100
Subject: [PATCH 19/30] simplify
---
i6_models/primitives/feature_extraction.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 4a08e8f3..de7a5163 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -213,10 +213,8 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
# compute amplitude spectrum using torch.fft.rfftn
amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1]
- amplitude_spectrum = amplitude_spectrum.transpose(1, 2) # [B, F, T']
- melspec = torch.einsum("...ft,mf->...mt", amplitude_spectrum, self.mel_basis) # [B, F'=num_filters, T']
+ melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters]
log_melspec = torch.log10(melspec + self.min_amp)
- feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
- return feature_data, res_length
+ return log_melspec, res_length
From 848c88642cf024fc3306d083d17c934cbbee5316 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 01:58:46 +0100
Subject: [PATCH 20/30] test_rasr_compatible_fft (failing)
---
tests/test_feature_extraction.py | 90 ++++++++++++++++++++++++++++++++
1 file changed, 90 insertions(+)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index b98af129..282e52a0 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -1,6 +1,7 @@
import os
import sys
import copy
+import math
import numpy as np
import torch
import unittest
@@ -365,6 +366,95 @@ def _get_length_naive(audio_len: int) -> int:
torch.testing.assert_close(smoothed, rasr_feat, rtol=1e-30, atol=1e-30)
+def test_rasr_compatible_fft():
+ try:
+ from i6_core.lib.rasr_cache import FileArchive
+ except ImportError:
+ raise unittest.SkipTest("i6_core not available")
+ try:
+ import soundfile
+ except ImportError:
+ raise unittest.SkipTest("soundfile not available")
+ rasr_feature_extractor_bin_path = (
+ "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/"
+ "feature-extraction.linux-x86_64-standard"
+ )
+ if not os.path.exists(rasr_feature_extractor_bin_path):
+ raise unittest.SkipTest("RASR feature-extraction binary not found")
+
+ wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
+ atexit.register(os.remove, wav_file_path)
+ generate_random_speech_like_audio_wav(wav_file_path)
+ rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow(
+ rasr_feature_extractor_bin_path,
+ wav_file_path,
+ textwrap.dedent(
+ f"""\
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ ),
+ flow_output_name="amplitude-spectrum",
+ )
+
+ rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
+ time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat")
+ assert len(time_) == len(rasr_feat)
+ rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32)
+ print("RASR:", _torch_repr(rasr_feat))
+
+ audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16")
+ audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
+
+ # preemphasize
+ audio[..., 1:] -= 1.0 * audio[..., :-1]
+ audio[..., 0] = 0.0
+
+ # windowing
+ win_size = 0.025
+ hop_size = 0.01
+ hop_length = int(hop_size * sample_rate)
+ win_length = int(win_size * sample_rate)
+
+ res_len = max(audio.shape[0] - win_length + hop_length - 1, 0) // hop_length + 1
+ assert res_len == len(rasr_feat)
+
+ last_win_size = audio.shape[0] - (res_len - 1) * hop_length
+ last_pad = win_length - last_win_size
+ padded = torch.nn.functional.pad(audio, (0, last_pad)) # zero pad for the last frame
+
+ windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length]
+ assert len(windowed) == res_len
+ window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32)
+ smoothed = windowed[:-1] * window[None, :] # [T'-1, W]
+
+ # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
+ last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32)
+ last_win = torch.nn.functional.pad(last_win, (0, last_pad))
+ smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0)
+
+ n_fft = 2 ** math.ceil(math.log2(win_length))
+ # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F]
+ # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2]
+ amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=n_fft)) # [B, T', F=n_fft//2+1]
+
+ print("i6_models", _torch_repr(amplitude_spectrum))
+
+ torch.testing.assert_close(amplitude_spectrum, rasr_feat, rtol=1e-30, atol=1e-30)
+
+
def generate_rasr_feature_cache_from_wav_and_flow(
rasr_feature_extractor_bin_path: str,
wav_file_path: str,
From 990a9779c4acc3eeb44a145f7858ef63f2d36d68 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 11:39:14 +0100
Subject: [PATCH 21/30] FFT test more direct (still failing)
---
tests/test_feature_extraction.py | 15 +++++++++------
1 file changed, 9 insertions(+), 6 deletions(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 282e52a0..5d7bc76e 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -402,11 +402,9 @@ def test_rasr_compatible_fft():
-
-
"""
),
- flow_output_name="amplitude-spectrum",
+ flow_output_name="scaling",
)
rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
@@ -446,13 +444,18 @@ def test_rasr_compatible_fft():
smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0)
n_fft = 2 ** math.ceil(math.log2(win_length))
+ print(f"win_length={win_length}, n_fft={n_fft}")
+ smoothed = smoothed.to(torch.float64)
+
# fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F]
# fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2]
- amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=n_fft)) # [B, T', F=n_fft//2+1]
+ fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1]
+ fft = torch.view_as_real(fft).flatten(-2) # # [B, T', F=(n_fft//2+1)*2]
+ fft = fft.to(torch.float32)
- print("i6_models", _torch_repr(amplitude_spectrum))
+ print("i6_models", _torch_repr(fft))
- torch.testing.assert_close(amplitude_spectrum, rasr_feat, rtol=1e-30, atol=1e-30)
+ torch.testing.assert_close(fft, rasr_feat, rtol=1e-30, atol=1e-30)
def generate_rasr_feature_cache_from_wav_and_flow(
From 6556b3c4f96789d9a26d64f0ae2b8dd9f2e4f49b Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 14:39:04 +0100
Subject: [PATCH 22/30] tests deterministic
---
tests/test_feature_extraction.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 5d7bc76e..fd6f9257 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -91,6 +91,7 @@ def test_rasr_compatible():
if not os.path.exists(rasr_feature_extractor_bin_path):
raise unittest.SkipTest("RASR feature-extraction binary not found")
+ torch.manual_seed(42)
wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
atexit.register(os.remove, wav_file_path)
generate_random_speech_like_audio_wav(wav_file_path)
@@ -176,6 +177,7 @@ def test_rasr_compatible_raw_audio_samples():
if not os.path.exists(rasr_feature_extractor_bin_path):
raise unittest.SkipTest("RASR feature-extraction binary not found")
+ torch.manual_seed(42)
wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
atexit.register(os.remove, wav_file_path)
generate_random_speech_like_audio_wav(wav_file_path)
@@ -222,6 +224,7 @@ def test_rasr_compatible_preemphasis():
if not os.path.exists(rasr_feature_extractor_bin_path):
raise unittest.SkipTest("RASR feature-extraction binary not found")
+ torch.manual_seed(42)
wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
atexit.register(os.remove, wav_file_path)
generate_random_speech_like_audio_wav(wav_file_path)
@@ -272,6 +275,7 @@ def test_rasr_compatible_window():
if not os.path.exists(rasr_feature_extractor_bin_path):
raise unittest.SkipTest("RASR feature-extraction binary not found")
+ torch.manual_seed(42)
wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
atexit.register(os.remove, wav_file_path)
generate_random_speech_like_audio_wav(wav_file_path)
@@ -382,6 +386,7 @@ def test_rasr_compatible_fft():
if not os.path.exists(rasr_feature_extractor_bin_path):
raise unittest.SkipTest("RASR feature-extraction binary not found")
+ torch.manual_seed(42)
wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
atexit.register(os.remove, wav_file_path)
generate_random_speech_like_audio_wav(wav_file_path)
From 79a043e69a881ba1d9d528d3d07fc6519c706abd Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 14:39:51 +0100
Subject: [PATCH 23/30] copy RASR C++ FFT code for testing
---
tests/test_feature_extraction.py | 118 ++++++++++++++++++++++++++++++-
1 file changed, 116 insertions(+), 2 deletions(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index fd6f9257..1b4d6d06 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -454,8 +454,9 @@ def test_rasr_compatible_fft():
# fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F]
# fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2]
- fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1]
- fft = torch.view_as_real(fft).flatten(-2) # # [B, T', F=(n_fft//2+1)*2]
+ # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1]
+ # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F=(n_fft//2+1)*2]
+ fft = my_fft(smoothed, n_fft=n_fft)
fft = fft.to(torch.float32)
print("i6_models", _torch_repr(fft))
@@ -463,6 +464,119 @@ def test_rasr_compatible_fft():
torch.testing.assert_close(fft, rasr_feat, rtol=1e-30, atol=1e-30)
+def create_bit_reversal_reordering(size):
+ """
+ Creates a bit reversal reordering tensor for a given size.
+ """
+ # Initial setup
+ length = size // 2
+ reordering = torch.arange(size)
+
+ # Bit reversal reordering logic
+ j = 1
+ for i in range(1, size, 2):
+ if j > i:
+ reordering[i - 1] = j - 1
+ reordering[i] = j
+ m = length
+ while 2 <= m < j:
+ j -= m
+ m //= 2
+ j += m
+
+ return reordering
+
+
+def bit_reversal_reordering(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Reorders a given tensor according to bit reversal pattern.
+ """
+ size = tensor.size(-1)
+ reordering = create_bit_reversal_reordering(size)
+ return tensor[..., reordering]
+
+
+def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
+ # https://github.com/rwth-i6/rasr/blob/master/src/Math/FastFourierTransform.cc#L95
+ size = n_fft
+ d_pi = torch.tensor(6.28318530717959, dtype=torch.float64)
+ theta_base = d_pi
+
+ tensor = torch.nn.functional.pad(tensor, (0, size - tensor.shape[-1]))
+ v = bit_reversal_reordering(tensor)
+
+ cur_length = 2
+ # estimate DFFT using Danielson and Lanczos formula
+ while cur_length < size:
+ # initialization of trigonometric recurrence
+ step = cur_length * 2
+ theta = theta_base / cur_length
+ sin_h_theta = torch.sin(0.5 * theta)
+ wp_r = -2.0 * sin_h_theta * sin_h_theta
+ wp_i = torch.sin(theta)
+ w_r = 1.0
+ w_i = 0.0
+ for m in range(1, cur_length, 2):
+ for i in range(m, size, step):
+ # Danielson & Lanczos formula
+ j = i + cur_length
+ tempr = w_r * v[..., j - 1] - w_i * v[..., j]
+ tempi = w_r * v[..., j] + w_i * v[..., j - 1]
+ v[..., j - 1] = v[..., i - 1] - tempr
+ v[..., j] = v[..., i] - tempi
+ v[..., i - 1] += tempr
+ v[..., i] += tempi
+ w_temp_r, w_temp_i = w_r, w_i
+ w_r = w_temp_r * wp_r - w_temp_i * wp_i + w_r
+ w_i = w_temp_i * wp_r + w_temp_r * wp_i + w_i
+ cur_length = step
+
+ pi = torch.tensor(3.141592653589793238, dtype=torch.float64)
+ size_d4 = size >> 2
+ theta = pi / (size >> 1)
+ c = -0.5
+
+ sin_h_theta = torch.sin(0.5 * theta)
+ wp_r = -2.0 * sin_h_theta * sin_h_theta
+ wp_i = torch.sin(theta)
+ w_r = wp_r + 1
+ w_i = wp_i
+
+ for i in range(1, size_d4):
+ i1 = i + i
+ i2 = i1 + 1
+ i3 = size - i1
+ i4 = i3 + 1
+
+ # separate the two transforms
+ h1_r = 0.5 * (v[..., i1] + v[..., i3])
+ h1_i = 0.5 * (v[..., i2] - v[..., i4])
+ h2_r = -c * (v[..., i2] + v[..., i4])
+ h2_i = c * (v[..., i1] - v[..., i3])
+
+ # Calculating the true transform of the original real data
+ v[..., i1] = h1_r + w_r * h2_r - w_i * h2_i
+ v[..., i2] = h1_i + w_r * h2_i + w_i * h2_r
+ v[..., i3] = h1_r - w_r * h2_r + w_i * h2_i
+ v[..., i4] = -h1_i + w_r * h2_i + w_i * h2_r
+
+ # Updating the trigonometric recurrences
+ old_wr = w_r
+ w_r = w_r * wp_r - w_i * wp_i + w_r
+ w_i = w_i * wp_r + old_wr * wp_i + w_i
+
+ h = v[..., 0].clone()
+ v[..., 0] = h + v[..., 1]
+ v[..., 1] = h - v[..., 1]
+
+ # unpack logic
+ v = torch.nn.functional.pad(v, (0, 2))
+ v[..., -2] = v[..., 1]
+ v[..., 1] = 0
+
+ return v
+
+
def generate_rasr_feature_cache_from_wav_and_flow(
rasr_feature_extractor_bin_path: str,
wav_file_path: str,
From 34765574728a2c7a0babf913f3d6e8ca6d78be77 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 15:06:52 +0100
Subject: [PATCH 24/30] FFT fixes
---
tests/test_feature_extraction.py | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 1b4d6d06..27393aa6 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -493,7 +493,12 @@ def bit_reversal_reordering(tensor: torch.Tensor) -> torch.Tensor:
"""
size = tensor.size(-1)
reordering = create_bit_reversal_reordering(size)
- return tensor[..., reordering]
+ v = tensor.clone()
+ for i in range(size):
+ h = v[..., i].clone()
+ v[..., i] = v[..., reordering[i]]
+ v[..., reordering[i]] = h
+ return v
def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
@@ -527,8 +532,8 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
v[..., i - 1] += tempr
v[..., i] += tempi
w_temp_r, w_temp_i = w_r, w_i
- w_r = w_temp_r * wp_r - w_temp_i * wp_i + w_r
- w_i = w_temp_i * wp_r + w_temp_r * wp_i + w_i
+ w_r = w_r * wp_r - w_temp_i * wp_i + w_r
+ w_i = w_i * wp_r + w_temp_r * wp_i + w_i
cur_length = step
pi = torch.tensor(3.141592653589793238, dtype=torch.float64)
From 467f8284d3f92cb6aa54c9e0d6014b79e4f9452b Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 15:15:11 +0100
Subject: [PATCH 25/30] FFT becomes more exact
---
tests/test_feature_extraction.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 27393aa6..315a720b 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -450,14 +450,12 @@ def test_rasr_compatible_fft():
n_fft = 2 ** math.ceil(math.log2(win_length))
print(f"win_length={win_length}, n_fft={n_fft}")
- smoothed = smoothed.to(torch.float64)
# fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F]
# fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2]
# fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1]
# fft = torch.view_as_real(fft).flatten(-2) # [B, T', F=(n_fft//2+1)*2]
fft = my_fft(smoothed, n_fft=n_fft)
- fft = fft.to(torch.float32)
print("i6_models", _torch_repr(fft))
@@ -525,8 +523,10 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
for i in range(m, size, step):
# Danielson & Lanczos formula
j = i + cur_length
- tempr = w_r * v[..., j - 1] - w_i * v[..., j]
- tempi = w_r * v[..., j] + w_i * v[..., j - 1]
+ tempr = w_r * v[..., j - 1].to(torch.float64) - w_i * v[..., j].to(torch.float64)
+ tempi = w_r * v[..., j].to(torch.float64) + w_i * v[..., j - 1].to(torch.float64)
+ tempr = tempr.to(torch.float32)
+ tempi = tempi.to(torch.float32)
v[..., j - 1] = v[..., i - 1] - tempr
v[..., j] = v[..., i] - tempi
v[..., i - 1] += tempr
@@ -554,10 +554,10 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
i4 = i3 + 1
# separate the two transforms
- h1_r = 0.5 * (v[..., i1] + v[..., i3])
- h1_i = 0.5 * (v[..., i2] - v[..., i4])
- h2_r = -c * (v[..., i2] + v[..., i4])
- h2_i = c * (v[..., i1] - v[..., i3])
+ h1_r = 0.5 * (v[..., i1] + v[..., i3]).to(torch.float64)
+ h1_i = 0.5 * (v[..., i2] - v[..., i4]).to(torch.float64)
+ h2_r = (-c * (v[..., i2] + v[..., i4])).to(torch.float64)
+ h2_i = (c * (v[..., i1] - v[..., i3])).to(torch.float64)
# Calculating the true transform of the original real data
v[..., i1] = h1_r + w_r * h2_r - w_i * h2_i
From 69cd90d9b75ecd664eee3775fd592a1e9c0af160 Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 15:25:02 +0100
Subject: [PATCH 26/30] FFT cleanup
---
tests/test_feature_extraction.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 315a720b..849f2c9e 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -517,8 +517,8 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
sin_h_theta = torch.sin(0.5 * theta)
wp_r = -2.0 * sin_h_theta * sin_h_theta
wp_i = torch.sin(theta)
- w_r = 1.0
- w_i = 0.0
+ w_r = torch.tensor(1.0, dtype=torch.float64)
+ w_i = torch.tensor(0.0, dtype=torch.float64)
for m in range(1, cur_length, 2):
for i in range(m, size, step):
# Danielson & Lanczos formula
@@ -539,7 +539,6 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
pi = torch.tensor(3.141592653589793238, dtype=torch.float64)
size_d4 = size >> 2
theta = pi / (size >> 1)
- c = -0.5
sin_h_theta = torch.sin(0.5 * theta)
wp_r = -2.0 * sin_h_theta * sin_h_theta
@@ -556,8 +555,8 @@ def my_fft(tensor: torch.Tensor, *, n_fft: int) -> torch.Tensor:
# separate the two transforms
h1_r = 0.5 * (v[..., i1] + v[..., i3]).to(torch.float64)
h1_i = 0.5 * (v[..., i2] - v[..., i4]).to(torch.float64)
- h2_r = (-c * (v[..., i2] + v[..., i4])).to(torch.float64)
- h2_i = (c * (v[..., i1] - v[..., i3])).to(torch.float64)
+ h2_r = 0.5 * (v[..., i2] + v[..., i4]).to(torch.float64)
+ h2_i = -0.5 * (v[..., i1] - v[..., i3]).to(torch.float64)
# Calculating the true transform of the original real data
v[..., i1] = h1_r + w_r * h2_r - w_i * h2_i
From aca1eb3617dcb174e145359664edb3c31bf0ea7b Mon Sep 17 00:00:00 2001
From: Albert Zeyer
Date: Wed, 14 Feb 2024 15:32:17 +0100
Subject: [PATCH 27/30] test_rasr_compatible_amplitude_spectrum (failing)
---
tests/test_feature_extraction.py | 97 +++++++++++++++++++++++++++++++-
1 file changed, 96 insertions(+), 1 deletion(-)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 849f2c9e..56f26f46 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -459,7 +459,102 @@ def test_rasr_compatible_fft():
print("i6_models", _torch_repr(fft))
- torch.testing.assert_close(fft, rasr_feat, rtol=1e-30, atol=1e-30)
+ torch.testing.assert_close(fft, rasr_feat, rtol=1e-6, atol=1e-2) # values are huge, accept some errors...?
+
+
+def test_rasr_compatible_amplitude_spectrum():
+ try:
+ from i6_core.lib.rasr_cache import FileArchive
+ except ImportError:
+ raise unittest.SkipTest("i6_core not available")
+ try:
+ import soundfile
+ except ImportError:
+ raise unittest.SkipTest("soundfile not available")
+ rasr_feature_extractor_bin_path = (
+ "/work/tools22/asr/rasr/rasr_onnx_haswell_0623/arch/linux-x86_64-standard/"
+ "feature-extraction.linux-x86_64-standard"
+ )
+ if not os.path.exists(rasr_feature_extractor_bin_path):
+ raise unittest.SkipTest("RASR feature-extraction binary not found")
+
+ torch.manual_seed(42)
+ wav_file_path = tempfile.mktemp(suffix=".wav", prefix="tmp-i6models-random-audio")
+ atexit.register(os.remove, wav_file_path)
+ generate_random_speech_like_audio_wav(wav_file_path)
+ rasr_feature_cache_path = generate_rasr_feature_cache_from_wav_and_flow(
+ rasr_feature_extractor_bin_path,
+ wav_file_path,
+ textwrap.dedent(
+ f"""\
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ ),
+ flow_output_name="amplitude-spectrum",
+ )
+
+ rasr_cache = FileArchive(rasr_feature_cache_path, must_exists=True)
+ time_, rasr_feat = rasr_cache.read("corpus/recording/1", "feat")
+ assert len(time_) == len(rasr_feat)
+ rasr_feat = torch.tensor(np.stack(rasr_feat, axis=0), dtype=torch.float32)
+ print("RASR:", _torch_repr(rasr_feat))
+
+ audio, sample_rate = soundfile.read(open(wav_file_path, "rb"), dtype="int16")
+ audio = torch.tensor(audio.astype(np.float32)) # [-2**15, 2**15-1]
+
+ # preemphasize
+ audio[..., 1:] -= 1.0 * audio[..., :-1]
+ audio[..., 0] = 0.0
+
+ # windowing
+ win_size = 0.025
+ hop_size = 0.01
+ hop_length = int(hop_size * sample_rate)
+ win_length = int(win_size * sample_rate)
+
+ res_len = max(audio.shape[0] - win_length + hop_length - 1, 0) // hop_length + 1
+ assert res_len == len(rasr_feat)
+
+ last_win_size = audio.shape[0] - (res_len - 1) * hop_length
+ last_pad = win_length - last_win_size
+ padded = torch.nn.functional.pad(audio, (0, last_pad)) # zero pad for the last frame
+
+ windowed = padded.unfold(0, size=win_length, step=hop_length) # [T', W=win_length]
+ assert len(windowed) == res_len
+ window = torch.hann_window(win_length, periodic=False, dtype=torch.float64).to(torch.float32)
+ smoothed = windowed[:-1] * window[None, :] # [T'-1, W]
+
+ # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
+ last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(torch.float32)
+ last_win = torch.nn.functional.pad(last_win, (0, last_pad))
+ smoothed = torch.cat([smoothed, (windowed[-1] * last_win)[None, :]], dim=0)
+
+ n_fft = 2 ** math.ceil(math.log2(win_length))
+ print(f"win_length={win_length}, n_fft={n_fft}")
+
+ # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F]
+ # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F*2]
+ # fft = torch.fft.rfftn(smoothed, s=n_fft) # [B, T', F=n_fft//2+1]
+ # fft = torch.view_as_real(fft).flatten(-2) # [B, T', F=(n_fft//2+1)*2]
+ fft = my_fft(smoothed, n_fft=n_fft)
+ amplitude_spectrum = torch.abs(torch.view_as_complex(fft.unflatten(-1, (-1, 2))))
+
+ print("i6_models", _torch_repr(amplitude_spectrum))
+
+ torch.testing.assert_close(amplitude_spectrum, rasr_feat, rtol=1e-30, atol=1e-30)
def create_bit_reversal_reordering(size):
From 3ead58feaad02bbed2cf32b98c6eaa98af6fa8bb Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Tue, 12 Mar 2024 09:46:07 -0400
Subject: [PATCH 28/30] add fft scaling, updata test and remove spaces
---
i6_models/primitives/feature_extraction.py | 7 +++++--
tests/test_feature_extraction.py | 18 ++++++++----------
2 files changed, 13 insertions(+), 12 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index de7a5163..e1528e81 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -152,6 +152,8 @@ class RasrCompatibleLogMelFeatureExtractionV1(nn.Module):
def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
super().__init__()
+
+ self.sample_rate = int(cfg.sample_rate)
self.hop_length = int(cfg.hop_size * cfg.sample_rate)
self.min_amp = cfg.min_amp
self.win_length = int(cfg.win_size * cfg.sample_rate)
@@ -211,8 +213,9 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W]
smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W]
- # compute amplitude spectrum using torch.fft.rfftn
- amplitude_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) # [B, T', F=n_fft//2+1]
+ # compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling
+ fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1]
+ amplitude_spectrum = torch.abs(fft) # [B, T', F=n_fft//2+1]
melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters]
log_melspec = torch.log10(melspec + self.min_amp)
diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py
index 56f26f46..0bc7a4c2 100644
--- a/tests/test_feature_extraction.py
+++ b/tests/test_feature_extraction.py
@@ -110,15 +110,13 @@ def test_rasr_compatible():
-
-
-
+
-
+
"""
),
flow_output_name="nonlinear",
@@ -189,7 +187,7 @@ def test_rasr_compatible_raw_audio_samples():
-
+
"""
),
flow_output_name="convert",
@@ -236,7 +234,7 @@ def test_rasr_compatible_preemphasis():
-
+
"""
@@ -287,7 +285,7 @@ def test_rasr_compatible_window():
-
+
@@ -398,7 +396,7 @@ def test_rasr_compatible_fft():
-
+
@@ -490,7 +488,7 @@ def test_rasr_compatible_amplitude_spectrum():
-
+
@@ -742,7 +740,7 @@ def generate_rasr_feature_cache_from_wav_and_flow(
f"""\
[*.corpus]
file = {corpus_xml_path}
-
+
[*.feature-extraction]
file = {rasr_flow_xml_path}
"""
From 1af10fc6fa4298638acbb1d48659b165e11f4ba2 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Tue, 12 Mar 2024 09:50:06 -0400
Subject: [PATCH 29/30] black
---
i6_models/primitives/feature_extraction.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index e1528e81..ac9022ab 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -214,7 +214,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W]
# compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling
- fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1]
+ fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1]
amplitude_spectrum = torch.abs(fft) # [B, T', F=n_fft//2+1]
melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters]
From fff39f0d56300edc49bf3c67ab32e543bfb8260a Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Mon, 18 Mar 2024 08:42:03 -0400
Subject: [PATCH 30/30] adjust last window for different sequence lengths
---
i6_models/primitives/feature_extraction.py | 27 +++++++++++++---------
1 file changed, 16 insertions(+), 11 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index ac9022ab..d791f776 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -198,20 +198,25 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1]
preemphasized[..., 0] = 0.0
- # zero pad for the last frame
- last_win_size = preemphasized.shape[1] - (res_size - 1) * self.hop_length
- last_pad = self.win_length - last_win_size
- padded = torch.nn.functional.pad(preemphasized, (0, last_pad))
+ # zero pad for the last frame of each sequence in the batch
+ last_win_size = length - (res_length - 1) * self.hop_length # [B]
+ last_pad = self.win_length - last_win_size # [B]
- windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
- smoothed = windowed[:, :-1] * self.window[None, None, :] # [B, T'-1, W]
+ # zero pad for the whole batch
+ last_pad_batch = self.win_length - (preemphasized.shape[1] - (res_size - 1) * self.hop_length)
+ padded = torch.nn.functional.pad(preemphasized, (0, last_pad_batch))
+
+ windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=self.win_length]
+
+ smoothed = windowed * self.window[None, None, :] # [B, T', W]
# The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
- last_win = torch.hann_window(last_win_size, periodic=False, dtype=torch.float64).to(
- self.window.device, torch.float32
- )
- last_win = torch.nn.functional.pad(last_win, (0, last_pad)) # [W]
- smoothed = torch.cat([smoothed, (windowed[:, -1] * last_win[None, :])[:, None]], dim=1) # [B, T', W]
+ for i, (last_w_size, last_p, res_l) in enumerate(zip(last_win_size, last_pad, res_length)):
+ last_win = torch.hann_window(last_w_size, periodic=False, dtype=torch.float64).to(
+ self.window.device, torch.float32
+ )
+ last_win = torch.nn.functional.pad(last_win, (0, last_p)) # [W]
+ smoothed[i, res_l - 1] = windowed[i, res_l - 1] * last_win[None, :]
# compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling
fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1]