diff --git a/amt/data.py b/amt/data.py index 5e74726..71982c5 100644 --- a/amt/data.py +++ b/amt/data.py @@ -56,7 +56,11 @@ def get_wav_mid_segments( # Create features total_samples = wav.shape[-1] res = [] - for idx in range(0, total_samples, num_samples // stride_factor): + for idx in range( + 0, + total_samples - (num_samples - (num_samples // stride_factor)), + num_samples // stride_factor, + ): audio_feature = pad_or_trim(wav[idx:], length=num_samples) if midi_dict is not None: mid_feature = tokenizer._tokenize_midi_dict( diff --git a/amt/infer.py b/amt/infer.py index 72c99b0..5f4dba1 100644 --- a/amt/infer.py +++ b/amt/infer.py @@ -5,15 +5,12 @@ import torch.multiprocessing as multiprocessing from torch.multiprocessing import Queue -from torch.cuda import device_count, is_available from tqdm import tqdm from amt.model import AmtEncoderDecoder from amt.tokenizer import AmtTokenizer from amt.audio import AudioTransform from amt.data import get_wav_mid_segments -from amt.config import load_config -from aria.data.midi import MidiDict MAX_SEQ_LEN = 4096 LEN_MS = 30000 @@ -24,8 +21,11 @@ VEL_TOLERANCE = 50 +# TODO: Profile and fix gpu util + + def calculate_vel( - logits: torch.tensor, + logits: torch.Tensor, init_vel: int, tokenizer: AmtTokenizer = AmtTokenizer(), ): @@ -51,13 +51,13 @@ def calculate_vel( vels = torch.tensor(vels).to(probs.device) new_vel = torch.sum(vels * probs) / torch.sum(probs) - new_vel = round(new_vel.item() / 10) * 10 + new_vel = round(new_vel.item() / 5) * 5 return tokenizer.tok_to_id[("vel", new_vel)] def calculate_onset( - logits: torch.tensor, + logits: torch.Tensor, init_onset: int, tokenizer: AmtTokenizer = AmtTokenizer(), ): @@ -88,6 +88,7 @@ def calculate_onset( return tokenizer.tok_to_id[("onset", new_onset)] +@torch.autocast("cuda", dtype=torch.bfloat16) def process_segments( tasks: list, model: AmtEncoderDecoder, @@ -111,14 +112,14 @@ def process_segments( kv_cache = model.get_empty_cache() - # for idx in ( - # pbar := tqdm( - # range(min_prefix_len, MAX_SEQ_LEN - 1), - # total=MAX_SEQ_LEN - (min_prefix_len + 1), - # leave=False, - # ) - # ): - for idx in range(min_prefix_len, MAX_SEQ_LEN - 1): + for idx in ( + pbar := tqdm( + range(min_prefix_len, MAX_SEQ_LEN - 1), + total=MAX_SEQ_LEN - (min_prefix_len + 1), + leave=False, + ) + ): + # for idx in range(min_prefix_len, MAX_SEQ_LEN - 1): if idx == min_prefix_len: logits = model.decoder( xa=audio_features, @@ -160,6 +161,12 @@ def process_segments( if all(eos_seen): break + if not all(eos_seen): + print("WARNING: OVERFLOW") + for _idx in range(seq.shape[0]): + if eos_seen[_idx] == False: + eos_seen[_idx] = MAX_SEQ_LEN + results = [ tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1]) for _idx in range(seq.shape[0]) @@ -174,9 +181,7 @@ def gpu_manager( model: AmtEncoderDecoder, batch_size: int, ): - model.cuda() - model.eval() - model.compile() + # model.compile() audio_transform = AudioTransform().cuda() tokenizer = AmtTokenizer(return_tensors=True) process_pid = multiprocessing.current_process().pid @@ -277,9 +282,6 @@ def process_file( seq[init_idx : seq.index(tokenizer.eos_tok)], idx * CHUNK_LEN_MS, ) - print( - f"{process_pid}: Finished {idx+1}/{len(audio_segments)} audio segments" - ) if idx == len(audio_segments) - 1: break @@ -310,8 +312,8 @@ def _get_save_path(_file_path: str): save_path = os.path.join( save_dir, os.path.splitext(input_rel_path)[0] + ".mid" ) - if not os.path.exists(os.path.dirname(save_path)): - os.makedirs(os.path.dirname(save_path)) + if not os.path.isdir(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path), exist_ok=True) return save_path @@ -361,7 +363,8 @@ def batch_transcribe( if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - model.to("cuda") + model.cuda() + model.eval() file_queue = Queue() for file_path in file_paths: file_queue.put(file_path) @@ -369,12 +372,6 @@ def batch_transcribe( gpu_task_queue = Queue() result_queue = Queue() - gpu_manager_process = multiprocessing.Process( - target=gpu_manager, - args=(gpu_task_queue, result_queue, model, batch_size), - ) - gpu_manager_process.start() - worker_processes = [ multiprocessing.Process( target=worker, @@ -391,6 +388,13 @@ def batch_transcribe( for p in worker_processes: p.start() + time.sleep(10) + gpu_manager_process = multiprocessing.Process( + target=gpu_manager, + args=(gpu_task_queue, result_queue, model, batch_size), + ) + gpu_manager_process.start() + for p in worker_processes: p.join() diff --git a/amt/run.py b/amt/run.py index c068b98..2199236 100644 --- a/amt/run.py +++ b/amt/run.py @@ -119,7 +119,9 @@ def transcribe(args): assert os.path.isfile(args.cp), "model checkpoint file not found" assert args.load_path or args.load_dir, "must give either load path or dir" if args.load_path: - assert os.path.isfile(args.load_path), "audio file not found" + assert os.path.isfile( + args.load_path + ), f"audio file not found: {args.load_path}" trans_mode = "single" if args.load_dir: assert os.path.isdir(args.load_dir), "load directory doesn't exist" @@ -201,6 +203,7 @@ def transcribe(args): model=model, save_dir=args.save_dir, batch_size=args.bs, + input_dir=args.load_dir, )