From c2d939af91933ab4cdd01090d64b6ac42765b500 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 27 Apr 2024 09:48:52 +0000 Subject: [PATCH 01/10] update README --- README.md | 13 +++++++++---- amt/data.py | 3 +++ scripts/{eval => }/split.py | 7 ++++--- 3 files changed, 16 insertions(+), 7 deletions(-) rename scripts/{eval => }/split.py (89%) diff --git a/README.md b/README.md index 831948c..2ea35e7 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,14 @@ pip install -e . Download the preliminary model weights: +Piano (not final) ``` -wget https://storage.googleapis.com/aria-checkpoints/amt/small-0.safetensors +wget https://storage.googleapis.com/aria-checkpoints/amt/guitar-temp.safetensors +``` + +Classical guitar (not final) +``` +wget https://storage.googleapis.com/aria-checkpoints/amt/piano-temp.safetensors ``` ## Usage @@ -39,7 +45,6 @@ aria-amt transcribe \ -q8 ``` -If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling may take some time, but provides a significant speedup. - -NOTE: Currently only bf16 is supported. +If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling and may take some time, but provides a significant speedup. Quantizing (`-q8`) further speeds up inference when the `-compile` flag is also used. +NOTE: Int8 quantization is only supported on GPUs that support BF16. diff --git a/amt/data.py b/amt/data.py index 4bf3ccb..6abb2f3 100644 --- a/amt/data.py +++ b/amt/data.py @@ -371,6 +371,9 @@ def build( num_processes: int = 1, ): assert os.path.isfile(save_path) is False, f"{save_path} already exists" + assert ( + len(save_path.rsplit(".", 1)) == 2 + ), "path is missing a file extension" index_path = AmtDataset._get_index_path(load_path=save_path) if os.path.isfile(index_path): diff --git a/scripts/eval/split.py b/scripts/split.py similarity index 89% rename from scripts/eval/split.py rename to scripts/split.py index c912cbc..ef688b3 100644 --- a/scripts/eval/split.py +++ b/scripts/split.py @@ -33,13 +33,13 @@ def get_matched_paths(audio_dir: str, mid_dir: str): return res -def create_csv(matched_paths, csv_path): +def create_csv(matched_paths, csv_path, ratio): split_csv = open(csv_path, "w") csv_writer = csv.writer(split_csv) csv_writer.writerow(["mid_path", "audio_path", "split"]) for audio_path, mid_path in matched_paths: - if random.random() < 0.1: + if random.random() < ratio: csv_writer.writerow([mid_path, audio_path, "test"]) else: csv_writer.writerow([mid_path, audio_path, "train"]) @@ -50,8 +50,9 @@ def create_csv(matched_paths, csv_path): parser.add_argument("-mid_dir", type=str) parser.add_argument("-audio_dir", type=str) parser.add_argument("-csv_path", type=str) + parser.add_argument("-ratio", type=int, default=0.1) args = parser.parse_args() matched_paths = get_matched_paths(args.audio_dir, args.mid_dir) - create_csv(matched_paths, args.csv_path) + create_csv(matched_paths, args.csv_path, args.ratio) From 148db775c6d4eda654ce27b4e6f9b93a791fd1ba Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 27 Apr 2024 09:50:34 +0000 Subject: [PATCH 02/10] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2ea35e7..7327587 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,6 @@ aria-amt transcribe \ -q8 ``` -If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling and may take some time, but provides a significant speedup. Quantizing (`-q8`) further speeds up inference when the `-compile` flag is also used. +If you want to do batch transcription, use the `-load_dir` flag and adjust `-bs` accordingly. Compiling and may take some time, but provides a significant speedup. Quantizing (`-q8` flag) further speeds up inference when the `-compile` flag is also used. NOTE: Int8 quantization is only supported on GPUs that support BF16. From 69dca6358acfdf7e84c4629d485eb716da52ac9b Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 27 Apr 2024 09:59:30 +0000 Subject: [PATCH 03/10] add fp16 --- amt/inference/model.py | 3 +++ amt/inference/transcribe.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/amt/inference/model.py b/amt/inference/model.py index e8390f4..44655c6 100644 --- a/amt/inference/model.py +++ b/amt/inference/model.py @@ -386,6 +386,7 @@ def setup_cache( batch_size, max_seq_len=4096, max_audio_len=1500, + dtype=torch.bfloat16, ): self.causal_mask = torch.tril( torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) @@ -397,12 +398,14 @@ def setup_cache( max_seq_length=max_seq_len, n_heads=8, head_dim=64, + dtype=dtype, ).cuda() b.cross_attn.kv_cache = KVCache( max_batch_size=batch_size, max_seq_length=max_audio_len, n_heads=8, head_dim=64, + dtype=dtype, ).cuda() diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 622b005..ba50102 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -132,7 +132,7 @@ def wrapper(*args, **kwargs): with torch.autocast("cuda", dtype=torch.bfloat16): return func(*args, **kwargs) else: - with torch.autocast("cuda", dtype=torch.float32): + with torch.autocast("cuda", dtype=torch.float16): return func(*args, **kwargs) return wrapper @@ -265,7 +265,11 @@ def gpu_manager( if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) - model.decoder.setup_cache(batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN) + model.decoder.setup_cache( + batch_size=batch_size, + max_seq_len=MAX_BLOCK_LEN, + dtype=torch.bfloat16 if is_bf16_supported() else torch.float16, + ) model.cuda() model.eval() if compile is True: From 4b0082bfdd0210c7b247613f8405ebf61fddb842 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 4 May 2024 19:27:33 +0000 Subject: [PATCH 04/10] inference modification --- amt/inference/model.py | 10 ++++++---- amt/inference/transcribe.py | 4 ++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/amt/inference/model.py b/amt/inference/model.py index 44655c6..8819dd5 100644 --- a/amt/inference/model.py +++ b/amt/inference/model.py @@ -344,6 +344,8 @@ def __init__( self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int ): super().__init__() + self.n_head = n_head + self.n_state = n_state self.token_embedding = nn.Embedding(n_vocab, n_state) self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) @@ -396,15 +398,15 @@ def setup_cache( b.attn.kv_cache = KVCache( max_batch_size=batch_size, max_seq_length=max_seq_len, - n_heads=8, - head_dim=64, + n_heads=self.n_head, + head_dim=self.n_state // self.n_head, dtype=dtype, ).cuda() b.cross_attn.kv_cache = KVCache( max_batch_size=batch_size, max_seq_length=max_audio_len, - n_heads=8, - head_dim=64, + n_heads=self.n_head, + head_dim=self.n_state // self.n_head, dtype=dtype, ).cuda() diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index ba50102..d4ea3cd 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -218,6 +218,9 @@ def process_segments( ), ) + logits[:, 389] *= 1.2 + next_tok_ids = torch.argmax(logits, dim=-1) + next_tok_ids = recalculate_tok_ids( logits=logits, tok_ids=next_tok_ids, @@ -683,6 +686,7 @@ def batch_transcribe( model.decoder = quantize_int8(model.decoder) file_queue = Queue() + sorted(file_paths, key=lambda x: os.path.getsize(x), reverse=True) for file_path in file_paths: if ( os.path.isfile(get_save_path(file_path, input_dir, save_dir)) From 17b271f60100f46092dbe891ba7c2e51c636cef1 Mon Sep 17 00:00:00 2001 From: Louis Date: Sat, 4 May 2024 19:38:01 +0000 Subject: [PATCH 05/10] add stacked mels --- amt/audio.py | 57 +++++++++----- amt/data.py | 2 + config/config.json | 4 +- config/models/medium-stacked.json | 11 +++ .../models/{medium-final.json => medium.json} | 4 +- tests/test_data.py | 78 ++++++++++++++----- 6 files changed, 117 insertions(+), 39 deletions(-) create mode 100644 config/models/medium-stacked.json rename config/models/{medium-final.json => medium.json} (77%) diff --git a/amt/audio.py b/amt/audio.py index 28913ab..7bb2a47 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -82,6 +82,10 @@ def __init__( self.config = load_config()["audio"] self.sample_rate = self.config["sample_rate"] self.chunk_len = self.config["chunk_len"] + self.n_fft = self.config["n_fft"] + self.n_fft_reduced = self.config["n_fft_reduced"] + self.n_mels = self.config["n_mels"] + self.n_mels_reduced = self.config["n_mels_reduced"] self.num_samples = self.sample_rate * self.chunk_len self.noise_ratio = noise_ratio @@ -95,7 +99,7 @@ def __init__( self.spec_aug_ratio = spec_aug_ratio self.time_mask_param = 2500 - self.freq_mask_param = 15 + self.freq_mask_param = 0 self.reduction_resample_rate = 6000 # Audio aug @@ -126,13 +130,26 @@ def __init__( self.num_applause += 1 self.spec_transform = torchaudio.transforms.Spectrogram( - n_fft=self.config["n_fft"], + n_fft=self.n_fft, hop_length=self.config["hop_len"], ) self.mel_transform = torchaudio.transforms.MelScale( - n_mels=self.config["n_mels"], - sample_rate=self.config["sample_rate"], - n_stft=self.config["n_fft"] // 2 + 1, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + n_stft=self.n_fft // 2 + 1, + f_min=30, + f_max=8000, + ) + self.spec_transform_reduced = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_reduced, + hop_length=self.config["hop_len"], + ) + self.mel_transform_reduced = torchaudio.transforms.MelScale( + n_mels=self.n_mels_reduced, + sample_rate=self.sample_rate, + n_stft=self.n_fft_reduced // 2 + 1, + f_min=30, + f_max=8000, ) self.spec_aug = torch.nn.Sequential( torchaudio.transforms.TimeMasking( @@ -315,16 +332,9 @@ def shift_spec(self, specs: torch.Tensor, shift: int | float): return shifted_specs - def detune_spec(self, specs: torch.Tensor): - if random.random() < self.detune_ratio: - detune_shift = random.uniform( - -self.detune_max_shift, self.detune_max_shift - ) - detuned_specs = self.shift_spec(specs, shift=detune_shift) - - return (specs + detuned_specs) / 2 - else: - return specs + def detune_spec(self, specs: torch.Tensor, detune_shift: float): + detuned_specs = self.shift_spec(specs, shift=detune_shift) + return (specs + detuned_specs) / 2 def aug_wav(self, wav: torch.Tensor): # This function doesn't apply distortion. If distortion is desired it @@ -361,19 +371,30 @@ def log_mel( self, wav: torch.Tensor, shift: int | None = None, detune: bool = False ): spec = self.spec_transform(wav)[..., :-1] + spec_reduced = self.spec_transform_reduced(wav)[..., :-1] if shift is not None and shift != 0: spec = self.shift_spec(spec, shift) + spec_reduced = self.shift_spec(spec_reduced, shift) elif detune is True: # Don't detune and spec shift at the same time - spec = self.detune_spec(spec) + if random.random() < self.detune_ratio: + detune_shift = random.uniform( + -self.detune_max_shift, self.detune_max_shift + ) + spec = self.detune_spec(spec, detune_shift=detune_shift) + spec_reduced = self.detune_spec( + spec_reduced, detune_shift=detune_shift + ) mel_spec = self.mel_transform(spec) + mel_spec_reduced = self.mel_transform_reduced(spec_reduced) # Norm - log_spec = self.norm_mel(mel_spec) + concat_mel = torch.cat((mel_spec, mel_spec_reduced), dim=1) + log_mel = self.norm_mel(concat_mel) - return log_spec + return log_mel def forward(self, wav: torch.Tensor, shift: int = 0): # Noise, and reverb diff --git a/amt/data.py b/amt/data.py index 6abb2f3..c966e96 100644 --- a/amt/data.py +++ b/amt/data.py @@ -435,6 +435,8 @@ def build( print("The GNU cat command is not available") else: for _path in sharded_save_paths: + if os.path.isfile(_path) is False: + continue shell_cmd = f"cat {_path} >> {save_path}" os.system(shell_cmd) os.remove(_path) diff --git a/config/config.json b/config/config.json index 9442648..7503777 100644 --- a/config/config.json +++ b/config/config.json @@ -12,9 +12,11 @@ "audio": { "sample_rate": 16000, "n_fft": 2048, + "n_fft_reduced": 800, "hop_len": 160, "chunk_len": 30, - "n_mels": 256 + "n_mels": 384, + "n_mels_reduced": 128 }, "data": { "stride_factor": 15, diff --git a/config/models/medium-stacked.json b/config/models/medium-stacked.json new file mode 100644 index 0000000..0cf3a51 --- /dev/null +++ b/config/models/medium-stacked.json @@ -0,0 +1,11 @@ +{ + "n_mels": 512, + "n_audio_ctx": 1500, + "n_audio_state": 768, + "n_audio_head": 12, + "n_audio_layer": 6, + "n_text_ctx": 4096, + "n_text_state": 768, + "n_text_head": 12, + "n_text_layer": 6 +} \ No newline at end of file diff --git a/config/models/medium-final.json b/config/models/medium.json similarity index 77% rename from config/models/medium-final.json rename to config/models/medium.json index 69b79c7..a0b3857 100644 --- a/config/models/medium-final.json +++ b/config/models/medium.json @@ -3,9 +3,9 @@ "n_audio_ctx": 1500, "n_audio_state": 768, "n_audio_head": 12, - "n_audio_layer": 12, + "n_audio_layer": 6, "n_text_ctx": 4096, "n_text_state": 768, "n_text_head": 12, - "n_text_layer": 12 + "n_text_layer": 6 } \ No newline at end of file diff --git a/tests/test_data.py b/tests/test_data.py index 9c566ad..56b9eba 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -20,16 +20,44 @@ if os.path.isdir("tests/test_results") is False: os.mkdir("tests/test_results") -MAESTRO_PATH = "/mnt/ssd1/amt/training_data/maestro/train-s15.txt" - +MAESTRO_PATH = "/home/loubb/work/aria-amt/temp/train.txt" + + +def plot_spec( + mel: torch.Tensor, + name: str | int, + onsets: list = [], + offsets: list = [], +): + # mel tensor dimensions [height, width] + + height, width = mel.shape + fig_width, fig_height = width // 100, height // 100 + plt.figure(figsize=(fig_width, fig_height), dpi=100) + plt.imshow( + mel, aspect="auto", origin="lower", cmap="viridis", interpolation="none" + ) + + line_width_in_points = 1 / 100 * 72 # Convert pixel width to points + + for x in onsets: + plt.axvline( + x=x, + color="red", + alpha=0.5, + linewidth=line_width_in_points, # setting the correct line width + ) + for x in offsets: + plt.axvline( + x=x, + color="purple", + alpha=0.5, + linewidth=line_width_in_points, # setting the correct line width + ) -def plot_spec(mel: torch.Tensor, name: str | int): - plt.figure(figsize=(10, 4)) - plt.imshow(mel, aspect="auto", origin="lower", cmap="viridis") - plt.colorbar(format="%+2.0f dB") - plt.title("(mel)-Spectrogram") - plt.tight_layout() - plt.savefig(f"tests/test_results/{name}.png") + plt.axis("off") + plt.tight_layout(pad=0) + plt.savefig(f"tests/test_results/{name}.png", dpi=100) plt.close() @@ -184,7 +212,7 @@ def test_spec(self): spec = audio_transform.spec_transform(wav) shift_spec = audio_transform.shift_spec(spec, 1) - shift_wav = griffin_lim(shift_spec) + shift_wav = griffin_lim(shift_spec[..., :384]) torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) torchaudio.save("tests/test_results/shift.wav", shift_wav, SAMPLE_RATE) @@ -232,28 +260,42 @@ def test_detune(self): spec = audio_transform.spec_transform(wav) shift_spec = audio_transform.detune_spec(spec) shift_wav = griffin_lim(shift_spec) - gl_wav = griffin_lim(spec) + gl_wav = griffin_lim(spec[..., :384]) torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) torchaudio.save("tests/test_results/orig_gl.wav", gl_wav, SAMPLE_RATE) torchaudio.save("tests/test_results/detune.wav", shift_wav, SAMPLE_RATE) def test_mels(self): - SAMPLE_RATE, CHUNK_LEN = 16000, 30 audio_transform = AudioTransform() + SAMPLE_RATE, N_FFT, CHUNK_LEN = ( + audio_transform.sample_rate, + audio_transform.n_fft, + 30, + ) wav, sr = torchaudio.load("tests/test_data/maestro.wav") wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean( 0, keepdim=True )[:, : SAMPLE_RATE * CHUNK_LEN] - wav_aug = audio_transform.aug_wav( - audio_transform.distortion_aug_cpu(wav) - ) - torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE) - torchaudio.save("tests/test_results/aug.wav", wav_aug, SAMPLE_RATE) + + # tokenizer = AmtTokenizer() + # mid_dict = MidiDict.from_midi("tests/test_data/maestro-test.mid") + # seq = tokenizer._tokenize_midi_dict(mid_dict, 0, 30000, 10000) + # mid_dict = tokenizer._detokenize_midi_dict(seq, 30000) + # onsets = [msg["data"]["start"] // 10 for msg in mid_dict.note_msgs] + # offsets = [ + # msg["data"]["end"] // 10 + # for msg in mid_dict.note_msgs + # if msg["data"]["end"] < 30000 + # ] wavs = torch.stack((wav[0], wav[0], wav[0])) mels = audio_transform(wavs) for idx in range(mels.shape[0]): - plot_spec(mels[idx], idx) + plot_spec( + mels[idx], + f"{mels[0].shape[0]}-{N_FFT}-{SAMPLE_RATE}", + ) + break def test_distortion(self): SAMPLE_RATE, CHUNK_LEN = 16000, 30 From 8e512b10faa63c372397aa70ecd1dbc1c7af0d9f Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 7 May 2024 10:07:22 +0000 Subject: [PATCH 06/10] triple --- amt/audio.py | 78 ++++++++++++++++++++++---------- config/config.json | 10 ++-- config/models/medium-triple.json | 11 +++++ 3 files changed, 70 insertions(+), 29 deletions(-) create mode 100644 config/models/medium-triple.json diff --git a/amt/audio.py b/amt/audio.py index 7bb2a47..d0c1a16 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -14,7 +14,7 @@ # hard-coded audio hyperparameters config = load_config()["audio"] SAMPLE_RATE = config["sample_rate"] -N_FFT = config["n_fft"] +N_FFT = config["n_fft_large"] HOP_LENGTH = config["hop_len"] CHUNK_LENGTH = config["chunk_len"] N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk @@ -82,10 +82,12 @@ def __init__( self.config = load_config()["audio"] self.sample_rate = self.config["sample_rate"] self.chunk_len = self.config["chunk_len"] - self.n_fft = self.config["n_fft"] - self.n_fft_reduced = self.config["n_fft_reduced"] - self.n_mels = self.config["n_mels"] - self.n_mels_reduced = self.config["n_mels_reduced"] + self.n_fft_large = self.config["n_fft_large"] + self.n_fft_med = self.config["n_fft_med"] + self.n_fft_small = self.config["n_fft_small"] + self.n_mels_large = self.config["n_mels_large"] + self.n_mels_med = self.config["n_mels_med"] + self.n_mels_small = self.config["n_mels_small"] self.num_samples = self.sample_rate * self.chunk_len self.noise_ratio = noise_ratio @@ -129,28 +131,39 @@ def __init__( self.register_buffer(f"applause_{i}", applause) self.num_applause += 1 - self.spec_transform = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft, + self.spec_transform_large = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_large, hop_length=self.config["hop_len"], ) - self.mel_transform = torchaudio.transforms.MelScale( - n_mels=self.n_mels, + self.mel_transform_large = torchaudio.transforms.MelScale( + n_mels=self.n_mels_large, sample_rate=self.sample_rate, - n_stft=self.n_fft // 2 + 1, + n_stft=self.n_fft_large // 2 + 1, f_min=30, f_max=8000, ) - self.spec_transform_reduced = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft_reduced, + self.spec_transform_med = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_med, hop_length=self.config["hop_len"], ) - self.mel_transform_reduced = torchaudio.transforms.MelScale( - n_mels=self.n_mels_reduced, + self.mel_transform_med = torchaudio.transforms.MelScale( + n_mels=self.n_mels_med, sample_rate=self.sample_rate, - n_stft=self.n_fft_reduced // 2 + 1, + n_stft=self.n_fft_med // 2 + 1, f_min=30, f_max=8000, ) + self.spec_transform_small = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft_small, + hop_length=self.config["hop_len"], + ) + self.mel_transform_small = torchaudio.transforms.MelScale( + n_mels=self.n_mels_small, + sample_rate=self.sample_rate, + n_stft=self.n_fft_small // 2 + 1, + f_min=30, + f_max=4000, + ) self.spec_aug = torch.nn.Sequential( torchaudio.transforms.TimeMasking( time_mask_param=self.time_mask_param, @@ -370,28 +383,43 @@ def norm_mel(self, mel_spec: torch.Tensor): def log_mel( self, wav: torch.Tensor, shift: int | None = None, detune: bool = False ): - spec = self.spec_transform(wav)[..., :-1] - spec_reduced = self.spec_transform_reduced(wav)[..., :-1] + spec_large = self.spec_transform_large(wav)[..., :-1] + spec_med = self.spec_transform_med(wav)[..., :-1] + spec_small = self.spec_transform_small(wav)[..., :-1] if shift is not None and shift != 0: - spec = self.shift_spec(spec, shift) - spec_reduced = self.shift_spec(spec_reduced, shift) + spec_large = self.shift_spec(spec_large, shift) + spec_med = self.shift_spec(spec_med, shift) + spec_small = self.shift_spec(spec_small, shift) elif detune is True: # Don't detune and spec shift at the same time if random.random() < self.detune_ratio: detune_shift = random.uniform( -self.detune_max_shift, self.detune_max_shift ) - spec = self.detune_spec(spec, detune_shift=detune_shift) - spec_reduced = self.detune_spec( - spec_reduced, detune_shift=detune_shift + spec_large = self.detune_spec( + spec_large, + detune_shift=detune_shift, + ) + spec_med = self.detune_spec( + spec_med, + detune_shift=detune_shift, + ) + spec_small = self.detune_spec( + spec_small, + detune_shift=detune_shift, ) - mel_spec = self.mel_transform(spec) - mel_spec_reduced = self.mel_transform_reduced(spec_reduced) + mel_spec_large = self.mel_transform_large(spec_large) + mel_spec_med = self.mel_transform_med(spec_med) + mel_spec_small = self.mel_transform_small(spec_small) # Norm - concat_mel = torch.cat((mel_spec, mel_spec_reduced), dim=1) + concat_mel = torch.cat( + (mel_spec_large, mel_spec_med, mel_spec_small), + # (mel_spec_large, mel_spec_small), + dim=1, + ) log_mel = self.norm_mel(concat_mel) return log_mel diff --git a/config/config.json b/config/config.json index 7503777..23c56d9 100644 --- a/config/config.json +++ b/config/config.json @@ -11,12 +11,14 @@ }, "audio": { "sample_rate": 16000, - "n_fft": 2048, - "n_fft_reduced": 800, + "n_fft_large": 4096, + "n_fft_med": 2048, + "n_fft_small": 768, "hop_len": 160, "chunk_len": 30, - "n_mels": 384, - "n_mels_reduced": 128 + "n_mels_large": 384, + "n_mels_med": 256, + "n_mels_small": 128 }, "data": { "stride_factor": 15, diff --git a/config/models/medium-triple.json b/config/models/medium-triple.json new file mode 100644 index 0000000..5f92924 --- /dev/null +++ b/config/models/medium-triple.json @@ -0,0 +1,11 @@ +{ + "n_mels": 768, + "n_audio_ctx": 1500, + "n_audio_state": 768, + "n_audio_head": 12, + "n_audio_layer": 4, + "n_text_ctx": 4096, + "n_text_state": 768, + "n_text_head": 12, + "n_text_layer": 4 +} \ No newline at end of file From dd5b9347d4d623d5f3aa4a571f0bc5c20224ba55 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 17 May 2024 18:39:07 +0000 Subject: [PATCH 07/10] add silence detection --- amt/inference/model.py | 13 ++- amt/inference/transcribe.py | 183 +++++++++++++++++++++++++++++-- amt/tokenizer.py | 2 +- config/models/medium-triple.json | 4 +- requirements.txt | 1 + 5 files changed, 186 insertions(+), 17 deletions(-) diff --git a/amt/inference/model.py b/amt/inference/model.py index 8819dd5..6b0d6b5 100644 --- a/amt/inference/model.py +++ b/amt/inference/model.py @@ -36,9 +36,14 @@ def __init__( dtype=torch.bfloat16, ): super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + self.dtype = dtype + self.cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer( + "k_cache", torch.zeros(self.cache_shape, dtype=dtype) + ) + self.register_buffer( + "v_cache", torch.zeros(self.cache_shape, dtype=dtype) + ) def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val, v_val: [B, H, L, D] @@ -118,7 +123,7 @@ def forward( class CrossAttention(nn.Module): def __init__(self, n_state: int, n_head: int): super().__init__() - assert n_state % n_head == 0, "n_head does not evenly devide n_state" + assert n_state % n_head == 0, "n_head does not evenly divide n_state" self.n_head = n_head self.d_head = n_state // n_head diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index d4ea3cd..945ad0f 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -1,6 +1,7 @@ import os import signal import time +import copy import random import logging import traceback @@ -9,15 +10,17 @@ import torch.multiprocessing as multiprocessing import torch._dynamo.config import torch._inductor.config +import numpy as np from torch.multiprocessing import Queue from tqdm import tqdm from functools import wraps from torch.cuda import is_bf16_supported +from librosa.effects import _signal_to_frame_nonsilent from amt.inference.model import AmtEncoderDecoder from amt.tokenizer import AmtTokenizer -from amt.audio import AudioTransform +from amt.audio import AudioTransform, SAMPLE_RATE from amt.data import get_wav_mid_segments torch._inductor.config.coordinate_descent_tuning = True @@ -78,8 +81,8 @@ def recalculate_tok_ids( # Mask out tok_ids larger than 30ms from original tok_id tok_ids_expanded = tok_ids.unsqueeze(1) - mask_c = col_indices <= tok_ids_expanded + 3 - mask_d = col_indices >= tok_ids_expanded - 3 + mask_c = col_indices <= tok_ids_expanded + 2 + mask_d = col_indices >= tok_ids_expanded - 2 beam_mask = mask_c & mask_d # Don't mask out the original tok_id (required for non-onset/vel toks) @@ -218,8 +221,8 @@ def process_segments( ), ) - logits[:, 389] *= 1.2 - next_tok_ids = torch.argmax(logits, dim=-1) + # logits[:, 389] *= 1.05 + # next_tok_ids = torch.argmax(logits, dim=-1) next_tok_ids = recalculate_tok_ids( logits=logits, @@ -429,6 +432,7 @@ def _truncate_seq( if len(_mid_dict.note_msgs) == 0: return [tokenizer.bos_tok] else: + # The end_ms - 1 is a workaround to get rid of the off msgs res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1) if res[-1] == tokenizer.eos_tok: @@ -436,6 +440,133 @@ def _truncate_seq( return res +# This is a sloppy implementation +def process_silent_intervals( + seq: list, intervals: list, tokenizer: AmtTokenizer +): + def adjust_onset(_onset: int): + # Adjusts the onset according to the silence intervals + for start, end in intervals: + if start <= _onset <= end: + return start + + return _onset + + if len(intervals) == 0: + return seq + + res = [] + logger = logging.getLogger(__name__) + active_notes = {pitch: False for pitch in range(0, 127)} + active_notes["pedal"] = False + + for tok_1, tok_2, tok_3 in zip( + seq, + seq[1:] + [tokenizer.pad_tok], + seq[2:] + [tokenizer.pad_tok, tokenizer.pad_tok], + ): + if isinstance(tok_1, tuple) is False: + res.append(tok_1) + continue + elif tok_1[0] == "prev": + res.append(tok_1) + active_notes[tok_1[1]] = True + continue + elif tok_1[0] in {"onset", "vel"}: + continue + + if tok_1[0] == "pedal": + note_type = "on" if tok_1[1] == 1 else "off" + note_val = "pedal" + elif tok_1[0] in {"on", "off"}: + note_type = tok_1[0] + note_val = tok_1[1] + + if note_type == "on": + # Check that the rest of the tokens are valid + if isinstance(tok_2, tuple) is False: + logger.debug(f"Invalid token sequence {tok_1}, {tok_2}") + continue + if note_val != "pedal" and isinstance(tok_3, tuple) is False: + logger.debug( + f"Invalid token sequence {tok_1}, {tok_2}, {tok_3}" + ) + continue + + # Don't add on if note is already on + if active_notes[note_val] is True: + continue + + # Calculate adjusted onset and add if conditions are met + onset = tok_2[1] + onset_adj = adjust_onset(onset) + if onset != onset_adj: + continue + else: + active_notes[note_val] = True + res.append(tok_1) + res.append(tok_2) + if note_val != "pedal": + res.append(tok_3) + + elif note_type == "off": + # Check that the rest of the tokens are valid + if isinstance(tok_2, tuple) is False and tok_2[0] != "onset": + logger.debug(f"Invalid token sequence {tok_1}, {tok_2}") + continue + + # Don't add on if note is not on + if active_notes[note_val] is False: + continue + + # Add note with adjusted offset + offset = tok_2[1] + offset_adj = adjust_onset(offset) + if offset != offset_adj: + logger.debug( + f"Adjusted offset of {tok_1}, {tok_2} -> {offset_adj}" + ) + res.append(tok_1) + res.append(("onset", tokenizer._quantize_onset(offset_adj))) + active_notes[note_val] = False + + return res + + +def get_silent_intervals(wav: torch.Tensor): + FRAME_LEN = 2048 + HOP_LEN = 512 + MIN_WINDOW_S = 5 + MIN_WINDOW_STEPS = (SAMPLE_RATE // HOP_LEN) * MIN_WINDOW_S + 1 + MS_PER_HOP = int((HOP_LEN * 1e3) / SAMPLE_RATE) + + non_silent = _signal_to_frame_nonsilent( + wav.numpy(), + frame_length=FRAME_LEN, + hop_length=HOP_LEN, + top_db=30, + ref=np.max, + ) + non_silent = np.concatenate(([True], non_silent, [True])) + + edges = np.diff(non_silent.astype(int)) + starts = np.where(edges == -1)[0] + ends = np.where(edges == 1)[0] + + # Calculate lengths + lengths = ends - starts + + # Filter intervals by minimum length + valid = lengths > MIN_WINDOW_STEPS + silent_intervals = [ + (start * MS_PER_HOP, (end - 1) * MS_PER_HOP) + for start, end, vl in zip(starts, ends, valid) + if vl + ] + + return silent_intervals + + def transcribe_file( file_path, gpu_task_queue: Queue, @@ -463,7 +594,10 @@ def transcribe_file( init_idx = len(seq) # Add to gpu queue and wait for results - gpu_task_queue.put(((audio_segments.pop(0), seq), pid)) + curr_audio_segment = audio_segments.pop(0) + silent_intervals = get_silent_intervals(curr_audio_segment) + input_seq = copy.deepcopy(seq) + gpu_task_queue.put(((curr_audio_segment, seq), pid)) while True: try: gpu_result = result_queue.get(timeout=0.1) @@ -476,6 +610,23 @@ def transcribe_file( else: result_queue.put(gpu_result) + if len(silent_intervals) > 0: + logger.debug( + f"Seen silent intervals in segment {idx}: {silent_intervals}" + ) + + seq_raw = seq + seq = process_silent_intervals( + seq, intervals=silent_intervals, tokenizer=tokenizer + ) + + if len(seq) != len(seq_raw): + logger.info( + f"Removed tokens ({len(seq_raw)} -> {len(seq)}) " + f"in segment {idx} according to silence in intervals: " + f"{silent_intervals}", + ) + try: next_seq = _truncate_seq( seq, @@ -483,11 +634,23 @@ def transcribe_file( LEN_MS - CHUNK_LEN_MS, ) except Exception as e: - logger.info( - f"Skipping segment {idx} (failed to transcribe): {file_path}" - ) + logger.info(f"Failed to reconcile segment {idx}: {file_path}") logger.debug(traceback.format_exc()) - seq = [tokenizer.bos_tok] + + try: + seq = _truncate_seq( + input_seq, + CHUNK_LEN_MS - 2, + CHUNK_LEN_MS, + ) + except Exception as e: + seq = [tokenizer.bos_tok] + logger.info( + f"Failed to recover prompt, proceeding with default: {seq}" + ) + else: + logger.info(f"Proceeding with prompt: {seq}") + else: if seq[-1] == tokenizer.eos_tok: logger.info(f"Seen eos_tok at segment {idx}: {file_path}") diff --git a/amt/tokenizer.py b/amt/tokenizer.py index dd48a93..dd34d0b 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -447,7 +447,7 @@ def msg_mixup(src: list): raise Exception random.shuffle(res) # Only includes prev toks - res.append(self.bos_tok) # Beggining of sequence + res.append(self.bos_tok) # Beginning of sequence buffer = defaultdict(lambda: defaultdict(list)) for tok_1, tok_2, tok_3 in zip( diff --git a/config/models/medium-triple.json b/config/models/medium-triple.json index 5f92924..463f65d 100644 --- a/config/models/medium-triple.json +++ b/config/models/medium-triple.json @@ -3,9 +3,9 @@ "n_audio_ctx": 1500, "n_audio_state": 768, "n_audio_head": 12, - "n_audio_layer": 4, + "n_audio_layer": 6, "n_text_ctx": 4096, "n_text_state": 768, "n_text_head": 12, - "n_text_layer": 4 + "n_text_layer": 6 } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 696fb40..96a7801 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ torch >= 2.2 torchaudio accelerate psutil +librosa mido tqdm orjson From f45febd9d99224251a691e2676ec5d54e435b017 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 24 May 2024 16:51:01 +0000 Subject: [PATCH 08/10] update inference --- amt/audio.py | 7 +++++-- amt/inference/transcribe.py | 6 +++--- config/config.json | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/amt/audio.py b/amt/audio.py index d0c1a16..9ab6170 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -131,6 +131,9 @@ def __init__( self.register_buffer(f"applause_{i}", applause) self.num_applause += 1 + # 256 - 0-8000 2048-256 + # 512 - 30-8000 2048-384 30-8000 800-128 + # 764 - 30-8000 4096-384 30-8000 2048-256 30-4000 768-128 self.spec_transform_large = torchaudio.transforms.Spectrogram( n_fft=self.n_fft_large, hop_length=self.config["hop_len"], @@ -416,8 +419,8 @@ def log_mel( # Norm concat_mel = torch.cat( - (mel_spec_large, mel_spec_med, mel_spec_small), - # (mel_spec_large, mel_spec_small), + # (mel_spec_large, mel_spec_med, mel_spec_small), + (mel_spec_large, mel_spec_small), dim=1, ) log_mel = self.norm_mel(concat_mel) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 945ad0f..1cd1b27 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -221,8 +221,8 @@ def process_segments( ), ) - # logits[:, 389] *= 1.05 - # next_tok_ids = torch.argmax(logits, dim=-1) + logits[:, 389] *= 1.05 + next_tok_ids = torch.argmax(logits, dim=-1) next_tok_ids = recalculate_tok_ids( logits=logits, @@ -544,7 +544,7 @@ def get_silent_intervals(wav: torch.Tensor): wav.numpy(), frame_length=FRAME_LEN, hop_length=HOP_LEN, - top_db=30, + top_db=45, ref=np.max, ) non_silent = np.concatenate(([True], non_silent, [True])) diff --git a/config/config.json b/config/config.json index 23c56d9..d2c24a0 100644 --- a/config/config.json +++ b/config/config.json @@ -11,9 +11,9 @@ }, "audio": { "sample_rate": 16000, - "n_fft_large": 4096, + "n_fft_large": 2048, "n_fft_med": 2048, - "n_fft_small": 768, + "n_fft_small": 800, "hop_len": 160, "chunk_len": 30, "n_mels_large": 384, From 837a18b5748406b0fe544ef4f9151d5ac0474275 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 24 May 2024 16:51:53 +0000 Subject: [PATCH 09/10] update README --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 7327587..bcd3c9e 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,8 @@ pip install -e . Download the preliminary model weights: -Piano (not final) -``` -wget https://storage.googleapis.com/aria-checkpoints/amt/guitar-temp.safetensors -``` +Piano (v1) -Classical guitar (not final) ``` wget https://storage.googleapis.com/aria-checkpoints/amt/piano-temp.safetensors ``` From 8b9f0db3485ed5e17c00bae0ab7f4555f9842458 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 24 May 2024 16:54:59 +0000 Subject: [PATCH 10/10] update scripts --- README.md | 2 +- amt/mir.py | 246 ++++++++++++++++++++++++++++++++++++++++++++ amt/run.py | 240 ++++++++++++++++++++++++++++++++---------- scripts/eval/mir.py | 128 ----------------------- scripts/split.py | 95 ++++++++++++++--- 5 files changed, 516 insertions(+), 195 deletions(-) create mode 100644 amt/mir.py delete mode 100644 scripts/eval/mir.py diff --git a/README.md b/README.md index bcd3c9e..029d059 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ You can then transcribe using the cli: ``` aria-amt transcribe \ - small-final \ + medium-stacked \ \ -load_path \ -save_dir \ diff --git a/amt/mir.py b/amt/mir.py new file mode 100644 index 0000000..d92c09e --- /dev/null +++ b/amt/mir.py @@ -0,0 +1,246 @@ +import glob +from tqdm.auto import tqdm +import pretty_midi +import numpy as np +import mir_eval +import json +import os + +from aria.data.midi import MidiDict, get_duration_ms + +pretty_midi.pretty_midi.MAX_TICK = 1e10 + + +def midi_to_intervals_and_pitches(midi_file_path): + mid_dict = MidiDict.from_midi(midi_file_path) + mid_dict.resolve_pedal() + + intervals, pitches, velocities = [], [], [] + for note_msg in mid_dict.note_msgs: + pitch = note_msg["data"]["pitch"] + onset_s = ( + get_duration_ms( + start_tick=0, + end_tick=note_msg["data"]["start"], + tempo_msgs=mid_dict.tempo_msgs, + ticks_per_beat=mid_dict.ticks_per_beat, + ) + * 1e-3 + ) + offset_s = ( + get_duration_ms( + start_tick=0, + end_tick=note_msg["data"]["end"], + tempo_msgs=mid_dict.tempo_msgs, + ticks_per_beat=mid_dict.ticks_per_beat, + ) + * 1e-3 + ) + velocity = note_msg["data"]["velocity"] + + if onset_s >= offset_s: + print("Skipping duration zero note") + continue + + intervals.append([onset_s, offset_s]) + pitches.append(pitch) + velocities.append(velocity) + + return np.array(intervals), np.array(pitches), np.array(velocities) + + +def midi_to_hz(note, shift=0): + """ + Convert MIDI to HZ. + + Shift, if != 0, is subtracted from the MIDI note. + Use "2" for the hFT augmented model transcriptions, else pitches won't match. + """ + # the one used in hFT transformer + return 440.0 * (2.0 ** (note.astype(int) - shift - 69) / 12) + # a = 440 # frequency of A (common value is 440Hz) + # return (a / 32) * (2 ** ((note - 9) / 12)) + + +def get_matched_files(est_dir: str, ref_dir: str): + # We assume that the files have the same path relative to their directory + + res = [] + est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True) + if len(est_paths) == 0: + est_paths = glob.glob( + os.path.join(est_dir, "**/*.midi"), recursive=True + ) + print(f"found {len(est_paths)} est files") + + for est_path in est_paths: + est_rel_path = os.path.relpath(est_path, est_dir) + ref_path = os.path.join( + ref_dir, os.path.splitext(est_rel_path)[0] + ".mid" + ) + if os.path.isfile(ref_path): + res.append((est_path, ref_path)) + else: + ref_path = os.path.join( + ref_dir, os.path.splitext(est_rel_path)[0] + ".midi" + ) + if os.path.isfile(ref_path): + res.append((est_path, ref_path)) + + print(f"found {len(res)} matched est-ref pairs") + + return res + + +def get_matched_files_direct(est_dir: str, ref_dir: str): + # Helper to extract filenames with normalized extensions + def get_filenames(paths): + normalized_files = {} + for path in paths: + basename = os.path.basename(path) + name, ext = os.path.splitext(basename) + + name = name[:-12] if name.endswith("_transcribed") else name + + if ext in [".mid", ".midi"]: + normalized_files[name] = path + return normalized_files + + # Gather all potential MIDI files in both directories + est_files = glob.glob(os.path.join(est_dir, "**/*.*"), recursive=True) + ref_files = glob.glob(os.path.join(ref_dir, "**/*.*"), recursive=True) + + # Map filenames to their full paths with normalized extensions + est_file_map = get_filenames(est_files) + ref_file_map = get_filenames(ref_files) + + # Find matching files by filename disregarding extension differences + matched_files = [] + for filename, ref_path in ref_file_map.items(): + if filename in est_file_map: + matched_files.append((est_file_map[filename], ref_path)) + + print(f"found {len(est_file_map)} MIDI files in estimation directory") + print(f"found {len(ref_file_map)} MIDI files in reference directory") + print(f"found {len(matched_files)} matched MIDI file pairs") + + return matched_files + + +def get_avg_scores(scores): + totals = {} + counts = {} + for d in scores: + for key, value in d.items(): + if key == "f_name": + continue + totals[key] = totals.get(key, 0) + value + counts[key] = counts.get(key, 0) + 1 + averages = {f"{key}_avg": totals[key] / counts[key] for key in totals} + return averages + + +def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0): + """ + Evaluate the estimated pitches against the reference pitches using mir_eval. + """ + + est_ref_pairs = get_matched_files(est_dir, ref_dir) + if len(est_ref_pairs) == 0: + print("Failed to find files, trying direct search") + est_ref_pairs = get_matched_files_direct(est_dir, ref_dir) + + output_fhandle = ( + open(output_stats_file, "w") if output_stats_file is not None else None + ) + + res = [] + for est_file, ref_file in tqdm(est_ref_pairs): + ref_intervals, ref_pitches, ref_velocities = ( + midi_to_intervals_and_pitches(ref_file) + ) + est_intervals, est_pitches, est_velocities = ( + midi_to_intervals_and_pitches(est_file) + ) + ref_pitches_hz = midi_to_hz(ref_pitches) + est_pitches_hz = midi_to_hz(est_pitches, est_shift) + + scores = mir_eval.transcription.evaluate( + ref_intervals, + ref_pitches_hz, + est_intervals, + est_pitches_hz, + ) + + prec_vel, recall_vel, f1_vel, _ = ( + mir_eval.transcription_velocity.precision_recall_f1_overlap( + ref_intervals=ref_intervals, + ref_pitches=ref_pitches, + ref_velocities=ref_velocities, + est_intervals=est_intervals, + est_pitches=est_pitches, + est_velocities=est_velocities, + ) + ) + + scores["Precision_vel"] = prec_vel + scores["Recall_vel"] = recall_vel + scores["F1_vel"] = f1_vel + scores["f_name"] = est_file + res.append(scores) + + avg_scores = get_avg_scores(res) + output_fhandle.write(json.dumps(avg_scores)) + output_fhandle.write("\n") + + res.sort(key=lambda x: x["F-measure"]) + for s in res: + output_fhandle.write(json.dumps(s)) + output_fhandle.write("\n") + + +def evaluate_single(est_file, ref_file): + ref_intervals, ref_pitches, ref_velocities = midi_to_intervals_and_pitches( + ref_file + ) + est_intervals, est_pitches, est_velocities = midi_to_intervals_and_pitches( + est_file + ) + ref_pitches_hz = midi_to_hz(ref_pitches) + est_pitches_hz = midi_to_hz(est_pitches) + + return mir_eval.transcription.evaluate( + ref_intervals, + ref_pitches_hz, + est_intervals, + est_pitches_hz, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(usage="evaluate []") + parser.add_argument( + "--est-dir", + type=str, + help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed.", + ) + parser.add_argument( + "--ref-dir", + type=str, + help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw).", + ) + parser.add_argument( + "--output-stats-file", + default=None, + type=str, + help="Path to the file to save the evaluation stats", + ) + + args = parser.parse_args() + evaluate_mir_eval( + args.est_dir, + args.ref_dir, + args.output_stats_file, + ) diff --git a/amt/run.py b/amt/run.py index 88fdfb3..b6c1880 100644 --- a/amt/run.py +++ b/amt/run.py @@ -21,6 +21,21 @@ def _add_maestro_args(subparser): ) +def _add_matched_args(subparser): + subparser.add_argument("audio", help="audio directory path") + subparser.add_argument("mid", help="midi directory path") + subparser.add_argument("csv", help="path to split.csv") + subparser.add_argument("-train", help="train save path", required=False) + subparser.add_argument("-val", help="val save path", required=False) + subparser.add_argument("-test", help="test save path", required=False) + subparser.add_argument( + "-mp", + help="number of processes to use", + type=int, + default=1, + ) + + def _add_synth_args(subparser): subparser.add_argument("dir", help="Directory containing MIDIs") subparser.add_argument("csv", help="Split csv") @@ -97,18 +112,18 @@ def get_synth_mid_paths(mid_dir: str, csv_path: str): def build_synth( mid_dir: str, csv_path: str, - train_file: str, - test_file: str, + train_path: str, + test_path: str, num_procs: int, ): from amt.data import AmtDataset, pianoteq_cmd_fn - if os.path.isfile(train_file): - print(f"Dataset file already exists at {train_file} - removing") - os.remove(train_file) - if os.path.isfile(test_file): - print(f"Dataset file already exists at {test_file} - removing") - os.remove(test_file) + if os.path.isfile(train_path): + print(f"Dataset file already exists at {train_path} - removing") + os.remove(train_path) + if os.path.isfile(test_path): + print(f"Dataset file already exists at {test_path} - removing") + os.remove(test_path) ( train_paths, @@ -117,23 +132,23 @@ def build_synth( print(f"Found {len(train_paths)} train and {len(test_paths)} test paths") - print(f"Building {train_file}") + print(f"Building {train_path}") AmtDataset.build( load_paths=train_paths, - save_path=train_file, + save_path=train_path, num_processes=num_procs, cli_cmd_fn=pianoteq_cmd_fn, ) - print(f"Building {test_file}") + print(f"Building {test_path}") AmtDataset.build( load_paths=test_paths, - save_path=test_file, + save_path=test_path, num_processes=num_procs, cli_cmd_fn=pianoteq_cmd_fn, ) -def get_matched_maestro_paths(maestro_dir): +def _get_matched_maestro_paths(maestro_dir): assert os.path.isdir(maestro_dir), "MAESTRO directory not found" maestro_csv_path = os.path.join(maestro_dir, "maestro-v3.0.0.csv") @@ -152,7 +167,7 @@ def get_matched_maestro_paths(maestro_dir): os.path.join(maestro_dir, entry["midi_filename"]) ) - if not os.path.isfile(audio_path) or not os.path.isfile(audio_path): + if not os.path.isfile(audio_path) or not os.path.isfile(midi_path): print("File missing - skipping") print(audio_path) print(midi_path) @@ -170,46 +185,151 @@ def get_matched_maestro_paths(maestro_dir): return matched_paths_train, matched_paths_val, matched_paths_test -def build_maestro(maestro_dir, train_file, val_file, test_file, num_procs): +def _get_matched_paths(audio_dir: str, mid_dir: str, split_csv_path: str): + assert os.path.isdir(audio_dir), "audio dir not found" + assert os.path.isdir(mid_dir), "mid dir not found" + assert os.path.isfile(split_csv_path), "split csv not found" + + matched_paths_train = [] + matched_paths_val = [] + matched_paths_test = [] + with open(split_csv_path, "r") as f: + dict_reader = DictReader(f) + for entry in dict_reader: + audio_path = os.path.normpath( + os.path.join(audio_dir, entry["audio_path"]) + ) + mid_path = os.path.normpath( + os.path.join(mid_dir, entry["mid_path"]) + ) + + if not os.path.isfile(audio_path) or not os.path.isfile(mid_path): + raise FileNotFoundError( + f"File pair missing: {(audio_path, mid_path)}" + ) + + if entry["split"] == "train": + matched_paths_train.append((audio_path, mid_path)) + elif entry["split"] == "val": + matched_paths_val.append((audio_path, mid_path)) + elif entry["split"] == "test": + matched_paths_test.append((audio_path, mid_path)) + else: + raise ValueError("Invalid split") + + return matched_paths_train, matched_paths_val, matched_paths_test + + +def _build_from_matched_paths( + matched_paths_train: list, + matched_paths_val: list, + matched_paths_test: list, + train_path: str | None = None, + val_path: str | None = None, + test_path: str | None = None, + num_procs: int = 1, +): from amt.data import AmtDataset - if os.path.isfile(train_file): - print(f"Dataset file already exists at {train_file} - removing") - os.remove(train_file) - if os.path.isfile(val_file): - print(f"Dataset file already exists at {val_file} - removing") - os.remove(val_file) - if os.path.isfile(test_file): - print(f"Dataset file already exists at {test_file} - removing") - os.remove(test_file) + if train_path is None: + pass + elif len(matched_paths_train) >= 1: + if os.path.isfile(train_path): + input( + f"Dataset file already exists at {train_path} - Press enter to continue (^C to quit)" + ) + os.remove(train_path) + print(f"Building {train_path}") + AmtDataset.build( + load_paths=matched_paths_train, + save_path=train_path, + num_processes=num_procs, + ) + if val_path is None: + pass + elif len(matched_paths_val) >= 1: + if os.path.isfile(val_path): + input( + f"Dataset file already exists at {val_path} - Press enter to continue (^C to quit)" + ) + os.remove(val_path) + print(f"Building {val_path}") + AmtDataset.build( + load_paths=matched_paths_val, + save_path=val_path, + num_processes=num_procs, + ) + if test_path is None: + pass + elif len(matched_paths_test) >= 1 and test_path: + if os.path.isfile(test_path): + input( + f"Dataset file already exists at {test_path} - Press enter to continue (^C to quit)" + ) + os.remove(test_path) + print(f"Building {test_path}") + AmtDataset.build( + load_paths=matched_paths_test, + save_path=test_path, + num_processes=num_procs, + ) + +def build_from_csv( + audio_dir: str, + mid_dir: str, + split_csv_path: str, + train_path: str, + val_path: str, + test_path: str, + num_procs: int, +): ( matched_paths_train, matched_paths_val, matched_paths_test, - ) = get_matched_maestro_paths(maestro_dir) + ) = _get_matched_paths( + audio_dir=audio_dir, + mid_dir=mid_dir, + split_csv_path=split_csv_path, + ) print( f"Found {len(matched_paths_train)}, {len(matched_paths_val)}, {len(matched_paths_test)} train, val, and test paths" ) - print(f"Building {train_file}") - AmtDataset.build( - load_paths=matched_paths_train, - save_path=train_file, - num_processes=num_procs, + _build_from_matched_paths( + matched_paths_train=matched_paths_train, + matched_paths_val=matched_paths_val, + matched_paths_test=matched_paths_test, + train_path=train_path, + val_path=val_path, + test_path=test_path, + num_procs=num_procs, ) - print(f"Building {val_file}") - AmtDataset.build( - load_paths=matched_paths_val, - save_path=val_file, - num_processes=num_procs, - ) - print(f"Building {test_file}") - AmtDataset.build( - load_paths=matched_paths_test, - save_path=test_file, - num_processes=num_procs, + + +def build_maestro( + maestro_dir: str, + train_path: str, + val_path: str, + test_path: str, + num_procs: int, +): + ( + matched_paths_train, + matched_paths_val, + matched_paths_test, + ) = _get_matched_maestro_paths(maestro_dir=maestro_dir) + + _build_from_matched_paths( + matched_paths_train=matched_paths_train, + matched_paths_val=matched_paths_val, + matched_paths_test=matched_paths_test, + train_path=train_path, + val_path=val_path, + test_path=test_path, + num_procs=num_procs, ) @@ -298,11 +418,12 @@ def transcribe( file_paths = found_mp3 + found_wav elif trans_mode == "maestro": matched_train_paths, matched_val_paths, matched_test_paths = ( - get_matched_maestro_paths(load_dir) + _get_matched_maestro_paths(load_dir) ) + train_mp3_paths = [ap for ap, mp in matched_train_paths] val_mp3_paths = [ap for ap, mp in matched_val_paths] test_mp3_paths = [ap for ap, mp in matched_test_paths] - file_paths = test_mp3_paths # val_mp3_paths + test_mp3_paths + file_paths = test_mp3_paths assert len(file_paths) == 177, "Invalid maestro files" else: file_paths = [load_path] @@ -340,17 +461,20 @@ def main(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(help="sub-command help", dest="command") - # add maestro and transcribe subparsers subparser_maestro = subparsers.add_parser( - "maestro", help="Commands to build the maestro dataset." + "build-maestro", help="Commands to build the maestro dataset." + ) + subparser_matched = subparsers.add_parser( + "build-matched", help="Commands to build dataset from matched paths." ) subparser_synth = subparsers.add_parser( - "synth", help="Commands to build the maestro dataset." + "build-synth", help="Commands to build the synthetic dataset." ) subparser_transcribe = subparsers.add_parser( "transcribe", help="Commands to run transcription." ) _add_maestro_args(subparser_maestro) + _add_matched_args(subparser_matched) _add_synth_args(subparser_synth) _add_transcribe_args(subparser_transcribe) @@ -360,20 +484,30 @@ def main(): parser.print_help() print("Unrecognized command") exit(1) - elif args.command == "maestro": + elif args.command == "build-maestro": build_maestro( maestro_dir=args.dir, - train_file=args.train, - val_file=args.val, - test_file=args.test, + train_path=args.train, + val_path=args.val, + test_path=args.test, + num_procs=args.mp, + ) + elif args.command == "build-matched": + build_from_csv( + audio_dir=args.audio, + mid_dir=args.mid, + split_csv_path=args.csv, + train_path=args.train, + val_path=args.val, + test_path=args.test, num_procs=args.mp, ) - elif args.command == "synth": + elif args.command == "build-synth": build_synth( mid_dir=args.dir, csv_path=args.csv, - train_file=args.train, - test_file=args.test, + train_path=args.train, + test_path=args.test, num_procs=args.mp, ) elif args.command == "transcribe": diff --git a/scripts/eval/mir.py b/scripts/eval/mir.py deleted file mode 100644 index 0bfc520..0000000 --- a/scripts/eval/mir.py +++ /dev/null @@ -1,128 +0,0 @@ -import glob -from tqdm.auto import tqdm -from collections import defaultdict -import pretty_midi -import numpy as np -import mir_eval -import json -import os - -pretty_midi.pretty_midi.MAX_TICK = 1e10 - - -def midi_to_intervals_and_pitches(midi_file_path): - """ - This function reads a MIDI file and extracts note intervals and pitches - suitable for use with mir_eval's transcription evaluation functions. - """ - # Load the MIDI file - midi_data = pretty_midi.PrettyMIDI(midi_file_path) - - # Prepare lists to collect note intervals and pitches - notes = [] - for instrument in midi_data.instruments: - # Skip drum instruments - if not instrument.is_drum: - for note in instrument.notes: - notes.append([note.start, note.end, note.pitch]) - notes = sorted(notes, key=lambda x: x[0]) - notes = np.array(notes) - intervals, pitches = notes[:, :2], notes[:, 2] - intervals -= intervals[0][0] - return intervals, pitches - - -def midi_to_hz(note, shift=0): - """ - Convert MIDI to HZ. - - Shift, if != 0, is subtracted from the MIDI note. - Use "2" for the hFT augmented model transcriptions, else pitches won't match. - """ - # the one used in hFT transformer - return 440.0 * (2.0 ** (note.astype(int) - shift - 69) / 12) - # a = 440 # frequency of A (common value is 440Hz) - # return (a / 32) * (2 ** ((note - 9) / 12)) - - -def get_matched_files(est_dir: str, ref_dir: str): - # We assume that the files have the same path relative to their directory - - res = [] - est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True) - print(f"found {len(est_paths)} est files") - - for est_path in est_paths: - est_rel_path = os.path.relpath(est_path, est_dir) - ref_path = os.path.join( - ref_dir, os.path.splitext(est_rel_path)[0] + ".midi" - ) - if os.path.isfile(ref_path): - res.append((est_path, ref_path)) - - print(f"found {len(res)} matched est-ref pairs") - - return res - - -def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0): - """ - Evaluate the estimated pitches against the reference pitches using mir_eval. - """ - - est_ref_pairs = get_matched_files(est_dir, ref_dir) - - output_fhandle = ( - open(output_stats_file, "w") if output_stats_file is not None else None - ) - - res = defaultdict(list) - for est_file, ref_file in tqdm(est_ref_pairs): - ref_intervals, ref_pitches = midi_to_intervals_and_pitches(ref_file) - est_intervals, est_pitches = midi_to_intervals_and_pitches(est_file) - ref_pitches_hz = midi_to_hz(ref_pitches) - est_pitches_hz = midi_to_hz(est_pitches, est_shift) - scores = mir_eval.transcription.evaluate( - ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz - ) - if output_fhandle is not None: - output_fhandle.write(json.dumps(scores)) - output_fhandle.write("\n") - for k, v in scores.items(): - res[k].append(v) - else: - print(json.dumps(scores, indent=4)) - for k, v in scores.items(): - res[k].append(v) - - for k, v in res.items(): - print(k, sum(v) / len(v)) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(usage="evaluate []") - parser.add_argument( - "--est-dir", - type=str, - help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed.", - ) - parser.add_argument( - "--ref-dir", - type=str, - help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw).", - ) - parser.add_argument( - "--output-stats-file", - default=None, - type=str, - help="Path to the file to save the evaluation stats", - ) - - args = parser.parse_args() - evaluate_mir_eval( - args.est_dir, - args.ref_dir, - args.output_stats_file, - ) diff --git a/scripts/split.py b/scripts/split.py index ef688b3..cbacf58 100644 --- a/scripts/split.py +++ b/scripts/split.py @@ -4,12 +4,44 @@ import argparse import os +from typing import Callable -def get_matched_paths(audio_dir: str, mid_dir: str): + +def guitarset_file_hook(audio_path: str): + base, ext = os.path.splitext(audio_path) + + return [base + "_mic" + ext, base + "_mix" + ext] + + +def gaps_file_hook(audio_path: str): + base, ext = os.path.splitext(audio_path) + assert base.endswith("-fine-aligned") + + return [base[: -len("-fine-aligned")] + ext] + + +def get_hook(hook_name: str): + name_to_fn = { + "guitarset": guitarset_file_hook, + "gaps": gaps_file_hook, + } + + return name_to_fn[hook_name] + + +def get_matched_paths( + audio_dir: str, + mid_dir: str, + midi_ex: str, + audio_ex: str, + audio_hook: Callable | None = None, +): # Assume that the files have the same path relative to their directory res = [] - mid_paths = glob.glob(os.path.join(mid_dir, "**/*.mid"), recursive=True) - print(f"found {len(mid_paths)} mid files") + mid_paths = glob.glob( + os.path.join(mid_dir, f"**/*.{midi_ex}"), recursive=True + ) + print(f"found {len(mid_paths)} .{midi_ex} files") audio_dir_last = os.path.basename(audio_dir) mid_dir_last = os.path.basename(mid_dir) @@ -17,18 +49,28 @@ def get_matched_paths(audio_dir: str, mid_dir: str): for mid_path in mid_paths: input_rel_path = os.path.relpath(mid_path, mid_dir) - mp3_rel_path = os.path.splitext(input_rel_path)[0] + ".mp3" - mp3_path = os.path.join(audio_dir, mp3_rel_path) + audio_rel_path = os.path.splitext(input_rel_path)[0] + f".{audio_ex}" + audio_path = os.path.join(audio_dir, audio_rel_path) + + if audio_hook is not None: + audio_paths = audio_hook(audio_path) + audio_rel_paths = audio_hook(audio_rel_path) + else: + audio_paths = [audio_path] + audio_rel_paths = [audio_rel_path] - # Check if the corresponding .mp3 file exists - if os.path.isfile(mp3_path): - matched_mid_path = os.path.join(mid_dir_last, input_rel_path) - matched_mp3_path = os.path.join(audio_dir_last, mp3_rel_path) + for _audio_path, _audio_rel_path in zip(audio_paths, audio_rel_paths): + if os.path.isfile(_audio_path): + matched_mid_path = os.path.join(mid_dir_last, input_rel_path) + matched_audio_path = os.path.join( + audio_dir_last, _audio_rel_path + ) - res.append((matched_mp3_path, matched_mid_path)) + # print((matched_audio_path, matched_mid_path)) + res.append((matched_audio_path, matched_mid_path)) - print(f"found {len(res)} matched mp3-midi pairs") - assert len(mid_paths) == len(res), "audio files missing" + print(f"found {len(res)} matched audio-midi pairs") + assert len(mid_paths) <= len(res), "audio files missing" return res @@ -50,9 +92,36 @@ def create_csv(matched_paths, csv_path, ratio): parser.add_argument("-mid_dir", type=str) parser.add_argument("-audio_dir", type=str) parser.add_argument("-csv_path", type=str) + parser.add_argument( + "-midi_ex", + type=str, + choices=["mid", "midi"], + default="mid", + help="File extension of the MIDI files", + ) + parser.add_argument( + "-audio_ex", + type=str, + choices=["mp3", "wav"], + default="mp3", + help="File extension of the audio files", + ) + parser.add_argument( + "-hook", + type=str, + choices=["guitarset", "gaps"], + help="use dataset specific hook for audio filenames", + required=False, + ) parser.add_argument("-ratio", type=int, default=0.1) args = parser.parse_args() - matched_paths = get_matched_paths(args.audio_dir, args.mid_dir) + matched_paths = get_matched_paths( + args.audio_dir, + args.mid_dir, + args.midi_ex, + args.audio_ex, + audio_hook=get_hook(args.hook) if args.hook else None, + ) create_csv(matched_paths, args.csv_path, args.ratio)