Skip to content

Commit

Permalink
Eval Refactor
Browse files Browse the repository at this point in the history
A step towards more configurable eval jobs with gin.
* Move summaries and metrics into their own files.
* eval_util_test.py -> metrics_test.py
* Move Trainers to their own file too
* Reduced test size for audio features tests by reducing length of time for fake data by factor of 10.

No functional training changes.

PiperOrigin-RevId: 314978813
  • Loading branch information
jesseengel authored and Magenta Team committed Jun 5, 2020
1 parent 128c4f4 commit 959014c
Show file tree
Hide file tree
Showing 12 changed files with 594 additions and 510 deletions.
2 changes: 1 addition & 1 deletion ddsp/colab/demos/train_autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@
" --gin_param=\"batch_size=16\" \\\n",
" --gin_param=\"train_util.train.num_steps=30000\" \\\n",
" --gin_param=\"train_util.train.steps_per_save=300\" \\\n",
" --gin_param=\"train_util.Trainer.checkpoints_to_keep=10\""
" --gin_param=\"trainers.Trainer.checkpoints_to_keep=10\""
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions ddsp/colab/tutorials/3_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"\n",
"import ddsp\n",
"from ddsp.training import (data, decoders, encoders, models, preprocessing, \n",
" train_util)\n",
" train_util, trainers)\n",
"from ddsp.colab.colab_utils import play, specplot, DEFAULT_SAMPLE_RATE\n",
"import gin\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -230,7 +230,7 @@
" decoder=decoder,\n",
" processor_group=processor_group,\n",
" losses=[spectral_loss])\n",
" trainer = train_util.Trainer(model, strategy, learning_rate=1e-3)"
" trainer = trainers.Trainer(model, strategy, learning_rate=1e-3)"
]
},
{
Expand Down Expand Up @@ -316,7 +316,7 @@
"with strategy.scope():\n",
" # Autoencoder arguments are filled by gin.\n",
" model = ddsp.training.models.Autoencoder()\n",
" trainer = train_util.Trainer(model, strategy, learning_rate=1e-3)"
" trainer = trainers.Trainer(model, strategy, learning_rate=1e-3)"
]
},
{
Expand Down
56 changes: 28 additions & 28 deletions ddsp/spectral_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def setUp(self):
self.frame_rate = 250

@parameterized.named_parameters(
('16k_2.1secs', 16000, 2.1),
('24k_2.1secs', 24000, 2.1),
('44.1k_2.1secs', 44100, 2.1),
('48k_2.1secs', 48000, 2.1),
('16k_4secs', 16000, 4),
('24k_4secs', 24000, 4),
('44.1k_4secs', 44100, 4),
('48k_4secs', 48000, 4),
('16k_.21secs', 16000, .21),
('24k_.21secs', 24000, .21),
('44.1k_.21secs', 44100, .21),
('48k_.21secs', 48000, .21),
('16k_.4secs', 16000, .4),
('24k_.4secs', 24000, .4),
('44.1k_.4secs', 44100, .4),
('48k_.4secs', 48000, .4),
)
def test_compute_f0_at_sample_rate(self, sample_rate, audio_len_sec):
audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate,
Expand All @@ -158,12 +158,12 @@ def test_compute_f0_at_sample_rate(self, sample_rate, audio_len_sec):
self.assertTrue(np.all(np.isfinite(f0_confidence)))

@parameterized.named_parameters(
('16k_2.1secs', 16000, 2.1),
('24k_2.1secs', 24000, 2.1),
('48k_2.1secs', 48000, 2.1),
('16k_4secs', 16000, 4),
('24k_4secs', 24000, 4),
('48k_4secs', 48000, 4),
('16k_.21secs', 16000, .21),
('24k_.21secs', 24000, .21),
('48k_.21secs', 48000, .21),
('16k_.4secs', 16000, .4),
('24k_.4secs', 24000, .4),
('48k_.4secs', 48000, .4),
)
def test_compute_loudness_at_sample_rate_1d(self, sample_rate, audio_len_sec):
audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate,
Expand All @@ -177,12 +177,12 @@ def test_compute_loudness_at_sample_rate_1d(self, sample_rate, audio_len_sec):
self.assertTrue(np.all(np.isfinite(loudness)))

@parameterized.named_parameters(
('16k_2.1secs', 16000, 2.1),
('24k_2.1secs', 24000, 2.1),
('48k_2.1secs', 48000, 2.1),
('16k_4secs', 16000, 4),
('24k_4secs', 24000, 4),
('48k_4secs', 48000, 4),
('16k_.21secs', 16000, .21),
('24k_.21secs', 24000, .21),
('48k_.21secs', 48000, .21),
('16k_.4secs', 16000, .4),
('24k_.4secs', 24000, .4),
('48k_.4secs', 48000, .4),
)
def test_compute_loudness_at_sample_rate_2d(self, sample_rate, audio_len_sec):
batch_size = 8
Expand All @@ -209,12 +209,12 @@ def test_compute_loudness_at_sample_rate_2d(self, sample_rate, audio_len_sec):
self.assertAllClose(loudness_batch, loudness_batch_target, atol=1, rtol=1)

@parameterized.named_parameters(
('16k_2.1secs', 16000, 2.1),
('24k_2.1secs', 24000, 2.1),
('48k_2.1secs', 48000, 2.1),
('16k_4secs', 16000, 4),
('24k_4secs', 24000, 4),
('48k_4secs', 48000, 4),
('16k_.21secs', 16000, .21),
('24k_.21secs', 24000, .21),
('48k_.21secs', 48000, .21),
('16k_.4secs', 16000, .4),
('24k_.4secs', 24000, .4),
('48k_.4secs', 48000, .4),
)
def test_tf_compute_loudness_at_sample_rate(self, sample_rate, audio_len_sec):
audio_sin = gen_np_sinusoid(self.frequency, self.amp, sample_rate,
Expand All @@ -226,8 +226,8 @@ def test_tf_compute_loudness_at_sample_rate(self, sample_rate, audio_len_sec):
self.assertTrue(np.all(np.isfinite(loudness)))

@parameterized.named_parameters(
('44.1k_2.1secs', 44100, 2.1),
('44.1k_4secs', 44100, 4),
('44.1k_.21secs', 44100, .21),
('44.1k_.4secs', 44100, .4),
)
def test_compute_loudness_indivisible_rates_raises_error(
self, sample_rate, audio_len_sec):
Expand Down
12 changes: 9 additions & 3 deletions ddsp/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ The DDSP training libraries are separated into several modules:

* [data](./data.py):
DataProvider objects that provide tf.data.Dataset.
* [inference](./inference.py):
Model wrappers for efficient inference and the ability to store as
SavedModels.
* [models](./models.py):
Model objects to encapsulate training and evalution.
* [preprocessing](./preprocessing.py):
Expand All @@ -29,9 +32,6 @@ The DDSP training libraries are separated into several modules:
Layers to turn latents into ddsp processor inputs.
* [nn](./nn.py):
Helper library of network functions and layers.
* [inference](./inference.py):
Model wrappers for efficient inference and the ability to store as
SavedModels.


The main training file is `ddsp_run.py` and its helper libraries:
Expand All @@ -40,8 +40,14 @@ The main training file is `ddsp_run.py` and its helper libraries:
Main file for training, evaluating, and sampling from models.
* [train_util](./train_util.py):
Helper functions for training including the Trainer object.
* [trainers](./trainers.py):
Helper objects to bind strategy, optimizer, and model, and define training step.
* [eval_util](./eval_util.py):
Helper functions for evaluation and sampling.
* [metrics](./metrics.py):
Metrics for evaluation.
* [summaries](./summaries.py):
Summaries for tensorboard.

While the modules in the `ddsp/` base directory can be used to train models
with `tf.compat.v1` or `tf.compat.v2` this directory only uses `tf.compat.v2`.
Expand Down
3 changes: 2 additions & 1 deletion ddsp/training/ddsp_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from ddsp.training import eval_util
from ddsp.training import models
from ddsp.training import train_util
from ddsp.training import trainers
import gin
import pkg_resources
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -169,7 +170,7 @@ def main(unused_argv):
strategy = train_util.get_strategy(tpu=FLAGS.tpu, gpus=FLAGS.gpu)
with strategy.scope():
model = models.get_model()
trainer = train_util.Trainer(model, strategy)
trainer = trainers.Trainer(model, strategy)

train_util.train(data_provider=gin.REQUIRED,
trainer=trainer,
Expand Down
Loading

0 comments on commit 959014c

Please sign in to comment.