Skip to content

Commit

Permalink
Add segment support for inference (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Aug 28, 2024
1 parent f43ec54 commit 27b711e
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 55 deletions.
6 changes: 5 additions & 1 deletion amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,12 @@ def _get_noise(self, noise_paths: list):
for wav, sr in noises
]

for wav in noises:
for wav, path in zip(noises, noise_paths):
assert wav.shape[-1] == self.num_samples, "noise wav too short"
assert not (
torch.all(wav < 0.01).item() is True
and torch.all(wav > -0.01).item() is True
), f"Loaded wav {path} is approximately silent which can cause NaN."

return noises

Expand Down
29 changes: 26 additions & 3 deletions amt/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import mmap
import os
import io
import math
import random
import shlex
import base64
Expand All @@ -12,7 +11,7 @@
import torchaudio

from multiprocessing import Pool, Queue, Process
from typing import Callable
from typing import Callable, Tuple

from aria.data.midi import MidiDict
from amt.tokenizer import AmtTokenizer
Expand Down Expand Up @@ -69,6 +68,7 @@ def get_wav_segments(
audio_path: str,
stride_factor: int | None = None,
pad_last=False,
segment: Tuple[int, int] | None = None,
):
assert os.path.isfile(audio_path), "Audio file not found"
config = load_config()
Expand All @@ -83,16 +83,35 @@ def get_wav_segments(
stride_samples = int(chunk_samples // stride_factor)
assert chunk_samples % stride_samples == 0, "Invalid stride"

# Handle segmentation if provided
if segment is not None:
assert (
segment[0] < segment[1]
), "Invalid segment: start must be less than end"
start_time_s, end_time_s = segment
start_sample = int(start_time_s * sample_rate)
end_sample = int(end_time_s * sample_rate)
stream.seek(start_time_s)
else:
start_sample, end_sample = 0, None

stream.add_basic_audio_stream(
frames_per_chunk=stride_samples,
stream_index=0,
sample_rate=sample_rate,
)

buffer = torch.tensor([], dtype=torch.float32)
total_samples = start_sample
for stride_seg in stream.stream():
seg_chunk = stride_seg[0].mean(1)

if end_sample and total_samples + seg_chunk.shape[0] > end_sample:
samples_to_use = end_sample - total_samples
seg_chunk = seg_chunk[:samples_to_use]

total_samples += seg_chunk.shape[0]

# Pad seg_chunk if required
if seg_chunk.shape[0] < stride_samples:
seg_chunk = F.pad(
Expand All @@ -110,7 +129,10 @@ def get_wav_segments(
if buffer.shape[0] == chunk_samples:
yield buffer

if pad_last == True:
if end_sample and total_samples >= end_sample:
break

if pad_last and buffer.shape[0] > stride_samples:
yield torch.nn.functional.pad(
buffer[stride_samples:],
(0, chunk_samples - len(buffer[stride_samples:])),
Expand Down Expand Up @@ -296,6 +318,7 @@ def build_synth_worker_fn(

class AmtDataset(torch.utils.data.Dataset):
def __init__(self, load_paths: str | list):
super().__init__()
self.tokenizer = AmtTokenizer(return_tensors=True)
self.config = load_config()["data"]
self.mixup_fn = self.tokenizer.export_msg_mixup()
Expand Down
133 changes: 84 additions & 49 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def gpu_manager(
# pid = -1 when its a pad sequence
for result, (_, pid) in zip(results, batch):
if pid != -1:
result_queue.put({"result": result, "pid": pid})
result_queue.put((result, pid))

except Exception as e:
logger.error(f"GPU manager failed with exception: {e}")
Expand Down Expand Up @@ -681,19 +681,20 @@ def transcribe_file(
result_queue: Queue,
pid: int,
tokenizer: AmtTokenizer = AmtTokenizer(),
segment: Tuple[int, int] | None = None,
):
logger = logging.getLogger(__name__)

logger.info(f"Getting wav segments: {file_path}")

res = []
seq = [tokenizer.bos_tok]
concat_seq = [tokenizer.bos_tok]
idx = 0
for curr_audio_segment in get_wav_segments(
audio_path=file_path,
stride_factor=STRIDE_FACTOR,
pad_last=True,
segment=segment,
):
init_idx = len(seq)
# Add to gpu queue and wait for results
Expand All @@ -706,25 +707,25 @@ def transcribe_file(
except Exception as e:
pass
else:
if gpu_result["pid"] == pid:
seq = gpu_result["result"]
if gpu_result[1] == pid:
seq = gpu_result[0]
break
else:
result_queue.put(gpu_result)

if len(silent_intervals) > 0:
logger.debug(
f"Seen silent intervals in segment {idx}: {silent_intervals}"
f"Seen silent intervals in audio chunk {idx}: {silent_intervals}"
)

seq_adj = _process_silent_intervals(
seq, intervals=silent_intervals, tokenizer=tokenizer
)

if len(seq_adj) < len(seq) - 5:
if len(seq_adj) < len(seq) - 15:
logger.info(
f"Removed tokens ({len(seq)} -> {len(seq_adj)}) "
f"in segment {idx} according to silence in intervals: "
f"in audio chunk {idx} according to silence in intervals: "
f"{silent_intervals}",
)
seq = seq_adj
Expand All @@ -736,7 +737,9 @@ def transcribe_file(
LEN_MS - CHUNK_LEN_MS,
)
except Exception as e:
logger.info(f"Failed to reconcile segment {idx}: {file_path}")
logger.info(
f"Failed to reconcile sequences for audio chunk {idx}: {file_path}"
)
logger.debug(traceback.format_exc())

try:
Expand All @@ -755,11 +758,13 @@ def transcribe_file(

else:
if seq[-1] == tokenizer.eos_tok:
logger.info(f"Seen eos_tok at segment {idx}: {file_path}")
logger.info(f"Seen eos_tok in audio chunk {idx}: {file_path}")
seq = seq[:-1]

if len(next_seq) == 1:
logger.info(f"Skipping segment {idx} (silence): {file_path}")
logger.info(
f"Skipping audio chunk {idx} (silence): {file_path}"
)
seq = [tokenizer.bos_tok]
else:
concat_seq += _shift_onset(
Expand All @@ -770,9 +775,7 @@ def transcribe_file(

idx += 1

res.append(concat_seq)

return res
return concat_seq


def get_save_path(
Expand Down Expand Up @@ -806,6 +809,7 @@ def process_file(
save_dir: str,
input_dir: str,
logger: logging.Logger,
segments: List[Tuple[int, int]] | None = None,
):
def _save_seq(_seq: List, _save_path: str):
if os.path.exists(_save_path):
Expand Down Expand Up @@ -846,26 +850,41 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
return num_removed

pid = threading.get_ident()
try:
seqs = transcribe_file(file_path, gpu_task_queue, result_queue, pid=pid)
except Exception as e:
logger.error(f"Failed to process {file_path}: {traceback.format_exc()}")
task_rmv_cnt = remove_failures_from_queue_(gpu_task_queue, pid)
res_rmv_cnt = remove_failures_from_queue_(result_queue, pid)
logger.info(f"Removed {task_rmv_cnt} from task queue")
logger.info(f"Removed {res_rmv_cnt} from result queue")
return
if segments is None:
segments = [None]

if len(segments) == 0:
logger.info(f"No segments to transcribe, skipping file: {file_path}")

for idx, segment in enumerate(segments):
try:
seq = transcribe_file(
file_path,
gpu_task_queue,
result_queue,
pid=pid,
segment=segment,
)
except Exception as e:
logger.error(
f"Failed to process {file_path} segment {idx}: {traceback.format_exc()}"
)
task_rmv_cnt = remove_failures_from_queue_(gpu_task_queue, pid)
res_rmv_cnt = remove_failures_from_queue_(result_queue, pid)
logger.info(f"Removed {task_rmv_cnt} from task queue")
logger.info(f"Removed {res_rmv_cnt} from result queue")
continue

logger.info(f"Finished file: {file_path}")
for seq in seqs:
logger.info(f"Finished file: {file_path} (segment: {idx})")
if len(seq) < 500:
logger.info("Skipping seq - too short")
logger.info(f"Skipping seq - too short (segment {idx})")
else:
logger.debug(
f"Saving seq of length {len(seq)} from file: {file_path}"
f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx})"
)

_save_seq(seq, get_save_path(file_path, input_dir, save_dir))
idx = f"_{idx}" if segment is not None else ""
save_path = get_save_path(file_path, input_dir, save_dir, idx)
_save_seq(seq, save_path)

logger.info(f"{file_queue.qsize()} file(s) remaining in queue")

Expand Down Expand Up @@ -905,7 +924,7 @@ def worker(
def process_file_wrapper():
while True:
try:
file_path = file_queue.get(timeout=15)
file_to_process = file_queue.get(timeout=15)
except Empty as e:
if file_queue.empty():
logger.info("File queue empty")
Expand All @@ -916,14 +935,15 @@ def process_file_wrapper():
continue

process_file(
file_path,
file_queue,
gpu_task_queue,
result_queue,
tokenizer,
save_dir,
input_dir,
logger,
file_path=file_to_process["path"],
file_queue=file_queue,
gpu_task_queue=gpu_task_queue,
result_queue=result_queue,
tokenizer=tokenizer,
save_dir=save_dir,
input_dir=input_dir,
logger=logger,
segments=file_to_process.get("segments", None),
)

if file_queue.empty():
Expand All @@ -943,7 +963,7 @@ def process_file_wrapper():


def batch_transcribe(
file_paths: List,
files_to_process: List[dict],
model: AmtEncoderDecoder,
save_dir: str,
batch_size: int = 8,
Expand All @@ -968,21 +988,36 @@ def batch_transcribe(
os.remove("transcribe.log")

if quantize is True:
logger.info("Quantising decoder weights to int8")
logger.info("Quantizing decoder weights to int8")
model.decoder = quantize_int8(model.decoder)

file_queue = Queue()
sorted(file_paths, key=lambda x: os.path.getsize(x), reverse=True)
for file_path in file_paths:
sorted(
files_to_process, key=lambda x: os.path.getsize(x["path"]), reverse=True
)
for file_to_process in files_to_process:
# Only add to file_queue if transcription MIDI file doesn't exist
if (
os.path.isfile(get_save_path(file_path, input_dir, save_dir))
os.path.isfile(
get_save_path(file_to_process["path"], input_dir, save_dir)
)
is False
):
file_queue.put(file_path)
elif len(file_paths) == 1:
file_queue.put(file_path)
) and os.path.isfile(
get_save_path(
file_to_process["path"], input_dir, save_dir, idx="_0"
)
) is False:
file_queue.put(file_to_process)
elif len(files_to_process) == 1:
file_queue.put(file_to_process)

logger.info(f"Files to process: {file_queue.qsize()}/{len(file_paths)}")
logger.info(
f"Files to process: {file_queue.qsize()}/{len(files_to_process)}"
)

if file_queue.qsize() == 0:
logger.info("No files to process")
return

if num_workers is None:
num_workers = min(
Expand Down Expand Up @@ -1113,10 +1148,10 @@ def batch_transcribe(
cleanup_processes(child_pids=child_pids)
logger.info("Complete")
finally:
gpu_batch_manager_process.terminate()
gpu_batch_manager_process.join()
watchdog_process.terminate()
watchdog_process.join()
gpu_batch_manager_process.terminate()
gpu_batch_manager_process.join()
file_queue.close()
file_queue.join_thread()
gpu_task_queue.close()
Expand Down
Loading

0 comments on commit 27b711e

Please sign in to comment.