Skip to content

Commit

Permalink
Unify padding for "frame" operations such as f0 and dB calculations.
Browse files Browse the repository at this point in the history
* Offer 3 padding modes: 'valid', 'same', and 'center'
* get_framed_lengths() function for verifying expected lengths after framing.
* Significantly revamp spectral_ops and prepare_tfrecord tests.
* Update ddsp_prepare_tfrecord.py and to allow centered padding.
* Add type annontations for prepare_tfrecord_lib.py.
* Slim down prepare_tfrecord tests.

PiperOrigin-RevId: 427359579
  • Loading branch information
jesseengel authored and Magenta Team committed Feb 9, 2022
1 parent a260cde commit 1935ff3
Show file tree
Hide file tree
Showing 70 changed files with 609 additions and 427 deletions.
2 changes: 1 addition & 1 deletion ddsp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/colab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/colab/colab_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/core_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/dags.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/dags_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/effects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/effects_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/losses_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/processors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion ddsp/processors_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
183 changes: 121 additions & 62 deletions ddsp/spectral_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The DDSP Authors.
# Copyright 2022 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,7 +26,7 @@
import tensorflow_probability as tfp

CREPE_SAMPLE_RATE = 16000
_CREPE_FRAME_SIZE = 1024
CREPE_FRAME_SIZE = 1024

F0_RANGE = 127.0 # MIDI.
DB_RANGE = core.DB_RANGE # dB (80.0).
Expand Down Expand Up @@ -57,18 +57,11 @@ def stft_np(audio, frame_size=2048, overlap=0.75, pad_end=True):
is_2d = (len(audio.shape) == 2)

if pad_end:
n_samples_initial = int(audio.shape[-1])
n_frames = int(np.ceil(n_samples_initial / hop_size))
n_samples_final = (n_frames - 1) * hop_size + frame_size
pad = n_samples_final - n_samples_initial
padding = ((0, 0), (0, pad)) if is_2d else ((0, pad),)
audio = np.pad(audio, padding, 'constant')
audio = pad(audio, frame_size, hop_size, 'same', axis=is_2d).numpy()

def stft_fn(y):
return librosa.stft(y=y,
n_fft=int(frame_size),
hop_length=hop_size,
center=False).T
return librosa.stft(
y=y, n_fft=int(frame_size), hop_length=hop_size, center=False).T

