Skip to content

Commit

Permalink
update from Zygote.dropgrad to ignore_derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
mrazomej committed Jul 11, 2024
1 parent c0a22ca commit 3ff7d1e
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions src/rhvae.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Import ML libraries
import Flux

# Import ChainRulesCore to ignore functions when computing gradients
using ChainRulesCore: ignore_derivatives

# Import AutoDiff backends
import ChainRulesCore
import TaylorDiff
import Zygote
import ForwardDiff
Expand Down Expand Up @@ -424,7 +426,9 @@ function update_metric(
rhvae::RHVAE{<:VAE{<:AbstractGaussianEncoder,<:AbstractVariationalDecoder}}
)
# Extract centroids_data
centroids_data = Zygote.dropgrad(rhvae.centroids_data)
centroids_data = ignore_derivatives() do
rhvae.centroids_data
end
# Run centroids_data through encoder and update centroids_latent
centroids_latent = rhvae.vae.encoder(centroids_data).µ
# Run centroids_data through metric_chain and update L
Expand Down Expand Up @@ -666,16 +670,16 @@ function _G_inv(
# Compute L_ψᵢ L_ψᵢᵀ exp(-‖z - cᵢ‖₂² / T²). Notes:
# - We use the reshape function to broadcast the operation over the third
# dimension of M.
# - The Zygote.dropgrad function is used to prevent the gradient from being
# - The ignore_derivatives function is used to prevent the gradient from being
# computed with respect to T.
LLexp = M .*
reshape(
exp.(-sum((z .- centroids_latent) .^ 2 / Zygote.dropgrad(T^2), dims=1)),
exp.(-sum((z .- centroids_latent) .^ 2 / ignore_derivatives(T^2), dims=1)),
1, 1, :
)

# Compute the regularization term.
Λ = ChainRulesCore.ignore_derivatives() do
Λ = ignore_derivatives() do
Matrix(LinearAlgebra.I(length(z))) .* λ
end # ignore_derivatives

Expand Down Expand Up @@ -725,10 +729,10 @@ function _G_inv(

# Compute exp(-‖z - cᵢ‖₂² / T²). Notes:
# - We bradcast the operation by reshaping the input arrays.
# - We use Zygot.dropgrad to prevent the gradient from being computed for T.
# - We use ignore_derivatives to prevent the gradient from being computed for T.
# - The result is a 3D array of size (1, n_centroid, n_sample).
exp_term = exp.(-sum(
(z .- centroids_latent) .^ 2 / Zygote.dropgrad(T^2),
(z .- centroids_latent) .^ 2 / ignore_derivatives(T^2),
dims=1
))

Expand All @@ -739,7 +743,7 @@ function _G_inv(
LLexp = M .* exp_term

# Compute the regularization term.
Λ = ChainRulesCore.ignore_derivatives() do
Λ = ignore_derivatives() do
Matrix(LinearAlgebra.I(size(z, 1))) .* λ
end # ignore_derivatives

Expand Down Expand Up @@ -3129,7 +3133,9 @@ function general_leapfrog_tempering_step(
tempering_schedule::Function=quadratic_tempering,
)
# Sample γₒ ~ N(0, Gₒ⁻¹).
γₒ = ChainRulesCore.@ignore_derivatives sample_MvNormalCanon(Gₒ⁻¹)
γₒ = ignore_derivatives() do
sample_MvNormalCanon(Gₒ⁻¹)
end

# Define ρₒ = γₒ / √βₒ
ρₒ = γₒ ./ (βₒ)
Expand Down Expand Up @@ -3273,7 +3279,9 @@ function general_leapfrog_tempering_step(
logdetGₒ = -slogdet(Gₒ⁻¹)

# Sample γₒ ~ N(0, Gₒ⁻¹).
γₒ = ChainRulesCore.@ignore_derivatives sample_MvNormalCanon(Gₒ⁻¹)
γₒ = ignore_derivatives() do
sample_MvNormalCanon(Gₒ⁻¹)
end

# Define ρₒ = γₒ / √βₒ
ρₒ = γₒ ./ (βₒ)
Expand Down

0 comments on commit 3ff7d1e

Please sign in to comment.