Skip to content

Commit

Permalink
Fix distortion and dataset indexing (#16)
Browse files Browse the repository at this point in the history
* working

* format

* fix distortion bottlekneck

* format

* adj
  • Loading branch information
loubbrad authored Mar 9, 2024
1 parent b82a9da commit d25a7fa
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 58 deletions.
28 changes: 18 additions & 10 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ def __init__(
min_dist_gain: int = 0,
noise_ratio: float = 0.95,
reverb_ratio: float = 0.95,
applause_ratio: float = 0.01, # CHANGE
applause_ratio: float = 0.01,
distort_ratio: float = 0.15,
reduce_ratio: float = 0.01,
spec_aug_ratio: float = 0.25,
spec_aug_ratio: float = 0.5,
):
super().__init__()
self.tokenizer = AmtTokenizer()
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=10, iid_masks=True
freq_mask_param=15, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=1000, iid_masks=True
Expand Down Expand Up @@ -374,6 +374,17 @@ def apply_distortion(self, wav: torch.tensor):

return AF.overdrive(wav, gain=gain, colour=colour)

def distortion_aug_cpu(self, wav: torch.Tensor):
# This function should run on the cpu (i.e. in the dataloader collate
# function) in order to not be a bottlekneck

if random.random() < self.reduce_ratio:
wav = self.apply_reduction(wav)
if random.random() < self.distort_ratio:
wav = self.apply_distortion(wav)

return wav

def shift_spec(self, specs: torch.Tensor, shift: int):
if shift == 0:
return specs
Expand All @@ -400,18 +411,15 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
return shifted_specs

def aug_wav(self, wav: torch.Tensor):
# This function doesn't apply distortion. If distortion is desired it
# should be run before hand on the cpu with distortion_aug_cpu.

# Noise
if random.random() < self.noise_ratio:
wav = self.apply_noise(wav)
if random.random() < self.applause_ratio:
wav = self.apply_applause(wav)

# Distortion
if random.random() < self.reduce_ratio:
wav = self.apply_reduction(wav)
elif random.random() < self.distort_ratio:
wav = self.apply_distortion(wav)

# Reverb
if random.random() < self.reverb_ratio:
return self.apply_reverb(wav)
Expand Down Expand Up @@ -439,7 +447,7 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None):
return log_spec

def forward(self, wav: torch.Tensor, shift: int = 0):
# Noise, distortion, and reverb
# Noise, and reverb
wav = self.aug_wav(wav)

# Spec & pitch shift
Expand Down
55 changes: 54 additions & 1 deletion amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,17 @@ def __init__(self, load_path: str):
self.file_mmap = mmap.mmap(
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ
)
self.index = self._build_index()

index_path = AmtDataset._get_index_path(load_path=load_path)
if os.path.isfile(index_path) is True:
self.index = self._load_index(load_path=index_path)
else:
print("Calculating index...")
self.index = self._build_index()
print(
f"Index of length {len(self.index)} calculated, saving to {index_path}"
)
self._save_index(index=self.index, save_path=index_path)

def close(self):
if self.file_buff:
Expand Down Expand Up @@ -167,6 +177,21 @@ def _build_index(self):

return index

def _save_index(self, index: list[int], save_path: str):
with open(save_path, "w") as file:
for idx in index:
file.write(f"{idx}\n")

def _load_index(self, load_path: str):
with open(load_path, "r") as file:
return [int(line.strip()) for line in file]

@staticmethod
def _get_index_path(load_path: str):
return (
f"{load_path.rsplit('.', 1)[0]}_index.{load_path.rsplit('.', 1)[1]}"
)

