Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add silence detection to inference #36

Merged
merged 9 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions amt/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ def __init__(
dtype=torch.bfloat16,
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
self.dtype = dtype
self.cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer(
"k_cache", torch.zeros(self.cache_shape, dtype=dtype)
)
self.register_buffer(
"v_cache", torch.zeros(self.cache_shape, dtype=dtype)
)

def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val, v_val: [B, H, L, D]
Expand Down Expand Up @@ -118,7 +123,7 @@ def forward(
class CrossAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
assert n_state % n_head == 0, "n_head does not evenly divide n_state"

self.n_head = n_head
self.d_head = n_state // n_head
Expand Down
183 changes: 173 additions & 10 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import signal
import time
import copy
import random
import logging
import traceback
Expand All @@ -9,15 +10,17 @@
import torch.multiprocessing as multiprocessing
import torch._dynamo.config
import torch._inductor.config
import numpy as np

from torch.multiprocessing import Queue
from tqdm import tqdm
from functools import wraps
from torch.cuda import is_bf16_supported
from librosa.effects import _signal_to_frame_nonsilent

from amt.inference.model import AmtEncoderDecoder
from amt.tokenizer import AmtTokenizer
from amt.audio import AudioTransform
from amt.audio import AudioTransform, SAMPLE_RATE
from amt.data import get_wav_mid_segments

torch._inductor.config.coordinate_descent_tuning = True
Expand Down Expand Up @@ -78,8 +81,8 @@ def recalculate_tok_ids(

# Mask out tok_ids larger than 30ms from original tok_id
tok_ids_expanded = tok_ids.unsqueeze(1)
mask_c = col_indices <= tok_ids_expanded + 3
mask_d = col_indices >= tok_ids_expanded - 3
mask_c = col_indices <= tok_ids_expanded + 2
mask_d = col_indices >= tok_ids_expanded - 2
beam_mask = mask_c & mask_d

# Don't mask out the original tok_id (required for non-onset/vel toks)
Expand Down Expand Up @@ -218,8 +221,8 @@ def process_segments(
),
)

logits[:, 389] *= 1.2
next_tok_ids = torch.argmax(logits, dim=-1)
# logits[:, 389] *= 1.05
# next_tok_ids = torch.argmax(logits, dim=-1)

next_tok_ids = recalculate_tok_ids(
logits=logits,
Expand Down Expand Up @@ -429,13 +432,141 @@ def _truncate_seq(
if len(_mid_dict.note_msgs) == 0:
return [tokenizer.bos_tok]
else:
# The end_ms - 1 is a workaround to get rid of the off msgs
res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1)

if res[-1] == tokenizer.eos_tok:
res.pop()
return res


# This is a sloppy implementation
def process_silent_intervals(
seq: list, intervals: list, tokenizer: AmtTokenizer
):
def adjust_onset(_onset: int):
# Adjusts the onset according to the silence intervals
for start, end in intervals:
if start <= _onset <= end:
return start

return _onset

if len(intervals) == 0:
return seq

res = []
logger = logging.getLogger(__name__)
active_notes = {pitch: False for pitch in range(0, 127)}
active_notes["pedal"] = False

for tok_1, tok_2, tok_3 in zip(
seq,
seq[1:] + [tokenizer.pad_tok],
seq[2:] + [tokenizer.pad_tok, tokenizer.pad_tok],
):
if isinstance(tok_1, tuple) is False:
res.append(tok_1)
continue
elif tok_1[0] == "prev":
res.append(tok_1)
active_notes[tok_1[1]] = True
continue
elif tok_1[0] in {"onset", "vel"}:
continue

if tok_1[0] == "pedal":
note_type = "on" if tok_1[1] == 1 else "off"
note_val = "pedal"
elif tok_1[0] in {"on", "off"}:
note_type = tok_1[0]
note_val = tok_1[1]

if note_type == "on":
# Check that the rest of the tokens are valid
if isinstance(tok_2, tuple) is False:
logger.debug(f"Invalid token sequence {tok_1}, {tok_2}")
continue
if note_val != "pedal" and isinstance(tok_3, tuple) is False:
logger.debug(
f"Invalid token sequence {tok_1}, {tok_2}, {tok_3}"
)
continue

# Don't add on if note is already on
if active_notes[note_val] is True:
continue

# Calculate adjusted onset and add if conditions are met
onset = tok_2[1]
onset_adj = adjust_onset(onset)
if onset != onset_adj:
continue
else:
active_notes[note_val] = True
res.append(tok_1)
res.append(tok_2)
if note_val != "pedal":
res.append(tok_3)

elif note_type == "off":
# Check that the rest of the tokens are valid
if isinstance(tok_2, tuple) is False and tok_2[0] != "onset":
logger.debug(f"Invalid token sequence {tok_1}, {tok_2}")
continue

# Don't add on if note is not on
if active_notes[note_val] is False:
continue

# Add note with adjusted offset
offset = tok_2[1]
offset_adj = adjust_onset(offset)
if offset != offset_adj:
logger.debug(
f"Adjusted offset of {tok_1}, {tok_2} -> {offset_adj}"
)
res.append(tok_1)
res.append(("onset", tokenizer._quantize_onset(offset_adj)))
active_notes[note_val] = False

return res


def get_silent_intervals(wav: torch.Tensor):
FRAME_LEN = 2048
HOP_LEN = 512
MIN_WINDOW_S = 5
MIN_WINDOW_STEPS = (SAMPLE_RATE // HOP_LEN) * MIN_WINDOW_S + 1
MS_PER_HOP = int((HOP_LEN * 1e3) / SAMPLE_RATE)

non_silent = _signal_to_frame_nonsilent(
wav.numpy(),
frame_length=FRAME_LEN,
hop_length=HOP_LEN,
top_db=30,
ref=np.max,
)
non_silent = np.concatenate(([True], non_silent, [True]))

edges = np.diff(non_silent.astype(int))
starts = np.where(edges == -1)[0]
ends = np.where(edges == 1)[0]

# Calculate lengths
lengths = ends - starts

# Filter intervals by minimum length
valid = lengths > MIN_WINDOW_STEPS
silent_intervals = [
(start * MS_PER_HOP, (end - 1) * MS_PER_HOP)
for start, end, vl in zip(starts, ends, valid)
if vl
]

return silent_intervals


def transcribe_file(
file_path,
gpu_task_queue: Queue,
Expand Down Expand Up @@ -463,7 +594,10 @@ def transcribe_file(
init_idx = len(seq)

# Add to gpu queue and wait for results
gpu_task_queue.put(((audio_segments.pop(0), seq), pid))
curr_audio_segment = audio_segments.pop(0)
silent_intervals = get_silent_intervals(curr_audio_segment)
input_seq = copy.deepcopy(seq)
gpu_task_queue.put(((curr_audio_segment, seq), pid))
while True:
try:
gpu_result = result_queue.get(timeout=0.1)
Expand All @@ -476,18 +610,47 @@ def transcribe_file(
else:
result_queue.put(gpu_result)

if len(silent_intervals) > 0:
logger.debug(
f"Seen silent intervals in segment {idx}: {silent_intervals}"
)

seq_raw = seq
seq = process_silent_intervals(
seq, intervals=silent_intervals, tokenizer=tokenizer
)

if len(seq) != len(seq_raw):
logger.info(
f"Removed tokens ({len(seq_raw)} -> {len(seq)}) "
f"in segment {idx} according to silence in intervals: "
f"{silent_intervals}",
)

try:
next_seq = _truncate_seq(
seq,
CHUNK_LEN_MS,
LEN_MS - CHUNK_LEN_MS,
)
except Exception as e:
logger.info(
f"Skipping segment {idx} (failed to transcribe): {file_path}"
)
logger.info(f"Failed to reconcile segment {idx}: {file_path}")
logger.debug(traceback.format_exc())
seq = [tokenizer.bos_tok]

try:
seq = _truncate_seq(
input_seq,
CHUNK_LEN_MS - 2,
CHUNK_LEN_MS,
)
except Exception as e:
seq = [tokenizer.bos_tok]
logger.info(
f"Failed to recover prompt, proceeding with default: {seq}"
)
else:
logger.info(f"Proceeding with prompt: {seq}")

else:
if seq[-1] == tokenizer.eos_tok:
logger.info(f"Seen eos_tok at segment {idx}: {file_path}")
Expand Down
2 changes: 1 addition & 1 deletion amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def msg_mixup(src: list):
raise Exception

random.shuffle(res) # Only includes prev toks
res.append(self.bos_tok) # Beggining of sequence
res.append(self.bos_tok) # Beginning of sequence

buffer = defaultdict(lambda: defaultdict(list))
for tok_1, tok_2, tok_3 in zip(
Expand Down
4 changes: 2 additions & 2 deletions config/models/medium-triple.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"n_audio_ctx": 1500,
"n_audio_state": 768,
"n_audio_head": 12,
"n_audio_layer": 4,
"n_audio_layer": 6,
"n_text_ctx": 4096,
"n_text_state": 768,
"n_text_head": 12,
"n_text_layer": 4
"n_text_layer": 6
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ torch >= 2.2
torchaudio
accelerate
psutil
librosa
mido
tqdm
orjson
Expand Down
Loading