diff --git a/data/audio.py b/data/audio.py index 2ac5e6cf..a35dfd9c 100644 --- a/data/audio.py +++ b/data/audio.py @@ -18,11 +18,11 @@ import soundfile import torch from torch.nn import functional as F -import torchaudio as ta import av +import subprocess as sp -from .audio_utils import f32_pcm, i16_pcm, normalize_audio +from .audio_utils import f32_pcm, normalize_audio _av_initialized = False @@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., wav = torch.from_numpy(wav).t().contiguous() if len(wav.shape) == 1: wav = torch.unsqueeze(wav, 0) - elif ( - fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() - and duration <= 0 and seek_time == 0 - ): - # Torchaudio is faster if we load an entire file at once. - wav, sr = ta.load(fp) else: wav, sr = _av_read(filepath, seek_time, duration) if pad and duration > 0: @@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., return wav, sr +def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]): + # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely. + assert wav.dim() == 2, wav.shape + command = [ + 'ffmpeg', + '-loglevel', 'error', + '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]), + '-i', '-'] + flags + [str(out_path)] + input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes() + sp.run(command, input=input_, check=True) + + def audio_write(stem_name: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, - format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, - strategy: str = 'peak', peak_clip_headroom_db: float = 1, + format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None, + normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = True, make_parent_dir: bool = True, @@ -164,8 +170,9 @@ def audio_write(stem_name: tp.Union[str, Path], stem_name (str or Path): Filename without extension which will be added automatically. wav (torch.Tensor): Audio data to save. sample_rate (int): Sample rate of audio data. - format (str): Either "wav" or "mp3". + format (str): Either "wav", "mp3", "ogg", or "flac". mp3_rate (int): kbps when using mp3s. + ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself. normalize (bool): if `True` (default), normalizes according to the prescribed strategy (see after). If `False`, the strategy is only used in case clipping would happen. @@ -193,14 +200,20 @@ def audio_write(stem_name: tp.Union[str, Path], rms_headroom_db, loudness_headroom_db, loudness_compressor, log_clipping=log_clipping, sample_rate=sample_rate, stem_name=str(stem_name)) - kwargs: dict = {} if format == 'mp3': suffix = '.mp3' - kwargs.update({"compression": mp3_rate}) + flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k'] elif format == 'wav': - wav = i16_pcm(wav) suffix = '.wav' - kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) + flags = ['-f', 'wav', '-c:a', 'pcm_s16le'] + elif format == 'ogg': + suffix = '.ogg' + flags = ['-f', 'ogg', '-c:a', 'libvorbis'] + if ogg_rate is not None: + flags += ['-b:a', f'{ogg_rate}k'] + elif format == 'flac': + suffix = '.flac' + flags = ['-f', 'flac'] else: raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") if not add_suffix: @@ -209,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path], if make_parent_dir: path.parent.mkdir(exist_ok=True, parents=True) try: - ta.save(path, wav, sample_rate, **kwargs) + _piping_to_ffmpeg(path, wav, sample_rate, flags) except Exception: if path.exists(): # we do not want to leave half written files around. diff --git a/modules/rope.py b/modules/rope.py index 503e6748..c12cee09 100644 --- a/modules/rope.py +++ b/modules/rope.py @@ -81,13 +81,16 @@ def get_rotation(self, start: int, end: int): self.rotation = torch.polar(torch.ones_like(angles), angles) return self.rotation[start:end] - def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): + def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False): """Apply rope rotation to query or key tensor.""" - T = x.shape[1] - rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2) + T = x.shape[time_dim] + target_shape = [1] * x.dim() + target_shape[time_dim] = T + target_shape[-1] = -1 + rotation = self.get_rotation(start, start + T).view(target_shape) if self.xpos: - decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2) + decay = self.xpos.get_decay(start, start + T).view(target_shape) else: decay = 1.0 @@ -96,11 +99,11 @@ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) - x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2) + x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x) return x_out.type_as(x) - def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): + def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1): """ Apply rope rotation to both query and key tensors. Supports streaming mode, in which query and key are not expected to have the same shape. In streaming mode, key will be of length [P + C] with P the cached past timesteps, but @@ -110,12 +113,13 @@ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): query (torch.Tensor): Query to rotate. key (torch.Tensor): Key to rotate. start (int): Start index of the sequence for time offset. + time_dim (int): which dimension represent the time steps. """ - query_timesteps = query.shape[1] - key_timesteps = key.shape[1] + query_timesteps = query.shape[time_dim] + key_timesteps = key.shape[time_dim] streaming_offset = key_timesteps - query_timesteps - query_out = self.rotate(query, start + streaming_offset) - key_out = self.rotate(key, start, invert_decay=True) + query_out = self.rotate(query, start + streaming_offset, time_dim) + key_out = self.rotate(key, start, time_dim, invert_decay=True) return query_out, key_out diff --git a/modules/transformer.py b/modules/transformer.py index 048c06df..691df6a2 100644 --- a/modules/transformer.py +++ b/modules/transformer.py @@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'): _efficient_attention_backend = backend -def _get_attention_time_dimension() -> int: - if _efficient_attention_backend == 'torch': +def _get_attention_time_dimension(memory_efficient: bool) -> int: + if _efficient_attention_backend == 'torch' and memory_efficient: return 2 else: return 1 @@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) -def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: +def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" if n_rep == 1: return x - if _efficient_attention_backend == 'torch': + if _efficient_attention_backend == 'torch' and memory_efficient: bs, n_kv_heads, slen, head_dim = x.shape return ( x[:, :, None, :, :] @@ -234,7 +234,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype # Return a causal mask, accounting for potentially stored past keys/values # We actually return a bias for the attention score, as this has the same # convention both in the builtin MHA in Pytorch, and Xformers functions. - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if self.memory_efficient: from xformers.ops import LowerTriangularMask if current_steps == 1: @@ -264,7 +264,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype torch.full([], float('-inf'), device=device, dtype=dtype)) def _complete_kv(self, k, v): - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if self.cross_attention: # With cross attention we assume all keys and values # are already available, and streaming is with respect @@ -298,8 +298,7 @@ def _complete_kv(self, k, v): return nk, nv def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): - # TODO: fix and verify layout. - assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn." + time_dim = _get_attention_time_dimension(self.memory_efficient) # Apply rope embeddings to query and key tensors. assert self.rope is not None if 'past_keys' in self._streaming_state: @@ -311,7 +310,7 @@ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): else: past_context_offset = 0 streaming_offset = past_context_offset + past_keys_offset - return self.rope.rotate_qk(query, key, start=streaming_offset) + return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask=None, need_weights=False, attn_mask=None, @@ -320,7 +319,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, assert not is_causal, ("New param added in torch 2.0.1 not supported, " "use the causal args in the constructor.") - time_dim = _get_attention_time_dimension() + time_dim = _get_attention_time_dimension(self.memory_efficient) if time_dim == 2: layout = "b h t d" else: @@ -394,8 +393,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q, k = self._apply_rope(q, k) k, v = self._complete_kv(k, v) if self.kv_repeat > 1: - k = expand_repeated_kv(k, self.kv_repeat) - v = expand_repeated_kv(v, self.kv_repeat) + k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient) + v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient) if self.attention_as_float32: q, k, v = [x.float() for x in [q, k, v]] if self.memory_efficient: diff --git a/optim/dadam.py b/optim/dadam.py index a84402f7..e009969f 100644 --- a/optim/dadam.py +++ b/optim/dadam.py @@ -5,19 +5,15 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import TYPE_CHECKING, Any +from typing import Any import torch import torch.optim import torch.distributed as dist -if TYPE_CHECKING: - from torch.optim.optimizer import _params_t -else: - _params_t = Any - logger = logging.getLogger(__name__) +_params_t = Any def to_real(x): diff --git a/train.py b/train.py index 22dd1178..5851222c 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,7 @@ import logging import multiprocessing import os +from pathlib import Path import sys import typing as tp @@ -119,6 +120,11 @@ def init_seed_and_system(cfg): logger.debug('Setting num threads to %d', cfg.num_threads) set_efficient_attention_backend(cfg.efficient_attention_backend) logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend) + if 'SLURM_JOB_ID' in os.environ: + tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID']) + if tmpdir.exists(): + logger.info("Changing tmpdir to %s", tmpdir) + os.environ['TMPDIR'] = str(tmpdir) @hydra_main(config_path='../config', config_name='config', version_base='1.1')