Skip to content

Commit

Permalink
Reorganization and fix recompilation every step in the training loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Dec 6, 2024
1 parent ffa8519 commit c51958a
Show file tree
Hide file tree
Showing 11 changed files with 976 additions and 833 deletions.
230 changes: 230 additions & 0 deletions f5_tts_mlx/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from __future__ import annotations
from functools import lru_cache
import math
from typing import Optional

import mlx.core as mx
import mlx.nn as nn

import numpy as np


@lru_cache(maxsize=None)
def mel_filters(
sample_rate: int,
n_fft: int,
n_mels: int,
f_min: float = 0,
f_max: Optional[float] = None,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> mx.array:
"""
Compute torch-compatible mel filterbanks.
Args:
sample_rate: Sampling rate of the audio.
n_fft: Number of FFT points.
n_mels: Number of mel bands.
f_min: Minimum frequency.
f_max: Maximum frequency.
norm: Normalization mode.
mel_scale: Mel scale type.
Returns:
mx.array of shape (n_mels, n_fft // 2 + 1) containing mel filterbanks.
"""

def hz_to_mel(freq, mel_scale="htk"):
if mel_scale == "htk":
return 2595.0 * math.log10(1.0 + freq / 700.0)

# slaney scale
f_min, f_sp = 0.0, 200.0 / 3
mels = (freq - f_min) / f_sp
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
return mels

def mel_to_hz(mels, mel_scale="htk"):
if mel_scale == "htk":
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)

# slaney scale
f_min, f_sp = 0.0, 200.0 / 3
freqs = f_min + f_sp * mels
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel))
return freqs

f_max = f_max or sample_rate / 2

# generate frequency points

n_freqs = n_fft // 2 + 1
all_freqs = mx.linspace(0, sample_rate // 2, n_freqs)

# convert frequencies to mel and back to hz

m_min = hz_to_mel(f_min, mel_scale)
m_max = hz_to_mel(f_max, mel_scale)
m_pts = mx.linspace(m_min, m_max, n_mels + 2)
f_pts = mel_to_hz(m_pts, mel_scale)

# compute slopes for filterbank

f_diff = f_pts[1:] - f_pts[:-1]
slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1)

# calculate overlapping triangular filters

down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
up_slopes = slopes[:, 2:] / f_diff[1:]
filterbank = mx.maximum(
mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes)
)

if norm == "slaney":
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
filterbank *= mx.expand_dims(enorm, 0)

filterbank = filterbank.moveaxis(0, 1)
return filterbank


@lru_cache(maxsize=None)
def hanning(size):
"""
Compute the Hanning window.
Args:
size: Size of the window.
Returns:
mx.array of shape (size,) containing the Hanning window.
"""
return mx.array(np.hanning(size + 1)[:-1])


def stft(
x,
window,
nperseg=256,
noverlap=None,
nfft=None,
pad_mode="constant",
):
"""
Compute the short-time Fourier transform of a signal.
Args:
x: mx.array of shape (t,) containing the input signal.
window: mx.array of shape (nperseg,) containing the window function.
nperseg: Number of samples per segment.
noverlap: Number of overlapping samples.
nfft: Number of FFT points.
pad_mode: Padding mode.
Returns:
mx.array of shape (t, nfft // 2 + 1) containing the short-time Fourier transform.
"""
if nfft is None:
nfft = nperseg
if noverlap is None:
noverlap = nfft // 4

def _pad(x, padding, pad_mode="constant"):
if pad_mode == "constant":
return mx.pad(x, [(padding, padding)])
elif pad_mode == "reflect":
prefix = x[1 : padding + 1][::-1]
suffix = x[-(padding + 1) : -1][::-1]
return mx.concatenate([prefix, x, suffix])
else:
raise ValueError(f"Invalid pad_mode {pad_mode}")

padding = nperseg // 2
x = _pad(x, padding, pad_mode)

strides = [noverlap, 1]
t = (x.size - nperseg + noverlap) // noverlap
shape = [t, nfft]
x = mx.as_strided(x, shape=shape, strides=strides)
return mx.fft.rfft(x * window)


def log_mel_spectrogram(
audio: mx.array,
sample_rate: int = 24_000,
n_mels: int = 100,
n_fft: int = 1024,
hop_length: int = 256,
padding: int = 0,
):
"""
Compute log-mel spectrograms for a batch of audio inputs.
Args:
audio: mx.array of shape [t] or [b, t] containing audio samples.
sample_rate: Sampling rate of the audio.
n_mels: Number of mel bands.
n_fft: Number of FFT points.
hop_length: Hop length between frames.
padding: Amount of padding to add to each audio signal.
Returns:
mx.array of shape (batch_size, n_mels, frames) containing log-mel spectrograms.
"""