s = np.stack([stft_fn(a) for a in audio]) if is_2d else stft_fn(audio)
return s
Expand Down Expand Up @@ -143,28 +136,105 @@ def compute_mfcc(audio,
return mfccs[..., :mfcc_bins]


def get_framed_lengths(input_length, frame_size, hop_size, padding='center'):
"""Give a strided framing, such as tf.signal.frame, gives output lengths.
Args:
input_length: Original length along the dimension to be framed.
frame_size: Size of frames for striding.
hop_size: Striding, space between frames.
padding: Type of padding to apply, ['valid', 'same', 'center']. 'valid' is
a no-op. 'same' applies padding to the end such that
n_frames = n_t / hop_size. 'center' applies padding to both ends such that
each frame timestamp is centered and n_frames = n_t / hop_size + 1.
Returns:
n_frames: Number of frames left after striding.
padded_length: Length of the padded signal before striding.
"""
# Use numpy since this function isn't used dynamically.
def get_n_frames(length):
return int(np.floor((length - frame_size) / hop_size)) + 1

if padding == 'valid':
padded_length = input_length
n_frames = get_n_frames(input_length)

elif padding == 'center':
padded_length = input_length + frame_size
n_frames = get_n_frames(padded_length)

elif padding == 'same':
n_frames = int(np.ceil(input_length / hop_size))
padded_length = (n_frames - 1) * hop_size + frame_size

return n_frames, padded_length


def pad(x, frame_size, hop_size, padding='center',
axis=1, mode='CONSTANT', constant_values=0):
"""Pad a tensor for strided framing such as tf.signal.frame.
Args:
x: Tensor to pad, any shape.
frame_size: Size of frames for striding.
hop_size: Striding, space between frames.
padding: Type of padding to apply, ['valid', 'same', 'center']. 'valid' is
a no-op. 'same' applies padding to the end such that
n_frames = n_t / hop_size. 'center' applies padding to both ends such that
each frame timestamp is centered and n_frames = n_t / hop_size + 1.
axis: Axis along which to pad `x`.
mode: Padding mode for tf.pad(). One of "CONSTANT", "REFLECT", or
"SYMMETRIC" (case-insensitive).
constant_values: Passthrough kwarg for tf.pad().
Returns:
A padded version of `x` along axis. Output sizes can be computed separately
with strided_lengths.
"""
x = tf_float32(x)

if padding == 'valid':
return x

if hop_size > frame_size:
raise ValueError(f'During padding, frame_size ({frame_size})'
f' must be greater than hop_size ({hop_size}).')

if len(x.shape) <= 1:
axis = 0

n_t = x.shape[axis]
_, n_t_padded = get_framed_lengths(n_t, frame_size, hop_size, padding)
pads = [[0, 0] for _ in range(len(x.shape))]

if padding == 'same':
pad_amount = int(n_t_padded - n_t)
pads[axis] = [0, pad_amount]

elif padding == 'center':
pad_amount = int(frame_size // 2) # Symmetric even padding like librosa.
pads[axis] = [pad_amount, pad_amount]

else:
raise ValueError('`padding` must be one of [\'center\', \'same\''
f'\'valid\'], received ({padding}).')

return tf.pad(x, pads, mode=mode, constant_values=constant_values)


def compute_rms_energy(audio,
sample_rate=16000,
frame_rate=250,
frame_size=512,
pad_end=True):
padding='center'):
"""Compute root mean squared energy of audio."""
n_samples = audio.shape[0] if len(audio.shape) == 1 else audio.shape[1]
n_secs = n_samples / float(sample_rate) # `n_secs` can have milliseconds
expected_len = int(n_secs * frame_rate)

audio = tf_float32(audio)

hop_size = sample_rate // frame_rate
audio_frames = tf.signal.frame(audio, frame_size, hop_size, pad_end=pad_end)
audio = pad(audio, frame_size, hop_size, padding=padding)
audio_frames = tf.signal.frame(audio, frame_size, hop_size, pad_end=False)
rms_energy = tf.reduce_mean(audio_frames**2.0, axis=-1)**0.5
if pad_end:
n_samples = audio.shape[0] if len(audio.shape) == 1 else audio.shape[1]
n_secs = n_samples / float(sample_rate) # `n_secs` can have milliseconds
expected_len = int(n_secs * frame_rate)
return pad_or_trim_to_expected_length(rms_energy, expected_len, use_tf=True)
else:
return rms_energy
return rms_energy


def compute_power(audio,
Expand All @@ -173,10 +243,10 @@ def compute_power(audio,
frame_size=512,
ref_db=0.0,
range_db=DB_RANGE,
pad_end=True):
padding='center'):
"""Compute power of audio in dB."""
rms_energy = compute_rms_energy(
audio, sample_rate, frame_rate, frame_size, pad_end=pad_end)
audio, sample_rate, frame_rate, frame_size, padding=padding)
power_db = core.amplitude_to_db(
rms_energy, ref_db=ref_db, range_db=range_db, use_tf=True)
return power_db
Expand All @@ -190,13 +260,13 @@ def compute_loudness(audio,
range_db=DB_RANGE,
ref_db=0.0,
use_tf=True,
pad_end=True):
padding='center'):
"""Perceptual loudness (weighted power) in dB.
Function is differentiable if use_tf=True.
Args:
audio: Numpy ndarray or tensor. Shape [batch_size, audio_length] or
[batch_size,].
[audio_length,].
sample_rate: Audio sample rate in Hz.
frame_rate: Rate of loudness frames in Hz.
n_fft: Fft window size.
Expand All @@ -208,33 +278,29 @@ def compute_loudness(audio,
n_fft=2048. With v2.0.0 it was set to 0.0 to be more consistent with power
calculations that have a natural scale for 0 dB being amplitude=1.0.
use_tf: Make function differentiable by using tensorflow.
pad_end: Add zero padding at end of audio (like `same` convolution).
padding: 'same', 'valid', or 'center'.
Returns:
Loudness in decibels. Shape [batch_size, n_frames] or [n_frames,].
"""
if sample_rate % frame_rate != 0:
raise ValueError(
'frame_rate: {} must evenly divide sample_rate: {}.'
'For default frame_rate: 250Hz, suggested sample_rate: 16kHz or 48kHz'
.format(frame_rate, sample_rate))

