diff --git a/.gitignore b/.gitignore index 1d36432..9a4767b 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,9 @@ .idea/ notebooks/scratch -baselines/hft_transformer/model_files/ +experiments/baselines/hft_transformer/model_files/ +experiments/baselines/google_t5/model_files/ +experiments/aria-amt-intermediate-transcribed-data # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/amt/audio.py b/amt/audio.py index 447822e..b9b7059 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -1,5 +1,5 @@ """Contains code taken from https://github.com/openai/whisper""" - +import functools import os import random import torch @@ -197,9 +197,11 @@ def __init__( bandpass_ratio: float = 0.15, distort_ratio: float = 0.15, reduce_ratio: float = 0.01, + max_num_transforms: int = None, # currently we're doing 8 different transformations detune_ratio: float = 0.0, detune_max_shift: float = 0.0, spec_aug_ratio: float = 0.9, + ): super().__init__() self.tokenizer = AmtTokenizer() @@ -223,7 +225,14 @@ def __init__( self.detune_ratio = detune_ratio self.detune_max_shift = detune_max_shift self.spec_aug_ratio = spec_aug_ratio - + # the following two variables, `self.t_count` and `self.max_num_transforms` + # are state variables that track the # of transformations applied. + # `self.t_count` is set in `forward` method to 0 + # `t_count` can also be passed into the following methods: `distortion_aug_cpu`, `log_mel`, `aug_wav`, + # the methods that we're stochastically applying transformations. + # a little messy/stateful, but helps the code be backwards compatible. + self.t_count = None + self.max_num_transforms = max_num_transforms self.time_mask_param = 2500 self.freq_mask_param = 15 self.reduction_resample_rate = 6000 @@ -273,6 +282,34 @@ def __init__( ), ) + # inverse mel transform + self.inverse_mel = torchaudio.transforms.InverseMelScale( + n_mels=self.config["n_mels"], + sample_rate=self.config["sample_rate"], + n_stft=self.config["n_fft"] // 2 + 1, + ) + self.inverse_spec_transform = torchaudio.transforms.GriffinLim( + n_fft=self.config["n_fft"], + hop_length=self.config["hop_len"], + ) + + def check_apply_transform(self, ratio: float): + """ + Check if a transformation should be applied based on the ratio and the + number of transformations already applied. + """ + + if ( + (self.max_num_transforms is not None) and + (self.t_count is not None) and + (self.t_count >= self.max_num_transforms) + ): + return False + apply_transform = random.random() < ratio + if apply_transform: + self.t_count += 1 + return apply_transform + def get_params(self): return { "noise_ratio": self.noise_ratio, @@ -408,13 +445,16 @@ def apply_distortion(self, wav: torch.tensor): return AF.overdrive(wav, gain=gain, colour=colour) - def distortion_aug_cpu(self, wav: torch.Tensor): + def distortion_aug_cpu(self, wav: torch.Tensor, t_count: int = None): # This function should run on the cpu (i.e. in the dataloader collate # function) in order to not be a bottlekneck + if t_count is not None: + self.t_count = t_count - if random.random() < self.reduce_ratio: + if self.check_apply_transform(self.reduce_ratio): wav = self.apply_reduction(wav) - if random.random() < self.distort_ratio: + + if self.check_apply_transform(self.distort_ratio): wav = self.apply_distortion(wav) return wav @@ -445,34 +485,34 @@ def shift_spec(self, specs: torch.Tensor, shift: int | float): return shifted_specs def detune_spec(self, specs: torch.Tensor): - if random.random() < self.detune_ratio: - detune_shift = random.uniform( - -self.detune_max_shift, self.detune_max_shift - ) - detuned_specs = self.shift_spec(specs, shift=detune_shift) + detune_shift = random.uniform( + -self.detune_max_shift, self.detune_max_shift + ) + detuned_specs = self.shift_spec(specs, shift=detune_shift) - return (specs + detuned_specs) / 2 - else: - return specs + specs = (specs + detuned_specs) / 2 + return specs - def aug_wav(self, wav: torch.Tensor): + def aug_wav(self, wav: torch.Tensor, t_count: int = None): # This function doesn't apply distortion. If distortion is desired it # should be run beforehand on the cpu with distortion_aug_cpu. Note # also that detuning is done to the spectrogram in log_mel, not the wav. + if t_count is not None: + self.t_count = t_count # Noise - if random.random() < self.noise_ratio: + if self.check_apply_transform(self.noise_ratio): wav = self.apply_noise(wav) - if random.random() < self.applause_ratio: + if self.check_apply_transform(self.applause_ratio): wav = self.apply_applause(wav) # Reverb - if random.random() < self.reverb_ratio: + if self.check_apply_transform(self.reverb_ratio): wav = self.apply_reverb(wav) # EQ - if random.random() < self.bandpass_ratio: + if self.check_apply_transform(self.bandpass_ratio): wav = self.apply_bandpass(wav) return wav @@ -487,15 +527,25 @@ def norm_mel(self, mel_spec: torch.Tensor): return log_spec def log_mel( - self, wav: torch.Tensor, shift: int | None = None, detune: bool = False + self, + wav: torch.Tensor, + shift: int | None = None, + detune: bool = False, + t_count: int = None, ): + if t_count is not None: + self.t_count = t_count + spec = self.spec_transform(wav)[..., :-1] + # check: are detune and shift mutually exclusive? + # should we also put a ratio on shift? if shift is not None and shift != 0: spec = self.shift_spec(spec, shift) elif detune is True: - # Don't detune and spec shift at the same time - spec = self.detune_spec(spec) + if self.check_apply_transform(self.detune_ratio): + # Don't detune and spec shift at the same time + spec = self.detune_spec(spec) mel_spec = self.mel_transform(spec) @@ -504,15 +554,25 @@ def log_mel( return log_spec + def inverse_log_mel(self, mel: torch.Tensor): + """ + Takes as input a log mel spectrogram and returns the corresponding audio. + """ + mel = (4 * mel) - 4 + mel = torch.pow(10, mel) + mel = self.inverse_mel(mel) + return self.inverse_spec_transform(mel) + def forward(self, wav: torch.Tensor, shift: int = 0): # Noise, and reverb + self.t_count = 0 wav = self.aug_wav(wav) # Spec, detuning & pitch shift log_mel = self.log_mel(wav, shift, detune=True) # Spec aug - if random.random() < self.spec_aug_ratio: + if self.check_apply_transform(self.spec_aug_ratio): log_mel = self.spec_aug(log_mel) return log_mel diff --git a/baselines/giantmidi/transcribe_new_files.py b/baselines/giantmidi/transcribe_new_files.py deleted file mode 100644 index 0650c73..0000000 --- a/baselines/giantmidi/transcribe_new_files.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import argparse -import time -import torch -import piano_transcription_inference -import glob - - -def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None): - """Transcribe piano solo mp3s to midi files.""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - os.makedirs(midis_dir, exist_ok=True) - - # Transcriptor - transcriptor = piano_transcription_inference.PianoTranscription(device=device) - - transcribe_time = time.time() - for n, mp3_path in enumerate(glob.glob(os.path.join(mp3s_dir, '*.mp3'))[begin_index:end_index]): - print(n, mp3_path) - midi_file = os.path.basename(mp3_path).replace('.mp3', '.midi') - midi_path = os.path.join(midis_dir, midi_file) - if os.path.exists(midi_path): - continue - - (audio, _) = ( - piano_transcription_inference - .load_audio(mp3_path, sr=piano_transcription_inference.sample_rate, mono=True) - ) - - try: - # Transcribe - transcribed_dict = transcriptor.transcribe(audio, midi_path) - print(transcribed_dict) - except: - print('Failed for this audio!') - - print('Time: {:.3f} s'.format(time.time() - transcribe_time)) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Example of parser. ') - parser.add_argument('--mp3s_dir', type=str, required=True, help='') - parser.add_argument('--midis_dir', type=str, required=True, help='') - parser.add_argument( - '--begin_index', type=int, required=False, - help='File num., of an ordered list of files, to start transcribing from.', default=None - ) - parser.add_argument( - '--end_index', type=int, required=False, default=None, - help='File num., of an ordered list of files, to end transcription.' - ) - - # Parse arguments - args = parser.parse_args() - transcribe_piano( - mp3s_dir=args.mp3s_dir, - midis_dir=args.midis_dir, - begin_index=args.begin_index, - end_index=args.end_index - ) - -""" -python transcribe_new_files.py \ - transcribe_piano \ - --mp3s_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ - --midis_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model -""" \ No newline at end of file diff --git a/baselines/requirements-baselines.txt b/baselines/requirements-baselines.txt deleted file mode 100644 index b56d966..0000000 --- a/baselines/requirements-baselines.txt +++ /dev/null @@ -1,3 +0,0 @@ -pretty_midi -librosa -piano_transcription_inference diff --git a/experiments/baselines/giantmidi/transcribe_new_files.py b/experiments/baselines/giantmidi/transcribe_new_files.py new file mode 100644 index 0000000..ee986dd --- /dev/null +++ b/experiments/baselines/giantmidi/transcribe_new_files.py @@ -0,0 +1,50 @@ +import os +import argparse +import time +import torch +import piano_transcription_inference +import glob +from more_itertools import unique_everseen +from tqdm.auto import tqdm +from random import shuffle +import sys +here = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(here, '../..')) +import loader_util + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Example of parser. ') + parser = loader_util.add_io_arguments(parser) + args = parser.parse_args() + + files_to_transcribe = loader_util.get_files_to_transcribe(args) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + transcriptor = piano_transcription_inference.PianoTranscription(device=device) + + # Transcriptor + for n, (input_fname, output_fname) in tqdm(enumerate(files_to_transcribe), total=len(files_to_transcribe)): + if os.path.exists(output_fname): + continue + + now_start = time.time() + (audio, _) = (piano_transcription_inference + .load_audio(input_fname, sr=piano_transcription_inference.sample_rate, mono=True)) + print(f'READING ELAPSED TIME: {time.time() - now_start}') + now_read = time.time() + try: + # Transcribe + transcribed_dict = transcriptor.transcribe(audio, output_fname) + except: + print('Failed for this audio!') + print(f'TRANSCRIPTION ELAPSED TIME: {time.time() - now_read}') + print(f'TOTAL ELAPSED TIME: {time.time() - now_start}') + + + +""" +python transcribe_new_files.py \ + --input_dir_to_transcribe /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ + --output_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model +""" \ No newline at end of file diff --git a/experiments/baselines/google_t5/transcribe_new_files.py b/experiments/baselines/google_t5/transcribe_new_files.py new file mode 100644 index 0000000..147cc83 --- /dev/null +++ b/experiments/baselines/google_t5/transcribe_new_files.py @@ -0,0 +1,311 @@ +# ! git clone --branch=main https://github.com/magenta/mt3 +# !python3 -m pip install jax[cuda11_local] nest-asyncio -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# copy checkpoints +# ! gsutil -q -m cp -r gs://mt3/checkpoints . + +import numpy as np +import tensorflow.compat.v2 as tf +import functools +import gin +import jax +import note_seq +import seqio +import t5 +import t5x +import librosa +import time +import sys +import os +from tqdm.auto import tqdm +here = os.path.dirname(__file__) +sys.path.append(os.path.join(here, '../..')) +import loader_util + +from mt3 import metrics_utils +from mt3 import models +from mt3 import network +from mt3 import note_sequences +from mt3 import preprocessors +from mt3 import spectrograms +from mt3 import vocabularies +from scipy.io import wavfile +import os +import glob +from more_itertools import unique_everseen +from random import shuffle +here = os.path.dirname(__file__) + +def download_model(): + # download model + import gdown + import tarfile + url = "https://drive.google.com/file/d/1H9i8AszhJf9xonSaY6YkVoCkh1KHxYSk/view?usp=sharing" + output = os.path.join(here, "model_files/checkpoint.tar.gz") + model_files_dirname = os.path.dirname(output) + if not os.path.exists(model_files_dirname): + os.makedirs(model_files_dirname) + gdown.download(url, output, fuzzy=True) + tar = tarfile.open(output) + tar.extractall(model_files_dirname) + tar.close() + # download configs + import wget + gin_dir = os.path.join(here, 'model_files', 'gin') + if not os.path.exists(gin_dir): + os.makedirs(gin_dir) + url_1 = 'https://raw.githubusercontent.com/magenta/mt3/main/mt3/gin/model.gin' + url_2 = 'https://raw.githubusercontent.com/magenta/mt3/main/mt3/gin/ismir2021.gin' + url_3 = 'https://raw.githubusercontent.com/magenta/mt3/main/mt3/gin/mt3.gin' + wget.download(url_1, os.path.join(gin_dir, 'model.gin')) + wget.download(url_2, os.path.join(gin_dir, 'ismir2021.gin')) + wget.download(url_3, os.path.join(gin_dir, 'mt3.gin')) + + +class InferenceModel(object): + """Wrapper of T5X model for music transcription.""" + + def __init__(self, checkpoint_path, model_type='mt3'): + + # Model Constants. + if model_type == 'ismir2021': + num_velocity_bins = 127 + self.encoding_spec = note_sequences.NoteEncodingSpec + self.inputs_length = 512 + elif model_type == 'mt3': + num_velocity_bins = 1 + self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec + self.inputs_length = 256 + else: + raise ValueError('unknown model_type: %s' % model_type) + + gin_files = [ + os.path.join(here, 'model_files', 'gin', 'model.gin'), + os.path.join(here, 'model_files', 'gin', f'{model_type}.gin') + ] + + self.batch_size = 8 + self.outputs_length = 1024 + self.sequence_length = { + 'inputs': self.inputs_length, + 'targets': self.outputs_length + } + + self.partitioner = t5x.partitioning.PjitPartitioner(num_partitions=1) + + # Build Codecs and Vocabularies. + self.spectrogram_config = spectrograms.SpectrogramConfig() + self.codec = vocabularies.build_codec( + vocab_config=vocabularies.VocabularyConfig(num_velocity_bins=num_velocity_bins) + ) + self.vocabulary = vocabularies.vocabulary_from_codec(self.codec) + self.output_features = { + 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2), + 'targets': seqio.Feature(vocabulary=self.vocabulary), + } + + # Create a T5X model. + self._parse_gin(gin_files) + self.model = self._load_model() + + # Restore from checkpoint. + self.restore_from_checkpoint(checkpoint_path) + + @property + def input_shapes(self): + return { + 'encoder_input_tokens': (self.batch_size, self.inputs_length), + 'decoder_input_tokens': (self.batch_size, self.outputs_length) + } + + def _parse_gin(self, gin_files): + """Parse gin files used to train the model.""" + gin_bindings = [ + 'from __gin__ import dynamic_registration', + 'from mt3 import vocabularies', + 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()', + 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS' + ] + with gin.unlock_config(): + gin.parse_config_files_and_bindings( + gin_files, gin_bindings, finalize_config=False) + + def _load_model(self): + """Load up a T5X `Model` after parsing training gin config.""" + model_config = gin.get_configurable(network.T5Config)() + module = network.Transformer(config=model_config) + return models.ContinuousInputsEncoderDecoderModel( + module=module, + input_vocabulary=self.output_features['inputs'].vocabulary, + output_vocabulary=self.output_features['targets'].vocabulary, + optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0), + input_depth=spectrograms.input_depth(self.spectrogram_config)) + + + def restore_from_checkpoint(self, checkpoint_path): + """Restore training state from checkpoint, resets self._predict_fn().""" + train_state_initializer = t5x.utils.TrainStateInitializer( + optimizer_def=self.model.optimizer_def, + init_fn=self.model.get_initial_variables, + input_shapes=self.input_shapes, + partitioner=self.partitioner) + + restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig( + path=checkpoint_path, mode='specific', dtype='float32') + + train_state_axes = train_state_initializer.train_state_axes + self._predict_fn = self._get_predict_fn(train_state_axes) + self._train_state = train_state_initializer.from_checkpoint_or_scratch( + [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0)) + + @functools.lru_cache() + def _get_predict_fn(self, train_state_axes): + """Generate a partitioned prediction function for decoding.""" + def partial_predict_fn(params, batch, decode_rng): + return self.model.predict_batch_with_aux( + params, batch, decoder_params={'decode_rng': None}) + return self.partitioner.partition( + partial_predict_fn, + in_axis_resources=( + train_state_axes.params, + t5x.partitioning.PartitionSpec('data',), None), + out_axis_resources=t5x.partitioning.PartitionSpec('data',) + ) + + def predict_tokens(self, batch, seed=0): + """Predict tokens from preprocessed dataset batch.""" + prediction, _ = self._predict_fn( + self._train_state.params, batch, jax.random.PRNGKey(seed)) + return self.vocabulary.decode_tf(prediction).numpy() + + def __call__(self, audio): + """Infer note sequence from audio samples. + Args: + audio: 1-d numpy array of audio samples (16kHz) for a single example. + Returns: + A note_sequence of the transcribed audio. + """ + ds = self.audio_to_dataset(audio) + ds = self.preprocess(ds) + + model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)( + ds, task_feature_lengths=self.sequence_length) + model_ds = model_ds.batch(self.batch_size) + + inferences = (tokens for batch in model_ds.as_numpy_iterator() + for tokens in self.predict_tokens(batch)) + + predictions = [] + for example, tokens in zip(ds.as_numpy_iterator(), inferences): + predictions.append(self.postprocess(tokens, example)) + + result = metrics_utils.event_predictions_to_ns( + predictions, codec=self.codec, encoding_spec=self.encoding_spec) + return result['est_ns'] + + def audio_to_dataset(self, audio): + """Create a TF Dataset of spectrograms from input audio.""" + frames, frame_times = self._audio_to_frames(audio) + return tf.data.Dataset.from_tensors({ + 'inputs': frames, + 'input_times': frame_times, + }) + + def _audio_to_frames(self, audio): + """Compute spectrogram frames from audio.""" + frame_size = self.spectrogram_config.hop_width + padding = [0, frame_size - len(audio) % frame_size] + audio = np.pad(audio, padding, mode='constant') + frames = spectrograms.split_audio(audio, self.spectrogram_config) + num_frames = len(audio) // frame_size + times = np.arange(num_frames) / self.spectrogram_config.frames_per_second + return frames, times + + def preprocess(self, ds): + pp_chain = [ + functools.partial( + t5.data.preprocessors.split_tokens_to_inputs_length, + sequence_length=self.sequence_length, + output_features=self.output_features, + feature_key='inputs', + additional_feature_keys=['input_times']), + # Cache occurs here during training. + preprocessors.add_dummy_targets, + functools.partial( + preprocessors.compute_spectrograms, + spectrogram_config=self.spectrogram_config) + ] + for pp in pp_chain: + ds = pp(ds) + return ds + + def postprocess(self, tokens, example): + tokens = self._trim_eos(tokens) + start_time = example['input_times'][0] + # Round down to nearest symbolic token step. + start_time -= start_time % (1 / self.codec.steps_per_second) + return { + 'est_tokens': tokens, + 'start_time': start_time, + # Internal MT3 code expects raw inputs, not used here. + 'raw_inputs': [] + } + + @staticmethod + def _trim_eos(tokens): + tokens = np.array(tokens, np.int32) + if vocabularies.DECODED_EOS_ID in tokens: + tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)] + return tokens + + +def load_audio(data, sample_rate=None): + # read in wave data and convert to samples + # todo: check if this still works with mp3s + f, sr = librosa.load(data, sr=sample_rate) + return f + # return note_seq.audio_io.wav_data_to_samples_librosa(data, sample_rate=sample_rate) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser = loader_util.add_io_arguments(parser) + # necessary arguments + parser.add_argument('-f_config', help='config json file', default=None) + parser.add_argument('-model_file', help='input model file', default="ismir2021") + parser.add_argument('-sample_rate', help='sample rate', type=int, default=16000) + + args = parser.parse_args() + # get model + MODEL = args.model_file + checkpoint_path = os.path.join(here, 'model_files', 'checkpoints', MODEL) + if not os.path.exists(checkpoint_path): + download_model() + inference_model = InferenceModel(checkpoint_path, MODEL) + + files_to_transcribe = loader_util.get_files_to_transcribe(args) + for n, (input_fname, output_fname) in tqdm(enumerate(files_to_transcribe), total=len(files_to_transcribe)): + print(f'Transcribing {input_fname} -> {output_fname}...') + if os.path.exists(output_fname): + continue + now_start = time.time() + audio = load_audio(input_fname, sample_rate=args.sample_rate) + print(f'READING ELAPSED TIME: {time.time() - now_start}') + now_read = time.time() + est_ns = inference_model(audio) + print(f'TRANSCRIPTION ELAPSED TIME: {time.time() - now_read}') + print(f'TOTAL ELAPSED TIME: {time.time() - now_start}') + note_seq.sequence_proto_to_midi_file(est_ns, output_fname) + + +""" +# test one transcription +python transcribe_new_files.py \ + -input_file_to_transcribe ../../../maestro-v3.0.0/2004/MIDI-Unprocessed_SMF_13_01_2004_01-05_ORIG_MID--AUDIO_13_R1_2004_12_Track12_wav.wav \ + -output_file test-output-file.midi + +# test multiple transcriptions +python transcribe_new_files.py \ + -input_dir_to_transcribe /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ + -output_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model +""" \ No newline at end of file diff --git a/baselines/hft_transformer/src/amt.py b/experiments/baselines/hft_transformer/model/hft_amt.py similarity index 100% rename from baselines/hft_transformer/src/amt.py rename to experiments/baselines/hft_transformer/model/hft_amt.py diff --git a/experiments/baselines/hft_transformer/model/model_spec2midi.py b/experiments/baselines/hft_transformer/model/model_spec2midi.py new file mode 100644 index 0000000..9555568 --- /dev/null +++ b/experiments/baselines/hft_transformer/model/model_spec2midi.py @@ -0,0 +1,378 @@ +#! python + +import torch +import torch.nn as nn + +## +## Model +## +class Model_SPEC2MIDI(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder_spec2midi = encoder + self.decoder_spec2midi = decoder + + def forward(self, input_spec): + #input_spec = [batch_size, n_bin, margin+n_frame+margin] (8, 256, 192) + #print('Model_SPEC2MIDI(0) input_spec: '+str(input_spec.shape)) + + enc_vector = self.encoder_spec2midi(input_spec) + #enc_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Model_SPEC2MIDI(1) enc_vector: '+str(enc_vector.shape)) + + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.decoder_spec2midi(enc_vector) + #output_onset_A = [batch_size, n_frame, n_note] (8, 128, 88) + #output_onset_B = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_A = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #output_velocity_B = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #attention = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Model_SPEC2MIDI(2) output_onset_A: '+str(output_onset_A.shape)) + #print('Model_SPEC2MIDI(2) output_onset_B: '+str(output_onset_B.shape)) + #print('Model_SPEC2MIDI(2) output_velocity_A: '+str(output_velocity_A.shape)) + #print('Model_SPEC2MIDI(2) output_velocity_B: '+str(output_velocity_B.shape)) + #print('Model_SPEC2MIDI(2) attention: '+str(attention.shape)) + + return output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B + + +## +## Encoder +## +class Encoder_SPEC2MIDI(nn.Module): + def __init__(self, n_margin, n_frame, n_bin, cnn_channel, cnn_kernel, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + + self.device = device + self.n_frame = n_frame + self.n_bin = n_bin + self.cnn_channel = cnn_channel + self.cnn_kernel = cnn_kernel + self.hid_dim = hid_dim + self.conv = nn.Conv2d(1, self.cnn_channel, kernel_size=(1, self.cnn_kernel)) + self.n_proc = n_margin * 2 + 1 + self.cnn_dim = self.cnn_channel * (self.n_proc - (self.cnn_kernel - 1)) + self.tok_embedding_freq = nn.Linear(self.cnn_dim, hid_dim) + self.pos_embedding_freq = nn.Embedding(n_bin, hid_dim) + self.layers_freq = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.dropout = nn.Dropout(dropout) + self.scale_freq = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + + def forward(self, spec_in): + #spec_in = [batch_size, n_bin, n_margin+n_frame+n_margin] (8, 256, 192) (batch_size=8, n_bins=256, margin=32/n_frame=128) + #print('Encoder_SPEC2MIDI(0) spec_in: '+str(spec_in.shape)) + batch_size = spec_in.shape[0] + + spec = spec_in.unfold(2, self.n_proc, 1).permute(0, 2, 1, 3).contiguous() + #spec = [batch_size, n_frame, n_bin, n_proc] (8, 128, 256, 65) (batch_size=8, n_frame=128, n_bins=256, n_proc=65) + #print('Encoder_SPEC2MIDI(1) spec: '+str(spec.shape)) + + # CNN 1D + spec_cnn = spec.reshape(batch_size*self.n_frame, self.n_bin, self.n_proc).unsqueeze(1) + #spec = [batch_size*n_frame, 1, n_bin, n_proc] (8*128, 1, 256, 65) (batch_size=128, 1, n_frame, n_bins=256, n_proc=65) + #print('Encoder_SPEC2MIDI(2) spec_cnn: '+str(spec_cnn.shape)) + spec_cnn = self.conv(spec_cnn).permute(0, 2, 1, 3).contiguous() + # spec_cnn: [batch_size*n_frame, n_bin, cnn_channel, n_proc-(cnn_kernel-1)] (8*128, 256, 4, 61) + #print('Encoder_SPEC2MIDI(2) spec_cnn: '+str(spec_cnn.shape)) + + ## + ## frequency + ## + spec_cnn_freq = spec_cnn.reshape(batch_size*self.n_frame, self.n_bin, self.cnn_dim) + # spec_cnn_freq: [batch_size*n_frame, n_bin, cnn_channel, (n_proc)-(cnn_kernel-1)] (8*128, 256, 244) + #print('Encoder_SPEC2MIDI(3) spec_cnn_freq: '+str(spec_cnn_freq.shape)) + + # embedding + spec_emb_freq = self.tok_embedding_freq(spec_cnn_freq) + # spec_emb_freq: [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_SPEC2MIDI(4) spec_emb_freq: '+str(spec_emb_freq.shape)) + + # position coding + pos_freq = torch.arange(0, self.n_bin).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + #pos_freq = [batch_size, n_frame, n_bin] (8*128, 256) + #print('Encoder_SPEC2MIDI(5) pos_freq: '+str(pos_freq.shape)) + + # embedding + spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq)) + #spec_freq = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_SPEC2MIDI(6) spec_freq: '+str(spec_freq.shape)) + + # transformer encoder + for layer_freq in self.layers_freq: + spec_freq = layer_freq(spec_freq) + spec_freq = spec_freq.reshape(batch_size, self.n_frame, self.n_bin, self.hid_dim) + #spec_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Encoder_SPEC2MIDI(7) spec_freq: '+str(spec_freq.shape)) + + return spec_freq + + +## +## Decoder +## +class Decoder_SPEC2MIDI(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + self.sigmoid = nn.Sigmoid() + self.dropout = nn.Dropout(dropout) + + # CAfreq + self.pos_embedding_freq = nn.Embedding(n_note, hid_dim) + self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device) + self.layers_freq = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers-1)]) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + # SAtime + self.scale_time = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + self.pos_embedding_time = nn.Embedding(n_frame, hid_dim) + #self.layers_time = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.layers_time = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + + self.fc_onset_time = nn.Linear(hid_dim, 1) + self.fc_offset_time = nn.Linear(hid_dim, 1) + self.fc_mpe_time = nn.Linear(hid_dim, 1) + self.fc_velocity_time = nn.Linear(hid_dim, self.n_velocity) + + def forward(self, enc_spec): + batch_size = enc_spec.shape[0] + enc_spec = enc_spec.reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + #enc_spec = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Decoder_SPEC2MIDI(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## CAfreq freq(256)/note(88) + ## + pos_freq = torch.arange(0, self.n_note).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + midi_freq = self.pos_embedding_freq(pos_freq) + #pos_freq = [batch_size*n_frame, n_note] (8*128, 88) + #midi_freq = [batch_size, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_SPEC2MIDI(1) pos_freq: '+str(pos_freq.shape)) + #print('Decoder_SPEC2MIDI(1) midi_freq: '+str(midi_freq.shape)) + + midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq) + for layer_freq in self.layers_freq: + midi_freq, attention_freq = layer_freq(enc_spec, midi_freq) + dim = attention_freq.shape + attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]]) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #attention_freq = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Decoder_SPEC2MIDI(2) midi_freq: '+str(midi_freq.shape)) + #print('Decoder_SPEC2MIDI(2) attention_freq: '+str(attention_freq.shape)) + + ## output(freq) + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_SPEC2MIDI(3) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_SPEC2MIDI(3) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_SPEC2MIDI(3) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_SPEC2MIDI(3) output_velocity_freq: '+str(output_velocity_freq.shape)) + + ## + ## SAtime time(64) + ## + #midi_time: [batch_size*n_frame, n_note, hid_dim] -> [batch_size*n_note, n_frame, hid_dim] + midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim]) + pos_time = torch.arange(0, self.n_frame).unsqueeze(0).repeat(batch_size*self.n_note, 1).to(self.device) + midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time)) + #pos_time = [batch_size*n_note, n_frame] (8*88, 128) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_SPEC2MIDI(4) pos_time: '+str(pos_time.shape)) + #print('Decoder_SPEC2MIDI(4) midi_time: '+str(midi_time.shape)) + + for layer_time in self.layers_time: + midi_time = layer_time(midi_time) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_SPEC2MIDI(5) midi_time: '+str(midi_time.shape)) + + ## output(time) + output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous() + #output_onset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_time = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_SPEC2MIDI(6) output_onset_time: '+str(output_onset_time.shape)) + #print('Decoder_SPEC2MIDI(6) output_offset_time: '+str(output_offset_time.shape)) + #print('Decoder_SPEC2MIDI(6) output_mpe_time: '+str(output_mpe_time.shape)) + #print('Decoder_SPEC2MIDI(6) output_velocity_time: '+str(output_velocity_time.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, attention_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time + + +## +## sub functions +## +class EncoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, src): + #src = [batch_size, src_len, hid_dim] + + #self attention + _src, _ = self.self_attention(src, src, src) + #dropout, residual connection and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + #positionwise feedforward + _src = self.positionwise_feedforward(src) + #dropout, residual and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + return src + +class DecoderLayer_Zero(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class DecoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #self attention + _trg, _ = self.self_attention(trg, trg, trg) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, hid_dim, n_heads, dropout, device): + super().__init__() + assert hid_dim % n_heads == 0 + self.hid_dim = hid_dim + self.n_heads = n_heads + self.head_dim = hid_dim // n_heads + self.fc_q = nn.Linear(hid_dim, hid_dim) + self.fc_k = nn.Linear(hid_dim, hid_dim) + self.fc_v = nn.Linear(hid_dim, hid_dim) + self.fc_o = nn.Linear(hid_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) + + def forward(self, query, key, value): + batch_size = query.shape[0] + #query = [batch_size, query_len, hid_dim] + #key = [batch_size, key_len, hid_dim] + #value = [batch_size, value_len, hid_dim] + + Q = self.fc_q(query) + K = self.fc_k(key) + V = self.fc_v(value) + #Q = [batch_size, query_len, hid_dim] + #K = [batch_size, key_len, hid_dim] + #V = [batch_size, value_len, hid_dim] + + Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + #Q = [batch_size, n_heads, query_len, head_dim] + #K = [batch_size, n_heads, key_len, head_dim] + #V = [batch_size, n_heads, value_len, head_dim] + + energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale + #energy = [batch_size, n_heads, seq len, seq len] + + attention = torch.softmax(energy, dim = -1) + #attention = [batch_size, n_heads, query_len, key_len] + + x = torch.matmul(self.dropout(attention), V) + #x = [batch_size, n_heads, seq len, head_dim] + + x = x.permute(0, 2, 1, 3).contiguous() + #x = [batch_size, seq_len, n_heads, head_dim] + + x = x.view(batch_size, -1, self.hid_dim) + #x = [batch_size, seq_len, hid_dim] + + x = self.fc_o(x) + #x = [batch_size, seq_len, hid_dim] + + return x, attention + +class PositionwiseFeedforwardLayer(nn.Module): + def __init__(self, hid_dim, pf_dim, dropout): + super().__init__() + self.fc_1 = nn.Linear(hid_dim, pf_dim) + self.fc_2 = nn.Linear(pf_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + #x = [batch_size, seq_len, hid_dim] + + x = self.dropout(torch.relu(self.fc_1(x))) + #x = [batch_size, seq_len, pf dim] + + x = self.fc_2(x) + #x = [batch_size, seq_len, hid_dim] + + return x diff --git a/experiments/baselines/hft_transformer/model/model_spec2midi_ablation.py b/experiments/baselines/hft_transformer/model/model_spec2midi_ablation.py new file mode 100644 index 0000000..5756cc4 --- /dev/null +++ b/experiments/baselines/hft_transformer/model/model_spec2midi_ablation.py @@ -0,0 +1,707 @@ +#! python + +import torch +import torch.nn as nn + +## +## Model (single output) +## +# 1FDN: Encoder_CNNtime_SAfreq / Decoder_CAfreq +class Model_single(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder_spec2midi = encoder + self.decoder_spec2midi = decoder + + def forward(self, input_spec): + #input_spec = [batch_size, n_bin, margin+n_frame+margin] (8, 256, 192) + #print('Model_single(0) input_spec: '+str(input_spec.shape)) + + enc_vector = self.encoder_spec2midi(input_spec) + #enc_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Model_single(1) enc_vector: '+str(enc_vector.shape)) + + output_onset, output_offset, output_mpe, output_velocity = self.decoder_spec2midi(enc_vector) + #output_onset = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Model_single(2) output_onset: '+str(output_onset.shape)) + #print('Model_single(2) output_velocity: '+str(output_velocity.shape)) + + return output_onset, output_offset, output_mpe, output_velocity + + +## +## Model (combination output) +## +# 1FDT: Encoder_CNNtime_SAfreq / Decoder_CAfreq_SAtime +# 1FLT: Encoder_CNNtime_SAfreq / Decoder_linear_SAtime +# 2FDT: Encoder_CNNblock_SAfreq / Decoder_CAfreq_SAtime +class Model_combination(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder_spec2midi = encoder + self.decoder_spec2midi = decoder + + def forward(self, input_spec): + #input_spec = [batch_size, n_bin, margin+n_frame+margin] (8, 256, 192) + #print('Model_combination(0) input_spec: '+str(input_spec.shape)) + + enc_vector = self.encoder_spec2midi(input_spec) + #enc_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Model_combination(1) enc_vector: '+str(enc_vector.shape)) + + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.decoder_spec2midi(enc_vector) + #output_onset_A = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_A = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Model_combination(2) output_onset_A: '+str(output_onset_A.shape)) + #print('Model_combination(2) output_velocity_A: '+str(output_velocity_A.shape)) + #print('Model_combination(2) output_onset_B: '+str(output_onset_B.shape)) + #print('Model_combination(2) output_velocity_B: '+str(output_velocity_B.shape)) + + return output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B + + +## +## Encoder +## +# Encoder_CNNtime_SAfreq +# Encoder_CNNblock_SAfreq +## +## Encoder CNN(time)+SA(freq) +## +class Encoder_CNNtime_SAfreq(nn.Module): + def __init__(self, n_margin, n_frame, n_bin, cnn_channel, cnn_kernel, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + + self.device = device + self.n_frame = n_frame + self.n_bin = n_bin + self.cnn_channel = cnn_channel + self.cnn_kernel = cnn_kernel + self.hid_dim = hid_dim + self.conv = nn.Conv2d(1, self.cnn_channel, kernel_size=(1, self.cnn_kernel)) + self.n_proc = n_margin * 2 + 1 + self.cnn_dim = self.cnn_channel * (self.n_proc - (self.cnn_kernel - 1)) + self.tok_embedding_freq = nn.Linear(self.cnn_dim, hid_dim) + self.pos_embedding_freq = nn.Embedding(n_bin, hid_dim) + self.layers_freq = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.dropout = nn.Dropout(dropout) + self.scale_freq = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + + def forward(self, spec_in): + #spec_in = [batch_size, n_bin, n_margin+n_frame+n_margin] (8, 256, 192) (batch_size=8, n_bins=256, margin=32/n_frame=128) + #print('Encoder_CNNtime_SAfreq(0) spec_in: '+str(spec_in.shape)) + batch_size = spec_in.shape[0] + + # CNN + spec_cnn = self.conv(spec_in.unsqueeze(1)) + #spec_cnn: [batch_size, cnn_channel, n_bin, n_margin+n_frame+n_margin-(cnn_kernel-1)] (8, 4, 256, 188) + #print('Encoder_CNNtime_SAfreq(1) spec_cnn: '+str(spec_cnn.shape)) + + # n_frame block + spec_cnn = spec_cnn.unfold(3, 61, 1).permute(0, 3, 2, 1, 4).contiguous().reshape([batch_size*self.n_frame, self.n_bin, self.cnn_dim]) + #spec_cnn: [batch_size*n_frame, n_bin, cnn_dim] (8*128, 256, 244) + #print('Encoder_CNNtime_SAfreq(2) spec_cnn: '+str(spec_cnn.shape)) + + # embedding + spec_emb_freq = self.tok_embedding_freq(spec_cnn) + # spec_emb_freq: [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNtime_SAfreq(3) spec_emb_freq: '+str(spec_emb_freq.shape)) + + # position coding + pos_freq = torch.arange(0, self.n_bin).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + #pos_freq = [batch_size*n_frame, n_bin] (8*128, 256) + #print('Encoder_CNNtime_SAfreq(4) pos_freq: '+str(pos_freq.shape)) + + # embedding + spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq)) + #spec_freq = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNtime_SAfreq(5) spec_freq: '+str(spec_freq.shape)) + + # transformer encoder + for layer_freq in self.layers_freq: + spec_freq = layer_freq(spec_freq) + spec_freq = spec_freq.reshape([batch_size, self.n_frame, self.n_bin, self.hid_dim]) + #spec_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Encoder_CNNtime_SAfreq(6) spec_freq: '+str(spec_freq.shape)) + + return spec_freq + + +## +## Encoder CNN(block)+SA(freq) +## +class Encoder_CNNblock_SAfreq(nn.Module): + def __init__(self, n_margin, n_frame, n_bin, hid_dim, n_layers, n_heads, pf_dim, dropout, dropout_convblock, device): + super().__init__() + + self.device = device + self.n_frame = n_frame + self.n_bin = n_bin + self.hid_dim = hid_dim + + k = 3 + p = 1 + # ConvBlock1 + layers_conv_1 = [] + ch1 = 48 + layers_conv_1.append(nn.Conv2d(1, ch1, kernel_size=k, stride=1, padding=p)) + layers_conv_1.append(nn.BatchNorm2d(ch1)) + layers_conv_1.append(nn.ReLU(True)) + layers_conv_1.append(nn.Conv2d(ch1, ch1, kernel_size=k, stride=1, padding=p)) + layers_conv_1.append(nn.BatchNorm2d(ch1)) + layers_conv_1.append(nn.ReLU(True)) + layers_conv_1.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_1 = nn.Sequential(*layers_conv_1) + self.dropout_1 = nn.Dropout(dropout_convblock) + # ConvBlock2 + layers_conv_2 = [] + ch2 = 64 + layers_conv_2.append(nn.Conv2d(ch1, ch2, kernel_size=k, stride=1, padding=p)) + layers_conv_2.append(nn.BatchNorm2d(ch2)) + layers_conv_2.append(nn.ReLU(True)) + layers_conv_2.append(nn.Conv2d(ch2, ch2, kernel_size=k, stride=1, padding=p)) + layers_conv_2.append(nn.BatchNorm2d(ch2)) + layers_conv_2.append(nn.ReLU(True)) + layers_conv_2.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_2 = nn.Sequential(*layers_conv_2) + self.dropout_2 = nn.Dropout(dropout_convblock) + # ConvBlock3 + layers_conv_3 = [] + ch3 = 96 + layers_conv_3.append(nn.Conv2d(ch2, ch3, kernel_size=k, stride=1, padding=p)) + layers_conv_3.append(nn.BatchNorm2d(ch3)) + layers_conv_3.append(nn.ReLU(True)) + layers_conv_3.append(nn.Conv2d(ch3, ch3, kernel_size=k, stride=1, padding=p)) + layers_conv_3.append(nn.BatchNorm2d(ch3)) + layers_conv_3.append(nn.ReLU(True)) + layers_conv_3.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_3 = nn.Sequential(*layers_conv_3) + self.dropout_3 = nn.Dropout(dropout_convblock) + # ConvBlock4 + layers_conv_4 = [] + ch4 = 128 + layers_conv_4.append(nn.Conv2d(ch3, ch4, kernel_size=k, stride=1, padding=p)) + layers_conv_4.append(nn.BatchNorm2d(ch4)) + layers_conv_4.append(nn.ReLU(True)) + layers_conv_4.append(nn.Conv2d(ch4, ch4, kernel_size=k, stride=1, padding=p)) + layers_conv_4.append(nn.BatchNorm2d(ch4)) + layers_conv_4.append(nn.ReLU(True)) + layers_conv_4.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_4 = nn.Sequential(*layers_conv_4) + self.dropout_4 = nn.Dropout(dropout_convblock) + + self.n_proc = n_margin * 2 + 1 + self.cnn_dim = int(int(int(int(self.n_bin/2)/2)/2)/2) + self.cnn_channel_A = 16 + self.cnn_channel_B = 8 + self.cnn_out_dim = self.n_proc * self.cnn_channel_B + + self.tok_embedding_freq = nn.Linear(self.cnn_out_dim, hid_dim) + self.pos_embedding_freq = nn.Embedding(n_bin, hid_dim) + self.layers_freq = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.dropout = nn.Dropout(dropout) + self.scale_freq = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + + def forward(self, spec_in): + #spec_in = [batch_size, n_bin, n_margin+n_frame+n_margin] (8, 256, 192) (batch_size=8, n_bins=256, margin=32/n_frame=128) + #print('Encoder_CNNblock_SAfreq(0) spec_in: '+str(spec_in.shape)) + batch_size = spec_in.shape[0] + + # conv blocks + spec1 = self.dropout_1(self.conv_1(spec_in.permute(0, 2, 1).contiguous().unsqueeze(1))) + #spec1 = [batch_size, ch1, n_margin+n_frame+n_margin, int(n_bin/2)] (8, 48, 192, 128) + #print('Encoder_CNNblock_SAfreq(1) spec1: '+str(spec1.shape)) + + spec2 = self.dropout_2(self.conv_2(spec1)) + #spec2 = [batch_size, ch2, n_margin+n_frame+n_margin, int(int(n_bin/2)/2)] (8, 64, 192, 64) + #print('Encoder_CNNblock_SAfreq(2) spec2: '+str(spec2.shape)) + + spec3 = self.dropout_3(self.conv_3(spec2)) + #spec3 = [batch_size, ch3, n_margin+n_frame+n_margin, int(int(int(n_bin/2)/2)/2)] (8, 96, 192, 32) + #print('Encoder_CNNblock_SAfreq(3) spec3: '+str(spec3.shape)) + + spec4 = self.dropout_4(self.conv_4(spec3)) + #spec4 = [batch_size, ch4, n_margin+n_frame+n_margin, int(int(int(int(n_bin/2)/2)/2)/2)] (8, 128, 192, 16) + #print('Encoder_CNNblock_SAfreq(4) spec4: '+str(spec4.shape)) + + # n_frame block + spec5 = spec4.unfold(2, self.n_proc, 1) + #spec5: [batch_size, ch4, n_frame, 16bin, n_proc] (8, 128, 128, 16, 65) + #print('Encoder_CNNblock_SAfreq(5) spec5: '+str(spec5.shape)) + + spec6 = spec5.permute(0, 2, 3, 1, 4).contiguous() + #spec6: [batch_size, n_frame, cnn_dim, ch4, n_proc] (8, 128, 16, 128, 65) + #print('Encoder_CNNblock_SAfreq(6) spec6: '+str(spec6.shape)) + + spec7 = spec6.reshape([batch_size, self.n_frame, self.cnn_dim, self.cnn_channel_A, self.cnn_channel_B, self.n_proc]) + #spec7: [batch_size, n_frame, cnn_dim, cnn_channel_A, cnn_channel_B, n_proc] (8, 128, 16, 16, 8, 65) + #print('Encoder_CNNblock_SAfreq(7) spec7: '+str(spec7.shape)) + + spec8 = spec7.reshape([batch_size, self.n_frame, self.n_bin, self.cnn_out_dim]) + #spec8: [batch_size, n_frame, n_bin, cnn_out_dim] (8, 128, 256, 520) + #print('Encoder_CNNblock_SAfreq(8) spec8: '+str(spec8.shape)) + + spec_emb_freq = self.tok_embedding_freq(spec8).reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + # spec_emb_freq: [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNblock_SAfreq(9) spec_emb_freq: '+str(spec_emb_freq.shape)) + + # position coding + pos_freq = torch.arange(0, self.n_bin).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + #pos_freq = [batch_size*n_frame, n_bin] (8*128, 256) + #print('Encoder_CNNblock_SAfreq(10) pos_freq: '+str(pos_freq.shape)) + + # embedding + spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq)) + #spec_freq = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNblock_SAfreq(11) spec_freq: '+str(spec_freq.shape)) + + # transformer encoder + for layer_freq in self.layers_freq: + spec_freq = layer_freq(spec_freq) + spec_freq = spec_freq.reshape(batch_size, self.n_frame, self.n_bin, self.hid_dim) + #spec_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Encoder_CNNblock_SAfreq(12) spec_freq: '+str(spec_freq.shape)) + + return spec_freq + + +## +## Decoder +## +# Decoder_CAfreq +# Decoder_CAfreq_SAtime +# Decoder_linear_SAtime +## +## Decoder CA(freq) +## +class Decoder_CAfreq(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + + self.pos_embedding_freq = nn.Embedding(n_note, hid_dim) + self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device) + self.layers_freq = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers-1)]) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + self.sigmoid = nn.Sigmoid() + + def forward(self, enc_spec): + #enc_spec = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + batch_size = enc_spec.shape[0] + + enc_spec = enc_spec.reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + #enc_spec = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Decoder_CAfreq(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## CAfreq bin(256)/note(88) + ## + pos_freq = torch.arange(0, self.n_note).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + midi_freq = self.pos_embedding_freq(pos_freq) + #pos_freq = [batch_size*n_frame, n_note] (8*128, 88) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_CAfreq(1) pos_freq: '+str(pos_freq.shape)) + #print('Decoder_CAfreq(1) midi_freq: '+str(midi_freq.shape)) + + midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq) + for layer_freq in self.layers_freq: + midi_freq, attention_freq = layer_freq(enc_spec, midi_freq) + dim = attention_freq.shape + attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]]) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #attention_freq = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Decoder_CAfreq(2) midi_freq: '+str(midi_freq.shape)) + #print('Decoder_CAfreq(2) attention_freq: '+str(attention_freq.shape)) + + ## output + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_CAfreq(3) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_CAfreq(3) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_CAfreq(3) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_CAfreq(3) output_velocity_freq: '+str(output_velocity_freq.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq + + +## +## Decoder CA(freq)/SA(time) +## +class Decoder_CAfreq_SAtime(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + self.sigmoid = nn.Sigmoid() + self.dropout = nn.Dropout(dropout) + + # CAfreq + self.pos_embedding_freq = nn.Embedding(n_note, hid_dim) + self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device) + self.layers_freq = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers-1)]) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + # SAtime + self.scale_time = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + self.pos_embedding_time = nn.Embedding(n_frame, hid_dim) + #self.layers_time = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.layers_time = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + + self.fc_onset_time = nn.Linear(hid_dim, 1) + self.fc_offset_time = nn.Linear(hid_dim, 1) + self.fc_mpe_time = nn.Linear(hid_dim, 1) + self.fc_velocity_time = nn.Linear(hid_dim, self.n_velocity) + + def forward(self, enc_spec): + batch_size = enc_spec.shape[0] + enc_spec = enc_spec.reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + #enc_spec = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Decoder_CAfreq_SAtime(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## CAfreq freq(256)/note(88) + ## + pos_freq = torch.arange(0, self.n_note).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + midi_freq = self.pos_embedding_freq(pos_freq) + #pos_freq = [batch_size*n_frame, n_note] (8*128, 88) + #midi_freq = [batch_size, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_CAfreq_SAtime(1) pos_freq: '+str(pos_freq.shape)) + #print('Decoder_CAfreq_SAtime(1) midi_freq: '+str(midi_freq.shape)) + + midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq) + for layer_freq in self.layers_freq: + midi_freq, attention_freq = layer_freq(enc_spec, midi_freq) + dim = attention_freq.shape + attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]]) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #attention_freq = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Decoder_CAfreq_SAtime(2) midi_freq: '+str(midi_freq.shape)) + #print('Decoder_CAfreq_SAtime(2) attention_freq: '+str(attention_freq.shape)) + + ## output(freq) + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_CAfreq_SAtime(3) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_CAfreq_SAtime(3) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_CAfreq_SAtime(3) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_CAfreq_SAtime(3) output_velocity_freq: '+str(output_velocity_freq.shape)) + + ## + ## SAtime time(64) + ## + #midi_time: [batch_size*n_frame, n_note, hid_dim] -> [batch_size*n_note, n_frame, hid_dim] + midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim]) + pos_time = torch.arange(0, self.n_frame).unsqueeze(0).repeat(batch_size*self.n_note, 1).to(self.device) + midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time)) + #pos_time = [batch_size*n_note, n_frame] (8*88, 128) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_CAfreq_SAtime(4) pos_time: '+str(pos_time.shape)) + #print('Decoder_CAfreq_SAtime(4) midi_time: '+str(midi_time.shape)) + + for layer_time in self.layers_time: + midi_time = layer_time(midi_time) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_CAfreq_SAtime(5) midi_time: '+str(midi_time.shape)) + + ## output(time) + output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous() + #output_onset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_time = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_CAfreq_SAtime(6) output_onset_time: '+str(output_onset_time.shape)) + #print('Decoder_CAfreq_SAtime(6) output_offset_time: '+str(output_offset_time.shape)) + #print('Decoder_CAfreq_SAtime(6) output_mpe_time: '+str(output_mpe_time.shape)) + #print('Decoder_CAfreq_SAtime(6) output_velocity_time: '+str(output_velocity_time.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time + + +## +## Decoder linear/SA(time) +## +class Decoder_linear_SAtime(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + self.sigmoid = nn.Sigmoid() + self.dropout = nn.Dropout(dropout) + + self.fc_convert = nn.Linear(n_bin, n_note) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + # SAtime + self.scale_time = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + self.pos_embedding_time = nn.Embedding(n_frame, hid_dim) + #self.layers_time = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.layers_time = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + + self.fc_onset_time = nn.Linear(hid_dim, 1) + self.fc_offset_time = nn.Linear(hid_dim, 1) + self.fc_mpe_time = nn.Linear(hid_dim, 1) + self.fc_velocity_time = nn.Linear(hid_dim, self.n_velocity) + + def forward(self, enc_spec): + batch_size = enc_spec.shape[0] + enc_spec = enc_spec.permute(0, 1, 3, 2).contiguous().reshape([batch_size*self.n_frame, self.hid_dim, self.n_bin]) + #enc_spec = [batch_size*n_frame, hid_dim, n_bin] (8*128, 256, 256) + #print('Decoder_linear_SAtime(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## linear bin(256)/note(88) + ## + midi_freq = self.fc_convert(enc_spec).permute(0, 2, 1).contiguous() + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_linear_SAtime(1) midi_freq: '+str(midi_freq.shape)) + + ## output(freq) + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_linear_SAtime(2) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_linear_SAtime(2) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_linear_SAtime(2) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_linear_SAtime(2) output_velocity_freq: '+str(output_velocity_freq.shape)) + + ## + ## SAtime time(64) + ## + #midi_time: [batch_size*n_frame, n_note, hid_dim] -> [batch_size*n_note, n_frame, hid_dim] + midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim]) + pos_time = torch.arange(0, self.n_frame).unsqueeze(0).repeat(batch_size*self.n_note, 1).to(self.device) + midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time)) + #pos_time = [batch_size*n_note, n_frame] (8*88, 128) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_linear_SAtime(3) pos_time: '+str(pos_time.shape)) + #print('Decoder_linear_SAtime(3) midi_time: '+str(midi_time.shape)) + + for layer_time in self.layers_time: + midi_time = layer_time(midi_time) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_linear_SAtime(4) midi_time: '+str(midi_time.shape)) + + ## output(time) + output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous() + #output_onset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_time = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_linear_SAtime(5) output_onset_time: '+str(output_onset_time.shape)) + #print('Decoder_linear_SAtime(5) output_offset_time: '+str(output_offset_time.shape)) + #print('Decoder_linear_SAtime(5) output_mpe_time: '+str(output_mpe_time.shape)) + #print('Decoder_linear_SAtime(5) output_velocity_time: '+str(output_velocity_time.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time + + +## +## sub functions +## +class EncoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, src): + #src = [batch_size, src_len, hid_dim] + + #self attention + _src, _ = self.self_attention(src, src, src) + #dropout, residual connection and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + #positionwise feedforward + _src = self.positionwise_feedforward(src) + #dropout, residual and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + return src + +class DecoderLayer_Zero(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class DecoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #self attention + _trg, _ = self.self_attention(trg, trg, trg) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, hid_dim, n_heads, dropout, device): + super().__init__() + assert hid_dim % n_heads == 0 + self.hid_dim = hid_dim + self.n_heads = n_heads + self.head_dim = hid_dim // n_heads + self.fc_q = nn.Linear(hid_dim, hid_dim) + self.fc_k = nn.Linear(hid_dim, hid_dim) + self.fc_v = nn.Linear(hid_dim, hid_dim) + self.fc_o = nn.Linear(hid_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) + + def forward(self, query, key, value): + batch_size = query.shape[0] + #query = [batch_size, query_len, hid_dim] + #key = [batch_size, key_len, hid_dim] + #value = [batch_size, value_len, hid_dim] + + Q = self.fc_q(query) + K = self.fc_k(key) + V = self.fc_v(value) + #Q = [batch_size, query_len, hid_dim] + #K = [batch_size, key_len, hid_dim] + #V = [batch_size, value_len, hid_dim] + + Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + #Q = [batch_size, n_heads, query_len, head_dim] + #K = [batch_size, n_heads, key_len, head_dim] + #V = [batch_size, n_heads, value_len, head_dim] + + energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale + #energy = [batch_size, n_heads, seq len, seq len] + + attention = torch.softmax(energy, dim = -1) + #attention = [batch_size, n_heads, query_len, key_len] + + x = torch.matmul(self.dropout(attention), V) + #x = [batch_size, n_heads, seq len, head_dim] + + x = x.permute(0, 2, 1, 3).contiguous() + #x = [batch_size, seq_len, n_heads, head_dim] + + x = x.view(batch_size, -1, self.hid_dim) + #x = [batch_size, seq_len, hid_dim] + + x = self.fc_o(x) + #x = [batch_size, seq_len, hid_dim] + + return x, attention + +class PositionwiseFeedforwardLayer(nn.Module): + def __init__(self, hid_dim, pf_dim, dropout): + super().__init__() + self.fc_1 = nn.Linear(hid_dim, pf_dim) + self.fc_2 = nn.Linear(pf_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + #x = [batch_size, seq_len, hid_dim] + + x = self.dropout(torch.relu(self.fc_1(x))) + #x = [batch_size, seq_len, pf dim] + + x = self.fc_2(x) + #x = [batch_size, seq_len, hid_dim] + + return x diff --git a/baselines/hft_transformer/transcribe_new_files.py b/experiments/baselines/hft_transformer/transcribe_new_files.py similarity index 58% rename from baselines/hft_transformer/transcribe_new_files.py rename to experiments/baselines/hft_transformer/transcribe_new_files.py index 594bb44..5ac5a80 100644 --- a/baselines/hft_transformer/transcribe_new_files.py +++ b/experiments/baselines/hft_transformer/transcribe_new_files.py @@ -4,13 +4,17 @@ import json import sys import glob -from baselines.hft_transformer.src import amt -from pydub import AudioSegment -from pydub.exceptions import CouldntDecodeError import random import torch here = os.path.dirname(os.path.abspath(__file__)) - +import sys +sys.path.append(os.path.join(here, 'model')) +import hft_amt as amt +import time +from random import shuffle +sys.path.append(os.path.join(here, '../..')) +import loader_util +from tqdm.auto import tqdm _AMT = None def get_AMT(config_file=None, model_file=None): @@ -33,17 +37,6 @@ def get_AMT(config_file=None, model_file=None): _AMT.model = model return _AMT -def check_and_convert_mp3_to_wav(fname): - wav_file = fname.replace('.mp3', '.wav') - if not os.path.exists(wav_file): - print('converting ' + fname + ' to .wav...') - try: - sound = AudioSegment.from_mp3(fname) - sound.export(fname.replace('.mp3', '.wav'), format="wav") - except CouldntDecodeError: - print('failed to convert ' + fname) - return None - return wav_file def transcribe_file( @@ -59,9 +52,10 @@ def transcribe_file( ): if AMT is None: AMT = get_AMT() - + now_start = time.time() a_feature = AMT.wav2feature(fname) - + print(f'READING ELAPSED TIME: {time.time() - now_start}') + now_read = time.time() # transcript if n_stride > 0: output = AMT.transcript_stride(a_feature, n_stride, mode=mode, ablation_flag=ablation) @@ -69,7 +63,8 @@ def transcribe_file( output = AMT.transcript(a_feature, mode=mode, ablation_flag=ablation) (output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity, output_2nd_onset, output_2nd_offset, output_2nd_mpe, output_2nd_velocity) = output - + print(f'TRANSCRIPTION ELAPSED TIME: {time.time() - now_read}') + print(f'TOTAL ELAPSED TIME: {time.time() - now_start}') # note (mpe2note) a_note_1st_predict = AMT.mpe2note( a_onset=output_1st_onset, @@ -101,15 +96,9 @@ def transcribe_file( if __name__ == '__main__': parser = argparse.ArgumentParser() # necessary arguments - parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') - parser.add_argument('-input_file_to_transcribe', default=None, help='one file') - parser.add_argument('-output_dir', help='output directory') - parser.add_argument('-output_file', default=None, help='output file') + parser = loader_util.add_io_arguments(parser) parser.add_argument('-f_config', help='config json file', default=None) parser.add_argument('-model_file', help='input model file', default=None) - parser.add_argument('-start_index', help='start index', type=int, default=None) - parser.add_argument('-end_index', help='end index', type=int, default=None) - parser.add_argument('-skip_transcribe_mp3', action='store_true', default=False) # parameters parser.add_argument('-mode', help='mode to transcript (combination|single)', default='combination') parser.add_argument('-thred_mpe', help='threshold value for mpe detection', type=float, default=0.5) @@ -121,56 +110,23 @@ def transcribe_file( assert (args.input_dir_to_transcribe is not None) or (args.input_file_to_transcribe is not None), "input file or directory is not specified" - if args.input_dir_to_transcribe is not None: - if not args.skip_transcribe_mp3: - # list file - a_mp3s = ( - glob.glob(os.path.join(args.input_dir_to_transcribe, '*.mp3')) + - glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.mp3')) - ) - print(f'transcribing {len(a_mp3s)} files: [{str(a_mp3s)}]...') - list(map(check_and_convert_mp3_to_wav, a_mp3s)) - - a_list = ( - glob.glob(os.path.join(args.input_dir_to_transcribe, '*.wav')) + - glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.wav')) - ) - if (args.start_index is not None) or (args.end_index is not None): - if args.start_index is None: - args.start_index = 0 - if args.end_index is None: - args.end_index = len(a_list) - a_list = a_list[args.start_index:args.end_index] - # shuffle a_list - random.shuffle(a_list) - - elif args.input_file_to_transcribe is not None: - args.input_file_to_transcribe = check_and_convert_mp3_to_wav(args.input_file_to_transcribe) - if args.input_file_to_transcribe is None: - sys.exit() - a_list = [args.input_file_to_transcribe] - print(f'transcribing {str(a_list)} files...') + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + a_list = loader_util.get_files_to_transcribe(args) # load model AMT = get_AMT(args.f_config, args.model_file) long_filename_counter = 0 - for fname in a_list: - if args.output_file is not None: - output_fname = args.output_file - else: - output_fname = fname.replace('.wav', '') - if len(output_fname) > 200: - output_fname = output_fname[:200] + f'_fnabbrev-{long_filename_counter}' - output_fname += '_transcribed.mid' - output_fname = os.path.join(args.output_dir, os.path.basename(output_fname)) - if os.path.exists(output_fname): - continue - - print('[' + fname + ']') + for input_fname, output_fname in tqdm(a_list): + if os.path.exists(output_fname): + continue + + print(f'transcribing {input_fname} -> {output_fname}') try: transcribe_file( - fname, + input_fname, output_fname, args.mode, args.thred_mpe, @@ -180,6 +136,8 @@ def transcribe_file( args.ablation, AMT, ) + now = time.time() + print(f'ELAPSED TIME: {time.time() - now}') except Exception as e: print(e) continue @@ -193,4 +151,12 @@ def transcribe_file( python evaluation/transcribe_new_files.py \ -input_dir_to_transcribe evaluation/glenn-gould-bach-data \ -output_dir hft-evaluation-data/ \ + +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir hft-dtw-evaluation-data/ \ + -file_col_name audio_path """ diff --git a/experiments/baselines/requirements-baselines.txt b/experiments/baselines/requirements-baselines.txt new file mode 100644 index 0000000..17a212d --- /dev/null +++ b/experiments/baselines/requirements-baselines.txt @@ -0,0 +1,17 @@ +# you need python 3.11 for this to work (tensorflow-text dependency) +pretty_midi +librosa==0.9.2 +piano_transcription_inference +# jax[cuda12_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +nest-asyncio +gdown +tensorflow[and-cuda] +gin-config +t5 +wget +note-seq +git+https://github.com/magenta/mt3.git + +# -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ No newline at end of file diff --git a/experiments/loader_util.py b/experiments/loader_util.py new file mode 100644 index 0000000..6a7ef84 --- /dev/null +++ b/experiments/loader_util.py @@ -0,0 +1,139 @@ +import argparse +from pydub import AudioSegment +from pydub.exceptions import CouldntDecodeError +import os +import glob +import random +import sys +import pandas as pd +from more_itertools import unique_everseen +from tqdm.auto import tqdm +from random import shuffle + + +def add_io_arguments(parser: argparse.ArgumentParser): + parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') + parser.add_argument('-input_file_to_transcribe', default=None, help='one file') + + # params for if we're reading file names from a CSV + parser.add_argument('-input_files_map', help='CSV of files to transcribe', default=None) + parser.add_argument('-file_col_name', help='column name for file', default='file') + parser.add_argument('-split', help='split', default=None) + parser.add_argument('-split_col_name', help='column name for split', default='split') + parser.add_argument('-dataset', help='dataset', default=None) + parser.add_argument('-dataset_col_name', help='column name for dataset', default='dataset') + + # some algorithms only take a certain file format (e.g. MP3 or WAV) + parser.add_argument('-input_file_format', default=None, + help='Required input format ["mp3", "wav"]. ' + 'E.g. (I think) hFT only takes in WAV files.' + ) + parser.add_argument('-output_dir', help='output directory') + parser.add_argument('-output_file', default=None, help='output file') + parser.add_argument('-start_index', help='start index', type=int, default=None) + parser.add_argument('-end_index', help='end index', type=int, default=None) + return parser + + +def check_and_convert_between_mp3_and_wav(input_fname, current_fmt='mp3', desired_fmt='wav'): + input_fmt, output_fmt = f'.{current_fmt}', f'.{desired_fmt}' + output_file = input_fname.replace(input_fmt, output_fmt) + if not os.path.exists(input_fname): + print(f'converting {input_fname}: {input_fmt}->{output_fmt}...') + try: + if input_fmt == 'mp3': + sound = AudioSegment.from_mp3(input_fname) + sound.export(output_file, format="wav") + else: + sound = AudioSegment.from_wav(input_fname) + sound.export(output_file, format="mp3") + except CouldntDecodeError: + print('failed to convert ' + input_fname) + return None + return output_file + + +def get_files_to_transcribe(args): + """ + Helper function to get the files to transcribe. + Reads in the files from a CSV, a directory, or a single file. + (if CSV is provided, then the input directory serves to give us a starting-point for the files.) + (otherwise, we just glob all the files in the directory.) + + Returns list of tuples (input_file, output_file). + Output file the same as input file, with "_transcribed.midi". + If no output directory is provided, it is placed in the same directory. + Otherwise, it is placed in the output directory. + The same file hierarchy is maintained. + + :param args: argparse.ArgumentParser + :return + + """ + # get files to transcribe + + # if just one filename is provided, format it as a list + if args.input_file_to_transcribe is not None: + files_to_transcribe = [args.input_file_to_transcribe] + + # get a list of files from a CSV + elif args.input_files_map is not None: + files_to_transcribe = pd.read_csv(args.input_files_map) + if args.split is not None: + files_to_transcribe = files_to_transcribe.loc[lambda df: df[args.split_col_name] == args.split] + if args.dataset is not None: + files_to_transcribe = files_to_transcribe.loc[lambda df: df[args.dataset_col_name] == args.dataset] + files_to_transcribe = files_to_transcribe[args.file_col_name].tolist() + if args.input_dir_to_transcribe is not None: + files_to_transcribe = list(map(lambda x: os.path.join(args.input_dir_to_transcribe, x), files_to_transcribe)) + + # get all files in a directory + elif args.input_dir_to_transcribe is not None: + files_to_transcribe = ( + glob.glob(os.path.join(args.input_dir_to_transcribe, '**', '*.mp3'), recursive=True) + + glob.glob(os.path.join(args.input_dir_to_transcribe, '**', '*.wav'), recursive=True) + ) + + # convert file-types + if args.input_file_format is not None: + # make sure all the files of mp3 are converted to wav, or v.v. + other_fmt = 'mp3' if args.input_file_format == 'wav' else 'wav' + files_to_convert = list(filter(lambda x: os.path.splitext(x)[1] == other_fmt, files_to_transcribe)) + print(f'converting {len(files_to_convert)} files...') + for f in files_to_convert: + check_and_convert_between_mp3_and_wav(f, current_fmt=other_fmt, desired_fmt=args.input_file_format) + else: + # input format doesn't matter, so we just want 1 of each + files_to_transcribe = list(unique_everseen(files_to_transcribe, key=lambda x: os.path.splitext(x)[0])) + + # apply cutoffs + if (args.start_index is not None) or (args.end_index is not None): + if args.start_index is None: + args.start_index = 0 + if args.end_index is None: + args.end_index = len(files_to_transcribe) + files_to_transcribe = files_to_transcribe[args.start_index:args.end_index] + + # format output + if args.output_file is not None: + os.makedirs(os.path.dirname(args.output_file), exist_ok=True) + return (files_to_transcribe[0], args.output_file) + + # if the output directory is not provided, then we just put the output files in the same directory + # otherwise, we output to the output directory, preserving the hierarchy of the original files. + output_files = list(map(lambda x: f"{os.path.splitext(x)[0]}_transcribed.midi", files_to_transcribe)) + if args.output_dir is not None: + if args.input_dir_to_transcribe is not None: + output_files = list(map(lambda x: x[len(args.input_dir_to_transcribe):], output_files)) + output_files = list(map(lambda x: os.path.join(args.output_dir, x), output_files)) + for o in output_files: + os.makedirs(os.path.dirname(o), exist_ok=True) + + # shuffle + output = list(zip(files_to_transcribe, output_files)) + output = list(filter(lambda x: not os.path.exists(x[1]), output)) + random.shuffle(output) + return output + + + diff --git a/experiments/process_input_files.py b/experiments/process_input_files.py new file mode 100644 index 0000000..0d11b53 --- /dev/null +++ b/experiments/process_input_files.py @@ -0,0 +1,96 @@ +""" +Helper script to get and augment MAESTRO test according to `-augmentation_config` and `-apply_augmentation` flags. +""" +import pandas as pd +import os +import shutil +from amt.audio import AudioTransform, pad_or_trim +from amt.data import get_wav_mid_segments, load_config +import json +import librosa +import torch +import torchaudio +from tqdm.auto import tqdm + +SAMPLE_RATE = load_config()['audio']['sample_rate'] +AUG_BATCH_SIZE = 100 + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-split', type=str, required=True, default='test', help='Split to print out.') + parser.add_argument('-dataset', type=str, default=None, help='Dataset to use.') + parser.add_argument('-input_file_dir', type=str, default=None, help='Directory of the dataset to use.') + parser.add_argument( + '-input_splits_file', + type=str, + required=True, + help='Directory of the MAESTRO dataset.' + ) + parser.add_argument('-midi_col_name', type=str, default=None, help='Column name for MIDI files.') + parser.add_argument('-audio_col_name', type=str, default=None, help='Column name for audio files.') + parser.add_argument('-output_dir', type=str, required=True, help='Output directory.') + parser.add_argument('-apply_augmentation', action='store_true', default=False, help='Apply augmentation to the files.') + parser.add_argument('-augmentation_config', type=str, default=None, help='Path to the augmentation config file.') + parser.add_argument('-device', type=str, default='cpu', help='Device to use.') + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + audio_transformer = None + if args.apply_augmentation: + aug_config = json.load(open(args.augmentation_config)) + audio_transformer = AudioTransform(**aug_config).to(args.device) + + # Load the split + input_files_to_process = pd.read_csv(args.input_splits_file) + if args.split is not None: + input_files_to_process = input_files_to_process.loc[lambda df: df['split'] == args.split] + if args.dataset is not None: + input_files_to_process = input_files_to_process.loc[lambda df: df['dataset'] == args.dataset] + + # Process the files + for _, row in tqdm( + input_files_to_process.iterrows(), + total=len(input_files_to_process), + desc=f'Processing {args.split} split' + ): + # copy MIDI file into the output directory + if args.midi_col_name is not None: + midi_outfile = os.path.basename(row[args.midi_col_name]) + fname, ext = os.path.splitext(midi_outfile) + midi_outfile = f'{fname}_gold{ext}' + midi_outfile = os.path.join(args.output_dir, midi_outfile) + if not os.path.exists(midi_outfile): + shutil.copy( + os.path.join(args.input_file_dir, row['midi_filename']), + os.path.join(args.output_dir, midi_outfile) + ) + + # either just vanilla copy the audio file, or apply augmentation + if args.audio_col_name is not None: + audio_outfile = os.path.basename(row[args.audio_col_name]) + audio_outfile = os.path.join(args.output_dir, audio_outfile) + audio_input_file = os.path.join(args.input_file_dir, row[args.audio_col_name]) + if not os.path.exists(audio_outfile): + if args.apply_augmentation: + try: + segments = get_wav_mid_segments(audio_input_file) + segments = list(map(lambda x: x[0], segments)) + aug_wav_parts = [] + for i in range(0, len(segments), AUG_BATCH_SIZE): + batch_to_augment = torch.vstack(segments[i:i + AUG_BATCH_SIZE]).to(args.device) + mel = audio_transformer(batch_to_augment) + aug_wav = audio_transformer.inverse_log_mel(mel) + aug_wav_parts.append(aug_wav) + aug_wav = torch.vstack(aug_wav_parts) + aug_wav = aug_wav.reshape(1, -1).cpu() + torchaudio.save(audio_outfile, src=aug_wav, sample_rate=SAMPLE_RATE) + except Exception as e: + print(f'Failed to augment {audio_input_file}: {e}') + else: + shutil.copy( + os.path.join(audio_input_file), + os.path.join(args.output_dir, audio_outfile) + ) diff --git a/experiments/run_dtw_transcription.sh b/experiments/run_dtw_transcription.sh new file mode 100644 index 0000000..45c5a5a --- /dev/null +++ b/experiments/run_dtw_transcription.sh @@ -0,0 +1,36 @@ +#!/bin/sh +#SBATCH --output=dtw_transcription__%x.%j.out +#SBATCH --error=dtw_transcription__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + + +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir hft_transformer-evaluation-data/ \ + -file_col_name audio_path + +python baselines/giantmidi/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir giantmidi-evaluation-data/ \ + -file_col_name audio_path + +conda activate py311 +python baselines/google_t5/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir google-evaluation-data/ \ + -file_col_name audio_path \ No newline at end of file diff --git a/experiments/run_maestro_aug_1.sh b/experiments/run_maestro_aug_1.sh new file mode 100644 index 0000000..914e554 --- /dev/null +++ b/experiments/run_maestro_aug_1.sh @@ -0,0 +1,50 @@ +#!/bin/sh +#SBATCH --output=aug_1__%x.%j.out +#SBATCH --error=aug_1__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + +conda activate py311 +PROJ_DIR=/project/jonmay_231/spangher/Projects/aria-amt +OUTPUT_DIR="$PROJ_DIR/experiments/aug_1_files" + +# process data +if [ ! -d "$OUTPUT_DIR" ]; then + python process_maestro.py \ + -split test \ + -maestro_dir "$PROJ_DIR/../maestro-v3.0.0/maestro-v3.0.0.csv" \ + -output_dir $OUTPUT_DIR \ + -split test \ + -midi_col_name 'midi_filename' \ + -audio_col_name 'audio_filename' \ + -apply_augmentation \ + -augmentation_config "$PROJ_DIR/experiments/augmentation_configs/config_2.json" \ + -device 'cuda:0' +fi + +# run google inference +echo "Running google inference" +GOOGLE_OUTPUT_DIR="$OUTPUT_DIR/google_t5_transcriptions" +#if [ ! -d "$GOOGLE_OUTPUT_DIR" ]; then +python baselines/google_t5/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GOOGLE_OUTPUT_DIR +#fi + +echo "Running giant midi inference" +GIANT_MIDI_OUTPUT_DIR="$OUTPUT_DIR/giant_midi_transcriptions" +python baselines/giantmidi/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GIANT_MIDI_OUTPUT_DIR + +echo "Running hft inference" +HFT_OUTPUT_DIR="$OUTPUT_DIR/hft_transcriptions" +conda activate py311 +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $HFT_OUTPUT_DIR \ No newline at end of file diff --git a/experiments/run_maestro_aug_2.sh b/experiments/run_maestro_aug_2.sh new file mode 100644 index 0000000..b579c3e --- /dev/null +++ b/experiments/run_maestro_aug_2.sh @@ -0,0 +1,52 @@ +#!/bin/sh +#SBATCH --output=aug_2__%x.%j.out +#SBATCH --error=aug_2__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + +conda activate py311 +PROJ_DIR=/project/jonmay_231/spangher/Projects/aria-amt +OUTPUT_DIR="$PROJ_DIR/experiments/aug_2_files" + +# process data +if [ ! -d "$OUTPUT_DIR" ]; then + python process_maestro.py \ + -split test \ + -maestro_dir "$PROJ_DIR/../maestro-v3.0.0/maestro-v3.0.0.csv" \ + -output_dir $OUTPUT_DIR \ + -split test \ + -midi_col_name 'midi_filename' \ + -audio_col_name 'audio_filename' \ + -apply_augmentation \ + -augmentation_config "$PROJ_DIR/experiments/augmentation_configs/config_2.json" \ + -device 'cuda:0' +fi + +source /home1/${USER}/.bashrc +conda activate py311 + +## run google inference +#echo "Running google inference" +#GOOGLE_OUTPUT_DIR="$OUTPUT_DIR/google_t5_transcriptions" +##if [ ! -d "$GOOGLE_OUTPUT_DIR" ]; then +#python baselines/google_t5/transcribe_new_files.py \ +# -input_dir_to_transcribe $OUTPUT_DIR \ +# -output_dir $GOOGLE_OUTPUT_DIR +##fi +# +#echo "Running giant midi inference" +#GIANT_MIDI_OUTPUT_DIR="$OUTPUT_DIR/giant_midi_transcriptions" +#python baselines/giantmidi/transcribe_new_files.py \ +# -input_dir_to_transcribe $OUTPUT_DIR \ +# -output_dir $GIANT_MIDI_OUTPUT_DIR + +echo "Running hft inference" +HFT_OUTPUT_DIR="$OUTPUT_DIR/hft_transcriptions" +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $HFT_OUTPUT_DIR \ No newline at end of file diff --git a/experiments/run_maestro_vanilla.sh b/experiments/run_maestro_vanilla.sh new file mode 100644 index 0000000..b45f503 --- /dev/null +++ b/experiments/run_maestro_vanilla.sh @@ -0,0 +1,44 @@ +#!/bin/sh +#SBATCH --output=vanilla__%x.%j.out +#SBATCH --error=vanilla__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + +source /home1/${USER}/.bashrc +conda activate py311 + +PROJ_DIR=/project/jonmay_231/spangher/Projects/aria-amt +OUTPUT_DIR="$PROJ_DIR/experiments/vanilla_files" + +if [ ! -d "$OUTPUT_DIR" ]; then + python process_maestro.py \ + -split test \ + -maestro_dir "$PROJ_DIR/../maestro-v3.0.0/maestro-v3.0.0.csv" \ + -output_dir $OUTPUT_DIR \ + -midi_col_name 'midi_filename' \ + -audio_col_name 'audio_filename' +fi + +# run google inference +echo "Running google inference" +GOOGLE_OUTPUT_DIR="$OUTPUT_DIR/google_t5_transcriptions" +python baselines/google_t5/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GOOGLE_OUTPUT_DIR + +echo "Running giant midi inference" +GIANT_MIDI_OUTPUT_DIR="$OUTPUT_DIR/giant_midi_transcriptions" +python baselines/giantmidi/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GIANT_MIDI_OUTPUT_DIR + +echo "Running hft inference" +HFT_OUTPUT_DIR="$OUTPUT_DIR/hft_transcriptions" +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $HFT_OUTPUT_DIR \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 696fb40..78daa70 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ mido tqdm orjson mir_eval +pyfluidsynth @ git+https://github.com/nwhitehead/pyfluidsynth.git +midi2audio