From d4970b265904d5d1a0385e8cd9e7466e9b82062e Mon Sep 17 00:00:00 2001 From: Jesse Engel Date: Wed, 29 Apr 2020 14:20:18 -0700 Subject: [PATCH] Simplify and refactor RnnFcDecoder. Takes a list of input keys (from the preprocessor output) and builds a fully connected stack for each of them. Bump to version v0.2.0 as this will require old models to add a single line to their operative gin configs. `RnnFcDecoder.input_keys = ('f0_scaled', 'ld_scaled')` PiperOrigin-RevId: 309095767 --- ddsp/colab/demos/timbre_transfer.ipynb | 4 +- ddsp/losses.py | 2 +- ddsp/training/decoders.py | 64 ++++---------------- ddsp/training/gin/models/ae.gin | 18 +++--- ddsp/training/gin/models/solo_instrument.gin | 1 + ddsp/version.py | 2 +- 6 files changed, 27 insertions(+), 64 deletions(-) diff --git a/ddsp/colab/demos/timbre_transfer.ipynb b/ddsp/colab/demos/timbre_transfer.ipynb index a99b1f75..4770cc94 100644 --- a/ddsp/colab/demos/timbre_transfer.ipynb +++ b/ddsp/colab/demos/timbre_transfer.ipynb @@ -206,7 +206,6 @@ "model = 'Violin' #@param ['Violin', 'Flute', 'Flute2', 'Trumpet', 'Tenor_Saxophone','Upload your own (checkpoint folder as .zip)']\n", "MODEL = model\n", "\n", - "GCS_CKPT_DIR = 'gs://ddsp/models/tf2'\n", "\n", "def find_model_dir(dir_name):\n", " # Iterate through directories until model directory is found\n", @@ -224,7 +223,9 @@ " # Copy over from gs:// for faster loading.\n", " !rm -r $PRETRAINED_DIR \u0026\u003e /dev/null\n", " !mkdir $PRETRAINED_DIR \u0026\u003e /dev/null\n", + " GCS_CKPT_DIR = 'gs://ddsp/models/tf2'\n", " model_dir = os.path.join(GCS_CKPT_DIR, 'solo_%s_ckpt' % model.lower())\n", + " \n", " !gsutil cp $model_dir/* $PRETRAINED_DIR \u0026\u003e /dev/null\n", " model_dir = PRETRAINED_DIR\n", " gin_file = os.path.join(model_dir, 'operative_config-0.gin')\n", @@ -268,6 +269,7 @@ "# print('')\n", "\n", "gin_params = [\n", + " 'RnnFcDecoder.input_keys = (\"f0_scaled\", \"ld_scaled\")',\n", " 'Additive.n_samples = {}'.format(n_samples),\n", " 'FilteredNoise.n_samples = {}'.format(n_samples),\n", " 'DefaultPreprocessor.time_steps = {}'.format(time_steps),\n", diff --git a/ddsp/losses.py b/ddsp/losses.py index dcfdffa9..be3828c8 100644 --- a/ddsp/losses.py +++ b/ddsp/losses.py @@ -220,7 +220,7 @@ def __init__(self, self._model = crepe.core.build_and_load_model(self._model_capacity) self.frame_length = 1024 - def build(self, x_shape): + def build(self, unused_x_shape): self.layer_names = [l.name for l in self._model.layers] if self._activation_layer not in self.layer_names: diff --git a/ddsp/training/decoders.py b/ddsp/training/decoders.py index d649b790..5266da09 100644 --- a/ddsp/training/decoders.py +++ b/ddsp/training/decoders.py @@ -53,50 +53,6 @@ def decode(self, conditioning): raise NotImplementedError -@gin.register -class ZRnnFcDecoder(Decoder): - """Decompress z in time with RNN. Fully connected stacks for z as well.""" - - def __init__(self, - rnn_channels=512, - rnn_type='gru', - ch=512, - layers_per_stack=3, - append_f0_loudness=True, - output_splits=(('amps', 1), ('harmonic_distribution', 40)), - name=None): - super().__init__(output_splits=output_splits, name=name) - self.append_f0_loudness = append_f0_loudness - stack = lambda: nn.fc_stack(ch, layers_per_stack) - - # Layers. - self.f_stack = stack() - self.l_stack = stack() - self.z_stack = stack() - self.rnn = nn.rnn(rnn_channels, rnn_type) - self.out_stack = stack() - self.dense_out = nn.dense(self.n_out) - - def decode(self, conditioning): - f, l, z = (conditioning['f0_scaled'], - conditioning['ld_scaled'], - conditioning['z']) - - # Initial processing. - f = self.f_stack(f) - l = self.l_stack(l) - z = self.z_stack(z) - - # Run an RNN over the latents. - x = tf.concat([f, l, z], axis=-1) if self.append_f0_loudness else z - x = self.rnn(x) - x = tf.concat([f, l, x], axis=-1) - - # Final processing. - x = self.out_stack(x) - return self.dense_out(x) - - @gin.register class RnnFcDecoder(Decoder): """RNN and FC stacks for f0 and loudness.""" @@ -106,29 +62,33 @@ def __init__(self, rnn_type='gru', ch=512, layers_per_stack=3, + input_keys=('ld_scaled', 'f0_scaled', 'z'), output_splits=(('amps', 1), ('harmonic_distribution', 40)), name=None): super().__init__(output_splits=output_splits, name=name) stack = lambda: nn.fc_stack(ch, layers_per_stack) + self.input_keys = input_keys # Layers. - self.f_stack = stack() - self.l_stack = stack() + self.input_stacks = [stack() for k in self.input_keys] self.rnn = nn.rnn(rnn_channels, rnn_type) self.out_stack = stack() self.dense_out = nn.dense(self.n_out) - def decode(self, conditioning): - f, l = conditioning['f0_scaled'], conditioning['ld_scaled'] + # Backwards compatability. + self.f_stack = self.input_stacks[0] if len(self.input_stacks) >= 1 else None + self.l_stack = self.input_stacks[1] if len(self.input_stacks) >= 2 else None + self.z_stack = self.input_stacks[2] if len(self.input_stacks) >= 3 else None + def decode(self, conditioning): # Initial processing. - f = self.f_stack(f) - l = self.l_stack(l) + inputs = [conditioning[k] for k in self.input_keys] + inputs = [stack(x) for stack, x in zip(self.input_stacks, inputs)] # Run an RNN over the latents. - x = tf.concat([f, l], axis=-1) + x = tf.concat(inputs, axis=-1) x = self.rnn(x) - x = tf.concat([f, l, x], axis=-1) + x = tf.concat(inputs + [x], axis=-1) # Final processing. x = self.out_stack(x) diff --git a/ddsp/training/gin/models/ae.gin b/ddsp/training/gin/models/ae.gin index 992516fd..86fa9b1c 100644 --- a/ddsp/training/gin/models/ae.gin +++ b/ddsp/training/gin/models/ae.gin @@ -22,15 +22,15 @@ MfccTimeDistributedRnnEncoder.z_dims = 16 MfccTimeDistributedRnnEncoder.z_time_steps = 125 # Decoder -Autoencoder.decoder = @decoders.ZRnnFcDecoder() -ZRnnFcDecoder.rnn_channels = 512 -ZRnnFcDecoder.rnn_type = 'gru' -ZRnnFcDecoder.ch = 512 -ZRnnFcDecoder.layers_per_stack = 3 -ZRnnFcDecoder.append_f0_loudness = True -ZRnnFcDecoder.output_splits = (('amps', 1), - ('harmonic_distribution', 100), - ('noise_magnitudes', 65)) +Autoencoder.decoder = @decoders.RnnFcDecoder() +RnnFcDecoder.rnn_channels = 512 +RnnFcDecoder.rnn_type = 'gru' +RnnFcDecoder.ch = 512 +RnnFcDecoder.layers_per_stack = 3 +RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled', 'z') +RnnFcDecoder.output_splits = (('amps', 1), + ('harmonic_distribution', 100), + ('noise_magnitudes', 65)) # Losses Autoencoder.losses = [ diff --git a/ddsp/training/gin/models/solo_instrument.gin b/ddsp/training/gin/models/solo_instrument.gin index 4cfa176d..a827d350 100644 --- a/ddsp/training/gin/models/solo_instrument.gin +++ b/ddsp/training/gin/models/solo_instrument.gin @@ -14,6 +14,7 @@ RnnFcDecoder.rnn_channels = 512 RnnFcDecoder.rnn_type = 'gru' RnnFcDecoder.ch = 512 RnnFcDecoder.layers_per_stack = 3 +RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled') RnnFcDecoder.output_splits = (('amps', 1), ('harmonic_distribution', 60), ('noise_magnitudes', 65)) diff --git a/ddsp/version.py b/ddsp/version.py index 77088e9a..363d53cc 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '0.1.0' +__version__ = '0.2.0'