# Pick tensorflow or numpy.
lib = tf if use_tf else np
reduce_mean = tf.reduce_mean if use_tf else np.mean
stft_fn = stft if use_tf else stft_np

# Make inputs tensors for tensorflow.
audio = tf_float32(audio) if use_tf else audio
frame_size = n_fft
hop_size = sample_rate // frame_rate
audio = pad(audio, frame_size, hop_size, padding=padding)
audio = audio if use_tf else np.array(audio)

# Temporarily a batch dimension for single examples.
is_1d = (len(audio.shape) == 1)
audio = audio[lib.newaxis, :] if is_1d else audio

# Take STFT.
hop_size = sample_rate // frame_rate
overlap = 1 - hop_size / n_fft
s = stft_fn(audio, frame_size=n_fft, overlap=overlap, pad_end=pad_end)
overlap = 1 - hop_size / frame_size
s = stft_fn(audio, frame_size=frame_size, overlap=overlap, pad_end=False)

# Compute power.
amplitude = lib.abs(s)
Expand All @@ -258,36 +324,30 @@ def compute_loudness(audio,
# Remove temporary batch dimension.
loudness = loudness[0] if is_1d else loudness

# Compute expected length of loudness vector.
expected_secs = audio.shape[-1] / float(sample_rate)
expected_len = int(expected_secs * frame_rate)

# Pad with `-range_db` noise floor or trim vector.
loudness = pad_or_trim_to_expected_length(
loudness, expected_len, -range_db, use_tf=use_tf)

return loudness


@gin.register
def compute_f0(audio, sample_rate, frame_rate, viterbi=True):
def compute_f0(audio, frame_rate, viterbi=True, padding='center'):
"""Fundamental frequency (f0) estimate using CREPE.
This function is non-differentiable and takes input as a numpy array.
Args:
audio: Numpy ndarray of single audio example. Shape [audio_length,].
sample_rate: Sample rate in Hz.
audio: Numpy ndarray of single audio (16kHz) example. Shape [audio_length,].
frame_rate: Rate of f0 frames in Hz.
viterbi: Use Viterbi decoding to estimate f0.
padding: Apply zero-padding for centered frames.
'same', 'valid', or 'center'.
Returns:
f0_hz: Fundamental frequency in Hz. Shape [n_frames,].
f0_confidence: Confidence in Hz estimate (scaled [0, 1]). Shape [n_frames,].
"""

n_secs = len(audio) / float(sample_rate) # `n_secs` can have milliseconds
sample_rate = CREPE_SAMPLE_RATE
crepe_step_size = 1000 / frame_rate # milliseconds
expected_len = int(n_secs * frame_rate)
hop_size = sample_rate // frame_rate

audio = pad(audio, CREPE_FRAME_SIZE, hop_size, padding)
audio = np.asarray(audio)

# Compute f0 with crepe.
Expand All @@ -299,14 +359,11 @@ def compute_f0(audio, sample_rate, frame_rate, viterbi=True):
center=False,
verbose=0)

# Postprocessing on f0_hz
f0_hz = pad_or_trim_to_expected_length(f0_hz, expected_len, 0) # pad with 0
# Postprocessing.
f0_hz = f0_hz.astype(np.float32)

# Postprocessing on f0_confidence
f0_confidence = pad_or_trim_to_expected_length(f0_confidence, expected_len, 1)
f0_confidence = np.nan_to_num(f0_confidence) # Set nans to 0 in confidence
f0_confidence = f0_confidence.astype(np.float32)
f0_confidence = np.nan_to_num(f0_confidence) # Set nans to 0 in confidence

return f0_hz, f0_confidence


Expand Down Expand Up @@ -446,10 +503,12 @@ def normalize_frames(self, frames):
frames /= std[:, None]
return frames

def predict_f0_and_confidence(self, audio, viterbi=False):
def predict_f0_and_confidence(self, audio, viterbi=False, padding='center'):
audio = audio[None, :] if len(audio.shape) == 1 else audio
batch_size = audio.shape[0]

audio = pad(audio, self.frame_size, self.hop_size, padding=padding)

frames = self.batch_frames(audio)
frames = self.normalize_frames(frames)
acts = self.core_model(frames, training=False)
Expand Down
Loading

0 comments on commit 1935ff3

Please sign in to comment.