Skip to content

Commit

Permalink
Add multiple paths to AmtDataset (#26)
Browse files Browse the repository at this point in the history
* clean up

* update test

* clean audio

* add multiple paths to dataset

* add tdqm
  • Loading branch information
loubbrad authored Apr 16, 2024
1 parent c333406 commit fb1fed4
Show file tree
Hide file tree
Showing 20 changed files with 256 additions and 2,377 deletions.
170 changes: 5 additions & 165 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
import random
import torch
import torchaudio
import torch.nn.functional as F
import torchaudio.functional as AF
import numpy as np

from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union

from amt.config import load_config
from amt.tokenizer import AmtTokenizer
Expand All @@ -28,160 +22,6 @@
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]

# chat-gpt says that this will work for reading mp3 ?? not tested
# cmd = [
# "ffmpeg",
# "-nostdin",
# "-threads", "0",
# "-i", file,
# "-ac", "1",
# "-ar", str(sr),
# "-"
# ]

# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(
array, [pad for sizes in pad_widths[::-1] for pad in sizes]
)
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array


@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128, 256}, f"Unsupported n_mels: {n_mels}"

filters_path = os.path.join(
os.path.dirname(__file__), "assets", "mel_filters.npz"
)
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)


def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 256,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(
audio, N_FFT, HOP_LENGTH, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2

filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec


# Refactor default params are stored in config.json
class AudioTransform(torch.nn.Module):
def __init__(
Expand All @@ -191,15 +31,15 @@ def __init__(
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
noise_ratio: float = 0.75,
reverb_ratio: float = 0.75,
noise_ratio: float = 0.9,
reverb_ratio: float = 0.9,
applause_ratio: float = 0.01,
bandpass_ratio: float = 0.15,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
detune_ratio: float = 0.0,
detune_max_shift: float = 0.0,
spec_aug_ratio: float = 0.9,
detune_ratio: float = 0.1,
detune_max_shift: float = 0.15,
spec_aug_ratio: float = 0.95,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand Down
105 changes: 48 additions & 57 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_wav_mid_segments(
midi_dict=midi_dict,
start_ms=idx // samples_per_ms,
end_ms=(idx + num_samples) / samples_per_ms,
max_pedal_len_ms=10000,
max_pedal_len_ms=15000,
)

# Hardcoded to 2.5s
Expand Down Expand Up @@ -148,8 +148,8 @@ def pianoteq_cmd_fn(mid_path: str, wav_path: str):
safe_mid_path = shlex.quote(mid_path)
safe_wav_path = shlex.quote(wav_path)

# Construct the command
command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}"
executable_path = "/home/loubb/pianoteq/x86-64bit/Pianoteq 8 STAGE"
command = f'"{executable_path}" --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}'

return command

Expand Down Expand Up @@ -205,8 +205,6 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
if os.path.isfile(audio_path_temp):
os.remove(audio_path_temp)

print(f"Found {len(features)}")

with open(save_path, mode="a") as file:
for wav, seq in features:
wav_buffer = io.BytesIO()
Expand Down Expand Up @@ -256,34 +254,37 @@ def build_synth_worker_fn(


class AmtDataset(torch.utils.data.Dataset):
def __init__(self, load_path: str):
def __init__(self, load_paths: str | list):
self.tokenizer = AmtTokenizer(return_tensors=True)
self.config = load_config()["data"]
self.mixup_fn = self.tokenizer.export_msg_mixup()
self.file_buff = open(load_path, mode="r")
self.file_mmap = mmap.mmap(
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ
)

index_path = AmtDataset._get_index_path(load_path=load_path)
if os.path.isfile(index_path) is True:
self.index = self._load_index(load_path=index_path)
else:
print("Calculating index...")
self.index = self._build_index()
print(
f"Index of length {len(self.index)} calculated, saving to {index_path}"
)
self._save_index(index=self.index, save_path=index_path)

def close(self):
if self.file_buff:
self.file_buff.close()
if self.file_mmap:
self.file_mmap.close()
if isinstance(load_paths, str):
load_paths = [load_paths]
self.file_buffs = []
self.file_mmaps = []
self.index = []

for path in load_paths:
buff = open(path, mode="r")
self.file_buffs.append(buff)
mmap_obj = mmap.mmap(buff.fileno(), 0, access=mmap.ACCESS_READ)
self.file_mmaps.append(mmap_obj)

index_path = AmtDataset._get_index_path(load_path=path)
if os.path.isfile(index_path):
_index = self._load_index(load_path=index_path)
else:
print("Calculating index...")
_index = self._build_index(mmap_obj)
print(
f"Index of length {len(_index)} calculated, saving to {index_path}"
)
self._save_index(index=_index, save_path=index_path)

def __del__(self):
self.close()
self.index.extend(
[(len(self.file_mmaps) - 1, pos) for pos in _index]
)

def __len__(self):
return len(self.index)
Expand All @@ -295,13 +296,13 @@ def _format(tok):
return tuple(tok)
return tok

self.file_mmap.seek(self.index[idx])
file_id, pos = self.index[idx]
mmap_obj = self.file_mmaps[file_id]
mmap_obj.seek(pos)

# Load data from line
wav = torch.load(
io.BytesIO(base64.b64decode(self.file_mmap.readline()))
)
_seq = orjson.loads(base64.b64decode(self.file_mmap.readline()))
wav = torch.load(io.BytesIO(base64.b64decode(mmap_obj.readline())))
_seq = orjson.loads(base64.b64decode(mmap_obj.readline()))

_seq = [_format(tok) for tok in _seq] # Format seq
_seq = self.mixup_fn(_seq) # Data augmentation
Expand All @@ -317,18 +318,14 @@ def _format(tok):

return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx

def _build_index(self):
self.file_mmap.seek(0)
index = []
while True:
pos = self.file_mmap.tell()
self.file_mmap.readline()
if self.file_mmap.readline() == b"":
break
else:
index.append(pos)
def close(self):
for buff in self.file_buffs:
buff.close()
for mmap in self.file_mmaps:
mmap.close()

return index
def __del__(self):
self.close()

def _save_index(self, index: list, save_path: str):
with open(save_path, "w") as file:
Expand All @@ -345,17 +342,17 @@ def _get_index_path(load_path: str):
f"{load_path.rsplit('.', 1)[0]}_index.{load_path.rsplit('.', 1)[1]}"
)

def _build_index(self):
self.file_mmap.seek(0)
def _build_index(self, mmap_obj):
mmap_obj.seek(0)
index = []
pos = 0
while True:
pos_buff = pos

pos = self.file_mmap.find(b"\n", pos)
pos = mmap_obj.find(b"\n", pos)
if pos == -1:
break
pos = self.file_mmap.find(b"\n", pos + 1)
pos = mmap_obj.find(b"\n", pos + 1)
if pos == -1:
break

Expand Down Expand Up @@ -433,16 +430,10 @@ def build(
if shutil.which("cat") is None:
print("The GNU cat command is not available")
else:
print("Concatinating sharded dataset files")
shell_cmd = f"cat "
for _path in sharded_save_paths:
shell_cmd += f"{_path} "
print()
shell_cmd += f">> {save_path}"

os.system(shell_cmd)
for _path in sharded_save_paths:
shell_cmd = f"cat {_path} >> {save_path}"
os.system(shell_cmd)
os.remove(_path)

# Create index by loading object
AmtDataset(load_path=save_path)
AmtDataset(load_paths=save_path)
Loading

0 comments on commit fb1fed4

Please sign in to comment.