Skip to content

Commit

Permalink
Add RASR compatible feature extraction (#44)
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Zeyer <[email protected]>
  • Loading branch information
kuacakuaca and albertz authored Mar 19, 2024
1 parent ea4354c commit a15dad4
Show file tree
Hide file tree
Showing 2 changed files with 933 additions and 2 deletions.
117 changes: 116 additions & 1 deletion i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
__all__ = [
"LogMelFeatureExtractionV1",
"LogMelFeatureExtractionV1Config",
"RasrCompatibleLogMelFeatureExtractionV1",
"RasrCompatibleLogMelFeatureExtractionV1Config",
]

import math
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -111,3 +117,112 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
length = ((length - self.n_fft) // self.hop_length) + 1

return feature_data, length.int()


@dataclass
class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration):
"""
Attributes:
sample_rate: audio sample rate in Hz
win_size: window size in seconds
hop_size: window shift in seconds
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
alpha: preemphasis weight
"""

sample_rate: int
win_size: float
hop_size: float
min_amp: float
num_filters: int
alpha: float = 1.0

def __post_init__(self) -> None:
super().__post_init__()
assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive"
assert self.num_filters > 0, "number of filters needs to be positive"
assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense"


class RasrCompatibleLogMelFeatureExtractionV1(nn.Module):
"""
Rasr-compatible log-mel feature extraction using log10. Does not use torchaudio.
"""

def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
super().__init__()

self.sample_rate = int(cfg.sample_rate)
self.hop_length = int(cfg.hop_size * cfg.sample_rate)
self.min_amp = cfg.min_amp
self.win_length = int(cfg.win_size * cfg.sample_rate)
self.n_fft = 2 ** math.ceil(
math.log2(self.win_length)
) # smallest power if two which is greater than or equal to win_length
self.alpha = cfg.alpha

self.register_buffer(
"mel_basis",
torch.tensor(
filters.mel(
sr=cfg.sample_rate,
n_fft=self.n_fft,
n_mels=cfg.num_filters,
fmin=0,
fmax=cfg.sample_rate // 2,
htk=True,
norm=None,
),
),
)
self.register_buffer(
"window", torch.hann_window(self.win_length, periodic=False, dtype=torch.float64).to(torch.float32)
)

def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param raw_audio: [B, T]
:param length: in samples [B]
:return features as [B,T,F] and length in frames [B]
"""
assert raw_audio.shape[1] > 0 # also same for length
res_size = max(raw_audio.shape[1] - self.win_length + self.hop_length - 1, 0) // self.hop_length + 1
res_length = (
torch.maximum(length - self.win_length + self.hop_length - 1, torch.zeros_like(length)) // self.hop_length
+ 1
)

# preemphasize
preemphasized = raw_audio.clone()
preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1]
preemphasized[..., 0] = 0.0

# zero pad for the last frame of each sequence in the batch
last_win_size = length - (res_length - 1) * self.hop_length # [B]
last_pad = self.win_length - last_win_size # [B]

# zero pad for the whole batch
last_pad_batch = self.win_length - (preemphasized.shape[1] - (res_size - 1) * self.hop_length)
padded = torch.nn.functional.pad(preemphasized, (0, last_pad_batch))

windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=self.win_length]

smoothed = windowed * self.window[None, None, :] # [B, T', W]

# The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
for i, (last_w_size, last_p, res_l) in enumerate(zip(last_win_size, last_pad, res_length)):
last_win = torch.hann_window(last_w_size, periodic=False, dtype=torch.float64).to(
self.window.device, torch.float32
)
last_win = torch.nn.functional.pad(last_win, (0, last_p)) # [W]
smoothed[i, res_l - 1] = windowed[i, res_l - 1] * last_win[None, :]

# compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling
fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1]
amplitude_spectrum = torch.abs(fft) # [B, T', F=n_fft//2+1]

melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters]
log_melspec = torch.log10(melspec + self.min_amp)

return log_melspec, res_length
Loading

0 comments on commit a15dad4

Please sign in to comment.