Skip to content

Commit

Permalink
Refactor losses accounting within model.
Browse files Browse the repository at this point in the history
* Remove use of `self.add_loss()` as it requires cumbersome keeping tracking of a parallel list of loss_names for tensorboard summaries.
* Forward `__call__()` on model now has `return_losses` keyword argument that if True also returns a dictionary of all the losses on that pass, including 'total_loss'.
* Needed to use the functional form (instead of a self.losses_dict property) to handle some tf.function() headaches that arose from side-effects.
* Bumps version to v0.1.0. The version should have actually been bumped at v0.0.7 which made a minor revision change.

PiperOrigin-RevId: 308137442
  • Loading branch information
jesseengel authored and Magenta Team committed Apr 23, 2020
1 parent 546f9e5 commit ce3e995
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 34 deletions.
17 changes: 9 additions & 8 deletions ddsp/training/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,9 @@ def evaluate_or_sample(data_provider,
# Load model.
model.restore(checkpoint_path)

# Create metrics.
if mode == 'eval':
f0_loudness_metrics = F0LoudnessMetrics()
avg_losses = {name: tf.keras.metrics.Mean(name=name, dtype=tf.float32)
for name in model.loss_names}

# Iterate through dataset and make predictions
checkpoint_start_time = time.time()

for batch_idx in range(1, num_batches + 1):
try:
start_time = time.time()
Expand All @@ -396,9 +391,16 @@ def evaluate_or_sample(data_provider,
batch = next(dataset_iter)
audio = batch['audio']
# TODO(jesseengel): Find a way to add losses with training=False.
audio_gen = model(batch, training=True) # Adds losses.
audio_gen, losses = model(batch, return_losses=True, training=True)
outputs = model.get_controls(batch, training=True)

# Create metrics on first batch.
if mode == 'eval' and batch_idx == 1:
f0_loudness_metrics = F0LoudnessMetrics()
avg_losses = {
name: tf.keras.metrics.Mean(name=name, dtype=tf.float32)
for name in list(losses.keys())}


# Resample f0_hz outputs to match batch if they don't already.
has_f0 = ('f0_hz' in outputs and 'f0_hz' in batch)
Expand Down Expand Up @@ -439,7 +441,6 @@ def evaluate_or_sample(data_provider,
outputs['f0_hz'])

# Loss.
losses = model.losses_dict
for k, v in losses.items():
avg_losses[k].update_state(v)

Expand Down
41 changes: 28 additions & 13 deletions ddsp/training/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,31 @@ def get_model(model=gin.REQUIRED):
class Model(tf.keras.Model):
"""Wrap the model function for dependency injection with gin."""

def __init__(self, losses=None, name='model'):
def __init__(self, name='model'):
super().__init__(name=name)
self.loss_objs = ddsp.core.make_iterable(losses)
self.loss_names = [loss_obj.name
for loss_obj in self.loss_objs] + ['total_loss']

@property
def losses_dict(self):
"""For metrics, returns dict {loss_name: loss_value}."""
losses_dict = dict(zip(self.loss_names, self.losses))
losses_dict['total_loss'] = tf.reduce_sum(self.losses)
return losses_dict
self._losses_dict = {}

def __call__(self, *args, return_losses=False, **kwargs):
"""Reset the losses dict on each call.
Args:
*args: Arguments passed on to call().
return_losses: Return a dictionary of losses in addition to the call()
function returns.
**kwargs: Keyword arguments passed on to call().
Returns:
Function results if return_losses=False, else the function results
and a dictionary of losses, {loss_name: loss_value}.
"""
self._losses_dict = {}
results = super().__call__(*args, **kwargs)
if not return_losses:
return results
else:
self._losses_dict['total_loss'] = tf.reduce_sum(
list(self._losses_dict.values()))
return results, self._losses_dict

def restore(self, checkpoint_path):
"""Restore model and optimizer from a checkpoint."""
Expand All @@ -81,11 +94,12 @@ def __init__(self,
processor_group=None,
losses=None,
name='autoencoder'):
super().__init__(name=name, losses=losses)
super().__init__(name=name)
self.preprocessor = preprocessor
self.encoder = encoder
self.decoder = decoder
self.processor_group = processor_group
self.loss_objs = ddsp.core.make_iterable(losses)

def controls_to_audio(self, controls):
return controls[self.processor_group.name]['signal']
Expand All @@ -106,7 +120,8 @@ def call(self, features, training=True):
audio_gen = self.decode(conditioning, training=training)
if training:
for loss_obj in self.loss_objs:
self.add_loss(loss_obj(features['audio'], audio_gen))
loss = loss_obj(features['audio'], audio_gen)
self._losses_dict[loss_obj.name] = loss
return audio_gen

def get_controls(self, features, keys=None, training=False):
Expand Down
25 changes: 13 additions & 12 deletions ddsp/training/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def run(self, fn, *args, **kwargs):
return self.strategy.experimental_run_v2(fn, args=args, kwargs=kwargs)

def build(self, batch):
"""Build the model by running a batch through it."""
"""Build the model by running a distributed batch through it."""
logging.info('Building the model...')
_ = self.run(tf.function(self.model.__call__), batch)
self.model.summary()
Expand All @@ -223,13 +223,12 @@ def train_step(self, dataset_iter):
def step_fn(self, batch):
"""Per-Replica training step."""
with tf.GradientTape() as tape:
_ = self.model(batch, training=True)
total_loss = tf.reduce_sum(self.model.losses)
_, losses = self.model(batch, return_losses=True, training=True)
# Clip and apply gradients.
grads = tape.gradient(total_loss, self.model.trainable_variables)
grads = tape.gradient(losses['total_loss'], self.model.trainable_variables)
grads, _ = tf.clip_by_global_norm(grads, self.grad_clip_norm)
self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
return self.model.losses_dict
return losses


@gin.configurable
Expand All @@ -252,11 +251,6 @@ def train(data_provider,
# Load latest checkpoint if one exists in model_dir.
trainer.restore(model_dir)

# Create training loss metrics.
logging.info('Creating metrics for %s', list(trainer.model.loss_names))
avg_losses = {name: tf.keras.metrics.Mean(name=name, dtype=tf.float32)
for name in trainer.model.loss_names}

# Set up the summary writer and metrics.
summary_dir = os.path.join(model_dir, 'summaries', 'train')
summary_writer = tf.summary.create_file_writer(summary_dir)
Expand All @@ -268,12 +262,19 @@ def train(data_provider,
with summary_writer.as_default():
tick = time.time()

for _ in range(num_steps):
step = trainer.step
for iteration in range(num_steps):
step = trainer.step # Step is not iteration if restarting a model.

# Take a step.
losses = trainer.train_step(dataset_iter)

# Create training loss metrics when starting/restarting training.
if iteration == 0:
loss_names = list(losses.keys())
logging.info('Creating metrics for %s', loss_names)
avg_losses = {name: tf.keras.metrics.Mean(name=name, dtype=tf.float32)
for name in loss_names}

# Update metrics.
for k, v in losses.items():
avg_losses[k].update_state(v)
Expand Down
2 changes: 1 addition & 1 deletion ddsp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
pulling in all the dependencies in __init__.py.
"""

__version__ = '0.0.10'
__version__ = '0.1.0'

0 comments on commit ce3e995

Please sign in to comment.