Skip to content

Commit

Permalink
changing tmpdir when running in slurm + not depending anymore on torc…
Browse files Browse the repository at this point in the history
…haudio for writing audio files. (facebookresearch#306)

* changing tmpdir when runnign in slurm

* fixing typing in dadam

* limiting dependency on torchaudio for writing files

* not using torchaudio for reading anymore

* trying desperatly to get those unit tests to pass

* plop

* fixing tests once more

* linter

* plop
  • Loading branch information
adefossez authored Oct 12, 2023
1 parent cf4537c commit 00f7ac4
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 44 deletions.
45 changes: 29 additions & 16 deletions data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import soundfile
import torch
from torch.nn import functional as F
import torchaudio as ta

import av
import subprocess as sp

from .audio_utils import f32_pcm, i16_pcm, normalize_audio
from .audio_utils import f32_pcm, normalize_audio


_av_initialized = False
Expand Down Expand Up @@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
wav = torch.from_numpy(wav).t().contiguous()
if len(wav.shape) == 1:
wav = torch.unsqueeze(wav, 0)
elif (
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
and duration <= 0 and seek_time == 0
):
# Torchaudio is faster if we load an entire file at once.
wav, sr = ta.load(fp)
else:
wav, sr = _av_read(filepath, seek_time, duration)
if pad and duration > 0:
Expand All @@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
return wav, sr


def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
assert wav.dim() == 2, wav.shape
command = [
'ffmpeg',
'-loglevel', 'error',
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
'-i', '-'] + flags + [str(out_path)]
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
sp.run(command, input=input_, check=True)


def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
Expand All @@ -164,8 +170,9 @@ def audio_write(stem_name: tp.Union[str, Path],
stem_name (str or Path): Filename without extension which will be added automatically.
wav (torch.Tensor): Audio data to save.
sample_rate (int): Sample rate of audio data.
format (str): Either "wav" or "mp3".
format (str): Either "wav", "mp3", "ogg", or "flac".
mp3_rate (int): kbps when using mp3s.
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
Expand Down Expand Up @@ -193,14 +200,20 @@ def audio_write(stem_name: tp.Union[str, Path],
rms_headroom_db, loudness_headroom_db, loudness_compressor,
log_clipping=log_clipping, sample_rate=sample_rate,
stem_name=str(stem_name))
kwargs: dict = {}
if format == 'mp3':
suffix = '.mp3'
kwargs.update({"compression": mp3_rate})
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
elif format == 'wav':
wav = i16_pcm(wav)
suffix = '.wav'
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
elif format == 'ogg':
suffix = '.ogg'
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
if ogg_rate is not None:
flags += ['-b:a', f'{ogg_rate}k']
elif format == 'flac':
suffix = '.flac'
flags = ['-f', 'flac']
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
Expand All @@ -209,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path],
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
ta.save(path, wav, sample_rate, **kwargs)
_piping_to_ffmpeg(path, wav, sample_rate, flags)
except Exception:
if path.exists():
# we do not want to leave half written files around.
Expand Down
24 changes: 14 additions & 10 deletions modules/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ def get_rotation(self, start: int, end: int):
self.rotation = torch.polar(torch.ones_like(angles), angles)
return self.rotation[start:end]

def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
"""Apply rope rotation to query or key tensor."""
T = x.shape[1]
rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
T = x.shape[time_dim]
target_shape = [1] * x.dim()
target_shape[time_dim] = T
target_shape[-1] = -1
rotation = self.get_rotation(start, start + T).view(target_shape)

if self.xpos:
decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
decay = self.xpos.get_decay(start, start + T).view(target_shape)
else:
decay = 1.0

Expand All @@ -96,11 +99,11 @@ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):

x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)

return x_out.type_as(x)