@classmethod
def build(
cls,
Expand All @@ -175,6 +200,12 @@ def build(
num_processes: int = 1,
):
assert os.path.isfile(save_path) is False, f"{save_path} already exists"

index_path = AmtDataset._get_index_path(load_path=save_path)
if os.path.isfile(index_path):
print(f"Removing existing index file at {index_path}")
os.remove(AmtDataset._get_index_path(load_path=save_path))

num_paths = len(matched_load_paths)
with Pool(processes=num_processes) as pool:
sharded_save_paths = []
Expand Down Expand Up @@ -202,3 +233,25 @@ def build(
os.system(shell_cmd)
for _path in sharded_save_paths:
os.remove(_path)

# Create index by loading object
AmtDataset(load_path=save_path)

def _build_index(self):
self.file_mmap.seek(0)
index = []
pos = 0
while True:
pos_buff = pos

pos = self.file_mmap.find(b"\n", pos)
if pos == -1:
break
pos = self.file_mmap.find(b"\n", pos + 1)
if pos == -1:
break

index.append(pos_buff)
pos += 1

return index
81 changes: 52 additions & 29 deletions amt/infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
import random
import logging
import torch
import torch.multiprocessing as multiprocessing

Expand All @@ -21,7 +22,23 @@
VEL_TOLERANCE = 50


# TODO: Profile and fix gpu util
def _setup_logger():
logger = logging.getLogger(__name__)
for h in logger.handlers[:]:
logger.removeHandler(h)

logger.propagate = False
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"[%(asctime)s] %(process)d: [%(levelname)s] %(message)s",
)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logging.getLogger(__name__)


def calculate_vel(
Expand Down Expand Up @@ -101,7 +118,7 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
else:
# Call the function with float16 if bfloat16 is not supported
with torch.autocast("cuda", dtype=torch.float16):
with torch.autocast("cuda", dtype=torch.float32):
return func(*args, **kwargs)

return wrapper
Expand All @@ -114,6 +131,7 @@ def process_segments(
audio_transform: AudioTransform,
tokenizer: AmtTokenizer,
):
logger = logging.getLogger(__name__)
audio_segs = torch.stack(
[audio_seg for (audio_seg, prefix), _ in tasks]
).cuda()
Expand All @@ -131,14 +149,14 @@ def process_segments(

kv_cache = model.get_empty_cache()

for idx in (
pbar := tqdm(
range(min_prefix_len, MAX_SEQ_LEN - 1),
total=MAX_SEQ_LEN - (min_prefix_len + 1),
leave=False,
)
):
# for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
# for idx in (
# pbar := tqdm(
# range(min_prefix_len, MAX_SEQ_LEN - 1),
# total=MAX_SEQ_LEN - (min_prefix_len + 1),
# leave=False,
# )
# ):
for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
if idx == min_prefix_len:
logits = model.decoder(
xa=audio_features,
Expand Down Expand Up @@ -181,7 +199,7 @@ def process_segments(
break

if not all(eos_seen):
print("WARNING: OVERFLOW")
logger.warning("Context length overflow when transcribing segment")
for _idx in range(seq.shape[0]):
if eos_seen[_idx] == False:
eos_seen[_idx] = MAX_SEQ_LEN
Expand All @@ -201,19 +219,19 @@ def gpu_manager(
batch_size: int,
):
# model.compile()
logger = _setup_logger()
audio_transform = AudioTransform().cuda()
tokenizer = AmtTokenizer(return_tensors=True)
process_pid = multiprocessing.current_process().pid

wait_for_batch = True
batch = []
while True:
try:
task, pid = gpu_task_queue.get(timeout=5)
except:
print(f"{process_pid}: GPU task timeout")
logger.info(f"GPU task timeout")
if len(batch) == 0:
print(f"{process_pid}: Finished GPU tasks")
logger.info(f"Finished GPU tasks")
return
else:
wait_for_batch = False
Expand Down Expand Up @@ -274,8 +292,10 @@ def process_file(
result_queue: Queue,
tokenizer: AmtTokenizer = AmtTokenizer(),
):
process_pid = multiprocessing.current_process().pid
print(f"{process_pid}: Getting wav segments")
logger = logging.getLogger(__name__)
pid = multiprocessing.current_process().pid

logger.info(f"Getting wav segments")
audio_segments = [
f
for f, _ in get_wav_mid_segments(
Expand All @@ -288,10 +308,10 @@ def process_file(
init_idx = len(seq)

# Add to gpu queue and wait for results
gpu_task_queue.put(((audio_seg, seq), process_pid))
gpu_task_queue.put(((audio_seg, seq), pid))
while True:
gpu_result = result_queue.get()
if gpu_result["pid"] == process_pid:
if gpu_result["pid"] == pid:
seq = gpu_result["result"]
break
else:
Expand All @@ -307,7 +327,7 @@ def process_file(
else:
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS)
if len(seq) == 1:
print(f"{process_pid}: exiting early")
logger.info(f"Exiting early")
return res

return res
Expand Down Expand Up @@ -336,19 +356,19 @@ def _get_save_path(_file_path: str):

return save_path

pid = multiprocessing.current_process().pid
logger = _setup_logger()
tokenizer = AmtTokenizer()
files_processed = 0
while not file_queue.empty():
file_path = file_queue.get()
save_path = _get_save_path(file_path)
if os.path.exists(save_path):
print(f"{pid}: {save_path} already exists, overwriting")
logger.info(f"{save_path} already exists, overwriting")

try:
res = process_file(file_path, gpu_task_queue, result_queue)
except Exception as e:
print(f"{pid}: Failed to transcribe {file_path}")
logger.error(f"Failed to transcribe {file_path}")
continue

files_processed += 1
Expand All @@ -365,14 +385,14 @@ def _get_save_path(_file_path: str):
mid = mid_dict.to_midi()
mid.save(save_path)
except Exception as e:
print(f"{pid}: Failed to detokenize with error {e}")
logger.error(f"Failed to detokenize with error {e}")
else:
print(f"{pid}: Finished file {files_processed} - {file_path}")
print(f"{pid}: {file_queue.qsize()} file(s) remaining in queue")
logger.info(f"Finished file {files_processed} - {file_path}")
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")


def batch_transcribe(
file_paths: list,
file_paths, # Queue | list,
model: AmtEncoderDecoder,
save_dir: str,
batch_size: int = 16,
Expand All @@ -384,9 +404,12 @@ def batch_transcribe(

model.cuda()
model.eval()
file_queue = Queue()
for file_path in file_paths:
file_queue.put(file_path)
if isinstance(file_paths, list):
file_queue = Queue()
for file_path in file_paths:
file_queue.put(file_path)
else:
file_queue = file_paths

gpu_task_queue = Queue()
result_queue = Queue()
Expand Down
Loading

0 comments on commit d25a7fa

Please sign in to comment.