diff --git a/ddsp/core.py b/ddsp/core.py index 5ad7eac7..fe59c98e 100644 --- a/ddsp/core.py +++ b/ddsp/core.py @@ -106,6 +106,11 @@ def soft_limit(x, x_min=0.0, x_max=1.0): return tf.nn.softplus(x) + x_min - tf.nn.softplus(x - (x_max - x_min)) +def gradient_reversal(x): + """Identity operation that reverses the gradient.""" + return tf.stop_gradient(2.0 * x) - x + + # Unit Conversions ------------------------------------------------------------- def midi_to_hz(notes: Number) -> Number: """TF-compatible midi_to_hz function."""