def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
""" Apply rope rotation to both query and key tensors.
Supports streaming mode, in which query and key are not expected to have the same shape.
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
Expand All @@ -110,12 +113,13 @@ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
query (torch.Tensor): Query to rotate.
key (torch.Tensor): Key to rotate.
start (int): Start index of the sequence for time offset.
time_dim (int): which dimension represent the time steps.
"""
query_timesteps = query.shape[1]
key_timesteps = key.shape[1]
query_timesteps = query.shape[time_dim]
key_timesteps = key.shape[time_dim]
streaming_offset = key_timesteps - query_timesteps

query_out = self.rotate(query, start + streaming_offset)
key_out = self.rotate(key, start, invert_decay=True)
query_out = self.rotate(query, start + streaming_offset, time_dim)
key_out = self.rotate(key, start, time_dim, invert_decay=True)

return query_out, key_out
23 changes: 11 additions & 12 deletions modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'):
_efficient_attention_backend = backend


def _get_attention_time_dimension() -> int:
if _efficient_attention_backend == 'torch':
def _get_attention_time_dimension(memory_efficient: bool) -> int:
if _efficient_attention_backend == 'torch' and memory_efficient:
return 2
else:
return 1
Expand Down Expand Up @@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)


def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
if n_rep == 1:
return x
if _efficient_attention_backend == 'torch':
if _efficient_attention_backend == 'torch' and memory_efficient:
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
Expand Down Expand Up @@ -234,7 +234,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
# Return a causal mask, accounting for potentially stored past keys/values
# We actually return a bias for the attention score, as this has the same
# convention both in the builtin MHA in Pytorch, and Xformers functions.
time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.memory_efficient:
from xformers.ops import LowerTriangularMask
if current_steps == 1:
Expand Down Expand Up @@ -264,7 +264,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
torch.full([], float('-inf'), device=device, dtype=dtype))

def _complete_kv(self, k, v):
time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.cross_attention:
# With cross attention we assume all keys and values
# are already available, and streaming is with respect
Expand Down Expand Up @@ -298,8 +298,7 @@ def _complete_kv(self, k, v):
return nk, nv

def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
# TODO: fix and verify layout.
assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn."
time_dim = _get_attention_time_dimension(self.memory_efficient)
# Apply rope embeddings to query and key tensors.
assert self.rope is not None
if 'past_keys' in self._streaming_state:
Expand All @@ -311,7 +310,7 @@ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
else:
past_context_offset = 0
streaming_offset = past_context_offset + past_keys_offset
return self.rope.rotate_qk(query, key, start=streaming_offset)
return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask=None, need_weights=False, attn_mask=None,
Expand All @@ -320,7 +319,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
"use the causal args in the constructor.")

time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if time_dim == 2:
layout = "b h t d"
else:
Expand Down Expand Up @@ -394,8 +393,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
q, k = self._apply_rope(q, k)
k, v = self._complete_kv(k, v)
if self.kv_repeat > 1:
k = expand_repeated_kv(k, self.kv_repeat)
v = expand_repeated_kv(v, self.kv_repeat)
k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
if self.attention_as_float32:
q, k, v = [x.float() for x in [q, k, v]]
if self.memory_efficient:
Expand Down
8 changes: 2 additions & 6 deletions optim/dadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import TYPE_CHECKING, Any
from typing import Any

import torch
import torch.optim
import torch.distributed as dist

if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any


logger = logging.getLogger(__name__)
_params_t = Any


def to_real(x):
Expand Down
6 changes: 6 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import multiprocessing
import os
from pathlib import Path
import sys
import typing as tp

Expand Down Expand Up @@ -119,6 +120,11 @@ def init_seed_and_system(cfg):
logger.debug('Setting num threads to %d', cfg.num_threads)
set_efficient_attention_backend(cfg.efficient_attention_backend)
logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
if 'SLURM_JOB_ID' in os.environ:
tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
if tmpdir.exists():
logger.info("Changing tmpdir to %s", tmpdir)
os.environ['TMPDIR'] = str(tmpdir)


@hydra_main(config_path='../config', config_name='config', version_base='1.1')
Expand Down

0 comments on commit 00f7ac4

Please sign in to comment.