Skip to content

Commit

Permalink
All inference models inherit from common base class. Preparation for …
Browse files Browse the repository at this point in the history
…model conversion scripts.

PiperOrigin-RevId: 362597624
  • Loading branch information
jesseengel authored and Magenta Team committed Mar 12, 2021
1 parent 8ec6243 commit 0f35be7
Showing 1 changed file with 84 additions and 7 deletions.
91 changes: 84 additions & 7 deletions ddsp/training/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,102 @@
import tensorflow as tf


class AutoencoderInference(models.Autoencoder):
class InferenceModel(object):
"""Base class for inference models."""

def __init__(self, ckpt, model_class, **kwargs):
self.parse_gin_config(ckpt)
model_class.__init__(self, **kwargs)
self.restore(ckpt)
self.build_network()

def parse_gin_config(self, ckpt):
with gin.unlock_config():
ckpt_dir = os.path.dirname(ckpt)
operative_config = train_util.get_latest_operative_config(ckpt_dir)
print(f'Parsing from operative_config {operative_config}')
gin.parse_config_file(operative_config, skip_unknown=True)

def build_network(self):
"""Run a fake batch through the network."""
raise NotImplementedError('Need to specify build_network() method.')

def save_model(self, save_dir):
"""Save model to save_dir, override for custom function signatures."""
self.save(save_dir)


@gin.configurable
class AutoencoderInference(models.Autoencoder, InferenceModel):
"""Create an inference-only version of the model."""

def __init__(self,
ckpt,
length_seconds=4,
sample_rate=16000,
frame_rate=250,
**kwargs):
# pylint: disable=super-init-not-called
self.length_seconds = length_seconds
self.sample_rate = sample_rate
self.frame_rate = frame_rate
self.hop_size = int(sample_rate / frame_rate)
self.time_steps = int(length_seconds * sample_rate / self.hop_size)
self.n_samples = self.time_steps * self.hop_size
self.n_frames = int(frame_rate * length_seconds)
InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs)

@tf.function
def call(self, input_dict):
"""Run the core of the network, get predictions."""
input_dict = ddsp.core.copy_if_tf_function(input_dict)
return super().call(input_dict, training=False)

def parse_gin_config(self, ckpt):
"""Parse the model operative config with new length parameters."""
with gin.unlock_config():
ckpt_dir = os.path.dirname(ckpt)
operative_config = train_util.get_latest_operative_config(ckpt_dir)
print(f'Parsing from operative_config {operative_config}')
gin.parse_config_file(operative_config, skip_unknown=True)
# Set gin params to new length.
# Remove reverb processor.
pg_string = """ProcessorGroup.dag = [
(@synths.Harmonic(),
['amps', 'harmonic_distribution', 'f0_hz']),
(@synths.FilteredNoise(),
['noise_magnitudes']),
(@processors.Add(),
['filtered_noise/signal', 'harmonic/signal']),
]"""
gin.parse_config([
'Harmonic.n_samples=%d' % self.n_samples,
'FilteredNoise.n_samples=%d' % self.n_samples,
'F0LoudnessPreprocessor.time_steps=%d' % self.time_steps,
'oscillator_bank.use_angular_cumsum=True',
pg_string,
])

def build_network(self):
"""Run a fake batch through the network."""
input_dict = {
'loudness_db': tf.zeros([self.n_frames]),
'f0_hz': tf.zeros([self.n_frames]),
}
print('Inputs to Model:', input_dict)
unused_outputs = self(input_dict)
print('Outputs', unused_outputs)


class StreamingF0Pw(models.Autoencoder):
@gin.configurable
class StreamingF0PwInference(models.Autoencoder, InferenceModel):
"""Create an inference-only version of the model."""

def __init__(self, ckpt, **kwargs):
self.parse_and_modify_gin_config(ckpt)
super().__init__(**kwargs)
self.restore(ckpt)
self.build_network()
# pylint: disable=super-init-not-called
InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs)

def parse_and_modify_gin_config(self, ckpt):
def parse_gin_config(self, ckpt):
"""Parse the model operative config with special streaming parameters."""
with gin.unlock_config():
ckpt_dir = os.path.dirname(ckpt)
Expand Down Expand Up @@ -94,3 +170,4 @@ def call(self, input_dict):
noise = controls['filtered_noise']['controls']['magnitudes']
return amps, hd, noise


0 comments on commit 0f35be7

Please sign in to comment.