From 1935ff3005bb070a6cd05d9ec31c1d1d9482c21c Mon Sep 17 00:00:00 2001 From: Jesse Engel Date: Tue, 8 Feb 2022 19:50:38 -0800 Subject: [PATCH] Unify padding for "frame" operations such as f0 and dB calculations. * Offer 3 padding modes: 'valid', 'same', and 'center' * get_framed_lengths() function for verifying expected lengths after framing. * Significantly revamp spectral_ops and prepare_tfrecord tests. * Update ddsp_prepare_tfrecord.py and to allow centered padding. * Add type annontations for prepare_tfrecord_lib.py. * Slim down prepare_tfrecord tests. PiperOrigin-RevId: 427359579 --- ddsp/__init__.py | 2 +- ddsp/colab/__init__.py | 2 +- ddsp/colab/colab_utils.py | 2 +- ddsp/core.py | 2 +- ddsp/core_test.py | 2 +- ddsp/dags.py | 2 +- ddsp/dags_test.py | 2 +- ddsp/effects.py | 2 +- ddsp/effects_test.py | 2 +- ddsp/losses.py | 2 +- ddsp/losses_test.py | 2 +- ddsp/processors.py | 2 +- ddsp/processors_test.py | 2 +- ddsp/spectral_ops.py | 183 ++++++++----- ddsp/spectral_ops_test.py | 259 ++++++++++-------- ddsp/synths.py | 2 +- ddsp/synths_test.py | 2 +- ddsp/test_util.py | 34 +++ ddsp/training/__init__.py | 2 +- ddsp/training/cloud.py | 2 +- ddsp/training/cloud_test.py | 2 +- ddsp/training/data.py | 2 +- ddsp/training/data_preparation/__init__.py | 2 +- .../ddsp_generate_synthetic_dataset.py | 2 +- .../data_preparation/ddsp_prepare_tfrecord.py | 31 ++- .../data_preparation/prepare_tfrecord_lib.py | 162 ++++++----- .../prepare_tfrecord_lib_test.py | 155 ++++++----- .../data_preparation/synthetic_data.py | 2 +- ddsp/training/ddsp_export.py | 2 +- ddsp/training/ddsp_run.py | 4 +- ddsp/training/decoders.py | 2 +- ddsp/training/decoders_test.py | 2 +- ddsp/training/docker/__init__.py | 2 +- ddsp/training/docker/ddsp_ai_platform.py | 2 +- ddsp/training/docker/task.py | 2 +- ddsp/training/docker/task_test.py | 2 +- ddsp/training/encoders.py | 2 +- ddsp/training/eval_util.py | 2 +- ddsp/training/evaluators.py | 2 +- ddsp/training/gin/__init__.py | 2 +- ddsp/training/gin/datasets/__init__.py | 2 +- ddsp/training/gin/eval/__init__.py | 2 +- ddsp/training/gin/models/__init__.py | 2 +- ddsp/training/gin/optimization/__init__.py | 2 +- ddsp/training/gin/papers/__init__.py | 2 +- ddsp/training/gin/papers/iclr2020/__init__.py | 2 +- ddsp/training/gin/papers/icml2020/__init__.py | 2 +- ddsp/training/heuristics.py | 2 +- ddsp/training/heuristics_test.py | 2 +- ddsp/training/inference.py | 2 +- ddsp/training/metrics.py | 22 +- ddsp/training/metrics_test.py | 58 ++-- ddsp/training/models/__init__.py | 2 +- ddsp/training/models/autoencoder.py | 2 +- ddsp/training/models/autoencoder_test.py | 2 +- ddsp/training/models/inverse_synthesis.py | 2 +- ddsp/training/models/midi_autoencoder.py | 2 +- ddsp/training/models/model.py | 2 +- ddsp/training/nn.py | 2 +- ddsp/training/nn_test.py | 2 +- ddsp/training/plotting.py | 2 +- ddsp/training/postprocessing.py | 2 +- ddsp/training/preprocessing.py | 2 +- ddsp/training/preprocessing_test.py | 2 +- ddsp/training/summaries.py | 2 +- ddsp/training/train_util.py | 6 +- ddsp/training/trainers.py | 2 +- ddsp/version.py | 4 +- setup.py | 2 +- update_gin_config.py | 2 +- 70 files changed, 609 insertions(+), 427 deletions(-) create mode 100644 ddsp/test_util.py diff --git a/ddsp/__init__.py b/ddsp/__init__.py index d8085855..d019836e 100644 --- a/ddsp/__init__.py +++ b/ddsp/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/colab/__init__.py b/ddsp/colab/__init__.py index f962edef..941b8629 100644 --- a/ddsp/colab/__init__.py +++ b/ddsp/colab/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/colab/colab_utils.py b/ddsp/colab/colab_utils.py index 61d0c5dd..6c6527f7 100644 --- a/ddsp/colab/colab_utils.py +++ b/ddsp/colab/colab_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/core.py b/ddsp/core.py index ad92b92b..ec2b710b 100644 --- a/ddsp/core.py +++ b/ddsp/core.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/core_test.py b/ddsp/core_test.py index 815196c6..289253a0 100644 --- a/ddsp/core_test.py +++ b/ddsp/core_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/dags.py b/ddsp/dags.py index 07b970b8..f126b022 100644 --- a/ddsp/dags.py +++ b/ddsp/dags.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/dags_test.py b/ddsp/dags_test.py index 8449d7ed..95c76001 100644 --- a/ddsp/dags_test.py +++ b/ddsp/dags_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/effects.py b/ddsp/effects.py index 126195b4..39a49758 100644 --- a/ddsp/effects.py +++ b/ddsp/effects.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/effects_test.py b/ddsp/effects_test.py index 8a21ce27..66fa3df4 100644 --- a/ddsp/effects_test.py +++ b/ddsp/effects_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/losses.py b/ddsp/losses.py index 231bc294..dfc11df7 100644 --- a/ddsp/losses.py +++ b/ddsp/losses.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/losses_test.py b/ddsp/losses_test.py index acba3534..4a5becf9 100644 --- a/ddsp/losses_test.py +++ b/ddsp/losses_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/processors.py b/ddsp/processors.py index 6399d0fc..182c5004 100644 --- a/ddsp/processors.py +++ b/ddsp/processors.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/processors_test.py b/ddsp/processors_test.py index 7d63a21f..224beddb 100644 --- a/ddsp/processors_test.py +++ b/ddsp/processors_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/spectral_ops.py b/ddsp/spectral_ops.py index 200a24df..a35fe43a 100644 --- a/ddsp/spectral_ops.py +++ b/ddsp/spectral_ops.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ import tensorflow_probability as tfp CREPE_SAMPLE_RATE = 16000 -_CREPE_FRAME_SIZE = 1024 +CREPE_FRAME_SIZE = 1024 F0_RANGE = 127.0 # MIDI. DB_RANGE = core.DB_RANGE # dB (80.0). @@ -57,18 +57,11 @@ def stft_np(audio, frame_size=2048, overlap=0.75, pad_end=True): is_2d = (len(audio.shape) == 2) if pad_end: - n_samples_initial = int(audio.shape[-1]) - n_frames = int(np.ceil(n_samples_initial / hop_size)) - n_samples_final = (n_frames - 1) * hop_size + frame_size - pad = n_samples_final - n_samples_initial - padding = ((0, 0), (0, pad)) if is_2d else ((0, pad),) - audio = np.pad(audio, padding, 'constant') + audio = pad(audio, frame_size, hop_size, 'same', axis=is_2d).numpy() def stft_fn(y): - return librosa.stft(y=y, - n_fft=int(frame_size), - hop_length=hop_size, - center=False).T + return librosa.stft( + y=y, n_fft=int(frame_size), hop_length=hop_size, center=False).T s = np.stack([stft_fn(a) for a in audio]) if is_2d else stft_fn(audio) return s @@ -143,28 +136,105 @@ def compute_mfcc(audio, return mfccs[..., :mfcc_bins] +def get_framed_lengths(input_length, frame_size, hop_size, padding='center'): + """Give a strided framing, such as tf.signal.frame, gives output lengths. + + Args: + input_length: Original length along the dimension to be framed. + frame_size: Size of frames for striding. + hop_size: Striding, space between frames. + padding: Type of padding to apply, ['valid', 'same', 'center']. 'valid' is + a no-op. 'same' applies padding to the end such that + n_frames = n_t / hop_size. 'center' applies padding to both ends such that + each frame timestamp is centered and n_frames = n_t / hop_size + 1. + + Returns: + n_frames: Number of frames left after striding. + padded_length: Length of the padded signal before striding. + """ + # Use numpy since this function isn't used dynamically. + def get_n_frames(length): + return int(np.floor((length - frame_size) / hop_size)) + 1 + + if padding == 'valid': + padded_length = input_length + n_frames = get_n_frames(input_length) + + elif padding == 'center': + padded_length = input_length + frame_size + n_frames = get_n_frames(padded_length) + + elif padding == 'same': + n_frames = int(np.ceil(input_length / hop_size)) + padded_length = (n_frames - 1) * hop_size + frame_size + + return n_frames, padded_length + + +def pad(x, frame_size, hop_size, padding='center', + axis=1, mode='CONSTANT', constant_values=0): + """Pad a tensor for strided framing such as tf.signal.frame. + + Args: + x: Tensor to pad, any shape. + frame_size: Size of frames for striding. + hop_size: Striding, space between frames. + padding: Type of padding to apply, ['valid', 'same', 'center']. 'valid' is + a no-op. 'same' applies padding to the end such that + n_frames = n_t / hop_size. 'center' applies padding to both ends such that + each frame timestamp is centered and n_frames = n_t / hop_size + 1. + axis: Axis along which to pad `x`. + mode: Padding mode for tf.pad(). One of "CONSTANT", "REFLECT", or + "SYMMETRIC" (case-insensitive). + constant_values: Passthrough kwarg for tf.pad(). + + Returns: + A padded version of `x` along axis. Output sizes can be computed separately + with strided_lengths. + """ + x = tf_float32(x) + + if padding == 'valid': + return x + + if hop_size > frame_size: + raise ValueError(f'During padding, frame_size ({frame_size})' + f' must be greater than hop_size ({hop_size}).') + + if len(x.shape) <= 1: + axis = 0 + + n_t = x.shape[axis] + _, n_t_padded = get_framed_lengths(n_t, frame_size, hop_size, padding) + pads = [[0, 0] for _ in range(len(x.shape))] + + if padding == 'same': + pad_amount = int(n_t_padded - n_t) + pads[axis] = [0, pad_amount] + + elif padding == 'center': + pad_amount = int(frame_size // 2) # Symmetric even padding like librosa. + pads[axis] = [pad_amount, pad_amount] + + else: + raise ValueError('`padding` must be one of [\'center\', \'same\'' + f'\'valid\'], received ({padding}).') + + return tf.pad(x, pads, mode=mode, constant_values=constant_values) + + def compute_rms_energy(audio, sample_rate=16000, frame_rate=250, frame_size=512, - pad_end=True): + padding='center'): """Compute root mean squared energy of audio.""" - n_samples = audio.shape[0] if len(audio.shape) == 1 else audio.shape[1] - n_secs = n_samples / float(sample_rate) # `n_secs` can have milliseconds - expected_len = int(n_secs * frame_rate) - audio = tf_float32(audio) - hop_size = sample_rate // frame_rate - audio_frames = tf.signal.frame(audio, frame_size, hop_size, pad_end=pad_end) + audio = pad(audio, frame_size, hop_size, padding=padding) + audio_frames = tf.signal.frame(audio, frame_size, hop_size, pad_end=False) rms_energy = tf.reduce_mean(audio_frames**2.0, axis=-1)**0.5 - if pad_end: - n_samples = audio.shape[0] if len(audio.shape) == 1 else audio.shape[1] - n_secs = n_samples / float(sample_rate) # `n_secs` can have milliseconds - expected_len = int(n_secs * frame_rate) - return pad_or_trim_to_expected_length(rms_energy, expected_len, use_tf=True) - else: - return rms_energy + return rms_energy def compute_power(audio, @@ -173,10 +243,10 @@ def compute_power(audio, frame_size=512, ref_db=0.0, range_db=DB_RANGE, - pad_end=True): + padding='center'): """Compute power of audio in dB.""" rms_energy = compute_rms_energy( - audio, sample_rate, frame_rate, frame_size, pad_end=pad_end) + audio, sample_rate, frame_rate, frame_size, padding=padding) power_db = core.amplitude_to_db( rms_energy, ref_db=ref_db, range_db=range_db, use_tf=True) return power_db @@ -190,13 +260,13 @@ def compute_loudness(audio, range_db=DB_RANGE, ref_db=0.0, use_tf=True, - pad_end=True): + padding='center'): """Perceptual loudness (weighted power) in dB. Function is differentiable if use_tf=True. Args: audio: Numpy ndarray or tensor. Shape [batch_size, audio_length] or - [batch_size,]. + [audio_length,]. sample_rate: Audio sample rate in Hz. frame_rate: Rate of loudness frames in Hz. n_fft: Fft window size. @@ -208,33 +278,29 @@ def compute_loudness(audio, n_fft=2048. With v2.0.0 it was set to 0.0 to be more consistent with power calculations that have a natural scale for 0 dB being amplitude=1.0. use_tf: Make function differentiable by using tensorflow. - pad_end: Add zero padding at end of audio (like `same` convolution). + padding: 'same', 'valid', or 'center'. Returns: Loudness in decibels. Shape [batch_size, n_frames] or [n_frames,]. """ - if sample_rate % frame_rate != 0: - raise ValueError( - 'frame_rate: {} must evenly divide sample_rate: {}.' - 'For default frame_rate: 250Hz, suggested sample_rate: 16kHz or 48kHz' - .format(frame_rate, sample_rate)) - # Pick tensorflow or numpy. lib = tf if use_tf else np reduce_mean = tf.reduce_mean if use_tf else np.mean stft_fn = stft if use_tf else stft_np # Make inputs tensors for tensorflow. - audio = tf_float32(audio) if use_tf else audio + frame_size = n_fft + hop_size = sample_rate // frame_rate + audio = pad(audio, frame_size, hop_size, padding=padding) + audio = audio if use_tf else np.array(audio) # Temporarily a batch dimension for single examples. is_1d = (len(audio.shape) == 1) audio = audio[lib.newaxis, :] if is_1d else audio # Take STFT. - hop_size = sample_rate // frame_rate - overlap = 1 - hop_size / n_fft - s = stft_fn(audio, frame_size=n_fft, overlap=overlap, pad_end=pad_end) + overlap = 1 - hop_size / frame_size + s = stft_fn(audio, frame_size=frame_size, overlap=overlap, pad_end=False) # Compute power. amplitude = lib.abs(s) @@ -258,36 +324,30 @@ def compute_loudness(audio, # Remove temporary batch dimension. loudness = loudness[0] if is_1d else loudness - # Compute expected length of loudness vector. - expected_secs = audio.shape[-1] / float(sample_rate) - expected_len = int(expected_secs * frame_rate) - - # Pad with `-range_db` noise floor or trim vector. - loudness = pad_or_trim_to_expected_length( - loudness, expected_len, -range_db, use_tf=use_tf) - return loudness @gin.register -def compute_f0(audio, sample_rate, frame_rate, viterbi=True): +def compute_f0(audio, frame_rate, viterbi=True, padding='center'): """Fundamental frequency (f0) estimate using CREPE. This function is non-differentiable and takes input as a numpy array. Args: - audio: Numpy ndarray of single audio example. Shape [audio_length,]. - sample_rate: Sample rate in Hz. + audio: Numpy ndarray of single audio (16kHz) example. Shape [audio_length,]. frame_rate: Rate of f0 frames in Hz. viterbi: Use Viterbi decoding to estimate f0. + padding: Apply zero-padding for centered frames. + 'same', 'valid', or 'center'. Returns: f0_hz: Fundamental frequency in Hz. Shape [n_frames,]. f0_confidence: Confidence in Hz estimate (scaled [0, 1]). Shape [n_frames,]. """ - - n_secs = len(audio) / float(sample_rate) # `n_secs` can have milliseconds + sample_rate = CREPE_SAMPLE_RATE crepe_step_size = 1000 / frame_rate # milliseconds - expected_len = int(n_secs * frame_rate) + hop_size = sample_rate // frame_rate + + audio = pad(audio, CREPE_FRAME_SIZE, hop_size, padding) audio = np.asarray(audio) # Compute f0 with crepe. @@ -299,14 +359,11 @@ def compute_f0(audio, sample_rate, frame_rate, viterbi=True): center=False, verbose=0) - # Postprocessing on f0_hz - f0_hz = pad_or_trim_to_expected_length(f0_hz, expected_len, 0) # pad with 0 + # Postprocessing. f0_hz = f0_hz.astype(np.float32) - - # Postprocessing on f0_confidence - f0_confidence = pad_or_trim_to_expected_length(f0_confidence, expected_len, 1) - f0_confidence = np.nan_to_num(f0_confidence) # Set nans to 0 in confidence f0_confidence = f0_confidence.astype(np.float32) + f0_confidence = np.nan_to_num(f0_confidence) # Set nans to 0 in confidence + return f0_hz, f0_confidence @@ -446,10 +503,12 @@ def normalize_frames(self, frames): frames /= std[:, None] return frames - def predict_f0_and_confidence(self, audio, viterbi=False): + def predict_f0_and_confidence(self, audio, viterbi=False, padding='center'): audio = audio[None, :] if len(audio.shape) == 1 else audio batch_size = audio.shape[0] + audio = pad(audio, self.frame_size, self.hop_size, padding=padding) + frames = self.batch_frames(audio) frames = self.normalize_frames(frames) acts = self.core_model(frames, training=False) diff --git a/ddsp/spectral_ops_test.py b/ddsp/spectral_ops_test.py index 621277b3..9e7e20fd 100644 --- a/ddsp/spectral_ops_test.py +++ b/ddsp/spectral_ops_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,25 +17,11 @@ from absl.testing import parameterized from ddsp import spectral_ops +from ddsp.test_util import gen_np_sinusoid import numpy as np import tensorflow.compat.v2 as tf -def gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec): - x = np.linspace(0, audio_len_sec, int(audio_len_sec * sample_rate)) - audio_sin = amp * (np.sin(2 * np.pi * frequency * x)) - return audio_sin - - -def gen_np_batched_sinusoids(frequency, amp, sample_rate, audio_len_sec, - batch_size): - batch_sinusoids = [ - gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec) - for _ in range(batch_size) - ] - return np.array(batch_sinusoids) - - class STFTTest(tf.test.TestCase): def test_tf_and_np_are_consistent(self): @@ -116,146 +102,193 @@ def setUp(self): self.amp = 0.75 self.frequency = 440.0 self.frame_rate = 250 + self.frame_size = 512 + + def expected_f0_length(self, audio, padding): + n_t = audio.shape[-1] + frame_size = spectral_ops.CREPE_FRAME_SIZE + hop_size = int(16000 // self.frame_rate) + expected_len, _ = spectral_ops.get_framed_lengths( + n_t, frame_size, hop_size, padding) + return expected_len + + def expected_db_length(self, audio, sr, padding): + n_t = audio.shape[-1] + hop_size = int(sr // self.frame_rate) + expected_len, _ = spectral_ops.get_framed_lengths( + n_t, self.frame_size, hop_size, padding) + return expected_len @parameterized.named_parameters( - ('16k_.21secs', 16000, .21), - ('24k_.21secs', 24000, .21), - ('44.1k_.21secs', 44100, .21), - ('48k_.21secs', 48000, .21), - ('16k_.4secs', 16000, .4), - ('24k_.4secs', 24000, .4), - ('44.1k_.4secs', 44100, .4), - ('48k_.4secs', 48000, .4), + ('same_.21secs', 'same', .21), + ('same_.4secs', 'same', .4), + ('center_.21secs', 'center', .21), + ('center_.4secs', 'center', .4), + ('valid_.21secs', 'valid', .21), + ('valid_.4secs', 'valid', .4), ) - def test_compute_f0_at_sample_rate(self, sample_rate, audio_len_sec): - audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, - audio_len_sec) - f0_hz, f0_confidence = spectral_ops.compute_f0(audio_sin, sample_rate, - self.frame_rate) - expected_f0_hz_and_f0_conf_len = int(self.frame_rate * audio_len_sec) - self.assertLen(f0_hz, expected_f0_hz_and_f0_conf_len) - self.assertLen(f0_confidence, expected_f0_hz_and_f0_conf_len) + def test_compute_f0(self, padding, audio_len_sec): + """Ensure that compute_f0 (crepe) has expected output shape.""" + sr = 16000 + audio_sin = gen_np_sinusoid(self.frequency, self.amp, sr, audio_len_sec) + expected_len = self.expected_f0_length(audio_sin, padding) + f0_hz, f0_confidence = spectral_ops.compute_f0( + audio_sin, self.frame_rate, viterbi=True, padding=padding) + self.assertLen(f0_hz, expected_len) + self.assertLen(f0_confidence, expected_len) self.assertTrue(np.all(np.isfinite(f0_hz))) self.assertTrue(np.all(np.isfinite(f0_confidence))) - @parameterized.named_parameters( - ('16k_.21secs', 16000, .21), - ('24k_.21secs', 24000, .21), - ('48k_.21secs', 48000, .21), - ('16k_.4secs', 16000, .4), - ('24k_.4secs', 24000, .4), - ('48k_.4secs', 48000, .4), - ) - def test_compute_loudness_at_sample_rate_1d(self, sample_rate, audio_len_sec): + def test_batch_compute_db(self): + """Ensure that compute_(loudness/power) can work on a batch.""" + batch_size = 2 + sample_rate = 16000 + audio_len_sec = 0.21 + padding = 'same' audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, audio_len_sec) - expected_loudness_len = int(self.frame_rate * audio_len_sec) - - for use_tf in [False, True]: - loudness = spectral_ops.compute_loudness( - audio_sin, sample_rate, self.frame_rate, use_tf=use_tf) - self.assertLen(loudness, expected_loudness_len) - self.assertTrue(np.all(np.isfinite(loudness))) - - @parameterized.named_parameters( - ('16k_.21secs', 16000, .21), - ('24k_.21secs', 24000, .21), - ('48k_.21secs', 48000, .21), - ('16k_.4secs', 16000, .4), - ('24k_.4secs', 24000, .4), - ('48k_.4secs', 48000, .4), - ) - def test_compute_loudness_at_sample_rate_2d(self, sample_rate, audio_len_sec): - batch_size = 8 - audio_sin_batch = gen_np_batched_sinusoids(self.frequency, self.amp, - sample_rate, audio_len_sec, - batch_size) - expected_loudness_len = int(self.frame_rate * audio_len_sec) - - for use_tf in [False, True]: - loudness_batch = spectral_ops.compute_loudness( - audio_sin_batch, sample_rate, self.frame_rate, use_tf=use_tf) - - self.assertEqual(loudness_batch.shape[0], batch_size) - self.assertEqual(loudness_batch.shape[1], expected_loudness_len) - self.assertTrue(np.all(np.isfinite(loudness_batch))) - - # Check if batched loudness is equal to equivalent single computations - audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, - audio_len_sec) - loudness_target = spectral_ops.compute_loudness( - audio_sin, sample_rate, self.frame_rate, use_tf=use_tf) - loudness_batch_target = np.tile(loudness_target, (batch_size, 1)) - # Allow tolerance within 1dB - self.assertAllClose(loudness_batch, loudness_batch_target, atol=1, rtol=1) + expected_len = self.expected_db_length(audio_sin, sample_rate, padding) + audio_batch = tf.tile(audio_sin[None, :], [batch_size, 1]) + loudness = spectral_ops.compute_loudness( + audio_batch, sample_rate, self.frame_rate, self.frame_size, + padding=padding) + power = spectral_ops.compute_power( + audio_batch, sample_rate, self.frame_rate, self.frame_size, + padding=padding) + self.assertLen(loudness.shape, 2) + self.assertLen(power.shape, 2) + self.assertEqual(batch_size, loudness.shape[0]) + self.assertEqual(batch_size, power.shape[0]) + self.assertEqual(expected_len, loudness.shape[1]) + self.assertEqual(expected_len, power.shape[1]) + + def test_compute_loudness_tf_np(self): + """Ensure that compute_loudness is the same output for np and tf.""" + sample_rate = 16000 + audio_len_sec = 0.21 + audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, + audio_len_sec) + loudness_tf = spectral_ops.compute_loudness( + audio_sin, sample_rate, self.frame_rate, self.frame_size, use_tf=True) + loudness_np = spectral_ops.compute_loudness( + audio_sin, sample_rate, self.frame_rate, self.frame_size, use_tf=False) + # Allow tolerance within 1dB + self.assertAllClose(loudness_tf.numpy(), loudness_np, atol=1, rtol=1) @parameterized.named_parameters( ('16k_.21secs', 16000, .21), ('24k_.21secs', 24000, .21), - ('48k_.21secs', 48000, .21), + ('44.1k_.21secs', 44100, .21), ('16k_.4secs', 16000, .4), ('24k_.4secs', 24000, .4), - ('48k_.4secs', 48000, .4), + ('44.1k_.4secs', 44100, .4), ) - def test_tf_compute_loudness_at_sample_rate(self, sample_rate, audio_len_sec): + def test_compute_loudness(self, sample_rate, audio_len_sec): + """Ensure that compute_loudness has expected output shape.""" + padding = 'center' audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, audio_len_sec) - loudness = spectral_ops.compute_loudness(audio_sin, sample_rate, - self.frame_rate) - expected_loudness_len = int(self.frame_rate * audio_len_sec) - self.assertLen(loudness, expected_loudness_len) + expected_len = self.expected_db_length(audio_sin, sample_rate, padding) + loudness = spectral_ops.compute_loudness( + audio_sin, sample_rate, self.frame_rate, self.frame_size, + padding=padding) + self.assertLen(loudness, expected_len) self.assertTrue(np.all(np.isfinite(loudness))) @parameterized.named_parameters( - ('44.1k_.21secs', 44100, .21), - ('44.1k_.4secs', 44100, .4), + ('same', 'same'), + ('valid', 'valid'), + ('center', 'center'), ) - def test_compute_loudness_indivisible_rates_raises_error( - self, sample_rate, audio_len_sec): + def test_compute_loudness_padding(self, padding): + """Ensure that compute_loudness works with different paddings.""" + sample_rate = 16000 + audio_len_sec = 0.21 audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, audio_len_sec) - - for use_tf in [False, True]: - with self.assertRaises(ValueError): - spectral_ops.compute_loudness( - audio_sin, sample_rate, self.frame_rate, use_tf=use_tf) + expected_len = self.expected_db_length(audio_sin, sample_rate, padding) + loudness = spectral_ops.compute_loudness( + audio_sin, sample_rate, self.frame_rate, self.frame_size, + padding=padding) + self.assertLen(loudness, expected_len) + self.assertTrue(np.all(np.isfinite(loudness))) @parameterized.named_parameters( ('16k_.21secs', 16000, .21), ('24k_.21secs', 24000, .21), - ('48k_.21secs', 48000, .21), + ('44.1k_.21secs', 44100, .21), ('16k_.4secs', 16000, .4), ('24k_.4secs', 24000, .4), - ('48k_.4secs', 48000, .4), + ('44.1k_.4secs', 44100, .4), ) def test_compute_rms_energy(self, sample_rate, audio_len_sec): + """Ensure that compute_rms_energy has expected output shape.""" + padding = 'center' audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, audio_len_sec) - expected_rmse_len = int(self.frame_rate * audio_len_sec) - + expected_len = self.expected_db_length(audio_sin, sample_rate, padding) rms_energy = spectral_ops.compute_rms_energy( - audio_sin, sample_rate, self.frame_rate) - self.assertLen(rms_energy, expected_rmse_len) + audio_sin, sample_rate, self.frame_rate, self.frame_size, + padding=padding) + self.assertLen(rms_energy, expected_len) self.assertTrue(np.all(np.isfinite(rms_energy))) @parameterized.named_parameters( - ('16k_.21secs', 16000, .21), - ('24k_.21secs', 24000, .21), - ('48k_.21secs', 48000, .21), - ('16k_.4secs', 16000, .4), - ('24k_.4secs', 24000, .4), - ('48k_.4secs', 48000, .4), + ('same', 'same'), + ('valid', 'valid'), + ('center', 'center'), ) - def test_compute_power(self, sample_rate, audio_len_sec): + def test_compute_power_padding(self, padding): + """Ensure that compute_power (-> +rms) work with different paddings.""" + sample_rate = 16000 + audio_len_sec = 0.21 audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, audio_len_sec) - expected_power_len = int(self.frame_rate * audio_len_sec) - + expected_len = self.expected_db_length(audio_sin, sample_rate, padding) power = spectral_ops.compute_power( - audio_sin, sample_rate, self.frame_rate) - self.assertLen(power, expected_power_len) + audio_sin, sample_rate, self.frame_rate, self.frame_size, + padding=padding) + self.assertLen(power, expected_len) self.assertTrue(np.all(np.isfinite(power))) +class PadTest(parameterized.TestCase, tf.test.TestCase): + + def test_pad_end_stft_is_consistent(self): + """Ensure that spectral_ops.pad('same') is same as stft(pad_end=True).""" + frame_size = 200 + hop_size = 180 + audio = tf.random.normal([1, 1000]) + padded_audio = spectral_ops.pad(audio, frame_size, hop_size, 'same') + s_pad_end = tf.signal.stft(audio, frame_size, hop_size, pad_end=True) + s_same = tf.signal.stft(padded_audio, frame_size, hop_size, pad_end=False) + self.assertAllClose(np.abs(s_pad_end), np.abs(s_same), rtol=1e-3, atol=1e-3) + + @parameterized.named_parameters( + ('valid_odd', 'valid', 180), + ('same_odd', 'same', 180), + ('center_odd', 'center', 180), + ('valid_even', 'valid', 200), + ('same_even', 'same', 200), + ('center_even', 'center', 200), + ) + def test_padding_shapes_are_correct(self, padding, hop_size): + """Ensure that pad() and get_framed_lengths() have correct shapes.""" + frame_size = 200 + n_t = 1000 + audio = tf.random.normal([1, n_t]) + padded_audio = spectral_ops.pad(audio, frame_size, hop_size, padding) + n_t_pad = padded_audio.shape[1] + + frames = tf.signal.frame(padded_audio, frame_size, hop_size) + n_frames = frames.shape[1] + + exp_n_frames, exp_n_t_pad = spectral_ops.get_framed_lengths( + n_t, frame_size, hop_size, padding) + + self.assertEqual(n_frames, exp_n_frames) + self.assertEqual(n_t_pad, exp_n_t_pad) + + if __name__ == '__main__': tf.test.main() diff --git a/ddsp/synths.py b/ddsp/synths.py index ec528986..97081c0f 100644 --- a/ddsp/synths.py +++ b/ddsp/synths.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/synths_test.py b/ddsp/synths_test.py index a8151bcd..dde0c69a 100644 --- a/ddsp/synths_test.py +++ b/ddsp/synths_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/test_util.py b/ddsp/test_util.py new file mode 100644 index 00000000..b36bd534 --- /dev/null +++ b/ddsp/test_util.py @@ -0,0 +1,34 @@ +# Copyright 2022 The DDSP Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Library of helper functions for testing.""" + +import numpy as np + + +def gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec): + x = np.linspace(0, audio_len_sec, int(audio_len_sec * sample_rate)) + audio_sin = amp * (np.sin(2 * np.pi * frequency * x)) + return audio_sin + + +def gen_np_batched_sinusoids(frequency, amp, sample_rate, audio_len_sec, + batch_size): + batch_sinusoids = [ + gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec) + for _ in range(batch_size) + ] + return np.array(batch_sinusoids) + diff --git a/ddsp/training/__init__.py b/ddsp/training/__init__.py index 30713ecb..baafdbf5 100644 --- a/ddsp/training/__init__.py +++ b/ddsp/training/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/cloud.py b/ddsp/training/cloud.py index 3b297c6c..b72c67f4 100644 --- a/ddsp/training/cloud.py +++ b/ddsp/training/cloud.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/cloud_test.py b/ddsp/training/cloud_test.py index 11bb0195..157c6fde 100644 --- a/ddsp/training/cloud_test.py +++ b/ddsp/training/cloud_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/data.py b/ddsp/training/data.py index 57fa377e..3bb5c626 100644 --- a/ddsp/training/data.py +++ b/ddsp/training/data.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/data_preparation/__init__.py b/ddsp/training/data_preparation/__init__.py index 34f439a5..f9ad3300 100644 --- a/ddsp/training/data_preparation/__init__.py +++ b/ddsp/training/data_preparation/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/data_preparation/ddsp_generate_synthetic_dataset.py b/ddsp/training/data_preparation/ddsp_generate_synthetic_dataset.py index 822f2139..e56ce5e0 100644 --- a/ddsp/training/data_preparation/ddsp_generate_synthetic_dataset.py +++ b/ddsp/training/data_preparation/ddsp_generate_synthetic_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py b/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py index f7a4c0f3..e2641e32 100644 --- a/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py +++ b/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,17 +54,28 @@ 'length using a sliding window. If 0, each full piece of audio will be ' 'used as an example.') flags.DEFINE_float( - 'sliding_window_hop_secs', 1, - 'The hop size in seconds to use when splitting audio into constant-length ' - 'examples.') + 'hop_secs', 1, + 'The hop size between example start points (in seconds), when splitting ' + 'audio into constant-length examples.') flags.DEFINE_float( 'eval_split_fraction', 0.0, 'Fraction of the dataset to reserve for eval split. If set to 0, no eval ' 'split is created.' ) flags.DEFINE_float( - 'coarse_chunk_secs', 20.0, - 'Chunk size in seconds used to split the input audio files.') + 'chunk_secs', 20.0, + 'Chunk size in seconds used to split the input audio files. These ' + 'non-overlapping chunks are partitioned into train and eval sets if ' + 'eval_split_fraction > 0. This is used to split large audio files into ' + 'manageable chunks for better parallelization and to enable ' + 'non-overlapping train/eval splits.') +flags.DEFINE_boolean( + 'center', False, + 'Add padding to audio such that frame timestamps are centered. Increases ' + 'number of frames by one.') +flags.DEFINE_boolean( + 'viterbi', True, + 'Use viterbi decoding of pitch.') flags.DEFINE_list( 'pipeline_options', '--runner=DirectRunner', 'A comma-separated list of command line arguments to be used as options ' @@ -82,10 +93,12 @@ def run(): num_shards=FLAGS.num_shards, sample_rate=FLAGS.sample_rate, frame_rate=FLAGS.frame_rate, - window_secs=FLAGS.example_secs, - hop_secs=FLAGS.sliding_window_hop_secs, + example_secs=FLAGS.example_secs, + hop_secs=FLAGS.hop_secs, eval_split_fraction=FLAGS.eval_split_fraction, - coarse_chunk_secs=FLAGS.coarse_chunk_secs, + chunk_secs=FLAGS.chunk_secs, + center=FLAGS.center, + viterbi=FLAGS.viterbi, pipeline_options=FLAGS.pipeline_options) diff --git a/ddsp/training/data_preparation/prepare_tfrecord_lib.py b/ddsp/training/data_preparation/prepare_tfrecord_lib.py index acf50a5c..a71fea24 100644 --- a/ddsp/training/data_preparation/prepare_tfrecord_lib.py +++ b/ddsp/training/data_preparation/prepare_tfrecord_lib.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ -def _load_audio_as_array(audio_path: str, sample_rate: int) -> np.array: +def _load_audio_as_array(audio_path, sample_rate): """Load audio file at specified sample rate and return an array. When `sample_rate` > original SR of audio file, Pydub may miss samples when @@ -61,23 +61,24 @@ def _load_audio(audio_path, sample_rate): return {'audio': audio} -def add_loudness(ex, sample_rate, frame_rate, n_fft=2048): - """Add loudness in dB.""" - beam.metrics.Metrics.counter('prepare-tfrecord', 'compute-loudness').inc() +def _chunk_audio(ex, sample_rate, chunk_secs): + """Pad audio and split into chunks.""" + beam.metrics.Metrics.counter('prepare-tfrecord', 'load-audio').inc() audio = ex['audio'] - mean_loudness_db = spectral_ops.compute_loudness(audio, sample_rate, - frame_rate, n_fft, - use_tf=False) - ex = dict(ex) - ex['loudness_db'] = mean_loudness_db.astype(np.float32) - return ex + chunk_size = int(chunk_secs * sample_rate) + chunks = tf.signal.frame(audio, chunk_size, chunk_size, pad_end=True) + n_chunks = chunks.shape[0] + for i in range(n_chunks): + yield {'audio': chunks[i]} -def _add_f0_estimate(ex, sample_rate, frame_rate): +def _add_f0_estimate(ex, frame_rate, center, viterbi): """Add fundamental frequency (f0) estimate using CREPE.""" beam.metrics.Metrics.counter('prepare-tfrecord', 'estimate-f0').inc() audio = ex['audio'] - f0_hz, f0_confidence = spectral_ops.compute_f0(audio, sample_rate, frame_rate) + padding = 'center' if center else 'same' + f0_hz, f0_confidence = spectral_ops.compute_f0( + audio, frame_rate, viterbi=viterbi, padding=padding) ex = dict(ex) ex.update({ 'f0_hz': f0_hz.astype(np.float32), @@ -86,24 +87,38 @@ def _add_f0_estimate(ex, sample_rate, frame_rate): return ex -def split_example(ex, sample_rate, frame_rate, window_secs, hop_secs): +def _add_loudness(ex, sample_rate, frame_rate, n_fft, center): + """Add loudness in dB.""" + beam.metrics.Metrics.counter('prepare-tfrecord', 'compute-loudness').inc() + audio = ex['audio'] + padding = 'center' if center else 'same' + loudness_db = spectral_ops.compute_loudness( + audio, sample_rate, frame_rate, n_fft, padding=padding) + ex = dict(ex) + ex['loudness_db'] = loudness_db.numpy().astype(np.float32) + return ex + + +def _split_example(ex, sample_rate, frame_rate, example_secs, hop_secs, center): """Splits example into windows, padding final window if needed.""" - def get_windows(sequence, rate): - window_size = int(window_secs * rate) + def get_windows(sequence, rate, center): + window_size = int(example_secs * rate) + if center: + window_size += 1 hop_size = int(hop_secs * rate) - n_windows = int(np.ceil((len(sequence) - window_size) / hop_size)) + 1 - n_samples_padded = (n_windows - 1) * hop_size + window_size - n_padding = n_samples_padded - len(sequence) - sequence = np.pad(sequence, (0, n_padding), mode='constant') - for window_end in range(window_size, len(sequence) + 1, hop_size): - yield sequence[window_end - window_size:window_end] + # Don't pad the end. + n_windows = int(np.floor((len(sequence) - window_size) / hop_size)) + 1 + for i in range(n_windows): + start = i * hop_size + end = start + window_size + yield sequence[start:end] for audio, loudness_db, f0_hz, f0_confidence in zip( - get_windows(ex['audio'], sample_rate), - get_windows(ex['loudness_db'], frame_rate), - get_windows(ex['f0_hz'], frame_rate), - get_windows(ex['f0_confidence'], frame_rate)): + get_windows(ex['audio'], sample_rate, center=False), + get_windows(ex['loudness_db'], frame_rate, center), + get_windows(ex['f0_hz'], frame_rate, center), + get_windows(ex['f0_confidence'], frame_rate, center)): beam.metrics.Metrics.counter('prepare-tfrecord', 'split-example').inc() yield { 'audio': audio, @@ -113,7 +128,7 @@ def get_windows(sequence, rate): } -def float_dict_to_tfexample(float_dict): +def _float_dict_to_tfexample(float_dict): """Convert dictionary of float arrays to tf.train.Example proto.""" return tf.train.Example( features=tf.train.Features( @@ -123,12 +138,12 @@ def float_dict_to_tfexample(float_dict): })) -def add_key(example): +def _add_key(example): """Add a key to this example by taking the hash of the values.""" return hash(example['audio'].tobytes()), example -def eval_split_partition_fn(example, num_partitions, eval_fraction, all_ids): +def _eval_split_partition_fn(example, num_partitions, eval_fraction, all_ids): """Partition function to split into train/eval based on the hash ids.""" del num_partitions example_id = example[0] @@ -144,10 +159,12 @@ def prepare_tfrecord(input_audio_paths, num_shards=None, sample_rate=16000, frame_rate=250, - window_secs=4, + example_secs=4, hop_secs=1, eval_split_fraction=0.0, - coarse_chunk_secs=20.0, + chunk_secs=20.0, + center=False, + viterbi=True, pipeline_options=''): """Prepares a TFRecord for use in training, evaluation, and prediction. @@ -161,18 +178,43 @@ def prepare_tfrecord(input_audio_paths, sample_rate: The sample rate to use for the audio. frame_rate: The frame rate to use for f0 and loudness features. If set to None, these features will not be computed. - window_secs: The size of the sliding window (in seconds) to use to split the - audio and features. If 0, they will not be split. + example_secs: The size of the sliding window (in seconds) to use to split + the audio and features. If 0, they will not be split. hop_secs: The number of seconds to hop when computing the sliding windows. eval_split_fraction: Fraction of the dataset to reserve for eval split. If set to 0, no eval split is created. - coarse_chunk_secs: Chunk size in seconds used to split the input audio - files. This is used to split large audio files into manageable chunks - for better parallelization and to enable non-overlapping train/eval - splits. + chunk_secs: Chunk size in seconds used to split the input audio + files. This is used to split large audio files into manageable chunks for + better parallelization and to enable non-overlapping train/eval splits. + center: Provide zero-padding to audio so that frame timestamps will be + centered. + viterbi: Use viterbi decoding of pitch. pipeline_options: An iterable of command line arguments to be used as options for the Beam Pipeline. """ + def postprocess_pipeline(examples, output_path, stage_name=''): + """After chunking, features, and train-eval split, create TFExamples.""" + if stage_name: + stage_name = f'_{stage_name}' + + if example_secs: + examples |= f'split_examples{stage_name}' >> beam.FlatMap( + _split_example, + sample_rate=sample_rate, + frame_rate=frame_rate, + example_secs=example_secs, + hop_secs=hop_secs, + center=center) + _ = ( + examples + | f'reshuffle{stage_name}' >> beam.Reshuffle() + | f'make_tfexample{stage_name}' >> beam.Map(_float_dict_to_tfexample) + | f'write{stage_name}' >> beam.io.tfrecordio.WriteToTFRecord( + output_path, + num_shards=num_shards, + coder=beam.coders.ProtoCoder(tf.train.Example))) + + # Start the pipeline. pipeline_options = beam.options.pipeline_options.PipelineOptions( pipeline_options) with beam.Pipeline(options=pipeline_options) as pipeline: @@ -181,36 +223,32 @@ def prepare_tfrecord(input_audio_paths, | beam.Create(input_audio_paths) | beam.Map(_load_audio, sample_rate)) + # Split into chunks for train/eval split and better parallelism. + if chunk_secs: + examples |= beam.FlatMap( + _chunk_audio, + sample_rate=sample_rate, + chunk_secs=chunk_secs) + + # Add features. if frame_rate: examples = ( examples - | beam.Map(_add_f0_estimate, sample_rate, frame_rate) - | beam.Map(add_loudness, sample_rate, frame_rate)) - - if coarse_chunk_secs: - examples |= beam.FlatMap(split_example, sample_rate, frame_rate, - coarse_chunk_secs, coarse_chunk_secs) - - def postprocess_pipeline(examples, output_path, stage_name=''): - if stage_name: - stage_name = f'_{stage_name}' - - if window_secs: - examples |= f'create_batches{stage_name}' >> beam.FlatMap( - split_example, sample_rate, frame_rate, window_secs, hop_secs) - _ = ( - examples - | f'reshuffle{stage_name}' >> beam.Reshuffle() - | f'make_tfexample{stage_name}' >> beam.Map(float_dict_to_tfexample) - | f'write{stage_name}' >> beam.io.tfrecordio.WriteToTFRecord( - output_path, - num_shards=num_shards, - coder=beam.coders.ProtoCoder(tf.train.Example))) - + | beam.Map(_add_f0_estimate, + frame_rate=frame_rate, + center=center, + viterbi=viterbi) + | beam.Map(_add_loudness, + sample_rate=sample_rate, + frame_rate=frame_rate, + n_fft=512, + center=center)) + + # Create train/eval split. if eval_split_fraction: - examples |= beam.Map(add_key) + examples |= beam.Map(_add_key) keys = examples | beam.Keys() - splits = examples | beam.Partition(eval_split_partition_fn, 2, + splits = examples | beam.Partition(_eval_split_partition_fn, 2, eval_split_fraction, beam.pvalue.AsList(keys)) diff --git a/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py b/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py index bbbb1fbf..7fac6270 100644 --- a/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py +++ b/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,13 +21,14 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized +from ddsp import spectral_ops from ddsp.training.data_preparation import prepare_tfrecord_lib import numpy as np import scipy.io.wavfile import tensorflow.compat.v2 as tf -class ProcessTaskBeamTest(parameterized.TestCase): +class PrepareTFRecordBeamTest(parameterized.TestCase): def get_tempdir(self): try: @@ -43,7 +44,7 @@ def setUp(self): # Write test wav file. self.wav_sr = 22050 - self.wav_secs = 2.1 + self.wav_secs = 0.5 self.wav_path = os.path.join(self.test_dir, 'test.wav') scipy.io.wavfile.write( self.wav_path, @@ -73,110 +74,124 @@ def validate_outputs(self, expected_num_examples, expected_feature_lengths): raise AssertionError('%s feature: %s' % (e, feat)) self.assertFalse(any(np.isinf(arr))) - @parameterized.named_parameters(('16k', 16000), ('24k', 24000), - ('48k', 48000)) - def test_prepare_tfrecord(self, sample_rate): - frame_rate = 250 - window_secs = 1 - hop_secs = 0.5 - prepare_tfrecord_lib.prepare_tfrecord( - [self.wav_path], - os.path.join(self.test_dir, 'output.tfrecord'), - num_shards=2, - sample_rate=sample_rate, - frame_rate=frame_rate, - window_secs=window_secs, - hop_secs=hop_secs, - coarse_chunk_secs=None) - - expected_f0_and_loudness_length = int(window_secs * frame_rate) - self.validate_outputs( - 4, { - 'audio': window_secs * sample_rate, - 'f0_hz': expected_f0_and_loudness_length, - 'f0_confidence': expected_f0_and_loudness_length, - 'loudness_db': expected_f0_and_loudness_length, - }) + def get_expected_length(self, input_length, frame_rate, center=False): + sample_rate = 16000 # Features at CREPE_SAMPLE_RATE. + frame_size = 1024 # Unused for this calculation. + hop_size = sample_rate // frame_rate + padding = 'center' if center else 'same' + n_frames, _ = spectral_ops.get_framed_lengths( + input_length, frame_size, hop_size, padding) + return n_frames + + @staticmethod + def get_n_per_chunk(chunk_length, example_secs, hop_secs): + """Convenience function to calculate number examples from striding.""" + n = (chunk_length - example_secs) / hop_secs + # Deal with limited float precision that causes (.3 / .1) = 2.9999.... + return int(np.floor(np.round(n, decimals=3))) + 1 - @parameterized.named_parameters(('16k', 16000), ('24k', 24000), - ('48k', 48000)) - def test_prepare_tfrecord_no_split(self, sample_rate): + @parameterized.named_parameters( + ('chunk_and_split', 0.3, 0.2), + ('no_chunk', None, 0.2), + ('no_split', 0.3, None), + ('no_chunk_no_split', None, None), + ) + def test_prepare_tfrecord(self, chunk_secs, example_secs): + sample_rate = 16000 frame_rate = 250 + hop_secs = 0.1 + + # Calculate expected batch size. + if example_secs: + length = chunk_secs if chunk_secs else self.wav_secs + n_per_chunk = self.get_n_per_chunk(length, example_secs, hop_secs) + else: + n_per_chunk = 1 + + n_chunks = int(np.ceil(self.wav_secs / chunk_secs)) if chunk_secs else 1 + expected_n_batch = n_per_chunk * n_chunks + print('n_chunks, n_per_chunk, chunk_secs, example_secs', + n_chunks, n_per_chunk, chunk_secs, example_secs) + + # Calculate expected lengths. + if example_secs: + length = example_secs + elif chunk_secs: + length = chunk_secs + else: + length = self.wav_secs + + expected_n_t = int(length * sample_rate) + expected_n_frames = self.get_expected_length(expected_n_t, frame_rate) + + # Make the actual records. prepare_tfrecord_lib.prepare_tfrecord( [self.wav_path], os.path.join(self.test_dir, 'output.tfrecord'), num_shards=2, sample_rate=sample_rate, frame_rate=frame_rate, - window_secs=None, - coarse_chunk_secs=None) + example_secs=example_secs, + hop_secs=hop_secs, + chunk_secs=chunk_secs, + center=False) - expected_f0_and_loudness_length = int(self.wav_secs * frame_rate) self.validate_outputs( - 1, { - 'audio': int(self.wav_secs * sample_rate), - 'f0_hz': expected_f0_and_loudness_length, - 'f0_confidence': expected_f0_and_loudness_length, - 'loudness_db': expected_f0_and_loudness_length, + expected_n_batch, + { + 'audio': expected_n_t, + 'f0_hz': expected_n_frames, + 'f0_confidence': expected_n_frames, + 'loudness_db': expected_n_frames, }) - @parameterized.named_parameters(('16k', 16000), ('24k', 24000), - ('48k', 48000)) - def test_prepare_tfrecord_chunk(self, sample_rate): + @parameterized.named_parameters(('no_center', False), ('center', True)) + def test_centering(self, center): frame_rate = 250 - chunk_secs = 1.5 + sample_rate = 16000 + example_secs = 0.3 + hop_secs = 0.1 + n_batch = self.get_n_per_chunk(self.wav_secs, example_secs, hop_secs) prepare_tfrecord_lib.prepare_tfrecord( [self.wav_path], os.path.join(self.test_dir, 'output.tfrecord'), num_shards=2, sample_rate=sample_rate, frame_rate=frame_rate, - window_secs=None, - coarse_chunk_secs=chunk_secs) - - expected_f0_and_loudness_length = int(chunk_secs * frame_rate) + example_secs=example_secs, + hop_secs=hop_secs, + center=center, + chunk_secs=None) + n_t = int(example_secs * sample_rate) + n_frames = self.get_expected_length(n_t, frame_rate, center) + n_expected_frames = 76 if center else 75 # (250 * 0.3) [+1]. + self.assertEqual(n_frames, n_expected_frames) self.validate_outputs( - 2, { - 'audio': int(chunk_secs * sample_rate), - 'f0_hz': expected_f0_and_loudness_length, - 'f0_confidence': expected_f0_and_loudness_length, - 'loudness_db': expected_f0_and_loudness_length, + n_batch, { + 'audio': n_t, + 'f0_hz': n_frames, + 'f0_confidence': n_frames, + 'loudness_db': n_frames, }) @parameterized.named_parameters(('16k', 16000), ('24k', 24000), ('48k', 48000)) - def test_prepare_tfrecord_no_f0_and_loudness(self, sample_rate): + def test_audio_only(self, sample_rate): prepare_tfrecord_lib.prepare_tfrecord( [self.wav_path], os.path.join(self.test_dir, 'output.tfrecord'), num_shards=2, sample_rate=sample_rate, frame_rate=None, - window_secs=None, - coarse_chunk_secs=None) + example_secs=None, + chunk_secs=None) self.validate_outputs( 1, { 'audio': int(self.wav_secs * sample_rate), }) - @parameterized.named_parameters( - ('44.1k', 44100),) - def test_prepare_tfrecord_at_indivisible_sample_rate_throws_error( - self, sample_rate): - frame_rate = 250 - with self.assertRaises(ValueError): - prepare_tfrecord_lib.prepare_tfrecord([self.wav_path], - os.path.join( - self.test_dir, - 'output.tfrecord'), - num_shards=2, - sample_rate=sample_rate, - frame_rate=frame_rate, - window_secs=None, - coarse_chunk_secs=None) - if __name__ == '__main__': absltest.main() diff --git a/ddsp/training/data_preparation/synthetic_data.py b/ddsp/training/data_preparation/synthetic_data.py index f4c88dc7..38231245 100644 --- a/ddsp/training/data_preparation/synthetic_data.py +++ b/ddsp/training/data_preparation/synthetic_data.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/ddsp_export.py b/ddsp/training/ddsp_export.py index 20f23200..7df3a38b 100644 --- a/ddsp/training/ddsp_export.py +++ b/ddsp/training/ddsp_export.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/ddsp_run.py b/ddsp/training/ddsp_run.py index f7f710df..34ca4c91 100644 --- a/ddsp/training/ddsp_run.py +++ b/ddsp/training/ddsp_run.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ from ddsp.training import trainers import gin import pkg_resources -import tensorflow.compat.v2 as tf +import tensorflow as tf gfile = tf.io.gfile FLAGS = flags.FLAGS diff --git a/ddsp/training/decoders.py b/ddsp/training/decoders.py index 740940be..0ffbed3b 100644 --- a/ddsp/training/decoders.py +++ b/ddsp/training/decoders.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/decoders_test.py b/ddsp/training/decoders_test.py index d3a0d6fc..08a5aae7 100644 --- a/ddsp/training/decoders_test.py +++ b/ddsp/training/decoders_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/docker/__init__.py b/ddsp/training/docker/__init__.py index 7de792ba..da61c6d6 100644 --- a/ddsp/training/docker/__init__.py +++ b/ddsp/training/docker/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/docker/ddsp_ai_platform.py b/ddsp/training/docker/ddsp_ai_platform.py index 64865a04..767e4673 100644 --- a/ddsp/training/docker/ddsp_ai_platform.py +++ b/ddsp/training/docker/ddsp_ai_platform.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/docker/task.py b/ddsp/training/docker/task.py index 6518feef..95d417c4 100644 --- a/ddsp/training/docker/task.py +++ b/ddsp/training/docker/task.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/docker/task_test.py b/ddsp/training/docker/task_test.py index 1b637119..42fbcce5 100644 --- a/ddsp/training/docker/task_test.py +++ b/ddsp/training/docker/task_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/encoders.py b/ddsp/training/encoders.py index d12830bc..0b1f469f 100644 --- a/ddsp/training/encoders.py +++ b/ddsp/training/encoders.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/eval_util.py b/ddsp/training/eval_util.py index b9dea615..2880f0e4 100644 --- a/ddsp/training/eval_util.py +++ b/ddsp/training/eval_util.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/evaluators.py b/ddsp/training/evaluators.py index fb0724f1..6d89a131 100644 --- a/ddsp/training/evaluators.py +++ b/ddsp/training/evaluators.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/__init__.py b/ddsp/training/gin/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/__init__.py +++ b/ddsp/training/gin/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/datasets/__init__.py b/ddsp/training/gin/datasets/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/datasets/__init__.py +++ b/ddsp/training/gin/datasets/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/eval/__init__.py b/ddsp/training/gin/eval/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/eval/__init__.py +++ b/ddsp/training/gin/eval/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/models/__init__.py b/ddsp/training/gin/models/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/models/__init__.py +++ b/ddsp/training/gin/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/optimization/__init__.py b/ddsp/training/gin/optimization/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/optimization/__init__.py +++ b/ddsp/training/gin/optimization/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/papers/__init__.py b/ddsp/training/gin/papers/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/papers/__init__.py +++ b/ddsp/training/gin/papers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/papers/iclr2020/__init__.py b/ddsp/training/gin/papers/iclr2020/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/papers/iclr2020/__init__.py +++ b/ddsp/training/gin/papers/iclr2020/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/gin/papers/icml2020/__init__.py b/ddsp/training/gin/papers/icml2020/__init__.py index f962edef..941b8629 100644 --- a/ddsp/training/gin/papers/icml2020/__init__.py +++ b/ddsp/training/gin/papers/icml2020/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/heuristics.py b/ddsp/training/heuristics.py index 2aef7d19..6440aa4c 100644 --- a/ddsp/training/heuristics.py +++ b/ddsp/training/heuristics.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/heuristics_test.py b/ddsp/training/heuristics_test.py index 80a7ccf6..c30189b4 100644 --- a/ddsp/training/heuristics_test.py +++ b/ddsp/training/heuristics_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/inference.py b/ddsp/training/inference.py index 36e89689..2ef8cf96 100644 --- a/ddsp/training/inference.py +++ b/ddsp/training/inference.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/metrics.py b/ddsp/training/metrics.py index 4102f413..ce1c51fb 100644 --- a/ddsp/training/metrics.py +++ b/ddsp/training/metrics.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,19 +54,18 @@ def is_outlier(ground_truth_f0_conf): return np.max(ground_truth_f0_conf) < MIN_F0_CONFIDENCE -def compute_audio_features(audio, - n_fft=512, - sample_rate=16000, - frame_rate=250): +def compute_audio_features(audio, frame_rate=250): """Compute features from audio.""" audio_feats = {'audio': audio} audio = squeeze(audio) + # Requires 16kHz for CREPE. + sample_rate = ddsp.spectral_ops.CREPE_SAMPLE_RATE audio_feats['loudness_db'] = ddsp.spectral_ops.compute_loudness( - audio, sample_rate, frame_rate, n_fft) + audio, sample_rate, frame_rate) audio_feats['f0_hz'], audio_feats['f0_confidence'] = ( - ddsp.spectral_ops.compute_f0(audio, sample_rate, frame_rate)) + ddsp.spectral_ops.compute_f0(audio, frame_rate)) return audio_feats @@ -194,12 +193,13 @@ def update_state(self, batch, audio_gen): loudness_original = batch['loudness_db'] else: loudness_original = ddsp.spectral_ops.compute_loudness( - batch['audio'], - sample_rate=self._sample_rate, frame_rate=self._frame_rate) + batch['audio'], sample_rate=self._sample_rate, + frame_rate=self._frame_rate) # Compute loudness across entire batch loudness_gen = ddsp.spectral_ops.compute_loudness( - audio_gen, sample_rate=self._sample_rate, frame_rate=self._frame_rate) + audio_gen, sample_rate=self._sample_rate, + frame_rate=self._frame_rate) batch_size = int(audio_gen.shape[0]) for i in range(batch_size): @@ -242,7 +242,6 @@ def update_state(self, batch, audio_gen): # Extract f0 from generated audio example. f0_hz_gen, _ = ddsp.spectral_ops.compute_f0( audio_gen[i], - sample_rate=self._sample_rate, frame_rate=self._frame_rate, viterbi=True) if 'f0_hz' in batch and 'f0_confidence' in batch: @@ -252,7 +251,6 @@ def update_state(self, batch, audio_gen): # Missing f0 in ground truth, extract it. f0_hz_gt, f0_conf_gt = ddsp.spectral_ops.compute_f0( batch['audio'][i], - sample_rate=self._sample_rate, frame_rate=self._frame_rate, viterbi=True) diff --git a/ddsp/training/metrics_test.py b/ddsp/training/metrics_test.py index c5417602..8de66e23 100644 --- a/ddsp/training/metrics_test.py +++ b/ddsp/training/metrics_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,9 @@ from unittest import mock from absl.testing import parameterized -from ddsp.spectral_ops_test import gen_np_batched_sinusoids -from ddsp.spectral_ops_test import gen_np_sinusoid +import ddsp +from ddsp.test_util import gen_np_batched_sinusoids +from ddsp.test_util import gen_np_sinusoid import ddsp.training.metrics as ddsp_metrics import numpy as np import tensorflow.compat.v2 as tf @@ -33,6 +34,15 @@ def setUp(self): self.amp = 0.75 self.frequency = 440.0 self.frame_rate = 250 + self.sample_rate = 16000 + + def expected_length(self, audio): + n_t = audio.shape[-1] + frame_size = ddsp.spectral_ops.CREPE_FRAME_SIZE + hop_size = int(self.sample_rate // self.frame_rate) + expected_len, _ = ddsp.spectral_ops.get_framed_lengths( + n_t, frame_size, hop_size) + return expected_len def validate_output_shapes(self, audio_features, expected_feature_lengths): for feat, expected_len in expected_feature_lengths.items(): @@ -40,46 +50,28 @@ def validate_output_shapes(self, audio_features, expected_feature_lengths): try: self.assertLen(arr, expected_len) except AssertionError as e: - raise AssertionError('%s feature: %s' % (e, feat)) + raise AssertionError('%s feature: %s' % (e, feat)) from e self.assertTrue(np.all(np.isfinite(arr))) @parameterized.named_parameters( - ('16k_.21secs', 16000, .21), - ('24k_.21secs', 24000, .21), - ('48k_.21secs', 48000, .21), - ('16k_.4secs', 16000, .4), - ('24k_.4secs', 24000, .4), - ('48k_.4secs', 48000, .4), + ('0.21secs', .21), + ('0.4secs', .4), ) - def test_correct_shape_compute_af_at_sample_rate(self, sample_rate, - audio_len_sec): - audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, + def test_correct_shape_compute_af_at_sample_rate(self, audio_len_sec): + audio_sin = gen_np_sinusoid(self.frequency, self.amp, self.sample_rate, audio_len_sec) + exp_length = self.expected_length(audio_sin) audio_features = ddsp_metrics.compute_audio_features( - audio_sin, sample_rate=sample_rate, frame_rate=self.frame_rate) + audio_sin, frame_rate=self.frame_rate) - expected_f0_and_loudness_length = int(audio_len_sec * self.frame_rate) self.validate_output_shapes( audio_features, { - 'audio': audio_len_sec * sample_rate, - 'f0_hz': expected_f0_and_loudness_length, - 'f0_confidence': expected_f0_and_loudness_length, - 'loudness_db': expected_f0_and_loudness_length, + 'audio': audio_len_sec * self.sample_rate, + 'f0_hz': exp_length, + 'f0_confidence': exp_length, + 'loudness_db': exp_length, }) - @parameterized.named_parameters( - ('44.1k_.21secs', 44100, .21), - ('44.1k_.4secs', 44100, .4), - ) - def test_indivisible_rates_raises_error_compute_af(self, sample_rate, - audio_len_sec): - audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, - audio_len_sec) - - with self.assertRaises(ValueError): - ddsp_metrics.compute_audio_features( - audio_sin, sample_rate=sample_rate, frame_rate=self.frame_rate) - class MetricsObjectsTest(parameterized.TestCase, tf.test.TestCase): @@ -114,7 +106,7 @@ def gen_batch_of_features(cls, batch_of_audio): batch_size = batch_of_audio.shape[0] audio = batch_of_audio[0] feats = ddsp_metrics.compute_audio_features( - audio, sample_rate=cls.sample_rate, frame_rate=cls.frame_rate) + audio, frame_rate=cls.frame_rate) for k, v in feats.items(): feats[k] = np.tile(v[np.newaxis, :], [batch_size, 1]) return feats diff --git a/ddsp/training/models/__init__.py b/ddsp/training/models/__init__.py index 16240efa..3952265c 100644 --- a/ddsp/training/models/__init__.py +++ b/ddsp/training/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/models/autoencoder.py b/ddsp/training/models/autoencoder.py index c7a6cb19..3c7535a9 100644 --- a/ddsp/training/models/autoencoder.py +++ b/ddsp/training/models/autoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/models/autoencoder_test.py b/ddsp/training/models/autoencoder_test.py index eb597c81..337201a1 100644 --- a/ddsp/training/models/autoencoder_test.py +++ b/ddsp/training/models/autoencoder_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/models/inverse_synthesis.py b/ddsp/training/models/inverse_synthesis.py index a1058300..97865471 100644 --- a/ddsp/training/models/inverse_synthesis.py +++ b/ddsp/training/models/inverse_synthesis.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/models/midi_autoencoder.py b/ddsp/training/models/midi_autoencoder.py index 880ecdc4..9fba697f 100644 --- a/ddsp/training/models/midi_autoencoder.py +++ b/ddsp/training/models/midi_autoencoder.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/models/model.py b/ddsp/training/models/model.py index 30e7975d..3d97d970 100644 --- a/ddsp/training/models/model.py +++ b/ddsp/training/models/model.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/nn.py b/ddsp/training/nn.py index 471b4280..b28ac31c 100644 --- a/ddsp/training/nn.py +++ b/ddsp/training/nn.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/nn_test.py b/ddsp/training/nn_test.py index 29a2f313..ea9a3b36 100644 --- a/ddsp/training/nn_test.py +++ b/ddsp/training/nn_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/plotting.py b/ddsp/training/plotting.py index 5761081b..476c0c44 100644 --- a/ddsp/training/plotting.py +++ b/ddsp/training/plotting.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/postprocessing.py b/ddsp/training/postprocessing.py index e7c2e2ff..00c5499d 100644 --- a/ddsp/training/postprocessing.py +++ b/ddsp/training/postprocessing.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/preprocessing.py b/ddsp/training/preprocessing.py index 90a81a45..e4caae3a 100644 --- a/ddsp/training/preprocessing.py +++ b/ddsp/training/preprocessing.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/preprocessing_test.py b/ddsp/training/preprocessing_test.py index 0b268ef2..89b6618b 100644 --- a/ddsp/training/preprocessing_test.py +++ b/ddsp/training/preprocessing_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/summaries.py b/ddsp/training/summaries.py index c364f1c5..26c11193 100644 --- a/ddsp/training/summaries.py +++ b/ddsp/training/summaries.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/training/train_util.py b/ddsp/training/train_util.py index 32d8ca81..a13a2a78 100644 --- a/ddsp/training/train_util.py +++ b/ddsp/training/train_util.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -111,11 +111,11 @@ def get_latest_file(dir_path, prefix='operative_config-', suffix='.gin'): get_iter = lambda fp: abs(int(fp.split(dir_prefix)[-1].split(suffix)[0])) latest_file = max(file_paths, key=get_iter) return latest_file - except ValueError: + except ValueError as verror: raise FileNotFoundError( f'Files found with pattern \'{search_pattern}\' do not match ' f'the pattern \'{dir_prefix}[iteration_number]{suffix}\'.\n\n' - f'Files found:\n{file_paths}') + f'Files found:\n{file_paths}') from verror def get_latest_checkpoint(checkpoint_path): diff --git a/ddsp/training/trainers.py b/ddsp/training/trainers.py index c59ecfff..3e71fd2c 100644 --- a/ddsp/training/trainers.py +++ b/ddsp/training/trainers.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ddsp/version.py b/ddsp/version.py index 6b8bb2cc..ae69df64 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '2.0.0' +__version__ = '3.0.0' diff --git a/setup.py b/setup.py index 65aa2e83..07d00544 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/update_gin_config.py b/update_gin_config.py index 9ac0dea1..09c73716 100644 --- a/update_gin_config.py +++ b/update_gin_config.py @@ -1,4 +1,4 @@ -# Copyright 2021 The DDSP Authors. +# Copyright 2022 The DDSP Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.