From c51958ab6067bf5e8e060854278b16851fdc2d8c Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Fri, 6 Dec 2024 10:45:29 -0800 Subject: [PATCH] Reorganization and fix recompilation every step in the training loop. --- f5_tts_mlx/audio.py | 230 ++++++++++++++ f5_tts_mlx/cfm.py | 29 +- f5_tts_mlx/convnext_v2.py | 54 ++++ f5_tts_mlx/data.py | 81 ++--- f5_tts_mlx/dit.py | 236 +++++++++++--- f5_tts_mlx/duration.py | 5 +- f5_tts_mlx/duration_trainer.py | 171 +++++++++++ f5_tts_mlx/modules.py | 542 --------------------------------- f5_tts_mlx/rope.py | 107 +++++++ f5_tts_mlx/trainer.py | 307 +++++++++---------- train_libritts_small.py | 47 ++- 11 files changed, 976 insertions(+), 833 deletions(-) create mode 100644 f5_tts_mlx/audio.py create mode 100644 f5_tts_mlx/convnext_v2.py create mode 100644 f5_tts_mlx/duration_trainer.py delete mode 100644 f5_tts_mlx/modules.py create mode 100644 f5_tts_mlx/rope.py diff --git a/f5_tts_mlx/audio.py b/f5_tts_mlx/audio.py new file mode 100644 index 0000000..5a21e50 --- /dev/null +++ b/f5_tts_mlx/audio.py @@ -0,0 +1,230 @@ +from __future__ import annotations +from functools import lru_cache +import math +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +import numpy as np + + +@lru_cache(maxsize=None) +def mel_filters( + sample_rate: int, + n_fft: int, + n_mels: int, + f_min: float = 0, + f_max: Optional[float] = None, + norm: Optional[str] = None, + mel_scale: str = "htk", +) -> mx.array: + """ + Compute torch-compatible mel filterbanks. + + Args: + sample_rate: Sampling rate of the audio. + n_fft: Number of FFT points. + n_mels: Number of mel bands. + f_min: Minimum frequency. + f_max: Maximum frequency. + norm: Normalization mode. + mel_scale: Mel scale type. + + Returns: + mx.array of shape (n_mels, n_fft // 2 + 1) containing mel filterbanks. + """ + + def hz_to_mel(freq, mel_scale="htk"): + if mel_scale == "htk": + return 2595.0 * math.log10(1.0 + freq / 700.0) + + # slaney scale + f_min, f_sp = 0.0, 200.0 / 3 + mels = (freq - f_min) / f_sp + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz) / logstep + return mels + + def mel_to_hz(mels, mel_scale="htk"): + if mel_scale == "htk": + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + + # slaney scale + f_min, f_sp = 0.0, 200.0 / 3 + freqs = f_min + f_sp * mels + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + log_t = mels >= min_log_mel + freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel)) + return freqs + + f_max = f_max or sample_rate / 2 + + # generate frequency points + + n_freqs = n_fft // 2 + 1 + all_freqs = mx.linspace(0, sample_rate // 2, n_freqs) + + # convert frequencies to mel and back to hz + + m_min = hz_to_mel(f_min, mel_scale) + m_max = hz_to_mel(f_max, mel_scale) + m_pts = mx.linspace(m_min, m_max, n_mels + 2) + f_pts = mel_to_hz(m_pts, mel_scale) + + # compute slopes for filterbank + + f_diff = f_pts[1:] - f_pts[:-1] + slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1) + + # calculate overlapping triangular filters + + down_slopes = (-slopes[:, :-2]) / f_diff[:-1] + up_slopes = slopes[:, 2:] / f_diff[1:] + filterbank = mx.maximum( + mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes) + ) + + if norm == "slaney": + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + filterbank *= mx.expand_dims(enorm, 0) + + filterbank = filterbank.moveaxis(0, 1) + return filterbank + + +@lru_cache(maxsize=None) +def hanning(size): + """ + Compute the Hanning window. + + Args: + size: Size of the window. + + Returns: + mx.array of shape (size,) containing the Hanning window. + """ + return mx.array(np.hanning(size + 1)[:-1]) + + +def stft( + x, + window, + nperseg=256, + noverlap=None, + nfft=None, + pad_mode="constant", +): + """ + Compute the short-time Fourier transform of a signal. + + Args: + x: mx.array of shape (t,) containing the input signal. + window: mx.array of shape (nperseg,) containing the window function. + nperseg: Number of samples per segment. + noverlap: Number of overlapping samples. + nfft: Number of FFT points. + pad_mode: Padding mode. + + Returns: + mx.array of shape (t, nfft // 2 + 1) containing the short-time Fourier transform. + """ + if nfft is None: + nfft = nperseg + if noverlap is None: + noverlap = nfft // 4 + + def _pad(x, padding, pad_mode="constant"): + if pad_mode == "constant": + return mx.pad(x, [(padding, padding)]) + elif pad_mode == "reflect": + prefix = x[1 : padding + 1][::-1] + suffix = x[-(padding + 1) : -1][::-1] + return mx.concatenate([prefix, x, suffix]) + else: + raise ValueError(f"Invalid pad_mode {pad_mode}") + + padding = nperseg // 2 + x = _pad(x, padding, pad_mode) + + strides = [noverlap, 1] + t = (x.size - nperseg + noverlap) // noverlap + shape = [t, nfft] + x = mx.as_strided(x, shape=shape, strides=strides) + return mx.fft.rfft(x * window) + + +def log_mel_spectrogram( + audio: mx.array, + sample_rate: int = 24_000, + n_mels: int = 100, + n_fft: int = 1024, + hop_length: int = 256, + padding: int = 0, +): + """ + Compute log-mel spectrograms for a batch of audio inputs. + + Args: + audio: mx.array of shape [t] or [b, t] containing audio samples. + sample_rate: Sampling rate of the audio. + n_mels: Number of mel bands. + n_fft: Number of FFT points. + hop_length: Hop length between frames. + padding: Amount of padding to add to each audio signal. + + Returns: + mx.array of shape (batch_size, n_mels, frames) containing log-mel spectrograms. + """ + + if audio.ndim == 1: + audio = mx.expand_dims(audio, axis=0) + + filters = mel_filters( + sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels, norm=None, mel_scale="htk" + ) + + batch = audio.shape[0] + outputs = [] + + for i in range(batch): + one_audio = audio[i] + + if padding > 0: + one_audio = mx.pad(one_audio, (0, padding)) + + freqs = stft(one_audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) + magnitudes = mx.abs(freqs[:-1, :]) + + mel_spec = mx.matmul(magnitudes, filters.T) + log_spec = mx.maximum(mel_spec, 1e-5).log() + outputs.append(log_spec) + + max_seq_len = max([x.shape[1] for x in outputs]) + outputs = [mx.pad(x, (0, max_seq_len - x.shape[1])) for x in outputs] + return mx.stack(outputs, axis=0) + + +class MelSpec(nn.Module): + def __init__( + self, + sample_rate=24_000, + n_fft=1024, + hop_length=256, + n_mels=100, + ): + super().__init__() + self.sample_rate = sample_rate + self.n_fft = n_fft + self.hop_length = hop_length + self.n_mels = n_mels + + def __call__(self, audio: mx.array, **kwargs) -> mx.array: + return log_mel_spectrogram( + audio, n_mels=self.n_mels, n_fft=self.n_fft, hop_length=self.hop_length + ) diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index c929811..9025429 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -9,7 +9,6 @@ from __future__ import annotations from pathlib import Path -from random import random from typing import Callable, Literal import mlx.core as mx @@ -19,9 +18,9 @@ from vocos_mlx import Vocos +from f5_tts_mlx.audio import MelSpec from f5_tts_mlx.duration import DurationPredictor, DurationTransformer from f5_tts_mlx.dit import DiT -from f5_tts_mlx.modules import MelSpec from f5_tts_mlx.utils import ( exists, fetch_from_hub, @@ -33,7 +32,6 @@ pad_sequence, ) - # ode solvers @@ -131,7 +129,6 @@ class F5TTS(nn.Module): def __init__( self, transformer: nn.Module, - sigma=0.0, audio_drop_prob=0.3, cond_drop_prob=0.2, num_channels=None, @@ -160,9 +157,6 @@ def __init__( dim = transformer.dim self.dim = dim - # conditional flow related - self.sigma = sigma - # vocab map for tokenization self._vocab_char_map = vocab_char_map @@ -178,15 +172,15 @@ def __call__( text: mx.array["b nt"] | list[str], *, lens: mx.array["b"] | None = None, - ) -> tuple[mx.array, mx.array, mx.array]: + ) -> mx.array: # handle raw wave if inp.ndim == 2: inp = self._mel_spec(inp) inp = rearrange(inp, "b d n -> b n d") assert inp.shape[-1] == self.num_channels - batch, seq_len, dtype, σ1 = *inp.shape[:2], inp.dtype, self.sigma - + batch, seq_len, dtype = *inp.shape[:2], inp.dtype + # handle text as string if isinstance(text, list): if exists(self._vocab_char_map): @@ -230,15 +224,13 @@ def __call__( ) # transformer and cfg training with a drop rate - drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper - if random() < self.cond_drop_prob: - drop_audio_cond = True - drop_text = True - else: - drop_text = False - # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here - # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences + rand_audio_drop = mx.random.uniform(0, 1, (1,)) + rand_cond_drop = mx.random.uniform(0, 1, (1,)) + drop_audio_cond = rand_audio_drop < self.audio_drop_prob + drop_text = rand_cond_drop < self.cond_drop_prob + drop_audio_cond = drop_audio_cond | drop_text + pred = self.transformer( x=φ, cond=cond, @@ -249,6 +241,7 @@ def __call__( ) # flow matching loss + loss = nn.losses.mse_loss(pred, flow, reduction="none") rand_span_mask = repeat(rand_span_mask, "b n -> b n d", d=self.num_channels) diff --git a/f5_tts_mlx/convnext_v2.py b/f5_tts_mlx/convnext_v2.py new file mode 100644 index 0000000..3b55a92 --- /dev/null +++ b/f5_tts_mlx/convnext_v2.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import mlx.core as mx +import mlx.nn as nn + +# global response normalization + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = mx.zeros((1, 1, dim)) + self.beta = mx.zeros((1, 1, dim)) + + def __call__(self, x): + Gx = mx.linalg.norm(x, ord=2, axis=1, keepdims=True) + Nx = Gx / (Gx.mean(axis=-1, keepdims=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-v2 block + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + + # depthwise conv + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) + self.norm = nn.LayerNorm(dim, eps=1e-6) + + # pointwise convs, implemented with linear layers + self.pwconv1 = nn.Linear(dim, intermediate_dim) + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def __call__(self, x: mx.array) -> mx.array: + residual = x + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x diff --git a/f5_tts_mlx/data.py b/f5_tts_mlx/data.py index e6b727a..b5e5dcc 100644 --- a/f5_tts_mlx/data.py +++ b/f5_tts_mlx/data.py @@ -6,8 +6,7 @@ import mlx.core as mx import mlx.data as dx import numpy as np - -from einops.array_api import rearrange +import os from mlx.data.datasets.common import ( CACHE_DIR, @@ -17,11 +16,10 @@ gzip_decompress, ) -from f5_tts_mlx.modules import log_mel_spectrogram +from f5_tts_mlx.audio import log_mel_spectrogram +from f5_tts_mlx.utils import list_str_to_idx SAMPLE_RATE = 24_000 -HOP_LENGTH = 256 -FRAMES_PER_SECOND = SAMPLE_RATE / HOP_LENGTH # utilities @@ -35,17 +33,23 @@ def files_with_extensions(dir: Path, extensions: list = ["wav"]): return [{"file": mx.array(f.as_posix().encode("utf-8"))} for f in files] -# transforms +def calculate_wav_duration(file_path): + # assumptions + bit_depth = 16 + num_channels = 1 + bytes_per_sample = bit_depth // 8 + bytes_per_second = SAMPLE_RATE * num_channels * bytes_per_sample -def _load_transcript_file(sample): - audio_file = Path(bytes(sample["file"]).decode("utf-8")) - if not audio_file.suffix == ".wav": - return dict() + file_size = os.path.getsize(file_path) + duration_seconds = file_size / bytes_per_second - transcript_file = audio_file.with_suffix(".normalized.txt") - sample["transcript_file"] = transcript_file.as_posix().encode("utf-8") - return sample + return duration_seconds + + +# transforms + +vocab = {chr(i): i for i in range(256)} def _load_transcript(sample): @@ -57,43 +61,25 @@ def _load_transcript(sample): if not transcript_file.exists(): return dict() - transcript = np.array( - list(transcript_file.read_text().strip().encode("utf-8")), dtype=np.int8 - ) - sample["transcript"] = transcript - + text = transcript_file.read_text().strip() + sample["transcript"] = mx.array(list_str_to_idx(text, vocab)) return sample -def _load_cached_mel_spec(sample, max_duration=5): +def _load_audio_file(sample, max_duration=10): audio_file = Path(bytes(sample["file"]).decode("utf-8")) - mel_file = audio_file.with_suffix(".mel.npy.npz") - mel_spec = mx.load(mel_file.as_posix())["arr_0"] - mel_len = mel_spec.shape[1] - if mel_len > int(max_duration * FRAMES_PER_SECOND): + duration = calculate_wav_duration(audio_file) + if duration > max_duration: return dict() - sample["mel_spec"] = mel_spec - sample["mel_len"] = mel_len - del sample["file"] - return sample - - -def _load_audio_file(sample): - audio_file = Path(bytes(sample["file"]).decode("utf-8")) audio = np.array(list(audio_file.read_bytes()), dtype=np.uint8) sample["audio"] = audio return sample -def _to_mel_spec(sample, max_duration=10): - audio = rearrange(mx.array(sample["audio"]), "t 1 -> t") - mel_len = audio.shape[0] // HOP_LENGTH - - if mel_len > int(max_duration * FRAMES_PER_SECOND): - return dict() - +def _to_mel_spec(sample): + audio = mx.squeeze(mx.array(sample["audio"]), axis=-1) mel_spec = log_mel_spectrogram(audio) sample["mel_spec"] = mel_spec sample["mel_len"] = mel_spec.shape[1] @@ -205,20 +191,7 @@ def load_libritts_r( tar.extractall(path=target.parent) tar.close() - files = files_with_extensions(path) - print(f"Found {len(files)} files at {path}") - - dset = ( - dx.buffer_from_vector(files) - .to_stream() - .sample_transform(lambda s: s if bytes(s["file"]).endswith(b".wav") else dict()) - .sample_transform(_load_transcript) - .sample_transform(_load_audio_file) - .load_audio("audio", from_memory=True) - .sample_transform(partial(_to_mel_spec, max_duration=max_duration)) - ) - - return dset + return load_dir(path, max_duration=max_duration), path def load_dir(dir=None, max_duration=30): @@ -232,7 +205,9 @@ def load_dir(dir=None, max_duration=30): .to_stream() .sample_transform(lambda s: s if bytes(s["file"]).endswith(b".wav") else dict()) .sample_transform(_load_transcript) - .sample_transform(partial(_load_cached_mel_spec, max_duration=max_duration)) + .sample_transform(partial(_load_audio_file, max_duration=max_duration)) + .load_audio("audio", from_memory=True) + .sample_transform(_to_mel_spec) ) return dset diff --git a/f5_tts_mlx/dit.py b/f5_tts_mlx/dit.py index fa5d72b..9dc628d 100644 --- a/f5_tts_mlx/dit.py +++ b/f5_tts_mlx/dit.py @@ -8,23 +8,185 @@ """ from __future__ import annotations +import math import mlx.core as mx import mlx.nn as nn -from einops.array_api import repeat +from einops.array_api import rearrange, repeat -from f5_tts_mlx.modules import ( - Attention, - FeedForward, +from f5_tts_mlx.convnext_v2 import ConvNeXtV2Block +from f5_tts_mlx.rope import ( RotaryEmbedding, - TimestepEmbedding, - ConvNeXtV2Block, - ConvPositionEmbedding, - precompute_freqs_cis, + apply_rotary_pos_emb, get_pos_embed_indices, + precompute_freqs_cis, ) +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array: + if mask is not None: + mask = mask[..., None] + x = x * mask + + out = self.conv1d(x) + + if mask is not None: + out = out * mask + + return out + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def __call__(self, x, scale=1000): + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = mx.exp(mx.arange(half_dim) * -emb) + emb = scale * mx.expand_dims(x, axis=1) * mx.expand_dims(emb, axis=0) + emb = mx.concatenate([emb.sin(), emb.cos()], axis=-1) + return emb + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + + def __call__(self, timestep: mx.array) -> mx.array: + time_hidden = self.time_embed(timestep) + time = self.time_mlp(time_hidden) + return time + + +# feed forward + + +class FeedForward(nn.Module): + def __init__( + self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approx=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def __call__(self, x: mx.array) -> mx.array: + return self.ff(x) + + +# attention + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + self.to_out = nn.Sequential(nn.Linear(self.inner_dim, dim), nn.Dropout(dropout)) + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + rope: mx.array | None = None, + ) -> mx.array: + batch, seq_len, _ = x.shape + + # `sample` projections. + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + ( + xpos_scale, + xpos_scale**-1.0, + ) + if xpos_scale is not None + else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + query = rearrange(query, "b n (h d) -> b h n d", h=self.heads) + key = rearrange(key, "b n (h d) -> b h n d", h=self.heads) + value = rearrange(value, "b n (h d) -> b h n d", h=self.heads) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + attn_mask = rearrange(attn_mask, "b n -> b () () n") + attn_mask = repeat(attn_mask, "b () () n -> b h () n", h=self.heads) + else: + attn_mask = None + + scale_factor = 1 / mx.sqrt(query.shape[-1]) + + x = mx.fast.scaled_dot_product_attention( + q=query, k=key, v=value, scale=scale_factor, mask=attn_mask + ) + x = x.transpose(0, 2, 1, 3).reshape(batch, seq_len, -1).astype(query.dtype) + + # linear proj + x = self.to_out(x) + + if attn_mask is not None: + mask = rearrange(mask, "b n -> b n 1") + x = mx.where(mask, x, 0.0) + + return x + + # Text embedding @@ -48,22 +210,21 @@ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): else: self.extra_modeling = False - def __call__(self, text: int["b nt"], seq_len, drop_text=False): + def __call__(self, text, seq_len, drop_text=False): batch, text_len = text.shape[0], text.shape[1] - text = ( - text + 1 - ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[ - :, :seq_len - ] # curtail if character tokens are more than the mel spec tokens - text = mx.pad(text, [(0, 0), (0, seq_len - text_len)], constant_values=0) - if drop_text: # cfg for text - text = mx.zeros_like(text) + # use 0 as filler token. we rely on text being padded with -1 values. + text = text + 1 + # curtail if character tokens are more than the mel spec tokens + text = text[:, :seq_len] + + text = mx.pad(text, [(0, 0), (0, seq_len - text_len)], constant_values=0) + + # cfg for text + text = mx.where(drop_text, mx.zeros_like(text), text) text = self.text_embed(text) # b n -> b n d - # possible extra modeling if self.extra_modeling: # sinus pos emb batch_start = mx.zeros((batch,), dtype=mx.int32) @@ -73,7 +234,7 @@ def __call__(self, text: int["b nt"], seq_len, drop_text=False): text_pos_embed = self._freqs_cis[pos_idx] text = text + text_pos_embed - # convnextv2 blocks + # convnext v2 blocks text = self.text_blocks(text) return text @@ -90,14 +251,13 @@ def __init__(self, mel_dim, text_dim, out_dim): def __call__( self, - x: float["b n d"], - cond: float["b n d"], - text_embed: float["b n d"], + x: mx.array, # b n d + cond: mx.array, # b n d + text_embed: mx.array, # b n d drop_audio_cond=False, ): - if drop_audio_cond: # cfg for cond audio - cond = mx.zeros_like(cond) - + # cfg for cond audio + cond = mx.where(drop_audio_cond, mx.zeros_like(cond), cond) x = self.proj(mx.concatenate((x, cond, text_embed), axis=-1)) x = self.conv_pos_embed(x) + x return x @@ -205,18 +365,17 @@ def __init__( text_num_embeds=256, text_dim=None, conv_layers=0, - long_skip_connection=False, ): super().__init__() - self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim + + self.time_embed = TimestepEmbedding(dim) self.text_embed = TextEmbedding( text_num_embeds, text_dim, conv_layers=conv_layers ) self.input_embed = InputEmbedding(mel_dim, text_dim, dim) - self.rotary_embed = RotaryEmbedding(dim_head) self.dim = dim @@ -232,22 +391,19 @@ def __init__( ) for _ in range(depth) ] - self.long_skip_connection = ( - nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None - ) self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) def __call__( self, - x: float["b n d"], # nosied input audio - cond: float["b n d"], # masked cond audio - text: int["b nt"], # text - time: float["b"] | float[""], # time step + x: mx.array, # b n d, nosied input audio + cond: mx.array, # b n d, masked cond audio + text: mx.array, # b nt, text + time: mx.array, # b, time step drop_audio_cond, # cfg for cond audio drop_text, # cfg for text - mask: bool["b n"] | None = None, + mask: mx.array | None = None, # b n ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: @@ -260,15 +416,9 @@ def __call__( rope = self.rotary_embed.forward_from_seq_len(seq_len) - if self.long_skip_connection is not None: - residual = x - for block in self.transformer_blocks: x = block(x, t, mask=mask, rope=rope) - if self.long_skip_connection is not None: - x = self.long_skip_connection(mx.concatenate((x, residual), axis=-1)) - x = self.norm_out(x, t) output = self.proj_out(x) diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py index 91cfd33..c9158bd 100644 --- a/f5_tts_mlx/duration.py +++ b/f5_tts_mlx/duration.py @@ -15,8 +15,9 @@ from einops.array_api import rearrange, repeat import einx -from f5_tts_mlx.dit import TextEmbedding, ConvPositionEmbedding -from f5_tts_mlx.modules import Attention, FeedForward, MelSpec, RotaryEmbedding +from f5_tts_mlx.audio import MelSpec +from f5_tts_mlx.dit import TextEmbedding, ConvPositionEmbedding, Attention, FeedForward +from f5_tts_mlx.rope import RotaryEmbedding from f5_tts_mlx.utils import ( exists, default, diff --git a/f5_tts_mlx/duration_trainer.py b/f5_tts_mlx/duration_trainer.py new file mode 100644 index 0000000..9f672ed --- /dev/null +++ b/f5_tts_mlx/duration_trainer.py @@ -0,0 +1,171 @@ +from __future__ import annotations +import datetime +from functools import partial + +from einops.array_api import rearrange + +import mlx.core as mx +import mlx.nn as nn +from mlx.optimizers import ( + AdamW, + linear_schedule, + cosine_decay, + join_schedules, + clip_grad_norm, +) +from mlx.utils import tree_flatten + +from f5_tts_mlx.audio import MelSpec +from f5_tts_mlx.cfm import F5TTS +from f5_tts_mlx.duration import DurationPredictor + +import wandb + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +# trainer + + +class DurationTrainer: + def __init__( + self, + model: DurationPredictor, + num_warmup_steps=1000, + max_grad_norm=1.0, + sample_rate=24_000, + log_with_wandb=False, + ): + self.model = model + self.num_warmup_steps = num_warmup_steps + self.mel_spectrogram = MelSpec(sample_rate=sample_rate) + self.max_grad_norm = max_grad_norm + self.log_with_wandb = log_with_wandb + + def save_checkpoint(self, step, finetune=False): + mx.save_safetensors( + f"f5tts_duration_{step}", + dict(tree_flatten(self.model.trainable_parameters())), + ) + + def load_checkpoint(self, step): + params = mx.load(f"f5tts_duration_{step}.saftensors") + self.model.load_weights(params) + self.model.eval() + + def train( + self, + train_dataset, + learning_rate=1e-4, + weight_decay=1e-2, + total_steps=100_000, + batch_size=8, + log_every=10, + save_every=1000, + checkpoint: int | None = None, + ): + if self.log_with_wandb: + wandb.init( + project="f5tts_duration", + config=dict( + learning_rate=learning_rate, + total_steps=total_steps, + batch_size=batch_size, + ), + ) + + decay_steps = total_steps - self.num_warmup_steps + + warmup_scheduler = linear_schedule( + init=1e-8, + end=learning_rate, + steps=self.num_warmup_steps, + ) + decay_scheduler = cosine_decay(init=learning_rate, decay_steps=decay_steps) + scheduler = join_schedules( + schedules=[warmup_scheduler, decay_scheduler], + boundaries=[self.num_warmup_steps], + ) + self.optimizer = AdamW(learning_rate=scheduler, weight_decay=weight_decay) + + if checkpoint is not None: + self.load_checkpoint(checkpoint) + start_step = checkpoint + else: + start_step = 0 + + global_step = start_step + + def loss_fn(model: F5TTS, mel_spec, text, lens): + loss = model(mel_spec, text=text, lens=lens, return_loss=True) + return loss + + state = [self.model.state, self.optimizer.state, mx.random.state] + + @partial(mx.compile, inputs=state, outputs=state) + def train_step(mel_spec, text_inputs, mel_lens): + loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) + loss, grads = loss_and_grad_fn( + self.model, mel_spec, text=text_inputs, lens=mel_lens + ) + + if self.max_grad_norm > 0: + grads, _ = clip_grad_norm(grads, max_norm=self.max_grad_norm) + + self.optimizer.update(self.model, grads) + + return loss + + training_start_date = datetime.datetime.now() + log_start_date = datetime.datetime.now() + + for batch in train_dataset: + effective_batch_size = batch["transcript"].shape[0] + text_inputs = [ + bytes(batch["transcript"][i]).decode("utf-8") + for i in range(effective_batch_size) + ] + + mel_spec = rearrange(mx.array(batch["mel_spec"]), "b 1 n c -> b n c") + mel_lens = mx.array(batch["mel_len"], dtype=mx.int32) + + loss = train_step(mel_spec, text_inputs, mel_lens) + mx.eval(state) + # mx.eval(self.model.parameters(), self.optimizer.state) + + if self.log_with_wandb: + wandb.log( + { + "loss": loss.item(), + "lr": self.optimizer.learning_rate.item(), + "batch_len": mel_lens.sum().item(), + }, + step=global_step, + ) + + if global_step > 0 and global_step % log_every == 0: + elapsed_time = datetime.datetime.now() - log_start_date + log_start_date = datetime.datetime.now() + + print( + f"step {global_step}: loss = {loss.item():.4f}, sec per step = {(log_every / elapsed_time.seconds):.2f}" + ) + + global_step += 1 + + if global_step % save_every == 0: + self.save_checkpoint(global_step) + + if global_step >= total_steps: + break + + if self.log_with_wandb: + wandb.finish() + + print(f"Training complete in {datetime.datetime.now() - training_start_date}") diff --git a/f5_tts_mlx/modules.py b/f5_tts_mlx/modules.py deleted file mode 100644 index 5030e00..0000000 --- a/f5_tts_mlx/modules.py +++ /dev/null @@ -1,542 +0,0 @@ -from __future__ import annotations -from functools import lru_cache -import math -from typing import Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -import numpy as np - -from einops.array_api import rearrange, repeat - - -# rotary positional embedding related - - -class RotaryEmbedding(nn.Module): - def __init__( - self, - dim: int, - use_xpos: bool = False, - scale_base: int = 512, - interpolation_factor: float = 1.0, - base: float = 10000.0, - base_rescale_factor: float = 1.0, - ): - super().__init__() - base *= base_rescale_factor ** (dim / (dim - 2)) - self.inv_freq = 1.0 / (base ** (mx.arange(0, dim, 2).astype(mx.float32) / dim)) - - assert interpolation_factor >= 1.0 - self.interpolation_factor = interpolation_factor - - if not use_xpos: - self.scale = None - return - - scale = (mx.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - - self.scale_base = scale_base - self.scale = scale - - def forward_from_seq_len(self, seq_len: int) -> tuple[mx.array, float]: - t = mx.arange(seq_len) - return self(t) - - def __call__(self, t: mx.array) -> tuple[mx.array, float]: - max_pos = t.max() + 1 - - freqs = ( - mx.einsum("i , j -> i j", t.astype(self.inv_freq.dtype), self.inv_freq) - / self.interpolation_factor - ) - freqs = mx.stack((freqs, freqs), axis=-1) - freqs = rearrange(freqs, "... d r -> ... (d r)") - - if self.scale is None: - return freqs, 1.0 - - power = (t - (max_pos // 2)) / self.scale_base - scale = self.scale ** rearrange(power, "n -> n 1") - scale = mx.stack((scale, scale), axis=-1) - scale = rearrange(scale, "... d r -> ... (d r)") - - return freqs, scale - - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 -): - freqs = 1.0 / ( - theta ** (mx.arange(0, dim, 2)[: (dim // 2)].astype(mx.float32) / dim) - ) - t = mx.arange(end) - freqs = mx.outer(t, freqs).astype(mx.float32) - freqs_cos = freqs.cos() # real part - freqs_sin = freqs.sin() # imaginary part - return mx.concatenate([freqs_cos, freqs_sin], axis=-1) - - -def get_pos_embed_indices(start, length, max_pos, scale=1.0): - # length = length if isinstance(length, int) else length.max() - scale = scale * mx.ones_like(start) - pos = mx.expand_dims(start, axis=1) + ( - mx.expand_dims(mx.arange(length), axis=0) * mx.expand_dims(scale, axis=1) - ).astype(mx.int32) - # avoid extra long error. - pos = mx.where(pos < max_pos, pos, max_pos - 1) - return pos - - -def rotate_half(x): - x = rearrange(x, "... (d r) -> ... d r", r=2) - x1, x2 = [mx.squeeze(s, axis=-1) for s in mx.split(x, x.shape[-1], axis=-1)] - x = mx.stack([-x2, x1], axis=-1) - return rearrange(x, "... d r -> ... (d r)") - - -def apply_rotary_pos_emb(t, freqs, scale=1): - rot_dim, seq_len = freqs.shape[-1], t.shape[-2] - - freqs = freqs[-seq_len:, :] - scale = scale[-seq_len:, :] if isinstance(scale, mx.array) else scale - - if t.ndim == 4 and freqs.ndim == 3: - freqs = rearrange(freqs, "b n d -> b 1 n d") - - t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] - t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) - out = mx.concatenate((t, t_unrotated), axis=-1) - - return out - - -# mel spectrogram - - -@lru_cache(maxsize=None) -def mel_filters( - sample_rate: int, - n_fft: int, - n_mels: int, - f_min: float = 0, - f_max: Optional[float] = None, - norm: Optional[str] = None, - mel_scale: str = "htk", -) -> mx.array: - def hz_to_mel(freq, mel_scale="htk"): - if mel_scale == "htk": - return 2595.0 * math.log10(1.0 + freq / 700.0) - - # slaney scale - f_min, f_sp = 0.0, 200.0 / 3 - mels = (freq - f_min) / f_sp - min_log_hz = 1000.0 - min_log_mel = (min_log_hz - f_min) / f_sp - logstep = math.log(6.4) / 27.0 - if freq >= min_log_hz: - mels = min_log_mel + math.log(freq / min_log_hz) / logstep - return mels - - def mel_to_hz(mels, mel_scale="htk"): - if mel_scale == "htk": - return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) - - # slaney scale - f_min, f_sp = 0.0, 200.0 / 3 - freqs = f_min + f_sp * mels - min_log_hz = 1000.0 - min_log_mel = (min_log_hz - f_min) / f_sp - logstep = math.log(6.4) / 27.0 - log_t = mels >= min_log_mel - freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel)) - return freqs - - f_max = f_max or sample_rate / 2 - - # generate frequency points - - n_freqs = n_fft // 2 + 1 - all_freqs = mx.linspace(0, sample_rate // 2, n_freqs) - - # convert frequencies to mel and back to hz - - m_min = hz_to_mel(f_min, mel_scale) - m_max = hz_to_mel(f_max, mel_scale) - m_pts = mx.linspace(m_min, m_max, n_mels + 2) - f_pts = mel_to_hz(m_pts, mel_scale) - - # compute slopes for filterbank - - f_diff = f_pts[1:] - f_pts[:-1] - slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1) - - # calculate overlapping triangular filters - - down_slopes = (-slopes[:, :-2]) / f_diff[:-1] - up_slopes = slopes[:, 2:] / f_diff[1:] - filterbank = mx.maximum( - mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes) - ) - - if norm == "slaney": - enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) - filterbank *= mx.expand_dims(enorm, 0) - - filterbank = filterbank.moveaxis(0, 1) - return filterbank - - -@lru_cache(maxsize=None) -def hanning(size): - return mx.array(np.hanning(size + 1)[:-1]) - - -def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="constant"): - if nfft is None: - nfft = nperseg - if noverlap is None: - noverlap = nfft // 4 - - def _pad(x, padding, pad_mode="constant"): - if pad_mode == "constant": - return mx.pad(x, [(padding, padding)]) - elif pad_mode == "reflect": - prefix = x[1 : padding + 1][::-1] - suffix = x[-(padding + 1) : -1][::-1] - return mx.concatenate([prefix, x, suffix]) - else: - raise ValueError(f"Invalid pad_mode {pad_mode}") - - padding = nperseg // 2 - x = _pad(x, padding, pad_mode) - - strides = [noverlap, 1] - t = (x.size - nperseg + noverlap) // noverlap - shape = [t, nfft] - x = mx.as_strided(x, shape=shape, strides=strides) - return mx.fft.rfft(x * window) - - -def log_mel_spectrogram( - audio: Union[mx.array, np.ndarray], - sample_rate: int = 24_000, - n_mels: int = 100, - n_fft: int = 1024, - hop_length: int = 256, - padding: int = 0, - filterbank: mx.array | None = None, -): - if padding > 0: - audio = mx.pad(audio, (0, padding)) - - freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) - magnitudes = freqs[:-1, :].abs() - filters = filterbank if filterbank is not None else mel_filters( - sample_rate=sample_rate, - n_fft=n_fft, - n_mels=n_mels, - norm=None, - mel_scale="htk" - ) - mel_spec = magnitudes @ filters.T - log_spec = mx.maximum(mel_spec, 1e-5).log() - return mx.expand_dims(log_spec, axis=0) - - -class MelSpec(nn.Module): - def __init__( - self, - sample_rate=24_000, - n_fft=1024, - hop_length=256, - n_mels=100, - padding="center", - filterbank: mx.array | None = None, - ): - super().__init__() - if padding not in ["center", "same"]: - raise ValueError("Padding must be 'center' or 'same'.") - self.sample_rate = sample_rate - self.padding = padding - self.n_fft = n_fft - self.hop_length = hop_length - self.n_mels = n_mels - self.filterbank = filterbank - - def __call__(self, audio: mx.array, **kwargs) -> mx.array: - return log_mel_spectrogram( - audio, - n_mels=self.n_mels, - n_fft=self.n_fft, - hop_length=self.hop_length, - padding=0, - filterbank=self.filterbank, - ) - - -# sinusoidal position embedding - - -class SinusPositionEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def __call__(self, x, scale=1000): - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = mx.exp(mx.arange(half_dim) * -emb) - emb = scale * mx.expand_dims(x, axis=1) * mx.expand_dims(emb, axis=0) - emb = mx.concatenate([emb.sin(), emb.cos()], axis=-1) - return emb - - -# convolutional position embedding - - -class Rearrange(nn.Module): - def __init__(self, pattern): - super().__init__() - self.pattern = pattern - - def __call__(self, x: mx.array) -> mx.array: - return rearrange(x, self.pattern) - - -class ConvPositionEmbedding(nn.Module): - def __init__(self, dim, kernel_size=31, groups=16): - super().__init__() - assert kernel_size % 2 != 0 - self.conv1d = nn.Sequential( - nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), - nn.Mish(), - nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), - nn.Mish(), - ) - - def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array: - if mask is not None: - mask = mask[..., None] - x = x * mask - - out = self.conv1d(x) - - if mask is not None: - out = out * mask - - return out - - -# global response normalization - - -class GRN(nn.Module): - def __init__(self, dim): - super().__init__() - self.gamma = mx.zeros((1, 1, dim)) - self.beta = mx.zeros((1, 1, dim)) - - def __call__(self, x): - Gx = mx.linalg.norm(x, ord=2, axis=1, keepdims=True) - Nx = Gx / (Gx.mean(axis=-1, keepdims=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - - -# ConvNeXt-v2 block - - -class ConvNeXtV2Block(nn.Module): - def __init__( - self, - dim: int, - intermediate_dim: int, - dilation: int = 1, - ): - super().__init__() - padding = (dilation * (7 - 1)) // 2 - self.dwconv = nn.Conv1d( - dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation - ) # depthwise conv - self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear( - dim, intermediate_dim - ) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.grn = GRN(intermediate_dim) - self.pwconv2 = nn.Linear(intermediate_dim, dim) - - def __call__(self, x: mx.array) -> mx.array: - residual = x - x = self.dwconv(x) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x) - x = self.pwconv2(x) - return residual + x - - -# AdaLayerNormZero -# return with modulated x for attn input, and params for later mlp modulation - - -class AdaLayerNormZero(nn.Module): - def __init__(self, dim): - super().__init__() - self.silu = nn.SiLU() - self.linear = nn.Linear(dim, dim * 6) - self.norm = nn.LayerNorm(dim, affine=False, eps=1e-6) - - def __call__(self, x: mx.array, emb: mx.array | None = None) -> mx.array: - emb = self.linear(self.silu(emb)) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mx.split( - emb, 6, axis=1 - ) - - x = self.norm(x) * (1 + mx.expand_dims(scale_msa, axis=1)) + mx.expand_dims( - shift_msa, axis=1 - ) - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp - - -# AdaLayerNormZero for final layer -# return only with modulated x for attn input, cuz no more mlp modulation - - -class AdaLayerNormZero_Final(nn.Module): - def __init__(self, dim): - super().__init__() - self.silu = nn.SiLU() - self.linear = nn.Linear(dim, dim * 2) - self.norm = nn.LayerNorm(dim, affine=False, eps=1e-6) - - def __call__(self, x: mx.array, emb: mx.array | None = None) -> mx.array: - emb = self.linear(self.silu(emb)) - scale, shift = mx.split(emb, 2, axis=1) - - x = self.norm(x) * (1 + mx.expand_dims(scale, axis=1)) + mx.expand_dims( - shift, axis=1 - ) - return x - - -# feed forward - - -class FeedForward(nn.Module): - def __init__( - self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - activation = nn.GELU(approx=approximate) - project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) - self.ff = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def __call__(self, x: mx.array) -> mx.array: - return self.ff(x) - - -# attention - - -class Attention(nn.Module): - def __init__( - self, - dim: int, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - ): - super().__init__() - - self.dim = dim - self.heads = heads - self.inner_dim = dim_head * heads - self.dropout = dropout - - self.to_q = nn.Linear(dim, self.inner_dim) - self.to_k = nn.Linear(dim, self.inner_dim) - self.to_v = nn.Linear(dim, self.inner_dim) - - self.to_out = nn.Sequential(nn.Linear(self.inner_dim, dim), nn.Dropout(dropout)) - - def __call__( - self, - x: mx.array, - mask: mx.array | None = None, - rope: mx.array | None = None, - ) -> mx.array: - batch, seq_len, _ = x.shape - - # `sample` projections. - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # apply rotary position embedding - if rope is not None: - freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = ( - ( - xpos_scale, - xpos_scale**-1.0, - ) - if xpos_scale is not None - else (1.0, 1.0) - ) - - query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) - key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) - - # attention - query = rearrange(query, "b n (h d) -> b h n d", h=self.heads) - key = rearrange(key, "b n (h d) -> b h n d", h=self.heads) - value = rearrange(value, "b n (h d) -> b h n d", h=self.heads) - - # mask. e.g. inference got a batch with different target durations, mask out the padding - if mask is not None: - attn_mask = mask - attn_mask = rearrange(attn_mask, "b n -> b () () n") - attn_mask = repeat(attn_mask, "b () () n -> b h () n", h=self.heads) - else: - attn_mask = None - - scale_factor = 1 / math.sqrt(query.shape[-1]) - - x = mx.fast.scaled_dot_product_attention( - q=query, k=key, v=value, scale=scale_factor, mask=attn_mask - ) - x = x.transpose(0, 2, 1, 3).reshape(batch, seq_len, -1).astype(query.dtype) - - # linear proj - x = self.to_out(x) - - if attn_mask is not None: - mask = rearrange(mask, "b n -> b n 1") - x = mx.where(mask, x, 0.0) - - return x - - -# time step conditioning embedding - - -class TimestepEmbedding(nn.Module): - def __init__(self, dim, freq_embed_dim=256): - super().__init__() - self.time_embed = SinusPositionEmbedding(freq_embed_dim) - self.time_mlp = nn.Sequential( - nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) - ) - - def __call__(self, timestep: mx.array) -> mx.array: - time_hidden = self.time_embed(timestep) - time = self.time_mlp(time_hidden) # b d - return time diff --git a/f5_tts_mlx/rope.py b/f5_tts_mlx/rope.py new file mode 100644 index 0000000..26c57d4 --- /dev/null +++ b/f5_tts_mlx/rope.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import mlx.core as mx +import mlx.nn as nn + +from einops.array_api import rearrange + + +# rotary positional embedding related + + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim: int, + use_xpos: bool = False, + scale_base: int = 512, + interpolation_factor: float = 1.0, + base: float = 10000.0, + base_rescale_factor: float = 1.0, + ): + super().__init__() + base *= base_rescale_factor ** (dim / (dim - 2)) + self.inv_freq = 1.0 / (base ** (mx.arange(0, dim, 2).astype(mx.float32) / dim)) + + assert interpolation_factor >= 1.0 + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.scale = None + return + + scale = (mx.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.scale = scale + + def forward_from_seq_len(self, seq_len: int) -> tuple[mx.array, float]: + t = mx.arange(seq_len) + return self(t) + + def __call__(self, t: mx.array) -> tuple[mx.array, float]: + max_pos = t.max() + 1 + + freqs = ( + mx.einsum("i , j -> i j", t.astype(self.inv_freq.dtype), self.inv_freq) + / self.interpolation_factor + ) + freqs = mx.stack((freqs, freqs), axis=-1) + freqs = rearrange(freqs, "... d r -> ... (d r)") + + if self.scale is None: + return freqs, 1.0 + + power = (t - (max_pos // 2)) / self.scale_base + scale = self.scale ** rearrange(power, "n -> n 1") + scale = mx.stack((scale, scale), axis=-1) + scale = rearrange(scale, "... d r -> ... (d r)") + + return freqs, scale + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 +): + freqs = 1.0 / ( + theta ** (mx.arange(0, dim, 2)[: (dim // 2)].astype(mx.float32) / dim) + ) + t = mx.arange(end) + freqs = mx.outer(t, freqs).astype(mx.float32) + freqs_cos = freqs.cos() # real part + freqs_sin = freqs.sin() # imaginary part + return mx.concatenate([freqs_cos, freqs_sin], axis=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * mx.ones_like(start) + pos = mx.expand_dims(start, axis=1) + ( + mx.expand_dims(mx.arange(length), axis=0) * mx.expand_dims(scale, axis=1) + ).astype(mx.int32) + # avoid extra long error. + pos = mx.where(pos < max_pos, pos, max_pos - 1) + return pos + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = [mx.squeeze(s, axis=-1) for s in mx.split(x, x.shape[-1], axis=-1)] + x = mx.stack([-x2, x1], axis=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_pos_emb(t, freqs, scale=1): + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + + freqs = freqs[-seq_len:, :] + scale = scale[-seq_len:, :] if isinstance(scale, mx.array) else scale + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, "b n d -> b 1 n d") + + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = mx.concatenate((t, t_unrotated), axis=-1) + + return out diff --git a/f5_tts_mlx/trainer.py b/f5_tts_mlx/trainer.py index 2284f20..8bc631e 100644 --- a/f5_tts_mlx/trainer.py +++ b/f5_tts_mlx/trainer.py @@ -1,10 +1,12 @@ from __future__ import annotations import datetime from functools import partial +import io import os from pathlib import Path +from tqdm import tqdm -from einops.array_api import rearrange +import numpy as np import mlx.core as mx import mlx.nn as nn @@ -17,9 +19,15 @@ ) from mlx.utils import tree_flatten +from einops.array_api import rearrange + +from f5_tts_mlx.audio import MelSpec from f5_tts_mlx.cfm import F5TTS -from f5_tts_mlx.duration import DurationPredictor -from f5_tts_mlx.modules import MelSpec + +import soundfile as sf + +from PIL import Image +import matplotlib.pyplot as plt import wandb @@ -34,11 +42,17 @@ def default(v, d): # trainer +TARGET_RMS = 0.1 -class DurationTrainer: +SAMPLE_RATE = 24_000 +HOP_LENGTH = 256 +FRAMES_PER_SEC = SAMPLE_RATE / HOP_LENGTH + + +class F5TTSTrainer: def __init__( self, - model: DurationPredictor, + model: F5TTS, num_warmup_steps=1000, max_grad_norm=1.0, sample_rate=24_000, @@ -51,159 +65,100 @@ def __init__( self.log_with_wandb = log_with_wandb def save_checkpoint(self, step, finetune=False): + if Path("results").exists() is False: + os.makedirs("results") + mx.save_safetensors( - f"f5tts_duration_{step}", + f"results/f5tts_{step}", dict(tree_flatten(self.model.trainable_parameters())), ) def load_checkpoint(self, step): - params = mx.load(f"f5tts_duration_{step}.saftensors") - self.model.load_weights(params) + params = mx.load(f"results/f5tts_{step}.safetensors") + self.model.load_weights(list(params.items())) self.model.eval() - def train( + def generate_sample( self, - train_dataset, - learning_rate=1e-4, - weight_decay=1e-2, - total_steps=100_000, - batch_size=8, - log_every=10, - save_every=1000, - checkpoint: int | None = None, + sample_audio: str, + sample_ref_text: str, + sample_generation_text: str, + sample_generation_duration: float, + step: int, ): - if self.log_with_wandb: - wandb.init( - project="f5tts_duration", - config=dict( - learning_rate=learning_rate, - total_steps=total_steps, - batch_size=batch_size, - ), - ) + audio, _ = sf.read(sample_audio) + audio = mx.array(audio) + ref_audio_duration = audio.shape[0] / SAMPLE_RATE - decay_steps = total_steps - self.num_warmup_steps + rms = mx.sqrt(mx.mean(mx.square(audio))) + if rms < TARGET_RMS: + audio = audio * TARGET_RMS / rms - warmup_scheduler = linear_schedule( - init=1e-8, - end=learning_rate, - steps=self.num_warmup_steps, - ) - decay_scheduler = cosine_decay(init=learning_rate, decay_steps=decay_steps) - scheduler = join_schedules( - schedules=[warmup_scheduler, decay_scheduler], - boundaries=[self.num_warmup_steps], - ) - self.optimizer = AdamW(learning_rate=scheduler, weight_decay=weight_decay) - - if checkpoint is not None: - self.load_checkpoint(checkpoint) - start_step = checkpoint - else: - start_step = 0 - - global_step = start_step - - def loss_fn(model: F5TTS, mel_spec, text, lens): - loss = model(mel_spec, text=text, lens=lens, return_loss=True) - return loss - - state = [self.model.state, self.optimizer.state, mx.random.state] - - @partial(mx.compile, inputs=state, outputs=state) - def train_step(mel_spec, text_inputs, mel_lens): - loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) - loss, grads = loss_and_grad_fn( - self.model, mel_spec, text=text_inputs, lens=mel_lens - ) - - if self.max_grad_norm > 0: - grads, _ = clip_grad_norm(grads, max_norm=self.max_grad_norm) - - self.optimizer.update(self.model, grads) - - return loss - - training_start_date = datetime.datetime.now() - log_start_date = datetime.datetime.now() - - for batch in train_dataset: - effective_batch_size = batch["transcript"].shape[0] - text_inputs = [ - bytes(batch["transcript"][i]).decode("utf-8") - for i in range(effective_batch_size) - ] - - mel_spec = rearrange(mx.array(batch["mel_spec"]), "b 1 n c -> b n c") - mel_lens = mx.array(batch["mel_len"], dtype=mx.int32) - - loss = train_step(mel_spec, text_inputs, mel_lens) - mx.eval(state) - # mx.eval(self.model.parameters(), self.optimizer.state) + audio = mx.expand_dims(audio, axis=0) + text = [sample_ref_text + " " + sample_generation_text] - if self.log_with_wandb: - wandb.log( - { - "loss": loss.item(), - "lr": self.optimizer.learning_rate.item(), - "batch_len": mel_lens.sum().item(), - }, - step=global_step, - ) - - if global_step > 0 and global_step % log_every == 0: - elapsed_time = datetime.datetime.now() - log_start_date - log_start_date = datetime.datetime.now() - - print( - f"step {global_step}: loss = {loss.item():.4f}, sec per step = {(log_every / elapsed_time.seconds):.2f}" - ) - - global_step += 1 - - if global_step % save_every == 0: - self.save_checkpoint(global_step) - - if global_step >= total_steps: - break - - if self.log_with_wandb: - wandb.finish() + self.model.eval() - print(f"Training complete in {datetime.datetime.now() - training_start_date}") + start_date = datetime.datetime.now() + + wave, trajectories = self.model.sample( + audio, + text=text, + duration=int( + (ref_audio_duration + sample_generation_duration) * FRAMES_PER_SEC + ), + method="rk4", + steps=8, + cfg_strength=2, + speed=1, + sway_sampling_coef=-1.0, + ) + mx.eval([wave, trajectories]) + elapsed_time = (datetime.datetime.now() - start_date).total_seconds() + print(f"Generated sample at step {step} in {elapsed_time:0.1f}s") -SAMPLE_RATE = 24_000 + # save the generated audio + wave = wave[audio.shape[1] :] + os.makedirs("samples/audio", exist_ok=True) + sf.write( + f"samples/audio/step_{step}.wav", np.array(wave), samplerate=SAMPLE_RATE + ) -class F5TTSTrainer: - def __init__( - self, - model: DurationPredictor, - num_warmup_steps=1000, - max_grad_norm=1.0, - sample_rate=24_000, - log_with_wandb=False, - ): - self.model = model - self.num_warmup_steps = num_warmup_steps - self.mel_spectrogram = MelSpec(sample_rate=sample_rate) - self.max_grad_norm = max_grad_norm - self.log_with_wandb = log_with_wandb + # save a visualization of the trajectory - def save_checkpoint(self, step, finetune=False): - if Path("results").exists() is False: - os.makedirs("results") + frames = [] + + ref_audio_frame_len = audio.shape[1] // HOP_LENGTH - mx.save_safetensors( - f"results/f5tts_{step}", - dict(tree_flatten(self.model.trainable_parameters())), + for trajectory in trajectories: + plt.figure(figsize=(10, 4)) + plt.imshow( + np.array(trajectory[0, ref_audio_frame_len:]).T, + aspect="auto", + origin="lower", + interpolation="none", + ) + plt.yticks([]) + + buf = io.BytesIO() + plt.savefig(buf, format="png") + buf.seek(0) + + frames.append(Image.open(buf)) + plt.close() + + os.makedirs("samples/viz", exist_ok=True) + frames[0].save( + f"samples/viz/step_{step}.gif", + save_all=True, + append_images=frames[1:], + duration=300, + loop=0, ) - def load_checkpoint(self, step): - params = mx.load(f"f5tts_{step}.saftensors") - self.model.load_weights(params) - self.model.eval() + self.model.train() def train( self, @@ -211,8 +166,12 @@ def train( learning_rate=1e-4, weight_decay=1e-2, total_steps=1_000_000, - log_every=100, - save_every=5000, + save_every=10_000, + sample_every=5_000, + sample_reference_audio: str | None = None, + sample_reference_text: str | None = None, + sample_generation_text: str | None = None, + sample_generation_duration: float | None = None, checkpoint: int | None = None, ): if self.log_with_wandb: @@ -245,18 +204,23 @@ def train( start_step = 0 global_step = start_step - print(f"Starting training at step {global_step}") - def loss_fn(model: F5TTS, mel_spec, text, lens): + if global_step != 0: + print(f"Starting training at step {global_step}") + + def loss_fn(model, mel_spec, text, lens): return model(mel_spec, text=text, lens=lens) - # state = [self.model.state, self.optimizer.state, mx.random.state] + state = [self.model.state, self.optimizer.state, mx.random.state] - # @partial(mx.compile, inputs=state, outputs=state) + @partial(mx.compile, inputs=state, outputs=state) def train_step(mel_spec, text_inputs, mel_lens): loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) loss, grads = loss_and_grad_fn( - self.model, mel_spec, text=text_inputs, lens=mel_lens + self.model, + mel_spec, + text=text_inputs, + lens=mel_lens, ) if self.max_grad_norm > 0: @@ -267,47 +231,64 @@ def train_step(mel_spec, text_inputs, mel_lens): return loss training_start_date = datetime.datetime.now() - log_start_date = datetime.datetime.now() self.model.train() - for step, batch in enumerate(train_dataset): - effective_batch_size = batch["transcript"].shape[0] - text_inputs = [ - bytes(batch["transcript"][i]).decode("utf-8") - for i in range(effective_batch_size) - ] - text_inputs = [text_input.replace("\x00", "") for text_input in text_inputs] + pbar = tqdm( + initial=start_step, total=total_steps, desc="", unit="step" + ) + for step, batch in enumerate(train_dataset): mel_spec = rearrange(mx.array(batch["mel_spec"]), "b 1 n c -> b n c") mel_lens = mx.array(batch["mel_len"], dtype=mx.int32) - loss = train_step(mel_spec, text_inputs, mel_lens) - # mx.eval(state) - mx.eval(self.model.parameters(), self.optimizer.state) + # pad text to sequence length with -1 + seq_len = mel_spec.shape[1] + text = mx.array(batch["transcript"]).squeeze(-1) + text = mx.pad( + text, [(0, 0), (0, seq_len - text.shape[-1])], constant_values=-1 + ) + + loss = train_step(mel_spec, text, mel_lens) + mx.eval(state) + # mx.eval(self.model.parameters(), self.optimizer.state) if self.log_with_wandb: wandb.log( - {"loss": loss.item(), "lr": self.optimizer.learning_rate.item()}, + { + "loss": loss.item(), + "lr": self.optimizer.learning_rate.item(), + "batch_len": mel_lens.sum().item(), + }, step=global_step, ) - if global_step > 0 and global_step % log_every == 0: - elapsed_time = datetime.datetime.now() - log_start_date - log_start_date = datetime.datetime.now() - - print( - f"step {global_step}: loss = {loss.item():.4f}, steps/s = {(log_every / elapsed_time.seconds):.2f}" - ) + pbar.update(1) + pbar.set_postfix( + { + "loss": f"{loss.item():.4f}", + "batch_len": f"{mel_lens.sum().item():04d}", + } + ) global_step += 1 if global_step % save_every == 0: self.save_checkpoint(global_step) + if ( + global_step % sample_every == 0 + and sample_reference_audio is not None + and sample_reference_text is not None + and sample_generation_text is not None + and sample_generation_duration is not None + ): + self.generate_sample(sample_reference_audio, sample_reference_text, sample_generation_text, sample_generation_duration, global_step) + if global_step >= total_steps: break + pbar.close() if self.log_with_wandb: wandb.finish() diff --git a/train_libritts_small.py b/train_libritts_small.py index 518d75f..112a3dc 100644 --- a/train_libritts_small.py +++ b/train_libritts_small.py @@ -1,23 +1,30 @@ +from pathlib import Path + from mlx.utils import tree_flatten from f5_tts_mlx.cfm import F5TTS from f5_tts_mlx.dit import DiT -from f5_tts_mlx.trainer import F5TTSTrainer +from f5_tts_mlx.trainer import F5TTSTrainer, FRAMES_PER_SEC from f5_tts_mlx.data import load_libritts_r +from vocos_mlx import Vocos + +vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz") + vocab = {chr(i): i for i in range(256)} f5tts = F5TTS( transformer=DiT( - dim=256, - depth=8, + dim=768, + depth=16, heads=8, ff_mult=2, - text_dim=128, - conv_layers=2, + text_dim=384, + conv_layers=4, text_num_embeds=len(vocab), ), vocab_char_map=vocab, + vocoder=vocos.decode, ) num_trainable_params = sum( @@ -25,14 +32,22 @@ ) print(f"Using {num_trainable_params:,} trainable parameters.") -dataset = load_libritts_r(max_duration = 10) +epochs = 100 +max_duration = 10 + +dataset, path = load_libritts_r(max_duration = max_duration) + +max_batch_duration = 40 +batch_size = int(max_batch_duration / max_duration) +max_data_size = int(max_batch_duration * FRAMES_PER_SEC) * f5tts._mel_spec.n_mels batched_dataset = ( dataset - .repeat(1_000_000) # repeat indefinitely - .shuffle(1000) - .prefetch(prefetch_size = 4, num_threads = 1) - .batch(4) + .repeat(epochs) + .shuffle(500) + .prefetch(prefetch_size = batch_size, num_threads = 6) + .batch(batch_size, pad=dict(mel_spec=0.0, transcript=-1)) + # .dynamic_batch(buffer_size = batch_size * 2, key = "mel_spec", max_data_size = max_data_size, shuffle = True) .pad_to_multiple("mel_spec", dim=2, pad_multiple=256, pad_value=0.0) ) @@ -43,10 +58,18 @@ log_with_wandb=False ) +sample_path = "tests/test_en_1_ref_short.wav" +sample_text = "Some call me nature, others call me mother nature." + trainer.train( train_dataset=batched_dataset, learning_rate=1e-4, - log_every=10, - save_every=10_000, total_steps=1_000_000, + save_every=10_000, + checkpoint=100_000, + sample_every=100, + sample_reference_audio=sample_path, + sample_reference_text=sample_text, + sample_generation_duration=3.5, + sample_generation_text="The quick brown fox jumped over the lazy dog.", )