if audio.ndim == 1:
audio = mx.expand_dims(audio, axis=0)

filters = mel_filters(
sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels, norm=None, mel_scale="htk"
)

batch = audio.shape[0]
outputs = []

for i in range(batch):
one_audio = audio[i]

if padding > 0:
one_audio = mx.pad(one_audio, (0, padding))

freqs = stft(one_audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length)
magnitudes = mx.abs(freqs[:-1, :])

mel_spec = mx.matmul(magnitudes, filters.T)
log_spec = mx.maximum(mel_spec, 1e-5).log()
outputs.append(log_spec)

max_seq_len = max([x.shape[1] for x in outputs])
outputs = [mx.pad(x, (0, max_seq_len - x.shape[1])) for x in outputs]
return mx.stack(outputs, axis=0)


class MelSpec(nn.Module):
def __init__(
self,
sample_rate=24_000,
n_fft=1024,
hop_length=256,
n_mels=100,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels

def __call__(self, audio: mx.array, **kwargs) -> mx.array:
return log_mel_spectrogram(
audio, n_mels=self.n_mels, n_fft=self.n_fft, hop_length=self.hop_length
)
29 changes: 11 additions & 18 deletions f5_tts_mlx/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from __future__ import annotations
from pathlib import Path
from random import random
from typing import Callable, Literal

import mlx.core as mx
Expand All @@ -19,9 +18,9 @@

from vocos_mlx import Vocos

from f5_tts_mlx.audio import MelSpec
from f5_tts_mlx.duration import DurationPredictor, DurationTransformer
from f5_tts_mlx.dit import DiT
from f5_tts_mlx.modules import MelSpec
from f5_tts_mlx.utils import (
exists,
fetch_from_hub,
Expand All @@ -33,7 +32,6 @@
pad_sequence,
)


# ode solvers


Expand Down Expand Up @@ -131,7 +129,6 @@ class F5TTS(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma=0.0,
audio_drop_prob=0.3,
cond_drop_prob=0.2,
num_channels=None,
Expand Down Expand Up @@ -160,9 +157,6 @@ def __init__(
dim = transformer.dim
self.dim = dim

# conditional flow related
self.sigma = sigma

# vocab map for tokenization
self._vocab_char_map = vocab_char_map

Expand All @@ -178,15 +172,15 @@ def __call__(
text: mx.array["b nt"] | list[str],
*,
lens: mx.array["b"] | None = None,
) -> tuple[mx.array, mx.array, mx.array]:
) -> mx.array:
# handle raw wave
if inp.ndim == 2:
inp = self._mel_spec(inp)
inp = rearrange(inp, "b d n -> b n d")
assert inp.shape[-1] == self.num_channels

batch, seq_len, dtype, σ1 = *inp.shape[:2], inp.dtype, self.sigma

batch, seq_len, dtype = *inp.shape[:2], inp.dtype
# handle text as string
if isinstance(text, list):
if exists(self._vocab_char_map):
Expand Down Expand Up @@ -230,15 +224,13 @@ def __call__(
)

# transformer and cfg training with a drop rate
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
if random() < self.cond_drop_prob:
drop_audio_cond = True
drop_text = True
else:
drop_text = False

# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
rand_audio_drop = mx.random.uniform(0, 1, (1,))
rand_cond_drop = mx.random.uniform(0, 1, (1,))
drop_audio_cond = rand_audio_drop < self.audio_drop_prob
drop_text = rand_cond_drop < self.cond_drop_prob
drop_audio_cond = drop_audio_cond | drop_text

pred = self.transformer(
x=φ,
cond=cond,
Expand All @@ -249,6 +241,7 @@ def __call__(
)

# flow matching loss

loss = nn.losses.mse_loss(pred, flow, reduction="none")

rand_span_mask = repeat(rand_span_mask, "b n -> b n d", d=self.num_channels)
Expand Down
54 changes: 54 additions & 0 deletions f5_tts_mlx/convnext_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

import mlx.core as mx
import mlx.nn as nn

# global response normalization


class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = mx.zeros((1, 1, dim))
self.beta = mx.zeros((1, 1, dim))

def __call__(self, x):
Gx = mx.linalg.norm(x, ord=2, axis=1, keepdims=True)
Nx = Gx / (Gx.mean(axis=-1, keepdims=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x


# ConvNeXt-v2 block


class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2

# depthwise conv
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
)
self.norm = nn.LayerNorm(dim, eps=1e-6)

# pointwise convs, implemented with linear layers
self.pwconv1 = nn.Linear(dim, intermediate_dim)
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)

def __call__(self, x: mx.array) -> mx.array:
residual = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
Loading

0 comments on commit c51958a

Please sign in to comment.