Skip to content

Commit

Permalink
Fix incorrect output dimensions in stacked bijectors (#2066)
Browse files Browse the repository at this point in the history
* Add output_length and transform dispatches for Vec bijector

This fixes #2065

* Add tests

* Bump revision

* Rename model function in simplex bijector test
  • Loading branch information
bgroenks96 authored Aug 28, 2023
1 parent 4affc28 commit f891aa8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.28.2"
version = "0.28.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
8 changes: 8 additions & 0 deletions src/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ end

Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size)

Bijectors.output_length(f::Vec, sz) = Bijectors.output_length(f.b, sz)
Bijectors.output_length(f::Vec, n::Int) = Bijectors.output_length(f.b, n)

function Bijectors.with_logabsdet_jacobian(f::Vec, x)
return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x)
end
Expand All @@ -15,6 +18,11 @@ function Bijectors.transform(f::Vec, x::AbstractVector)
return vec(f.b(reshape(x, f.size)))
end

function Bijectors.transform(f::Vec{N,<:Bijectors.Inverse}, x::AbstractVector) where N
# Reshape into shape compatible with original (forward) bijector and then `vec` again.
return vec(f.b(reshape(x, Bijectors.output_length(f.b.orig, prod(f.size)))))
end

function Bijectors.transform(f::Vec, x::AbstractMatrix)
# At the moment we do batching for higher-than-1-dim spaces by simply using
# lists of inputs rather than `AbstractArray` with `N + 1` dimension.
Expand Down
18 changes: 18 additions & 0 deletions test/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,22 @@
xs = rand(target, 10)
@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) 0.05
end

# regression test for:
# https://github.com/TuringLang/Turing.jl/issues/2065
@turing_testset "simplex bijector" begin
@model function dirichlet()
x ~ Dirichlet([1.0,1.0])
return x
end

m = dirichlet()
b = bijector(m)
x0 = m()
z0 = b(x0)
@test size(z0) == (1,)
x0_inv = inverse(b)(z0)
@test size(x0_inv) == size(x0)
@test all(x0 .≈ x0_inv)
end
end

2 comments on commit f891aa8

@yebai
Copy link
Member

@yebai yebai commented on f891aa8 Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/90398

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.28.3 -m "<description of version>" f891aa8f82a66f373bc1d87df1ee7aaa5e10795b
git push origin v0.28.3

Please sign in to comment.