Skip to content

Commit

Permalink
Simplify and refactor RnnFcDecoder. Takes a list of input keys (from …
Browse files Browse the repository at this point in the history
…the preprocessor output) and builds a fully connected stack for each of them.

Bump to version v0.2.0 as this will require old models to add a single line to their operative gin configs.
`RnnFcDecoder.input_keys = ('f0_scaled', 'ld_scaled')`

PiperOrigin-RevId: 309095767
  • Loading branch information
jesseengel authored and Magenta Team committed Apr 29, 2020
1 parent 3b31ad6 commit d4970b2
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 64 deletions.
4 changes: 3 additions & 1 deletion ddsp/colab/demos/timbre_transfer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@
"model = 'Violin' #@param ['Violin', 'Flute', 'Flute2', 'Trumpet', 'Tenor_Saxophone','Upload your own (checkpoint folder as .zip)']\n",
"MODEL = model\n",
"\n",
"GCS_CKPT_DIR = 'gs://ddsp/models/tf2'\n",
"\n",
"def find_model_dir(dir_name):\n",
" # Iterate through directories until model directory is found\n",
Expand All @@ -224,7 +223,9 @@
" # Copy over from gs:// for faster loading.\n",
" !rm -r $PRETRAINED_DIR \u0026\u003e /dev/null\n",
" !mkdir $PRETRAINED_DIR \u0026\u003e /dev/null\n",
" GCS_CKPT_DIR = 'gs://ddsp/models/tf2'\n",
" model_dir = os.path.join(GCS_CKPT_DIR, 'solo_%s_ckpt' % model.lower())\n",
" \n",
" !gsutil cp $model_dir/* $PRETRAINED_DIR \u0026\u003e /dev/null\n",
" model_dir = PRETRAINED_DIR\n",
" gin_file = os.path.join(model_dir, 'operative_config-0.gin')\n",
Expand Down Expand Up @@ -268,6 +269,7 @@
"# print('')\n",
"\n",
"gin_params = [\n",
" 'RnnFcDecoder.input_keys = (\"f0_scaled\", \"ld_scaled\")',\n",
" 'Additive.n_samples = {}'.format(n_samples),\n",
" 'FilteredNoise.n_samples = {}'.format(n_samples),\n",
" 'DefaultPreprocessor.time_steps = {}'.format(time_steps),\n",
Expand Down
2 changes: 1 addition & 1 deletion ddsp/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(self,
self._model = crepe.core.build_and_load_model(self._model_capacity)
self.frame_length = 1024

def build(self, x_shape):
def build(self, unused_x_shape):
self.layer_names = [l.name for l in self._model.layers]

if self._activation_layer not in self.layer_names:
Expand Down
64 changes: 12 additions & 52 deletions ddsp/training/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,50 +53,6 @@ def decode(self, conditioning):
raise NotImplementedError


@gin.register
class ZRnnFcDecoder(Decoder):
"""Decompress z in time with RNN. Fully connected stacks for z as well."""

def __init__(self,
rnn_channels=512,
rnn_type='gru',
ch=512,
layers_per_stack=3,
append_f0_loudness=True,
output_splits=(('amps', 1), ('harmonic_distribution', 40)),
name=None):
super().__init__(output_splits=output_splits, name=name)
self.append_f0_loudness = append_f0_loudness
stack = lambda: nn.fc_stack(ch, layers_per_stack)

# Layers.
self.f_stack = stack()
self.l_stack = stack()
self.z_stack = stack()
self.rnn = nn.rnn(rnn_channels, rnn_type)
self.out_stack = stack()
self.dense_out = nn.dense(self.n_out)

def decode(self, conditioning):
f, l, z = (conditioning['f0_scaled'],
conditioning['ld_scaled'],
conditioning['z'])

# Initial processing.
f = self.f_stack(f)
l = self.l_stack(l)
z = self.z_stack(z)

# Run an RNN over the latents.
x = tf.concat([f, l, z], axis=-1) if self.append_f0_loudness else z
x = self.rnn(x)
x = tf.concat([f, l, x], axis=-1)

# Final processing.
x = self.out_stack(x)
return self.dense_out(x)


@gin.register
class RnnFcDecoder(Decoder):
"""RNN and FC stacks for f0 and loudness."""
Expand All @@ -106,29 +62,33 @@ def __init__(self,
rnn_type='gru',
ch=512,
layers_per_stack=3,
input_keys=('ld_scaled', 'f0_scaled', 'z'),
output_splits=(('amps', 1), ('harmonic_distribution', 40)),
name=None):
super().__init__(output_splits=output_splits, name=name)
stack = lambda: nn.fc_stack(ch, layers_per_stack)
self.input_keys = input_keys

# Layers.
self.f_stack = stack()
self.l_stack = stack()
self.input_stacks = [stack() for k in self.input_keys]
self.rnn = nn.rnn(rnn_channels, rnn_type)
self.out_stack = stack()
self.dense_out = nn.dense(self.n_out)

def decode(self, conditioning):
f, l = conditioning['f0_scaled'], conditioning['ld_scaled']
# Backwards compatability.
self.f_stack = self.input_stacks[0] if len(self.input_stacks) >= 1 else None
self.l_stack = self.input_stacks[1] if len(self.input_stacks) >= 2 else None
self.z_stack = self.input_stacks[2] if len(self.input_stacks) >= 3 else None

def decode(self, conditioning):
# Initial processing.
f = self.f_stack(f)
l = self.l_stack(l)
inputs = [conditioning[k] for k in self.input_keys]
inputs = [stack(x) for stack, x in zip(self.input_stacks, inputs)]

# Run an RNN over the latents.
x = tf.concat([f, l], axis=-1)
x = tf.concat(inputs, axis=-1)
x = self.rnn(x)
x = tf.concat([f, l, x], axis=-1)
x = tf.concat(inputs + [x], axis=-1)

# Final processing.
x = self.out_stack(x)
Expand Down
18 changes: 9 additions & 9 deletions ddsp/training/gin/models/ae.gin
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ MfccTimeDistributedRnnEncoder.z_dims = 16
MfccTimeDistributedRnnEncoder.z_time_steps = 125

# Decoder
Autoencoder.decoder = @decoders.ZRnnFcDecoder()
ZRnnFcDecoder.rnn_channels = 512
ZRnnFcDecoder.rnn_type = 'gru'
ZRnnFcDecoder.ch = 512
ZRnnFcDecoder.layers_per_stack = 3
ZRnnFcDecoder.append_f0_loudness = True
ZRnnFcDecoder.output_splits = (('amps', 1),
('harmonic_distribution', 100),
('noise_magnitudes', 65))
Autoencoder.decoder = @decoders.RnnFcDecoder()
RnnFcDecoder.rnn_channels = 512
RnnFcDecoder.rnn_type = 'gru'
RnnFcDecoder.ch = 512
RnnFcDecoder.layers_per_stack = 3
RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled', 'z')
RnnFcDecoder.output_splits = (('amps', 1),
('harmonic_distribution', 100),
('noise_magnitudes', 65))

# Losses
Autoencoder.losses = [
Expand Down
1 change: 1 addition & 0 deletions ddsp/training/gin/models/solo_instrument.gin
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ RnnFcDecoder.rnn_channels = 512
RnnFcDecoder.rnn_type = 'gru'
RnnFcDecoder.ch = 512
RnnFcDecoder.layers_per_stack = 3
RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled')
RnnFcDecoder.output_splits = (('amps', 1),
('harmonic_distribution', 60),
('noise_magnitudes', 65))
Expand Down
2 changes: 1 addition & 1 deletion ddsp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
pulling in all the dependencies in __init__.py.
"""

__version__ = '0.1.0'
__version__ = '0.2.0'

0 comments on commit d4970b2

Please sign in to comment.