From 90b19031d8bbdb9d26f33699e01a409263b2b284 Mon Sep 17 00:00:00 2001 From: Alex Spangher Date: Mon, 11 Mar 2024 14:01:03 -0700 Subject: [PATCH 1/4] added codecs compression to augmentation --- .gitignore | 2 +- amt/audio.py | 18 ++ ..._experiment-with-sound-augmentations.ipynb | 213 ++++++++++++++++++ 3 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb diff --git a/.gitignore b/.gitignore index 5b48bfd..2ec3467 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,7 @@ *.htm .idea/ - +notebooks/scratch # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/amt/audio.py b/amt/audio.py index ed90f65..8e50c04 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -197,6 +197,7 @@ def __init__( distort_ratio: float = 0.15, reduce_ratio: float = 0.01, spec_aug_ratio: float = 0.25, + codecs_ratio: float = 0.01, ): super().__init__() self.tokenizer = AmtTokenizer() @@ -217,6 +218,7 @@ def __init__( self.distort_ratio = distort_ratio self.reduce_ratio = reduce_ratio self.spec_aug_ratio = spec_aug_ratio + self.codecs_ratio = codecs_ratio self.reduction_resample_rate = 6000 # Hardcoded? # Audio aug @@ -374,6 +376,21 @@ def apply_distortion(self, wav: torch.tensor): return AF.overdrive(wav, gain=gain, colour=colour) + def apply_codec(self, wav: torch.tensor): + """ + Apply different audio codecs to the audio. + """ + format_encoder_pairs = [ + ("wav", "pcm_mulaw"), + ("g722", None), + ("ogg", "vorbis") + ] + for format, encoder in format_encoder_pairs: + encoder = torchaudio.io.AudioEffector(format=format, encoder=encoder) + if random.random() < self.codecs_ratio: + wav = encoder.apply(wav, self.sample_rate) + return wav + def shift_spec(self, specs: torch.Tensor, shift: int): if shift == 0: return specs @@ -403,6 +420,7 @@ def aug_wav(self, wav: torch.Tensor): # Noise if random.random() < self.noise_ratio: wav = self.apply_noise(wav) + if random.random() < self.applause_ratio: wav = self.apply_applause(wav) diff --git a/notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb b/notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb new file mode 100644 index 0000000..f45c6dd --- /dev/null +++ b/notebooks/2024-03-11__experiment-with-sound-augmentations.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "outputs": [ + { + "ename": "OSError", + "evalue": "dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mOSError\u001B[0m Traceback (most recent call last)", + "Input \u001B[0;32mIn [1]\u001B[0m, in \u001B[0;36m\u001B[0;34m()\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mIPython\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdisplay\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Audio\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m download_asset\n\u001B[1;32m 5\u001B[0m SAMPLE_WAV \u001B[38;5;241m=\u001B[39m download_asset(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtutorial-assets/steam-train-whistle-daniel_simon.wav\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 6\u001B[0m SAMPLE_RIR \u001B[38;5;241m=\u001B[39m download_asset(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtutorial-assets/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo-8000hz.wav\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/__init__.py:1\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ( \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[1;32m 2\u001B[0m _extension,\n\u001B[1;32m 3\u001B[0m compliance,\n\u001B[1;32m 4\u001B[0m datasets,\n\u001B[1;32m 5\u001B[0m functional,\n\u001B[1;32m 6\u001B[0m io,\n\u001B[1;32m 7\u001B[0m kaldi_io,\n\u001B[1;32m 8\u001B[0m models,\n\u001B[1;32m 9\u001B[0m pipelines,\n\u001B[1;32m 10\u001B[0m sox_effects,\n\u001B[1;32m 11\u001B[0m transforms,\n\u001B[1;32m 12\u001B[0m utils,\n\u001B[1;32m 13\u001B[0m )\n\u001B[1;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackend\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_audio_backend, list_audio_backends, set_audio_backend\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:103\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 99\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mException\u001B[39;00m:\n\u001B[1;32m 100\u001B[0m \u001B[38;5;28;01mpass\u001B[39;00m\n\u001B[0;32m--> 103\u001B[0m \u001B[43m_init_extension\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:88\u001B[0m, in \u001B[0;36m_init_extension\u001B[0;34m()\u001B[0m\n\u001B[1;32m 85\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtorchaudio C++ extension is not available.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 86\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m\n\u001B[0;32m---> 88\u001B[0m \u001B[43m_load_lib\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mlibtorchaudio\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 89\u001B[0m \u001B[38;5;66;03m# This import is for initializing the methods registered via PyBind11\u001B[39;00m\n\u001B[1;32m 90\u001B[0m \u001B[38;5;66;03m# This has to happen after the base library is loaded\u001B[39;00m\n\u001B[1;32m 91\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _torchaudio \u001B[38;5;66;03m# noqa\u001B[39;00m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:51\u001B[0m, in \u001B[0;36m_load_lib\u001B[0;34m(lib)\u001B[0m\n\u001B[1;32m 49\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m path\u001B[38;5;241m.\u001B[39mexists():\n\u001B[1;32m 50\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m---> 51\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mops\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_library\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 52\u001B[0m torch\u001B[38;5;241m.\u001B[39mclasses\u001B[38;5;241m.\u001B[39mload_library(path)\n\u001B[1;32m 53\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torch/_ops.py:643\u001B[0m, in \u001B[0;36m_Ops.load_library\u001B[0;34m(self, path)\u001B[0m\n\u001B[1;32m 638\u001B[0m path \u001B[38;5;241m=\u001B[39m _utils_internal\u001B[38;5;241m.\u001B[39mresolve_library_path(path)\n\u001B[1;32m 639\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m dl_open_guard():\n\u001B[1;32m 640\u001B[0m \u001B[38;5;66;03m# Import the shared library into the process, thus running its\u001B[39;00m\n\u001B[1;32m 641\u001B[0m \u001B[38;5;66;03m# static (global) initialization code in order to register custom\u001B[39;00m\n\u001B[1;32m 642\u001B[0m \u001B[38;5;66;03m# operators with the JIT.\u001B[39;00m\n\u001B[0;32m--> 643\u001B[0m \u001B[43mctypes\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mCDLL\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 644\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mloaded_libraries\u001B[38;5;241m.\u001B[39madd(path)\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/ctypes/__init__.py:382\u001B[0m, in \u001B[0;36mCDLL.__init__\u001B[0;34m(self, name, mode, handle, use_errno, use_last_error, winmode)\u001B[0m\n\u001B[1;32m 379\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_FuncPtr \u001B[38;5;241m=\u001B[39m _FuncPtr\n\u001B[1;32m 381\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m handle \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m--> 382\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m \u001B[43m_dlopen\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_name\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 383\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 384\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m handle\n", + "\u001B[0;31mOSError\u001B[0m: dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib" + ] + } + ], + "source": [ + "from IPython.display import Audio\n", + "\n", + "from torchaudio.utils import download_asset\n", + "\n", + "SAMPLE_WAV = download_asset(\"tutorial-assets/steam-train-whistle-daniel_simon.wav\")\n", + "SAMPLE_RIR = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo-8000hz.wav\")\n", + "SAMPLE_SPEECH = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042-8000hz.wav\")\n", + "SAMPLE_NOISE = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo-8000hz.wav\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-11T20:15:46.720641Z", + "start_time": "2024-03-11T20:15:40.451397Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [], + "source": [ + "import IPython" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-11T18:24:40.427485Z", + "start_time": "2024-03-11T18:24:40.423546Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bach_old_1.mp3\r\n" + ] + } + ], + "source": [ + "ls scratch/files-with-reverb" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-11T18:25:01.398215Z", + "start_time": "2024-03-11T18:25:01.250743Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 21, + "outputs": [ + { + "data": { + "text/plain": "", + "text/html": "\n \n " + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "IPython.display.Audio('scratch/files-with-reverb/bach_old_1.mp3')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-11T18:25:42.144419Z", + "start_time": "2024-03-11T18:25:42.058763Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [ + { + "ename": "OSError", + "evalue": "dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mOSError\u001B[0m Traceback (most recent call last)", + "Input \u001B[0;32mIn [22]\u001B[0m, in \u001B[0;36m\u001B[0;34m()\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mpretrained\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m SpectralMaskEnhancement\n\u001B[1;32m 2\u001B[0m model \u001B[38;5;241m=\u001B[39m SpectralMaskEnhancement\u001B[38;5;241m.\u001B[39mfrom_hparams(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mspeechbrain/mtl-mimic-voicebank\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/__init__.py:4\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;124;03m\"\"\" Comprehensive speech processing toolkit\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mcore\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Stage, Brain, create_experiment_directory, parse_arguments\n\u001B[1;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m alignment \u001B[38;5;66;03m# noqa\u001B[39;00m\n\u001B[1;32m 6\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m dataio \u001B[38;5;66;03m# noqa\u001B[39;00m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/core.py:36\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 34\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mnn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mparallel\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m DistributedDataParallel \u001B[38;5;28;01mas\u001B[39;00m DDP\n\u001B[1;32m 35\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mhyperpyyaml\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m resolve_references\n\u001B[0;32m---> 36\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdistributed\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m run_on_main\n\u001B[1;32m 37\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataloader\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m LoopedLoader\n\u001B[1;32m 38\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdataloader\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m SaveableDataLoader\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/utils/__init__.py:11\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 8\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m filename\u001B[38;5;241m.\u001B[39mendswith(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.py\u001B[39m\u001B[38;5;124m\"\u001B[39m) \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m filename\u001B[38;5;241m.\u001B[39mstartswith(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m__\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[1;32m 9\u001B[0m __all__\u001B[38;5;241m.\u001B[39mappend(filename[:\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m3\u001B[39m])\n\u001B[0;32m---> 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/utils/parameter_transfer.py:12\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 9\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mlogging\u001B[39;00m\n\u001B[1;32m 10\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m\n\u001B[0;32m---> 12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mpretrained\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfetching\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m fetch\n\u001B[1;32m 13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mcheckpoints\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[1;32m 14\u001B[0m DEFAULT_LOAD_HOOKS,\n\u001B[1;32m 15\u001B[0m DEFAULT_TRANSFER_HOOKS,\n\u001B[1;32m 16\u001B[0m PARAMFILE_EXT,\n\u001B[1;32m 17\u001B[0m get_default_hook,\n\u001B[1;32m 18\u001B[0m )\n\u001B[1;32m 20\u001B[0m logger \u001B[38;5;241m=\u001B[39m logging\u001B[38;5;241m.\u001B[39mgetLogger(\u001B[38;5;18m__name__\u001B[39m)\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/pretrained/__init__.py:3\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;124;03m\"\"\"Pretrained models\"\"\"\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01minterfaces\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;241m*\u001B[39m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/speechbrain/pretrained/interfaces.py:18\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mspeechbrain\u001B[39;00m\n\u001B[1;32m 17\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m---> 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\n\u001B[1;32m 19\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01msentencepiece\u001B[39;00m\n\u001B[1;32m 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtypes\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m SimpleNamespace\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/__init__.py:1\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ( \u001B[38;5;66;03m# noqa: F401\u001B[39;00m\n\u001B[1;32m 2\u001B[0m _extension,\n\u001B[1;32m 3\u001B[0m compliance,\n\u001B[1;32m 4\u001B[0m datasets,\n\u001B[1;32m 5\u001B[0m functional,\n\u001B[1;32m 6\u001B[0m io,\n\u001B[1;32m 7\u001B[0m kaldi_io,\n\u001B[1;32m 8\u001B[0m models,\n\u001B[1;32m 9\u001B[0m pipelines,\n\u001B[1;32m 10\u001B[0m sox_effects,\n\u001B[1;32m 11\u001B[0m transforms,\n\u001B[1;32m 12\u001B[0m utils,\n\u001B[1;32m 13\u001B[0m )\n\u001B[1;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackend\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_audio_backend, list_audio_backends, set_audio_backend\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:103\u001B[0m, in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 99\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mException\u001B[39;00m:\n\u001B[1;32m 100\u001B[0m \u001B[38;5;28;01mpass\u001B[39;00m\n\u001B[0;32m--> 103\u001B[0m \u001B[43m_init_extension\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:88\u001B[0m, in \u001B[0;36m_init_extension\u001B[0;34m()\u001B[0m\n\u001B[1;32m 85\u001B[0m warnings\u001B[38;5;241m.\u001B[39mwarn(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtorchaudio C++ extension is not available.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 86\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m\n\u001B[0;32m---> 88\u001B[0m \u001B[43m_load_lib\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mlibtorchaudio\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 89\u001B[0m \u001B[38;5;66;03m# This import is for initializing the methods registered via PyBind11\u001B[39;00m\n\u001B[1;32m 90\u001B[0m \u001B[38;5;66;03m# This has to happen after the base library is loaded\u001B[39;00m\n\u001B[1;32m 91\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m _torchaudio \u001B[38;5;66;03m# noqa\u001B[39;00m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torchaudio/_extension.py:51\u001B[0m, in \u001B[0;36m_load_lib\u001B[0;34m(lib)\u001B[0m\n\u001B[1;32m 49\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m path\u001B[38;5;241m.\u001B[39mexists():\n\u001B[1;32m 50\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[0;32m---> 51\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mops\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_library\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 52\u001B[0m torch\u001B[38;5;241m.\u001B[39mclasses\u001B[38;5;241m.\u001B[39mload_library(path)\n\u001B[1;32m 53\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/site-packages/torch/_ops.py:643\u001B[0m, in \u001B[0;36m_Ops.load_library\u001B[0;34m(self, path)\u001B[0m\n\u001B[1;32m 638\u001B[0m path \u001B[38;5;241m=\u001B[39m _utils_internal\u001B[38;5;241m.\u001B[39mresolve_library_path(path)\n\u001B[1;32m 639\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m dl_open_guard():\n\u001B[1;32m 640\u001B[0m \u001B[38;5;66;03m# Import the shared library into the process, thus running its\u001B[39;00m\n\u001B[1;32m 641\u001B[0m \u001B[38;5;66;03m# static (global) initialization code in order to register custom\u001B[39;00m\n\u001B[1;32m 642\u001B[0m \u001B[38;5;66;03m# operators with the JIT.\u001B[39;00m\n\u001B[0;32m--> 643\u001B[0m \u001B[43mctypes\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mCDLL\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 644\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mloaded_libraries\u001B[38;5;241m.\u001B[39madd(path)\n", + "File \u001B[0;32m~/opt/anaconda3/lib/python3.9/ctypes/__init__.py:382\u001B[0m, in \u001B[0;36mCDLL.__init__\u001B[0;34m(self, name, mode, handle, use_errno, use_last_error, winmode)\u001B[0m\n\u001B[1;32m 379\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_FuncPtr \u001B[38;5;241m=\u001B[39m _FuncPtr\n\u001B[1;32m 381\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m handle \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m--> 382\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m \u001B[43m_dlopen\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_name\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 383\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 384\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle \u001B[38;5;241m=\u001B[39m handle\n", + "\u001B[0;31mOSError\u001B[0m: dlopen(/Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so, 0x0006): Symbol not found: __ZN2at4_ops15sum_dim_IntList4callERKNS_6TensorEN3c108ArrayRefIxEEbNS5_8optionalINS5_10ScalarTypeEEE\n Referenced from: <34C7FCDA-98E6-3DB6-B57D-478635DE1F58> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torchaudio/lib/libtorchaudio.so\n Expected in: <89972BE7-3028-34DA-B561-E66870D59767> /Users/spangher/opt/anaconda3/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib" + ] + } + ], + "source": [ + "from speechbrain.pretrained import SpectralMaskEnhancement\n", + "model = SpectralMaskEnhancement.from_hparams(\"speechbrain/mtl-mimic-voicebank\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-11T18:38:33.342704Z", + "start_time": "2024-03-11T18:38:09.584639Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 8f9f9cfcc22cc9fb882a40acd7e2ea6c210a2ea4 Mon Sep 17 00:00:00 2001 From: Alex Spangher Date: Mon, 11 Mar 2024 19:01:51 -0700 Subject: [PATCH 2/4] updated --- baselines/giantmidi/transcribe_new_files.py | 67 +++ baselines/hft_transformer/src/amt.py | 407 ++++++++++++++++++ .../hft_transformer/transcribe_new_files.py | 196 +++++++++ baselines/requirements-baselines.txt | 3 + 4 files changed, 673 insertions(+) create mode 100644 baselines/giantmidi/transcribe_new_files.py create mode 100644 baselines/hft_transformer/src/amt.py create mode 100644 baselines/hft_transformer/transcribe_new_files.py create mode 100644 baselines/requirements-baselines.txt diff --git a/baselines/giantmidi/transcribe_new_files.py b/baselines/giantmidi/transcribe_new_files.py new file mode 100644 index 0000000..0650c73 --- /dev/null +++ b/baselines/giantmidi/transcribe_new_files.py @@ -0,0 +1,67 @@ +import os +import argparse +import time +import torch +import piano_transcription_inference +import glob + + +def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None): + """Transcribe piano solo mp3s to midi files.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + os.makedirs(midis_dir, exist_ok=True) + + # Transcriptor + transcriptor = piano_transcription_inference.PianoTranscription(device=device) + + transcribe_time = time.time() + for n, mp3_path in enumerate(glob.glob(os.path.join(mp3s_dir, '*.mp3'))[begin_index:end_index]): + print(n, mp3_path) + midi_file = os.path.basename(mp3_path).replace('.mp3', '.midi') + midi_path = os.path.join(midis_dir, midi_file) + if os.path.exists(midi_path): + continue + + (audio, _) = ( + piano_transcription_inference + .load_audio(mp3_path, sr=piano_transcription_inference.sample_rate, mono=True) + ) + + try: + # Transcribe + transcribed_dict = transcriptor.transcribe(audio, midi_path) + print(transcribed_dict) + except: + print('Failed for this audio!') + + print('Time: {:.3f} s'.format(time.time() - transcribe_time)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Example of parser. ') + parser.add_argument('--mp3s_dir', type=str, required=True, help='') + parser.add_argument('--midis_dir', type=str, required=True, help='') + parser.add_argument( + '--begin_index', type=int, required=False, + help='File num., of an ordered list of files, to start transcribing from.', default=None + ) + parser.add_argument( + '--end_index', type=int, required=False, default=None, + help='File num., of an ordered list of files, to end transcription.' + ) + + # Parse arguments + args = parser.parse_args() + transcribe_piano( + mp3s_dir=args.mp3s_dir, + midis_dir=args.midis_dir, + begin_index=args.begin_index, + end_index=args.end_index + ) + +""" +python transcribe_new_files.py \ + transcribe_piano \ + --mp3s_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \ + --midis_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model +""" \ No newline at end of file diff --git a/baselines/hft_transformer/src/amt.py b/baselines/hft_transformer/src/amt.py new file mode 100644 index 0000000..45f97ff --- /dev/null +++ b/baselines/hft_transformer/src/amt.py @@ -0,0 +1,407 @@ +#! python + +import pickle +import torch +import numpy as np +import torchaudio +import pretty_midi + +class AMT(): + def __init__(self, config, model_path, batch_size=1, verbose_flag=False): + if verbose_flag is True: + print('torch version: '+torch.__version__) + print('torch cuda : '+str(torch.cuda.is_available())) + if torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + + self.config = config + + if model_path == None: + self.model = None + else: + with open(model_path, 'rb') as f: + self.model = pickle.load(f) + self.model = self.model.to(self.device) + self.model.eval() + if verbose_flag is True: + print(self.model) + + self.batch_size = batch_size + + + def wav2feature(self, f_wav): + ### torchaudio + # torchaudio.transforms.MelSpectrogram() + # default + # sapmle_rate(16000) + # win_length(n_fft) + # hop_length(win_length//2) + # n_fft(400) + # f_min(0) + # f_max(None) + # pad(0) + # n_mels(128) + # window_fn(hann_window) + # center(True) + # power(2.0) + # pad_mode(reflect) + # onesided(True) + # norm(None) + ## melfilter: htk + ## normalize: none -> slaney + + wave, sr = torchaudio.load(f_wav) + wave_mono = torch.mean(wave, dim=0) + tr_fsconv = torchaudio.transforms.Resample(sr, self.config['feature']['sr']) + wave_mono_16k = tr_fsconv(wave_mono) + tr_mel = torchaudio.transforms.MelSpectrogram( + sample_rate=self.config['feature']['sr'], + n_fft=self.config['feature']['fft_bins'], + win_length=self.config['feature']['window_length'], + hop_length=self.config['feature']['hop_sample'], + pad_mode=self.config['feature']['pad_mode'], + n_mels=self.config['feature']['mel_bins'], + norm='slaney' + ) + mel_spec = tr_mel(wave_mono_16k) + a_feature = (torch.log(mel_spec + self.config['feature']['log_offset'])).T + + return a_feature + + + def transcript(self, a_feature, mode='combination', ablation_flag=False): + # a_feature: [num_frame, n_mels] + a_feature = np.array(a_feature, dtype=np.float32) + + a_tmp_b = np.full([self.config['input']['margin_b'], self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) + len_s = int(np.ceil(a_feature.shape[0] / self.config['input']['num_frame']) * self.config['input']['num_frame']) - a_feature.shape[0] + a_tmp_f = np.full([len_s+self.config['input']['margin_f'], self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) + a_input = torch.from_numpy(np.concatenate([a_tmp_b, a_feature, a_tmp_f], axis=0)) + # a_input: [margin_b+a_feature.shape[0]+len_s+margin_f, n_bins] + + a_output_onset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_offset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_mpe_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_velocity_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) + + if mode == 'combination': + a_output_onset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_offset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_mpe_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_velocity_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) + + self.model.eval() + for i in range(0, a_feature.shape[0], self.config['input']['num_frame']): + input_spec = (a_input[i:i+self.config['input']['margin_b']+self.config['input']['num_frame']+self.config['input']['margin_f']]).T.unsqueeze(0).to(self.device) + + with torch.no_grad(): + if mode == 'combination': + if ablation_flag is True: + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) + else: + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) + # output_onset: [batch_size, n_frame, n_note] + # output_offset: [batch_size, n_frame, n_note] + # output_mpe: [batch_size, n_frame, n_note] + # output_velocity: [batch_size, n_frame, n_note, n_velocity] + else: + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A = self.model(input_spec) + + a_output_onset_A[i:i + self.config['input']['num_frame']] = (output_onset_A.squeeze(0)).to('cpu').detach().numpy() + a_output_offset_A[i:i + self.config['input']['num_frame']] = (output_offset_A.squeeze(0)).to('cpu').detach().numpy() + a_output_mpe_A[i:i + self.config['input']['num_frame']] = (output_mpe_A.squeeze(0)).to('cpu').detach().numpy() + a_output_velocity_A[i:i + self.config['input']['num_frame']] = (output_velocity_A.squeeze(0).argmax(2)).to('cpu').detach().numpy() + + if mode == 'combination': + a_output_onset_B[i:i+self.config['input']['num_frame']] = (output_onset_B.squeeze(0)).to('cpu').detach().numpy() + a_output_offset_B[i:i+self.config['input']['num_frame']] = (output_offset_B.squeeze(0)).to('cpu').detach().numpy() + a_output_mpe_B[i:i+self.config['input']['num_frame']] = (output_mpe_B.squeeze(0)).to('cpu').detach().numpy() + a_output_velocity_B[i:i+self.config['input']['num_frame']] = (output_velocity_B.squeeze(0).argmax(2)).to('cpu').detach().numpy() + + if mode == 'combination': + return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A, a_output_onset_B, a_output_offset_B, a_output_mpe_B, a_output_velocity_B + else: + return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A + + + def transcript_stride(self, a_feature, n_offset, mode='combination', ablation_flag=False): + # a_feature: [num_frame, n_mels] + a_feature = np.array(a_feature, dtype=np.float32) + + half_frame = int(self.config['input']['num_frame']/2) + a_tmp_b = np.full([self.config['input']['margin_b'] + n_offset, self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) + tmp_len = a_feature.shape[0] + self.config['input']['margin_b'] + self.config['input']['margin_f'] + half_frame + len_s = int(np.ceil(tmp_len / half_frame) * half_frame) - tmp_len + a_tmp_f = np.full([len_s+self.config['input']['margin_f']+(half_frame-n_offset), self.config['feature']['n_bins']], self.config['input']['min_value'], dtype=np.float32) + + a_input = torch.from_numpy(np.concatenate([a_tmp_b, a_feature, a_tmp_f], axis=0)) + # a_input: [n_offset+margin_b+a_feature.shape[0]+len_s+(half_frame-n_offset)+margin_f, n_bins] + + a_output_onset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_offset_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_mpe_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_velocity_A = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) + + if mode == 'combination': + a_output_onset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_offset_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_mpe_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.float32) + a_output_velocity_B = np.zeros((a_feature.shape[0]+len_s, self.config['midi']['num_note']), dtype=np.int8) + + self.model.eval() + for i in range(0, a_feature.shape[0], half_frame): + input_spec = (a_input[i:i+self.config['input']['margin_b']+self.config['input']['num_frame']+self.config['input']['margin_f']]).T.unsqueeze(0).to(self.device) + + with torch.no_grad(): + if mode == 'combination': + if ablation_flag is True: + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) + else: + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A, attention, output_onset_B, output_offset_B, output_mpe_B, output_velocity_B = self.model(input_spec) + # output_onset: [batch_size, n_frame, n_note] + # output_offset: [batch_size, n_frame, n_note] + # output_mpe: [batch_size, n_frame, n_note] + # output_velocity: [batch_size, n_frame, n_note, n_velocity] + else: + output_onset_A, output_offset_A, output_mpe_A, output_velocity_A = self.model(input_spec) + + a_output_onset_A[i:i+half_frame] = ( + (output_onset_A + .squeeze(0)[n_offset : n_offset+half_frame]) + .to('cpu').detach().numpy() + ) + a_output_offset_A[i:i+half_frame] = ( + (output_offset_A + .squeeze(0) + [n_offset:n_offset+half_frame]) + .to('cpu').detach().numpy() + ) + a_output_mpe_A[i:i+half_frame] = ( + (output_mpe_A + .squeeze(0) + [n_offset:n_offset+half_frame]) + .to('cpu').detach().numpy() + ) + a_output_velocity_A[i:i+half_frame] = (output_velocity_A.squeeze(0)[n_offset:n_offset+half_frame].argmax(2)).to('cpu').detach().numpy() + + if mode == 'combination': + a_output_onset_B[i:i+half_frame] = (output_onset_B.squeeze(0)[n_offset:n_offset+half_frame]).to('cpu').detach().numpy() + a_output_offset_B[i:i+half_frame] = (output_offset_B.squeeze(0)[n_offset:n_offset+half_frame]).to('cpu').detach().numpy() + a_output_mpe_B[i:i+half_frame] = (output_mpe_B.squeeze(0)[n_offset:n_offset+half_frame]).to('cpu').detach().numpy() + a_output_velocity_B[i:i+half_frame] = (output_velocity_B.squeeze(0)[n_offset:n_offset+half_frame].argmax(2)).to('cpu').detach().numpy() + + if mode == 'combination': + return ( + a_output_onset_A, + a_output_offset_A, + a_output_mpe_A, + a_output_velocity_A, + a_output_onset_B, + a_output_offset_B, + a_output_mpe_B, + a_output_velocity_B + ) + else: + return a_output_onset_A, a_output_offset_A, a_output_mpe_A, a_output_velocity_A + + + def mpe2note( + self, + a_onset=None, + a_offset=None, + a_mpe=None, + a_velocity=None, + thred_onset=0.5, + thred_offset=0.5, + thred_mpe=0.5, + mode_velocity='ignore_zero', + mode_offset='shorter' + ): + ## mode_velocity + ## org: 0-127 + ## ignore_zero: 0-127 (output note does not include 0) (default) + + ## mode_offset + ## shorter: use shorter one of mpe and offset (default) + ## longer : use longer one of mpe and offset + ## offset : use offset (ignore mpe) + + a_note = [] + hop_sec = float(self.config['feature']['hop_sample'] / self.config['feature']['sr']) + + for j in range(self.config['midi']['num_note']): + # find local maximum + a_onset_detect = [] + for i in range(len(a_onset)): + if a_onset[i][j] >= thred_onset: + left_flag = True + for ii in range(i-1, -1, -1): + if a_onset[i][j] > a_onset[ii][j]: + left_flag = True + break + elif a_onset[i][j] < a_onset[ii][j]: + left_flag = False + break + right_flag = True + for ii in range(i+1, len(a_onset)): + if a_onset[i][j] > a_onset[ii][j]: + right_flag = True + break + elif a_onset[i][j] < a_onset[ii][j]: + right_flag = False + break + if (left_flag is True) and (right_flag is True): + if (i == 0) or (i == len(a_onset) - 1): + onset_time = i * hop_sec + else: + if a_onset[i-1][j] == a_onset[i+1][j]: + onset_time = i * hop_sec + elif a_onset[i-1][j] > a_onset[i+1][j]: + onset_time = (i * hop_sec - (hop_sec * 0.5 * (a_onset[i-1][j] - a_onset[i+1][j]) / (a_onset[i][j] - a_onset[i+1][j]))) + else: + onset_time = (i * hop_sec + (hop_sec * 0.5 * (a_onset[i+1][j] - a_onset[i-1][j]) / (a_onset[i][j] - a_onset[i-1][j]))) + a_onset_detect.append({'loc': i, 'onset_time': onset_time}) + + a_offset_detect = [] + for i in range(len(a_offset)): + if a_offset[i][j] >= thred_offset: + left_flag = True + for ii in range(i-1, -1, -1): + if a_offset[i][j] > a_offset[ii][j]: + left_flag = True + break + elif a_offset[i][j] < a_offset[ii][j]: + left_flag = False + break + right_flag = True + for ii in range(i+1, len(a_offset)): + if a_offset[i][j] > a_offset[ii][j]: + right_flag = True + break + elif a_offset[i][j] < a_offset[ii][j]: + right_flag = False + break + if (left_flag is True) and (right_flag is True): + if (i == 0) or (i == len(a_offset) - 1): + offset_time = i * hop_sec + else: + if a_offset[i-1][j] == a_offset[i+1][j]: + offset_time = i * hop_sec + elif a_offset[i-1][j] > a_offset[i+1][j]: + offset_time = (i * hop_sec - (hop_sec * 0.5 * (a_offset[i-1][j] - a_offset[i+1][j]) / (a_offset[i][j] - a_offset[i+1][j]))) + else: + offset_time = (i * hop_sec + (hop_sec * 0.5 * (a_offset[i+1][j] - a_offset[i-1][j]) / (a_offset[i][j] - a_offset[i-1][j]))) + a_offset_detect.append({'loc': i, 'offset_time': offset_time}) + + time_next = 0.0 + time_offset = 0.0 + time_mpe = 0.0 + for idx_on in range(len(a_onset_detect)): + # onset + loc_onset = a_onset_detect[idx_on]['loc'] + time_onset = a_onset_detect[idx_on]['onset_time'] + + if idx_on + 1 < len(a_onset_detect): + loc_next = a_onset_detect[idx_on+1]['loc'] + # time_next = loc_next * hop_sec + time_next = a_onset_detect[idx_on+1]['onset_time'] + else: + loc_next = len(a_mpe) + time_next = (loc_next-1) * hop_sec + + # offset + loc_offset = loc_onset+1 + flag_offset = False + #time_offset = 0### + for idx_off in range(len(a_offset_detect)): + if loc_onset < a_offset_detect[idx_off]['loc']: + loc_offset = a_offset_detect[idx_off]['loc'] + time_offset = a_offset_detect[idx_off]['offset_time'] + flag_offset = True + break + if loc_offset > loc_next: + loc_offset = loc_next + time_offset = time_next + + # offset by MPE + # (1frame longer) + loc_mpe = loc_onset+1 + flag_mpe = False + # time_mpe = 0 ### + for ii_mpe in range(loc_onset+1, loc_next): + if a_mpe[ii_mpe][j] < thred_mpe: + loc_mpe = ii_mpe + flag_mpe = True + time_mpe = loc_mpe * hop_sec + break + ''' + # (right algorighm) + loc_mpe = loc_onset + flag_mpe = False + for ii_mpe in range(loc_onset+1, loc_next+1): + if a_mpe[ii_mpe][j] < thred_mpe: + loc_mpe = ii_mpe-1 + flag_mpe = True + time_mpe = loc_mpe * hop_sec + break + ''' + pitch_value = int(j+self.config['midi']['note_min']) + velocity_value = int(a_velocity[loc_onset][j]) + + if (flag_offset is False) and (flag_mpe is False): + offset_value = float(time_next) + elif (flag_offset is True) and (flag_mpe is False): + offset_value = float(time_offset) + elif (flag_offset is False) and (flag_mpe is True): + offset_value = float(time_mpe) + else: + if mode_offset == 'offset': + ## (a) offset + offset_value = float(time_offset) + elif mode_offset == 'longer': + ## (b) longer + if loc_offset >= loc_mpe: + offset_value = float(time_offset) + else: + offset_value = float(time_mpe) + else: + ## (c) shorter + if loc_offset <= loc_mpe: + offset_value = float(time_offset) + else: + offset_value = float(time_mpe) + if mode_velocity != 'ignore_zero': + a_note.append({'pitch': pitch_value, 'onset': float(time_onset), 'offset': offset_value, 'velocity': velocity_value}) + else: + if velocity_value > 0: + a_note.append({'pitch': pitch_value, 'onset': float(time_onset), 'offset': offset_value, 'velocity': velocity_value}) + + if ( + (len(a_note) > 1) and + (a_note[len(a_note)-1]['pitch'] == a_note[len(a_note)-2]['pitch']) and + (a_note[len(a_note)-1]['onset'] < a_note[len(a_note)-2]['offset']) + ): + a_note[len(a_note)-2]['offset'] = a_note[len(a_note)-1]['onset'] + + a_note = sorted(sorted(a_note, key=lambda x: x['pitch']), key=lambda x: x['onset']) + return a_note + + + def note2midi(self, a_note, f_midi): + midi = pretty_midi.PrettyMIDI() + instrument = pretty_midi.Instrument(program=0) + for note in a_note: + instrument.notes.append( + pretty_midi.Note( + velocity=note['velocity'], + pitch=note['pitch'], + start=note['onset'], + end=note['offset'] + ) + ) + midi.instruments.append(instrument) + midi.write(f_midi) + + return diff --git a/baselines/hft_transformer/transcribe_new_files.py b/baselines/hft_transformer/transcribe_new_files.py new file mode 100644 index 0000000..594bb44 --- /dev/null +++ b/baselines/hft_transformer/transcribe_new_files.py @@ -0,0 +1,196 @@ +#! python +import os +import argparse +import json +import sys +import glob +from baselines.hft_transformer.src import amt +from pydub import AudioSegment +from pydub.exceptions import CouldntDecodeError +import random +import torch +here = os.path.dirname(os.path.abspath(__file__)) + + +_AMT = None +def get_AMT(config_file=None, model_file=None): + global _AMT + if _AMT is None: + if config_file is None: + config_file = os.path.join(here, 'model_files/config-aug.json') + if model_file is None: + if torch.cuda.is_available(): + model_file = os.path.join(here, 'model_files/model-with-aug-data_006_009.pkl') + else: + model_file = os.path.join(here, 'model_files/model-with-aug-data_006_009_cpu.bin') + with open(config_file, 'r', encoding='utf-8') as f: + config = json.load(f) + if torch.cuda.is_available(): + _AMT = amt.AMT(config, model_file, verbose_flag=False) + else: + model = torch.load(model_file, map_location=torch.device('cpu')) + _AMT = amt.AMT(config, model_path=None, verbose_flag=False) + _AMT.model = model + return _AMT + +def check_and_convert_mp3_to_wav(fname): + wav_file = fname.replace('.mp3', '.wav') + if not os.path.exists(wav_file): + print('converting ' + fname + ' to .wav...') + try: + sound = AudioSegment.from_mp3(fname) + sound.export(fname.replace('.mp3', '.wav'), format="wav") + except CouldntDecodeError: + print('failed to convert ' + fname) + return None + return wav_file + + +def transcribe_file( + fname, + output_fname, + mode='combination', + thred_mpe=0.5, + thred_onset=0.5, + thred_offset=0.5, + n_stride=0, + ablation=False, + AMT=None +): + if AMT is None: + AMT = get_AMT() + + a_feature = AMT.wav2feature(fname) + + # transcript + if n_stride > 0: + output = AMT.transcript_stride(a_feature, n_stride, mode=mode, ablation_flag=ablation) + else: + output = AMT.transcript(a_feature, mode=mode, ablation_flag=ablation) + (output_1st_onset, output_1st_offset, output_1st_mpe, output_1st_velocity, + output_2nd_onset, output_2nd_offset, output_2nd_mpe, output_2nd_velocity) = output + + # note (mpe2note) + a_note_1st_predict = AMT.mpe2note( + a_onset=output_1st_onset, + a_offset=output_1st_offset, + a_mpe=output_1st_mpe, + a_velocity=output_1st_velocity, + thred_onset=thred_onset, + thred_offset=thred_offset, + thred_mpe=thred_mpe, + mode_velocity='ignore_zero', + mode_offset='shorter' + ) + + a_note_2nd_predict = AMT.mpe2note( + a_onset=output_2nd_onset, + a_offset=output_2nd_offset, + a_mpe=output_2nd_mpe, + a_velocity=output_2nd_velocity, + thred_onset=thred_onset, + thred_offset=thred_offset, + thred_mpe=thred_mpe, + mode_velocity='ignore_zero', + mode_offset='shorter' + ) + + AMT.note2midi(a_note_2nd_predict, output_fname) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # necessary arguments + parser.add_argument('-input_dir_to_transcribe', default=None, help='file list') + parser.add_argument('-input_file_to_transcribe', default=None, help='one file') + parser.add_argument('-output_dir', help='output directory') + parser.add_argument('-output_file', default=None, help='output file') + parser.add_argument('-f_config', help='config json file', default=None) + parser.add_argument('-model_file', help='input model file', default=None) + parser.add_argument('-start_index', help='start index', type=int, default=None) + parser.add_argument('-end_index', help='end index', type=int, default=None) + parser.add_argument('-skip_transcribe_mp3', action='store_true', default=False) + # parameters + parser.add_argument('-mode', help='mode to transcript (combination|single)', default='combination') + parser.add_argument('-thred_mpe', help='threshold value for mpe detection', type=float, default=0.5) + parser.add_argument('-thred_onset', help='threshold value for onset detection', type=float, default=0.5) + parser.add_argument('-thred_offset', help='threshold value for offset detection', type=float, default=0.5) + parser.add_argument('-n_stride', help='number of samples for offset', type=int, default=0) + parser.add_argument('-ablation', help='ablation mode', action='store_true') + args = parser.parse_args() + + assert (args.input_dir_to_transcribe is not None) or (args.input_file_to_transcribe is not None), "input file or directory is not specified" + + if args.input_dir_to_transcribe is not None: + if not args.skip_transcribe_mp3: + # list file + a_mp3s = ( + glob.glob(os.path.join(args.input_dir_to_transcribe, '*.mp3')) + + glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.mp3')) + ) + print(f'transcribing {len(a_mp3s)} files: [{str(a_mp3s)}]...') + list(map(check_and_convert_mp3_to_wav, a_mp3s)) + + a_list = ( + glob.glob(os.path.join(args.input_dir_to_transcribe, '*.wav')) + + glob.glob(os.path.join(args.input_dir_to_transcribe, '*', '*.wav')) + ) + if (args.start_index is not None) or (args.end_index is not None): + if args.start_index is None: + args.start_index = 0 + if args.end_index is None: + args.end_index = len(a_list) + a_list = a_list[args.start_index:args.end_index] + # shuffle a_list + random.shuffle(a_list) + + elif args.input_file_to_transcribe is not None: + args.input_file_to_transcribe = check_and_convert_mp3_to_wav(args.input_file_to_transcribe) + if args.input_file_to_transcribe is None: + sys.exit() + a_list = [args.input_file_to_transcribe] + print(f'transcribing {str(a_list)} files...') + + # load model + AMT = get_AMT(args.f_config, args.model_file) + + long_filename_counter = 0 + for fname in a_list: + if args.output_file is not None: + output_fname = args.output_file + else: + output_fname = fname.replace('.wav', '') + if len(output_fname) > 200: + output_fname = output_fname[:200] + f'_fnabbrev-{long_filename_counter}' + output_fname += '_transcribed.mid' + output_fname = os.path.join(args.output_dir, os.path.basename(output_fname)) + if os.path.exists(output_fname): + continue + + print('[' + fname + ']') + try: + transcribe_file( + fname, + output_fname, + args.mode, + args.thred_mpe, + args.thred_onset, + args.thred_offset, + args.n_stride, + args.ablation, + AMT, + ) + except Exception as e: + print(e) + continue + + print('** done **') + + +""" +e.g. usage: + +python evaluation/transcribe_new_files.py \ + -input_dir_to_transcribe evaluation/glenn-gould-bach-data \ + -output_dir hft-evaluation-data/ \ +""" diff --git a/baselines/requirements-baselines.txt b/baselines/requirements-baselines.txt new file mode 100644 index 0000000..b56d966 --- /dev/null +++ b/baselines/requirements-baselines.txt @@ -0,0 +1,3 @@ +pretty_midi +librosa +piano_transcription_inference From 198ca24001a8fa5def2d1ee53a3b6e769f8d4249 Mon Sep 17 00:00:00 2001 From: Alex Spangher Date: Mon, 11 Mar 2024 19:02:38 -0700 Subject: [PATCH 3/4] added soundfonts to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 2ec3467..026b7af 100644 --- a/.gitignore +++ b/.gitignore @@ -15,9 +15,11 @@ *.xml *.html *.htm +*.sf2 .idea/ notebooks/scratch +baselines/hft_transformer/model_files/ # Byte-compiled / optimized / DLL files __pycache__/ From c8f0fd451f4327a0120a9a3f04eb3206cb987811 Mon Sep 17 00:00:00 2001 From: Alex Spangher Date: Mon, 11 Mar 2024 19:02:52 -0700 Subject: [PATCH 4/4] updated for synthetic data creation --- amt/data.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/amt/data.py b/amt/data.py index 71982c5..2d61224 100644 --- a/amt/data.py +++ b/amt/data.py @@ -11,6 +11,81 @@ from amt.tokenizer import AmtTokenizer from amt.config import load_config from amt.audio import pad_or_trim +from midi2audio import FluidSynth +import random + + +class SyntheticMidiHandler: + def __init__(self, soundfont_path: str, soundfont_prob_dict: dict = None, num_wavs_per_midi: int = 1): + """ + File to load MIDI files and convert them to audio. + + Parameters + ---------- + soundfont_path : str + Path to the directory containing soundfont files. + soundfont_prob_dict : dict, optional + Dictionary containing the probability of using a soundfont file. + The keys are the soundfont file names and the values are the + probability of using the soundfont file. If none is given, then + a uniform distribution is used. + num_wavs_per_midi : int, optional + Number of audio files to generate per MIDI file. + """ + + self.soundfont_path = soundfont_path + self.soundfont_prob_dict = soundfont_prob_dict + self.num_wavs_per_midi = num_wavs_per_midi + + self.fs_objs = self._load_soundfonts() + self.soundfont_cumul_prob_dict = self._get_cumulative_prob_dict() + + def _load_soundfonts(self): + """Loads the soundfonts into fluidsynth objects.""" + fs_files = os.listdir(self.soundfont_path) + fs_objs = {} + for fs_file in fs_files: + fs_objs[fs_file] = FluidSynth(fs_file) + return fs_objs + + def _get_cumulative_prob_dict(self): + """Returns a dictionary with the cumulative probabilities of the soundfonts. + Used for sampling the soundfonts. + """ + if self.soundfont_prob_dict is None: + self.soundfont_prob_dict = {k: 1 / len(self.fs_objs) for k in self.fs_objs.keys()} + self.soundfont_prob_dict = {k: v / sum(self.soundfont_prob_dict.values()) + for k, v in self.soundfont_prob_dict.items()} + cumul_prob_dict = {} + cumul_prob = 0 + for k, v in self.soundfont_prob_dict.items(): + cumul_prob_dict[k] = (cumul_prob, cumul_prob + v) + cumul_prob += v + return cumul_prob_dict + + def _sample_soundfont(self): + """Samples a soundfont file.""" + rand_num = random.random() + for k, (v_s, v_e) in self.soundfont_cumul_prob_dict.items(): + if (rand_num >= v_s) and (rand_num < v_e): + return self.fs_objs[k] + + def get_wav(self, midi_path: str, save_path: str): + """ + Converts a MIDI file to audio. + + Parameters + ---------- + midi_path : str + Path to the MIDI file. + save_path : str + Path to save the audio file. + """ + for i in range(self.num_wavs_per_midi): + soundfont = self._sample_soundfont() + if self.num_wavs_per_midi > 1: + save_path = save_path[:-4] + f"_{i}.wav" + soundfont.midi_to_audio(midi_path, save_path) def get_wav_mid_segments(