From 79b92c96f761a10525d7e9f110b6f6ec9459d3b0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Feb 2023 00:02:14 +0000 Subject: [PATCH 001/172] initial work on VecCorrBijector --- src/bijectors/corr.jl | 222 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 190 insertions(+), 32 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 252ecc68..f4cab829 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -78,19 +78,7 @@ function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) return w' * w end -function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) - K = LinearAlgebra.checksquare(y) - - result = float(zero(eltype(y))) - for j in 2:K, i in 1:(j - 1) - @inbounds abs_y_i_j = abs(y[i, j]) - result += (K - i + 1) * ( - IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) - ) - end - - return result -end +logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y) function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) #= It may be more efficient if we can use un-contraint value to prevent call of b @@ -98,28 +86,159 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) `logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})` if possible. =# - return -logabsdetjac(inverse(b), (b(X))) + return -logabsdetjac(inverse(b), (b(X))) end -function _inv_link_chol_lkj(y) - K = LinearAlgebra.checksquare(y) +""" + VecCorrBijector <: Bijector - w = similar(y) +Similar to `CorrBijector`, but transforms a vector representing the Cholesky +to a correlation matrix, and its inverse transforms correlation matrix to vector +representing Cholesky. + +See also: [`CorrBijector`](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCorrBijector(); + +julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. +3×3 Matrix{Float64}: + 1.0 -0.705273 -0.348638 + -0.705273 1.0 0.0534538 + -0.348638 0.0534538 1.0 + +julia> # Get the cholesky and convert to a vector. + u = Bijectors.triu1_to_vec(cholesky(X).U) + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> inverse(b)(y) ≈ u # (✓) Round-trip through `b` and its inverse. +true +""" +struct VecCorrBijector <: Bijector end +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +function triu_mask(X::AbstractMatrix, k::Int) + # Ensure that we're working with a square matrix. + LinearAlgebra.checksquare(X) - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] - w[i-1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j+1):K - w[i, j] = 0 + # Using `similar` allows us to respect device of array, etc., e.g. `CuArray`. + m = similar(X, Bool) + return triu(.~m .| m, k) +end + +triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)] + +function update_triu_from_vec!( + vals::AbstractVector{<:Real}, + k::Int, + X::AbstractMatrix{<:Real} +) + # Ensure that we're working with one-based indexing. + # `triu` requires this too. + LinearAlgebra.require_one_based_indexing(X) + + # Set the values. + idx = 1 + m, n = size(X) + for j = 1:n + for i = 1:min(j - k, m) + X[i, j] = vals[idx] + idx += 1 end end - - return w + + return X +end + +function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int) + X = similar(vals, dim, dim) + # TODO: Do we need this? + X .= 0 + return update_triu_from_vec!(vals, k, X) +end + +function ChainRulesCore.rrule(::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int) + function update_triu_from_vec_pullback(ΔX) + return ( + ChainRulesCore.NoTangent(), + triu_to_vec(ChainRulesCore.unthunk(ΔX), k), + ChainRulesCore.NoTangent(), + ChainRulesCore.NoTangent() + ) + end + return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback +end + +""" + triu1_to_vec(X::AbstractMatrix{<:Real}) + +Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector. +""" +triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1) + +inverse(::typeof(triu1_to_vec)) = vec_to_triu1 + +""" + vec_to_triu1(x::AbstractVector{<:Real}) + +Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`. +""" +function vec_to_triu1(x::AbstractVector) + n = _triu1_dim_from_length(length(x)) + X = update_triu_from_vec(x, 1, n) + return UpperTriangular(X) +end + +inverse(::typeof(vec_to_triu1)) = triu1_to_vec + +# n * (n - 1) / 2 = d +# ⟺ n^2 - n - 2d = 0 +# ⟹ n = (1 + sqrt(1 + 8d)) / 2 +_triu1_dim_from_length(d) = Int((1 + sqrt(1 + 8d)) / 2) + +function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) + w = cholesky(X).U # keep LowerTriangular until here can avoid some computation + r = _link_chol_lkj(w) + + # Extract only the upper triangle of `r`. + return triu1_to_vec(r) +end + +# NOTE: The `logabsdetjac` is NOT the correcet on for this `transform`. +# The `logabsdetjac` implementation also includes the `logabsdetjac` of the +# cholesky decomposition, which is only valid if we're working on the space of +# postitive-definite matrices. +function transform(::VecCorrBijector, chol_vec::AbstractVector{<:Real}) + r = _link_chol_lkj(vec_to_triu1(chol_vec)) + + # Extract only the upper triangle of `r`. + return triu1_to_vec(r) +end + + +function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + Y = vec_to_triu1(y) + w = _inv_link_chol_lkj(Y) + # TODO: Should we just return `w` instead? + return triu1_to_vec(w) +end + +function logabsdetjac(b::VecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end +function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + return _logabsdetjac_chol_lkj(vec_to_triu1(y)) end """ @@ -163,16 +282,55 @@ function _link_chol_lkj(w) # This block can't be integrated with loop below, because w[1,1] != 0. @inbounds z[1, 1] = 0 - @inbounds for j=2:K + @inbounds for j = 2:K z[1, j] = atanh(w[1, j]) tmp = sqrt(1 - w[1, j]^2) - for i in 2:(j - 1) + for i in 2:(j-1) p = w[i, j] / tmp tmp *= sqrt(1 - p^2) z[i, j] = atanh(p) end z[j, j] = 0 end - + return z end + +""" + _inv_link_chol_lkj(y) + +Inverse link function for cholesky factor. +""" +function _inv_link_chol_lkj(y) + K = LinearAlgebra.checksquare(y) + + w = similar(y) + + @inbounds for j in 1:K + w[1, j] = 1 + for i in 2:j + z = tanh(y[i-1, j]) + tmp = w[i-1, j] + w[i-1, j] = z * tmp + w[i, j] = tmp * sqrt(1 - z^2) + end + for i in (j+1):K + w[i, j] = 0 + end + end + + return w +end + +function _logabsdetjac_chol_lkj(Y::AbstractMatrix) + K = LinearAlgebra.checksquare(Y) + + result = float(zero(eltype(Y))) + for j in 2:K, i in 1:(j-1) + @inbounds abs_y_i_j = abs(Y[i, j]) + result += (K - i + 1) * ( + IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) + ) + end + return result +end From aa2fe610f52a95714ced5743fb145bd5120ffbfe Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Feb 2023 00:19:28 +0000 Subject: [PATCH 002/172] added some tests for CorrBijector, and fixed implementation for VecCorrBijector --- src/bijectors/corr.jl | 25 ++++--------------------- src/bijectors/pd.jl | 2 ++ test/bijectors/corr.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 21 deletions(-) create mode 100644 test/bijectors/corr.jl diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f4cab829..613863cd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -92,9 +92,8 @@ end """ VecCorrBijector <: Bijector -Similar to `CorrBijector`, but transforms a vector representing the Cholesky -to a correlation matrix, and its inverse transforms correlation matrix to vector -representing Cholesky. +Similar to `CorrBijector`, but correlation matrix to a vector, +and its inverse transforms vector to a correlation matrix. See also: [`CorrBijector`](@ref) @@ -113,16 +112,13 @@ julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. -0.705273 1.0 0.0534538 -0.348638 0.0534538 1.0 -julia> # Get the cholesky and convert to a vector. - u = Bijectors.triu1_to_vec(cholesky(X).U) - julia> y = b(X) # Transform to unconstrained vector representation. 3-element Vector{Float64}: -0.8777149781928181 -0.3638927608636788 -0.29813769428942216 -julia> inverse(b)(y) ≈ u # (✓) Round-trip through `b` and its inverse. +julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ struct VecCorrBijector <: Bijector end @@ -215,23 +211,10 @@ function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) return triu1_to_vec(r) end -# NOTE: The `logabsdetjac` is NOT the correcet on for this `transform`. -# The `logabsdetjac` implementation also includes the `logabsdetjac` of the -# cholesky decomposition, which is only valid if we're working on the space of -# postitive-definite matrices. -function transform(::VecCorrBijector, chol_vec::AbstractVector{<:Real}) - r = _link_chol_lkj(vec_to_triu1(chol_vec)) - - # Extract only the upper triangle of `r`. - return triu1_to_vec(r) -end - - function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) Y = vec_to_triu1(y) w = _inv_link_chol_lkj(Y) - # TODO: Should we just return `w` instead? - return triu1_to_vec(w) + return w' * w end function logabsdetjac(b::VecCorrBijector, x) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index bed6ee9a..c9c0ff8a 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -18,6 +18,8 @@ function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) return getpd(X) end + +# TODO: AFAIK this is used because of AD-related issues; can we remove? getpd(X) = LowerTriangular(X) * LowerTriangular(X)' function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl new file mode 100644 index 00000000..7dbef68c --- /dev/null +++ b/test/bijectors/corr.jl @@ -0,0 +1,35 @@ +using Bijectors, DistributionsAD, LinearAlgebra, Test +using Bijectors: VecCorrBijector, CorrBijector + +@testset "PDBijector" begin + d = 3 + + b = CorrBijector() + bvec = VecCorrBijector() + + dist = LKJ(d, 1) + x = rand(dist) + + y = b(x) + yvec = bvec(x) + + # Make sure that they represent the same thing. + @test Bijectors.triu1_to_vec(y) ≈ yvec + + # Check the inverse. + binv = inverse(b) + xinv = binv(y) + bvecinv = inverse(bvec) + xvecinv = bvecinv(yvec) + + @test xinv ≈ xvecinv + + # And finally that the `logabsdetjac` is the same. + @test logabsdetjac(bvec, x) ≈ logabsdetjac(b, x) + + # NOTE: `CorrBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) + test_bijector(bvec, x; test_not_identity=true, changes_of_variables_test=false) +end From 8d2309453e0d0ef85fde7bd42e592a489330bb3e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Feb 2023 00:26:48 +0000 Subject: [PATCH 003/172] improved tests and are now using integer sqrt and division --- src/bijectors/corr.jl | 2 +- test/bijectors/corr.jl | 46 +++++++++++++++++++++--------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 613863cd..5899d250 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -201,7 +201,7 @@ inverse(::typeof(vec_to_triu1)) = triu1_to_vec # n * (n - 1) / 2 = d # ⟺ n^2 - n - 2d = 0 # ⟹ n = (1 + sqrt(1 + 8d)) / 2 -_triu1_dim_from_length(d) = Int((1 + sqrt(1 + 8d)) / 2) +_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) w = cholesky(X).U # keep LowerTriangular until here can avoid some computation diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 7dbef68c..ceb1cb4a 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -2,34 +2,34 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test using Bijectors: VecCorrBijector, CorrBijector @testset "PDBijector" begin - d = 3 + for d ∈ [1, 2, 5] + b = CorrBijector() + bvec = VecCorrBijector() - b = CorrBijector() - bvec = VecCorrBijector() + dist = LKJ(d, 1) + x = rand(dist) - dist = LKJ(d, 1) - x = rand(dist) + y = b(x) + yvec = bvec(x) - y = b(x) - yvec = bvec(x) + # Make sure that they represent the same thing. + @test Bijectors.triu1_to_vec(y) ≈ yvec - # Make sure that they represent the same thing. - @test Bijectors.triu1_to_vec(y) ≈ yvec + # Check the inverse. + binv = inverse(b) + xinv = binv(y) + bvecinv = inverse(bvec) + xvecinv = bvecinv(yvec) - # Check the inverse. - binv = inverse(b) - xinv = binv(y) - bvecinv = inverse(bvec) - xvecinv = bvecinv(yvec) + @test xinv ≈ xvecinv - @test xinv ≈ xvecinv + # And finally that the `logabsdetjac` is the same. + @test logabsdetjac(bvec, x) ≈ logabsdetjac(b, x) - # And finally that the `logabsdetjac` is the same. - @test logabsdetjac(bvec, x) ≈ logabsdetjac(b, x) - - # NOTE: `CorrBijector` technically isn't bijective, and so the default `getjacobian` - # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. - # Hence, we disable those tests. - test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) - test_bijector(bvec, x; test_not_identity=true, changes_of_variables_test=false) + # NOTE: `CorrBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) + end end From a35e36ff43cb66595e41a106f6a5b662c06b6a38 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 12 Feb 2023 15:23:07 +0000 Subject: [PATCH 004/172] moved things around a bit --- src/bijectors/corr.jl | 79 +++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5899d250..84cdb2e8 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -90,40 +90,10 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) end """ - VecCorrBijector <: Bijector - -Similar to `CorrBijector`, but correlation matrix to a vector, -and its inverse transforms vector to a correlation matrix. - -See also: [`CorrBijector`](@ref) - -# Example - -```jldoctest -julia> using LinearAlgebra - -julia> using StableRNGs; rng = StableRNG(42); - -julia> b = Bijectors.VecCorrBijector(); - -julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. -3×3 Matrix{Float64}: - 1.0 -0.705273 -0.348638 - -0.705273 1.0 0.0534538 - -0.348638 0.0534538 1.0 - -julia> y = b(X) # Transform to unconstrained vector representation. -3-element Vector{Float64}: - -0.8777149781928181 - -0.3638927608636788 - -0.29813769428942216 + triu_mask(X::AbstractMatrix, k::Int) -julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. -true +Return a mask for elements of `X` above the `k`th diagonal. """ -struct VecCorrBijector <: Bijector end -with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) - function triu_mask(X::AbstractMatrix, k::Int) # Ensure that we're working with a square matrix. LinearAlgebra.checksquare(X) @@ -176,6 +146,11 @@ function ChainRulesCore.rrule(::typeof(update_triu_from_vec), x::AbstractVector{ return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback end +# n * (n - 1) / 2 = d +# ⟺ n^2 - n - 2d = 0 +# ⟹ n = (1 + sqrt(1 + 8d)) / 2 +_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 + """ triu1_to_vec(X::AbstractMatrix{<:Real}) @@ -198,10 +173,40 @@ end inverse(::typeof(vec_to_triu1)) = triu1_to_vec -# n * (n - 1) / 2 = d -# ⟺ n^2 - n - 2d = 0 -# ⟹ n = (1 + sqrt(1 + 8d)) / 2 -_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 +""" + VecCorrBijector <: Bijector + +Similar to `CorrBijector`, but correlation matrix to a vector, +and its inverse transforms vector to a correlation matrix. + +See also: [`CorrBijector`](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCorrBijector(); + +julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. +3×3 Matrix{Float64}: + 1.0 -0.705273 -0.348638 + -0.705273 1.0 0.0534538 + -0.348638 0.0534538 1.0 + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. +true +""" +struct VecCorrBijector <: Bijector end +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) w = cholesky(X).U # keep LowerTriangular until here can avoid some computation @@ -240,7 +245,7 @@ end But this implementation will not work when w[i-1, j] = 0. Though it is a zero measure set, unit matrix initialization will not work. -For equivelence, following explanations is given by @torfjelde: +For equivalence, following explanations is given by @torfjelde: For `(i, j)` in the loop below, we define From 8cadf69190eacfa37f826b7e1dbdfab6cd4cf6a2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Feb 2023 09:49:05 +0000 Subject: [PATCH 005/172] added chainrule for ReverseDiff --- src/compat/reversediff.jl | 4 +++- test/bijectors/corr.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index c498205a..15369b38 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,7 +1,7 @@ module ReverseDiffCompat using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, - TrackedMatrix + TrackedMatrix, @grad_from_chainrules using Requires, LinearAlgebra using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, @@ -181,6 +181,8 @@ end return y, (wrap_chainrules_output ∘ Base.tail ∘ dy) end +@grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) + # NOTE: Probably doesn't work in complete generality. wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index ceb1cb4a..16b7de9c 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,7 +1,7 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test using Bijectors: VecCorrBijector, CorrBijector -@testset "PDBijector" begin +@testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] b = CorrBijector() bvec = VecCorrBijector() From eaf5324de656bd3db870ff0b492fe7ef4de1edee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Feb 2023 10:50:51 +0000 Subject: [PATCH 006/172] some fixes for AD --- src/bijectors/corr.jl | 10 +++++----- src/bijectors/pd.jl | 8 ++------ src/compat/reversediff.jl | 29 +++++++++++++++++++++++------ src/compat/tracker.jl | 10 +++++----- src/compat/zygote.jl | 6 +++--- src/utils.jl | 7 +++++++ 6 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 84cdb2e8..e043b2bf 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -66,7 +66,7 @@ struct CorrBijector <: Bijector end with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) - w = cholesky(x).U # keep LowerTriangular until here can avoid some computation + w = upper_triangular(parent(cholesky(x).U)) # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) return r + zero(x) # This dense format itself is required by a test, though I can't get the point. @@ -75,7 +75,7 @@ end function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) w = _inv_link_chol_lkj(y) - return w' * w + return pd_from_upper(w) end logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y) @@ -168,7 +168,7 @@ Constructs a matrix from a vector `x` by filling the upper triangle with offset function vec_to_triu1(x::AbstractVector) n = _triu1_dim_from_length(length(x)) X = update_triu_from_vec(x, 1, n) - return UpperTriangular(X) + return upper_triangular(X) end inverse(::typeof(vec_to_triu1)) = triu1_to_vec @@ -209,7 +209,7 @@ struct VecCorrBijector <: Bijector end with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) - w = cholesky(X).U # keep LowerTriangular until here can avoid some computation + w = upper_triangular(parent(cholesky(X).U)) r = _link_chol_lkj(w) # Extract only the upper triangle of `r`. @@ -219,7 +219,7 @@ end function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) Y = vec_to_triu1(y) w = _inv_link_chol_lkj(Y) - return w' * w + return pd_from_upper(w) end function logabsdetjac(b::VecCorrBijector, x) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index c9c0ff8a..3ba5526b 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -9,19 +9,15 @@ function replace_diag(f, X) end transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X) function pd_link(X) - Y = lower(parent(cholesky(X; check = true).L)) + Y = lower_triangular(parent(cholesky(X; check = true).L)) return replace_diag(log, Y) end -lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) - return getpd(X) + return pd_from_lower(X) end -# TODO: AFAIK this is used because of AD-related issues; can we remove? -getpd(X) = LowerTriangular(X) * LowerTriangular(X)' - function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) T = eltype(X) Xcf = cholesky(X, check = false) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 15369b38..78871ce1 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -7,7 +7,8 @@ using Requires, LinearAlgebra using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, simplex_invlink_jacobian, simplex_logabsdetjac_gradient, Inverse import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, - _simplex_inv_bijector, replace_diag, jacobian, getpd, lower, + _simplex_inv_bijector, replace_diag, jacobian, pd_from_lower, pd_from_upper, + lower_triangular, upper_triangular, _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, find_alpha @@ -136,18 +137,34 @@ logabsdetjac(b::SimplexBijector, x::Union{TrackedVector, TrackedMatrix}) = track end end -getpd(X::TrackedMatrix) = track(getpd, X) -@grad function getpd(X::AbstractMatrix) +pd_from_lower(X::TrackedMatrix) = track(pd_from_lower, X) +@grad function pd_from_lower(X::AbstractMatrix) Xd = value(X) return LowerTriangular(Xd) * LowerTriangular(Xd)', Δ -> begin Xl = LowerTriangular(Xd) return (LowerTriangular(Δ' * Xl + Δ * Xl),) end end -lower(A::TrackedMatrix) = track(lower, A) -@grad function lower(A::AbstractMatrix) + +pd_from_upper(X::TrackedMatrix) = track(pd_from_upper, X) +@grad function pd_from_upper(X::AbstractMatrix) + Xd = value(X) + return UpperTriangular(Xd)' * UpperTriangular(Xd), Δ -> begin + Xu = UpperTriangular(Xd) + return (UpperTriangular(Δ * Xu + Δ' * Xu),) + end +end + +lower_triangular(A::TrackedMatrix) = track(lower_triangular, A) +@grad function lower_triangular(A::AbstractMatrix) + Ad = value(A) + return lower_triangular(Ad), Δ -> (lower_triangular(Δ),) +end + +upper_triangular(A::TrackedMatrix) = track(upper_triangular, A) +@grad function upper_triangular(A::AbstractMatrix) Ad = value(A) - return lower(Ad), Δ -> (lower(Δ),) + return upper_triangular(Ad), Δ -> (upper_triangular(Δ),) end function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal} diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 1166a29e..4763d724 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -281,8 +281,8 @@ end (b::Elementwise{typeof(log)})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) (b::Elementwise{typeof(log)})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) -Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) -@grad function Bijectors.getpd(X::AbstractMatrix) +Bijectors.pd_from_lower(X::TrackedMatrix) = track(Bijectors.pd_from_lower, X) +@grad function Bijectors.pd_from_lower(X::AbstractMatrix) Xd = data(X) return Bijectors.LowerTriangular(Xd) * Bijectors.LowerTriangular(Xd)', Δ -> begin Xl = Bijectors.LowerTriangular(Xd) @@ -290,10 +290,10 @@ Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) end end -Bijectors.lower(A::TrackedMatrix) = track(Bijectors.lower, A) -@grad function Bijectors.lower(A::AbstractMatrix) +Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, A) +@grad function Bijectors.lower_triangular(A::AbstractMatrix) Ad = data(A) - return Bijectors.lower(Ad), Δ -> (Bijectors.lower(Δ),) + return Bijectors.lower_triangular(Ad), Δ -> (Bijectors.lower_triangular(Δ),) end Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index eedf4b3d..6a81a749 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -158,10 +158,10 @@ end end return pullback(_maximum, d) end -@adjoint function lower(A::AbstractMatrix) - return lower(A), Δ -> (lower(Δ),) +@adjoint function lower_triangular(A::AbstractMatrix) + return lower_triangular(A), Δ -> (lower_triangular(Δ),) end -@adjoint function getpd(X::AbstractMatrix) +@adjoint function pd_from_lower(X::AbstractMatrix) return LowerTriangular(X) * LowerTriangular(X)', Δ -> begin Xl = LowerTriangular(X) return (LowerTriangular(Δ' * Xl + Δ * Xl),) diff --git a/src/utils.jl b/src/utils.jl index 8203e1b4..dca95731 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,3 +6,10 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b) # flatten arrays with fallback for scalars _vec(x::AbstractArray{<:Real}) = vec(x) _vec(x::Real) = x + +# # Because `ReverseDiff` does not play well with structural matrices. +lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) +upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) + +pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' +pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) From 36ffbdb9b86e4bc417095fd38bbde92d6f32f5b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Feb 2023 10:55:07 +0000 Subject: [PATCH 007/172] added some TODOs --- src/bijectors/corr.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index e043b2bf..6c4f3b97 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -262,6 +262,7 @@ and so which is the above implementation. """ function _link_chol_lkj(w) + # TODO: Implement adjoint to support reverse-mode AD backends properly. K = LinearAlgebra.checksquare(w) z = similar(w) # z is also UpperTriangular. @@ -290,6 +291,7 @@ end Inverse link function for cholesky factor. """ function _inv_link_chol_lkj(y) + # TODO: Implement adjoint to support reverse-mode AD backends properly. K = LinearAlgebra.checksquare(y) w = similar(y) From 62ae1ace63f8ff5b591ca2b5e2b139755784792a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Mar 2023 13:11:45 +0000 Subject: [PATCH 008/172] Update src/bijectors/corr.jl --- src/bijectors/corr.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 6c4f3b97..cedda705 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -291,7 +291,6 @@ end Inverse link function for cholesky factor. """ function _inv_link_chol_lkj(y) - # TODO: Implement adjoint to support reverse-mode AD backends properly. K = LinearAlgebra.checksquare(y) w = similar(y) From 3f25a8bf33ba00052954cfeaf021d60b65495540 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 4 Apr 2023 15:17:17 +0100 Subject: [PATCH 009/172] define bijectors for `LKJ` and `LKJCholesky` --- src/transformed_distribution.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 3b21bf4e..a130a41e 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -77,7 +77,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() bijector(d::MatrixBeta) = PDBijector() -bijector(d::LKJ) = CorrBijector() +bijector(d::LKJ) = VecCorrBijector() +bijector(d::LKJCholesky) = d.uplo === 'L' ? VecTrilBijector() : VecTriuBijector() ############################## # Distributions.jl interface # From e1567c35cac4e96c65354738897e330cc452066f Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:49:55 +0100 Subject: [PATCH 010/172] add `TransformedDistribution` constructor for `LKJCholesky` --- src/transformed_distribution.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index a130a41e..7e51cbe6 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -6,6 +6,7 @@ struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D< TransformedDistribution(d::UnivariateDistribution, b) = new{typeof(d), typeof(b), Univariate}(d, b) TransformedDistribution(d::MultivariateDistribution, b) = new{typeof(d), typeof(b), Multivariate}(d, b) TransformedDistribution(d::MatrixDistribution, b) = new{typeof(d), typeof(b), Matrixvariate}(d, b) + TransformedDistribution(d::Distribution{CholeskyVariate}, b) = new{typeof(d), typeof(b), CholeskyVariate}(d, b) end # fields may contain nested numerical parameters From 8d07e342f451982f8860439bb872906494533585 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:50:27 +0100 Subject: [PATCH 011/172] define `logpdf` for `LKJ` & `LKJCholesky` --- src/transformed_distribution.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 7e51cbe6..dd045359 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -109,6 +109,11 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) return logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac end +function logpdf(td::TransformedDistribution{T}, y::AbstractVector{<:Real}) where {T <: Union{LKJ, LKJCholesky}} + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, x) + logjac +end + function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) return logpdf(td.dist, x) + logjac From 9a59a9f796ba9dd2a83abb52531fc63d41605d46 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:50:53 +0100 Subject: [PATCH 012/172] define `rand` for `LKJ` & `LKJCholesky` --- src/transformed_distribution.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index dd045359..b1cf82a3 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -161,6 +161,10 @@ function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) x .= td.transform(x) end +function rand(rng::AbstractRNG, td::TransformedDistribution{T}) where {T <: Union{LKJ, LKJCholesky}} + return td.transform(rand(rng, td.dist)) +end + # utility stuff Distributions.params(td::Transformed) = Distributions.params(td.dist) function Base.maximum(td::UnivariateTransformed) From f15ad85a55e015cb774dfa048988773939b786fd Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:51:28 +0100 Subject: [PATCH 013/172] add util to extract Cholesky factor --- src/utils.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index dca95731..974a28a6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,3 +13,8 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) + +cholesky_factor(X::AbstractMatrix) = cholesky(X).UL +cholesky_factor(X::Cholesky) = X.UL +cholesky_factor(X::UpperTriangular) = X +cholesky_factor(X::LowerTriangular) = X From 53e78f3be1554b7add086c9111b2795ba6a9911e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:52:51 +0100 Subject: [PATCH 014/172] TYPO: capitalize matrix --- src/bijectors/corr.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 6c4f3b97..b44606fc 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -65,10 +65,10 @@ struct CorrBijector <: Bijector end with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) -function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) - w = upper_triangular(parent(cholesky(x).U)) # keep LowerTriangular until here can avoid some computation +function transform(b::CorrBijector, X::AbstractMatrix{<:Real}) + w = upper_triangular(parent(cholesky(X).U)) # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) - return r + zero(x) + return r + zero(X) # This dense format itself is required by a test, though I can't get the point. # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end From ec7d20e1bfae179032ced6b01c41ce8918da7d81 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:53:44 +0100 Subject: [PATCH 015/172] add util to convert `Vector` index to `Matrix` row index --- src/bijectors/corr.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index b44606fc..909a3b8a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -173,6 +173,13 @@ end inverse(::typeof(vec_to_triu1)) = triu1_to_vec +function vec_to_triu1_row_index(idx) + # Assumes that vector was saved in a column-major order + # and that vector is one-based indexed. + M = _triu1_dim_from_length(idx - 1) + return idx - (M*(M-1) ÷ 2) +end + """ VecCorrBijector <: Bijector From 2ed00f4949a8b9a66b4a5125920f8348f9d9c182 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:55:16 +0100 Subject: [PATCH 016/172] add `VecTriBijector`s for `LKJCholesky` --- src/bijectors/corr.jl | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 909a3b8a..d79638ba 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -180,6 +180,16 @@ function vec_to_triu1_row_index(idx) return idx - (M*(M-1) ÷ 2) end +abstract type AbstractVecCorrBijector <: Bijector end + +with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::AbstractVecCorrBijector, X) = (_link_chol_lkj ∘ cholesky_factor)(X) + +function logabsdetjac(b::AbstractVecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end + """ VecCorrBijector <: Bijector @@ -212,29 +222,21 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector <: Bijector end -with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) +struct VecCorrBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = (pd_from_upper ∘ _inv_link_chol_lkj)(y) -function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) - w = upper_triangular(parent(cholesky(X).U)) - r = _link_chol_lkj(w) +logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) - # Extract only the upper triangle of `r`. - return triu1_to_vec(r) -end +struct VecTriuBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ UpperTriangular ∘ _inv_link_chol_lkj)(y) -function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - Y = vec_to_triu1(y) - w = _inv_link_chol_lkj(Y) - return pd_from_upper(w) -end +logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) + +struct VecTrilBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ LowerTriangular ∘ transpose ∘ _inv_link_chol_lkj)(y) + +logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) -function logabsdetjac(b::VecCorrBijector, x) - return -logabsdetjac(inverse(b), b(x)) -end -function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - return _logabsdetjac_chol_lkj(vec_to_triu1(y)) -end """ function _link_chol_lkj(w) From 07555fc713ec198b3f8351c6ae840a3bd7d64daa Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:55:53 +0100 Subject: [PATCH 017/172] TYPO: capitilize matrix --- src/bijectors/corr.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index d79638ba..5b010dcd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -270,21 +270,21 @@ and so which is the above implementation. """ -function _link_chol_lkj(w) +function _link_chol_lkj(W::AbstractMatrix) # TODO: Implement adjoint to support reverse-mode AD backends properly. - K = LinearAlgebra.checksquare(w) + K = LinearAlgebra.checksquare(W) - z = similar(w) # z is also UpperTriangular. + z = similar(W) # z is also UpperTriangular. # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. - # This block can't be integrated with loop below, because w[1,1] != 0. + # This block can't be integrated with loop below, because W[1,1] != 0. @inbounds z[1, 1] = 0 @inbounds for j = 2:K - z[1, j] = atanh(w[1, j]) - tmp = sqrt(1 - w[1, j]^2) + z[1, j] = atanh(W[1, j]) + tmp = sqrt(1 - W[1, j]^2) for i in 2:(j-1) - p = w[i, j] / tmp + p = W[i, j] / tmp tmp *= sqrt(1 - p^2) z[i, j] = atanh(p) end From a75cabca7609a6af2ef1ddbfdc60d4b980a9b61a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:56:44 +0100 Subject: [PATCH 018/172] add `LKJCholesky` link for `UpperTriangular` --- src/bijectors/corr.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5b010dcd..47aa4c50 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -294,6 +294,30 @@ function _link_chol_lkj(W::AbstractMatrix) return z end +function _link_chol_lkj(W::UpperTriangular) + K = LinearAlgebra.checksquare(W) + N = ((K-1)*K) ÷ 2 # {K \choose 2} free parameters + + z = zeros(eltype(W), N) + + # This block can't be integrated with loop below, because w[1,1] != 0. + idx = 1 + @inbounds for j = 2:K + z[idx] = atanh(W[1, j]) + idx += 1 + tmp = sqrt(1 - W[1, j]^2) + for i in 2:(j-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + z[idx] = atanh(p) + idx += 1 + end + end + + return z +end + +function _link_chol_lkj(W::LowerTriangular) """ _inv_link_chol_lkj(y) From 844b07ea58041236daed91d8987379f271889b84 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:57:05 +0100 Subject: [PATCH 019/172] add `LKJCholesky` link for `LowerTriangular` --- src/bijectors/corr.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 47aa4c50..596347c1 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -318,6 +318,28 @@ function _link_chol_lkj(W::UpperTriangular) end function _link_chol_lkj(W::LowerTriangular) + K = LinearAlgebra.checksquare(W) + N = div((K-1)*K, 2) # {K \choose 2} free parameters + + z = zeros(eltype(W), N) + + # This block can't be integrated with loop below, because w[1,1] != 0. + idx = 1 + @inbounds for i = 2:K + z[idx] = atanh(W[i, 1]) + idx += 1 + tmp = sqrt(1 - W[i, 1]^2) + for j in 2:(i-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + z[idx] = atanh(p) + idx += 1 + end + end + + return z +end + """ _inv_link_chol_lkj(y) From 792cfe9d7bbb12eeffef0e1be0faf22396bc1179 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:57:54 +0100 Subject: [PATCH 020/172] TYPO: capitalize matrix --- src/bijectors/corr.jl | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 596347c1..bc80f61a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -345,22 +345,28 @@ end Inverse link function for cholesky factor. """ -function _inv_link_chol_lkj(y) +function _inv_link_chol_lkj(Y::AbstractMatrix) # TODO: Implement adjoint to support reverse-mode AD backends properly. - K = LinearAlgebra.checksquare(y) + K = LinearAlgebra.checksquare(Y) - w = similar(y) + W = similar(Y) @inbounds for j in 1:K - w[1, j] = 1 + W[1, j] = 1 for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] - w[i-1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) + z = tanh(Y[i-1, j]) + tmp = W[i-1, j] + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) end for i in (j+1):K - w[i, j] = 0 + W[i, j] = 0 + end + end + + return W +end + end end From 8f0886b327cd33efe1408dbf4b1950265c061fa3 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:58:50 +0100 Subject: [PATCH 021/172] add `LKJCholesky` inverse link to `UpperTriangular` --- src/bijectors/corr.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index bc80f61a..89d34aee 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -367,10 +367,26 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) return W end +function _inv_link_chol_lkj(y::AbstractVector) + # TODO: Implement adjoint to support reverse-mode AD backends properly. + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + W .= zeros(eltype(y)) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + idx += 1 + tmp = W[i-1, j] + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) end end - return w + return W end function _logabsdetjac_chol_lkj(Y::AbstractMatrix) From 35f1c035f42899b65aa346f62c5863a9f2e021d6 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:00:14 +0100 Subject: [PATCH 022/172] rename `_logabsdetjac_chol_lkj` to `_logabsdetjac_inv_corr` --- src/bijectors/corr.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 89d34aee..3496a016 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -78,7 +78,7 @@ function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) return pd_from_upper(w) end -logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y) +logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_inv_corr(Y) function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) #= It may be more efficient if we can use un-contraint value to prevent call of b @@ -389,7 +389,7 @@ function _inv_link_chol_lkj(y::AbstractVector) return W end -function _logabsdetjac_chol_lkj(Y::AbstractMatrix) +function _logabsdetjac_inv_corr(Y::AbstractMatrix) K = LinearAlgebra.checksquare(Y) result = float(zero(eltype(Y))) From 9d558292b4405d00b231e0ceb09081c4d69d8046 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:01:50 +0100 Subject: [PATCH 023/172] dispatch `_logabsdetjac_inv_corr` for `::Vector` --- src/bijectors/corr.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 3496a016..2ffe9ee3 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -401,3 +401,18 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix) end return result end + +function _logabsdetjac_inv_corr(y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + result = float(zero(eltype(y))) + for (i, y_i) in enumerate(y) + abs_y_i = abs(y_i) + row_idx = vec_to_triu1_row_index(i) + result += (K - row_idx + 1) * ( + IrrationalConstants.logtwo - (abs_y_i + LogExpFunctions.log1pexp(-2 * abs_y_i)) + ) + end + return result +end + From adf10ad6f73cb0d6656a0b03a80206788de1366e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:02:52 +0100 Subject: [PATCH 024/172] add logabsdetjac for inverse link of `LKJCholesky` --- src/bijectors/corr.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 2ffe9ee3..eb915b75 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -416,3 +416,21 @@ function _logabsdetjac_inv_corr(y::AbstractVector) return result end +function _logabsdetjac_inv_chol(y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + result = float(zero(eltype(y))) + idx = 1 + @inbounds for j in 2:K + tmp = zero(result) + for _ in 1:(j-1) + z = tanh(y[idx]) + logz = log(1 - z^2) + tmp += logz + result += logz + (tmp / 2) + idx += 1 + end + end + + return result +end From 03a55b23bd5145540efd91ba69fe32b5bd851562 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:44:06 +0100 Subject: [PATCH 025/172] add tests for `VecTriBijector`s --- test/bijectors/corr.jl | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 16b7de9c..900b5e68 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,5 +1,5 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, CorrBijector +using Bijectors: VecCorrBijector, CorrBijector, VecTriuBijector, VecTrilBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] @@ -33,3 +33,29 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) end end + +@testset "VecTriuBijector & VecTrilBijector" begin + for d ∈ [2, 5] + for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] + b = bijector(dist) + + b_lkj = VecCorrBijector() + x = rand(dist) + y = b(x) + y_lkj = b_lkj(x) + + @test y ≈ y_lkj + + binv = inverse(b) + xinv = binv(y) + binv_lkj = inverse(b_lkj) + xinv_lkj = binv_lkj(y_lkj) + + @test xinv.U ≈ cholesky(xinv_lkj).U + + # test_bijector is commented out for now, + # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) + # test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + end + end +end From 1059569696ff363718b9a6edacd12fc514656a9c Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:46:13 +0100 Subject: [PATCH 026/172] add `rrule` for LKJ(Cholesky) link function --- src/chainrules.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 45cacf82..445e217b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -156,5 +156,59 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM return y, _transform_inverse_ordered_adjoint end +function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) + project_W = ChainRulesCore.ProjectTo(W) + + K = LinearAlgebra.checksquare(W) + N = ((K-1)*K) ÷ 2 + + z = zeros(eltype(W), N) + tmp_vec = similar(z) + + idx = 1 + @inbounds for j = 2:K + z[idx] = atanh(W[1, j]) + tmp = sqrt(1 - W[1, j]^2) + tmp_vec[idx] = tmp + idx += 1 + for i in 2:(j-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + tmp_vec[idx] = tmp + z[idx] = atanh(p) + idx += 1 + end + end + + function pullback_link_chol_lkj(Δz_thunked) + Δz = ChainRulesCore.unthunk(Δz_thunked) + + ΔW = similar(W) + + @inbounds ΔW[1,1] = zero(eltype(Δz)) + @inbounds for j=2:K + idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) + ΔW[j, j] = zero(eltype(Δz)) + Δtmp = zero(eltype(Δz)) + for i in (j-1):-1:2 + tmp = tmp_vec[idx_up_to_prev_column + i - 1] + p = W[i, j] / tmp + ftmp = sqrt(1 - p^2) + d_ftmp_p = -p / ftmp + d_p_tmp = -W[i,j] / tmp^2 + + Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp * d_ftmp_p + ΔW[i, j] = Δp / tmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp + end + ΔW[1, j] = Δz[1, j] / (1-W[1,j]^2) - Δtmp / sqrt(1 - W[1,j]^2) * W[1,j] + end + + return ChainRulesCore.NoTangent(), project_W(ΔW) + end + + return z, pullback_link_chol_lkj +end + # Fixes Zygote's issues with `@debug` -ChainRulesCore.@non_differentiable _debug(::Any) \ No newline at end of file +ChainRulesCore.@non_differentiable _debug(::Any) From ad080ea21f341d2becf0d3bfc2a9a657e4a38ae5 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 11 Apr 2023 15:36:36 +0100 Subject: [PATCH 027/172] use `transpose` in link for `::LowerTriangular' --- src/bijectors/corr.jl | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index eb915b75..a0269ce7 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -317,28 +317,7 @@ function _link_chol_lkj(W::UpperTriangular) return z end -function _link_chol_lkj(W::LowerTriangular) - K = LinearAlgebra.checksquare(W) - N = div((K-1)*K, 2) # {K \choose 2} free parameters - - z = zeros(eltype(W), N) - - # This block can't be integrated with loop below, because w[1,1] != 0. - idx = 1 - @inbounds for i = 2:K - z[idx] = atanh(W[i, 1]) - idx += 1 - tmp = sqrt(1 - W[i, 1]^2) - for j in 2:(i-1) - p = W[i, j] / tmp - tmp *= sqrt(1 - p^2) - z[idx] = atanh(p) - idx += 1 - end - end - - return z -end +_link_chol_lkj(W::LowerTriangular) = (_link_chol_lkj ∘ transpose)(W) """ _inv_link_chol_lkj(y) From 6e1a5b10413603a4fe950cdb0deaf7cae02619ab Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 12 Apr 2023 17:17:32 +0100 Subject: [PATCH 028/172] add `Tracker` support for inverse link --- src/compat/tracker.jl | 53 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 4763d724..dae58086 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,7 +12,7 @@ using ..Tracker: Tracker, param import ..Bijectors -using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked +using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked, _triu1_dim_from_length import ChainRulesCore import LogExpFunctions @@ -296,8 +296,57 @@ Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, return Bijectors.lower_triangular(Ad), Δ -> (Bijectors.lower_triangular(Δ),) end +Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_lkj, y) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedVector) + y = data(y_tracked) + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + W .= zeros(eltype(y)) + + z_vec = similar(y) + tmp_vec = similar(y) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + tmp = W[i-1, j] + + z_vec[idx] = z + tmp_vec[idx] = tmp + idx += 1 + + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + end + + function pullback_inv_link_chol_lkj(ΔW) + LinearAlgebra.checksquare(ΔW) + + Δy = zero(y) + + @inbounds for j in 1:K + idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) + Δtmp = ΔW[j,j] + for i in j:-1:2 + idx = idx_up_to_prev_column + i - 1 + Δz = ΔW[i-1, j] * tmp_vec[idx] - Δtmp * tmp_vec[idx] / sqrt(1 - z_vec[idx]^2) * z_vec[idx] + Δy[idx] = Δz / cosh(y[idx])^2 + Δtmp = ΔW[i-1, j] * z_vec[idx] + Δtmp * sqrt(1 - z_vec[idx]^2) + end + end + + return (Δy,) + end + + return W, pullback_inv_link_chol_lkj +end + Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) -@grad function Bijectors._inv_link_chol_lkj(y_tracked) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedMatrix) y = data(y_tracked) K = LinearAlgebra.checksquare(y) From 5fd0a652b107fa4e12d4e9fdfc242e949d06d6d1 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 12 Apr 2023 18:13:11 +0100 Subject: [PATCH 029/172] better utility function call --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 974a28a6..34842e89 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,7 +14,7 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) -cholesky_factor(X::AbstractMatrix) = cholesky(X).UL +cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(X)) cholesky_factor(X::Cholesky) = X.UL cholesky_factor(X::UpperTriangular) = X cholesky_factor(X::LowerTriangular) = X From b38acda3b9acdc72dcf0095f22f9d13c017e3f29 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 12 Apr 2023 18:13:43 +0100 Subject: [PATCH 030/172] use function barrier properly for type stability --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index a0269ce7..f2dec9f9 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -184,7 +184,7 @@ abstract type AbstractVecCorrBijector <: Bijector end with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) -transform(::AbstractVecCorrBijector, X) = (_link_chol_lkj ∘ cholesky_factor)(X) +transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) function logabsdetjac(b::AbstractVecCorrBijector, x) return -logabsdetjac(inverse(b), b(x)) From 424f8cafee245967c540ede071e956814d442c79 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 13:17:34 +0100 Subject: [PATCH 031/172] account for difference in support dimensions --- test/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transform.jl b/test/transform.jl index 7be147d3..33f7a73b 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -61,7 +61,7 @@ function single_sample_tests(dist) else # This should probably be exact. @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) - @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100])) + @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(y)) for _ in 1:100])) end end From b749d37428af78e5ec5891857eaa5a70337d4cac Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 13:18:21 +0100 Subject: [PATCH 032/172] fix indexing in Jacobian of `VecCorrBijector` --- test/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transform.jl b/test/transform.jl index 33f7a73b..2776d7fa 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -182,7 +182,7 @@ end upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] J = ForwardDiff.jacobian(x->link(dist, x), x) - J = J[upperinds, upperinds] + J = J[:, upperinds] logpdf_turing = logpdf_with_trans(dist, x, true) @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing end From 7b1f74d53d8ed282aa27da4232723e0a6e282071 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:06:52 +0100 Subject: [PATCH 033/172] add `_logabsdetjac_dist` for `::LKJCholesky` --- src/Bijectors.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index afabd3d7..3a339020 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -138,6 +138,8 @@ _logabsdetjac_dist(d::MultivariateDistribution, x::AbstractMatrix) = logabsdetja _logabsdetjac_dist(d::MatrixDistribution, x::AbstractMatrix) = logabsdetjac(bijector(d), x) _logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractMatrix}) = logabsdetjac.((bijector(d),), x) +_logabsdetjac_dist(d::LKJCholesky, x::Cholesky) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::LKJCholesky, x::AbstractVector) = logabsdetjac.((bijector(d),), x) function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) From 75c605b5df5b8a06db930b38bd8de81ccd4af4b6 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:08:37 +0100 Subject: [PATCH 034/172] replace function composition for proper barrier --- src/bijectors/corr.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f2dec9f9..8caa436c 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -223,17 +223,17 @@ julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ struct VecCorrBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = (pd_from_upper ∘ _inv_link_chol_lkj)(y) +transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) struct VecTriuBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ UpperTriangular ∘ _inv_link_chol_lkj)(y) +transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = Cholesky(UpperTriangular(_inv_link_chol_lkj(y))) logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) struct VecTrilBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ LowerTriangular ∘ transpose ∘ _inv_link_chol_lkj)(y) +transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = Cholesky(LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y)))) logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) @@ -317,7 +317,7 @@ function _link_chol_lkj(W::UpperTriangular) return z end -_link_chol_lkj(W::LowerTriangular) = (_link_chol_lkj ∘ transpose)(W) +_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W)) """ _inv_link_chol_lkj(y) From a7a6c05e3549a960b196435001dbedbaf0223d59 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:09:25 +0100 Subject: [PATCH 035/172] add util convert `Transpose -> Matrix` for type stability --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index 34842e89..3d41b5d2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,6 +10,7 @@ _vec(x::Real) = x # # Because `ReverseDiff` does not play well with structural matrices. lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) +_transpose_matrix(A::AbstractMatrix) = Matrix(transpose(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) From 09c35b6657b9053f653a170136d7f160b4047e4b Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:10:20 +0100 Subject: [PATCH 036/172] add `LKJCholesky` Jacobian+type tests --- test/transform.jl | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/test/transform.jl b/test/transform.jl index 2776d7fa..f86b19ae 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -38,8 +38,13 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) - @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 + if dist isa LKJCholesky + x_inv = @inferred(invlink(dist, link(dist, copy(x)))) + @test x_inv.UL ≈ x.UL atol=1e-9 + else + @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 + end # Check that link is inverse of invlink. Hopefully this just holds given the above... y = @inferred(link(dist, x)) if dist isa Dirichlet @@ -169,9 +174,9 @@ let end end -@testset "correlation matrix" begin +@testset "LKJ" begin - dist = LKJ(2, 1) + dist = LKJ(3, 1) single_sample_tests(dist) @@ -187,6 +192,22 @@ end @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing end + +@testset "LKJCholesky" begin + + dist = LKJCholesky(3, 1) + + single_sample_tests(dist) + + x = rand(dist) + + upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] + J = ForwardDiff.jacobian(x->link(dist, x), x.U) + J = J[:, upperinds] + logpdf_turing = logpdf_with_trans(dist, x, true) + @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing +end + ################################## Miscelaneous old tests ################################## # julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), exp.([-1000., -1000., -1000.]), true) From 2ad5038864f445e87bb5c72ed6bae2dcc7255a76 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Fri, 14 Apr 2023 15:35:15 +0100 Subject: [PATCH 037/172] fix `logabsdetjac` for inverse link --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 8caa436c..ccd061bc 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -405,8 +405,8 @@ function _logabsdetjac_inv_chol(y::AbstractVector) for _ in 1:(j-1) z = tanh(y[idx]) logz = log(1 - z^2) - tmp += logz result += logz + (tmp / 2) + tmp += logz idx += 1 end end From f5be4e2705e7393a663001d18ad9ca2ea4c794c2 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Fri, 14 Apr 2023 16:16:41 +0100 Subject: [PATCH 038/172] use `Cholesky` constructor compatible with `v1.6` --- src/bijectors/corr.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index ccd061bc..0cffe0db 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -228,12 +228,24 @@ transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) struct VecTriuBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = Cholesky(UpperTriangular(_inv_link_chol_lkj(y))) + +function transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) + # This constructor is compatible with Julia v1.6 + # for later versions Cholesky(::UpperTriangular) works + U = UpperTriangular(_inv_link_chol_lkj(y)) + return Cholesky(U.data, 'U', 0) +end logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) struct VecTrilBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = Cholesky(LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y)))) + +function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) + # This constructor is compatible with Julia v1.6 + # for later versions Cholesky(::LowerTriangular) works + L = LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y))) + return Cholesky(L.data, 'L', 0) +end logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) From 10d93453a55f41c3ab5498c3cd7f2c3826b84bc1 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 12:43:59 +0100 Subject: [PATCH 039/172] add empty line --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index 3d41b5d2..a0c2841b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,6 +10,7 @@ _vec(x::Real) = x # # Because `ReverseDiff` does not play well with structural matrices. lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) + _transpose_matrix(A::AbstractMatrix) = Matrix(transpose(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' From bcf32a3d213da61b7362c109381f10cbbf42f4f3 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 12:44:25 +0100 Subject: [PATCH 040/172] fix `rrule` for link function --- src/chainrules.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 445e217b..6f274f3b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -157,7 +157,6 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM end function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) - project_W = ChainRulesCore.ProjectTo(W) K = LinearAlgebra.checksquare(W) N = ((K-1)*K) ÷ 2 @@ -186,9 +185,10 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) ΔW = similar(W) @inbounds ΔW[1,1] = zero(eltype(Δz)) + @inbounds for j=2:K idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) - ΔW[j, j] = zero(eltype(Δz)) + ΔW[j, j] = 0 Δtmp = zero(eltype(Δz)) for i in (j-1):-1:2 tmp = tmp_vec[idx_up_to_prev_column + i - 1] @@ -197,14 +197,14 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) d_ftmp_p = -p / ftmp d_p_tmp = -W[i,j] / tmp^2 - Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp * d_ftmp_p + Δp = Δz[idx_up_to_prev_column + i] / (1-p^2) + Δtmp * tmp * d_ftmp_p ΔW[i, j] = Δp / tmp - Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp end - ΔW[1, j] = Δz[1, j] / (1-W[1,j]^2) - Δtmp / sqrt(1 - W[1,j]^2) * W[1,j] + ΔW[1, j] = Δz[idx_up_to_prev_column + 1] / (1-W[1,j]^2) - Δtmp / sqrt(1 - W[1,j]^2) * W[1,j] end - return ChainRulesCore.NoTangent(), project_W(ΔW) + return ChainRulesCore.NoTangent(), ΔW end return z, pullback_link_chol_lkj From 7f4551f92bb9f50953bdbf85bbb108af2fba429b Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 12:44:41 +0100 Subject: [PATCH 041/172] add link `rrule` test --- test/ad/chainrules.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index b0e4dc2e..a2289d26 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -13,4 +13,9 @@ test_rrule(Bijectors._transform_ordered, randn(5, 2)) test_rrule(Bijectors._transform_inverse_ordered, b(rand(5))) test_rrule(Bijectors._transform_inverse_ordered, b(rand(5, 2))) + + # LKJ and LKJCholesky bijector + dist = LKJCholesky(3, 1) + x = rand(dist) + test_rrule(Bijectors._link_chol_lkj, x.U) end From dc2c85611b527c9a832d916f15c0f96a13c254c9 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 13:27:55 +0100 Subject: [PATCH 042/172] add `rrule` for inverse link --- src/chainrules.jl | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index 6f274f3b..79389c09 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -210,5 +210,56 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) return z, pullback_link_chol_lkj end +function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) + + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + W .= zeros(eltype(y)) + + z_vec = similar(y) + tmp_vec = similar(y) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + tmp = W[i-1, j] + + z_vec[idx] = z + tmp_vec[idx] = tmp + idx += 1 + + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + end + + function pullback_inv_link_chol_lkj(ΔW_thunked) + ΔW = ChainRulesCore.unthunk(ΔW_thunked) + + Δy = zero(y) + + @inbounds for j in 1:K + idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) + Δtmp = ΔW[j,j] + for i in j:-1:2 + idx = idx_up_to_prev_column + i - 1 + tmp = tmp_vec[idx] + z = z_vec[idx] + + Δz = ΔW[i-1, j] * tmp - Δtmp * tmp / sqrt(1 - z^2) * z + Δy[idx] = Δz / cosh(y[idx])^2 + Δtmp = ΔW[i-1, j] * z + Δtmp * sqrt(1 - z^2) + end + end + + return ChainRulesCore.NoTangent(), Δy + end + + return W, pullback_inv_link_chol_lkj +end + # Fixes Zygote's issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) From 87bc3cafbf29addf120fde0b7cad85505f47f452 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 13:28:12 +0100 Subject: [PATCH 043/172] remove TODO --- src/bijectors/corr.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 0cffe0db..f6c3b5cb 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -359,7 +359,6 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) end function _inv_link_chol_lkj(y::AbstractVector) - # TODO: Implement adjoint to support reverse-mode AD backends properly. K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) From bfb7c151c963a22bf003f8205ae012e44eab70df Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 13:28:29 +0100 Subject: [PATCH 044/172] add inverse link `rrule` test --- test/ad/chainrules.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index a2289d26..2542e48f 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -18,4 +18,8 @@ dist = LKJCholesky(3, 1) x = rand(dist) test_rrule(Bijectors._link_chol_lkj, x.U) + + b = bijector(dist) + y = b(x) + test_rrule(Bijectors._inv_link_chol_lkj, y) end From 20ab3b4a6268988bbe14691666fd61b2b4d713a6 Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Mon, 17 Apr 2023 20:54:50 +0100 Subject: [PATCH 045/172] Update src/bijectors/corr.jl Co-authored-by: Tor Erlend Fjelde --- src/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f6c3b5cb..5951840e 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -243,7 +243,9 @@ struct VecTrilBijector <: AbstractVecCorrBijector end function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) # This constructor is compatible with Julia v1.6 # for later versions Cholesky(::LowerTriangular) works - L = LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y))) + # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. + # If we don't, the return-type can be both `Matrix` and `Transposed`. + L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) return Cholesky(L.data, 'L', 0) end From 7bb37e08b4476a852cabcce747dbdbca4199a85a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:19:13 +0100 Subject: [PATCH 046/172] add link `rrule` for `LowerTriangular` --- src/chainrules.jl | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index 79389c09..2d3a0f2d 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -210,6 +210,59 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) return z, pullback_link_chol_lkj end +function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular) + K = LinearAlgebra.checksquare(W) + N = ((K-1)*K) ÷ 2 + + z = zeros(eltype(W), N) + tmp_vec = similar(z) + + idx = 1 + @inbounds for i = 2:K + z[idx] = atanh(W[i, 1]) + tmp = sqrt(1 - W[i, 1]^2) + tmp_vec[idx] = tmp + idx += 1 + for j in 2:(i-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + tmp_vec[idx] = tmp + z[idx] = atanh(p) + idx += 1 + end + end + + function pullback_link_chol_lkj(Δz_thunked) + Δz = ChainRulesCore.unthunk(Δz_thunked) + + ΔW = similar(W) + + @inbounds ΔW[1,1] = zero(eltype(Δz)) + + @inbounds for i=2:K + idx_up_to_prev_row = ((i-1)*(i-2) ÷ 2) + ΔW[i, i] = 0 + Δtmp = zero(eltype(Δz)) + for j in (i-1):-1:2 + tmp = tmp_vec[idx_up_to_prev_row + j - 1] + p = W[i, j] / tmp + ftmp = sqrt(1 - p^2) + d_ftmp_p = -p / ftmp + d_p_tmp = -W[i,j] / tmp^2 + + Δp = Δz[idx_up_to_prev_row + j] / (1-p^2) + Δtmp * tmp * d_ftmp_p + ΔW[i, j] = Δp / tmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp + end + ΔW[i, 1] = Δz[idx_up_to_prev_row + 1] / (1-W[i,1]^2) - Δtmp / sqrt(1 - W[i,1]^2) * W[i,1] + end + + return ChainRulesCore.NoTangent(), ΔW + end + + return z, pullback_link_chol_lkj +end + function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) K = _triu1_dim_from_length(length(y)) From 3e2c7a83eee447c9b032ed56d1808e8660b8ac17 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:19:52 +0100 Subject: [PATCH 047/172] add `LowerTriangular` chainrule test --- test/ad/chainrules.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 2542e48f..59581fda 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -18,6 +18,7 @@ dist = LKJCholesky(3, 1) x = rand(dist) test_rrule(Bijectors._link_chol_lkj, x.U) + test_rrule(Bijectors._link_chol_lkj, x.L) b = bijector(dist) y = b(x) From adba9e8ada1da59a586021fad3eadac6ef215c1a Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Tue, 18 Apr 2023 10:38:12 +0100 Subject: [PATCH 048/172] Update src/bijectors/corr.jl Co-authored-by: Tor Erlend Fjelde --- src/bijectors/corr.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5951840e..01e89b1f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -182,6 +182,7 @@ end abstract type AbstractVecCorrBijector <: Bijector end +TODO: Implement directly to make use of shared computations. with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) From ec18964957f14ef785fb8dfa759ec72a0dc87894 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:45:11 +0100 Subject: [PATCH 049/172] remove unused util --- src/utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index a0c2841b..34842e89 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,8 +11,6 @@ _vec(x::Real) = x lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) -_transpose_matrix(A::AbstractMatrix) = Matrix(transpose(A)) - pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) From 37c38abc716d9c11bc90ca223a1502ce3491f802 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:45:42 +0100 Subject: [PATCH 050/172] use `similar` instead of `zeros` --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 01e89b1f..eed36742 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -313,7 +313,7 @@ function _link_chol_lkj(W::UpperTriangular) K = LinearAlgebra.checksquare(W) N = ((K-1)*K) ÷ 2 # {K \choose 2} free parameters - z = zeros(eltype(W), N) + z = similar(W, N) # This block can't be integrated with loop below, because w[1,1] != 0. idx = 1 From 8fd13b098a93ad907c91797bfb0afe9f83199783 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:45:58 +0100 Subject: [PATCH 051/172] update comments --- src/bijectors/corr.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index eed36742..d3b8a07d 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -182,7 +182,7 @@ end abstract type AbstractVecCorrBijector <: Bijector end -TODO: Implement directly to make use of shared computations. +# TODO: Implement directly to make use of shared computations. with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) @@ -231,9 +231,9 @@ logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdet struct VecTriuBijector <: AbstractVecCorrBijector end function transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) - # This constructor is compatible with Julia v1.6 - # for later versions Cholesky(::UpperTriangular) works U = UpperTriangular(_inv_link_chol_lkj(y)) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::UpperTriangular) works return Cholesky(U.data, 'U', 0) end @@ -242,11 +242,11 @@ logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdet struct VecTrilBijector <: AbstractVecCorrBijector end function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) - # This constructor is compatible with Julia v1.6 - # for later versions Cholesky(::LowerTriangular) works # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::LowerTriangular) works return Cholesky(L.data, 'L', 0) end From 56cc43f9e25f33bb3315319035d25ee696957a4d Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:46:55 +0100 Subject: [PATCH 052/172] remove old comment --- src/bijectors/corr.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index d3b8a07d..30fc5032 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -315,7 +315,6 @@ function _link_chol_lkj(W::UpperTriangular) z = similar(W, N) - # This block can't be integrated with loop below, because w[1,1] != 0. idx = 1 @inbounds for j = 2:K z[idx] = atanh(W[1, j]) From 8ee086aa4a4f37c80147afe5330751448de9fdbf Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 11:18:23 +0100 Subject: [PATCH 053/172] minimize zero-setting operations in inverse link --- src/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 30fc5032..41d7b533 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -364,7 +364,6 @@ function _inv_link_chol_lkj(y::AbstractVector) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) - W .= zeros(eltype(y)) idx = 1 @inbounds for j in 1:K @@ -376,6 +375,9 @@ function _inv_link_chol_lkj(y::AbstractVector) W[i-1, j] = z * tmp W[i, j] = tmp * sqrt(1 - z^2) end + for i in (j+1):K + W[i, j] = 0 + end end return W From 837b49c1e8aa57b654c301115c3c1bfacd78e506 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 11:20:09 +0100 Subject: [PATCH 054/172] minimize zero-setting operations in `rrule` --- src/chainrules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 2d3a0f2d..75e8676b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -268,7 +268,6 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) - W .= zeros(eltype(y)) z_vec = similar(y) tmp_vec = similar(y) @@ -287,6 +286,9 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) W[i-1, j] = z * tmp W[i, j] = tmp * sqrt(1 - z^2) end + for i in (j+1):K + W[i, j] = 0 + end end function pullback_inv_link_chol_lkj(ΔW_thunked) From 0c3aa399c7be95141236a65a323bc6e08828c8c2 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 12:38:51 +0100 Subject: [PATCH 055/172] add parametric `Val` type to `VecCorrBijector` --- src/bijectors/corr.jl | 43 ++++++++++++++++----------------- src/transformed_distribution.jl | 4 +-- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 41d7b533..99c1fe25 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -180,17 +180,6 @@ function vec_to_triu1_row_index(idx) return idx - (M*(M-1) ÷ 2) end -abstract type AbstractVecCorrBijector <: Bijector end - -# TODO: Implement directly to make use of shared computations. -with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) - -transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) - -function logabsdetjac(b::AbstractVecCorrBijector, x) - return -logabsdetjac(inverse(b), b(x)) -end - """ VecCorrBijector <: Bijector @@ -223,25 +212,33 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) +struct VecCorrBijector{T} <: Bijector + uplo::Symbol + function VecCorrBijector(uplo) + s = Symbol(uplo) + new{Val{s}}(s) + end +end + +# TODO: Implement directly to make use of shared computations. +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) -logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) +function logabsdetjac(b::VecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end -struct VecTriuBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) -function transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) +function transform(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) U = UpperTriangular(_inv_link_chol_lkj(y)) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works return Cholesky(U.data, 'U', 0) end -logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) - -struct VecTrilBijector <: AbstractVecCorrBijector end - -function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) +function transform(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) @@ -250,7 +247,9 @@ function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) return Cholesky(L.data, 'L', 0) end -logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) +logabsdetjac(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) +logabsdetjac(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) +logabsdetjac(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) """ diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index b1cf82a3..ced16de3 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -78,8 +78,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() bijector(d::MatrixBeta) = PDBijector() -bijector(d::LKJ) = VecCorrBijector() -bijector(d::LKJCholesky) = d.uplo === 'L' ? VecTrilBijector() : VecTriuBijector() +bijector(d::LKJ) = VecCorrBijector('C') +bijector(d::LKJCholesky) = VecCorrBijector(d.uplo) ############################## # Distributions.jl interface # From c1be27294c85da43a07659444b7a278d1cdeedea Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 12:39:05 +0100 Subject: [PATCH 056/172] update `VecCorrBijector` tests --- test/bijectors/corr.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 900b5e68..dd528eb5 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,5 +1,5 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, CorrBijector, VecTriuBijector, VecTrilBijector +using Bijectors: VecCorrBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] @@ -34,12 +34,12 @@ using Bijectors: VecCorrBijector, CorrBijector, VecTriuBijector, VecTrilBijector end end -@testset "VecTriuBijector & VecTrilBijector" begin +@testset "VecCorrBijector on LKJCholesky" begin for d ∈ [2, 5] for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] b = bijector(dist) - b_lkj = VecCorrBijector() + b_lkj = VecCorrBijector('C') x = rand(dist) y = b(x) y_lkj = b_lkj(x) From 29fced6653940076a7dc2659d9246db36dc494c4 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 14:40:03 +0100 Subject: [PATCH 057/172] use field value instead of `Val`-parametric type --- src/bijectors/corr.jl | 50 +++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 99c1fe25..2169feaf 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -212,12 +212,9 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector{T} <: Bijector - uplo::Symbol - function VecCorrBijector(uplo) - s = Symbol(uplo) - new{Val{s}}(s) - end +struct VecCorrBijector <: Bijector + mode::Symbol + VecCorrBijector(uplo) = new(Symbol(uplo)) end # TODO: Implement directly to make use of shared computations. @@ -229,29 +226,32 @@ function logabsdetjac(b::VecCorrBijector, x) return -logabsdetjac(inverse(b), b(x)) end -transform(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) - -function transform(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) - U = UpperTriangular(_inv_link_chol_lkj(y)) - # This Cholesky constructor is compatible with Julia v1.6 - # for later versions Cholesky(::UpperTriangular) works - return Cholesky(U.data, 'U', 0) +function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + if b.orig.mode === :U + U = UpperTriangular(_inv_link_chol_lkj(y)) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::UpperTriangular) works + return Cholesky(U.data, 'U', 0) + elseif b.orig.mode === :L + # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. + # If we don't, the return-type can be both `Matrix` and `Transposed`. + L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::LowerTriangular) works + return Cholesky(L.data, 'L', 0) + else + return pd_from_upper(_inv_link_chol_lkj(y)) + end end -function transform(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) - # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. - # If we don't, the return-type can be both `Matrix` and `Transposed`. - L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) - # This Cholesky constructor is compatible with Julia v1.6 - # for later versions Cholesky(::LowerTriangular) works - return Cholesky(L.data, 'L', 0) +function logabsdetjac(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + if (b.orig.mode === :U) || (b.orig.mode === :L) + return _logabsdetjac_inv_chol(y) + else + return _logabsdetjac_inv_corr(y) + end end -logabsdetjac(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) -logabsdetjac(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) -logabsdetjac(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) - - """ function _link_chol_lkj(w) From 74d6edbeb634e2d4bdf24668690959dc82f3796e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 14:47:34 +0100 Subject: [PATCH 058/172] update tests with new `VecCorrBijector` --- test/bijectors/utils.jl | 14 ++++++++++++-- test/transform.jl | 7 ++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index dc1d3a55..17a7bc79 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -25,9 +25,19 @@ function test_bijector( y_test = @inferred b(x) ilogjac_test = !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) ires = if !isnothing(y) - @inferred(with_logabsdet_jacobian(inverse(b), y)) + if b isa VecCorrBijector + # Inverse{VecCorrBijector} returns a ::Cholesky{...} in the case of a LKJCholesky distribution + # and a ::Matrix{Float64} in the case of a LKJ distribution. + @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y) + else + @inferred(with_logabsdet_jacobian(inverse(b), y)) + end else - @inferred(with_logabsdet_jacobian(inverse(b), y_test)) + if b isa VecCorrBijector + @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y_test) + else + @inferred(with_logabsdet_jacobian(inverse(b), y_test)) + end end # ChangesOfVariables.jl diff --git a/test/transform.jl b/test/transform.jl index f86b19ae..11e04d76 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -39,9 +39,14 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) + # LKJCholesky and LKJ use the same VecCorrBijector. + # The return type of Inverse{VecCorrBijector} depends on the distribution. if dist isa LKJCholesky - x_inv = @inferred(invlink(dist, link(dist, copy(x)))) + x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) @test x_inv.UL ≈ x.UL atol=1e-9 + elseif dist isa LKJ + x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) + @test x_inv ≈ x atol=1e-9 else @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 end From 4c27987ec698281d031a48aa9c5a934d6b9ed167 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 14:57:50 +0100 Subject: [PATCH 059/172] `using VecCorrBijector` in test utils --- test/bijectors/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index 17a7bc79..45e863c6 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,3 +1,5 @@ +using Bijectors: VecCorrBijector + # Allows us to run `ChangesOfVariables.test_with_logabsdet_jacobian` include(joinpath(dirname(pathof(ChangesOfVariables)), "..", "test", "getjacobian.jl")) From 9108c40e2bc1ffa8c3bb4606481076a89536cbda Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 16:27:04 +0100 Subject: [PATCH 060/172] add `VecCorrBijector.mode` check --- src/bijectors/corr.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 2169feaf..69d5da69 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -214,7 +214,14 @@ true """ struct VecCorrBijector <: Bijector mode::Symbol - VecCorrBijector(uplo) = new(Symbol(uplo)) + function VecCorrBijector(uplo_or_corr) + s = Symbol(uplo_or_corr) + if (s === :U) || (s === :L) || (s === :C) + new(s) + else + throw(ArgumentError("mode must be :U (upper), :L (lower) or :C (correlation matrix)")) + end + end end # TODO: Implement directly to make use of shared computations. From 24847cc63dc7f45470558aa349805a02972703d6 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 16:27:30 +0100 Subject: [PATCH 061/172] update `VecCorrBijector` docstring --- src/bijectors/corr.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 69d5da69..5f6b91fd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -183,8 +183,20 @@ end """ VecCorrBijector <: Bijector -Similar to `CorrBijector`, but correlation matrix to a vector, -and its inverse transforms vector to a correlation matrix. +A bijector to transform either a correlation matrix or a Cholesky factor of a correlation matrix +to an unconstrained vector. + +# Fields +- mode :`Symbol`. Controls the inverse tranformation : + - if `mode === :C` returns a correlation matrix + - if `mode === :U` returns a `LinearAlgebra.Cholesky` holding the `UpperTriangular` factor + - if `mode === :L` returns a `LinearAlgebra.Cholesky` holding the `LowerTriangular` factor + +# Reference +- Transforming a orrelation matrix : +https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html#absolute-jacobian-determinant-of-the-correlation-matrix-inverse-transform +- Transforming a Cholesky factor of a correlation matrix : +https://mc-stan.org/docs/reference-manual/cholesky-factors-of-correlation-matrices-1 See also: [`CorrBijector`](@ref) From bd4de96f7ce22fd6cce16c27cdd405dce650c4b1 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:20:53 +0100 Subject: [PATCH 062/172] specialise `Zygote@adjoint` for `AbstractMatrix` --- src/compat/zygote.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 6a81a749..864286d7 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -174,7 +174,7 @@ end end end -@adjoint function _inv_link_chol_lkj(y) +@adjoint function _inv_link_chol_lkj(y::AbstractMatrix) K = LinearAlgebra.checksquare(y) w = similar(y) @@ -219,7 +219,7 @@ end return w, pullback_inv_link_chol_lkj end -@adjoint function _link_chol_lkj(w) +@adjoint function _link_chol_lkj(w::AbstractMatrix) K = LinearAlgebra.checksquare(w) z = similar(w) From 65bfc42420454dce07c252e91586355422d11e71 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:21:38 +0100 Subject: [PATCH 063/172] `ReverseDiff` opt-in to `ChainRules` --- src/compat/reversediff.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 78871ce1..52301333 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -200,6 +200,9 @@ end @grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) +@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) +@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) + # NOTE: Probably doesn't work in complete generality. wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing From eca34119c362a2e5963b6abfcf23a78857c35310 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:35:15 +0100 Subject: [PATCH 064/172] empty lines format --- src/chainrules.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 75e8676b..857389de 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -157,10 +157,9 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM end function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) - K = LinearAlgebra.checksquare(W) N = ((K-1)*K) ÷ 2 - + z = zeros(eltype(W), N) tmp_vec = similar(z) @@ -264,7 +263,6 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular) end function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) - K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) From f02fd9b04afa244ac8d0d81703a23db7180b1ecc Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:35:53 +0100 Subject: [PATCH 065/172] add AD test for inverse link --- test/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index dd528eb5..ac889424 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -4,7 +4,7 @@ using Bijectors: VecCorrBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] b = CorrBijector() - bvec = VecCorrBijector() + bvec = VecCorrBijector('C') dist = LKJ(d, 1) x = rand(dist) @@ -31,6 +31,8 @@ using Bijectors: VecCorrBijector, CorrBijector # Hence, we disable those tests. test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) + + test_ad(x -> sum(transform(inverse(b), x)), y, (:Tracker,)) end end From c90f7ac80ffbdb222b114a59060b3966213cf6a1 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:36:40 +0100 Subject: [PATCH 066/172] include `VecCorrBijector` tests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 0a00ea5f..985da6cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,7 @@ if GROUP == "All" || GROUP == "Interface" include("bijectors/coupling.jl") include("bijectors/ordered.jl") include("bijectors/pd.jl") + include("bijectors/corr.jl") end if GROUP == "All" || GROUP == "AD" From 974efb5228d30aef1b69216373382f6fc19a329a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:48:01 +0100 Subject: [PATCH 067/172] remove broken flag for `Tracker` --- test/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index ac889424..eedefc3e 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(transform(inverse(b), x)), y, (:Tracker,)) + test_ad(x -> sum(transform(inverse(b), x)), y) end end From 71fdae6007b1f9d9b21ef4b12104821669d91b04 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 19:36:01 +0100 Subject: [PATCH 068/172] add roundtrip AD tests for `VecCorrBijector` --- test/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index eedefc3e..f1f6ab1f 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(transform(inverse(b), x)), y) + test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker,)) end end @@ -55,6 +55,8 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U + test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) + # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) # test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) From 6524fe40e6df961ffd6f403d6d78938fa44db43e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 17:53:42 +0100 Subject: [PATCH 069/172] remove wrong `ReverseDiff.@grad` for `pd_from_upper` --- src/compat/reversediff.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 52301333..0487751f 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -146,14 +146,7 @@ pd_from_lower(X::TrackedMatrix) = track(pd_from_lower, X) end end -pd_from_upper(X::TrackedMatrix) = track(pd_from_upper, X) -@grad function pd_from_upper(X::AbstractMatrix) - Xd = value(X) - return UpperTriangular(Xd)' * UpperTriangular(Xd), Δ -> begin - Xu = UpperTriangular(Xd) - return (UpperTriangular(Δ * Xu + Δ' * Xu),) - end -end +@grad_from_chainrules pd_from_upper(X::TrackedMatrix) lower_triangular(A::TrackedMatrix) = track(lower_triangular, A) @grad function lower_triangular(A::AbstractMatrix) From 5e4abaec4eb7a2eb20038f3b789f41cf8ec2a19c Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 17:54:23 +0100 Subject: [PATCH 070/172] add corrected `rrule` for `pd_from_upper` --- src/chainrules.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 857389de..c5041f36 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -264,7 +264,7 @@ end function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) K = _triu1_dim_from_length(length(y)) - + W = similar(y, K, K) z_vec = similar(y) @@ -314,5 +314,12 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) return W, pullback_inv_link_chol_lkj end +function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) + return UpperTriangular(X)' * UpperTriangular(X), Δ -> begin + Xu = UpperTriangular(X) + return ChainRulesCore.NoTangent(), UpperTriangular(Xu * Δ + Xu * Δ') + end +end + # Fixes Zygote's issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) From c547542e8de68e6d2ba4d49788dad036b38c3a3c Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 17:56:10 +0100 Subject: [PATCH 071/172] update AD tests --- test/bijectors/corr.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index f1f6ab1f..5f0017de 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -31,8 +31,8 @@ using Bijectors: VecCorrBijector, CorrBijector # Hence, we disable those tests. test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - - test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker,)) + + test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker, :Zygote,)) end end @@ -55,7 +55,7 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U - test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) + test_ad(x -> sum(b(binv(x))), y, (:Tracker, :Zygote,)) # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) From 0d599e858131be8e9ae7af289b16f94112f502e7 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 18:30:38 +0100 Subject: [PATCH 072/172] remove `Tracker` from broken --- test/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 5f0017de..a377b2a8 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker, :Zygote,)) + test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Zygote,)) end end From a1f16b60c3dcd1bc8d83676a2c4a3d07e250ae12 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 25 Apr 2023 16:19:01 +0100 Subject: [PATCH 073/172] update zero-filling in `Tracker` pullback --- src/compat/tracker.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index dae58086..dee73f8b 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -302,7 +302,6 @@ Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_ K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) - W .= zeros(eltype(y)) z_vec = similar(y) tmp_vec = similar(y) @@ -321,6 +320,9 @@ Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_ W[i-1, j] = z * tmp W[i, j] = tmp * sqrt(1 - z^2) end + for i in (j+1):K + W[i, j] = 0 + end end function pullback_inv_link_chol_lkj(ΔW) From 8b4b0c79c4957fdf0b0b982b6bb091af8a4d0fe2 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 25 Apr 2023 16:19:26 +0100 Subject: [PATCH 074/172] fix `Zygote` --- src/bijectors/corr.jl | 8 ++++---- src/utils.jl | 2 +- test/bijectors/corr.jl | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5f6b91fd..f0c7d638 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -247,17 +247,17 @@ end function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) if b.orig.mode === :U - U = UpperTriangular(_inv_link_chol_lkj(y)) + U = _inv_link_chol_lkj(y) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works - return Cholesky(U.data, 'U', 0) + return Cholesky(U, 'U', 0) elseif b.orig.mode === :L # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) + L = Matrix(transpose(_inv_link_chol_lkj(y))) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::LowerTriangular) works - return Cholesky(L.data, 'L', 0) + return Cholesky(L, 'L', 0) else return pd_from_upper(_inv_link_chol_lkj(y)) end diff --git a/src/utils.jl b/src/utils.jl index 34842e89..4880ac4c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,6 @@ pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(X)) -cholesky_factor(X::Cholesky) = X.UL +cholesky_factor(X::Cholesky) = X.U cholesky_factor(X::UpperTriangular) = X cholesky_factor(X::LowerTriangular) = X diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index a377b2a8..947678d8 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Zygote,)) + test_ad(x -> sum(bvec(bvecinv(x))), yvec) end end @@ -55,7 +55,7 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U - test_ad(x -> sum(b(binv(x))), y, (:Tracker, :Zygote,)) + test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) From 890127fb06d04b7b5eb20cd9580668553b7164a1 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 4 May 2023 13:46:46 +0100 Subject: [PATCH 075/172] merge lines - applying feedback suggestions --- src/bijectors/corr.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f0c7d638..5c60998f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -247,17 +247,13 @@ end function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) if b.orig.mode === :U - U = _inv_link_chol_lkj(y) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works - return Cholesky(U, 'U', 0) + return Cholesky(_inv_link_chol_lkj(y), 'U', 0) elseif b.orig.mode === :L # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - L = Matrix(transpose(_inv_link_chol_lkj(y))) - # This Cholesky constructor is compatible with Julia v1.6 - # for later versions Cholesky(::LowerTriangular) works - return Cholesky(L, 'L', 0) + return Cholesky(Matrix(transpose(_inv_link_chol_lkj(y))), 'L', 0) else return pd_from_upper(_inv_link_chol_lkj(y)) end From fa13e270f803e4722b2937731adec0c9755fba48 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 17:58:40 +0300 Subject: [PATCH 076/172] `unthunk` in `pd_from_upper` rrule --- src/chainrules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index c5041f36..46a10db8 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -315,7 +315,8 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) end function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) - return UpperTriangular(X)' * UpperTriangular(X), Δ -> begin + return UpperTriangular(X)' * UpperTriangular(X), Δ_thunked -> begin + Δ = ChainRulesCore.unthunk(Δ_thunked) Xu = UpperTriangular(X) return ChainRulesCore.NoTangent(), UpperTriangular(Xu * Δ + Xu * Δ') end From a36f2b6fd70d528be72286ca138a558f5739b8f4 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:05:17 +0300 Subject: [PATCH 077/172] split structs into `VecCorrBijector` and `VecCholeskyBijector` --- src/bijectors/corr.jl | 104 ++++++++++++++++++++++---------- src/transformed_distribution.jl | 4 +- 2 files changed, 74 insertions(+), 34 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5c60998f..367fe671 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -183,22 +183,12 @@ end """ VecCorrBijector <: Bijector -A bijector to transform either a correlation matrix or a Cholesky factor of a correlation matrix -to an unconstrained vector. - -# Fields -- mode :`Symbol`. Controls the inverse tranformation : - - if `mode === :C` returns a correlation matrix - - if `mode === :U` returns a `LinearAlgebra.Cholesky` holding the `UpperTriangular` factor - - if `mode === :L` returns a `LinearAlgebra.Cholesky` holding the `LowerTriangular` factor +A bijector to transform a correlation matrix to an unconstrained vector. # Reference -- Transforming a orrelation matrix : -https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html#absolute-jacobian-determinant-of-the-correlation-matrix-inverse-transform -- Transforming a Cholesky factor of a correlation matrix : -https://mc-stan.org/docs/reference-manual/cholesky-factors-of-correlation-matrices-1 +https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html -See also: [`CorrBijector`](@ref) +See also: [`CorrBijector`](@ref) and ['VecCholeskyBijector'](@ref) # Example @@ -224,48 +214,98 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector <: Bijector +struct VecCorrBijector <: Bijector end + +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) + +function logabsdetjac(b::VecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end + +transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) + +logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) + +""" + VecCholeskyBijector <: Bijector + +A bijector to transform a Cholesky factor of a correlation matrix to an unconstrained vector. + +# Fields +- mode :`Symbol`. Controls the inverse tranformation : + - if `mode === :U` returns a `LinearAlgebra.Cholesky` holding the `UpperTriangular` factor + - if `mode === :L` returns a `LinearAlgebra.Cholesky` holding the `LowerTriangular` factor + +# Reference +https://mc-stan.org/docs/reference-manual/cholesky-factors-of-correlation-matrices-1 + +See also: [`VecCorrBijector`](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCholeskyBijector(:U); + +julia> X = rand(rng, LKJCholesky(3, 1, :U)) # Sample a correlation matrix. +Cholesky{Float64, Matrix{Float64}} +U factor: +3×3 UpperTriangular{Float64, Matrix{Float64}}: + 1.0 0.937494 0.865891 + ⋅ 0.348002 -0.320442 + ⋅ ⋅ 0.384122 + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> X_inv = inverse(b)(y); +julia> X_inv.U ≈ X.U # (✓) Round-trip through `b` and its inverse. +true +julia> X_inv.L ≈ X.L # (✓) Also works for the lower triangular factor. +true +""" +struct VecCholeskyBijector <: Bijector mode::Symbol - function VecCorrBijector(uplo_or_corr) - s = Symbol(uplo_or_corr) - if (s === :U) || (s === :L) || (s === :C) + function VecCholeskyBijector(uplo) + s = Symbol(uplo) + if (s === :U) || (s === :L) new(s) else - throw(ArgumentError("mode must be :U (upper), :L (lower) or :C (correlation matrix)")) + throw(ArgumentError("mode must be either :U (upper triangular) or :L (lower triangular)")) end end end # TODO: Implement directly to make use of shared computations. -with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) +with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x) -transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) +transform(::VecCholeskyBijector, X) = _link_chol_lkj(cholesky_factor(X)) -function logabsdetjac(b::VecCorrBijector, x) +function logabsdetjac(b::VecCholeskyBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) +function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) if b.orig.mode === :U # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works return Cholesky(_inv_link_chol_lkj(y), 'U', 0) - elseif b.orig.mode === :L + else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. return Cholesky(Matrix(transpose(_inv_link_chol_lkj(y))), 'L', 0) - else - return pd_from_upper(_inv_link_chol_lkj(y)) end end -function logabsdetjac(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - if (b.orig.mode === :U) || (b.orig.mode === :L) - return _logabsdetjac_inv_chol(y) - else - return _logabsdetjac_inv_corr(y) - end -end +logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) """ function _link_chol_lkj(w) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index ced16de3..cfa30445 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -78,8 +78,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() bijector(d::MatrixBeta) = PDBijector() -bijector(d::LKJ) = VecCorrBijector('C') -bijector(d::LKJCholesky) = VecCorrBijector(d.uplo) +bijector(d::LKJ) = VecCorrBijector() +bijector(d::LKJCholesky) = VecCholeskyBijector(d.uplo) ############################## # Distributions.jl interface # From 9690dd2157f3358ba60a30226a34658fbb2a407c Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:05:33 +0300 Subject: [PATCH 078/172] remove old `Zygote` adjoints --- src/compat/zygote.jl | 97 -------------------------------------------- 1 file changed, 97 deletions(-) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 864286d7..29497140 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -173,100 +173,3 @@ end return replace_diag(log, Y) end end - -@adjoint function _inv_link_chol_lkj(y::AbstractMatrix) - K = LinearAlgebra.checksquare(y) - - w = similar(y) - - z_mat = similar(y) # cache for adjoint - tmp_mat = similar(y) - - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] - - z_mat[i, j] = z - tmp_mat[i, j] = tmp - - w[i-1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j+1):K - w[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(Δw) - LinearAlgebra.checksquare(Δw) - - Δy = zero(y) - - @inbounds for j in 1:K - Δtmp = Δw[j,j] - for i in j:-1:2 - Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] - Δy[i-1, j] = Δz / cosh(y[i-1, j])^2 - Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) - end - end - - return (Δy,) - end - - return w, pullback_inv_link_chol_lkj -end - -@adjoint function _link_chol_lkj(w::AbstractMatrix) - K = LinearAlgebra.checksquare(w) - - z = similar(w) - - @inbounds z[1, 1] = 0 - - tmp_mat = similar(w) # cache for pullback. - - @inbounds for j=2:K - z[1, j] = atanh(w[1, j]) - tmp = sqrt(1 - w[1, j]^2) - tmp_mat[1, j] = tmp - for i in 2:(j - 1) - p = w[i, j] / tmp - tmp *= sqrt(1 - p^2) - tmp_mat[i, j] = tmp - z[i, j] = atanh(p) - end - z[j, j] = 0 - end - - function pullback_link_chol_lkj(Δz) - LinearAlgebra.checksquare(Δz) - - Δw = similar(w) - - @inbounds Δw[1,1] = zero(eltype(Δz)) - - @inbounds for j=2:K - Δw[j, j] = 0 - Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j] - for i in (j-1):-1:2 - p = w[i, j] / tmp_mat[i-1, j] - ftmp = sqrt(1 - p^2) - d_ftmp_p = -p / ftmp - d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2 - - Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p - Δw[i, j] = Δp / tmp_mat[i-1, j] - Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp - end - Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j] - end - - return (Δw,) - end - - return z, pullback_link_chol_lkj - -end From 8a677139d4dd91b78f8377a6f0e6510dafffbf80 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:05:59 +0300 Subject: [PATCH 079/172] update tests --- test/bijectors/corr.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 947678d8..e78c2da2 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,10 +1,10 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, CorrBijector +using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] b = CorrBijector() - bvec = VecCorrBijector('C') + bvec = VecCorrBijector() dist = LKJ(d, 1) x = rand(dist) @@ -36,12 +36,12 @@ using Bijectors: VecCorrBijector, CorrBijector end end -@testset "VecCorrBijector on LKJCholesky" begin +@testset "VecCholeskyBijector" begin for d ∈ [2, 5] for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] b = bijector(dist) - b_lkj = VecCorrBijector('C') + b_lkj = VecCorrBijector() x = rand(dist) y = b(x) y_lkj = b_lkj(x) @@ -55,7 +55,7 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U - test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) + test_ad(x -> sum(b(binv(x))), y) # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) From 37cfd90781f5a8c9b26800c7911d77b181a8669d Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:07:10 +0300 Subject: [PATCH 080/172] fix `Union` in `@inferred` after splitting structs --- test/bijectors/utils.jl | 12 ------------ test/transform.jl | 6 +----- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index 45e863c6..4c986bee 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,5 +1,3 @@ -using Bijectors: VecCorrBijector - # Allows us to run `ChangesOfVariables.test_with_logabsdet_jacobian` include(joinpath(dirname(pathof(ChangesOfVariables)), "..", "test", "getjacobian.jl")) @@ -27,19 +25,9 @@ function test_bijector( y_test = @inferred b(x) ilogjac_test = !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) ires = if !isnothing(y) - if b isa VecCorrBijector - # Inverse{VecCorrBijector} returns a ::Cholesky{...} in the case of a LKJCholesky distribution - # and a ::Matrix{Float64} in the case of a LKJ distribution. - @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y) - else @inferred(with_logabsdet_jacobian(inverse(b), y)) - end else - if b isa VecCorrBijector - @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y_test) - else @inferred(with_logabsdet_jacobian(inverse(b), y_test)) - end end # ChangesOfVariables.jl diff --git a/test/transform.jl b/test/transform.jl index 11e04d76..aaa53a94 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -39,17 +39,13 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) - # LKJCholesky and LKJ use the same VecCorrBijector. - # The return type of Inverse{VecCorrBijector} depends on the distribution. if dist isa LKJCholesky x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) @test x_inv.UL ≈ x.UL atol=1e-9 - elseif dist isa LKJ - x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) - @test x_inv ≈ x atol=1e-9 else @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 end + # Check that link is inverse of invlink. Hopefully this just holds given the above... y = @inferred(link(dist, x)) if dist isa Dirichlet From a3c7f577987bdbad355e37fd23d1187f27b410ef Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:07:36 +0300 Subject: [PATCH 081/172] remove `Tracker` tests as support is dropped --- test/ad/utils.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 6bf8365f..da21e3da 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -4,16 +4,6 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] - if AD == "All" || AD == "Tracker" - if :Tracker in broken - @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol - else - ∇tracker = Tracker.gradient(f, x)[1] - @test Tracker.data(∇tracker) ≈ finitediff rtol=rtol atol=atol - @test Tracker.istracked(∇tracker) - end - end - if AD == "All" || AD == "ForwardDiff" if :ForwardDiff in broken @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol From df4d9602c060288347a0676be299a6186bda50b2 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Jun 2023 13:11:12 +0300 Subject: [PATCH 082/172] use `permutedims` instead of casting --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 367fe671..b92ad70f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -301,7 +301,7 @@ function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - return Cholesky(Matrix(transpose(_inv_link_chol_lkj(y))), 'L', 0) + return Cholesky(permutedims(_inv_link_chol_lkj(y), (2, 1)), 'L', 0) end end From 17f784f4ddbc59e09e7385e66f10b12a3afbc40b Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Jun 2023 13:11:58 +0300 Subject: [PATCH 083/172] remove `Union` in `@inferred` --- test/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transform.jl b/test/transform.jl index aaa53a94..29d0dbcf 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -40,7 +40,7 @@ function single_sample_tests(dist) x = rand(dist) if dist isa LKJCholesky - x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) + x_inv = @inferred Cholesky{Float64, Matrix{Float64}} invlink(dist, link(dist, copy(x))) @test x_inv.UL ≈ x.UL atol=1e-9 else @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 From 852573d79832b152da579785bd55e4bf080aeae3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Feb 2023 00:02:14 +0000 Subject: [PATCH 084/172] initial work on VecCorrBijector --- src/bijectors/corr.jl | 222 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 190 insertions(+), 32 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 252ecc68..f4cab829 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -78,19 +78,7 @@ function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) return w' * w end -function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) - K = LinearAlgebra.checksquare(y) - - result = float(zero(eltype(y))) - for j in 2:K, i in 1:(j - 1) - @inbounds abs_y_i_j = abs(y[i, j]) - result += (K - i + 1) * ( - IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) - ) - end - - return result -end +logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y) function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) #= It may be more efficient if we can use un-contraint value to prevent call of b @@ -98,28 +86,159 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) `logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})` if possible. =# - return -logabsdetjac(inverse(b), (b(X))) + return -logabsdetjac(inverse(b), (b(X))) end -function _inv_link_chol_lkj(y) - K = LinearAlgebra.checksquare(y) +""" + VecCorrBijector <: Bijector - w = similar(y) +Similar to `CorrBijector`, but transforms a vector representing the Cholesky +to a correlation matrix, and its inverse transforms correlation matrix to vector +representing Cholesky. + +See also: [`CorrBijector`](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCorrBijector(); + +julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. +3×3 Matrix{Float64}: + 1.0 -0.705273 -0.348638 + -0.705273 1.0 0.0534538 + -0.348638 0.0534538 1.0 + +julia> # Get the cholesky and convert to a vector. + u = Bijectors.triu1_to_vec(cholesky(X).U) + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> inverse(b)(y) ≈ u # (✓) Round-trip through `b` and its inverse. +true +""" +struct VecCorrBijector <: Bijector end +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +function triu_mask(X::AbstractMatrix, k::Int) + # Ensure that we're working with a square matrix. + LinearAlgebra.checksquare(X) - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] - w[i-1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j+1):K - w[i, j] = 0 + # Using `similar` allows us to respect device of array, etc., e.g. `CuArray`. + m = similar(X, Bool) + return triu(.~m .| m, k) +end + +triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)] + +function update_triu_from_vec!( + vals::AbstractVector{<:Real}, + k::Int, + X::AbstractMatrix{<:Real} +) + # Ensure that we're working with one-based indexing. + # `triu` requires this too. + LinearAlgebra.require_one_based_indexing(X) + + # Set the values. + idx = 1 + m, n = size(X) + for j = 1:n + for i = 1:min(j - k, m) + X[i, j] = vals[idx] + idx += 1 end end - - return w + + return X +end + +function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int) + X = similar(vals, dim, dim) + # TODO: Do we need this? + X .= 0 + return update_triu_from_vec!(vals, k, X) +end + +function ChainRulesCore.rrule(::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int) + function update_triu_from_vec_pullback(ΔX) + return ( + ChainRulesCore.NoTangent(), + triu_to_vec(ChainRulesCore.unthunk(ΔX), k), + ChainRulesCore.NoTangent(), + ChainRulesCore.NoTangent() + ) + end + return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback +end + +""" + triu1_to_vec(X::AbstractMatrix{<:Real}) + +Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector. +""" +triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1) + +inverse(::typeof(triu1_to_vec)) = vec_to_triu1 + +""" + vec_to_triu1(x::AbstractVector{<:Real}) + +Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`. +""" +function vec_to_triu1(x::AbstractVector) + n = _triu1_dim_from_length(length(x)) + X = update_triu_from_vec(x, 1, n) + return UpperTriangular(X) +end + +inverse(::typeof(vec_to_triu1)) = triu1_to_vec + +# n * (n - 1) / 2 = d +# ⟺ n^2 - n - 2d = 0 +# ⟹ n = (1 + sqrt(1 + 8d)) / 2 +_triu1_dim_from_length(d) = Int((1 + sqrt(1 + 8d)) / 2) + +function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) + w = cholesky(X).U # keep LowerTriangular until here can avoid some computation + r = _link_chol_lkj(w) + + # Extract only the upper triangle of `r`. + return triu1_to_vec(r) +end + +# NOTE: The `logabsdetjac` is NOT the correcet on for this `transform`. +# The `logabsdetjac` implementation also includes the `logabsdetjac` of the +# cholesky decomposition, which is only valid if we're working on the space of +# postitive-definite matrices. +function transform(::VecCorrBijector, chol_vec::AbstractVector{<:Real}) + r = _link_chol_lkj(vec_to_triu1(chol_vec)) + + # Extract only the upper triangle of `r`. + return triu1_to_vec(r) +end + + +function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + Y = vec_to_triu1(y) + w = _inv_link_chol_lkj(Y) + # TODO: Should we just return `w` instead? + return triu1_to_vec(w) +end + +function logabsdetjac(b::VecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end +function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + return _logabsdetjac_chol_lkj(vec_to_triu1(y)) end """ @@ -163,16 +282,55 @@ function _link_chol_lkj(w) # This block can't be integrated with loop below, because w[1,1] != 0. @inbounds z[1, 1] = 0 - @inbounds for j=2:K + @inbounds for j = 2:K z[1, j] = atanh(w[1, j]) tmp = sqrt(1 - w[1, j]^2) - for i in 2:(j - 1) + for i in 2:(j-1) p = w[i, j] / tmp tmp *= sqrt(1 - p^2) z[i, j] = atanh(p) end z[j, j] = 0 end - + return z end + +""" + _inv_link_chol_lkj(y) + +Inverse link function for cholesky factor. +""" +function _inv_link_chol_lkj(y) + K = LinearAlgebra.checksquare(y) + + w = similar(y) + + @inbounds for j in 1:K + w[1, j] = 1 + for i in 2:j + z = tanh(y[i-1, j]) + tmp = w[i-1, j] + w[i-1, j] = z * tmp + w[i, j] = tmp * sqrt(1 - z^2) + end + for i in (j+1):K + w[i, j] = 0 + end + end + + return w +end + +function _logabsdetjac_chol_lkj(Y::AbstractMatrix) + K = LinearAlgebra.checksquare(Y) + + result = float(zero(eltype(Y))) + for j in 2:K, i in 1:(j-1) + @inbounds abs_y_i_j = abs(Y[i, j]) + result += (K - i + 1) * ( + IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) + ) + end + return result +end From cea5f19d975aafcc86c44dca24a993a8171cbb0b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Feb 2023 00:19:28 +0000 Subject: [PATCH 085/172] added some tests for CorrBijector, and fixed implementation for VecCorrBijector --- src/bijectors/corr.jl | 25 ++++--------------------- src/bijectors/pd.jl | 2 ++ test/bijectors/corr.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 21 deletions(-) create mode 100644 test/bijectors/corr.jl diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f4cab829..613863cd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -92,9 +92,8 @@ end """ VecCorrBijector <: Bijector -Similar to `CorrBijector`, but transforms a vector representing the Cholesky -to a correlation matrix, and its inverse transforms correlation matrix to vector -representing Cholesky. +Similar to `CorrBijector`, but correlation matrix to a vector, +and its inverse transforms vector to a correlation matrix. See also: [`CorrBijector`](@ref) @@ -113,16 +112,13 @@ julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. -0.705273 1.0 0.0534538 -0.348638 0.0534538 1.0 -julia> # Get the cholesky and convert to a vector. - u = Bijectors.triu1_to_vec(cholesky(X).U) - julia> y = b(X) # Transform to unconstrained vector representation. 3-element Vector{Float64}: -0.8777149781928181 -0.3638927608636788 -0.29813769428942216 -julia> inverse(b)(y) ≈ u # (✓) Round-trip through `b` and its inverse. +julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ struct VecCorrBijector <: Bijector end @@ -215,23 +211,10 @@ function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) return triu1_to_vec(r) end -# NOTE: The `logabsdetjac` is NOT the correcet on for this `transform`. -# The `logabsdetjac` implementation also includes the `logabsdetjac` of the -# cholesky decomposition, which is only valid if we're working on the space of -# postitive-definite matrices. -function transform(::VecCorrBijector, chol_vec::AbstractVector{<:Real}) - r = _link_chol_lkj(vec_to_triu1(chol_vec)) - - # Extract only the upper triangle of `r`. - return triu1_to_vec(r) -end - - function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) Y = vec_to_triu1(y) w = _inv_link_chol_lkj(Y) - # TODO: Should we just return `w` instead? - return triu1_to_vec(w) + return w' * w end function logabsdetjac(b::VecCorrBijector, x) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index bed6ee9a..c9c0ff8a 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -18,6 +18,8 @@ function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) return getpd(X) end + +# TODO: AFAIK this is used because of AD-related issues; can we remove? getpd(X) = LowerTriangular(X) * LowerTriangular(X)' function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl new file mode 100644 index 00000000..7dbef68c --- /dev/null +++ b/test/bijectors/corr.jl @@ -0,0 +1,35 @@ +using Bijectors, DistributionsAD, LinearAlgebra, Test +using Bijectors: VecCorrBijector, CorrBijector + +@testset "PDBijector" begin + d = 3 + + b = CorrBijector() + bvec = VecCorrBijector() + + dist = LKJ(d, 1) + x = rand(dist) + + y = b(x) + yvec = bvec(x) + + # Make sure that they represent the same thing. + @test Bijectors.triu1_to_vec(y) ≈ yvec + + # Check the inverse. + binv = inverse(b) + xinv = binv(y) + bvecinv = inverse(bvec) + xvecinv = bvecinv(yvec) + + @test xinv ≈ xvecinv + + # And finally that the `logabsdetjac` is the same. + @test logabsdetjac(bvec, x) ≈ logabsdetjac(b, x) + + # NOTE: `CorrBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) + test_bijector(bvec, x; test_not_identity=true, changes_of_variables_test=false) +end From 89612cc0a90604797846c636cac9d7e6e71177c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Feb 2023 00:26:48 +0000 Subject: [PATCH 086/172] improved tests and are now using integer sqrt and division --- src/bijectors/corr.jl | 2 +- test/bijectors/corr.jl | 46 +++++++++++++++++++++--------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 613863cd..5899d250 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -201,7 +201,7 @@ inverse(::typeof(vec_to_triu1)) = triu1_to_vec # n * (n - 1) / 2 = d # ⟺ n^2 - n - 2d = 0 # ⟹ n = (1 + sqrt(1 + 8d)) / 2 -_triu1_dim_from_length(d) = Int((1 + sqrt(1 + 8d)) / 2) +_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) w = cholesky(X).U # keep LowerTriangular until here can avoid some computation diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 7dbef68c..ceb1cb4a 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -2,34 +2,34 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test using Bijectors: VecCorrBijector, CorrBijector @testset "PDBijector" begin - d = 3 + for d ∈ [1, 2, 5] + b = CorrBijector() + bvec = VecCorrBijector() - b = CorrBijector() - bvec = VecCorrBijector() + dist = LKJ(d, 1) + x = rand(dist) - dist = LKJ(d, 1) - x = rand(dist) + y = b(x) + yvec = bvec(x) - y = b(x) - yvec = bvec(x) + # Make sure that they represent the same thing. + @test Bijectors.triu1_to_vec(y) ≈ yvec - # Make sure that they represent the same thing. - @test Bijectors.triu1_to_vec(y) ≈ yvec + # Check the inverse. + binv = inverse(b) + xinv = binv(y) + bvecinv = inverse(bvec) + xvecinv = bvecinv(yvec) - # Check the inverse. - binv = inverse(b) - xinv = binv(y) - bvecinv = inverse(bvec) - xvecinv = bvecinv(yvec) + @test xinv ≈ xvecinv - @test xinv ≈ xvecinv + # And finally that the `logabsdetjac` is the same. + @test logabsdetjac(bvec, x) ≈ logabsdetjac(b, x) - # And finally that the `logabsdetjac` is the same. - @test logabsdetjac(bvec, x) ≈ logabsdetjac(b, x) - - # NOTE: `CorrBijector` technically isn't bijective, and so the default `getjacobian` - # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. - # Hence, we disable those tests. - test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) - test_bijector(bvec, x; test_not_identity=true, changes_of_variables_test=false) + # NOTE: `CorrBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) + end end From bc8f755a4aab3cef265ecb864c4597d2124c409c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 12 Feb 2023 15:23:07 +0000 Subject: [PATCH 087/172] moved things around a bit --- src/bijectors/corr.jl | 79 +++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5899d250..84cdb2e8 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -90,40 +90,10 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) end """ - VecCorrBijector <: Bijector - -Similar to `CorrBijector`, but correlation matrix to a vector, -and its inverse transforms vector to a correlation matrix. - -See also: [`CorrBijector`](@ref) - -# Example - -```jldoctest -julia> using LinearAlgebra - -julia> using StableRNGs; rng = StableRNG(42); - -julia> b = Bijectors.VecCorrBijector(); - -julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. -3×3 Matrix{Float64}: - 1.0 -0.705273 -0.348638 - -0.705273 1.0 0.0534538 - -0.348638 0.0534538 1.0 - -julia> y = b(X) # Transform to unconstrained vector representation. -3-element Vector{Float64}: - -0.8777149781928181 - -0.3638927608636788 - -0.29813769428942216 + triu_mask(X::AbstractMatrix, k::Int) -julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. -true +Return a mask for elements of `X` above the `k`th diagonal. """ -struct VecCorrBijector <: Bijector end -with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) - function triu_mask(X::AbstractMatrix, k::Int) # Ensure that we're working with a square matrix. LinearAlgebra.checksquare(X) @@ -176,6 +146,11 @@ function ChainRulesCore.rrule(::typeof(update_triu_from_vec), x::AbstractVector{ return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback end +# n * (n - 1) / 2 = d +# ⟺ n^2 - n - 2d = 0 +# ⟹ n = (1 + sqrt(1 + 8d)) / 2 +_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 + """ triu1_to_vec(X::AbstractMatrix{<:Real}) @@ -198,10 +173,40 @@ end inverse(::typeof(vec_to_triu1)) = triu1_to_vec -# n * (n - 1) / 2 = d -# ⟺ n^2 - n - 2d = 0 -# ⟹ n = (1 + sqrt(1 + 8d)) / 2 -_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 +""" + VecCorrBijector <: Bijector + +Similar to `CorrBijector`, but correlation matrix to a vector, +and its inverse transforms vector to a correlation matrix. + +See also: [`CorrBijector`](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCorrBijector(); + +julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. +3×3 Matrix{Float64}: + 1.0 -0.705273 -0.348638 + -0.705273 1.0 0.0534538 + -0.348638 0.0534538 1.0 + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. +true +""" +struct VecCorrBijector <: Bijector end +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) w = cholesky(X).U # keep LowerTriangular until here can avoid some computation @@ -240,7 +245,7 @@ end But this implementation will not work when w[i-1, j] = 0. Though it is a zero measure set, unit matrix initialization will not work. -For equivelence, following explanations is given by @torfjelde: +For equivalence, following explanations is given by @torfjelde: For `(i, j)` in the loop below, we define From 9b3d7e90cb8bf02ee044836a92b39af9b59296ff Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Feb 2023 09:49:05 +0000 Subject: [PATCH 088/172] added chainrule for ReverseDiff --- src/compat/reversediff.jl | 4 +++- test/bijectors/corr.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index c498205a..15369b38 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,7 +1,7 @@ module ReverseDiffCompat using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, - TrackedMatrix + TrackedMatrix, @grad_from_chainrules using Requires, LinearAlgebra using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, @@ -181,6 +181,8 @@ end return y, (wrap_chainrules_output ∘ Base.tail ∘ dy) end +@grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) + # NOTE: Probably doesn't work in complete generality. wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index ceb1cb4a..16b7de9c 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,7 +1,7 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test using Bijectors: VecCorrBijector, CorrBijector -@testset "PDBijector" begin +@testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] b = CorrBijector() bvec = VecCorrBijector() From b1176d061d2e172b54579769cd200c1d5349eb3e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Feb 2023 10:50:51 +0000 Subject: [PATCH 089/172] some fixes for AD --- src/bijectors/corr.jl | 10 +++++----- src/bijectors/pd.jl | 8 ++------ src/compat/reversediff.jl | 29 +++++++++++++++++++++++------ src/compat/tracker.jl | 10 +++++----- src/compat/zygote.jl | 6 +++--- src/utils.jl | 7 +++++++ 6 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 84cdb2e8..e043b2bf 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -66,7 +66,7 @@ struct CorrBijector <: Bijector end with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) - w = cholesky(x).U # keep LowerTriangular until here can avoid some computation + w = upper_triangular(parent(cholesky(x).U)) # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) return r + zero(x) # This dense format itself is required by a test, though I can't get the point. @@ -75,7 +75,7 @@ end function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) w = _inv_link_chol_lkj(y) - return w' * w + return pd_from_upper(w) end logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y) @@ -168,7 +168,7 @@ Constructs a matrix from a vector `x` by filling the upper triangle with offset function vec_to_triu1(x::AbstractVector) n = _triu1_dim_from_length(length(x)) X = update_triu_from_vec(x, 1, n) - return UpperTriangular(X) + return upper_triangular(X) end inverse(::typeof(vec_to_triu1)) = triu1_to_vec @@ -209,7 +209,7 @@ struct VecCorrBijector <: Bijector end with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) - w = cholesky(X).U # keep LowerTriangular until here can avoid some computation + w = upper_triangular(parent(cholesky(X).U)) r = _link_chol_lkj(w) # Extract only the upper triangle of `r`. @@ -219,7 +219,7 @@ end function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) Y = vec_to_triu1(y) w = _inv_link_chol_lkj(Y) - return w' * w + return pd_from_upper(w) end function logabsdetjac(b::VecCorrBijector, x) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index c9c0ff8a..3ba5526b 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -9,19 +9,15 @@ function replace_diag(f, X) end transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X) function pd_link(X) - Y = lower(parent(cholesky(X; check = true).L)) + Y = lower_triangular(parent(cholesky(X; check = true).L)) return replace_diag(log, Y) end -lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) - return getpd(X) + return pd_from_lower(X) end -# TODO: AFAIK this is used because of AD-related issues; can we remove? -getpd(X) = LowerTriangular(X) * LowerTriangular(X)' - function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) T = eltype(X) Xcf = cholesky(X, check = false) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 15369b38..78871ce1 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -7,7 +7,8 @@ using Requires, LinearAlgebra using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, simplex_invlink_jacobian, simplex_logabsdetjac_gradient, Inverse import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, - _simplex_inv_bijector, replace_diag, jacobian, getpd, lower, + _simplex_inv_bijector, replace_diag, jacobian, pd_from_lower, pd_from_upper, + lower_triangular, upper_triangular, _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, find_alpha @@ -136,18 +137,34 @@ logabsdetjac(b::SimplexBijector, x::Union{TrackedVector, TrackedMatrix}) = track end end -getpd(X::TrackedMatrix) = track(getpd, X) -@grad function getpd(X::AbstractMatrix) +pd_from_lower(X::TrackedMatrix) = track(pd_from_lower, X) +@grad function pd_from_lower(X::AbstractMatrix) Xd = value(X) return LowerTriangular(Xd) * LowerTriangular(Xd)', Δ -> begin Xl = LowerTriangular(Xd) return (LowerTriangular(Δ' * Xl + Δ * Xl),) end end -lower(A::TrackedMatrix) = track(lower, A) -@grad function lower(A::AbstractMatrix) + +pd_from_upper(X::TrackedMatrix) = track(pd_from_upper, X) +@grad function pd_from_upper(X::AbstractMatrix) + Xd = value(X) + return UpperTriangular(Xd)' * UpperTriangular(Xd), Δ -> begin + Xu = UpperTriangular(Xd) + return (UpperTriangular(Δ * Xu + Δ' * Xu),) + end +end + +lower_triangular(A::TrackedMatrix) = track(lower_triangular, A) +@grad function lower_triangular(A::AbstractMatrix) + Ad = value(A) + return lower_triangular(Ad), Δ -> (lower_triangular(Δ),) +end + +upper_triangular(A::TrackedMatrix) = track(upper_triangular, A) +@grad function upper_triangular(A::AbstractMatrix) Ad = value(A) - return lower(Ad), Δ -> (lower(Δ),) + return upper_triangular(Ad), Δ -> (upper_triangular(Δ),) end function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal} diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 1166a29e..4763d724 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -281,8 +281,8 @@ end (b::Elementwise{typeof(log)})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) (b::Elementwise{typeof(log)})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) -Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) -@grad function Bijectors.getpd(X::AbstractMatrix) +Bijectors.pd_from_lower(X::TrackedMatrix) = track(Bijectors.pd_from_lower, X) +@grad function Bijectors.pd_from_lower(X::AbstractMatrix) Xd = data(X) return Bijectors.LowerTriangular(Xd) * Bijectors.LowerTriangular(Xd)', Δ -> begin Xl = Bijectors.LowerTriangular(Xd) @@ -290,10 +290,10 @@ Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) end end -Bijectors.lower(A::TrackedMatrix) = track(Bijectors.lower, A) -@grad function Bijectors.lower(A::AbstractMatrix) +Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, A) +@grad function Bijectors.lower_triangular(A::AbstractMatrix) Ad = data(A) - return Bijectors.lower(Ad), Δ -> (Bijectors.lower(Δ),) + return Bijectors.lower_triangular(Ad), Δ -> (Bijectors.lower_triangular(Δ),) end Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index eedf4b3d..6a81a749 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -158,10 +158,10 @@ end end return pullback(_maximum, d) end -@adjoint function lower(A::AbstractMatrix) - return lower(A), Δ -> (lower(Δ),) +@adjoint function lower_triangular(A::AbstractMatrix) + return lower_triangular(A), Δ -> (lower_triangular(Δ),) end -@adjoint function getpd(X::AbstractMatrix) +@adjoint function pd_from_lower(X::AbstractMatrix) return LowerTriangular(X) * LowerTriangular(X)', Δ -> begin Xl = LowerTriangular(X) return (LowerTriangular(Δ' * Xl + Δ * Xl),) diff --git a/src/utils.jl b/src/utils.jl index 8203e1b4..dca95731 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,3 +6,10 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b) # flatten arrays with fallback for scalars _vec(x::AbstractArray{<:Real}) = vec(x) _vec(x::Real) = x + +# # Because `ReverseDiff` does not play well with structural matrices. +lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) +upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) + +pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' +pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) From f3a623f28167eeb23ae35afd45471a54cdaf67dc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Feb 2023 10:55:07 +0000 Subject: [PATCH 090/172] added some TODOs --- src/bijectors/corr.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index e043b2bf..6c4f3b97 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -262,6 +262,7 @@ and so which is the above implementation. """ function _link_chol_lkj(w) + # TODO: Implement adjoint to support reverse-mode AD backends properly. K = LinearAlgebra.checksquare(w) z = similar(w) # z is also UpperTriangular. @@ -290,6 +291,7 @@ end Inverse link function for cholesky factor. """ function _inv_link_chol_lkj(y) + # TODO: Implement adjoint to support reverse-mode AD backends properly. K = LinearAlgebra.checksquare(y) w = similar(y) From d46e966b38534927c48022696ca9a7b6904c1bd3 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 4 Apr 2023 15:17:17 +0100 Subject: [PATCH 091/172] define bijectors for `LKJ` and `LKJCholesky` --- src/transformed_distribution.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index ffa30237..62aba47c 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -77,7 +77,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() bijector(d::MatrixBeta) = PDBijector() -bijector(d::LKJ) = CorrBijector() +bijector(d::LKJ) = VecCorrBijector() +bijector(d::LKJCholesky) = d.uplo === 'L' ? VecTrilBijector() : VecTriuBijector() function bijector(d::Distributions.ReshapedDistribution) inner_dims = size(d.dist) From f21035693c46dea8487584237d001294d881c559 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:49:55 +0100 Subject: [PATCH 092/172] add `TransformedDistribution` constructor for `LKJCholesky` --- src/transformed_distribution.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 62aba47c..df5c7eed 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -6,6 +6,7 @@ struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D< TransformedDistribution(d::UnivariateDistribution, b) = new{typeof(d), typeof(b), Univariate}(d, b) TransformedDistribution(d::MultivariateDistribution, b) = new{typeof(d), typeof(b), Multivariate}(d, b) TransformedDistribution(d::MatrixDistribution, b) = new{typeof(d), typeof(b), Matrixvariate}(d, b) + TransformedDistribution(d::Distribution{CholeskyVariate}, b) = new{typeof(d), typeof(b), CholeskyVariate}(d, b) end # fields may contain nested numerical parameters From 71e1017f3764a1b8d8a0846b6cd82e4550f578a3 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:50:27 +0100 Subject: [PATCH 093/172] define `logpdf` for `LKJ` & `LKJCholesky` --- src/transformed_distribution.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index df5c7eed..3cb2aed5 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -116,6 +116,11 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) return logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac end +function logpdf(td::TransformedDistribution{T}, y::AbstractVector{<:Real}) where {T <: Union{LKJ, LKJCholesky}} + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, x) + logjac +end + function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) return logpdf(td.dist, x) + logjac From 37e649c01120a899bb71fa3ab540df4f59de1abf Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:50:53 +0100 Subject: [PATCH 094/172] define `rand` for `LKJ` & `LKJCholesky` --- src/transformed_distribution.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 3cb2aed5..fa53dc95 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -168,6 +168,10 @@ function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) x .= td.transform(x) end +function rand(rng::AbstractRNG, td::TransformedDistribution{T}) where {T <: Union{LKJ, LKJCholesky}} + return td.transform(rand(rng, td.dist)) +end + # utility stuff Distributions.params(td::Transformed) = Distributions.params(td.dist) function Base.maximum(td::UnivariateTransformed) From c09c5c8399ab4655597e62abe6ff1a49d636d45a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:51:28 +0100 Subject: [PATCH 095/172] add util to extract Cholesky factor --- src/utils.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index dca95731..974a28a6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,3 +13,8 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) + +cholesky_factor(X::AbstractMatrix) = cholesky(X).UL +cholesky_factor(X::Cholesky) = X.UL +cholesky_factor(X::UpperTriangular) = X +cholesky_factor(X::LowerTriangular) = X From 2a514c85f3858564b8f7eef444a4fd2739a7c770 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:52:51 +0100 Subject: [PATCH 096/172] TYPO: capitalize matrix --- src/bijectors/corr.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 6c4f3b97..b44606fc 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -65,10 +65,10 @@ struct CorrBijector <: Bijector end with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) -function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) - w = upper_triangular(parent(cholesky(x).U)) # keep LowerTriangular until here can avoid some computation +function transform(b::CorrBijector, X::AbstractMatrix{<:Real}) + w = upper_triangular(parent(cholesky(X).U)) # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) - return r + zero(x) + return r + zero(X) # This dense format itself is required by a test, though I can't get the point. # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end From 6596c9e5e5024d234892c4898ee33eb54263d408 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:53:44 +0100 Subject: [PATCH 097/172] add util to convert `Vector` index to `Matrix` row index --- src/bijectors/corr.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index b44606fc..909a3b8a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -173,6 +173,13 @@ end inverse(::typeof(vec_to_triu1)) = triu1_to_vec +function vec_to_triu1_row_index(idx) + # Assumes that vector was saved in a column-major order + # and that vector is one-based indexed. + M = _triu1_dim_from_length(idx - 1) + return idx - (M*(M-1) ÷ 2) +end + """ VecCorrBijector <: Bijector From 6123d6d7fdea96bc1c924ddefce10978a5c2b29f Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:55:16 +0100 Subject: [PATCH 098/172] add `VecTriBijector`s for `LKJCholesky` --- src/bijectors/corr.jl | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 909a3b8a..d79638ba 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -180,6 +180,16 @@ function vec_to_triu1_row_index(idx) return idx - (M*(M-1) ÷ 2) end +abstract type AbstractVecCorrBijector <: Bijector end + +with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::AbstractVecCorrBijector, X) = (_link_chol_lkj ∘ cholesky_factor)(X) + +function logabsdetjac(b::AbstractVecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end + """ VecCorrBijector <: Bijector @@ -212,29 +222,21 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector <: Bijector end -with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) +struct VecCorrBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = (pd_from_upper ∘ _inv_link_chol_lkj)(y) -function transform(::VecCorrBijector, X::AbstractMatrix{<:Real}) - w = upper_triangular(parent(cholesky(X).U)) - r = _link_chol_lkj(w) +logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) - # Extract only the upper triangle of `r`. - return triu1_to_vec(r) -end +struct VecTriuBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ UpperTriangular ∘ _inv_link_chol_lkj)(y) -function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - Y = vec_to_triu1(y) - w = _inv_link_chol_lkj(Y) - return pd_from_upper(w) -end +logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) + +struct VecTrilBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ LowerTriangular ∘ transpose ∘ _inv_link_chol_lkj)(y) + +logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) -function logabsdetjac(b::VecCorrBijector, x) - return -logabsdetjac(inverse(b), b(x)) -end -function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - return _logabsdetjac_chol_lkj(vec_to_triu1(y)) -end """ function _link_chol_lkj(w) From 791f7646f32efe9a031d4d305299fb19afd7aca7 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:55:53 +0100 Subject: [PATCH 099/172] TYPO: capitilize matrix --- src/bijectors/corr.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index d79638ba..5b010dcd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -270,21 +270,21 @@ and so which is the above implementation. """ -function _link_chol_lkj(w) +function _link_chol_lkj(W::AbstractMatrix) # TODO: Implement adjoint to support reverse-mode AD backends properly. - K = LinearAlgebra.checksquare(w) + K = LinearAlgebra.checksquare(W) - z = similar(w) # z is also UpperTriangular. + z = similar(W) # z is also UpperTriangular. # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. - # This block can't be integrated with loop below, because w[1,1] != 0. + # This block can't be integrated with loop below, because W[1,1] != 0. @inbounds z[1, 1] = 0 @inbounds for j = 2:K - z[1, j] = atanh(w[1, j]) - tmp = sqrt(1 - w[1, j]^2) + z[1, j] = atanh(W[1, j]) + tmp = sqrt(1 - W[1, j]^2) for i in 2:(j-1) - p = w[i, j] / tmp + p = W[i, j] / tmp tmp *= sqrt(1 - p^2) z[i, j] = atanh(p) end From f47cdacc54ab73a6271c43d381e1a7c9db1e514b Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:56:44 +0100 Subject: [PATCH 100/172] add `LKJCholesky` link for `UpperTriangular` --- src/bijectors/corr.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5b010dcd..47aa4c50 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -294,6 +294,30 @@ function _link_chol_lkj(W::AbstractMatrix) return z end +function _link_chol_lkj(W::UpperTriangular) + K = LinearAlgebra.checksquare(W) + N = ((K-1)*K) ÷ 2 # {K \choose 2} free parameters + + z = zeros(eltype(W), N) + + # This block can't be integrated with loop below, because w[1,1] != 0. + idx = 1 + @inbounds for j = 2:K + z[idx] = atanh(W[1, j]) + idx += 1 + tmp = sqrt(1 - W[1, j]^2) + for i in 2:(j-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + z[idx] = atanh(p) + idx += 1 + end + end + + return z +end + +function _link_chol_lkj(W::LowerTriangular) """ _inv_link_chol_lkj(y) From 959b83648f053f073ce65cf70206da4baf3f437a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:57:05 +0100 Subject: [PATCH 101/172] add `LKJCholesky` link for `LowerTriangular` --- src/bijectors/corr.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 47aa4c50..596347c1 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -318,6 +318,28 @@ function _link_chol_lkj(W::UpperTriangular) end function _link_chol_lkj(W::LowerTriangular) + K = LinearAlgebra.checksquare(W) + N = div((K-1)*K, 2) # {K \choose 2} free parameters + + z = zeros(eltype(W), N) + + # This block can't be integrated with loop below, because w[1,1] != 0. + idx = 1 + @inbounds for i = 2:K + z[idx] = atanh(W[i, 1]) + idx += 1 + tmp = sqrt(1 - W[i, 1]^2) + for j in 2:(i-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + z[idx] = atanh(p) + idx += 1 + end + end + + return z +end + """ _inv_link_chol_lkj(y) From a8ccaa1cd8603049de51822c74a383594a32e8b5 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:57:54 +0100 Subject: [PATCH 102/172] TYPO: capitalize matrix --- src/bijectors/corr.jl | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 596347c1..bc80f61a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -345,22 +345,28 @@ end Inverse link function for cholesky factor. """ -function _inv_link_chol_lkj(y) +function _inv_link_chol_lkj(Y::AbstractMatrix) # TODO: Implement adjoint to support reverse-mode AD backends properly. - K = LinearAlgebra.checksquare(y) + K = LinearAlgebra.checksquare(Y) - w = similar(y) + W = similar(Y) @inbounds for j in 1:K - w[1, j] = 1 + W[1, j] = 1 for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] - w[i-1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) + z = tanh(Y[i-1, j]) + tmp = W[i-1, j] + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) end for i in (j+1):K - w[i, j] = 0 + W[i, j] = 0 + end + end + + return W +end + end end From 82bf085b790a9d7e13d40e186ef68b87a28e8b0a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 08:58:50 +0100 Subject: [PATCH 103/172] add `LKJCholesky` inverse link to `UpperTriangular` --- src/bijectors/corr.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index bc80f61a..89d34aee 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -367,10 +367,26 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) return W end +function _inv_link_chol_lkj(y::AbstractVector) + # TODO: Implement adjoint to support reverse-mode AD backends properly. + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + W .= zeros(eltype(y)) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + idx += 1 + tmp = W[i-1, j] + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) end end - return w + return W end function _logabsdetjac_chol_lkj(Y::AbstractMatrix) From 597b6a16bdf120c24cc59e0932b54ed94012f10e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:00:14 +0100 Subject: [PATCH 104/172] rename `_logabsdetjac_chol_lkj` to `_logabsdetjac_inv_corr` --- src/bijectors/corr.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 89d34aee..3496a016 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -78,7 +78,7 @@ function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) return pd_from_upper(w) end -logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y) +logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_inv_corr(Y) function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) #= It may be more efficient if we can use un-contraint value to prevent call of b @@ -389,7 +389,7 @@ function _inv_link_chol_lkj(y::AbstractVector) return W end -function _logabsdetjac_chol_lkj(Y::AbstractMatrix) +function _logabsdetjac_inv_corr(Y::AbstractMatrix) K = LinearAlgebra.checksquare(Y) result = float(zero(eltype(Y))) From 54dd86d7608301d6bf273b5ec2e01936dbf6c34e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:01:50 +0100 Subject: [PATCH 105/172] dispatch `_logabsdetjac_inv_corr` for `::Vector` --- src/bijectors/corr.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 3496a016..2ffe9ee3 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -401,3 +401,18 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix) end return result end + +function _logabsdetjac_inv_corr(y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + result = float(zero(eltype(y))) + for (i, y_i) in enumerate(y) + abs_y_i = abs(y_i) + row_idx = vec_to_triu1_row_index(i) + result += (K - row_idx + 1) * ( + IrrationalConstants.logtwo - (abs_y_i + LogExpFunctions.log1pexp(-2 * abs_y_i)) + ) + end + return result +end + From eaf60f7489393a63ce2a118ede93e10e0f500986 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:02:52 +0100 Subject: [PATCH 106/172] add logabsdetjac for inverse link of `LKJCholesky` --- src/bijectors/corr.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 2ffe9ee3..eb915b75 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -416,3 +416,21 @@ function _logabsdetjac_inv_corr(y::AbstractVector) return result end +function _logabsdetjac_inv_chol(y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + result = float(zero(eltype(y))) + idx = 1 + @inbounds for j in 2:K + tmp = zero(result) + for _ in 1:(j-1) + z = tanh(y[idx]) + logz = log(1 - z^2) + tmp += logz + result += logz + (tmp / 2) + idx += 1 + end + end + + return result +end From 861eef643d89be79898c33cd8be39d5a3c8715a2 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:44:06 +0100 Subject: [PATCH 107/172] add tests for `VecTriBijector`s --- test/bijectors/corr.jl | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 16b7de9c..900b5e68 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,5 +1,5 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, CorrBijector +using Bijectors: VecCorrBijector, CorrBijector, VecTriuBijector, VecTrilBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] @@ -33,3 +33,29 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) end end + +@testset "VecTriuBijector & VecTrilBijector" begin + for d ∈ [2, 5] + for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] + b = bijector(dist) + + b_lkj = VecCorrBijector() + x = rand(dist) + y = b(x) + y_lkj = b_lkj(x) + + @test y ≈ y_lkj + + binv = inverse(b) + xinv = binv(y) + binv_lkj = inverse(b_lkj) + xinv_lkj = binv_lkj(y_lkj) + + @test xinv.U ≈ cholesky(xinv_lkj).U + + # test_bijector is commented out for now, + # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) + # test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + end + end +end From 78b9999afe668bb04c01ec7a07dbbf7cbb0f761b Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 6 Apr 2023 09:46:13 +0100 Subject: [PATCH 108/172] add `rrule` for LKJ(Cholesky) link function --- src/chainrules.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 45cacf82..445e217b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -156,5 +156,59 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM return y, _transform_inverse_ordered_adjoint end +function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) + project_W = ChainRulesCore.ProjectTo(W) + + K = LinearAlgebra.checksquare(W) + N = ((K-1)*K) ÷ 2 + + z = zeros(eltype(W), N) + tmp_vec = similar(z) + + idx = 1 + @inbounds for j = 2:K + z[idx] = atanh(W[1, j]) + tmp = sqrt(1 - W[1, j]^2) + tmp_vec[idx] = tmp + idx += 1 + for i in 2:(j-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + tmp_vec[idx] = tmp + z[idx] = atanh(p) + idx += 1 + end + end + + function pullback_link_chol_lkj(Δz_thunked) + Δz = ChainRulesCore.unthunk(Δz_thunked) + + ΔW = similar(W) + + @inbounds ΔW[1,1] = zero(eltype(Δz)) + @inbounds for j=2:K + idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) + ΔW[j, j] = zero(eltype(Δz)) + Δtmp = zero(eltype(Δz)) + for i in (j-1):-1:2 + tmp = tmp_vec[idx_up_to_prev_column + i - 1] + p = W[i, j] / tmp + ftmp = sqrt(1 - p^2) + d_ftmp_p = -p / ftmp + d_p_tmp = -W[i,j] / tmp^2 + + Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp * d_ftmp_p + ΔW[i, j] = Δp / tmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp + end + ΔW[1, j] = Δz[1, j] / (1-W[1,j]^2) - Δtmp / sqrt(1 - W[1,j]^2) * W[1,j] + end + + return ChainRulesCore.NoTangent(), project_W(ΔW) + end + + return z, pullback_link_chol_lkj +end + # Fixes Zygote's issues with `@debug` -ChainRulesCore.@non_differentiable _debug(::Any) \ No newline at end of file +ChainRulesCore.@non_differentiable _debug(::Any) From 5b4119a464e15877235cff56c6df93b146c96817 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 11 Apr 2023 15:36:36 +0100 Subject: [PATCH 109/172] use `transpose` in link for `::LowerTriangular' --- src/bijectors/corr.jl | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index eb915b75..a0269ce7 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -317,28 +317,7 @@ function _link_chol_lkj(W::UpperTriangular) return z end -function _link_chol_lkj(W::LowerTriangular) - K = LinearAlgebra.checksquare(W) - N = div((K-1)*K, 2) # {K \choose 2} free parameters - - z = zeros(eltype(W), N) - - # This block can't be integrated with loop below, because w[1,1] != 0. - idx = 1 - @inbounds for i = 2:K - z[idx] = atanh(W[i, 1]) - idx += 1 - tmp = sqrt(1 - W[i, 1]^2) - for j in 2:(i-1) - p = W[i, j] / tmp - tmp *= sqrt(1 - p^2) - z[idx] = atanh(p) - idx += 1 - end - end - - return z -end +_link_chol_lkj(W::LowerTriangular) = (_link_chol_lkj ∘ transpose)(W) """ _inv_link_chol_lkj(y) From 011534c075aa10d193a06a6f36fa1431c3b51a07 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 12 Apr 2023 17:17:32 +0100 Subject: [PATCH 110/172] add `Tracker` support for inverse link --- src/compat/tracker.jl | 53 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 4763d724..dae58086 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,7 +12,7 @@ using ..Tracker: Tracker, param import ..Bijectors -using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked +using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked, _triu1_dim_from_length import ChainRulesCore import LogExpFunctions @@ -296,8 +296,57 @@ Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, return Bijectors.lower_triangular(Ad), Δ -> (Bijectors.lower_triangular(Δ),) end +Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_lkj, y) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedVector) + y = data(y_tracked) + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + W .= zeros(eltype(y)) + + z_vec = similar(y) + tmp_vec = similar(y) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + tmp = W[i-1, j] + + z_vec[idx] = z + tmp_vec[idx] = tmp + idx += 1 + + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + end + + function pullback_inv_link_chol_lkj(ΔW) + LinearAlgebra.checksquare(ΔW) + + Δy = zero(y) + + @inbounds for j in 1:K + idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) + Δtmp = ΔW[j,j] + for i in j:-1:2 + idx = idx_up_to_prev_column + i - 1 + Δz = ΔW[i-1, j] * tmp_vec[idx] - Δtmp * tmp_vec[idx] / sqrt(1 - z_vec[idx]^2) * z_vec[idx] + Δy[idx] = Δz / cosh(y[idx])^2 + Δtmp = ΔW[i-1, j] * z_vec[idx] + Δtmp * sqrt(1 - z_vec[idx]^2) + end + end + + return (Δy,) + end + + return W, pullback_inv_link_chol_lkj +end + Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) -@grad function Bijectors._inv_link_chol_lkj(y_tracked) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedMatrix) y = data(y_tracked) K = LinearAlgebra.checksquare(y) From ff61ef0024a1b1599e5e6701be28f1a6297c4a09 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 12 Apr 2023 18:13:11 +0100 Subject: [PATCH 111/172] better utility function call --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 974a28a6..34842e89 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,7 +14,7 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) -cholesky_factor(X::AbstractMatrix) = cholesky(X).UL +cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(X)) cholesky_factor(X::Cholesky) = X.UL cholesky_factor(X::UpperTriangular) = X cholesky_factor(X::LowerTriangular) = X From a2ec6038207d6bb06e0b973ecfbefd69b81c653c Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 12 Apr 2023 18:13:43 +0100 Subject: [PATCH 112/172] use function barrier properly for type stability --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index a0269ce7..f2dec9f9 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -184,7 +184,7 @@ abstract type AbstractVecCorrBijector <: Bijector end with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) -transform(::AbstractVecCorrBijector, X) = (_link_chol_lkj ∘ cholesky_factor)(X) +transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) function logabsdetjac(b::AbstractVecCorrBijector, x) return -logabsdetjac(inverse(b), b(x)) From 4c3a68bbcde7d6de3cff75f0e76bd7c14a937a19 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 13:17:34 +0100 Subject: [PATCH 113/172] account for difference in support dimensions --- test/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transform.jl b/test/transform.jl index 7be147d3..33f7a73b 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -61,7 +61,7 @@ function single_sample_tests(dist) else # This should probably be exact. @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) - @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100])) + @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(y)) for _ in 1:100])) end end From 6349546a4e4cea4eec384dd67c21333b8a44a5bb Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 13:18:21 +0100 Subject: [PATCH 114/172] fix indexing in Jacobian of `VecCorrBijector` --- test/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transform.jl b/test/transform.jl index 33f7a73b..2776d7fa 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -182,7 +182,7 @@ end upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] J = ForwardDiff.jacobian(x->link(dist, x), x) - J = J[upperinds, upperinds] + J = J[:, upperinds] logpdf_turing = logpdf_with_trans(dist, x, true) @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing end From e65a78bc4db32b4d13345d4bcad680ff347da134 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:06:52 +0100 Subject: [PATCH 115/172] add `_logabsdetjac_dist` for `::LKJCholesky` --- src/Bijectors.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index d934381f..6e1171da 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -137,6 +137,8 @@ _logabsdetjac_dist(d::MultivariateDistribution, x::AbstractMatrix) = logabsdetja _logabsdetjac_dist(d::MatrixDistribution, x::AbstractMatrix) = logabsdetjac(bijector(d), x) _logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractMatrix}) = logabsdetjac.((bijector(d),), x) +_logabsdetjac_dist(d::LKJCholesky, x::Cholesky) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::LKJCholesky, x::AbstractVector) = logabsdetjac.((bijector(d),), x) function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) From b6b7fa6e5f10446f38c8a578f6d48b479219255c Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:08:37 +0100 Subject: [PATCH 116/172] replace function composition for proper barrier --- src/bijectors/corr.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f2dec9f9..8caa436c 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -223,17 +223,17 @@ julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ struct VecCorrBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = (pd_from_upper ∘ _inv_link_chol_lkj)(y) +transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) struct VecTriuBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ UpperTriangular ∘ _inv_link_chol_lkj)(y) +transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = Cholesky(UpperTriangular(_inv_link_chol_lkj(y))) logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) struct VecTrilBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = (Cholesky ∘ LowerTriangular ∘ transpose ∘ _inv_link_chol_lkj)(y) +transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = Cholesky(LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y)))) logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) @@ -317,7 +317,7 @@ function _link_chol_lkj(W::UpperTriangular) return z end -_link_chol_lkj(W::LowerTriangular) = (_link_chol_lkj ∘ transpose)(W) +_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W)) """ _inv_link_chol_lkj(y) From fd2460289a7ee964235df5c1c0ea85b9e852d810 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:09:25 +0100 Subject: [PATCH 117/172] add util convert `Transpose -> Matrix` for type stability --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index 34842e89..3d41b5d2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,6 +10,7 @@ _vec(x::Real) = x # # Because `ReverseDiff` does not play well with structural matrices. lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) +_transpose_matrix(A::AbstractMatrix) = Matrix(transpose(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) From 1cd62d141f5bf7f5bdc3331d4ff9737da16f818f Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 13 Apr 2023 15:10:20 +0100 Subject: [PATCH 118/172] add `LKJCholesky` Jacobian+type tests --- test/transform.jl | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/test/transform.jl b/test/transform.jl index 2776d7fa..f86b19ae 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -38,8 +38,13 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) - @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 + if dist isa LKJCholesky + x_inv = @inferred(invlink(dist, link(dist, copy(x)))) + @test x_inv.UL ≈ x.UL atol=1e-9 + else + @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 + end # Check that link is inverse of invlink. Hopefully this just holds given the above... y = @inferred(link(dist, x)) if dist isa Dirichlet @@ -169,9 +174,9 @@ let end end -@testset "correlation matrix" begin +@testset "LKJ" begin - dist = LKJ(2, 1) + dist = LKJ(3, 1) single_sample_tests(dist) @@ -187,6 +192,22 @@ end @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing end + +@testset "LKJCholesky" begin + + dist = LKJCholesky(3, 1) + + single_sample_tests(dist) + + x = rand(dist) + + upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] + J = ForwardDiff.jacobian(x->link(dist, x), x.U) + J = J[:, upperinds] + logpdf_turing = logpdf_with_trans(dist, x, true) + @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing +end + ################################## Miscelaneous old tests ################################## # julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), exp.([-1000., -1000., -1000.]), true) From f437e68452d3bbd4e4481d758eec854464002f92 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Fri, 14 Apr 2023 15:35:15 +0100 Subject: [PATCH 119/172] fix `logabsdetjac` for inverse link --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 8caa436c..ccd061bc 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -405,8 +405,8 @@ function _logabsdetjac_inv_chol(y::AbstractVector) for _ in 1:(j-1) z = tanh(y[idx]) logz = log(1 - z^2) - tmp += logz result += logz + (tmp / 2) + tmp += logz idx += 1 end end From 85397e8dfed3bb214b559f51ebeef9258a083032 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Fri, 14 Apr 2023 16:16:41 +0100 Subject: [PATCH 120/172] use `Cholesky` constructor compatible with `v1.6` --- src/bijectors/corr.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index ccd061bc..0cffe0db 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -228,12 +228,24 @@ transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) struct VecTriuBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = Cholesky(UpperTriangular(_inv_link_chol_lkj(y))) + +function transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) + # This constructor is compatible with Julia v1.6 + # for later versions Cholesky(::UpperTriangular) works + U = UpperTriangular(_inv_link_chol_lkj(y)) + return Cholesky(U.data, 'U', 0) +end logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) struct VecTrilBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = Cholesky(LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y)))) + +function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) + # This constructor is compatible with Julia v1.6 + # for later versions Cholesky(::LowerTriangular) works + L = LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y))) + return Cholesky(L.data, 'L', 0) +end logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) From aa5685a896af9f1c7eb61d3d9030747328ff4024 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 12:43:59 +0100 Subject: [PATCH 121/172] add empty line --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index 3d41b5d2..a0c2841b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,6 +10,7 @@ _vec(x::Real) = x # # Because `ReverseDiff` does not play well with structural matrices. lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) + _transpose_matrix(A::AbstractMatrix) = Matrix(transpose(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' From df264d606a183931fe5f4dd0a5830bcf7ea5fa23 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 12:44:25 +0100 Subject: [PATCH 122/172] fix `rrule` for link function --- src/chainrules.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 445e217b..6f274f3b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -157,7 +157,6 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM end function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) - project_W = ChainRulesCore.ProjectTo(W) K = LinearAlgebra.checksquare(W) N = ((K-1)*K) ÷ 2 @@ -186,9 +185,10 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) ΔW = similar(W) @inbounds ΔW[1,1] = zero(eltype(Δz)) + @inbounds for j=2:K idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) - ΔW[j, j] = zero(eltype(Δz)) + ΔW[j, j] = 0 Δtmp = zero(eltype(Δz)) for i in (j-1):-1:2 tmp = tmp_vec[idx_up_to_prev_column + i - 1] @@ -197,14 +197,14 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) d_ftmp_p = -p / ftmp d_p_tmp = -W[i,j] / tmp^2 - Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp * d_ftmp_p + Δp = Δz[idx_up_to_prev_column + i] / (1-p^2) + Δtmp * tmp * d_ftmp_p ΔW[i, j] = Δp / tmp - Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp end - ΔW[1, j] = Δz[1, j] / (1-W[1,j]^2) - Δtmp / sqrt(1 - W[1,j]^2) * W[1,j] + ΔW[1, j] = Δz[idx_up_to_prev_column + 1] / (1-W[1,j]^2) - Δtmp / sqrt(1 - W[1,j]^2) * W[1,j] end - return ChainRulesCore.NoTangent(), project_W(ΔW) + return ChainRulesCore.NoTangent(), ΔW end return z, pullback_link_chol_lkj From 599cb6648eaad2f34e56190915d4d614257a52e0 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 12:44:41 +0100 Subject: [PATCH 123/172] add link `rrule` test --- test/ad/chainrules.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index b0e4dc2e..a2289d26 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -13,4 +13,9 @@ test_rrule(Bijectors._transform_ordered, randn(5, 2)) test_rrule(Bijectors._transform_inverse_ordered, b(rand(5))) test_rrule(Bijectors._transform_inverse_ordered, b(rand(5, 2))) + + # LKJ and LKJCholesky bijector + dist = LKJCholesky(3, 1) + x = rand(dist) + test_rrule(Bijectors._link_chol_lkj, x.U) end From 9cd42c0741ee3be60446eb7634bfb6f4a08a4542 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 13:27:55 +0100 Subject: [PATCH 124/172] add `rrule` for inverse link --- src/chainrules.jl | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index 6f274f3b..79389c09 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -210,5 +210,56 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) return z, pullback_link_chol_lkj end +function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) + + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + W .= zeros(eltype(y)) + + z_vec = similar(y) + tmp_vec = similar(y) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + tmp = W[i-1, j] + + z_vec[idx] = z + tmp_vec[idx] = tmp + idx += 1 + + W[i-1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + end + + function pullback_inv_link_chol_lkj(ΔW_thunked) + ΔW = ChainRulesCore.unthunk(ΔW_thunked) + + Δy = zero(y) + + @inbounds for j in 1:K + idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2) + Δtmp = ΔW[j,j] + for i in j:-1:2 + idx = idx_up_to_prev_column + i - 1 + tmp = tmp_vec[idx] + z = z_vec[idx] + + Δz = ΔW[i-1, j] * tmp - Δtmp * tmp / sqrt(1 - z^2) * z + Δy[idx] = Δz / cosh(y[idx])^2 + Δtmp = ΔW[i-1, j] * z + Δtmp * sqrt(1 - z^2) + end + end + + return ChainRulesCore.NoTangent(), Δy + end + + return W, pullback_inv_link_chol_lkj +end + # Fixes Zygote's issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) From 9de473468522d52f901703b3c94b4d036a084d6f Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 13:28:12 +0100 Subject: [PATCH 125/172] remove TODO --- src/bijectors/corr.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 0cffe0db..f6c3b5cb 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -359,7 +359,6 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) end function _inv_link_chol_lkj(y::AbstractVector) - # TODO: Implement adjoint to support reverse-mode AD backends properly. K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) From befa1cccbb50b5bb58c3aff89aa029c7414d04c0 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 17 Apr 2023 13:28:29 +0100 Subject: [PATCH 126/172] add inverse link `rrule` test --- test/ad/chainrules.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index a2289d26..2542e48f 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -18,4 +18,8 @@ dist = LKJCholesky(3, 1) x = rand(dist) test_rrule(Bijectors._link_chol_lkj, x.U) + + b = bijector(dist) + y = b(x) + test_rrule(Bijectors._inv_link_chol_lkj, y) end From 6ba1c1f4c92a0c5b98face5f2e875e1fb5a25bf2 Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Mon, 17 Apr 2023 20:54:50 +0100 Subject: [PATCH 127/172] Update src/bijectors/corr.jl Co-authored-by: Tor Erlend Fjelde --- src/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f6c3b5cb..5951840e 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -243,7 +243,9 @@ struct VecTrilBijector <: AbstractVecCorrBijector end function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) # This constructor is compatible with Julia v1.6 # for later versions Cholesky(::LowerTriangular) works - L = LowerTriangular(_transpose_matrix(_inv_link_chol_lkj(y))) + # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. + # If we don't, the return-type can be both `Matrix` and `Transposed`. + L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) return Cholesky(L.data, 'L', 0) end From 79ad5f8343beb545164b187cd8cc5c0d1da6be44 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:19:13 +0100 Subject: [PATCH 128/172] add link `rrule` for `LowerTriangular` --- src/chainrules.jl | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index 79389c09..2d3a0f2d 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -210,6 +210,59 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) return z, pullback_link_chol_lkj end +function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular) + K = LinearAlgebra.checksquare(W) + N = ((K-1)*K) ÷ 2 + + z = zeros(eltype(W), N) + tmp_vec = similar(z) + + idx = 1 + @inbounds for i = 2:K + z[idx] = atanh(W[i, 1]) + tmp = sqrt(1 - W[i, 1]^2) + tmp_vec[idx] = tmp + idx += 1 + for j in 2:(i-1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + tmp_vec[idx] = tmp + z[idx] = atanh(p) + idx += 1 + end + end + + function pullback_link_chol_lkj(Δz_thunked) + Δz = ChainRulesCore.unthunk(Δz_thunked) + + ΔW = similar(W) + + @inbounds ΔW[1,1] = zero(eltype(Δz)) + + @inbounds for i=2:K + idx_up_to_prev_row = ((i-1)*(i-2) ÷ 2) + ΔW[i, i] = 0 + Δtmp = zero(eltype(Δz)) + for j in (i-1):-1:2 + tmp = tmp_vec[idx_up_to_prev_row + j - 1] + p = W[i, j] / tmp + ftmp = sqrt(1 - p^2) + d_ftmp_p = -p / ftmp + d_p_tmp = -W[i,j] / tmp^2 + + Δp = Δz[idx_up_to_prev_row + j] / (1-p^2) + Δtmp * tmp * d_ftmp_p + ΔW[i, j] = Δp / tmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp + end + ΔW[i, 1] = Δz[idx_up_to_prev_row + 1] / (1-W[i,1]^2) - Δtmp / sqrt(1 - W[i,1]^2) * W[i,1] + end + + return ChainRulesCore.NoTangent(), ΔW + end + + return z, pullback_link_chol_lkj +end + function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) K = _triu1_dim_from_length(length(y)) From 19e8843aca072e72be9455018ce54cd82abbfeeb Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:19:52 +0100 Subject: [PATCH 129/172] add `LowerTriangular` chainrule test --- test/ad/chainrules.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 2542e48f..59581fda 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -18,6 +18,7 @@ dist = LKJCholesky(3, 1) x = rand(dist) test_rrule(Bijectors._link_chol_lkj, x.U) + test_rrule(Bijectors._link_chol_lkj, x.L) b = bijector(dist) y = b(x) From 4216dbdcd6425660e5b95fbfc987f30637f5e159 Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Tue, 18 Apr 2023 10:38:12 +0100 Subject: [PATCH 130/172] Update src/bijectors/corr.jl Co-authored-by: Tor Erlend Fjelde --- src/bijectors/corr.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5951840e..01e89b1f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -182,6 +182,7 @@ end abstract type AbstractVecCorrBijector <: Bijector end +TODO: Implement directly to make use of shared computations. with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) From e70430f44a51ccff8f374adfec39a7b42c7f4a80 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:45:11 +0100 Subject: [PATCH 131/172] remove unused util --- src/utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index a0c2841b..34842e89 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,8 +11,6 @@ _vec(x::Real) = x lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) -_transpose_matrix(A::AbstractMatrix) = Matrix(transpose(A)) - pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) From 2caba1cd23e421b6948b2d5818a0385fc9204c82 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:45:42 +0100 Subject: [PATCH 132/172] use `similar` instead of `zeros` --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 01e89b1f..eed36742 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -313,7 +313,7 @@ function _link_chol_lkj(W::UpperTriangular) K = LinearAlgebra.checksquare(W) N = ((K-1)*K) ÷ 2 # {K \choose 2} free parameters - z = zeros(eltype(W), N) + z = similar(W, N) # This block can't be integrated with loop below, because w[1,1] != 0. idx = 1 From 561f6b151af90c8eb749a86a3b2023b9a5b058fa Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:45:58 +0100 Subject: [PATCH 133/172] update comments --- src/bijectors/corr.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index eed36742..d3b8a07d 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -182,7 +182,7 @@ end abstract type AbstractVecCorrBijector <: Bijector end -TODO: Implement directly to make use of shared computations. +# TODO: Implement directly to make use of shared computations. with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) @@ -231,9 +231,9 @@ logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdet struct VecTriuBijector <: AbstractVecCorrBijector end function transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) - # This constructor is compatible with Julia v1.6 - # for later versions Cholesky(::UpperTriangular) works U = UpperTriangular(_inv_link_chol_lkj(y)) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::UpperTriangular) works return Cholesky(U.data, 'U', 0) end @@ -242,11 +242,11 @@ logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdet struct VecTrilBijector <: AbstractVecCorrBijector end function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) - # This constructor is compatible with Julia v1.6 - # for later versions Cholesky(::LowerTriangular) works # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::LowerTriangular) works return Cholesky(L.data, 'L', 0) end From 69f5daadeb91709e7e00937472a6b5e2c01768ab Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 10:46:55 +0100 Subject: [PATCH 134/172] remove old comment --- src/bijectors/corr.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index d3b8a07d..30fc5032 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -315,7 +315,6 @@ function _link_chol_lkj(W::UpperTriangular) z = similar(W, N) - # This block can't be integrated with loop below, because w[1,1] != 0. idx = 1 @inbounds for j = 2:K z[idx] = atanh(W[1, j]) From ca9807e1c882b7eb3cbe279d063eeb3860cea071 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 11:18:23 +0100 Subject: [PATCH 135/172] minimize zero-setting operations in inverse link --- src/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 30fc5032..41d7b533 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -364,7 +364,6 @@ function _inv_link_chol_lkj(y::AbstractVector) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) - W .= zeros(eltype(y)) idx = 1 @inbounds for j in 1:K @@ -376,6 +375,9 @@ function _inv_link_chol_lkj(y::AbstractVector) W[i-1, j] = z * tmp W[i, j] = tmp * sqrt(1 - z^2) end + for i in (j+1):K + W[i, j] = 0 + end end return W From 1883b361c3425e5dbd662e35c565e3d6d9d8d752 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 11:20:09 +0100 Subject: [PATCH 136/172] minimize zero-setting operations in `rrule` --- src/chainrules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 2d3a0f2d..75e8676b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -268,7 +268,6 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) - W .= zeros(eltype(y)) z_vec = similar(y) tmp_vec = similar(y) @@ -287,6 +286,9 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) W[i-1, j] = z * tmp W[i, j] = tmp * sqrt(1 - z^2) end + for i in (j+1):K + W[i, j] = 0 + end end function pullback_inv_link_chol_lkj(ΔW_thunked) From f84b329fd5b6a490c567abb9604fb8bcda51ee76 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 12:38:51 +0100 Subject: [PATCH 137/172] add parametric `Val` type to `VecCorrBijector` --- src/bijectors/corr.jl | 43 ++++++++++++++++----------------- src/transformed_distribution.jl | 4 +-- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 41d7b533..99c1fe25 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -180,17 +180,6 @@ function vec_to_triu1_row_index(idx) return idx - (M*(M-1) ÷ 2) end -abstract type AbstractVecCorrBijector <: Bijector end - -# TODO: Implement directly to make use of shared computations. -with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) - -transform(::AbstractVecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) - -function logabsdetjac(b::AbstractVecCorrBijector, x) - return -logabsdetjac(inverse(b), b(x)) -end - """ VecCorrBijector <: Bijector @@ -223,25 +212,33 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector <: AbstractVecCorrBijector end -transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) +struct VecCorrBijector{T} <: Bijector + uplo::Symbol + function VecCorrBijector(uplo) + s = Symbol(uplo) + new{Val{s}}(s) + end +end + +# TODO: Implement directly to make use of shared computations. +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) -logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) +function logabsdetjac(b::VecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end -struct VecTriuBijector <: AbstractVecCorrBijector end +transform(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) -function transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) +function transform(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) U = UpperTriangular(_inv_link_chol_lkj(y)) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works return Cholesky(U.data, 'U', 0) end -logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) - -struct VecTrilBijector <: AbstractVecCorrBijector end - -function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) +function transform(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) @@ -250,7 +247,9 @@ function transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) return Cholesky(L.data, 'L', 0) end -logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) +logabsdetjac(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) +logabsdetjac(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) +logabsdetjac(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) """ diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index fa53dc95..e32b2408 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -78,8 +78,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() bijector(d::MatrixBeta) = PDBijector() -bijector(d::LKJ) = VecCorrBijector() -bijector(d::LKJCholesky) = d.uplo === 'L' ? VecTrilBijector() : VecTriuBijector() +bijector(d::LKJ) = VecCorrBijector('C') +bijector(d::LKJCholesky) = VecCorrBijector(d.uplo) function bijector(d::Distributions.ReshapedDistribution) inner_dims = size(d.dist) From 29184631705470889925cea5f66433fa99893754 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 12:39:05 +0100 Subject: [PATCH 138/172] update `VecCorrBijector` tests --- test/bijectors/corr.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 900b5e68..dd528eb5 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,5 +1,5 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, CorrBijector, VecTriuBijector, VecTrilBijector +using Bijectors: VecCorrBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] @@ -34,12 +34,12 @@ using Bijectors: VecCorrBijector, CorrBijector, VecTriuBijector, VecTrilBijector end end -@testset "VecTriuBijector & VecTrilBijector" begin +@testset "VecCorrBijector on LKJCholesky" begin for d ∈ [2, 5] for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] b = bijector(dist) - b_lkj = VecCorrBijector() + b_lkj = VecCorrBijector('C') x = rand(dist) y = b(x) y_lkj = b_lkj(x) From 2c4920d00f11c2411c655fc175f43dc468408940 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 14:40:03 +0100 Subject: [PATCH 139/172] use field value instead of `Val`-parametric type --- src/bijectors/corr.jl | 50 +++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 99c1fe25..2169feaf 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -212,12 +212,9 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector{T} <: Bijector - uplo::Symbol - function VecCorrBijector(uplo) - s = Symbol(uplo) - new{Val{s}}(s) - end +struct VecCorrBijector <: Bijector + mode::Symbol + VecCorrBijector(uplo) = new(Symbol(uplo)) end # TODO: Implement directly to make use of shared computations. @@ -229,29 +226,32 @@ function logabsdetjac(b::VecCorrBijector, x) return -logabsdetjac(inverse(b), b(x)) end -transform(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) - -function transform(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) - U = UpperTriangular(_inv_link_chol_lkj(y)) - # This Cholesky constructor is compatible with Julia v1.6 - # for later versions Cholesky(::UpperTriangular) works - return Cholesky(U.data, 'U', 0) +function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + if b.orig.mode === :U + U = UpperTriangular(_inv_link_chol_lkj(y)) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::UpperTriangular) works + return Cholesky(U.data, 'U', 0) + elseif b.orig.mode === :L + # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. + # If we don't, the return-type can be both `Matrix` and `Transposed`. + L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::LowerTriangular) works + return Cholesky(L.data, 'L', 0) + else + return pd_from_upper(_inv_link_chol_lkj(y)) + end end -function transform(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) - # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. - # If we don't, the return-type can be both `Matrix` and `Transposed`. - L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) - # This Cholesky constructor is compatible with Julia v1.6 - # for later versions Cholesky(::LowerTriangular) works - return Cholesky(L.data, 'L', 0) +function logabsdetjac(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + if (b.orig.mode === :U) || (b.orig.mode === :L) + return _logabsdetjac_inv_chol(y) + else + return _logabsdetjac_inv_corr(y) + end end -logabsdetjac(::Inverse{VecCorrBijector{Val{:C}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) -logabsdetjac(::Inverse{VecCorrBijector{Val{:U}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) -logabsdetjac(::Inverse{VecCorrBijector{Val{:L}}}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) - - """ function _link_chol_lkj(w) From 1872bb618538fe0f1166b142855132193652efba Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 14:47:34 +0100 Subject: [PATCH 140/172] update tests with new `VecCorrBijector` --- test/bijectors/utils.jl | 14 ++++++++++++-- test/transform.jl | 7 ++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index dc1d3a55..17a7bc79 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -25,9 +25,19 @@ function test_bijector( y_test = @inferred b(x) ilogjac_test = !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) ires = if !isnothing(y) - @inferred(with_logabsdet_jacobian(inverse(b), y)) + if b isa VecCorrBijector + # Inverse{VecCorrBijector} returns a ::Cholesky{...} in the case of a LKJCholesky distribution + # and a ::Matrix{Float64} in the case of a LKJ distribution. + @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y) + else + @inferred(with_logabsdet_jacobian(inverse(b), y)) + end else - @inferred(with_logabsdet_jacobian(inverse(b), y_test)) + if b isa VecCorrBijector + @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y_test) + else + @inferred(with_logabsdet_jacobian(inverse(b), y_test)) + end end # ChangesOfVariables.jl diff --git a/test/transform.jl b/test/transform.jl index f86b19ae..11e04d76 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -39,9 +39,14 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) + # LKJCholesky and LKJ use the same VecCorrBijector. + # The return type of Inverse{VecCorrBijector} depends on the distribution. if dist isa LKJCholesky - x_inv = @inferred(invlink(dist, link(dist, copy(x)))) + x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) @test x_inv.UL ≈ x.UL atol=1e-9 + elseif dist isa LKJ + x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) + @test x_inv ≈ x atol=1e-9 else @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 end From 1250592e63e0566877ee4e4cf4b853a178841c0e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 14:57:50 +0100 Subject: [PATCH 141/172] `using VecCorrBijector` in test utils --- test/bijectors/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index 17a7bc79..45e863c6 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,3 +1,5 @@ +using Bijectors: VecCorrBijector + # Allows us to run `ChangesOfVariables.test_with_logabsdet_jacobian` include(joinpath(dirname(pathof(ChangesOfVariables)), "..", "test", "getjacobian.jl")) From 66b4caa12c8add21e86e64f4b3b4facb47889b6f Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 16:27:04 +0100 Subject: [PATCH 142/172] add `VecCorrBijector.mode` check --- src/bijectors/corr.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 2169feaf..69d5da69 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -214,7 +214,14 @@ true """ struct VecCorrBijector <: Bijector mode::Symbol - VecCorrBijector(uplo) = new(Symbol(uplo)) + function VecCorrBijector(uplo_or_corr) + s = Symbol(uplo_or_corr) + if (s === :U) || (s === :L) || (s === :C) + new(s) + else + throw(ArgumentError("mode must be :U (upper), :L (lower) or :C (correlation matrix)")) + end + end end # TODO: Implement directly to make use of shared computations. From c5cb535e7e3f2050c7880d1a59a99ccc3284d4fc Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 16:27:30 +0100 Subject: [PATCH 143/172] update `VecCorrBijector` docstring --- src/bijectors/corr.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 69d5da69..5f6b91fd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -183,8 +183,20 @@ end """ VecCorrBijector <: Bijector -Similar to `CorrBijector`, but correlation matrix to a vector, -and its inverse transforms vector to a correlation matrix. +A bijector to transform either a correlation matrix or a Cholesky factor of a correlation matrix +to an unconstrained vector. + +# Fields +- mode :`Symbol`. Controls the inverse tranformation : + - if `mode === :C` returns a correlation matrix + - if `mode === :U` returns a `LinearAlgebra.Cholesky` holding the `UpperTriangular` factor + - if `mode === :L` returns a `LinearAlgebra.Cholesky` holding the `LowerTriangular` factor + +# Reference +- Transforming a orrelation matrix : +https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html#absolute-jacobian-determinant-of-the-correlation-matrix-inverse-transform +- Transforming a Cholesky factor of a correlation matrix : +https://mc-stan.org/docs/reference-manual/cholesky-factors-of-correlation-matrices-1 See also: [`CorrBijector`](@ref) From 8a06239404acd19281907d45a043cbb5807a5286 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:20:53 +0100 Subject: [PATCH 144/172] specialise `Zygote@adjoint` for `AbstractMatrix` --- src/compat/zygote.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 6a81a749..864286d7 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -174,7 +174,7 @@ end end end -@adjoint function _inv_link_chol_lkj(y) +@adjoint function _inv_link_chol_lkj(y::AbstractMatrix) K = LinearAlgebra.checksquare(y) w = similar(y) @@ -219,7 +219,7 @@ end return w, pullback_inv_link_chol_lkj end -@adjoint function _link_chol_lkj(w) +@adjoint function _link_chol_lkj(w::AbstractMatrix) K = LinearAlgebra.checksquare(w) z = similar(w) From 44b3b9fa217a765a792a98529518e70972f783d9 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:21:38 +0100 Subject: [PATCH 145/172] `ReverseDiff` opt-in to `ChainRules` --- src/compat/reversediff.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 78871ce1..52301333 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -200,6 +200,9 @@ end @grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) +@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) +@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) + # NOTE: Probably doesn't work in complete generality. wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing From a5d601d8468cab6e3ae6e8618a2d8aeaf2710ff5 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:35:15 +0100 Subject: [PATCH 146/172] empty lines format --- src/chainrules.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 75e8676b..857389de 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -157,10 +157,9 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM end function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) - K = LinearAlgebra.checksquare(W) N = ((K-1)*K) ÷ 2 - + z = zeros(eltype(W), N) tmp_vec = similar(z) @@ -264,7 +263,6 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular) end function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) - K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) From 8783271342024fdab3897c114d3d8864fd2d7cf6 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:35:53 +0100 Subject: [PATCH 147/172] add AD test for inverse link --- test/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index dd528eb5..ac889424 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -4,7 +4,7 @@ using Bijectors: VecCorrBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] b = CorrBijector() - bvec = VecCorrBijector() + bvec = VecCorrBijector('C') dist = LKJ(d, 1) x = rand(dist) @@ -31,6 +31,8 @@ using Bijectors: VecCorrBijector, CorrBijector # Hence, we disable those tests. test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) + + test_ad(x -> sum(transform(inverse(b), x)), y, (:Tracker,)) end end From a197076bfd7bb40e3e04eec8bae07b7b2c68a52e Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:36:40 +0100 Subject: [PATCH 148/172] include `VecCorrBijector` tests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index b48c656d..b8a1cf39 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,7 @@ if GROUP == "All" || GROUP == "Interface" include("bijectors/ordered.jl") include("bijectors/pd.jl") include("bijectors/reshape.jl") + include("bijectors/corr.jl") end if GROUP == "All" || GROUP == "AD" From 7b9d1b2a4d7023d9ace9d60c7c6713f2f7bde9df Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 18:48:01 +0100 Subject: [PATCH 149/172] remove broken flag for `Tracker` --- test/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index ac889424..eedefc3e 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(transform(inverse(b), x)), y, (:Tracker,)) + test_ad(x -> sum(transform(inverse(b), x)), y) end end From 5d1a7b8d72d49f656ba5408c5e488fe434dda1f9 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 18 Apr 2023 19:36:01 +0100 Subject: [PATCH 150/172] add roundtrip AD tests for `VecCorrBijector` --- test/bijectors/corr.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index eedefc3e..f1f6ab1f 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(transform(inverse(b), x)), y) + test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker,)) end end @@ -55,6 +55,8 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U + test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) + # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) # test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) From a0d5e5262c38f55c564c38605eb7ad4b471ff5d0 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 17:53:42 +0100 Subject: [PATCH 151/172] remove wrong `ReverseDiff.@grad` for `pd_from_upper` --- src/compat/reversediff.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 52301333..0487751f 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -146,14 +146,7 @@ pd_from_lower(X::TrackedMatrix) = track(pd_from_lower, X) end end -pd_from_upper(X::TrackedMatrix) = track(pd_from_upper, X) -@grad function pd_from_upper(X::AbstractMatrix) - Xd = value(X) - return UpperTriangular(Xd)' * UpperTriangular(Xd), Δ -> begin - Xu = UpperTriangular(Xd) - return (UpperTriangular(Δ * Xu + Δ' * Xu),) - end -end +@grad_from_chainrules pd_from_upper(X::TrackedMatrix) lower_triangular(A::TrackedMatrix) = track(lower_triangular, A) @grad function lower_triangular(A::AbstractMatrix) From bd0efffc22e458e4becf2f6d3c1e3194ababf743 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 17:54:23 +0100 Subject: [PATCH 152/172] add corrected `rrule` for `pd_from_upper` --- src/chainrules.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 857389de..c5041f36 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -264,7 +264,7 @@ end function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) K = _triu1_dim_from_length(length(y)) - + W = similar(y, K, K) z_vec = similar(y) @@ -314,5 +314,12 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) return W, pullback_inv_link_chol_lkj end +function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) + return UpperTriangular(X)' * UpperTriangular(X), Δ -> begin + Xu = UpperTriangular(X) + return ChainRulesCore.NoTangent(), UpperTriangular(Xu * Δ + Xu * Δ') + end +end + # Fixes Zygote's issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) From e3314a4cbaef42140b7db71141c9d1c7d532a9f1 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 17:56:10 +0100 Subject: [PATCH 153/172] update AD tests --- test/bijectors/corr.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index f1f6ab1f..5f0017de 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -31,8 +31,8 @@ using Bijectors: VecCorrBijector, CorrBijector # Hence, we disable those tests. test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - - test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker,)) + + test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker, :Zygote,)) end end @@ -55,7 +55,7 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U - test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) + test_ad(x -> sum(b(binv(x))), y, (:Tracker, :Zygote,)) # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) From c34ad471497dcf1a07f9d9e92d00379ee0a28b8f Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 19 Apr 2023 18:30:38 +0100 Subject: [PATCH 154/172] remove `Tracker` from broken --- test/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 5f0017de..a377b2a8 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Tracker, :Zygote,)) + test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Zygote,)) end end From e154061d4bb72ec03d0a4ba6924b7d0292b8f911 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 25 Apr 2023 16:19:01 +0100 Subject: [PATCH 155/172] update zero-filling in `Tracker` pullback --- src/compat/tracker.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index dae58086..dee73f8b 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -302,7 +302,6 @@ Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_ K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) - W .= zeros(eltype(y)) z_vec = similar(y) tmp_vec = similar(y) @@ -321,6 +320,9 @@ Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_ W[i-1, j] = z * tmp W[i, j] = tmp * sqrt(1 - z^2) end + for i in (j+1):K + W[i, j] = 0 + end end function pullback_inv_link_chol_lkj(ΔW) From cffb616edf26470ec20854efcce1dd0aadf62aaf Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 25 Apr 2023 16:19:26 +0100 Subject: [PATCH 156/172] fix `Zygote` --- src/bijectors/corr.jl | 8 ++++---- src/utils.jl | 2 +- test/bijectors/corr.jl | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5f6b91fd..f0c7d638 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -247,17 +247,17 @@ end function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) if b.orig.mode === :U - U = UpperTriangular(_inv_link_chol_lkj(y)) + U = _inv_link_chol_lkj(y) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works - return Cholesky(U.data, 'U', 0) + return Cholesky(U, 'U', 0) elseif b.orig.mode === :L # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - L = LowerTriangular(Matrix(transpose(_inv_link_chol_lkj(y)))) + L = Matrix(transpose(_inv_link_chol_lkj(y))) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::LowerTriangular) works - return Cholesky(L.data, 'L', 0) + return Cholesky(L, 'L', 0) else return pd_from_upper(_inv_link_chol_lkj(y)) end diff --git a/src/utils.jl b/src/utils.jl index 34842e89..4880ac4c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,6 @@ pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(X)) -cholesky_factor(X::Cholesky) = X.UL +cholesky_factor(X::Cholesky) = X.U cholesky_factor(X::UpperTriangular) = X cholesky_factor(X::LowerTriangular) = X diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index a377b2a8..947678d8 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -32,7 +32,7 @@ using Bijectors: VecCorrBijector, CorrBijector test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) - test_ad(x -> sum(bvec(bvecinv(x))), yvec, (:Zygote,)) + test_ad(x -> sum(bvec(bvecinv(x))), yvec) end end @@ -55,7 +55,7 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U - test_ad(x -> sum(b(binv(x))), y, (:Tracker, :Zygote,)) + test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) From c13fce6e1ceca793c9619f535141064babbb7b75 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 4 May 2023 13:46:46 +0100 Subject: [PATCH 157/172] merge lines - applying feedback suggestions --- src/bijectors/corr.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index f0c7d638..5c60998f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -247,17 +247,13 @@ end function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) if b.orig.mode === :U - U = _inv_link_chol_lkj(y) # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works - return Cholesky(U, 'U', 0) + return Cholesky(_inv_link_chol_lkj(y), 'U', 0) elseif b.orig.mode === :L # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - L = Matrix(transpose(_inv_link_chol_lkj(y))) - # This Cholesky constructor is compatible with Julia v1.6 - # for later versions Cholesky(::LowerTriangular) works - return Cholesky(L, 'L', 0) + return Cholesky(Matrix(transpose(_inv_link_chol_lkj(y))), 'L', 0) else return pd_from_upper(_inv_link_chol_lkj(y)) end From dfeb71e4f43d80f0f5fea5210253f3dfd13cca43 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 17:58:40 +0300 Subject: [PATCH 158/172] `unthunk` in `pd_from_upper` rrule --- src/chainrules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index c5041f36..46a10db8 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -315,7 +315,8 @@ function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) end function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) - return UpperTriangular(X)' * UpperTriangular(X), Δ -> begin + return UpperTriangular(X)' * UpperTriangular(X), Δ_thunked -> begin + Δ = ChainRulesCore.unthunk(Δ_thunked) Xu = UpperTriangular(X) return ChainRulesCore.NoTangent(), UpperTriangular(Xu * Δ + Xu * Δ') end From 5210437786aa02608e8d6e76f3bfb1474fa21deb Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:05:17 +0300 Subject: [PATCH 159/172] split structs into `VecCorrBijector` and `VecCholeskyBijector` --- src/bijectors/corr.jl | 104 ++++++++++++++++++++++---------- src/transformed_distribution.jl | 4 +- 2 files changed, 74 insertions(+), 34 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5c60998f..367fe671 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -183,22 +183,12 @@ end """ VecCorrBijector <: Bijector -A bijector to transform either a correlation matrix or a Cholesky factor of a correlation matrix -to an unconstrained vector. - -# Fields -- mode :`Symbol`. Controls the inverse tranformation : - - if `mode === :C` returns a correlation matrix - - if `mode === :U` returns a `LinearAlgebra.Cholesky` holding the `UpperTriangular` factor - - if `mode === :L` returns a `LinearAlgebra.Cholesky` holding the `LowerTriangular` factor +A bijector to transform a correlation matrix to an unconstrained vector. # Reference -- Transforming a orrelation matrix : -https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html#absolute-jacobian-determinant-of-the-correlation-matrix-inverse-transform -- Transforming a Cholesky factor of a correlation matrix : -https://mc-stan.org/docs/reference-manual/cholesky-factors-of-correlation-matrices-1 +https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html -See also: [`CorrBijector`](@ref) +See also: [`CorrBijector`](@ref) and ['VecCholeskyBijector'](@ref) # Example @@ -224,48 +214,98 @@ julia> y = b(X) # Transform to unconstrained vector representation. julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. true """ -struct VecCorrBijector <: Bijector +struct VecCorrBijector <: Bijector end + +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) + +function logabsdetjac(b::VecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end + +transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = pd_from_upper(_inv_link_chol_lkj(y)) + +logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y) + +""" + VecCholeskyBijector <: Bijector + +A bijector to transform a Cholesky factor of a correlation matrix to an unconstrained vector. + +# Fields +- mode :`Symbol`. Controls the inverse tranformation : + - if `mode === :U` returns a `LinearAlgebra.Cholesky` holding the `UpperTriangular` factor + - if `mode === :L` returns a `LinearAlgebra.Cholesky` holding the `LowerTriangular` factor + +# Reference +https://mc-stan.org/docs/reference-manual/cholesky-factors-of-correlation-matrices-1 + +See also: [`VecCorrBijector`](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCholeskyBijector(:U); + +julia> X = rand(rng, LKJCholesky(3, 1, :U)) # Sample a correlation matrix. +Cholesky{Float64, Matrix{Float64}} +U factor: +3×3 UpperTriangular{Float64, Matrix{Float64}}: + 1.0 0.937494 0.865891 + ⋅ 0.348002 -0.320442 + ⋅ ⋅ 0.384122 + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> X_inv = inverse(b)(y); +julia> X_inv.U ≈ X.U # (✓) Round-trip through `b` and its inverse. +true +julia> X_inv.L ≈ X.L # (✓) Also works for the lower triangular factor. +true +""" +struct VecCholeskyBijector <: Bijector mode::Symbol - function VecCorrBijector(uplo_or_corr) - s = Symbol(uplo_or_corr) - if (s === :U) || (s === :L) || (s === :C) + function VecCholeskyBijector(uplo) + s = Symbol(uplo) + if (s === :U) || (s === :L) new(s) else - throw(ArgumentError("mode must be :U (upper), :L (lower) or :C (correlation matrix)")) + throw(ArgumentError("mode must be either :U (upper triangular) or :L (lower triangular)")) end end end # TODO: Implement directly to make use of shared computations. -with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) +with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x) -transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) +transform(::VecCholeskyBijector, X) = _link_chol_lkj(cholesky_factor(X)) -function logabsdetjac(b::VecCorrBijector, x) +function logabsdetjac(b::VecCholeskyBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function transform(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) +function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) if b.orig.mode === :U # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works return Cholesky(_inv_link_chol_lkj(y), 'U', 0) - elseif b.orig.mode === :L + else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. return Cholesky(Matrix(transpose(_inv_link_chol_lkj(y))), 'L', 0) - else - return pd_from_upper(_inv_link_chol_lkj(y)) end end -function logabsdetjac(b::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - if (b.orig.mode === :U) || (b.orig.mode === :L) - return _logabsdetjac_inv_chol(y) - else - return _logabsdetjac_inv_corr(y) - end -end +logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y) """ function _link_chol_lkj(w) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index e32b2408..bc805694 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -78,8 +78,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() bijector(d::MatrixBeta) = PDBijector() -bijector(d::LKJ) = VecCorrBijector('C') -bijector(d::LKJCholesky) = VecCorrBijector(d.uplo) +bijector(d::LKJ) = VecCorrBijector() +bijector(d::LKJCholesky) = VecCholeskyBijector(d.uplo) function bijector(d::Distributions.ReshapedDistribution) inner_dims = size(d.dist) From 25a70b4b78a408c735c9af82baed0d52ed23fcdd Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:05:33 +0300 Subject: [PATCH 160/172] remove old `Zygote` adjoints --- src/compat/zygote.jl | 97 -------------------------------------------- 1 file changed, 97 deletions(-) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 864286d7..29497140 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -173,100 +173,3 @@ end return replace_diag(log, Y) end end - -@adjoint function _inv_link_chol_lkj(y::AbstractMatrix) - K = LinearAlgebra.checksquare(y) - - w = similar(y) - - z_mat = similar(y) # cache for adjoint - tmp_mat = similar(y) - - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] - - z_mat[i, j] = z - tmp_mat[i, j] = tmp - - w[i-1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j+1):K - w[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(Δw) - LinearAlgebra.checksquare(Δw) - - Δy = zero(y) - - @inbounds for j in 1:K - Δtmp = Δw[j,j] - for i in j:-1:2 - Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] - Δy[i-1, j] = Δz / cosh(y[i-1, j])^2 - Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) - end - end - - return (Δy,) - end - - return w, pullback_inv_link_chol_lkj -end - -@adjoint function _link_chol_lkj(w::AbstractMatrix) - K = LinearAlgebra.checksquare(w) - - z = similar(w) - - @inbounds z[1, 1] = 0 - - tmp_mat = similar(w) # cache for pullback. - - @inbounds for j=2:K - z[1, j] = atanh(w[1, j]) - tmp = sqrt(1 - w[1, j]^2) - tmp_mat[1, j] = tmp - for i in 2:(j - 1) - p = w[i, j] / tmp - tmp *= sqrt(1 - p^2) - tmp_mat[i, j] = tmp - z[i, j] = atanh(p) - end - z[j, j] = 0 - end - - function pullback_link_chol_lkj(Δz) - LinearAlgebra.checksquare(Δz) - - Δw = similar(w) - - @inbounds Δw[1,1] = zero(eltype(Δz)) - - @inbounds for j=2:K - Δw[j, j] = 0 - Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j] - for i in (j-1):-1:2 - p = w[i, j] / tmp_mat[i-1, j] - ftmp = sqrt(1 - p^2) - d_ftmp_p = -p / ftmp - d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2 - - Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p - Δw[i, j] = Δp / tmp_mat[i-1, j] - Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp - end - Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j] - end - - return (Δw,) - end - - return z, pullback_link_chol_lkj - -end From b056fdd991341eeeb8a74a88a34a85e65790144d Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:05:59 +0300 Subject: [PATCH 161/172] update tests --- test/bijectors/corr.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 947678d8..e78c2da2 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,10 +1,10 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, CorrBijector +using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d ∈ [1, 2, 5] b = CorrBijector() - bvec = VecCorrBijector('C') + bvec = VecCorrBijector() dist = LKJ(d, 1) x = rand(dist) @@ -36,12 +36,12 @@ using Bijectors: VecCorrBijector, CorrBijector end end -@testset "VecCorrBijector on LKJCholesky" begin +@testset "VecCholeskyBijector" begin for d ∈ [2, 5] for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] b = bijector(dist) - b_lkj = VecCorrBijector('C') + b_lkj = VecCorrBijector() x = rand(dist) y = b(x) y_lkj = b_lkj(x) @@ -55,7 +55,7 @@ end @test xinv.U ≈ cholesky(xinv_lkj).U - test_ad(x -> sum(b(binv(x))), y, (:Tracker,)) + test_ad(x -> sum(b(binv(x))), y) # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) From 33a8a29cd3c768ef462e929ca5adb4c008a670a7 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:07:10 +0300 Subject: [PATCH 162/172] fix `Union` in `@inferred` after splitting structs --- test/bijectors/utils.jl | 12 ------------ test/transform.jl | 6 +----- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index 45e863c6..4c986bee 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -1,5 +1,3 @@ -using Bijectors: VecCorrBijector - # Allows us to run `ChangesOfVariables.test_with_logabsdet_jacobian` include(joinpath(dirname(pathof(ChangesOfVariables)), "..", "test", "getjacobian.jl")) @@ -27,19 +25,9 @@ function test_bijector( y_test = @inferred b(x) ilogjac_test = !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) ires = if !isnothing(y) - if b isa VecCorrBijector - # Inverse{VecCorrBijector} returns a ::Cholesky{...} in the case of a LKJCholesky distribution - # and a ::Matrix{Float64} in the case of a LKJ distribution. - @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y) - else @inferred(with_logabsdet_jacobian(inverse(b), y)) - end else - if b isa VecCorrBijector - @inferred Tuple{Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}}, Float64} with_logabsdet_jacobian(inverse(b), y_test) - else @inferred(with_logabsdet_jacobian(inverse(b), y_test)) - end end # ChangesOfVariables.jl diff --git a/test/transform.jl b/test/transform.jl index 11e04d76..aaa53a94 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -39,17 +39,13 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) - # LKJCholesky and LKJ use the same VecCorrBijector. - # The return type of Inverse{VecCorrBijector} depends on the distribution. if dist isa LKJCholesky x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) @test x_inv.UL ≈ x.UL atol=1e-9 - elseif dist isa LKJ - x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) - @test x_inv ≈ x atol=1e-9 else @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 end + # Check that link is inverse of invlink. Hopefully this just holds given the above... y = @inferred(link(dist, x)) if dist isa Dirichlet From bfa448ba33e76749d2db81f9f2a7efeb1c15f380 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Wed, 24 May 2023 19:07:36 +0300 Subject: [PATCH 163/172] remove `Tracker` tests as support is dropped --- test/ad/utils.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 6bf8365f..da21e3da 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -4,16 +4,6 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] - if AD == "All" || AD == "Tracker" - if :Tracker in broken - @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol - else - ∇tracker = Tracker.gradient(f, x)[1] - @test Tracker.data(∇tracker) ≈ finitediff rtol=rtol atol=atol - @test Tracker.istracked(∇tracker) - end - end - if AD == "All" || AD == "ForwardDiff" if :ForwardDiff in broken @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol From 96b90e6c43b559b46a3fb04917a4de89a8e72ad0 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Jun 2023 13:11:12 +0300 Subject: [PATCH 164/172] use `permutedims` instead of casting --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 367fe671..b92ad70f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -301,7 +301,7 @@ function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - return Cholesky(Matrix(transpose(_inv_link_chol_lkj(y))), 'L', 0) + return Cholesky(permutedims(_inv_link_chol_lkj(y), (2, 1)), 'L', 0) end end From 48edf87f044e874ebec92fc6cdc5c0ada7303e53 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Jun 2023 13:11:58 +0300 Subject: [PATCH 165/172] remove `Union` in `@inferred` --- test/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transform.jl b/test/transform.jl index aaa53a94..29d0dbcf 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -40,7 +40,7 @@ function single_sample_tests(dist) x = rand(dist) if dist isa LKJCholesky - x_inv = @inferred Union{Cholesky{Float64, Matrix{Float64}}, Matrix{Float64}} invlink(dist, link(dist, copy(x))) + x_inv = @inferred Cholesky{Float64, Matrix{Float64}} invlink(dist, link(dist, copy(x))) @test x_inv.UL ≈ x.UL atol=1e-9 else @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 From 159ddb618df27477555bf3ce132f60ef6e9860e4 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Tue, 6 Jun 2023 17:27:28 +0300 Subject: [PATCH 166/172] wrap matrix in `Hermitian` before `cholesky` --- src/Bijectors.jl | 2 +- src/utils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 6e1171da..7caa597d 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -33,7 +33,7 @@ using Reexport, Requires using LinearAlgebra using MappedArrays using Base.Iterators: drop -using LinearAlgebra: AbstractTriangular +using LinearAlgebra: AbstractTriangular, Hermitian using InverseFunctions: InverseFunctions diff --git a/src/utils.jl b/src/utils.jl index 4880ac4c..439a2d0c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,7 +14,7 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) -cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(X)) +cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X))) cholesky_factor(X::Cholesky) = X.U cholesky_factor(X::UpperTriangular) = X cholesky_factor(X::LowerTriangular) = X From 9c3dec8546de60d7704b4035960ba4db28d73c50 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 8 Jun 2023 09:42:24 +0300 Subject: [PATCH 167/172] add hacky dispatch for `cholesky_factor` and `ReverseDiff` --- src/compat/reversediff.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 0487751f..3bbd5278 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -10,7 +10,7 @@ import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, _simplex_inv_bijector, replace_diag, jacobian, pd_from_lower, pd_from_upper, lower_triangular, upper_triangular, _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, - find_alpha + find_alpha, cholesky_factor import ChainRulesCore @@ -201,4 +201,10 @@ wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) +# HACK: To make it work for julia v1.6 . +# This dispatch does not wrap X in Hermitian before calling cholesky. +# cholesky does not work with AbstractMatrix in julia v1.6, +# and it would error with Hermitian{ReverseDiff.TrackedArray}. +cholesky_factor(X::ReverseDiff.TrackedArray) = cholesky_factor(cholesky(X)) + end From 87a6faca5a151773d630f19ccd767d6f81228eae Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 8 Jun 2023 09:45:04 +0300 Subject: [PATCH 168/172] import `cholesky_factor` in ReverseDiff module for hacky dispatch --- src/compat/reversediff.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index b5bed098..cf7a00c9 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -35,7 +35,8 @@ import ..Bijectors: _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, - find_alpha + find_alpha, + cholesky_factor using ChainRulesCore: ChainRulesCore From 1d8999f511b35b1ad908b14f681b094904ea669a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 8 Jun 2023 15:01:37 +0300 Subject: [PATCH 169/172] only use hacky `cholesky_factor` in versions before fix --- src/compat/reversediff.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index cf7a00c9..7e95e69c 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -237,10 +237,13 @@ wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) -# HACK: To make it work for julia v1.6 . -# This dispatch does not wrap X in Hermitian before calling cholesky. -# cholesky does not work with AbstractMatrix in julia v1.6, -# and it would error with Hermitian{ReverseDiff.TrackedArray}. -cholesky_factor(X::ReverseDiff.TrackedArray) = cholesky_factor(cholesky(X)) +if VERSION <= v"1.8.0-DEV.1526" + # HACK: This dispatch does not wrap X in Hermitian before calling cholesky. + # cholesky does not work with AbstractMatrix in julia versions before the compared one, + # and it would error with Hermitian{ReverseDiff.TrackedArray}. + # See commit when the fix was introduced : + # https://github.com/JuliaLang/julia/commit/635449dabee81bba315ab066627a98f856141969 + cholesky_factor(X::ReverseDiff.TrackedArray) = cholesky_factor(cholesky(X)) +end end From 424607d27bdee224d587255e91f0edc79d434fe1 Mon Sep 17 00:00:00 2001 From: harisorgn Date: Thu, 8 Jun 2023 17:17:32 +0300 Subject: [PATCH 170/172] change `LKJCholesky` shape to avoid stochastic test failures --- test/ad/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index dd74c1d4..e639e4ea 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -21,7 +21,7 @@ test_rrule(Bijectors._transform_inverse_ordered, b(rand(5, 2))) # LKJ and LKJCholesky bijector - dist = LKJCholesky(3, 1) + dist = LKJCholesky(3, 4) x = rand(dist) test_rrule(Bijectors._link_chol_lkj, x.U) test_rrule(Bijectors._link_chol_lkj, x.L) From 6aeebbfe5165991a5c8fe6af2e3406b0864d5eca Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 12 Jun 2023 12:16:15 +0300 Subject: [PATCH 171/172] remove old TODOs --- src/bijectors/corr.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 99fb3fa2..0a339b6a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -350,7 +350,6 @@ and so which is the above implementation. """ function _link_chol_lkj(W::AbstractMatrix) - # TODO: Implement adjoint to support reverse-mode AD backends properly. K = LinearAlgebra.checksquare(W) z = similar(W) # z is also UpperTriangular. @@ -403,7 +402,6 @@ _link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W)) Inverse link function for cholesky factor. """ function _inv_link_chol_lkj(Y::AbstractMatrix) - # TODO: Implement adjoint to support reverse-mode AD backends properly. K = LinearAlgebra.checksquare(Y) W = similar(Y) From 62ca234569a665c0f5cb9a58d760d66bbff2144a Mon Sep 17 00:00:00 2001 From: harisorgn Date: Mon, 12 Jun 2023 12:17:14 +0300 Subject: [PATCH 172/172] add explicit zero-filling in link for `CorrBijector` --- src/bijectors/corr.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 0a339b6a..9367b0cd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -356,7 +356,7 @@ function _link_chol_lkj(W::AbstractMatrix) # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. # This block can't be integrated with loop below, because W[1,1] != 0. - @inbounds z[1, 1] = 0 + @inbounds z[:, 1] .= 0 @inbounds for j in 2:K z[1, j] = atanh(W[1, j]) @@ -366,7 +366,9 @@ function _link_chol_lkj(W::AbstractMatrix) tmp *= sqrt(1 - p^2) z[i, j] = atanh(p) end - z[j, j] = 0 + for i in j:K + z[i, j] = 0 + end end return z