From e6f12f0cac68f3b9ca63e3720db5b713097b9828 Mon Sep 17 00:00:00 2001 From: Alex Spangher Date: Tue, 2 Apr 2024 22:50:03 -0700 Subject: [PATCH 1/3] updated --- .gitignore | 3 +- amt/audio.py | 102 ++++-- baselines/requirements-baselines.txt | 3 - .../giantmidi/transcribe_new_files.py | 11 +- .../baselines/google_t5/test-output-file.midi | Bin 0 -> 16389 bytes .../google_t5/transcribe_new_files.py | 316 ++++++++++++++++++ .../baselines}/hft_transformer/src/amt.py | 0 .../hft_transformer/transcribe_new_files.py | 2 +- .../baselines/requirements-baselines.txt | 14 + experiments/get_maestro_split.py | 45 +++ experiments/run_maestro_vanilla.sh | 0 11 files changed, 465 insertions(+), 31 deletions(-) delete mode 100644 baselines/requirements-baselines.txt rename {baselines => experiments/baselines}/giantmidi/transcribe_new_files.py (78%) create mode 100644 experiments/baselines/google_t5/test-output-file.midi create mode 100644 experiments/baselines/google_t5/transcribe_new_files.py rename {baselines => experiments/baselines}/hft_transformer/src/amt.py (100%) rename {baselines => experiments/baselines}/hft_transformer/transcribe_new_files.py (99%) create mode 100644 experiments/baselines/requirements-baselines.txt create mode 100644 experiments/get_maestro_split.py create mode 100644 experiments/run_maestro_vanilla.sh diff --git a/.gitignore b/.gitignore index 026b7af..ab9832b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,8 @@ .idea/ notebooks/scratch -baselines/hft_transformer/model_files/ +experiments/baselines/hft_transformer/model_files/ +experiments/baselines/google_t5/model_files/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/amt/audio.py b/amt/audio.py index 6c37f3f..797926d 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -1,5 +1,5 @@ """Contains code taken from https://github.com/openai/whisper""" - +import functools import os import random import torch @@ -200,6 +200,7 @@ def __init__( detune_ratio: float = 0.1, detune_max_shift: float = 0.15, spec_aug_ratio: float = 0.5, + max_num_transforms: int = None, # currently we're doing 8 different transformations ): super().__init__() self.tokenizer = AmtTokenizer() @@ -224,6 +225,14 @@ def __init__( self.detune_max_shift = detune_max_shift self.spec_aug_ratio = spec_aug_ratio self.reduction_resample_rate = 6000 # Hardcoded? + # the following two variables, `self.t_count` and `self.max_num_transforms` + # are state variables that track the # of transformations applied. + # `self.t_count` is set in `forward` method to 0 + # `t_count` can also be passed into the following methods: `distortion_aug_cpu`, `log_mel`, `aug_wav`, + # the methods that we're stochastically applying transformations. + # a little messy/stateful, but helps the code be backwards compatible. + self.t_count = None + self.max_num_transforms = max_num_transforms # Audio aug impulse_paths = self._get_paths( @@ -270,6 +279,34 @@ def __init__( ), ) + # inverse mel transform + self.inverse_mel = torchaudio.transforms.InverseMelScale( + n_mels=self.config["n_mels"], + sample_rate=self.config["sample_rate"], + n_stft=self.config["n_fft"] // 2 + 1, + ) + self.inverse_spec_transform = torchaudio.transforms.GriffinLim( + n_fft=self.config["n_fft"], + hop_length=self.config["hop_len"], + ) + + def check_apply_transform(self, ratio: float): + """ + Check if a transformation should be applied based on the ratio and the + number of transformations already applied. + """ + + if ( + (self.max_num_transforms is not None) and + (self.t_count is not None) and + (self.t_count >= self.max_num_transforms) + ): + return False + apply_transform = random.random() < ratio + if apply_transform: + self.t_count += 1 + return apply_transform + def get_params(self): return { "noise_ratio": self.noise_ratio, @@ -403,13 +440,16 @@ def apply_distortion(self, wav: torch.tensor): return AF.overdrive(wav, gain=gain, colour=colour) - def distortion_aug_cpu(self, wav: torch.Tensor): + def distortion_aug_cpu(self, wav: torch.Tensor, t_count: int = None): # This function should run on the cpu (i.e. in the dataloader collate # function) in order to not be a bottlekneck + if t_count is not None: + self.t_count = t_count - if random.random() < self.reduce_ratio: + if self.check_apply_transform(self.reduce_ratio): wav = self.apply_reduction(wav) - if random.random() < self.distort_ratio: + + if self.check_apply_transform(self.distort_ratio): wav = self.apply_distortion(wav) return wav @@ -440,34 +480,34 @@ def shift_spec(self, specs: torch.Tensor, shift: int | float): return shifted_specs def detune_spec(self, specs: torch.Tensor): - if random.random() < self.detune_ratio: - detune_shift = random.uniform( - -self.detune_max_shift, self.detune_max_shift - ) - detuned_specs = self.shift_spec(specs, shift=detune_shift) + detune_shift = random.uniform( + -self.detune_max_shift, self.detune_max_shift + ) + detuned_specs = self.shift_spec(specs, shift=detune_shift) - return (specs + detuned_specs) / 2 - else: - return specs + specs = (specs + detuned_specs) / 2 + return specs - def aug_wav(self, wav: torch.Tensor): + def aug_wav(self, wav: torch.Tensor, t_count: int = None): # This function doesn't apply distortion. If distortion is desired it # should be run beforehand on the cpu with distortion_aug_cpu. Note # also that detuning is done to the spectrogram in log_mel, not the wav. + if t_count is not None: + self.t_count = t_count # Noise - if random.random() < self.noise_ratio: + if self.check_apply_transform(self.noise_ratio): wav = self.apply_noise(wav) - if random.random() < self.applause_ratio: + if self.check_apply_transform(self.applause_ratio): wav = self.apply_applause(wav) # Reverb - if random.random() < self.reverb_ratio: + if self.check_apply_transform(self.reverb_ratio): wav = self.apply_reverb(wav) # EQ - if random.random() < self.bandpass_ratio: + if self.check_apply_transform(self.bandpass_ratio): wav = self.apply_bandpass(wav) return wav @@ -482,15 +522,25 @@ def norm_mel(self, mel_spec: torch.Tensor): return log_spec def log_mel( - self, wav: torch.Tensor, shift: int | None = None, detune: bool = False + self, + wav: torch.Tensor, + shift: int | None = None, + detune: bool = False, + t_count: int = None, ): + if t_count is not None: + self.t_count = t_count + spec = self.spec_transform(wav)[..., :-1] + # check: are detune and shift mutually exclusive? + # should we also put a ratio on shift? if shift is not None and shift != 0: spec = self.shift_spec(spec, shift) elif detune is True: - # Don't detune and spec shift at the same time - spec = self.detune_spec(spec) + if self.check_apply_transform(self.detune_ratio): + # Don't detune and spec shift at the same time + spec = self.detune_spec(spec) mel_spec = self.mel_transform(spec) @@ -499,15 +549,25 @@ def log_mel( return log_spec + def inverse_log_mel(self, mel: torch.Tensor): + """ + Takes as input a log mel spectrogram and returns the corresponding audio. + """ + mel = (4 * mel) - 4 + mel = torch.pow(10, mel) + mel = self.inverse_mel(mel) + return self.inverse_spec_transform(mel) + def forward(self, wav: torch.Tensor, shift: int = 0): # Noise, and reverb + self.t_count = 0 wav = self.aug_wav(wav) # Spec, detuning & pitch shift log_mel = self.log_mel(wav, shift, detune=True) # Spec aug - if random.random() < self.spec_aug_ratio: + if self.check_apply_transform(self.spec_aug_ratio): log_mel = self.spec_aug(log_mel) return log_mel diff --git a/baselines/requirements-baselines.txt b/baselines/requirements-baselines.txt deleted file mode 100644 index b56d966..0000000 --- a/baselines/requirements-baselines.txt +++ /dev/null @@ -1,3 +0,0 @@ -pretty_midi -librosa -piano_transcription_inference diff --git a/baselines/giantmidi/transcribe_new_files.py b/experiments/baselines/giantmidi/transcribe_new_files.py similarity index 78% rename from baselines/giantmidi/transcribe_new_files.py rename to experiments/baselines/giantmidi/transcribe_new_files.py index 0650c73..1fc6929 100644 --- a/baselines/giantmidi/transcribe_new_files.py +++ b/experiments/baselines/giantmidi/transcribe_new_files.py @@ -39,8 +39,10 @@ def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Example of parser. ') - parser.add_argument('--mp3s_dir', type=str, required=True, help='') - parser.add_argument('--midis_dir', type=str, required=True, help='') + parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') + parser.add_argument('-input_file_to_transcribe', default=None, help='one file') + parser.add_argument('-output_dir', help='output directory') + parser.add_argument('-output_file', default=None, help='output file') parser.add_argument( '--begin_index', type=int, required=False, help='File num., of an ordered list of files, to start transcribing from.', default=None @@ -61,7 +63,6 @@ def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None): """ python transcribe_new_files.py \ - transcribe_piano \ - --mp3s_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ - --midis_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model + --input_dir_to_transcribe /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ + --output_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model """ \ No newline at end of file diff --git a/experiments/baselines/google_t5/test-output-file.midi b/experiments/baselines/google_t5/test-output-file.midi new file mode 100644 index 0000000000000000000000000000000000000000..a0814f7d424732ab99e5aa62098360fa7bae1c1f GIT binary patch literal 16389 zcmZvDS#Kj*mfo51y*2}I$RKBi?xwnuSxJ@1ltSEs!B`jzV_~dZLKG!YN-3AF&O&t? zNWk4d*F!bXU^WcHZ~hDZC;kzB^nV$DH<*>^0sJCR5qCNF-0$osR=Xd6M@0Xg{(=6H z{_kq{$Nz%={)zte`rrMB|Mimob9Ay@I3IXx=)%nlReU$T=sp->qeN;@qF>z_c2BBT>KvT z^J=_}-;S$eO!xTQG|zm)@u}%)*wJ+_u~g!X>mKdg^QKF#H^T1^ewW#oiU~46TMtkE zez9!>bwu!4m1|9yz z>toZVKWO_Tdl&dEDH^wpx0S|5`GLQj%1XY?AzO$J3clZLHk2$)1^A zF(L6ZrmAF5DAgq6TZ|-+7`94@8&vm)r#>-nf?2F@VUGLgm*eo zCMZ1k**On;HHXM~#q~CM-yCueYpQL3&=No9r>@8PDq#i&;fYTairJD&a-W{FdCgI{ zIwD;mZzN+(RxR&||xDCwZsLcu8sJ@JUd9-u8R#YTy^L zBf%x+>9_+~>5;9=>dT+92zN#Hh&+!Q!@owW6V0mOZmQeECSq z3V?WoX^EqMdf9ExzFc3+zhL(H;iEWnM^az*R%eHa_~(j6-yEK*Ujn)6xn#6IdKWkm zNCuq1pT^67%ND%C@mEy0O|DF8>g$sV(?H{)o=zp>46-0@3QQMJ&^w+3`mmN{yHZ#G zSOL0aj|6&fxt>a?C;#2eo}^Us)wVa0gp2`j0qdIw>JzyD%ZQvIhv;M;5cCz$vgcqZ z6ha510iw`lnZ;cSJP#7x3`xqU;m>8ChTar1-0|j(zDBN(dv5mWQ)V&elBrjuzXrg? zHB@lM0l05=Jok72=FH{pYrOqq=0bd#{u}Q6O2&D>xoSLKF@M_eEQ~)>G%POLmZZ?1 zQ=emT!R7&jn!tEyii$^)=Yqn9KWhpJ_Ip6K8}TTmG=#o#y&UJNN5oTV`vsKV$g?n> z`x9tF=n26f+hoe%FW$VSDrd8(b3`$2-lk@BE!lBRAkz2&e?b7QG@@Es@~O(ME>Ewh z>BccY?TQO)#VbkIxi@2da3aSQlK0p7bm|USczJO81BxfO zHFCdY#&0&=%fEoPxOWV~!{br_hj-P~uMSF)ve=WHDTSH@IRUMqxVd^vZb8z@Y=-|V ztw;`OZ@*^6n0sd|JfrligJBqK(+0{Dl&#v$6&!D42b0@$Lx6r^2jv$%?5K znqtq8Tzy(oRS*s^Xts*MWie;UltrclAy$Y*APUl57(5!()icdDo`9R`UWEW+kZn+; zHzRSDkr(FBoN6Cx9&5pxv<+HUv>K9;y@DEd?sAnouaA zG^~^omJW$oBbXV*3rzvBKvGBbYTb^ZHw+43W_*ohBY*;6hjKc+P%q?c4E;g=H{AZ` z*qQzedMm;HqM~3u{2|jXegM?Pjm-g=502GvC=goGtpho*Ihn~q6H+;<#FqI+!zWW8 zNlp(o5EG#rp{u%<*|f;*P9brW*^Mph`B*u~eBPgu=L?D} z?S7CwsEL`R<5Dk2Zs34ikOin!uO%mxbPoy;`$4A+ot1)(abZ>oW^`3*SS0E5x=uwZ6&0FJQG#KacxtHl%a}X@m#TwoO?b#p6l=&FB=zF{W`CYDcEy6e*oX z1@2ahmlBLjOiH=}$Vjz#5t&h;9hF68&H#_Pu^?~3RZsXN0Fp`_J8IZ8?9N$SQc+l8 z(t%7KMH&Kc#IV{-4Fu=`6QIss116eM@Cf5exxz74{wXRX)uWjw7kVRJ%XR?VQnXY4 z9@5pa%!c|af#Hs~H|;f7W7EjAuSmk7BFIlk*$wM2xMJdT;EdqBLVKs+;exWF6BTtm zlmw{rQ?MO-Wh4VSR!PAfrYQcJn@Q_vfdk?uy_j$L&?O_JbOm$7=7M`?dPMd*hFIWg zpihT6g*)~d+Ie>}B=(wq#um7mc|qx#4HL4`<{URjXXc9RQcN_(MCR|fQ8@993&sMb zIB81q%Z#UYrcy|&`#p>-?gBTyqMChA+eg~nb1i==cWzv;)mDjeIW4uL6320>6)>1e zqAk%5BD-gsf^G^nTbx3CXT)!)wc0X&rFjJvTMb7`%3*c0f`&vW1hK+~PPQe}OIod1 zKR$s!)~8-knrY6Wd=w_NIP2kPZ=VKnELUECws1Xb{m)$s}bq6 zn!qN#a3=t>98b?;Xm8<>GC0PO?%D)hOtB$ZgMCbpd8wF4;W<+~yAUpU{@BRrP|I!` z8T}z$)>GOfX(ORHE*migi?XqxAlfn|W#fS2`3W>D?gZimxv^5%gRZ;fQ)j{?LXs|pE_pw=rP)KHbHxU9 z(h*RS^`xUi!Q|}H-B z5-6%r*8|3qlFt?__jn>Upqx5!DLE0#Z14`!+J+{~a%oZUxyK;g3xZ_9hhxzkS`m;F zau-oI+Ecofv{egEk_K0=fU1PCQ-XvW4s-iFpqfGOD%xqp!a{;u8d9(Sz>)_Z?bs_D zXa}75D|!)@(BUj0B}rM3ONN2ahR;UA;ZNXpR36)mT472_(5TMRi65DE%U}C~DmY0R z<+#y^dgYcyD{k1#oj^iCCuo2Od(+LdcSg25s}$;8{k(83RA>Q;#OvnGK0EvfWpSW z$3$+rwSj;jbY05`6azAq^S(v8#k!*}9u!#7h2mYyYI@S`bc!vDm)zY0l$vwPMr#eq ze^en17Ye<4w3WCo$h&{bvlXpwS(I`6feqKVHDS3TQf5zlP&fl3=)L0+0~-tctGDq-c0$7adT@O(iKP6E4v8^uX5cLV?H`=)|e zVP~%1z!`@Q{Zu@#9S-J8hA#_2yq~bxTs;HzD_MD*Ic>3T}#L?m^|; zGd(_lVZt;)wvGoV=;YNc^}nyb5>|0yt@+3;!Y$={=B~)UW}0_s;09qwFbCIn+&`!K z=^b^S8>h6f?t=7zc?EYv3X@NI$nJO2Drnsk-9N)IBpBb#!$-`8pABMHZ|H7I4-d?b z67lbEtApOr1JT1h9L_KF9iejEuHrG>Lzg9#-q-GgQWHx(&=5jVJk$L>+vUxB(%{Ma z;NtyFT=MHDdj7!foE;XyMS)w_;3&~E$#1`BZc1BR!hN9S8=MZ?f$Mnt7O|5n!KTdN zDGpiEM<^vF^kP$1SJ!3>o91oS2@c1HPSa~-7x-WO9 zEFgJ7r(s^pVNcflE_VPZ><&hO!fG}C=9X=iTs>!ZBY?4^<$;w^v}|IvY5=_vDhTdu z)3;dZ(|yJZTfs2_2iW&}YR+PMfE(a8X4KWfAPi*tLc(uo z#a9jM)npsAq2(&x!bg?k2rgwOT{lze#U*}s(*nQSKFnt!ITL7dOlN@yM_ft^@JjNd zMv8-xepsCv^;=+4QvvJQ#i>c za0*4f5nkb;ajCF)B!}O^9r~l!hx6M7>P6Oi%=HUtJgUK+t2{t0cbYl5=OeZ-rQn>M zi+k1_wXmJqSO;)~yEP^Sc*z5dG{b;h+JR*u_=A(uTfjS>v+>}>rrJV`Y(Y!I9d$(- zqvw_IlB5e$q;3=97Z9(K&*$Cr5!Bl!pZPVTh+cU4h zp~e|K_ovj&gFcV0%%Wh6i7Vc&B*#V!E?5i48AQT4D#%Uzz#j&SbEyd8BjWH1*M%v_ zUDz|Ojj?ioG*JOZ*EH98mHIS;g>eu%?P%GKDjHl~gT7Z(bht2>xq#*a$bVUibO%j3fC6Pj(U>>Cw`wr{f-zG zHWuUhDz2uKa3GJuB4!<_UqN0hvTZqE6=HnBJ(b0Fyv~Hr8Bc?Sbr354cmPcMyE9bO zPiu<{spxXQ{pXeXVL9TH?{cxwOw!apO)WO5EGwdse|sc^XU+WTQvqfV+;zm_C|S6h zBlK}F?1Dws!WclP*e=B#vSJ#n@?e>}dQiF4#R%uG7=+bYMk;ymo^dBB=E9Gxr4uVi zYDJkG13V1=#iL;BHIGws5fs>HEr50sObO|;U9wwcvA*ano!AK5;voJ<(}YYQAp;*> zByaGr+O(5vQhK$AXyqgqW;39XpxLg?M-Z#dn%b)g^E8^%ACOWFdI-$;Qt6oVWb{sp z%i5BD$}hBlc7BJMoocVk!^HYln*XMSZGcj+tO8hLeZjJC{<6BO)*&~Sb=#_otQF3c&Y5nRB4YGrYfHAzbk@M&gKc`5Q% zwbVOHyccWtY!$YaQCqAtX$j@zbmCc4Y#J(JJH5FutMy~0h@hS+4$j2e;KdqXOxt1# zeivt1_ChR@w#D03r@EvwL1c`PcycNjz3z(VyC!BGDHPnKll3(&wI#N@R*Jh*@uO^2 zufR}R-AhYquOE=t#Nx{J_Y5v(@_~wH1_yVR38XLXxfuaQ0`i!6O^5vAhILA=?uDy6 zyGGKM%v1vUpjJ9Ut%y9mNqk1b$sIRCa!PqZur-JjV&pLpd`{ycC`3LNDk<7hvZiv! zu)@^C7&4)3(>i{9RVBy0OTsSZETsAR1X+<8nZ}OR4?K9J?3OwEc(D{{g`5I$pcp)Y zu9%@Dm{MeOcTt0-tGp>da^w{(S}=Pg!tzQ$hsz~*VwxAU$YFrQWdqtvQnG}hdgU7@ zN3)XA#0z*3Q{qz*L*EI2Cv;NKaPG1+25l|aoovM>NE*5n%{OI|K$Nj2)FP-VNp>BU z`m_vqG=lO4k`UJ*K3A3@pIa;oIiBE*8E8+r5wn4bw5?O7$)FM=tbzp{a@<>WuX3ucJBG-rP^t>-cD(hYG!BuW_v@&SSrtPg(zd>Zkmn{Gqt28#at1bo1A@FB8zK4Lc` zkv{=0Z^*jNY70xsMnN#48;=*G`tgD)NMQiNYUMg-o{dTjn+-&c9E=4e*EGB_c%jN- z3v$Z^OB^2Q=1nBeO`e&t&5E@OdbMeqT*EHXHD`0(TAi7H|G89Uly zJZat+OrFya4qqt>Na8{P6^D8^qW*MJ(HyKT7F^L{5a2kd%_ZVSz5pKMp%HGVF`LjD zR@s+kBGJLu;9hfHXK*=0wkSV{j*`qBZ$O>@L23wyC30jl3x7e-0k(*wawOitt3mNzg^8!*p}h^G1{yfComx z{W3EcoQY%_pMVyyG^6H+2ll&E`hc{ zKbG|(J$+22AE^d9;Q}sG_PdgQD26|zeC^dA6Jd{mWE|}Sf#>uCH4*cI$#F*4fw=lN z)OuT-ISck&4xu5TT0z58f#4k?Nxxmw_k;*e!lu6^`OQ*V*xko%2Ra99B07Fp^VOO~ zy7Xbi;8{Ef_1k>Rn%P`_6T@{sZryyp!$y1w56gq#gN>umZtce?0E_MR)Ohz$Tjo3% zu*KEsUDy#hjkKPQ+WV5;ZxKk+!fHRH9Vo`Ylpb@qIl?daL&Z_JLn-rC>peHDzn%$P@4!H(nwCP`}% z>iS*1q7y_eBABtC^{MIUilaV-S6pC1*H-E0L18uGQM+3uw1D3uk$Am=%{Sp*P&|(f z*a2(gS_t0+YEXCa@>8@VE!J3f;w3JU;%+M1BT!y*knJe&wgZCVZQ$8~=5-KS;+3m^ z@XVX8+eS2k<6$o(nUWZn_M}e}Z>l+$o3Pu30ZG1fSI+0S7=Ft7AJfsHYU;V*Mz7;u zpyn*HZ%0Zo78K#y08P}gCGY&_r2iy1_;ROTB3?(MQ|03Iope&AFRyx<(h|@6BQeYx zW-)3u=qT!_7FwdnomS7?Zoaa_bMFP%;sku`i}EhoS;UyRB3_x58(dbu7ahdwNY~_YWvH*OfNk)0c3zQ;vt69BaP;(7wgcStcZ6td<%Ph=lyJ>SM9wNon9+>0~T+iGu z=NtG$SQF!+$WO;+cW~+rg8rbRj#?E}ZD-PQ$JGJN{b~wxcUt&LZobGoq(2EP>sRDq z@R2+$1+D?UU>x)u5@V^2ESHoy?V?aNMR%7qZ?m&EaqA)NL3(cEju7ssePgytzXuF& zyygv5wf90i51aQ6B(BYOm&H?&c2Y}!@} zB?x~mI=Ny$Q1IHIc!ZEs@leCyOTR{4A=Y zscJ@mJ+snNi`dj_PDN_1sr!)U21>Z54;PP(Ls26$7vi@X_5dei0eNN<8|z4h0~d*D zc1L%Sw}W`|VWgm7tFMQn5{tCQr97gcx@N%LHmLZf4wBS9Ca?~^PhX||901H?Yt|w=pDK%_fO(d z;zRxAh3QY>v{0g*0$UZ1fuYa8LjqU*wqg9!#G6XR)Lr>qF(Vv73;RBUjqHvNU@4k2 z#3Zk*;f!c}5`$1#!g*g-v2b2xK5{o(ed$EPt13Ja0d?pewGcLB{b{F6_2K(kbs>bk>Xg;8R%*JWoHP}{TFkKUN-dMbkU+yr93;IS^2sqnAkE6G_Iqid(qYp$66g-pb%_My30!_TAZtjgp#Rh`HZi6*RHlKb6)rL#}l z6^~=&EtqD=-&F)0JrC@TApZAGuPP&97|=Cza;zaP=F>X&r`QpK}T?m^JtS zriT`-%!3@wYlTt7^QtA$-H!(~rH2Igm&Gv-$HMwYwxM!|rkrF%tKs3j;Q4n|CPU*! zCvYZ1SN|!}WUt=>?~Y~x_{FwR;h=AY(8EV~b|5E4U&*ML1zylxnSr=1`=g??Fcdh% zhShFO{{r6kXW7200R+aG>mkwQx~sU7t+;)>pud+8Y*NjXyM6AXAt5`w5Zw|b+Ml%5 zn%mQ+Q5W%p|DP*4DEC>y!AdLvR(_SUUqHf~g`|&I|B9=f7V5aD*y~8@L*^<@qE)q> zx=))a5_aN{pd2%W><#PaqjO#}uYC}yPA@}?5fzcLT35-;2<=8+CLBfXC~rBsPYaIP zaUr@G|A-7pfo%MYIYwu-tc0%UUD*FWp@WyW02rkUl+>xC$7YD!$dVb(2u&d9@&SFv zBObS_SX4){J@S}>3@UnxN6_8y;LsuQr!P_xppja1QpNtK5{*d}dny<(`$~#>2h%eo z^|erG#~VWs3EKi1AX93kKBy3Wx z{ma$2H4ilhN|)?v3X3CMQ5{v0sf|`9QC_H++=|ysH+g^3Ma>+Cn}@Pt$VjBeqQq23 z3t6#{%n?ZG!Jq__hSj?X=p?*qdarK4gE-!WLFF^7_OE7W1rf`N=K?;>G_AR##w(Zw z<7I0zvfiKol{=ODV;w-HM;0bx15B@*(#E*ff2>Rbvy16gBdH)igBYo-C75(Pq^>6|b-Lv9(HAbz*^dw^@ z%6tCtU-j>XjZlpv-#!yl`qgu*a?^fQjW>?VRt>=x8tuuCDj(1k9Rqrxx$s#rF+FxQ z_rl^ulr`f*NfVyft!&*aN@W#E%V}e?t(T?t6<9|f4GV&cM|}ZNJW2!5w)|pJpdQ)o zd;XH>Wl72%rSIx5MR(*=xqY@hrsCSGm;&KS7^BTmXm#KcYZmY}a@6i?$fiZ)Y@fM$ zL-7Oa`bkmb;5$a5jR5s}KmVewP%({A2lY|3f%*nAK^V-47S4I>plJic0;S}Lq|1`M z)Oj=k`h`jdl0hLFHf7XSqH8Lz8IrMS!PU^QG1I76C< zRXj+cC=6j&ByyT)h4Uj6isLTo{af1Xus2K59!Ww6!E9pU25r}n`WZ%`E|1sf-#dEoFSTZ5Z`4vYC1xR z!qVt5^H7D&g%v|Efl}^#!a_q5UMlsvF-Q6|J&vb9IK6dqRICdXhjAj)8WO z5Zw=^Xz~i_tY3;Od%R;s%tO(ggOx?fMuh07)B+D{=7h+Ea?#J*pSY>;MnM2pJA%AY z5{-GzaD07QN%1v8bRi2kCO!}e3hsmyKY@ct(t!+=(i`Cxoywi;0nHXH+;Vlzx(4FS6-A6kGV(_q@K7IUd&jy3OTcrHA%Js(Jn2SxWG_>`dt~XB?smvX6yM2eIuSf(m%y$e;Phnp8+iLpbLBWB*yb4X?o-z3- z<46Ku?OA{gPHgl`}Q4=ymZ6no(b`ppIdlSg7Msl2J zRXSfOg%X+hQZi+lT2SR*u~hhC0`X^^;G(WqjB9Cu=w&>47N9``DSZ3z8gF6mcP2T$Lj#( zwsfKe>^>Pj##!a?D7-UeIZ7!dq;}{ z$^q8`j#R>IuT$mUq!9(u95I%W)CjO>%cnJu^6aU1wTdS24 zt=nMoQG~(R$|J0Ir_yQ$&A^tI@}|&)=tV6gc0m|4hS7m5i+YKaWy&HDDTEcZa0R+8 zbRfPssoFT2MbOQ>qYSo0&pwp4=$(M(X@&t}{W|l1BcNp@-|^x7e%Jy0Oa3$(S)7?eD8u3q?TIr zHssqW^1@)6BvzQqXsLM;AEpGxL}XXrxckUilx@(h-$r*E%YPgvxpC&ugq z>y9FUVYnc?dP}FlMK>BZ9&98-dD>^YtM+r;6tnUn5PKYd9Q;gNtdmJ|lNn>$n`UGs zsUdQy5BBg;Zxxp(V!%v%)df5p=i);5Rap9nSH#|%($jz~2XKX-(*9__05TN}*eBBb zvUgkM<&P6_S;vVW<~^{QanD3w7c=xE)Vl)7E3|ge#;&JhX}@Accq$&2!@I30hE-bh ztBn`=T)agt*mHvWoP|@KP22lb)z23UkAs1PU$WaLcQIRg=)uB=WLlthjthcsV!#Y# zFgeki+iB<8meT&_DA9;*amDwzcONa+ayjO~s}IpDW#`JPzGDpet?vj43WApfx{icO z!AZd%;1h}v6h5Kig$!9CEk!M@H;LjR5jV99Ms4x?x-+vRDDX+O0;Ypen2w3J zD2;4;6LTVl`Jy7b9HXMyx-DCVEs2LYI=h6QJ*zGVnjdqw+vD5@+A}8j?nLdtZ9<7< zJx9d&w9h3>(IIr!7tH`B;Fze%E&I7{x>#Z9FkA0Idx+njG#gFeBwz}5u`8vA1MZQ^ z5)~a&9aT*yf#~jlQMN|C(38z6Qfm|h$UJ>9*7z6NQhbb|lW(A-=lIP)8*R|axuk#l zX^Ps`$nEppj{3p7XkoY-$Dv4oyh7y$YVLzl#xhlsH0O%#E+8d zvnfRj4Wz>8J2qs$9XqnCaa|F8EEA>j{7sbK#70$of=Jv=Of;zRAb>L_DHq~<)=-gE zy#cobWW((@#v;nWrn4bF>Q#^y!4}AHCoZyJ93}8{tN?GBRpn4g+U8O?K-G|PLwvB2 zleV-wI2TyDrN+&)jv}j}(q`K7rp+D7$unNgzTh12gCJnd9i0IW#e8yO^QTnkoExEq zl-u<5mg;ZU(#GJ+Dcxx`s6in4E!FSc-hR=3w?+>2C?N7XYCNVaa-~Pnj>+BxQsKyN z2vsGHdI{=ny9gVh&MLi0A(+F*738{=Px)pl3V5hSW3?)NJ9$U-2h@jrIg0C|W){+$ z5bi|#{SKBzD6Xdh$z4!*0s()M4j&6>7!cEXhq4x`S*SDzDAQJU*+X%w_9ljB(t5v@ zp3=_qvh%iV{ua%qLA8{q=zt!!bZ{;S&J3DF-X^2BRC{0Y-{STIzMHecM$Hv6Rews+ zRx5Pid(^`~@Lala=`lvZ9<6XS!31It(*~?$gv?~fA8qtupj$!)Ti%$p-JHK)pv(rb zVo$}NzR%@qDfxaOp@w>oq_rmm>c(!~m$(IV@v?H!fFrV+ox;sqx~Q;9CAwN1TwFV~ zaYjH?N-W??D&kILp$oP-pNVo*#Q}K^2Nll|9!;9I2`p7L0Hq)4Z~xVF|Jr(Q4;&?ddfzN>PV)uNltK`{gKj2zo*3?@4^y|Nj7N Cx3f Date: Tue, 9 Apr 2024 13:20:41 -0700 Subject: [PATCH 2/3] updated --- .gitignore | 1 + .../giantmidi/transcribe_new_files.py | 68 +- .../baselines/google_t5/test-output-file.midi | Bin 16389 -> 0 bytes .../google_t5/transcribe_new_files.py | 51 +- .../{src/amt.py => model/hft_amt.py} | 0 .../hft_transformer/model/model_spec2midi.py | 378 ++++++++++ .../model/model_spec2midi_ablation.py | 707 ++++++++++++++++++ .../hft_transformer/transcribe_new_files.py | 102 +-- .../baselines/requirements-baselines.txt | 9 +- experiments/get_maestro_split.py | 45 -- experiments/loader_util.py | 139 ++++ experiments/process_input_files.py | 96 +++ experiments/run_dtw_transcription.sh | 36 + experiments/run_maestro_aug_1.sh | 50 ++ experiments/run_maestro_aug_2.sh | 52 ++ experiments/run_maestro_vanilla.sh | 44 ++ 16 files changed, 1591 insertions(+), 187 deletions(-) delete mode 100644 experiments/baselines/google_t5/test-output-file.midi rename experiments/baselines/hft_transformer/{src/amt.py => model/hft_amt.py} (100%) create mode 100644 experiments/baselines/hft_transformer/model/model_spec2midi.py create mode 100644 experiments/baselines/hft_transformer/model/model_spec2midi_ablation.py delete mode 100644 experiments/get_maestro_split.py create mode 100644 experiments/loader_util.py create mode 100644 experiments/process_input_files.py create mode 100644 experiments/run_dtw_transcription.sh create mode 100644 experiments/run_maestro_aug_1.sh create mode 100644 experiments/run_maestro_aug_2.sh diff --git a/.gitignore b/.gitignore index ab9832b..bfc1814 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ notebooks/scratch experiments/baselines/hft_transformer/model_files/ experiments/baselines/google_t5/model_files/ +experiments/aria-amt-intermediate-transcribed-data # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/experiments/baselines/giantmidi/transcribe_new_files.py b/experiments/baselines/giantmidi/transcribe_new_files.py index 1fc6929..ee986dd 100644 --- a/experiments/baselines/giantmidi/transcribe_new_files.py +++ b/experiments/baselines/giantmidi/transcribe_new_files.py @@ -4,63 +4,45 @@ import torch import piano_transcription_inference import glob +from more_itertools import unique_everseen +from tqdm.auto import tqdm +from random import shuffle +import sys +here = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(here, '../..')) +import loader_util -def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None): - """Transcribe piano solo mp3s to midi files.""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - os.makedirs(midis_dir, exist_ok=True) +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Example of parser. ') + parser = loader_util.add_io_arguments(parser) + args = parser.parse_args() - # Transcriptor + files_to_transcribe = loader_util.get_files_to_transcribe(args) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' transcriptor = piano_transcription_inference.PianoTranscription(device=device) - transcribe_time = time.time() - for n, mp3_path in enumerate(glob.glob(os.path.join(mp3s_dir, '*.mp3'))[begin_index:end_index]): - print(n, mp3_path) - midi_file = os.path.basename(mp3_path).replace('.mp3', '.midi') - midi_path = os.path.join(midis_dir, midi_file) - if os.path.exists(midi_path): + # Transcriptor + for n, (input_fname, output_fname) in tqdm(enumerate(files_to_transcribe), total=len(files_to_transcribe)): + if os.path.exists(output_fname): continue - (audio, _) = ( - piano_transcription_inference - .load_audio(mp3_path, sr=piano_transcription_inference.sample_rate, mono=True) - ) - + now_start = time.time() + (audio, _) = (piano_transcription_inference + .load_audio(input_fname, sr=piano_transcription_inference.sample_rate, mono=True)) + print(f'READING ELAPSED TIME: {time.time() - now_start}') + now_read = time.time() try: # Transcribe - transcribed_dict = transcriptor.transcribe(audio, midi_path) - print(transcribed_dict) + transcribed_dict = transcriptor.transcribe(audio, output_fname) except: print('Failed for this audio!') + print(f'TRANSCRIPTION ELAPSED TIME: {time.time() - now_read}') + print(f'TOTAL ELAPSED TIME: {time.time() - now_start}') - print('Time: {:.3f} s'.format(time.time() - transcribe_time)) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Example of parser. ') - parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') - parser.add_argument('-input_file_to_transcribe', default=None, help='one file') - parser.add_argument('-output_dir', help='output directory') - parser.add_argument('-output_file', default=None, help='output file') - parser.add_argument( - '--begin_index', type=int, required=False, - help='File num., of an ordered list of files, to start transcribing from.', default=None - ) - parser.add_argument( - '--end_index', type=int, required=False, default=None, - help='File num., of an ordered list of files, to end transcription.' - ) - - # Parse arguments - args = parser.parse_args() - transcribe_piano( - mp3s_dir=args.mp3s_dir, - midis_dir=args.midis_dir, - begin_index=args.begin_index, - end_index=args.end_index - ) - """ python transcribe_new_files.py \ --input_dir_to_transcribe /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ diff --git a/experiments/baselines/google_t5/test-output-file.midi b/experiments/baselines/google_t5/test-output-file.midi deleted file mode 100644 index a0814f7d424732ab99e5aa62098360fa7bae1c1f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16389 zcmZvDS#Kj*mfo51y*2}I$RKBi?xwnuSxJ@1ltSEs!B`jzV_~dZLKG!YN-3AF&O&t? zNWk4d*F!bXU^WcHZ~hDZC;kzB^nV$DH<*>^0sJCR5qCNF-0$osR=Xd6M@0Xg{(=6H z{_kq{$Nz%={)zte`rrMB|Mimob9Ay@I3IXx=)%nlReU$T=sp->qeN;@qF>z_c2BBT>KvT z^J=_}-;S$eO!xTQG|zm)@u}%)*wJ+_u~g!X>mKdg^QKF#H^T1^ewW#oiU~46TMtkE zez9!>bwu!4m1|9yz z>toZVKWO_Tdl&dEDH^wpx0S|5`GLQj%1XY?AzO$J3clZLHk2$)1^A zF(L6ZrmAF5DAgq6TZ|-+7`94@8&vm)r#>-nf?2F@VUGLgm*eo zCMZ1k**On;HHXM~#q~CM-yCueYpQL3&=No9r>@8PDq#i&;fYTairJD&a-W{FdCgI{ zIwD;mZzN+(RxR&||xDCwZsLcu8sJ@JUd9-u8R#YTy^L zBf%x+>9_+~>5;9=>dT+92zN#Hh&+!Q!@owW6V0mOZmQeECSq z3V?WoX^EqMdf9ExzFc3+zhL(H;iEWnM^az*R%eHa_~(j6-yEK*Ujn)6xn#6IdKWkm zNCuq1pT^67%ND%C@mEy0O|DF8>g$sV(?H{)o=zp>46-0@3QQMJ&^w+3`mmN{yHZ#G zSOL0aj|6&fxt>a?C;#2eo}^Us)wVa0gp2`j0qdIw>JzyD%ZQvIhv;M;5cCz$vgcqZ z6ha510iw`lnZ;cSJP#7x3`xqU;m>8ChTar1-0|j(zDBN(dv5mWQ)V&elBrjuzXrg? zHB@lM0l05=Jok72=FH{pYrOqq=0bd#{u}Q6O2&D>xoSLKF@M_eEQ~)>G%POLmZZ?1 zQ=emT!R7&jn!tEyii$^)=Yqn9KWhpJ_Ip6K8}TTmG=#o#y&UJNN5oTV`vsKV$g?n> z`x9tF=n26f+hoe%FW$VSDrd8(b3`$2-lk@BE!lBRAkz2&e?b7QG@@Es@~O(ME>Ewh z>BccY?TQO)#VbkIxi@2da3aSQlK0p7bm|USczJO81BxfO zHFCdY#&0&=%fEoPxOWV~!{br_hj-P~uMSF)ve=WHDTSH@IRUMqxVd^vZb8z@Y=-|V ztw;`OZ@*^6n0sd|JfrligJBqK(+0{Dl&#v$6&!D42b0@$Lx6r^2jv$%?5K znqtq8Tzy(oRS*s^Xts*MWie;UltrclAy$Y*APUl57(5!()icdDo`9R`UWEW+kZn+; zHzRSDkr(FBoN6Cx9&5pxv<+HUv>K9;y@DEd?sAnouaA zG^~^omJW$oBbXV*3rzvBKvGBbYTb^ZHw+43W_*ohBY*;6hjKc+P%q?c4E;g=H{AZ` z*qQzedMm;HqM~3u{2|jXegM?Pjm-g=502GvC=goGtpho*Ihn~q6H+;<#FqI+!zWW8 zNlp(o5EG#rp{u%<*|f;*P9brW*^Mph`B*u~eBPgu=L?D} z?S7CwsEL`R<5Dk2Zs34ikOin!uO%mxbPoy;`$4A+ot1)(abZ>oW^`3*SS0E5x=uwZ6&0FJQG#KacxtHl%a}X@m#TwoO?b#p6l=&FB=zF{W`CYDcEy6e*oX z1@2ahmlBLjOiH=}$Vjz#5t&h;9hF68&H#_Pu^?~3RZsXN0Fp`_J8IZ8?9N$SQc+l8 z(t%7KMH&Kc#IV{-4Fu=`6QIss116eM@Cf5exxz74{wXRX)uWjw7kVRJ%XR?VQnXY4 z9@5pa%!c|af#Hs~H|;f7W7EjAuSmk7BFIlk*$wM2xMJdT;EdqBLVKs+;exWF6BTtm zlmw{rQ?MO-Wh4VSR!PAfrYQcJn@Q_vfdk?uy_j$L&?O_JbOm$7=7M`?dPMd*hFIWg zpihT6g*)~d+Ie>}B=(wq#um7mc|qx#4HL4`<{URjXXc9RQcN_(MCR|fQ8@993&sMb zIB81q%Z#UYrcy|&`#p>-?gBTyqMChA+eg~nb1i==cWzv;)mDjeIW4uL6320>6)>1e zqAk%5BD-gsf^G^nTbx3CXT)!)wc0X&rFjJvTMb7`%3*c0f`&vW1hK+~PPQe}OIod1 zKR$s!)~8-knrY6Wd=w_NIP2kPZ=VKnELUECws1Xb{m)$s}bq6 zn!qN#a3=t>98b?;Xm8<>GC0PO?%D)hOtB$ZgMCbpd8wF4;W<+~yAUpU{@BRrP|I!` z8T}z$)>GOfX(ORHE*migi?XqxAlfn|W#fS2`3W>D?gZimxv^5%gRZ;fQ)j{?LXs|pE_pw=rP)KHbHxU9 z(h*RS^`xUi!Q|}H-B z5-6%r*8|3qlFt?__jn>Upqx5!DLE0#Z14`!+J+{~a%oZUxyK;g3xZ_9hhxzkS`m;F zau-oI+Ecofv{egEk_K0=fU1PCQ-XvW4s-iFpqfGOD%xqp!a{;u8d9(Sz>)_Z?bs_D zXa}75D|!)@(BUj0B}rM3ONN2ahR;UA;ZNXpR36)mT472_(5TMRi65DE%U}C~DmY0R z<+#y^dgYcyD{k1#oj^iCCuo2Od(+LdcSg25s}$;8{k(83RA>Q;#OvnGK0EvfWpSW z$3$+rwSj;jbY05`6azAq^S(v8#k!*}9u!#7h2mYyYI@S`bc!vDm)zY0l$vwPMr#eq ze^en17Ye<4w3WCo$h&{bvlXpwS(I`6feqKVHDS3TQf5zlP&fl3=)L0+0~-tctGDq-c0$7adT@O(iKP6E4v8^uX5cLV?H`=)|e zVP~%1z!`@Q{Zu@#9S-J8hA#_2yq~bxTs;HzD_MD*Ic>3T}#L?m^|; zGd(_lVZt;)wvGoV=;YNc^}nyb5>|0yt@+3;!Y$={=B~)UW}0_s;09qwFbCIn+&`!K z=^b^S8>h6f?t=7zc?EYv3X@NI$nJO2Drnsk-9N)IBpBb#!$-`8pABMHZ|H7I4-d?b z67lbEtApOr1JT1h9L_KF9iejEuHrG>Lzg9#-q-GgQWHx(&=5jVJk$L>+vUxB(%{Ma z;NtyFT=MHDdj7!foE;XyMS)w_;3&~E$#1`BZc1BR!hN9S8=MZ?f$Mnt7O|5n!KTdN zDGpiEM<^vF^kP$1SJ!3>o91oS2@c1HPSa~-7x-WO9 zEFgJ7r(s^pVNcflE_VPZ><&hO!fG}C=9X=iTs>!ZBY?4^<$;w^v}|IvY5=_vDhTdu z)3;dZ(|yJZTfs2_2iW&}YR+PMfE(a8X4KWfAPi*tLc(uo z#a9jM)npsAq2(&x!bg?k2rgwOT{lze#U*}s(*nQSKFnt!ITL7dOlN@yM_ft^@JjNd zMv8-xepsCv^;=+4QvvJQ#i>c za0*4f5nkb;ajCF)B!}O^9r~l!hx6M7>P6Oi%=HUtJgUK+t2{t0cbYl5=OeZ-rQn>M zi+k1_wXmJqSO;)~yEP^Sc*z5dG{b;h+JR*u_=A(uTfjS>v+>}>rrJV`Y(Y!I9d$(- zqvw_IlB5e$q;3=97Z9(K&*$Cr5!Bl!pZPVTh+cU4h zp~e|K_ovj&gFcV0%%Wh6i7Vc&B*#V!E?5i48AQT4D#%Uzz#j&SbEyd8BjWH1*M%v_ zUDz|Ojj?ioG*JOZ*EH98mHIS;g>eu%?P%GKDjHl~gT7Z(bht2>xq#*a$bVUibO%j3fC6Pj(U>>Cw`wr{f-zG zHWuUhDz2uKa3GJuB4!<_UqN0hvTZqE6=HnBJ(b0Fyv~Hr8Bc?Sbr354cmPcMyE9bO zPiu<{spxXQ{pXeXVL9TH?{cxwOw!apO)WO5EGwdse|sc^XU+WTQvqfV+;zm_C|S6h zBlK}F?1Dws!WclP*e=B#vSJ#n@?e>}dQiF4#R%uG7=+bYMk;ymo^dBB=E9Gxr4uVi zYDJkG13V1=#iL;BHIGws5fs>HEr50sObO|;U9wwcvA*ano!AK5;voJ<(}YYQAp;*> zByaGr+O(5vQhK$AXyqgqW;39XpxLg?M-Z#dn%b)g^E8^%ACOWFdI-$;Qt6oVWb{sp z%i5BD$}hBlc7BJMoocVk!^HYln*XMSZGcj+tO8hLeZjJC{<6BO)*&~Sb=#_otQF3c&Y5nRB4YGrYfHAzbk@M&gKc`5Q% zwbVOHyccWtY!$YaQCqAtX$j@zbmCc4Y#J(JJH5FutMy~0h@hS+4$j2e;KdqXOxt1# zeivt1_ChR@w#D03r@EvwL1c`PcycNjz3z(VyC!BGDHPnKll3(&wI#N@R*Jh*@uO^2 zufR}R-AhYquOE=t#Nx{J_Y5v(@_~wH1_yVR38XLXxfuaQ0`i!6O^5vAhILA=?uDy6 zyGGKM%v1vUpjJ9Ut%y9mNqk1b$sIRCa!PqZur-JjV&pLpd`{ycC`3LNDk<7hvZiv! zu)@^C7&4)3(>i{9RVBy0OTsSZETsAR1X+<8nZ}OR4?K9J?3OwEc(D{{g`5I$pcp)Y zu9%@Dm{MeOcTt0-tGp>da^w{(S}=Pg!tzQ$hsz~*VwxAU$YFrQWdqtvQnG}hdgU7@ zN3)XA#0z*3Q{qz*L*EI2Cv;NKaPG1+25l|aoovM>NE*5n%{OI|K$Nj2)FP-VNp>BU z`m_vqG=lO4k`UJ*K3A3@pIa;oIiBE*8E8+r5wn4bw5?O7$)FM=tbzp{a@<>WuX3ucJBG-rP^t>-cD(hYG!BuW_v@&SSrtPg(zd>Zkmn{Gqt28#at1bo1A@FB8zK4Lc` zkv{=0Z^*jNY70xsMnN#48;=*G`tgD)NMQiNYUMg-o{dTjn+-&c9E=4e*EGB_c%jN- z3v$Z^OB^2Q=1nBeO`e&t&5E@OdbMeqT*EHXHD`0(TAi7H|G89Uly zJZat+OrFya4qqt>Na8{P6^D8^qW*MJ(HyKT7F^L{5a2kd%_ZVSz5pKMp%HGVF`LjD zR@s+kBGJLu;9hfHXK*=0wkSV{j*`qBZ$O>@L23wyC30jl3x7e-0k(*wawOitt3mNzg^8!*p}h^G1{yfComx z{W3EcoQY%_pMVyyG^6H+2ll&E`hc{ zKbG|(J$+22AE^d9;Q}sG_PdgQD26|zeC^dA6Jd{mWE|}Sf#>uCH4*cI$#F*4fw=lN z)OuT-ISck&4xu5TT0z58f#4k?Nxxmw_k;*e!lu6^`OQ*V*xko%2Ra99B07Fp^VOO~ zy7Xbi;8{Ef_1k>Rn%P`_6T@{sZryyp!$y1w56gq#gN>umZtce?0E_MR)Ohz$Tjo3% zu*KEsUDy#hjkKPQ+WV5;ZxKk+!fHRH9Vo`Ylpb@qIl?daL&Z_JLn-rC>peHDzn%$P@4!H(nwCP`}% z>iS*1q7y_eBABtC^{MIUilaV-S6pC1*H-E0L18uGQM+3uw1D3uk$Am=%{Sp*P&|(f z*a2(gS_t0+YEXCa@>8@VE!J3f;w3JU;%+M1BT!y*knJe&wgZCVZQ$8~=5-KS;+3m^ z@XVX8+eS2k<6$o(nUWZn_M}e}Z>l+$o3Pu30ZG1fSI+0S7=Ft7AJfsHYU;V*Mz7;u zpyn*HZ%0Zo78K#y08P}gCGY&_r2iy1_;ROTB3?(MQ|03Iope&AFRyx<(h|@6BQeYx zW-)3u=qT!_7FwdnomS7?Zoaa_bMFP%;sku`i}EhoS;UyRB3_x58(dbu7ahdwNY~_YWvH*OfNk)0c3zQ;vt69BaP;(7wgcStcZ6td<%Ph=lyJ>SM9wNon9+>0~T+iGu z=NtG$SQF!+$WO;+cW~+rg8rbRj#?E}ZD-PQ$JGJN{b~wxcUt&LZobGoq(2EP>sRDq z@R2+$1+D?UU>x)u5@V^2ESHoy?V?aNMR%7qZ?m&EaqA)NL3(cEju7ssePgytzXuF& zyygv5wf90i51aQ6B(BYOm&H?&c2Y}!@} zB?x~mI=Ny$Q1IHIc!ZEs@leCyOTR{4A=Y zscJ@mJ+snNi`dj_PDN_1sr!)U21>Z54;PP(Ls26$7vi@X_5dei0eNN<8|z4h0~d*D zc1L%Sw}W`|VWgm7tFMQn5{tCQr97gcx@N%LHmLZf4wBS9Ca?~^PhX||901H?Yt|w=pDK%_fO(d z;zRxAh3QY>v{0g*0$UZ1fuYa8LjqU*wqg9!#G6XR)Lr>qF(Vv73;RBUjqHvNU@4k2 z#3Zk*;f!c}5`$1#!g*g-v2b2xK5{o(ed$EPt13Ja0d?pewGcLB{b{F6_2K(kbs>bk>Xg;8R%*JWoHP}{TFkKUN-dMbkU+yr93;IS^2sqnAkE6G_Iqid(qYp$66g-pb%_My30!_TAZtjgp#Rh`HZi6*RHlKb6)rL#}l z6^~=&EtqD=-&F)0JrC@TApZAGuPP&97|=Cza;zaP=F>X&r`QpK}T?m^JtS zriT`-%!3@wYlTt7^QtA$-H!(~rH2Igm&Gv-$HMwYwxM!|rkrF%tKs3j;Q4n|CPU*! zCvYZ1SN|!}WUt=>?~Y~x_{FwR;h=AY(8EV~b|5E4U&*ML1zylxnSr=1`=g??Fcdh% zhShFO{{r6kXW7200R+aG>mkwQx~sU7t+;)>pud+8Y*NjXyM6AXAt5`w5Zw|b+Ml%5 zn%mQ+Q5W%p|DP*4DEC>y!AdLvR(_SUUqHf~g`|&I|B9=f7V5aD*y~8@L*^<@qE)q> zx=))a5_aN{pd2%W><#PaqjO#}uYC}yPA@}?5fzcLT35-;2<=8+CLBfXC~rBsPYaIP zaUr@G|A-7pfo%MYIYwu-tc0%UUD*FWp@WyW02rkUl+>xC$7YD!$dVb(2u&d9@&SFv zBObS_SX4){J@S}>3@UnxN6_8y;LsuQr!P_xppja1QpNtK5{*d}dny<(`$~#>2h%eo z^|erG#~VWs3EKi1AX93kKBy3Wx z{ma$2H4ilhN|)?v3X3CMQ5{v0sf|`9QC_H++=|ysH+g^3Ma>+Cn}@Pt$VjBeqQq23 z3t6#{%n?ZG!Jq__hSj?X=p?*qdarK4gE-!WLFF^7_OE7W1rf`N=K?;>G_AR##w(Zw z<7I0zvfiKol{=ODV;w-HM;0bx15B@*(#E*ff2>Rbvy16gBdH)igBYo-C75(Pq^>6|b-Lv9(HAbz*^dw^@ z%6tCtU-j>XjZlpv-#!yl`qgu*a?^fQjW>?VRt>=x8tuuCDj(1k9Rqrxx$s#rF+FxQ z_rl^ulr`f*NfVyft!&*aN@W#E%V}e?t(T?t6<9|f4GV&cM|}ZNJW2!5w)|pJpdQ)o zd;XH>Wl72%rSIx5MR(*=xqY@hrsCSGm;&KS7^BTmXm#KcYZmY}a@6i?$fiZ)Y@fM$ zL-7Oa`bkmb;5$a5jR5s}KmVewP%({A2lY|3f%*nAK^V-47S4I>plJic0;S}Lq|1`M z)Oj=k`h`jdl0hLFHf7XSqH8Lz8IrMS!PU^QG1I76C< zRXj+cC=6j&ByyT)h4Uj6isLTo{af1Xus2K59!Ww6!E9pU25r}n`WZ%`E|1sf-#dEoFSTZ5Z`4vYC1xR z!qVt5^H7D&g%v|Efl}^#!a_q5UMlsvF-Q6|J&vb9IK6dqRICdXhjAj)8WO z5Zw=^Xz~i_tY3;Od%R;s%tO(ggOx?fMuh07)B+D{=7h+Ea?#J*pSY>;MnM2pJA%AY z5{-GzaD07QN%1v8bRi2kCO!}e3hsmyKY@ct(t!+=(i`Cxoywi;0nHXH+;Vlzx(4FS6-A6kGV(_q@K7IUd&jy3OTcrHA%Js(Jn2SxWG_>`dt~XB?smvX6yM2eIuSf(m%y$e;Phnp8+iLpbLBWB*yb4X?o-z3- z<46Ku?OA{gPHgl`}Q4=ymZ6no(b`ppIdlSg7Msl2J zRXSfOg%X+hQZi+lT2SR*u~hhC0`X^^;G(WqjB9Cu=w&>47N9``DSZ3z8gF6mcP2T$Lj#( zwsfKe>^>Pj##!a?D7-UeIZ7!dq;}{ z$^q8`j#R>IuT$mUq!9(u95I%W)CjO>%cnJu^6aU1wTdS24 zt=nMoQG~(R$|J0Ir_yQ$&A^tI@}|&)=tV6gc0m|4hS7m5i+YKaWy&HDDTEcZa0R+8 zbRfPssoFT2MbOQ>qYSo0&pwp4=$(M(X@&t}{W|l1BcNp@-|^x7e%Jy0Oa3$(S)7?eD8u3q?TIr zHssqW^1@)6BvzQqXsLM;AEpGxL}XXrxckUilx@(h-$r*E%YPgvxpC&ugq z>y9FUVYnc?dP}FlMK>BZ9&98-dD>^YtM+r;6tnUn5PKYd9Q;gNtdmJ|lNn>$n`UGs zsUdQy5BBg;Zxxp(V!%v%)df5p=i);5Rap9nSH#|%($jz~2XKX-(*9__05TN}*eBBb zvUgkM<&P6_S;vVW<~^{QanD3w7c=xE)Vl)7E3|ge#;&JhX}@Accq$&2!@I30hE-bh ztBn`=T)agt*mHvWoP|@KP22lb)z23UkAs1PU$WaLcQIRg=)uB=WLlthjthcsV!#Y# zFgeki+iB<8meT&_DA9;*amDwzcONa+ayjO~s}IpDW#`JPzGDpet?vj43WApfx{icO z!AZd%;1h}v6h5Kig$!9CEk!M@H;LjR5jV99Ms4x?x-+vRDDX+O0;Ypen2w3J zD2;4;6LTVl`Jy7b9HXMyx-DCVEs2LYI=h6QJ*zGVnjdqw+vD5@+A}8j?nLdtZ9<7< zJx9d&w9h3>(IIr!7tH`B;Fze%E&I7{x>#Z9FkA0Idx+njG#gFeBwz}5u`8vA1MZQ^ z5)~a&9aT*yf#~jlQMN|C(38z6Qfm|h$UJ>9*7z6NQhbb|lW(A-=lIP)8*R|axuk#l zX^Ps`$nEppj{3p7XkoY-$Dv4oyh7y$YVLzl#xhlsH0O%#E+8d zvnfRj4Wz>8J2qs$9XqnCaa|F8EEA>j{7sbK#70$of=Jv=Of;zRAb>L_DHq~<)=-gE zy#cobWW((@#v;nWrn4bF>Q#^y!4}AHCoZyJ93}8{tN?GBRpn4g+U8O?K-G|PLwvB2 zleV-wI2TyDrN+&)jv}j}(q`K7rp+D7$unNgzTh12gCJnd9i0IW#e8yO^QTnkoExEq zl-u<5mg;ZU(#GJ+Dcxx`s6in4E!FSc-hR=3w?+>2C?N7XYCNVaa-~Pnj>+BxQsKyN z2vsGHdI{=ny9gVh&MLi0A(+F*738{=Px)pl3V5hSW3?)NJ9$U-2h@jrIg0C|W){+$ z5bi|#{SKBzD6Xdh$z4!*0s()M4j&6>7!cEXhq4x`S*SDzDAQJU*+X%w_9ljB(t5v@ zp3=_qvh%iV{ua%qLA8{q=zt!!bZ{;S&J3DF-X^2BRC{0Y-{STIzMHecM$Hv6Rews+ zRx5Pid(^`~@Lala=`lvZ9<6XS!31It(*~?$gv?~fA8qtupj$!)Ti%$p-JHK)pv(rb zVo$}NzR%@qDfxaOp@w>oq_rmm>c(!~m$(IV@v?H!fFrV+ox;sqx~Q;9CAwN1TwFV~ zaYjH?N-W??D&kILp$oP-pNVo*#Q}K^2Nll|9!;9I2`p7L0Hq)4Z~xVF|Jr(Q4;&?ddfzN>PV)uNltK`{gKj2zo*3?@4^y|Nj7N Cx3f {output_fname}...') + if os.path.exists(output_fname): continue - audio = load_audio(mp3_path, sample_rate=args.sample_rate) + now_start = time.time() + audio = load_audio(input_fname, sample_rate=args.sample_rate) + print(f'READING ELAPSED TIME: {time.time() - now_start}') + now_read = time.time() est_ns = inference_model(audio) - note_seq.sequence_proto_to_midi_file(est_ns, midi_path) + print(f'TRANSCRIPTION ELAPSED TIME: {time.time() - now_read}') + print(f'TOTAL ELAPSED TIME: {time.time() - now_start}') + note_seq.sequence_proto_to_midi_file(est_ns, output_fname) """ diff --git a/experiments/baselines/hft_transformer/src/amt.py b/experiments/baselines/hft_transformer/model/hft_amt.py similarity index 100% rename from experiments/baselines/hft_transformer/src/amt.py rename to experiments/baselines/hft_transformer/model/hft_amt.py diff --git a/experiments/baselines/hft_transformer/model/model_spec2midi.py b/experiments/baselines/hft_transformer/model/model_spec2midi.py new file mode 100644 index 0000000..9555568 --- /dev/null +++ b/experiments/baselines/hft_transformer/model/model_spec2midi.py @@ -0,0 +1,378 @@ +#! python + +import torch +import torch.nn as nn + +## +## Model +## +class Model_SPEC2MIDI(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder_spec2midi = encoder + self.decoder_spec2midi = decoder + + def forward(self, input_spec): + #input_spec = [batch_size, n_bin, margin+n_frame+margin] (8, 256, 192) + #print('Model_SPEC2MIDI(0) input_spec: '+str(input_spec.shape)) + + enc_vector = self.encoder_spec2midi(input_spec) + #enc_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Model_SPEC2MIDI(1) enc_vector: '+str(enc_vector.shape)) + + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.decoder_spec2midi(enc_vector) + #output_onset_A = [batch_size, n_frame, n_note] (8, 128, 88) + #output_onset_B = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_A = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #output_velocity_B = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #attention = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Model_SPEC2MIDI(2) output_onset_A: '+str(output_onset_A.shape)) + #print('Model_SPEC2MIDI(2) output_onset_B: '+str(output_onset_B.shape)) + #print('Model_SPEC2MIDI(2) output_velocity_A: '+str(output_velocity_A.shape)) + #print('Model_SPEC2MIDI(2) output_velocity_B: '+str(output_velocity_B.shape)) + #print('Model_SPEC2MIDI(2) attention: '+str(attention.shape)) + + return output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B + + +## +## Encoder +## +class Encoder_SPEC2MIDI(nn.Module): + def __init__(self, n_margin, n_frame, n_bin, cnn_channel, cnn_kernel, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + + self.device = device + self.n_frame = n_frame + self.n_bin = n_bin + self.cnn_channel = cnn_channel + self.cnn_kernel = cnn_kernel + self.hid_dim = hid_dim + self.conv = nn.Conv2d(1, self.cnn_channel, kernel_size=(1, self.cnn_kernel)) + self.n_proc = n_margin * 2 + 1 + self.cnn_dim = self.cnn_channel * (self.n_proc - (self.cnn_kernel - 1)) + self.tok_embedding_freq = nn.Linear(self.cnn_dim, hid_dim) + self.pos_embedding_freq = nn.Embedding(n_bin, hid_dim) + self.layers_freq = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.dropout = nn.Dropout(dropout) + self.scale_freq = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + + def forward(self, spec_in): + #spec_in = [batch_size, n_bin, n_margin+n_frame+n_margin] (8, 256, 192) (batch_size=8, n_bins=256, margin=32/n_frame=128) + #print('Encoder_SPEC2MIDI(0) spec_in: '+str(spec_in.shape)) + batch_size = spec_in.shape[0] + + spec = spec_in.unfold(2, self.n_proc, 1).permute(0, 2, 1, 3).contiguous() + #spec = [batch_size, n_frame, n_bin, n_proc] (8, 128, 256, 65) (batch_size=8, n_frame=128, n_bins=256, n_proc=65) + #print('Encoder_SPEC2MIDI(1) spec: '+str(spec.shape)) + + # CNN 1D + spec_cnn = spec.reshape(batch_size*self.n_frame, self.n_bin, self.n_proc).unsqueeze(1) + #spec = [batch_size*n_frame, 1, n_bin, n_proc] (8*128, 1, 256, 65) (batch_size=128, 1, n_frame, n_bins=256, n_proc=65) + #print('Encoder_SPEC2MIDI(2) spec_cnn: '+str(spec_cnn.shape)) + spec_cnn = self.conv(spec_cnn).permute(0, 2, 1, 3).contiguous() + # spec_cnn: [batch_size*n_frame, n_bin, cnn_channel, n_proc-(cnn_kernel-1)] (8*128, 256, 4, 61) + #print('Encoder_SPEC2MIDI(2) spec_cnn: '+str(spec_cnn.shape)) + + ## + ## frequency + ## + spec_cnn_freq = spec_cnn.reshape(batch_size*self.n_frame, self.n_bin, self.cnn_dim) + # spec_cnn_freq: [batch_size*n_frame, n_bin, cnn_channel, (n_proc)-(cnn_kernel-1)] (8*128, 256, 244) + #print('Encoder_SPEC2MIDI(3) spec_cnn_freq: '+str(spec_cnn_freq.shape)) + + # embedding + spec_emb_freq = self.tok_embedding_freq(spec_cnn_freq) + # spec_emb_freq: [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_SPEC2MIDI(4) spec_emb_freq: '+str(spec_emb_freq.shape)) + + # position coding + pos_freq = torch.arange(0, self.n_bin).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + #pos_freq = [batch_size, n_frame, n_bin] (8*128, 256) + #print('Encoder_SPEC2MIDI(5) pos_freq: '+str(pos_freq.shape)) + + # embedding + spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq)) + #spec_freq = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_SPEC2MIDI(6) spec_freq: '+str(spec_freq.shape)) + + # transformer encoder + for layer_freq in self.layers_freq: + spec_freq = layer_freq(spec_freq) + spec_freq = spec_freq.reshape(batch_size, self.n_frame, self.n_bin, self.hid_dim) + #spec_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Encoder_SPEC2MIDI(7) spec_freq: '+str(spec_freq.shape)) + + return spec_freq + + +## +## Decoder +## +class Decoder_SPEC2MIDI(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + self.sigmoid = nn.Sigmoid() + self.dropout = nn.Dropout(dropout) + + # CAfreq + self.pos_embedding_freq = nn.Embedding(n_note, hid_dim) + self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device) + self.layers_freq = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers-1)]) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + # SAtime + self.scale_time = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + self.pos_embedding_time = nn.Embedding(n_frame, hid_dim) + #self.layers_time = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.layers_time = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + + self.fc_onset_time = nn.Linear(hid_dim, 1) + self.fc_offset_time = nn.Linear(hid_dim, 1) + self.fc_mpe_time = nn.Linear(hid_dim, 1) + self.fc_velocity_time = nn.Linear(hid_dim, self.n_velocity) + + def forward(self, enc_spec): + batch_size = enc_spec.shape[0] + enc_spec = enc_spec.reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + #enc_spec = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Decoder_SPEC2MIDI(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## CAfreq freq(256)/note(88) + ## + pos_freq = torch.arange(0, self.n_note).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + midi_freq = self.pos_embedding_freq(pos_freq) + #pos_freq = [batch_size*n_frame, n_note] (8*128, 88) + #midi_freq = [batch_size, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_SPEC2MIDI(1) pos_freq: '+str(pos_freq.shape)) + #print('Decoder_SPEC2MIDI(1) midi_freq: '+str(midi_freq.shape)) + + midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq) + for layer_freq in self.layers_freq: + midi_freq, attention_freq = layer_freq(enc_spec, midi_freq) + dim = attention_freq.shape + attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]]) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #attention_freq = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Decoder_SPEC2MIDI(2) midi_freq: '+str(midi_freq.shape)) + #print('Decoder_SPEC2MIDI(2) attention_freq: '+str(attention_freq.shape)) + + ## output(freq) + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_SPEC2MIDI(3) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_SPEC2MIDI(3) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_SPEC2MIDI(3) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_SPEC2MIDI(3) output_velocity_freq: '+str(output_velocity_freq.shape)) + + ## + ## SAtime time(64) + ## + #midi_time: [batch_size*n_frame, n_note, hid_dim] -> [batch_size*n_note, n_frame, hid_dim] + midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim]) + pos_time = torch.arange(0, self.n_frame).unsqueeze(0).repeat(batch_size*self.n_note, 1).to(self.device) + midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time)) + #pos_time = [batch_size*n_note, n_frame] (8*88, 128) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_SPEC2MIDI(4) pos_time: '+str(pos_time.shape)) + #print('Decoder_SPEC2MIDI(4) midi_time: '+str(midi_time.shape)) + + for layer_time in self.layers_time: + midi_time = layer_time(midi_time) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_SPEC2MIDI(5) midi_time: '+str(midi_time.shape)) + + ## output(time) + output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous() + #output_onset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_time = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_SPEC2MIDI(6) output_onset_time: '+str(output_onset_time.shape)) + #print('Decoder_SPEC2MIDI(6) output_offset_time: '+str(output_offset_time.shape)) + #print('Decoder_SPEC2MIDI(6) output_mpe_time: '+str(output_mpe_time.shape)) + #print('Decoder_SPEC2MIDI(6) output_velocity_time: '+str(output_velocity_time.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, attention_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time + + +## +## sub functions +## +class EncoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, src): + #src = [batch_size, src_len, hid_dim] + + #self attention + _src, _ = self.self_attention(src, src, src) + #dropout, residual connection and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + #positionwise feedforward + _src = self.positionwise_feedforward(src) + #dropout, residual and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + return src + +class DecoderLayer_Zero(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class DecoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #self attention + _trg, _ = self.self_attention(trg, trg, trg) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, hid_dim, n_heads, dropout, device): + super().__init__() + assert hid_dim % n_heads == 0 + self.hid_dim = hid_dim + self.n_heads = n_heads + self.head_dim = hid_dim // n_heads + self.fc_q = nn.Linear(hid_dim, hid_dim) + self.fc_k = nn.Linear(hid_dim, hid_dim) + self.fc_v = nn.Linear(hid_dim, hid_dim) + self.fc_o = nn.Linear(hid_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) + + def forward(self, query, key, value): + batch_size = query.shape[0] + #query = [batch_size, query_len, hid_dim] + #key = [batch_size, key_len, hid_dim] + #value = [batch_size, value_len, hid_dim] + + Q = self.fc_q(query) + K = self.fc_k(key) + V = self.fc_v(value) + #Q = [batch_size, query_len, hid_dim] + #K = [batch_size, key_len, hid_dim] + #V = [batch_size, value_len, hid_dim] + + Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + #Q = [batch_size, n_heads, query_len, head_dim] + #K = [batch_size, n_heads, key_len, head_dim] + #V = [batch_size, n_heads, value_len, head_dim] + + energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale + #energy = [batch_size, n_heads, seq len, seq len] + + attention = torch.softmax(energy, dim = -1) + #attention = [batch_size, n_heads, query_len, key_len] + + x = torch.matmul(self.dropout(attention), V) + #x = [batch_size, n_heads, seq len, head_dim] + + x = x.permute(0, 2, 1, 3).contiguous() + #x = [batch_size, seq_len, n_heads, head_dim] + + x = x.view(batch_size, -1, self.hid_dim) + #x = [batch_size, seq_len, hid_dim] + + x = self.fc_o(x) + #x = [batch_size, seq_len, hid_dim] + + return x, attention + +class PositionwiseFeedforwardLayer(nn.Module): + def __init__(self, hid_dim, pf_dim, dropout): + super().__init__() + self.fc_1 = nn.Linear(hid_dim, pf_dim) + self.fc_2 = nn.Linear(pf_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + #x = [batch_size, seq_len, hid_dim] + + x = self.dropout(torch.relu(self.fc_1(x))) + #x = [batch_size, seq_len, pf dim] + + x = self.fc_2(x) + #x = [batch_size, seq_len, hid_dim] + + return x diff --git a/experiments/baselines/hft_transformer/model/model_spec2midi_ablation.py b/experiments/baselines/hft_transformer/model/model_spec2midi_ablation.py new file mode 100644 index 0000000..5756cc4 --- /dev/null +++ b/experiments/baselines/hft_transformer/model/model_spec2midi_ablation.py @@ -0,0 +1,707 @@ +#! python + +import torch +import torch.nn as nn + +## +## Model (single output) +## +# 1FDN: Encoder_CNNtime_SAfreq / Decoder_CAfreq +class Model_single(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder_spec2midi = encoder + self.decoder_spec2midi = decoder + + def forward(self, input_spec): + #input_spec = [batch_size, n_bin, margin+n_frame+margin] (8, 256, 192) + #print('Model_single(0) input_spec: '+str(input_spec.shape)) + + enc_vector = self.encoder_spec2midi(input_spec) + #enc_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Model_single(1) enc_vector: '+str(enc_vector.shape)) + + output_onset, output_offset, output_mpe, output_velocity = self.decoder_spec2midi(enc_vector) + #output_onset = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Model_single(2) output_onset: '+str(output_onset.shape)) + #print('Model_single(2) output_velocity: '+str(output_velocity.shape)) + + return output_onset, output_offset, output_mpe, output_velocity + + +## +## Model (combination output) +## +# 1FDT: Encoder_CNNtime_SAfreq / Decoder_CAfreq_SAtime +# 1FLT: Encoder_CNNtime_SAfreq / Decoder_linear_SAtime +# 2FDT: Encoder_CNNblock_SAfreq / Decoder_CAfreq_SAtime +class Model_combination(nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder_spec2midi = encoder + self.decoder_spec2midi = decoder + + def forward(self, input_spec): + #input_spec = [batch_size, n_bin, margin+n_frame+margin] (8, 256, 192) + #print('Model_combination(0) input_spec: '+str(input_spec.shape)) + + enc_vector = self.encoder_spec2midi(input_spec) + #enc_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Model_combination(1) enc_vector: '+str(enc_vector.shape)) + + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.decoder_spec2midi(enc_vector) + #output_onset_A = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_A = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Model_combination(2) output_onset_A: '+str(output_onset_A.shape)) + #print('Model_combination(2) output_velocity_A: '+str(output_velocity_A.shape)) + #print('Model_combination(2) output_onset_B: '+str(output_onset_B.shape)) + #print('Model_combination(2) output_velocity_B: '+str(output_velocity_B.shape)) + + return output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B + + +## +## Encoder +## +# Encoder_CNNtime_SAfreq +# Encoder_CNNblock_SAfreq +## +## Encoder CNN(time)+SA(freq) +## +class Encoder_CNNtime_SAfreq(nn.Module): + def __init__(self, n_margin, n_frame, n_bin, cnn_channel, cnn_kernel, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + + self.device = device + self.n_frame = n_frame + self.n_bin = n_bin + self.cnn_channel = cnn_channel + self.cnn_kernel = cnn_kernel + self.hid_dim = hid_dim + self.conv = nn.Conv2d(1, self.cnn_channel, kernel_size=(1, self.cnn_kernel)) + self.n_proc = n_margin * 2 + 1 + self.cnn_dim = self.cnn_channel * (self.n_proc - (self.cnn_kernel - 1)) + self.tok_embedding_freq = nn.Linear(self.cnn_dim, hid_dim) + self.pos_embedding_freq = nn.Embedding(n_bin, hid_dim) + self.layers_freq = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.dropout = nn.Dropout(dropout) + self.scale_freq = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + + def forward(self, spec_in): + #spec_in = [batch_size, n_bin, n_margin+n_frame+n_margin] (8, 256, 192) (batch_size=8, n_bins=256, margin=32/n_frame=128) + #print('Encoder_CNNtime_SAfreq(0) spec_in: '+str(spec_in.shape)) + batch_size = spec_in.shape[0] + + # CNN + spec_cnn = self.conv(spec_in.unsqueeze(1)) + #spec_cnn: [batch_size, cnn_channel, n_bin, n_margin+n_frame+n_margin-(cnn_kernel-1)] (8, 4, 256, 188) + #print('Encoder_CNNtime_SAfreq(1) spec_cnn: '+str(spec_cnn.shape)) + + # n_frame block + spec_cnn = spec_cnn.unfold(3, 61, 1).permute(0, 3, 2, 1, 4).contiguous().reshape([batch_size*self.n_frame, self.n_bin, self.cnn_dim]) + #spec_cnn: [batch_size*n_frame, n_bin, cnn_dim] (8*128, 256, 244) + #print('Encoder_CNNtime_SAfreq(2) spec_cnn: '+str(spec_cnn.shape)) + + # embedding + spec_emb_freq = self.tok_embedding_freq(spec_cnn) + # spec_emb_freq: [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNtime_SAfreq(3) spec_emb_freq: '+str(spec_emb_freq.shape)) + + # position coding + pos_freq = torch.arange(0, self.n_bin).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + #pos_freq = [batch_size*n_frame, n_bin] (8*128, 256) + #print('Encoder_CNNtime_SAfreq(4) pos_freq: '+str(pos_freq.shape)) + + # embedding + spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq)) + #spec_freq = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNtime_SAfreq(5) spec_freq: '+str(spec_freq.shape)) + + # transformer encoder + for layer_freq in self.layers_freq: + spec_freq = layer_freq(spec_freq) + spec_freq = spec_freq.reshape([batch_size, self.n_frame, self.n_bin, self.hid_dim]) + #spec_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Encoder_CNNtime_SAfreq(6) spec_freq: '+str(spec_freq.shape)) + + return spec_freq + + +## +## Encoder CNN(block)+SA(freq) +## +class Encoder_CNNblock_SAfreq(nn.Module): + def __init__(self, n_margin, n_frame, n_bin, hid_dim, n_layers, n_heads, pf_dim, dropout, dropout_convblock, device): + super().__init__() + + self.device = device + self.n_frame = n_frame + self.n_bin = n_bin + self.hid_dim = hid_dim + + k = 3 + p = 1 + # ConvBlock1 + layers_conv_1 = [] + ch1 = 48 + layers_conv_1.append(nn.Conv2d(1, ch1, kernel_size=k, stride=1, padding=p)) + layers_conv_1.append(nn.BatchNorm2d(ch1)) + layers_conv_1.append(nn.ReLU(True)) + layers_conv_1.append(nn.Conv2d(ch1, ch1, kernel_size=k, stride=1, padding=p)) + layers_conv_1.append(nn.BatchNorm2d(ch1)) + layers_conv_1.append(nn.ReLU(True)) + layers_conv_1.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_1 = nn.Sequential(*layers_conv_1) + self.dropout_1 = nn.Dropout(dropout_convblock) + # ConvBlock2 + layers_conv_2 = [] + ch2 = 64 + layers_conv_2.append(nn.Conv2d(ch1, ch2, kernel_size=k, stride=1, padding=p)) + layers_conv_2.append(nn.BatchNorm2d(ch2)) + layers_conv_2.append(nn.ReLU(True)) + layers_conv_2.append(nn.Conv2d(ch2, ch2, kernel_size=k, stride=1, padding=p)) + layers_conv_2.append(nn.BatchNorm2d(ch2)) + layers_conv_2.append(nn.ReLU(True)) + layers_conv_2.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_2 = nn.Sequential(*layers_conv_2) + self.dropout_2 = nn.Dropout(dropout_convblock) + # ConvBlock3 + layers_conv_3 = [] + ch3 = 96 + layers_conv_3.append(nn.Conv2d(ch2, ch3, kernel_size=k, stride=1, padding=p)) + layers_conv_3.append(nn.BatchNorm2d(ch3)) + layers_conv_3.append(nn.ReLU(True)) + layers_conv_3.append(nn.Conv2d(ch3, ch3, kernel_size=k, stride=1, padding=p)) + layers_conv_3.append(nn.BatchNorm2d(ch3)) + layers_conv_3.append(nn.ReLU(True)) + layers_conv_3.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_3 = nn.Sequential(*layers_conv_3) + self.dropout_3 = nn.Dropout(dropout_convblock) + # ConvBlock4 + layers_conv_4 = [] + ch4 = 128 + layers_conv_4.append(nn.Conv2d(ch3, ch4, kernel_size=k, stride=1, padding=p)) + layers_conv_4.append(nn.BatchNorm2d(ch4)) + layers_conv_4.append(nn.ReLU(True)) + layers_conv_4.append(nn.Conv2d(ch4, ch4, kernel_size=k, stride=1, padding=p)) + layers_conv_4.append(nn.BatchNorm2d(ch4)) + layers_conv_4.append(nn.ReLU(True)) + layers_conv_4.append(nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)) + self.conv_4 = nn.Sequential(*layers_conv_4) + self.dropout_4 = nn.Dropout(dropout_convblock) + + self.n_proc = n_margin * 2 + 1 + self.cnn_dim = int(int(int(int(self.n_bin/2)/2)/2)/2) + self.cnn_channel_A = 16 + self.cnn_channel_B = 8 + self.cnn_out_dim = self.n_proc * self.cnn_channel_B + + self.tok_embedding_freq = nn.Linear(self.cnn_out_dim, hid_dim) + self.pos_embedding_freq = nn.Embedding(n_bin, hid_dim) + self.layers_freq = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.dropout = nn.Dropout(dropout) + self.scale_freq = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + + def forward(self, spec_in): + #spec_in = [batch_size, n_bin, n_margin+n_frame+n_margin] (8, 256, 192) (batch_size=8, n_bins=256, margin=32/n_frame=128) + #print('Encoder_CNNblock_SAfreq(0) spec_in: '+str(spec_in.shape)) + batch_size = spec_in.shape[0] + + # conv blocks + spec1 = self.dropout_1(self.conv_1(spec_in.permute(0, 2, 1).contiguous().unsqueeze(1))) + #spec1 = [batch_size, ch1, n_margin+n_frame+n_margin, int(n_bin/2)] (8, 48, 192, 128) + #print('Encoder_CNNblock_SAfreq(1) spec1: '+str(spec1.shape)) + + spec2 = self.dropout_2(self.conv_2(spec1)) + #spec2 = [batch_size, ch2, n_margin+n_frame+n_margin, int(int(n_bin/2)/2)] (8, 64, 192, 64) + #print('Encoder_CNNblock_SAfreq(2) spec2: '+str(spec2.shape)) + + spec3 = self.dropout_3(self.conv_3(spec2)) + #spec3 = [batch_size, ch3, n_margin+n_frame+n_margin, int(int(int(n_bin/2)/2)/2)] (8, 96, 192, 32) + #print('Encoder_CNNblock_SAfreq(3) spec3: '+str(spec3.shape)) + + spec4 = self.dropout_4(self.conv_4(spec3)) + #spec4 = [batch_size, ch4, n_margin+n_frame+n_margin, int(int(int(int(n_bin/2)/2)/2)/2)] (8, 128, 192, 16) + #print('Encoder_CNNblock_SAfreq(4) spec4: '+str(spec4.shape)) + + # n_frame block + spec5 = spec4.unfold(2, self.n_proc, 1) + #spec5: [batch_size, ch4, n_frame, 16bin, n_proc] (8, 128, 128, 16, 65) + #print('Encoder_CNNblock_SAfreq(5) spec5: '+str(spec5.shape)) + + spec6 = spec5.permute(0, 2, 3, 1, 4).contiguous() + #spec6: [batch_size, n_frame, cnn_dim, ch4, n_proc] (8, 128, 16, 128, 65) + #print('Encoder_CNNblock_SAfreq(6) spec6: '+str(spec6.shape)) + + spec7 = spec6.reshape([batch_size, self.n_frame, self.cnn_dim, self.cnn_channel_A, self.cnn_channel_B, self.n_proc]) + #spec7: [batch_size, n_frame, cnn_dim, cnn_channel_A, cnn_channel_B, n_proc] (8, 128, 16, 16, 8, 65) + #print('Encoder_CNNblock_SAfreq(7) spec7: '+str(spec7.shape)) + + spec8 = spec7.reshape([batch_size, self.n_frame, self.n_bin, self.cnn_out_dim]) + #spec8: [batch_size, n_frame, n_bin, cnn_out_dim] (8, 128, 256, 520) + #print('Encoder_CNNblock_SAfreq(8) spec8: '+str(spec8.shape)) + + spec_emb_freq = self.tok_embedding_freq(spec8).reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + # spec_emb_freq: [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNblock_SAfreq(9) spec_emb_freq: '+str(spec_emb_freq.shape)) + + # position coding + pos_freq = torch.arange(0, self.n_bin).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + #pos_freq = [batch_size*n_frame, n_bin] (8*128, 256) + #print('Encoder_CNNblock_SAfreq(10) pos_freq: '+str(pos_freq.shape)) + + # embedding + spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq)) + #spec_freq = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Encoder_CNNblock_SAfreq(11) spec_freq: '+str(spec_freq.shape)) + + # transformer encoder + for layer_freq in self.layers_freq: + spec_freq = layer_freq(spec_freq) + spec_freq = spec_freq.reshape(batch_size, self.n_frame, self.n_bin, self.hid_dim) + #spec_freq = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + #print('Encoder_CNNblock_SAfreq(12) spec_freq: '+str(spec_freq.shape)) + + return spec_freq + + +## +## Decoder +## +# Decoder_CAfreq +# Decoder_CAfreq_SAtime +# Decoder_linear_SAtime +## +## Decoder CA(freq) +## +class Decoder_CAfreq(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + + self.pos_embedding_freq = nn.Embedding(n_note, hid_dim) + self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device) + self.layers_freq = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers-1)]) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + self.sigmoid = nn.Sigmoid() + + def forward(self, enc_spec): + #enc_spec = [batch_size, n_frame, n_bin, hid_dim] (8, 128, 256, 256) + batch_size = enc_spec.shape[0] + + enc_spec = enc_spec.reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + #enc_spec = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Decoder_CAfreq(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## CAfreq bin(256)/note(88) + ## + pos_freq = torch.arange(0, self.n_note).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + midi_freq = self.pos_embedding_freq(pos_freq) + #pos_freq = [batch_size*n_frame, n_note] (8*128, 88) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_CAfreq(1) pos_freq: '+str(pos_freq.shape)) + #print('Decoder_CAfreq(1) midi_freq: '+str(midi_freq.shape)) + + midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq) + for layer_freq in self.layers_freq: + midi_freq, attention_freq = layer_freq(enc_spec, midi_freq) + dim = attention_freq.shape + attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]]) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #attention_freq = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Decoder_CAfreq(2) midi_freq: '+str(midi_freq.shape)) + #print('Decoder_CAfreq(2) attention_freq: '+str(attention_freq.shape)) + + ## output + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_CAfreq(3) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_CAfreq(3) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_CAfreq(3) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_CAfreq(3) output_velocity_freq: '+str(output_velocity_freq.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq + + +## +## Decoder CA(freq)/SA(time) +## +class Decoder_CAfreq_SAtime(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + self.sigmoid = nn.Sigmoid() + self.dropout = nn.Dropout(dropout) + + # CAfreq + self.pos_embedding_freq = nn.Embedding(n_note, hid_dim) + self.layer_zero_freq = DecoderLayer_Zero(hid_dim, n_heads, pf_dim, dropout, device) + self.layers_freq = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers-1)]) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + # SAtime + self.scale_time = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + self.pos_embedding_time = nn.Embedding(n_frame, hid_dim) + #self.layers_time = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.layers_time = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + + self.fc_onset_time = nn.Linear(hid_dim, 1) + self.fc_offset_time = nn.Linear(hid_dim, 1) + self.fc_mpe_time = nn.Linear(hid_dim, 1) + self.fc_velocity_time = nn.Linear(hid_dim, self.n_velocity) + + def forward(self, enc_spec): + batch_size = enc_spec.shape[0] + enc_spec = enc_spec.reshape([batch_size*self.n_frame, self.n_bin, self.hid_dim]) + #enc_spec = [batch_size*n_frame, n_bin, hid_dim] (8*128, 256, 256) + #print('Decoder_CAfreq_SAtime(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## CAfreq freq(256)/note(88) + ## + pos_freq = torch.arange(0, self.n_note).unsqueeze(0).repeat(batch_size*self.n_frame, 1).to(self.device) + midi_freq = self.pos_embedding_freq(pos_freq) + #pos_freq = [batch_size*n_frame, n_note] (8*128, 88) + #midi_freq = [batch_size, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_CAfreq_SAtime(1) pos_freq: '+str(pos_freq.shape)) + #print('Decoder_CAfreq_SAtime(1) midi_freq: '+str(midi_freq.shape)) + + midi_freq, attention_freq = self.layer_zero_freq(enc_spec, midi_freq) + for layer_freq in self.layers_freq: + midi_freq, attention_freq = layer_freq(enc_spec, midi_freq) + dim = attention_freq.shape + attention_freq = attention_freq.reshape([batch_size, self.n_frame, dim[1], dim[2], dim[3]]) + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #attention_freq = [batch_size, n_frame, n_heads, n_note, n_bin] (8, 128, 4, 88, 256) + #print('Decoder_CAfreq_SAtime(2) midi_freq: '+str(midi_freq.shape)) + #print('Decoder_CAfreq_SAtime(2) attention_freq: '+str(attention_freq.shape)) + + ## output(freq) + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_CAfreq_SAtime(3) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_CAfreq_SAtime(3) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_CAfreq_SAtime(3) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_CAfreq_SAtime(3) output_velocity_freq: '+str(output_velocity_freq.shape)) + + ## + ## SAtime time(64) + ## + #midi_time: [batch_size*n_frame, n_note, hid_dim] -> [batch_size*n_note, n_frame, hid_dim] + midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim]) + pos_time = torch.arange(0, self.n_frame).unsqueeze(0).repeat(batch_size*self.n_note, 1).to(self.device) + midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time)) + #pos_time = [batch_size*n_note, n_frame] (8*88, 128) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_CAfreq_SAtime(4) pos_time: '+str(pos_time.shape)) + #print('Decoder_CAfreq_SAtime(4) midi_time: '+str(midi_time.shape)) + + for layer_time in self.layers_time: + midi_time = layer_time(midi_time) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_CAfreq_SAtime(5) midi_time: '+str(midi_time.shape)) + + ## output(time) + output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous() + #output_onset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_time = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_CAfreq_SAtime(6) output_onset_time: '+str(output_onset_time.shape)) + #print('Decoder_CAfreq_SAtime(6) output_offset_time: '+str(output_offset_time.shape)) + #print('Decoder_CAfreq_SAtime(6) output_mpe_time: '+str(output_mpe_time.shape)) + #print('Decoder_CAfreq_SAtime(6) output_velocity_time: '+str(output_velocity_time.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time + + +## +## Decoder linear/SA(time) +## +class Decoder_linear_SAtime(nn.Module): + def __init__(self, n_frame, n_bin, n_note, n_velocity, hid_dim, n_layers, n_heads, pf_dim, dropout, device): + super().__init__() + self.device = device + self.n_note = n_note + self.n_frame = n_frame + self.n_velocity = n_velocity + self.n_bin = n_bin + self.hid_dim = hid_dim + self.sigmoid = nn.Sigmoid() + self.dropout = nn.Dropout(dropout) + + self.fc_convert = nn.Linear(n_bin, n_note) + + self.fc_onset_freq = nn.Linear(hid_dim, 1) + self.fc_offset_freq = nn.Linear(hid_dim, 1) + self.fc_mpe_freq = nn.Linear(hid_dim, 1) + self.fc_velocity_freq = nn.Linear(hid_dim, self.n_velocity) + + # SAtime + self.scale_time = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) + self.pos_embedding_time = nn.Embedding(n_frame, hid_dim) + #self.layers_time = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + self.layers_time = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)]) + + self.fc_onset_time = nn.Linear(hid_dim, 1) + self.fc_offset_time = nn.Linear(hid_dim, 1) + self.fc_mpe_time = nn.Linear(hid_dim, 1) + self.fc_velocity_time = nn.Linear(hid_dim, self.n_velocity) + + def forward(self, enc_spec): + batch_size = enc_spec.shape[0] + enc_spec = enc_spec.permute(0, 1, 3, 2).contiguous().reshape([batch_size*self.n_frame, self.hid_dim, self.n_bin]) + #enc_spec = [batch_size*n_frame, hid_dim, n_bin] (8*128, 256, 256) + #print('Decoder_linear_SAtime(0) enc_spec: '+str(enc_spec.shape)) + + ## + ## linear bin(256)/note(88) + ## + midi_freq = self.fc_convert(enc_spec).permute(0, 2, 1).contiguous() + #midi_freq = [batch_size*n_frame, n_note, hid_dim] (8*128, 88, 256) + #print('Decoder_linear_SAtime(1) midi_freq: '+str(midi_freq.shape)) + + ## output(freq) + output_onset_freq = self.sigmoid(self.fc_onset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_offset_freq = self.sigmoid(self.fc_offset_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_mpe_freq = self.sigmoid(self.fc_mpe_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note])) + output_velocity_freq = self.fc_velocity_freq(midi_freq).reshape([batch_size, self.n_frame, self.n_note, self.n_velocity]) + #output_onset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_freq = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_freq = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_linear_SAtime(2) output_onset_freq: '+str(output_onset_freq.shape)) + #print('Decoder_linear_SAtime(2) output_offset_freq: '+str(output_offset_freq.shape)) + #print('Decoder_linear_SAtime(2) output_mpe_freq: '+str(output_mpe_freq.shape)) + #print('Decoder_linear_SAtime(2) output_velocity_freq: '+str(output_velocity_freq.shape)) + + ## + ## SAtime time(64) + ## + #midi_time: [batch_size*n_frame, n_note, hid_dim] -> [batch_size*n_note, n_frame, hid_dim] + midi_time = midi_freq.reshape([batch_size, self.n_frame, self.n_note, self.hid_dim]).permute(0, 2, 1, 3).contiguous().reshape([batch_size*self.n_note, self.n_frame, self.hid_dim]) + pos_time = torch.arange(0, self.n_frame).unsqueeze(0).repeat(batch_size*self.n_note, 1).to(self.device) + midi_time = self.dropout((midi_time * self.scale_time) + self.pos_embedding_time(pos_time)) + #pos_time = [batch_size*n_note, n_frame] (8*88, 128) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_linear_SAtime(3) pos_time: '+str(pos_time.shape)) + #print('Decoder_linear_SAtime(3) midi_time: '+str(midi_time.shape)) + + for layer_time in self.layers_time: + midi_time = layer_time(midi_time) + #midi_time = [batch_size*n_note, n_frame, hid_dim] (8*88, 128, 256) + #print('Decoder_linear_SAtime(4) midi_time: '+str(midi_time.shape)) + + ## output(time) + output_onset_time = self.sigmoid(self.fc_onset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_offset_time = self.sigmoid(self.fc_offset_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_mpe_time = self.sigmoid(self.fc_mpe_time(midi_time).reshape([batch_size, self.n_note, self.n_frame]).permute(0, 2, 1).contiguous()) + output_velocity_time = self.fc_velocity_time(midi_time).reshape([batch_size, self.n_note, self.n_frame, self.n_velocity]).permute(0, 2, 1, 3).contiguous() + #output_onset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_offset_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_mpe_time = [batch_size, n_frame, n_note] (8, 128, 88) + #output_velocity_time = [batch_size, n_frame, n_note, n_velocity] (8, 128, 88, 128) + #print('Decoder_linear_SAtime(5) output_onset_time: '+str(output_onset_time.shape)) + #print('Decoder_linear_SAtime(5) output_offset_time: '+str(output_offset_time.shape)) + #print('Decoder_linear_SAtime(5) output_mpe_time: '+str(output_mpe_time.shape)) + #print('Decoder_linear_SAtime(5) output_velocity_time: '+str(output_velocity_time.shape)) + + return output_onset_freq, output_offset_freq, output_mpe_freq, output_velocity_freq, output_onset_time, output_offset_time, output_mpe_time, output_velocity_time + + +## +## sub functions +## +class EncoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, src): + #src = [batch_size, src_len, hid_dim] + + #self attention + _src, _ = self.self_attention(src, src, src) + #dropout, residual connection and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + #positionwise feedforward + _src = self.positionwise_feedforward(src) + #dropout, residual and layer norm + src = self.layer_norm(src + self.dropout(_src)) + #src = [batch_size, src_len, hid_dim] + + return src + +class DecoderLayer_Zero(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class DecoderLayer(nn.Module): + def __init__(self, hid_dim, n_heads, pf_dim, dropout, device): + super().__init__() + self.layer_norm = nn.LayerNorm(hid_dim) + self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) + self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_src, trg): + #trg = [batch_size, trg_len, hid_dim] + #enc_src = [batch_size, src_len, hid_dim] + + #self attention + _trg, _ = self.self_attention(trg, trg, trg) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #encoder attention + _trg, attention = self.encoder_attention(trg, enc_src, enc_src) + #dropout, residual connection and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + + #positionwise feedforward + _trg = self.positionwise_feedforward(trg) + #dropout, residual and layer norm + trg = self.layer_norm(trg + self.dropout(_trg)) + #trg = [batch_size, trg_len, hid_dim] + #attention = [batch_size, n_heads, trg_len, src_len] + + return trg, attention + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, hid_dim, n_heads, dropout, device): + super().__init__() + assert hid_dim % n_heads == 0 + self.hid_dim = hid_dim + self.n_heads = n_heads + self.head_dim = hid_dim // n_heads + self.fc_q = nn.Linear(hid_dim, hid_dim) + self.fc_k = nn.Linear(hid_dim, hid_dim) + self.fc_v = nn.Linear(hid_dim, hid_dim) + self.fc_o = nn.Linear(hid_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) + + def forward(self, query, key, value): + batch_size = query.shape[0] + #query = [batch_size, query_len, hid_dim] + #key = [batch_size, key_len, hid_dim] + #value = [batch_size, value_len, hid_dim] + + Q = self.fc_q(query) + K = self.fc_k(key) + V = self.fc_v(value) + #Q = [batch_size, query_len, hid_dim] + #K = [batch_size, key_len, hid_dim] + #V = [batch_size, value_len, hid_dim] + + Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) + #Q = [batch_size, n_heads, query_len, head_dim] + #K = [batch_size, n_heads, key_len, head_dim] + #V = [batch_size, n_heads, value_len, head_dim] + + energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale + #energy = [batch_size, n_heads, seq len, seq len] + + attention = torch.softmax(energy, dim = -1) + #attention = [batch_size, n_heads, query_len, key_len] + + x = torch.matmul(self.dropout(attention), V) + #x = [batch_size, n_heads, seq len, head_dim] + + x = x.permute(0, 2, 1, 3).contiguous() + #x = [batch_size, seq_len, n_heads, head_dim] + + x = x.view(batch_size, -1, self.hid_dim) + #x = [batch_size, seq_len, hid_dim] + + x = self.fc_o(x) + #x = [batch_size, seq_len, hid_dim] + + return x, attention + +class PositionwiseFeedforwardLayer(nn.Module): + def __init__(self, hid_dim, pf_dim, dropout): + super().__init__() + self.fc_1 = nn.Linear(hid_dim, pf_dim) + self.fc_2 = nn.Linear(pf_dim, hid_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + #x = [batch_size, seq_len, hid_dim] + + x = self.dropout(torch.relu(self.fc_1(x))) + #x = [batch_size, seq_len, pf dim] + + x = self.fc_2(x) + #x = [batch_size, seq_len, hid_dim] + + return x diff --git a/experiments/baselines/hft_transformer/transcribe_new_files.py b/experiments/baselines/hft_transformer/transcribe_new_files.py index 37f874e..5ac5a80 100644 --- a/experiments/baselines/hft_transformer/transcribe_new_files.py +++ b/experiments/baselines/hft_transformer/transcribe_new_files.py @@ -4,13 +4,17 @@ import json import sys import glob -from experiments.baselines.hft_transformer.src import amt -from pydub import AudioSegment -from pydub.exceptions import CouldntDecodeError import random import torch here = os.path.dirname(os.path.abspath(__file__)) - +import sys +sys.path.append(os.path.join(here, 'model')) +import hft_amt as amt +import time +from random import shuffle +sys.path.append(os.path.join(here, '../..')) +import loader_util +from tqdm.auto import tqdm _AMT = None def get_AMT(config_file=None, model_file=None): @@ -33,17 +37,6 @@ def get_AMT(config_file=None, model_file=None): _AMT.model = model return _AMT -def check_and_convert_mp3_to_wav(fname): - wav_file = fname.replace('.mp3', '.wav') - if not os.path.exists(wav_file): - print('converting ' + fname + ' to .wav...') - try: - sound = AudioSegment.from_mp3(fname) - sound.export(fname.replace('.mp3', '.wav'), format="wav") - except CouldntDecodeError: - print('failed to convert ' + fname) - return None - return wav_file def transcribe_file( @@ -59,9 +52,10 @@ def transcribe_file( ): if AMT is None: AMT = get_AMT() - + now_start = time.time() a_feature = AMT.wav2feature(fname) - + print(f'READING ELAPSED TIME: {time.time() - now_start}') + now_read = time.time() # transcript if n_stride > 0: output = AMT.transcript_stride(a_feature, n_stride, mode=mode, ablation_flag=ablation) @@ -69,7 +63,8 @@ def transcribe_file( output = AMT.transcript(a_feature, mode=mode, ablation_flag=ablation) (output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity, output_2nd_onset, output_2nd_offset, output_2nd_mpe, output_2nd_velocity) = output - + print(f'TRANSCRIPTION ELAPSED TIME: {time.time() - now_read}') + print(f'TOTAL ELAPSED TIME: {time.time() - now_start}') # note (mpe2note) a_note_1st_predict = AMT.mpe2note( a_onset=output_1st_onset, @@ -101,15 +96,9 @@ def transcribe_file( if __name__ == '__main__': parser = argparse.ArgumentParser() # necessary arguments - parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') - parser.add_argument('-input_file_to_transcribe', default=None, help='one file') - parser.add_argument('-output_dir', help='output directory') - parser.add_argument('-output_file', default=None, help='output file') + parser = loader_util.add_io_arguments(parser) parser.add_argument('-f_config', help='config json file', default=None) parser.add_argument('-model_file', help='input model file', default=None) - parser.add_argument('-start_index', help='start index', type=int, default=None) - parser.add_argument('-end_index', help='end index', type=int, default=None) - parser.add_argument('-skip_transcribe_mp3', action='store_true', default=False) # parameters parser.add_argument('-mode', help='mode to transcript (combination|single)', default='combination') parser.add_argument('-thred_mpe', help='threshold value for mpe detection', type=float, default=0.5) @@ -121,56 +110,23 @@ def transcribe_file( assert (args.input_dir_to_transcribe is not None) or (args.input_file_to_transcribe is not None), "input file or directory is not specified" - if args.input_dir_to_transcribe is not None: - if not args.skip_transcribe_mp3: - # list file - a_mp3s = ( - glob.glob(os.path.join(args.input_dir_to_transcribe, '*.mp3')) + - glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.mp3')) - ) - print(f'transcribing {len(a_mp3s)} files: [{str(a_mp3s)}]...') - list(map(check_and_convert_mp3_to_wav, a_mp3s)) - - a_list = ( - glob.glob(os.path.join(args.input_dir_to_transcribe, '*.wav')) + - glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.wav')) - ) - if (args.start_index is not None) or (args.end_index is not None): - if args.start_index is None: - args.start_index = 0 - if args.end_index is None: - args.end_index = len(a_list) - a_list = a_list[args.start_index:args.end_index] - # shuffle a_list - random.shuffle(a_list) - - elif args.input_file_to_transcribe is not None: - args.input_file_to_transcribe = check_and_convert_mp3_to_wav(args.input_file_to_transcribe) - if args.input_file_to_transcribe is None: - sys.exit() - a_list = [args.input_file_to_transcribe] - print(f'transcribing {str(a_list)} files...') + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + a_list = loader_util.get_files_to_transcribe(args) # load model AMT = get_AMT(args.f_config, args.model_file) long_filename_counter = 0 - for fname in a_list: - if args.output_file is not None: - output_fname = args.output_file - else: - output_fname = fname.replace('.wav', '') - if len(output_fname) > 200: - output_fname = output_fname[:200] + f'_fnabbrev-{long_filename_counter}' - output_fname += '_transcribed.mid' - output_fname = os.path.join(args.output_dir, os.path.basename(output_fname)) - if os.path.exists(output_fname): - continue - - print('[' + fname + ']') + for input_fname, output_fname in tqdm(a_list): + if os.path.exists(output_fname): + continue + + print(f'transcribing {input_fname} -> {output_fname}') try: transcribe_file( - fname, + input_fname, output_fname, args.mode, args.thred_mpe, @@ -180,6 +136,8 @@ def transcribe_file( args.ablation, AMT, ) + now = time.time() + print(f'ELAPSED TIME: {time.time() - now}') except Exception as e: print(e) continue @@ -193,4 +151,12 @@ def transcribe_file( python evaluation/transcribe_new_files.py \ -input_dir_to_transcribe evaluation/glenn-gould-bach-data \ -output_dir hft-evaluation-data/ \ + +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir hft-dtw-evaluation-data/ \ + -file_col_name audio_path """ diff --git a/experiments/baselines/requirements-baselines.txt b/experiments/baselines/requirements-baselines.txt index c0fb844..17a212d 100644 --- a/experiments/baselines/requirements-baselines.txt +++ b/experiments/baselines/requirements-baselines.txt @@ -1,14 +1,17 @@ +# you need python 3.11 for this to work (tensorflow-text dependency) pretty_midi -librosa +librosa==0.9.2 piano_transcription_inference -jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# jax[cuda12_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html nest-asyncio gdown tensorflow[and-cuda] gin-config -seqio t5 wget +note-seq git+https://github.com/magenta/mt3.git # -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ No newline at end of file diff --git a/experiments/get_maestro_split.py b/experiments/get_maestro_split.py deleted file mode 100644 index cd374c3..0000000 --- a/experiments/get_maestro_split.py +++ /dev/null @@ -1,45 +0,0 @@ -# Helper script to print out all files in the desired split of the MAESTRO dataset. -import pandas as pd -import os -import shutil -from amt.audio import AudioTransform - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('-split', type=str, required=True, help='Split to print out.') - parser.add_argument('-maestro_dir', type=str, required=True, help='Directory of the MAESTRO dataset.') - parser.add_argument('-output_dir', type=str, required=True, help='Output directory.') - parser.add_argument('-apply_augmentation', action='store_true', default=False, help='Apply augmentation to the files.') - args = parser.parse_args() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - if args.apply_augmentation: - audio_transformer = AudioTransform( - reverb_factor=1, - min_snr=20, - max_snr=50, - max_dist_gain=25, - min_dist_gain=0, - noise_ratio=0.95, - reverb_ratio=0.95, - applause_ratio=0.01, - bandpass_ratio=0.15, - distort_ratio=0.15, - reduce_ratio=0.01, - detune_ratio=0.1, - detune_max_shift=0.15, - spec_aug_ratio=0.5, - ) - - # Load the split - maestro_df = pd.read_csv(os.path.join(args.maestro_dir, 'maestro-v3.0.0.csv')) - split_df = maestro_df.loc[lambda df: df['split'] == args.split] - for _, row in split_df.iterrows(): - shutil.copy( - os.path.join(args.maestro_dir, row['audio_filename']), - os.path.join(args.output_dir, row['audio_filename']) - ) - diff --git a/experiments/loader_util.py b/experiments/loader_util.py new file mode 100644 index 0000000..6a7ef84 --- /dev/null +++ b/experiments/loader_util.py @@ -0,0 +1,139 @@ +import argparse +from pydub import AudioSegment +from pydub.exceptions import CouldntDecodeError +import os +import glob +import random +import sys +import pandas as pd +from more_itertools import unique_everseen +from tqdm.auto import tqdm +from random import shuffle + + +def add_io_arguments(parser: argparse.ArgumentParser): + parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') + parser.add_argument('-input_file_to_transcribe', default=None, help='one file') + + # params for if we're reading file names from a CSV + parser.add_argument('-input_files_map', help='CSV of files to transcribe', default=None) + parser.add_argument('-file_col_name', help='column name for file', default='file') + parser.add_argument('-split', help='split', default=None) + parser.add_argument('-split_col_name', help='column name for split', default='split') + parser.add_argument('-dataset', help='dataset', default=None) + parser.add_argument('-dataset_col_name', help='column name for dataset', default='dataset') + + # some algorithms only take a certain file format (e.g. MP3 or WAV) + parser.add_argument('-input_file_format', default=None, + help='Required input format ["mp3", "wav"]. ' + 'E.g. (I think) hFT only takes in WAV files.' + ) + parser.add_argument('-output_dir', help='output directory') + parser.add_argument('-output_file', default=None, help='output file') + parser.add_argument('-start_index', help='start index', type=int, default=None) + parser.add_argument('-end_index', help='end index', type=int, default=None) + return parser + + +def check_and_convert_between_mp3_and_wav(input_fname, current_fmt='mp3', desired_fmt='wav'): + input_fmt, output_fmt = f'.{current_fmt}', f'.{desired_fmt}' + output_file = input_fname.replace(input_fmt, output_fmt) + if not os.path.exists(input_fname): + print(f'converting {input_fname}: {input_fmt}->{output_fmt}...') + try: + if input_fmt == 'mp3': + sound = AudioSegment.from_mp3(input_fname) + sound.export(output_file, format="wav") + else: + sound = AudioSegment.from_wav(input_fname) + sound.export(output_file, format="mp3") + except CouldntDecodeError: + print('failed to convert ' + input_fname) + return None + return output_file + + +def get_files_to_transcribe(args): + """ + Helper function to get the files to transcribe. + Reads in the files from a CSV, a directory, or a single file. + (if CSV is provided, then the input directory serves to give us a starting-point for the files.) + (otherwise, we just glob all the files in the directory.) + + Returns list of tuples (input_file, output_file). + Output file the same as input file, with "_transcribed.midi". + If no output directory is provided, it is placed in the same directory. + Otherwise, it is placed in the output directory. + The same file hierarchy is maintained. + + :param args: argparse.ArgumentParser + :return + + """ + # get files to transcribe + + # if just one filename is provided, format it as a list + if args.input_file_to_transcribe is not None: + files_to_transcribe = [args.input_file_to_transcribe] + + # get a list of files from a CSV + elif args.input_files_map is not None: + files_to_transcribe = pd.read_csv(args.input_files_map) + if args.split is not None: + files_to_transcribe = files_to_transcribe.loc[lambda df: df[args.split_col_name] == args.split] + if args.dataset is not None: + files_to_transcribe = files_to_transcribe.loc[lambda df: df[args.dataset_col_name] == args.dataset] + files_to_transcribe = files_to_transcribe[args.file_col_name].tolist() + if args.input_dir_to_transcribe is not None: + files_to_transcribe = list(map(lambda x: os.path.join(args.input_dir_to_transcribe, x), files_to_transcribe)) + + # get all files in a directory + elif args.input_dir_to_transcribe is not None: + files_to_transcribe = ( + glob.glob(os.path.join(args.input_dir_to_transcribe, '**', '*.mp3'), recursive=True) + + glob.glob(os.path.join(args.input_dir_to_transcribe, '**', '*.wav'), recursive=True) + ) + + # convert file-types + if args.input_file_format is not None: + # make sure all the files of mp3 are converted to wav, or v.v. + other_fmt = 'mp3' if args.input_file_format == 'wav' else 'wav' + files_to_convert = list(filter(lambda x: os.path.splitext(x)[1] == other_fmt, files_to_transcribe)) + print(f'converting {len(files_to_convert)} files...') + for f in files_to_convert: + check_and_convert_between_mp3_and_wav(f, current_fmt=other_fmt, desired_fmt=args.input_file_format) + else: + # input format doesn't matter, so we just want 1 of each + files_to_transcribe = list(unique_everseen(files_to_transcribe, key=lambda x: os.path.splitext(x)[0])) + + # apply cutoffs + if (args.start_index is not None) or (args.end_index is not None): + if args.start_index is None: + args.start_index = 0 + if args.end_index is None: + args.end_index = len(files_to_transcribe) + files_to_transcribe = files_to_transcribe[args.start_index:args.end_index] + + # format output + if args.output_file is not None: + os.makedirs(os.path.dirname(args.output_file), exist_ok=True) + return (files_to_transcribe[0], args.output_file) + + # if the output directory is not provided, then we just put the output files in the same directory + # otherwise, we output to the output directory, preserving the hierarchy of the original files. + output_files = list(map(lambda x: f"{os.path.splitext(x)[0]}_transcribed.midi", files_to_transcribe)) + if args.output_dir is not None: + if args.input_dir_to_transcribe is not None: + output_files = list(map(lambda x: x[len(args.input_dir_to_transcribe):], output_files)) + output_files = list(map(lambda x: os.path.join(args.output_dir, x), output_files)) + for o in output_files: + os.makedirs(os.path.dirname(o), exist_ok=True) + + # shuffle + output = list(zip(files_to_transcribe, output_files)) + output = list(filter(lambda x: not os.path.exists(x[1]), output)) + random.shuffle(output) + return output + + + diff --git a/experiments/process_input_files.py b/experiments/process_input_files.py new file mode 100644 index 0000000..0d11b53 --- /dev/null +++ b/experiments/process_input_files.py @@ -0,0 +1,96 @@ +""" +Helper script to get and augment MAESTRO test according to `-augmentation_config` and `-apply_augmentation` flags. +""" +import pandas as pd +import os +import shutil +from amt.audio import AudioTransform, pad_or_trim +from amt.data import get_wav_mid_segments, load_config +import json +import librosa +import torch +import torchaudio +from tqdm.auto import tqdm + +SAMPLE_RATE = load_config()['audio']['sample_rate'] +AUG_BATCH_SIZE = 100 + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-split', type=str, required=True, default='test', help='Split to print out.') + parser.add_argument('-dataset', type=str, default=None, help='Dataset to use.') + parser.add_argument('-input_file_dir', type=str, default=None, help='Directory of the dataset to use.') + parser.add_argument( + '-input_splits_file', + type=str, + required=True, + help='Directory of the MAESTRO dataset.' + ) + parser.add_argument('-midi_col_name', type=str, default=None, help='Column name for MIDI files.') + parser.add_argument('-audio_col_name', type=str, default=None, help='Column name for audio files.') + parser.add_argument('-output_dir', type=str, required=True, help='Output directory.') + parser.add_argument('-apply_augmentation', action='store_true', default=False, help='Apply augmentation to the files.') + parser.add_argument('-augmentation_config', type=str, default=None, help='Path to the augmentation config file.') + parser.add_argument('-device', type=str, default='cpu', help='Device to use.') + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + audio_transformer = None + if args.apply_augmentation: + aug_config = json.load(open(args.augmentation_config)) + audio_transformer = AudioTransform(**aug_config).to(args.device) + + # Load the split + input_files_to_process = pd.read_csv(args.input_splits_file) + if args.split is not None: + input_files_to_process = input_files_to_process.loc[lambda df: df['split'] == args.split] + if args.dataset is not None: + input_files_to_process = input_files_to_process.loc[lambda df: df['dataset'] == args.dataset] + + # Process the files + for _, row in tqdm( + input_files_to_process.iterrows(), + total=len(input_files_to_process), + desc=f'Processing {args.split} split' + ): + # copy MIDI file into the output directory + if args.midi_col_name is not None: + midi_outfile = os.path.basename(row[args.midi_col_name]) + fname, ext = os.path.splitext(midi_outfile) + midi_outfile = f'{fname}_gold{ext}' + midi_outfile = os.path.join(args.output_dir, midi_outfile) + if not os.path.exists(midi_outfile): + shutil.copy( + os.path.join(args.input_file_dir, row['midi_filename']), + os.path.join(args.output_dir, midi_outfile) + ) + + # either just vanilla copy the audio file, or apply augmentation + if args.audio_col_name is not None: + audio_outfile = os.path.basename(row[args.audio_col_name]) + audio_outfile = os.path.join(args.output_dir, audio_outfile) + audio_input_file = os.path.join(args.input_file_dir, row[args.audio_col_name]) + if not os.path.exists(audio_outfile): + if args.apply_augmentation: + try: + segments = get_wav_mid_segments(audio_input_file) + segments = list(map(lambda x: x[0], segments)) + aug_wav_parts = [] + for i in range(0, len(segments), AUG_BATCH_SIZE): + batch_to_augment = torch.vstack(segments[i:i + AUG_BATCH_SIZE]).to(args.device) + mel = audio_transformer(batch_to_augment) + aug_wav = audio_transformer.inverse_log_mel(mel) + aug_wav_parts.append(aug_wav) + aug_wav = torch.vstack(aug_wav_parts) + aug_wav = aug_wav.reshape(1, -1).cpu() + torchaudio.save(audio_outfile, src=aug_wav, sample_rate=SAMPLE_RATE) + except Exception as e: + print(f'Failed to augment {audio_input_file}: {e}') + else: + shutil.copy( + os.path.join(audio_input_file), + os.path.join(args.output_dir, audio_outfile) + ) diff --git a/experiments/run_dtw_transcription.sh b/experiments/run_dtw_transcription.sh new file mode 100644 index 0000000..45c5a5a --- /dev/null +++ b/experiments/run_dtw_transcription.sh @@ -0,0 +1,36 @@ +#!/bin/sh +#SBATCH --output=dtw_transcription__%x.%j.out +#SBATCH --error=dtw_transcription__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + + +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir hft_transformer-evaluation-data/ \ + -file_col_name audio_path + +python baselines/giantmidi/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir giantmidi-evaluation-data/ \ + -file_col_name audio_path + +conda activate py311 +python baselines/google_t5/transcribe_new_files.py \ + -input_dir_to_transcribe ../../music-corpora/ \ + -input_files_map other-dataset-splits.csv \ + -split_col_name split \ + -split test \ + -output_dir google-evaluation-data/ \ + -file_col_name audio_path \ No newline at end of file diff --git a/experiments/run_maestro_aug_1.sh b/experiments/run_maestro_aug_1.sh new file mode 100644 index 0000000..914e554 --- /dev/null +++ b/experiments/run_maestro_aug_1.sh @@ -0,0 +1,50 @@ +#!/bin/sh +#SBATCH --output=aug_1__%x.%j.out +#SBATCH --error=aug_1__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + +conda activate py311 +PROJ_DIR=/project/jonmay_231/spangher/Projects/aria-amt +OUTPUT_DIR="$PROJ_DIR/experiments/aug_1_files" + +# process data +if [ ! -d "$OUTPUT_DIR" ]; then + python process_maestro.py \ + -split test \ + -maestro_dir "$PROJ_DIR/../maestro-v3.0.0/maestro-v3.0.0.csv" \ + -output_dir $OUTPUT_DIR \ + -split test \ + -midi_col_name 'midi_filename' \ + -audio_col_name 'audio_filename' \ + -apply_augmentation \ + -augmentation_config "$PROJ_DIR/experiments/augmentation_configs/config_2.json" \ + -device 'cuda:0' +fi + +# run google inference +echo "Running google inference" +GOOGLE_OUTPUT_DIR="$OUTPUT_DIR/google_t5_transcriptions" +#if [ ! -d "$GOOGLE_OUTPUT_DIR" ]; then +python baselines/google_t5/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GOOGLE_OUTPUT_DIR +#fi + +echo "Running giant midi inference" +GIANT_MIDI_OUTPUT_DIR="$OUTPUT_DIR/giant_midi_transcriptions" +python baselines/giantmidi/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GIANT_MIDI_OUTPUT_DIR + +echo "Running hft inference" +HFT_OUTPUT_DIR="$OUTPUT_DIR/hft_transcriptions" +conda activate py311 +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $HFT_OUTPUT_DIR \ No newline at end of file diff --git a/experiments/run_maestro_aug_2.sh b/experiments/run_maestro_aug_2.sh new file mode 100644 index 0000000..b579c3e --- /dev/null +++ b/experiments/run_maestro_aug_2.sh @@ -0,0 +1,52 @@ +#!/bin/sh +#SBATCH --output=aug_2__%x.%j.out +#SBATCH --error=aug_2__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + +conda activate py311 +PROJ_DIR=/project/jonmay_231/spangher/Projects/aria-amt +OUTPUT_DIR="$PROJ_DIR/experiments/aug_2_files" + +# process data +if [ ! -d "$OUTPUT_DIR" ]; then + python process_maestro.py \ + -split test \ + -maestro_dir "$PROJ_DIR/../maestro-v3.0.0/maestro-v3.0.0.csv" \ + -output_dir $OUTPUT_DIR \ + -split test \ + -midi_col_name 'midi_filename' \ + -audio_col_name 'audio_filename' \ + -apply_augmentation \ + -augmentation_config "$PROJ_DIR/experiments/augmentation_configs/config_2.json" \ + -device 'cuda:0' +fi + +source /home1/${USER}/.bashrc +conda activate py311 + +## run google inference +#echo "Running google inference" +#GOOGLE_OUTPUT_DIR="$OUTPUT_DIR/google_t5_transcriptions" +##if [ ! -d "$GOOGLE_OUTPUT_DIR" ]; then +#python baselines/google_t5/transcribe_new_files.py \ +# -input_dir_to_transcribe $OUTPUT_DIR \ +# -output_dir $GOOGLE_OUTPUT_DIR +##fi +# +#echo "Running giant midi inference" +#GIANT_MIDI_OUTPUT_DIR="$OUTPUT_DIR/giant_midi_transcriptions" +#python baselines/giantmidi/transcribe_new_files.py \ +# -input_dir_to_transcribe $OUTPUT_DIR \ +# -output_dir $GIANT_MIDI_OUTPUT_DIR + +echo "Running hft inference" +HFT_OUTPUT_DIR="$OUTPUT_DIR/hft_transcriptions" +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $HFT_OUTPUT_DIR \ No newline at end of file diff --git a/experiments/run_maestro_vanilla.sh b/experiments/run_maestro_vanilla.sh index e69de29..b45f503 100644 --- a/experiments/run_maestro_vanilla.sh +++ b/experiments/run_maestro_vanilla.sh @@ -0,0 +1,44 @@ +#!/bin/sh +#SBATCH --output=vanilla__%x.%j.out +#SBATCH --error=vanilla__%x.%j.err +#SBATCH -N 1 +#SBATCH -n 1 +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --mem-per-gpu=100GB +#SBATCH --cpus-per-gpu=20 +#SBATCH --partition=isi + +source /home1/${USER}/.bashrc +conda activate py311 + +PROJ_DIR=/project/jonmay_231/spangher/Projects/aria-amt +OUTPUT_DIR="$PROJ_DIR/experiments/vanilla_files" + +if [ ! -d "$OUTPUT_DIR" ]; then + python process_maestro.py \ + -split test \ + -maestro_dir "$PROJ_DIR/../maestro-v3.0.0/maestro-v3.0.0.csv" \ + -output_dir $OUTPUT_DIR \ + -midi_col_name 'midi_filename' \ + -audio_col_name 'audio_filename' +fi + +# run google inference +echo "Running google inference" +GOOGLE_OUTPUT_DIR="$OUTPUT_DIR/google_t5_transcriptions" +python baselines/google_t5/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GOOGLE_OUTPUT_DIR + +echo "Running giant midi inference" +GIANT_MIDI_OUTPUT_DIR="$OUTPUT_DIR/giant_midi_transcriptions" +python baselines/giantmidi/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $GIANT_MIDI_OUTPUT_DIR + +echo "Running hft inference" +HFT_OUTPUT_DIR="$OUTPUT_DIR/hft_transcriptions" +python baselines/hft_transformer/transcribe_new_files.py \ + -input_dir_to_transcribe $OUTPUT_DIR \ + -output_dir $HFT_OUTPUT_DIR \ No newline at end of file From fa9ac6e488529e7795ea4540efb949a28ef518c4 Mon Sep 17 00:00:00 2001 From: Alex Spangher Date: Tue, 9 Apr 2024 13:22:01 -0700 Subject: [PATCH 3/3] updated requirements.txt --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a74c9cf..1ed3700 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ accelerate mido tqdm orjson -mir_eval \ No newline at end of file +mir_eval +pyfluidsynth @ git+https://github.com/nwhitehead/pyfluidsynth.git +midi2audio \ No newline at end of file