From 0f35be7c96317c4184007c70a1ddeb84df39ba3f Mon Sep 17 00:00:00 2001 From: Jesse Engel Date: Fri, 12 Mar 2021 14:08:43 -0800 Subject: [PATCH] All inference models inherit from common base class. Preparation for model conversion scripts. PiperOrigin-RevId: 362597624 --- ddsp/training/inference.py | 91 +++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/ddsp/training/inference.py b/ddsp/training/inference.py index f42ee8f0..28587a37 100644 --- a/ddsp/training/inference.py +++ b/ddsp/training/inference.py @@ -28,26 +28,102 @@ import tensorflow as tf -class AutoencoderInference(models.Autoencoder): +class InferenceModel(object): + """Base class for inference models.""" + + def __init__(self, ckpt, model_class, **kwargs): + self.parse_gin_config(ckpt) + model_class.__init__(self, **kwargs) + self.restore(ckpt) + self.build_network() + + def parse_gin_config(self, ckpt): + with gin.unlock_config(): + ckpt_dir = os.path.dirname(ckpt) + operative_config = train_util.get_latest_operative_config(ckpt_dir) + print(f'Parsing from operative_config {operative_config}') + gin.parse_config_file(operative_config, skip_unknown=True) + + def build_network(self): + """Run a fake batch through the network.""" + raise NotImplementedError('Need to specify build_network() method.') + + def save_model(self, save_dir): + """Save model to save_dir, override for custom function signatures.""" + self.save(save_dir) + + +@gin.configurable +class AutoencoderInference(models.Autoencoder, InferenceModel): """Create an inference-only version of the model.""" + def __init__(self, + ckpt, + length_seconds=4, + sample_rate=16000, + frame_rate=250, + **kwargs): + # pylint: disable=super-init-not-called + self.length_seconds = length_seconds + self.sample_rate = sample_rate + self.frame_rate = frame_rate + self.hop_size = int(sample_rate / frame_rate) + self.time_steps = int(length_seconds * sample_rate / self.hop_size) + self.n_samples = self.time_steps * self.hop_size + self.n_frames = int(frame_rate * length_seconds) + InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs) + @tf.function def call(self, input_dict): """Run the core of the network, get predictions.""" input_dict = ddsp.core.copy_if_tf_function(input_dict) return super().call(input_dict, training=False) + def parse_gin_config(self, ckpt): + """Parse the model operative config with new length parameters.""" + with gin.unlock_config(): + ckpt_dir = os.path.dirname(ckpt) + operative_config = train_util.get_latest_operative_config(ckpt_dir) + print(f'Parsing from operative_config {operative_config}') + gin.parse_config_file(operative_config, skip_unknown=True) + # Set gin params to new length. + # Remove reverb processor. + pg_string = """ProcessorGroup.dag = [ + (@synths.Harmonic(), + ['amps', 'harmonic_distribution', 'f0_hz']), + (@synths.FilteredNoise(), + ['noise_magnitudes']), + (@processors.Add(), + ['filtered_noise/signal', 'harmonic/signal']), + ]""" + gin.parse_config([ + 'Harmonic.n_samples=%d' % self.n_samples, + 'FilteredNoise.n_samples=%d' % self.n_samples, + 'F0LoudnessPreprocessor.time_steps=%d' % self.time_steps, + 'oscillator_bank.use_angular_cumsum=True', + pg_string, + ]) + + def build_network(self): + """Run a fake batch through the network.""" + input_dict = { + 'loudness_db': tf.zeros([self.n_frames]), + 'f0_hz': tf.zeros([self.n_frames]), + } + print('Inputs to Model:', input_dict) + unused_outputs = self(input_dict) + print('Outputs', unused_outputs) + -class StreamingF0Pw(models.Autoencoder): +@gin.configurable +class StreamingF0PwInference(models.Autoencoder, InferenceModel): """Create an inference-only version of the model.""" def __init__(self, ckpt, **kwargs): - self.parse_and_modify_gin_config(ckpt) - super().__init__(**kwargs) - self.restore(ckpt) - self.build_network() + # pylint: disable=super-init-not-called + InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs) - def parse_and_modify_gin_config(self, ckpt): + def parse_gin_config(self, ckpt): """Parse the model operative config with special streaming parameters.""" with gin.unlock_config(): ckpt_dir = os.path.dirname(ckpt) @@ -94,3 +170,4 @@ def call(self, input_dict): noise = controls['filtered_noise']['controls']['magnitudes'] return amps, hd, noise +