diff --git a/ddsp/colab/demos/train_autoencoder.ipynb b/ddsp/colab/demos/train_autoencoder.ipynb index 63cbeabe..d6814ad4 100644 --- a/ddsp/colab/demos/train_autoencoder.ipynb +++ b/ddsp/colab/demos/train_autoencoder.ipynb @@ -478,7 +478,7 @@ " --gin_param=\"batch_size=16\" \\\n", " --gin_param=\"train_util.train.num_steps=30000\" \\\n", " --gin_param=\"train_util.train.steps_per_save=300\" \\\n", - " --gin_param=\"train_util.Trainer.checkpoints_to_keep=10\"" + " --gin_param=\"trainers.Trainer.checkpoints_to_keep=10\"" ] }, { diff --git a/ddsp/colab/tutorials/3_training.ipynb b/ddsp/colab/tutorials/3_training.ipynb index 3cb95302..ad6115bf 100644 --- a/ddsp/colab/tutorials/3_training.ipynb +++ b/ddsp/colab/tutorials/3_training.ipynb @@ -90,7 +90,7 @@ "\n", "import ddsp\n", "from ddsp.training import (data, decoders, encoders, models, preprocessing, \n", - " train_util)\n", + " train_util, trainers)\n", "from ddsp.colab.colab_utils import play, specplot, DEFAULT_SAMPLE_RATE\n", "import gin\n", "import matplotlib.pyplot as plt\n", @@ -230,7 +230,7 @@ " decoder=decoder,\n", " processor_group=processor_group,\n", " losses=[spectral_loss])\n", - " trainer = train_util.Trainer(model, strategy, learning_rate=1e-3)" + " trainer = trainers.Trainer(model, strategy, learning_rate=1e-3)" ] }, { @@ -316,7 +316,7 @@ "with strategy.scope():\n", " # Autoencoder arguments are filled by gin.\n", " model = ddsp.training.models.Autoencoder()\n", - " trainer = train_util.Trainer(model, strategy, learning_rate=1e-3)" + " trainer = trainers.Trainer(model, strategy, learning_rate=1e-3)" ] }, { diff --git a/ddsp/spectral_ops_test.py b/ddsp/spectral_ops_test.py index 4453179e..d08eaf18 100644 --- a/ddsp/spectral_ops_test.py +++ b/ddsp/spectral_ops_test.py @@ -137,14 +137,14 @@ def setUp(self): self.frame_rate = 250 @parameterized.named_parameters( - ('16k_2.1secs', 16000, 2.1), - ('24k_2.1secs', 24000, 2.1), - ('44.1k_2.1secs', 44100, 2.1), - ('48k_2.1secs', 48000, 2.1), - ('16k_4secs', 16000, 4), - ('24k_4secs', 24000, 4), - ('44.1k_4secs', 44100, 4), - ('48k_4secs', 48000, 4), + ('16k_.21secs', 16000, .21), + ('24k_.21secs', 24000, .21), + ('44.1k_.21secs', 44100, .21), + ('48k_.21secs', 48000, .21), + ('16k_.4secs', 16000, .4), + ('24k_.4secs', 24000, .4), + ('44.1k_.4secs', 44100, .4), + ('48k_.4secs', 48000, .4), ) def test_compute_f0_at_sample_rate(self, sample_rate, audio_len_sec): audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, @@ -158,12 +158,12 @@ def test_compute_f0_at_sample_rate(self, sample_rate, audio_len_sec): self.assertTrue(np.all(np.isfinite(f0_confidence))) @parameterized.named_parameters( - ('16k_2.1secs', 16000, 2.1), - ('24k_2.1secs', 24000, 2.1), - ('48k_2.1secs', 48000, 2.1), - ('16k_4secs', 16000, 4), - ('24k_4secs', 24000, 4), - ('48k_4secs', 48000, 4), + ('16k_.21secs', 16000, .21), + ('24k_.21secs', 24000, .21), + ('48k_.21secs', 48000, .21), + ('16k_.4secs', 16000, .4), + ('24k_.4secs', 24000, .4), + ('48k_.4secs', 48000, .4), ) def test_compute_loudness_at_sample_rate_1d(self, sample_rate, audio_len_sec): audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, @@ -177,12 +177,12 @@ def test_compute_loudness_at_sample_rate_1d(self, sample_rate, audio_len_sec): self.assertTrue(np.all(np.isfinite(loudness))) @parameterized.named_parameters( - ('16k_2.1secs', 16000, 2.1), - ('24k_2.1secs', 24000, 2.1), - ('48k_2.1secs', 48000, 2.1), - ('16k_4secs', 16000, 4), - ('24k_4secs', 24000, 4), - ('48k_4secs', 48000, 4), + ('16k_.21secs', 16000, .21), + ('24k_.21secs', 24000, .21), + ('48k_.21secs', 48000, .21), + ('16k_.4secs', 16000, .4), + ('24k_.4secs', 24000, .4), + ('48k_.4secs', 48000, .4), ) def test_compute_loudness_at_sample_rate_2d(self, sample_rate, audio_len_sec): batch_size = 8 @@ -209,12 +209,12 @@ def test_compute_loudness_at_sample_rate_2d(self, sample_rate, audio_len_sec): self.assertAllClose(loudness_batch, loudness_batch_target, atol=1, rtol=1) @parameterized.named_parameters( - ('16k_2.1secs', 16000, 2.1), - ('24k_2.1secs', 24000, 2.1), - ('48k_2.1secs', 48000, 2.1), - ('16k_4secs', 16000, 4), - ('24k_4secs', 24000, 4), - ('48k_4secs', 48000, 4), + ('16k_.21secs', 16000, .21), + ('24k_.21secs', 24000, .21), + ('48k_.21secs', 48000, .21), + ('16k_.4secs', 16000, .4), + ('24k_.4secs', 24000, .4), + ('48k_.4secs', 48000, .4), ) def test_tf_compute_loudness_at_sample_rate(self, sample_rate, audio_len_sec): audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate, @@ -226,8 +226,8 @@ def test_tf_compute_loudness_at_sample_rate(self, sample_rate, audio_len_sec): self.assertTrue(np.all(np.isfinite(loudness))) @parameterized.named_parameters( - ('44.1k_2.1secs', 44100, 2.1), - ('44.1k_4secs', 44100, 4), + ('44.1k_.21secs', 44100, .21), + ('44.1k_.4secs', 44100, .4), ) def test_compute_loudness_indivisible_rates_raises_error( self, sample_rate, audio_len_sec): diff --git a/ddsp/training/README.md b/ddsp/training/README.md index 0cf08fc8..8d53069c 100644 --- a/ddsp/training/README.md +++ b/ddsp/training/README.md @@ -19,6 +19,9 @@ The DDSP training libraries are separated into several modules: * [data](./data.py): DataProvider objects that provide tf.data.Dataset. +* [inference](./inference.py): + Model wrappers for efficient inference and the ability to store as + SavedModels. * [models](./models.py): Model objects to encapsulate training and evalution. * [preprocessing](./preprocessing.py): @@ -29,9 +32,6 @@ The DDSP training libraries are separated into several modules: Layers to turn latents into ddsp processor inputs. * [nn](./nn.py): Helper library of network functions and layers. -* [inference](./inference.py): - Model wrappers for efficient inference and the ability to store as - SavedModels. The main training file is `ddsp_run.py` and its helper libraries: @@ -40,8 +40,14 @@ The main training file is `ddsp_run.py` and its helper libraries: Main file for training, evaluating, and sampling from models. * [train_util](./train_util.py): Helper functions for training including the Trainer object. +* [trainers](./trainers.py): + Helper objects to bind strategy, optimizer, and model, and define training step. * [eval_util](./eval_util.py): Helper functions for evaluation and sampling. +* [metrics](./metrics.py): + Metrics for evaluation. +* [summaries](./summaries.py): + Summaries for tensorboard. While the modules in the `ddsp/` base directory can be used to train models with `tf.compat.v1` or `tf.compat.v2` this directory only uses `tf.compat.v2`. diff --git a/ddsp/training/ddsp_run.py b/ddsp/training/ddsp_run.py index 3003589f..804269eb 100644 --- a/ddsp/training/ddsp_run.py +++ b/ddsp/training/ddsp_run.py @@ -68,6 +68,7 @@ from ddsp.training import eval_util from ddsp.training import models from ddsp.training import train_util +from ddsp.training import trainers import gin import pkg_resources import tensorflow.compat.v2 as tf @@ -169,7 +170,7 @@ def main(unused_argv): strategy = train_util.get_strategy(tpu=FLAGS.tpu, gpus=FLAGS.gpu) with strategy.scope(): model = models.get_model() - trainer = train_util.Trainer(model, strategy) + trainer = trainers.Trainer(model, strategy) train_util.train(data_provider=gin.REQUIRED, trainer=trainer, diff --git a/ddsp/training/eval_util.py b/ddsp/training/eval_util.py index 75b752c6..8ac86c55 100644 --- a/ddsp/training/eval_util.py +++ b/ddsp/training/eval_util.py @@ -15,333 +15,17 @@ # Lint as: python3 """Library of evaluation functions.""" -import io import os import time from absl import logging import ddsp -from ddsp.core import tf_float32 +from ddsp.training import metrics +from ddsp.training import summaries import gin -import librosa -import matplotlib.pyplot as plt import numpy as np import tensorflow.compat.v2 as tf -# Global values for evaluation. -MIN_F0_CONFIDENCE = 0.85 -OUTLIER_MIDI_THRESH = 12 - - -def squeeze(input_vector): - """Ensure vector only has one axis of dimensionality.""" - if input_vector.ndim > 1: - return np.squeeze(input_vector) - else: - return input_vector - - -# ---------------------- Metrics ----------------------------------------------- -def l1_distance(prediction, ground_truth): - """Average L1 distance difference between two 1-D vectors.""" - prediction, ground_truth = np.squeeze(prediction), np.squeeze(ground_truth) - min_length = min(prediction.size, ground_truth.size) - diff = prediction[:min_length] - ground_truth[:min_length] - return np.abs(diff) - - -def is_outlier(ground_truth_f0_conf): - """Determine if ground truth f0 for audio sample is an outlier.""" - ground_truth_f0_conf = squeeze(ground_truth_f0_conf) - return np.max(ground_truth_f0_conf) < MIN_F0_CONFIDENCE - - -def compute_audio_features(audio, - n_fft=2048, - sample_rate=16000, - frame_rate=250): - """Compute features from audio.""" - audio_feats = {'audio': audio} - audio = squeeze(audio) - - audio_feats['loudness_db'] = ddsp.spectral_ops.compute_loudness( - audio, sample_rate, frame_rate, n_fft) - - audio_feats['f0_hz'], audio_feats['f0_confidence'] = ( - ddsp.spectral_ops.compute_f0(audio, sample_rate, frame_rate)) - - return audio_feats - - -def f0_dist_conf_thresh(f0_hz, - f0_hz_gen, - f0_confidence, - f0_confidence_thresh=MIN_F0_CONFIDENCE): - """Compute L1 between gen audio and ground truth audio. - - Calculating F0 distance is more complicated than calculating loudness - distance because of inherent inaccuracies in pitch tracking. - - We take the following steps: - - Define a `keep_mask` that only select f0 values above when f0_confidence in - the GENERATED AUDIO (not ground truth) exceeds a minimum threshold. - Experimentation by jessengel@ and hanoih@ found this to be optimal way to - filter out bad f0 pitch tracking. - - Compute `delta_f0` between generated audio and ground truth audio. - - Only select values in `delta_f0` based on this `keep_mask` - - Compute mean on this selection - - At the start of training, audio samples will sound bad and thus have no - pitch content. If the `f0_confidence` is all below the threshold, we keep a - count of it. A better performing model will have a smaller count of - "untrackable pitch" samples. - - Args: - f0_hz: Ground truth audio f0 in hertz [MB,:]. - f0_hz_gen: Generated audio f0 in hertz [MB,:]. - f0_confidence: Ground truth audio f0 confidence [MB,:] - f0_confidence_thresh: Confidence threshold above which f0 metrics will be - computed - - Returns: - delta_f0_mean: Float or None if entire generated sample had - f0_confidence below threshold. In units of MIDI (logarithmic frequency). - """ - if len(f0_hz.shape) > 2: - f0_hz = f0_hz[:, :, 0] - if len(f0_hz_gen.shape) > 2: - f0_hz_gen = f0_hz_gen[:, :, 0] - if len(f0_confidence.shape) > 2: - f0_confidence = f0_confidence[:, :, 0] - - if np.max(f0_confidence) < f0_confidence_thresh: - # Generated audio is not good enough for reliable pitch tracking. - return None - else: - keep_mask = f0_confidence >= f0_confidence_thresh - - # Report mean error in midi space for easier interpretation. - f0_midi = librosa.core.hz_to_midi(f0_hz) - f0_midi_gen = librosa.core.hz_to_midi(f0_hz_gen) - # Set -infs introduced by hz_to_midi to 0. - f0_midi[f0_midi == -np.inf] = 0 - f0_midi_gen[f0_midi_gen == -np.inf] = 0 - - delta_f0_midi = l1_distance(f0_midi, f0_midi_gen) - delta_f0_midi_filt = delta_f0_midi[keep_mask] - return np.mean(delta_f0_midi_filt) - - -class F0LoudnessMetrics(object): - """Helper object for computing f0 and loudness metrics.""" - - def __init__(self, sample_rate): - self.metrics = { - 'loudness_db': tf.keras.metrics.Mean('loudness_db'), - 'f0_encoder': tf.keras.metrics.Mean('f0_encoder'), - 'f0_crepe': tf.keras.metrics.Mean('f0_crepe'), - 'f0_crepe_outlier_ratio': - tf.keras.metrics.Accuracy('f0_crepe_outlier_ratio'), - } - self._sample_rate = sample_rate - - def update_state(self, batch, audio_gen, f0_hz_predict): - """Update metrics based on a batch of audio. - - Args: - batch: Dictionary of input features. - audio_gen: Batch of generated audio. - f0_hz_predict: Batch of encoded f0, same as input f0 if no f0 encoder. - """ - batch_size = int(audio_gen.shape[0]) - # Compute metrics per sample. No batch operations possible. - for i in range(batch_size): - # Extract features from generated audio example. - keys = ['loudness_db', 'f0_hz', 'f0_confidence'] - feats = {k: v[i] for k, v in batch.items() if k in keys} - feats_gen = compute_audio_features( - audio_gen[i], sample_rate=self._sample_rate) - - # Loudness metric. - ld_dist = np.mean(l1_distance(feats['loudness_db'], - feats_gen['loudness_db'])) - self.metrics['loudness_db'].update_state(ld_dist) - - # F0 metric. - if is_outlier(feats['f0_confidence']): - # Ground truth f0 was unreliable to begin with. Discard. - f0_crepe_dist = None - else: - # Gound truth f0 was reliable, compute f0 distance with generated audio - f0_crepe_dist = f0_dist_conf_thresh(feats['f0_hz'], - feats_gen['f0_hz'], - feats['f0_confidence']) - - # Compute distance original f0_hz labels and f0 encoder values. - # Resample if f0 encoder has different number of time steps. - f0_encoder = f0_hz_predict[i] - f0_original = feats['f0_hz'] - if f0_encoder.shape[0] != f0_original.shape[0]: - f0_encoder = ddsp.core.resample(f0_encoder, f0_original.shape[0]) - f0_encoder_dist = f0_dist_conf_thresh(f0_original, - f0_encoder, - feats['f0_confidence']) - self.metrics['f0_encoder'].update_state(f0_encoder_dist) - - if f0_crepe_dist is None or f0_crepe_dist > OUTLIER_MIDI_THRESH: - # Generated audio had untrackable pitch content or is an outlier. - self.metrics['f0_crepe_outlier_ratio'].update_state(True, True) - logging.info('Sample %d has untrackable pitch content', i) - else: - # Generated audio had trackable pitch content and is within tolerance - self.metrics['f0_crepe'].update_state(f0_crepe_dist) - self.metrics['f0_crepe_outlier_ratio'].update_state(True, False) - logging.info( - 'sample {} | ld_dist(db): {:.3f} | f0_crepe_dist(midi): {:.3f} | ' - 'f0_encoder_dist(midi): {:.3f}'.format( - i, ld_dist, f0_crepe_dist, f0_encoder_dist)) - - def flush(self, step): - """Add summaries for each metric and reset the state.""" - # Start by logging the metrics result. - logging.info('COMPUTING METRICS COMPLETE. FLUSHING ALL METRICS') - metrics_str = ' | '.join( - '{}: {:0.3f}'.format(k, v.result()) for k, v in self.metrics.items()) - logging.info(metrics_str) - - for name, metric in self.metrics.items(): - tf.summary.scalar('metrics/{}'.format(name), metric.result(), step) - metric.reset_states() - - ddsp.spectral_ops.reset_crepe() # Reset CREPE global state - - -# ---------------------- Custom summaries -------------------------------------- -def fig_summary(tag, fig, step): - """Writes an image summary from a string buffer of an mpl figure. - - This writer writes a scalar summary in V1 format using V2 API. - - Args: - tag: An arbitrary tag name for this summary. - fig: A matplotlib figure. - step: The `int64` monotonic step variable, which defaults - to `tf.compat.v1.train.get_global_step`. - """ - buffer = io.BytesIO() - fig.savefig(buffer, format='png') - image_summary = tf.compat.v1.Summary.Image( - encoded_image_string=buffer.getvalue()) - plt.close(fig) - - pb = tf.compat.v1.Summary() - pb.value.add(tag=tag, image=image_summary) - serialized = tf.convert_to_tensor(pb.SerializeToString()) - tf.summary.experimental.write_raw_pb(serialized, step=step, name=tag) - - -def waveform_summary(audio, audio_gen, step, name=''): - """Creates a waveform plot summary for a batch of audio.""" - - def plot_waveform(i, length=None, prefix='waveform', name=''): - """Plots a waveforms.""" - waveform = squeeze(audio[i]) - waveform = waveform[:length] if length is not None else waveform - waveform_gen = squeeze(audio_gen[i]) - waveform_gen = waveform_gen[:length] if length is not None else waveform_gen - # Manually specify exact size of fig for tensorboard - fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(2.5, 2.5)) - ax0.plot(waveform) - ax1.plot(waveform_gen) - - # Format and save plot to image - name = name + '_' if name else '' - tag = 'waveform/{}{}_{}'.format(name, prefix, i + 1) - fig_summary(tag, fig, step) - - # Make plots at multiple lengths. - batch_size = int(audio.shape[0]) - for i in range(batch_size): - plot_waveform(i, length=None, prefix='full', name=name) - plot_waveform(i, length=2000, prefix='125ms', name=name) - - -def get_spectrogram(audio, rotate=False, size=1024): - """Compute logmag spectrogram.""" - mag = ddsp.spectral_ops.compute_logmag(tf_float32(audio), size=size) - if rotate: - mag = np.rot90(mag) - return mag - - -def spectrogram_summary(audio, audio_gen, step, name=''): - """Writes a summary of spectrograms for a batch of images.""" - specgram = lambda a: ddsp.spectral_ops.compute_logmag(tf_float32(a), size=768) - - # Batch spectrogram operations - spectrograms = specgram(audio) - spectrograms_gen = specgram(audio_gen) - - batch_size = int(audio.shape[0]) - for i in range(batch_size): - # Manually specify exact size of fig for tensorboard - fig, axs = plt.subplots(2, 1, figsize=(8, 8)) - - ax = axs[0] - spec = np.rot90(spectrograms[i]) - ax.matshow(spec, vmin=-5, vmax=1, aspect='auto', cmap=plt.cm.magma) - ax.set_title('original') - ax.set_xticks([]) - ax.set_yticks([]) - - ax = axs[1] - spec = np.rot90(spectrograms_gen[i]) - ax.matshow(spec, vmin=-5, vmax=1, aspect='auto', cmap=plt.cm.magma) - ax.set_title('synthesized') - ax.set_xticks([]) - ax.set_yticks([]) - - # Format and save plot to image - name = name + '_' if name else '' - tag = 'spectrogram/{}{}'.format(name, i + 1) - fig_summary(tag, fig, step) - - -def audio_summary(audio, step, sample_rate=16000, name='audio'): - """Update metrics dictionary given a batch of audio.""" - # Ensure there is a single channel dimension. - batch_size = int(audio.shape[0]) - if len(audio.shape) == 2: - audio = audio[:, :, tf.newaxis] - tf.summary.audio( - name, audio, sample_rate, step, max_outputs=batch_size, encoding='wav') - - -def f0_summary(f0_hz, f0_hz_predict, step, name=''): - """Creates a plot comparison of ground truth f0_hz and predicted values.""" - batch_size = int(f0_hz.shape[0]) - - for i in range(batch_size): - f0_midi = ddsp.core.hz_to_midi(squeeze(f0_hz[i])) - f0_midi_predict = ddsp.core.hz_to_midi(squeeze(f0_hz_predict[i])) - - # Resample if f0_encoder has different number of time steps - if f0_midi_predict.shape[0] != f0_midi.shape[0]: - f0_midi_predict = ddsp.core.resample(f0_midi_predict, f0_midi.shape[0]) - - # Manually specify exact size of fig for tensorboard - fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(6.0, 2.0)) - ax0.plot(f0_midi) - ax0.plot(f0_midi_predict) - ax0.set_title('original vs. predicted') - - ax1.plot(f0_midi_predict) - ax1.set_title('predicted') - - # Format and save plot to image - name = name + '_' if name else '' - tag = 'f0_midi/{}{}'.format(name, i + 1) - fig_summary(tag, fig, step) - # ---------------------- Evaluation -------------------------------------------- def evaluate_or_sample(data_provider, @@ -409,7 +93,8 @@ def evaluate_or_sample(data_provider, # Create metrics on first batch. if mode == 'eval' and batch_idx == 1: - f0_loudness_metrics = F0LoudnessMetrics(sample_rate=sample_rate) + f0_loudness_metrics = metrics.F0LoudnessMetrics( + sample_rate=sample_rate) avg_losses = { name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) for name in list(losses.keys())} @@ -431,14 +116,16 @@ def evaluate_or_sample(data_provider, logging.info('Writing summmaries for batch %d', batch_idx) # Add audio. - audio_summary(audio_gen, step, sample_rate, name='audio_generated') - audio_summary(audio, step, sample_rate, name='audio_original') + summaries.audio_summary( + audio_gen, step, sample_rate, name='audio_generated') + summaries.audio_summary( + audio, step, sample_rate, name='audio_original') # Add plots. - waveform_summary(audio, audio_gen, step) - spectrogram_summary(audio, audio_gen, step) + summaries.waveform_summary(audio, audio_gen, step) + summaries.spectrogram_summary(audio, audio_gen, step) if has_f0: - f0_summary(batch['f0_hz'], outputs['f0_hz'], step) + summaries.f0_summary(batch['f0_hz'], outputs['f0_hz'], step) logging.info('Writing batch %i with size %i took %.1f seconds', batch_idx, batch_size, time.time() - start_time) diff --git a/ddsp/training/metrics.py b/ddsp/training/metrics.py new file mode 100644 index 00000000..adfed7dc --- /dev/null +++ b/ddsp/training/metrics.py @@ -0,0 +1,210 @@ +# Copyright 2020 The DDSP Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Library of performance metrics relevant to DDSP training.""" + +from absl import logging +import ddsp +import librosa +import numpy as np +import tensorflow.compat.v2 as tf + +# Global values for evaluation. +MIN_F0_CONFIDENCE = 0.85 +OUTLIER_MIDI_THRESH = 12 + + + +# ---------------------- Helper Functions -------------------------------------- +def squeeze(input_vector): + """Ensure vector only has one axis of dimensionality.""" + if input_vector.ndim > 1: + return np.squeeze(input_vector) + else: + return input_vector + + +def l1_distance(prediction, ground_truth): + """Average L1 distance difference between two 1-D vectors.""" + prediction, ground_truth = np.squeeze(prediction), np.squeeze(ground_truth) + min_length = min(prediction.size, ground_truth.size) + diff = prediction[:min_length] - ground_truth[:min_length] + return np.abs(diff) + + +def is_outlier(ground_truth_f0_conf): + """Determine if ground truth f0 for audio sample is an outlier.""" + ground_truth_f0_conf = squeeze(ground_truth_f0_conf) + return np.max(ground_truth_f0_conf) < MIN_F0_CONFIDENCE + + +def compute_audio_features(audio, + n_fft=2048, + sample_rate=16000, + frame_rate=250): + """Compute features from audio.""" + audio_feats = {'audio': audio} + audio = squeeze(audio) + + audio_feats['loudness_db'] = ddsp.spectral_ops.compute_loudness( + audio, sample_rate, frame_rate, n_fft) + + audio_feats['f0_hz'], audio_feats['f0_confidence'] = ( + ddsp.spectral_ops.compute_f0(audio, sample_rate, frame_rate)) + + return audio_feats + + +def f0_dist_conf_thresh(f0_hz, + f0_hz_gen, + f0_confidence, + f0_confidence_thresh=MIN_F0_CONFIDENCE): + """Compute L1 between gen audio and ground truth audio. + + Calculating F0 distance is more complicated than calculating loudness + distance because of inherent inaccuracies in pitch tracking. + + We take the following steps: + - Define a `keep_mask` that only select f0 values above when f0_confidence in + the GENERATED AUDIO (not ground truth) exceeds a minimum threshold. + Experimentation by jessengel@ and hanoih@ found this to be optimal way to + filter out bad f0 pitch tracking. + - Compute `delta_f0` between generated audio and ground truth audio. + - Only select values in `delta_f0` based on this `keep_mask` + - Compute mean on this selection + - At the start of training, audio samples will sound bad and thus have no + pitch content. If the `f0_confidence` is all below the threshold, we keep a + count of it. A better performing model will have a smaller count of + "untrackable pitch" samples. + + Args: + f0_hz: Ground truth audio f0 in hertz [MB,:]. + f0_hz_gen: Generated audio f0 in hertz [MB,:]. + f0_confidence: Ground truth audio f0 confidence [MB,:] + f0_confidence_thresh: Confidence threshold above which f0 metrics will be + computed + + Returns: + delta_f0_mean: Float or None if entire generated sample had + f0_confidence below threshold. In units of MIDI (logarithmic frequency). + """ + if len(f0_hz.shape) > 2: + f0_hz = f0_hz[:, :, 0] + if len(f0_hz_gen.shape) > 2: + f0_hz_gen = f0_hz_gen[:, :, 0] + if len(f0_confidence.shape) > 2: + f0_confidence = f0_confidence[:, :, 0] + + if np.max(f0_confidence) < f0_confidence_thresh: + # Generated audio is not good enough for reliable pitch tracking. + return None + else: + keep_mask = f0_confidence >= f0_confidence_thresh + + # Report mean error in midi space for easier interpretation. + f0_midi = librosa.core.hz_to_midi(f0_hz) + f0_midi_gen = librosa.core.hz_to_midi(f0_hz_gen) + # Set -infs introduced by hz_to_midi to 0. + f0_midi[f0_midi == -np.inf] = 0 + f0_midi_gen[f0_midi_gen == -np.inf] = 0 + + delta_f0_midi = l1_distance(f0_midi, f0_midi_gen) + delta_f0_midi_filt = delta_f0_midi[keep_mask] + return np.mean(delta_f0_midi_filt) + + +# ---------------------- Metrics ----------------------------------------------- +class F0LoudnessMetrics(object): + """Helper object for computing f0 and loudness metrics.""" + + def __init__(self, sample_rate): + self.metrics = { + 'loudness_db': tf.keras.metrics.Mean('loudness_db'), + 'f0_encoder': tf.keras.metrics.Mean('f0_encoder'), + 'f0_crepe': tf.keras.metrics.Mean('f0_crepe'), + 'f0_crepe_outlier_ratio': + tf.keras.metrics.Accuracy('f0_crepe_outlier_ratio'), + } + self._sample_rate = sample_rate + + def update_state(self, batch, audio_gen, f0_hz_predict): + """Update metrics based on a batch of audio. + + Args: + batch: Dictionary of input features. + audio_gen: Batch of generated audio. + f0_hz_predict: Batch of encoded f0, same as input f0 if no f0 encoder. + """ + batch_size = int(audio_gen.shape[0]) + # Compute metrics per sample. No batch operations possible. + for i in range(batch_size): + # Extract features from generated audio example. + keys = ['loudness_db', 'f0_hz', 'f0_confidence'] + feats = {k: v[i] for k, v in batch.items() if k in keys} + feats_gen = compute_audio_features( + audio_gen[i], sample_rate=self._sample_rate) + + # Loudness metric. + ld_dist = np.mean(l1_distance(feats['loudness_db'], + feats_gen['loudness_db'])) + self.metrics['loudness_db'].update_state(ld_dist) + + # F0 metric. + if is_outlier(feats['f0_confidence']): + # Ground truth f0 was unreliable to begin with. Discard. + f0_crepe_dist = None + else: + # Gound truth f0 was reliable, compute f0 distance with generated audio + f0_crepe_dist = f0_dist_conf_thresh(feats['f0_hz'], + feats_gen['f0_hz'], + feats['f0_confidence']) + + # Compute distance original f0_hz labels and f0 encoder values. + # Resample if f0 encoder has different number of time steps. + f0_encoder = f0_hz_predict[i] + f0_original = feats['f0_hz'] + if f0_encoder.shape[0] != f0_original.shape[0]: + f0_encoder = ddsp.core.resample(f0_encoder, f0_original.shape[0]) + f0_encoder_dist = f0_dist_conf_thresh(f0_original, + f0_encoder, + feats['f0_confidence']) + self.metrics['f0_encoder'].update_state(f0_encoder_dist) + + if f0_crepe_dist is None or f0_crepe_dist > OUTLIER_MIDI_THRESH: + # Generated audio had untrackable pitch content or is an outlier. + self.metrics['f0_crepe_outlier_ratio'].update_state(True, True) + logging.info('Sample %d has untrackable pitch content', i) + else: + # Generated audio had trackable pitch content and is within tolerance + self.metrics['f0_crepe'].update_state(f0_crepe_dist) + self.metrics['f0_crepe_outlier_ratio'].update_state(True, False) + logging.info( + 'sample {} | ld_dist(db): {:.3f} | f0_crepe_dist(midi): {:.3f} | ' + 'f0_encoder_dist(midi): {:.3f}'.format( + i, ld_dist, f0_crepe_dist, f0_encoder_dist)) + + def flush(self, step): + """Add summaries for each metric and reset the state.""" + # Start by logging the metrics result. + logging.info('COMPUTING METRICS COMPLETE. FLUSHING ALL METRICS') + metrics_str = ' | '.join( + '{}: {:0.3f}'.format(k, v.result()) for k, v in self.metrics.items()) + logging.info(metrics_str) + + for name, metric in self.metrics.items(): + tf.summary.scalar('metrics/{}'.format(name), metric.result(), step) + metric.reset_states() + + ddsp.spectral_ops.reset_crepe() # Reset CREPE global state diff --git a/ddsp/training/eval_util_test.py b/ddsp/training/metrics_test.py similarity index 88% rename from ddsp/training/eval_util_test.py rename to ddsp/training/metrics_test.py index 1efa009a..a1edd726 100644 --- a/ddsp/training/eval_util_test.py +++ b/ddsp/training/metrics_test.py @@ -17,7 +17,7 @@ from absl.testing import parameterized from ddsp.spectral_ops_test import gen_np_sinusoid -from ddsp.training.eval_util import compute_audio_features +from ddsp.training.metrics import compute_audio_features import numpy as np import tensorflow.compat.v2 as tf @@ -41,12 +41,12 @@ def validate_output_shapes(self, audio_features, expected_feature_lengths): self.assertTrue(np.all(np.isfinite(arr))) @parameterized.named_parameters( - ('16k_2.1secs', 16000, 2.1), - ('24k_2.1secs', 24000, 2.1), - ('48k_2.1secs', 48000, 2.1), - ('16k_4secs', 16000, 4), - ('24k_4secs', 24000, 4), - ('48k_4secs', 48000, 4), + ('16k_.21secs', 16000, .21), + ('24k_.21secs', 24000, .21), + ('48k_.21secs', 48000, .21), + ('16k_.4secs', 16000, .4), + ('24k_.4secs', 24000, .4), + ('48k_.4secs', 48000, .4), ) def test_correct_shape_compute_af_at_sample_rate(self, sample_rate, audio_len_sec): @@ -65,8 +65,8 @@ def test_correct_shape_compute_af_at_sample_rate(self, sample_rate, }) @parameterized.named_parameters( - ('44.1k_2.1secs', 44100, 2.1), - ('44.1k_4secs', 44100, 4), + ('44.1k_.21secs', 44100, .21), + ('44.1k_.4secs', 44100, .4), ) def test_indivisible_rates_raises_error_compute_af(self, sample_rate, audio_len_sec): diff --git a/ddsp/training/summaries.py b/ddsp/training/summaries.py new file mode 100644 index 00000000..0bf04a3d --- /dev/null +++ b/ddsp/training/summaries.py @@ -0,0 +1,153 @@ +# Copyright 2020 The DDSP Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Library of tensorboard summary functions relevant to DDSP training.""" + +import io + +import ddsp +from ddsp.core import tf_float32 +import matplotlib.pyplot as plt +import numpy as np +import tensorflow.compat.v2 as tf + + +def fig_summary(tag, fig, step): + """Writes an image summary from a string buffer of an mpl figure. + + This writer writes a scalar summary in V1 format using V2 API. + + Args: + tag: An arbitrary tag name for this summary. + fig: A matplotlib figure. + step: The `int64` monotonic step variable, which defaults + to `tf.compat.v1.train.get_global_step`. + """ + buffer = io.BytesIO() + fig.savefig(buffer, format='png') + image_summary = tf.compat.v1.Summary.Image( + encoded_image_string=buffer.getvalue()) + plt.close(fig) + + pb = tf.compat.v1.Summary() + pb.value.add(tag=tag, image=image_summary) + serialized = tf.convert_to_tensor(pb.SerializeToString()) + tf.summary.experimental.write_raw_pb(serialized, step=step, name=tag) + + +def waveform_summary(audio, audio_gen, step, name=''): + """Creates a waveform plot summary for a batch of audio.""" + + def plot_waveform(i, length=None, prefix='waveform', name=''): + """Plots a waveforms.""" + waveform = np.squeeze(audio[i]) + waveform = waveform[:length] if length is not None else waveform + waveform_gen = np.squeeze(audio_gen[i]) + waveform_gen = waveform_gen[:length] if length is not None else waveform_gen + # Manually specify exact size of fig for tensorboard + fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(2.5, 2.5)) + ax0.plot(waveform) + ax1.plot(waveform_gen) + + # Format and save plot to image + name = name + '_' if name else '' + tag = 'waveform/{}{}_{}'.format(name, prefix, i + 1) + fig_summary(tag, fig, step) + + # Make plots at multiple lengths. + batch_size = int(audio.shape[0]) + for i in range(batch_size): + plot_waveform(i, length=None, prefix='full', name=name) + plot_waveform(i, length=2000, prefix='125ms', name=name) + + +def get_spectrogram(audio, rotate=False, size=1024): + """Compute logmag spectrogram.""" + mag = ddsp.spectral_ops.compute_logmag(tf_float32(audio), size=size) + if rotate: + mag = np.rot90(mag) + return mag + + +def spectrogram_summary(audio, audio_gen, step, name=''): + """Writes a summary of spectrograms for a batch of images.""" + specgram = lambda a: ddsp.spectral_ops.compute_logmag(tf_float32(a), size=768) + + # Batch spectrogram operations + spectrograms = specgram(audio) + spectrograms_gen = specgram(audio_gen) + + batch_size = int(audio.shape[0]) + for i in range(batch_size): + # Manually specify exact size of fig for tensorboard + fig, axs = plt.subplots(2, 1, figsize=(8, 8)) + + ax = axs[0] + spec = np.rot90(spectrograms[i]) + ax.matshow(spec, vmin=-5, vmax=1, aspect='auto', cmap=plt.cm.magma) + ax.set_title('original') + ax.set_xticks([]) + ax.set_yticks([]) + + ax = axs[1] + spec = np.rot90(spectrograms_gen[i]) + ax.matshow(spec, vmin=-5, vmax=1, aspect='auto', cmap=plt.cm.magma) + ax.set_title('synthesized') + ax.set_xticks([]) + ax.set_yticks([]) + + # Format and save plot to image + name = name + '_' if name else '' + tag = 'spectrogram/{}{}'.format(name, i + 1) + fig_summary(tag, fig, step) + + +def audio_summary(audio, step, sample_rate=16000, name='audio'): + """Update metrics dictionary given a batch of audio.""" + # Ensure there is a single channel dimension. + batch_size = int(audio.shape[0]) + if len(audio.shape) == 2: + audio = audio[:, :, tf.newaxis] + tf.summary.audio( + name, audio, sample_rate, step, max_outputs=batch_size, encoding='wav') + + +def f0_summary(f0_hz, f0_hz_predict, step, name=''): + """Creates a plot comparison of ground truth f0_hz and predicted values.""" + batch_size = int(f0_hz.shape[0]) + + for i in range(batch_size): + f0_midi = ddsp.core.hz_to_midi(tf.squeeze(f0_hz[i])) + f0_midi_predict = ddsp.core.hz_to_midi(tf.squeeze(f0_hz_predict[i])) + + # Resample if f0_encoder has different number of time steps + if f0_midi_predict.shape[0] != f0_midi.shape[0]: + f0_midi_predict = ddsp.core.resample(f0_midi_predict, f0_midi.shape[0]) + + # Manually specify exact size of fig for tensorboard + fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(6.0, 2.0)) + ax0.plot(f0_midi) + ax0.plot(f0_midi_predict) + ax0.set_title('original vs. predicted') + + ax1.plot(f0_midi_predict) + ax1.set_title('predicted') + + # Format and save plot to image + name = name + '_' if name else '' + tag = 'f0_midi/{}{}'.format(name, i + 1) + fig_summary(tag, fig, step) + + diff --git a/ddsp/training/train_util.py b/ddsp/training/train_util.py index 8bb57184..a7eb17e5 100644 --- a/ddsp/training/train_util.py +++ b/ddsp/training/train_util.py @@ -23,6 +23,7 @@ import tensorflow.compat.v2 as tf +# ---------------------- Helper Functions -------------------------------------- def get_strategy(tpu='', gpus=None): """Create a distribution strategy. @@ -123,146 +124,7 @@ def format_for_tensorboard(line): summary_writer.flush() -@gin.configurable -class Trainer(object): - """Class to bind an optimizer, model, strategy, and training step function.""" - - def __init__(self, - model, - strategy, - checkpoints_to_keep=100, - learning_rate=0.001, - lr_decay_steps=10000, - lr_decay_rate=0.98, - grad_clip_norm=3.0, - restore_keys=None): - """Constructor. - - Args: - model: Model to train. - strategy: A distribution strategy. - checkpoints_to_keep: Max number of checkpoints before deleting oldest. - learning_rate: Scalar initial learning rate. - lr_decay_steps: Exponential decay timescale. - lr_decay_rate: Exponential decay magnitude. - grad_clip_norm: Norm level by which to clip gradients. - restore_keys: List of names of model properties to restore. If no keys are - passed, restore the whole model. - """ - self.model = model - self.strategy = strategy - self.checkpoints_to_keep = checkpoints_to_keep - self.grad_clip_norm = grad_clip_norm - self.restore_keys = restore_keys - - # Create an optimizer. - lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( - initial_learning_rate=learning_rate, - decay_steps=lr_decay_steps, - decay_rate=lr_decay_rate) - - with self.strategy.scope(): - optimizer = tf.keras.optimizers.Adam(lr_schedule) - self.optimizer = optimizer - - def save(self, save_dir): - """Saves model and optimizer to a checkpoint.""" - # Saving weights in checkpoint format because saved_model requires - # handling variable batch size, which some synths and effects can't. - start_time = time.time() - checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) - manager = tf.train.CheckpointManager( - checkpoint, directory=save_dir, max_to_keep=self.checkpoints_to_keep) - step = self.step.numpy() - manager.save(checkpoint_number=step) - logging.info('Saved checkpoint to %s at step %s', save_dir, step) - logging.info('Saving model took %.1f seconds', time.time() - start_time) - - def restore(self, checkpoint_path, restore_keys=None): - """Restore model and optimizer from a checkpoint if it exists.""" - logging.info('Restoring from checkpoint...') - start_time = time.time() - - # Prefer function args over object properties. - restore_keys = self.restore_keys if restore_keys is None else restore_keys - if restore_keys is None: - # If no keys are passed, restore the whole model. - model = self.model - logging.info('Trainer restoring the full model') - else: - # Restore only sub-modules by building a new subgraph. - restore_dict = {k: getattr(self.model, k) for k in restore_keys} - model = tf.train.Checkpoint(**restore_dict) - - logging.info('Trainer restoring model subcomponents:') - for k, v in restore_dict.items(): - log_str = 'Restoring {}: {}'.format(k, v) - logging.info(log_str) - - # Restore from latest checkpoint. - checkpoint = tf.train.Checkpoint(model=model, optimizer=self.optimizer) - latest_checkpoint = get_latest_chekpoint(checkpoint_path) - if latest_checkpoint is not None: - # checkpoint.restore must be within a strategy.scope() so that optimizer - # slot variables are mirrored. - with self.strategy.scope(): - if restore_keys is None: - checkpoint.restore(latest_checkpoint) - else: - checkpoint.restore(latest_checkpoint).expect_partial() - logging.info('Loaded checkpoint %s', latest_checkpoint) - logging.info('Loading model took %.1f seconds', time.time() - start_time) - else: - logging.info('No checkpoint, skipping.') - - @property - def step(self): - """The number of training steps completed.""" - return self.optimizer.iterations - - def psum(self, x, axis=None): - """Sum across processors.""" - return self.strategy.reduce(tf.distribute.ReduceOp.SUM, x, axis=axis) - - def run(self, fn, *args, **kwargs): - """Distribute and run function on processors.""" - return self.strategy.experimental_run_v2(fn, args=args, kwargs=kwargs) - - def build(self, batch): - """Build the model by running a distributed batch through it.""" - logging.info('Building the model...') - _ = self.run(tf.function(self.model.__call__), batch) - self.model.summary() - - def distribute_dataset(self, dataset): - """Create a distributed dataset.""" - if isinstance(dataset, tf.data.Dataset): - return self.strategy.experimental_distribute_dataset(dataset) - else: - return dataset - - @tf.function - def train_step(self, inputs): - """Distributed training step.""" - # Wrap iterator in tf.function, slight speedup passing in iter vs batch. - batch = next(inputs) if hasattr(inputs, '__next__') else inputs - losses = self.run(self.step_fn, batch) - # Add up the scalar losses across replicas. - n_replicas = self.strategy.num_replicas_in_sync - return {k: self.psum(v, axis=None) / n_replicas for k, v in losses.items()} - - @tf.function - def step_fn(self, batch): - """Per-Replica training step.""" - with tf.GradientTape() as tape: - _, losses = self.model(batch, return_losses=True, training=True) - # Clip and apply gradients. - grads = tape.gradient(losses['total_loss'], self.model.trainable_variables) - grads, _ = tf.clip_by_global_norm(grads, self.grad_clip_norm) - self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) - return losses - - +# ------------------------ Training Loop --------------------------------------- @gin.configurable def train(data_provider, trainer, diff --git a/ddsp/training/trainers.py b/ddsp/training/trainers.py new file mode 100644 index 00000000..7c370697 --- /dev/null +++ b/ddsp/training/trainers.py @@ -0,0 +1,165 @@ +# Copyright 2020 The DDSP Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Library of Trainer objects that define traning step and wrap optimizer.""" + +import time + +from absl import logging +from ddsp.training import train_util +import gin +import tensorflow.compat.v2 as tf + + +@gin.configurable +class Trainer(object): + """Class to bind an optimizer, model, strategy, and training step function.""" + + def __init__(self, + model, + strategy, + checkpoints_to_keep=100, + learning_rate=0.001, + lr_decay_steps=10000, + lr_decay_rate=0.98, + grad_clip_norm=3.0, + restore_keys=None): + """Constructor. + + Args: + model: Model to train. + strategy: A distribution strategy. + checkpoints_to_keep: Max number of checkpoints before deleting oldest. + learning_rate: Scalar initial learning rate. + lr_decay_steps: Exponential decay timescale. + lr_decay_rate: Exponential decay magnitude. + grad_clip_norm: Norm level by which to clip gradients. + restore_keys: List of names of model properties to restore. If no keys are + passed, restore the whole model. + """ + self.model = model + self.strategy = strategy + self.checkpoints_to_keep = checkpoints_to_keep + self.grad_clip_norm = grad_clip_norm + self.restore_keys = restore_keys + + # Create an optimizer. + lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate=learning_rate, + decay_steps=lr_decay_steps, + decay_rate=lr_decay_rate) + + with self.strategy.scope(): + optimizer = tf.keras.optimizers.Adam(lr_schedule) + self.optimizer = optimizer + + def save(self, save_dir): + """Saves model and optimizer to a checkpoint.""" + # Saving weights in checkpoint format because saved_model requires + # handling variable batch size, which some synths and effects can't. + start_time = time.time() + checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) + manager = tf.train.CheckpointManager( + checkpoint, directory=save_dir, max_to_keep=self.checkpoints_to_keep) + step = self.step.numpy() + manager.save(checkpoint_number=step) + logging.info('Saved checkpoint to %s at step %s', save_dir, step) + logging.info('Saving model took %.1f seconds', time.time() - start_time) + + def restore(self, checkpoint_path, restore_keys=None): + """Restore model and optimizer from a checkpoint if it exists.""" + logging.info('Restoring from checkpoint...') + start_time = time.time() + + # Prefer function args over object properties. + restore_keys = self.restore_keys if restore_keys is None else restore_keys + if restore_keys is None: + # If no keys are passed, restore the whole model. + model = self.model + logging.info('Trainer restoring the full model') + else: + # Restore only sub-modules by building a new subgraph. + restore_dict = {k: getattr(self.model, k) for k in restore_keys} + model = tf.train.Checkpoint(**restore_dict) + + logging.info('Trainer restoring model subcomponents:') + for k, v in restore_dict.items(): + log_str = 'Restoring {}: {}'.format(k, v) + logging.info(log_str) + + # Restore from latest checkpoint. + checkpoint = tf.train.Checkpoint(model=model, optimizer=self.optimizer) + latest_checkpoint = train_util.get_latest_chekpoint(checkpoint_path) + if latest_checkpoint is not None: + # checkpoint.restore must be within a strategy.scope() so that optimizer + # slot variables are mirrored. + with self.strategy.scope(): + if restore_keys is None: + checkpoint.restore(latest_checkpoint) + else: + checkpoint.restore(latest_checkpoint).expect_partial() + logging.info('Loaded checkpoint %s', latest_checkpoint) + logging.info('Loading model took %.1f seconds', time.time() - start_time) + else: + logging.info('No checkpoint, skipping.') + + @property + def step(self): + """The number of training steps completed.""" + return self.optimizer.iterations + + def psum(self, x, axis=None): + """Sum across processors.""" + return self.strategy.reduce(tf.distribute.ReduceOp.SUM, x, axis=axis) + + def run(self, fn, *args, **kwargs): + """Distribute and run function on processors.""" + return self.strategy.experimental_run_v2(fn, args=args, kwargs=kwargs) + + def build(self, batch): + """Build the model by running a distributed batch through it.""" + logging.info('Building the model...') + _ = self.run(tf.function(self.model.__call__), batch) + self.model.summary() + + def distribute_dataset(self, dataset): + """Create a distributed dataset.""" + if isinstance(dataset, tf.data.Dataset): + return self.strategy.experimental_distribute_dataset(dataset) + else: + return dataset + + @tf.function + def train_step(self, inputs): + """Distributed training step.""" + # Wrap iterator in tf.function, slight speedup passing in iter vs batch. + batch = next(inputs) if hasattr(inputs, '__next__') else inputs + losses = self.run(self.step_fn, batch) + # Add up the scalar losses across replicas. + n_replicas = self.strategy.num_replicas_in_sync + return {k: self.psum(v, axis=None) / n_replicas for k, v in losses.items()} + + @tf.function + def step_fn(self, batch): + """Per-Replica training step.""" + with tf.GradientTape() as tape: + _, losses = self.model(batch, return_losses=True, training=True) + # Clip and apply gradients. + grads = tape.gradient(losses['total_loss'], self.model.trainable_variables) + grads, _ = tf.clip_by_global_norm(grads, self.grad_clip_norm) + self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) + return losses + + diff --git a/ddsp/version.py b/ddsp/version.py index 77ae4d87..f6257fc2 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '0.4.1' +__version__ = '0.5.0'