Skip to content

Commit

Permalink
Update RecordProvider to allow centered padding datasets.
Browse files Browse the repository at this point in the history
* Small helper function to core for NaNs.

PiperOrigin-RevId: 427597538
  • Loading branch information
jesseengel authored and Magenta Team committed Feb 10, 2022
1 parent 1935ff3 commit a38a0b3
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 45 deletions.
29 changes: 7 additions & 22 deletions ddsp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def diff(x, axis=-1):


# Math -------------------------------------------------------------------------
def nan_to_num(x, value=0.0):
"""Replace NaNs with value."""
return tf.where(tf.math.is_nan(x), value * tf.ones_like(x), x)


def safe_divide(numerator, denominator, eps=1e-7):
"""Avoid dividing by zero by adding a small epsilon."""
safe_denominator = tf.where(denominator == 0.0, eps, denominator)
Expand Down Expand Up @@ -710,26 +715,6 @@ def upsample_with_windows(inputs: tf.Tensor,
return x[:, hop_size:-hop_size, :]


# TODO(jesseengel): Axis param, don't assume axis=1.
def center_pad(audio, frame_size, mode='CONSTANT'):
"""Pad an audio signal such that timestamps align to the center of frames.
Without centering, timestamps align to the front of frames.
Args:
audio: Input, shape [batch, time, ...].
frame_size: Size of each frame.
mode: Padding mode for tf.pad. One of "CONSTANT", "REFLECT", or
"SYMMETRIC" (case-insensitive).
Returns:
audio_padded: Shape [batch, time + (frame_size // 2) * 2, ...].
"""
pad_amount = int(frame_size // 2) # Symmetric even padding like librosa.
pads = [[0, 0] for _ in range(len(audio.shape))]
pads[1] = [pad_amount, pad_amount]
return tf.pad(audio, pads, mode=mode)


def center_crop(audio, frame_size):
"""Remove padding introduced from centering frames.
Expand Down Expand Up @@ -852,8 +837,8 @@ def angular_cumsum(angular_frequency, chunk_size=1000):
# Pad if needed.
remainder = n_time % chunk_size
if remainder:
pad = chunk_size - remainder
angular_frequency = pad_axis(angular_frequency, [0, pad], axis=1)
pad_amount = chunk_size - remainder
angular_frequency = pad_axis(angular_frequency, [0, pad_amount], axis=1)

# Split input into chunks.
length = angular_frequency.shape[1]
Expand Down
20 changes: 16 additions & 4 deletions ddsp/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os

from absl import logging
from ddsp.spectral_ops import get_framed_lengths
import gin
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -186,14 +187,24 @@ def __init__(self,
example_secs,
sample_rate,
frame_rate,
data_format_map_fn):
data_format_map_fn,
centered=False):
"""RecordProvider constructor."""
self._file_pattern = file_pattern or self.default_file_pattern
self._audio_length = example_secs * sample_rate
self._feature_length = example_secs * frame_rate
super().__init__(sample_rate, frame_rate)
self._feature_length = self.get_feature_length(centered)
self._data_format_map_fn = data_format_map_fn

def get_feature_length(self, centered):
"""Take into account center padding to get number of frames."""
# Number of frames is independent of frame size for "center/same" padding.
frame_size = 1024
hop_size = self.sample_rate / self.frame_rate
padding = 'center' if centered else 'same'
return get_framed_lengths(
self._audio_length, frame_size, hop_size, padding)[0]

@property
def default_file_pattern(self):
"""Used if file_pattern is not provided to constructor."""
Expand Down Expand Up @@ -244,10 +255,11 @@ def __init__(self,
file_pattern=None,
example_secs=4,
sample_rate=16000,
frame_rate=250):
frame_rate=250,
centered=False):
"""TFRecordProvider constructor."""
super().__init__(file_pattern, example_secs, sample_rate,
frame_rate, tf.data.TFRecordDataset)
frame_rate, tf.data.TFRecordDataset, centered=centered)


# ------------------------------------------------------------------------------
Expand Down
21 changes: 9 additions & 12 deletions ddsp/training/gin/models/vst/vst.gin
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,18 @@ n_samples = 64064 # Extra frame for center padding.


# Preprocessor
# Constructor requires path to a SavedModel of the base CREPE
# models. Available on GCS at gs://crepe-models/saved_models/[full,large,small].
# Use same preprocessor for creating dataset and for training / inference.
prepare_tfrecord_lib_vst.prepare_tfrecord.preprocessor = @preprocessing.OnlineF0PowerPreprocessor()
Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor()
OnlineF0PowerPreprocessor:
time_steps = 1001 # Extra frame added for center padding.
sample_rate = %sample_rate
frame_rate = %frame_rate
frame_size = %frame_size
padding = 'center'
compute_power = True
center_power = True
power_frame_rate = %frame_rate
power_frame_size = %frame_size
compute_f0 = True
center_f0 = True
f0_frame_rate = %frame_rate
f0_frame_size = %frame_size
crepe_saved_model_path = ''
compute_f0 = False
crepe_saved_model_path = 'full'
viterbi = False
# time_steps = 1001 # Extra frame added for center padding.


# Encoder
Expand Down
11 changes: 8 additions & 3 deletions ddsp/training/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def detect_notes(loudness_db,
exponent=2.0,
smoothing=40,
f0_confidence_threshold=0.7,
min_db=-120.):
min_db=-spectral_ops.DB_RANGE):
"""Detect note on-off using loudness and smoothed f0_confidence."""
mean_db = np.mean(loudness_db)
db = smooth(f0_confidence**exponent, smoothing) * (loudness_db - min_db)
Expand Down Expand Up @@ -253,13 +253,15 @@ def fit_transform(self, x):

def compute_dataset_statistics(data_provider,
batch_size=1,
power_frame_size=256):
power_frame_size=1024,
power_frame_rate=50):
"""Calculate dataset stats.
Args:
data_provider: A DataProvider from ddsp.training.data.
batch_size: Iterate over dataset with this batch size.
power_frame_size: Calculate power features on the fly with this frame size.
power_frame_rate: Calculate power features on the fly with this frame rate.
Returns:
Dictionary of dataset statistics. This is an overcomplete set of statistics,
Expand All @@ -280,7 +282,9 @@ def compute_dataset_statistics(data_provider,
for batch in data_iter:
loudness.append(batch['loudness_db'])
power.append(
spectral_ops.compute_power(batch['audio'], frame_size=power_frame_size))
spectral_ops.compute_power(batch['audio'],
frame_size=power_frame_size,
frame_rate=power_frame_rate))
f0.append(batch['f0_hz'])
f0_conf.append(batch['f0_confidence'])
audio.append(batch['audio'])
Expand All @@ -304,6 +308,7 @@ def compute_dataset_statistics(data_provider,

# Detect notes.
mask_on, _ = detect_notes(loudness_trimmed, f0_conf_trimmed)

quantile_transform = fit_quantile_transform(loudness_trimmed, mask_on)

# Pitch statistics.
Expand Down
6 changes: 3 additions & 3 deletions ddsp/training/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ def __init__(self,
time_steps=1000,
frame_rate=250,
sample_rate=16000,
recompute_loudness=True,
compute_loudness=True,
**kwargs):
super().__init__(**kwargs)
self.time_steps = time_steps
self.frame_rate = frame_rate
self.sample_rate = sample_rate
self.recompute_loudness = recompute_loudness
self.compute_loudness = compute_loudness

def call(self, loudness_db, f0_hz, audio=None) -> [
'f0_hz', 'loudness_db', 'f0_scaled', 'ld_scaled']:
# Compute loudness fresh (it's fast).
if self.recompute_loudness:
if self.compute_loudness:
loudness_db = ddsp.spectral_ops.compute_loudness(
audio,
sample_rate=self.sample_rate,
Expand Down
2 changes: 1 addition & 1 deletion ddsp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
pulling in all the dependencies in __init__.py.
"""

__version__ = '3.0.0'
__version__ = '3.1.0'

0 comments on commit a38a0b3

Please sign in to comment.