Skip to content

Commit

Permalink
Formatting VI files
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Jun 10, 2024
1 parent 357dd9d commit 798bad7
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 51 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
style="blue"
format_markdown = true
import_to_using = false
# TODO
# We ignore these files because when formatting was first put in place they were being worked on.
# These ignores should be removed once the relevant PRs are merged/closed.
Expand Down
28 changes: 7 additions & 21 deletions src/variational/VariationalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,9 @@ using Random: Random
import AdvancedVI
import Bijectors


# Reexports
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad
export
vi,
ADVI,
ELBO,
elbo,
TruncatedADAGrad,
DecayedADAGrad
export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad

"""
make_logjoint(model::Model; weight = 1.0)
Expand All @@ -31,17 +24,10 @@ use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_o
## Notes
- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`.
"""
function make_logjoint(model::DynamicPPL.Model; weight = 1.0)
function make_logjoint(model::DynamicPPL.Model; weight=1.0)
# setup
ctx = DynamicPPL.MiniBatchContext(
DynamicPPL.DefaultContext(),
weight
)
f = DynamicPPL.LogDensityFunction(
model,
DynamicPPL.VarInfo(model),
ctx
)
ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight)
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx)
return Base.Fix1(LogDensityProblems.logdensity, f)
end

Expand All @@ -52,10 +38,10 @@ function (elbo::ELBO)(
q,
model::DynamicPPL.Model,
num_samples;
weight = 1.0,
kwargs...
weight=1.0,
kwargs...,
)
return elbo(rng, alg, q, make_logjoint(model; weight = weight), num_samples; kwargs...)
return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...)
end

# VI algorithms
Expand Down
33 changes: 14 additions & 19 deletions src/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,32 @@ function wrap_in_vec_reshape(f, in_size)
return reshape_outer f reshape_inner
end


"""
bijector(model::Model[, sym2ranges = Val(false)])
Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d`
denoting the dimensionality of the latent variables.
"""
function Bijectors.bijector(
model::DynamicPPL.Model,
::Val{sym2ranges} = Val(false);
varinfo = DynamicPPL.VarInfo(model)
model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model)
) where {sym2ranges}
num_params = sum([size(varinfo.metadata[sym].vals, 1)
for sym keys(varinfo.metadata)])
num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata)
])

dists = vcat([varinfo.metadata[sym].dists for sym keys(varinfo.metadata)]...)
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)

num_ranges = sum([length(varinfo.metadata[sym].ranges)
for sym keys(varinfo.metadata)])
num_ranges = sum([
length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata)
])
ranges = Vector{UnitRange{Int}}(undef, num_ranges)
idx = 0
range_idx = 1

# ranges might be discontinuous => values are vectors of ranges rather than just ranges
sym_lookup = Dict{Symbol, Vector{UnitRange{Int}}}()
for sym keys(varinfo.metadata)
sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}()
for sym in keys(varinfo.metadata)
sym_lookup[sym] = Vector{UnitRange{Int}}()
for r varinfo.metadata[sym].ranges
for r in varinfo.metadata[sym].ranges
ranges[range_idx] = idx .+ r
push!(sym_lookup[sym], ranges[range_idx])
range_idx += 1
Expand Down Expand Up @@ -117,27 +115,24 @@ function AdvancedVI.update(
end

function AdvancedVI.vi(
model::DynamicPPL.Model,
alg::AdvancedVI.ADVI;
optimizer = AdvancedVI.TruncatedADAGrad(),
model::DynamicPPL.Model, alg::AdvancedVI.ADVI; optimizer=AdvancedVI.TruncatedADAGrad()
)
q = meanfield(model)
return AdvancedVI.vi(model, alg, q; optimizer = optimizer)
return AdvancedVI.vi(model, alg, q; optimizer=optimizer)
end


function AdvancedVI.vi(
model::DynamicPPL.Model,
alg::AdvancedVI.ADVI,
q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal};
optimizer = AdvancedVI.TruncatedADAGrad(),
optimizer=AdvancedVI.TruncatedADAGrad(),
)
# Initial parameters for mean-field approx
μ, σs = StatsBase.params(q)
θ = vcat(μ, StatsFuns.invsoftplus.(σs))

# Optimize
AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer = optimizer)
AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer=optimizer)

# Return updated `Distribution`
return AdvancedVI.update(q, θ)
Expand Down
18 changes: 9 additions & 9 deletions test/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ using Turing.Essential: TuringDiagMvNormal
N = 500

alg = ADVI(10, 5000)
q = vi(gdemo_default, alg; optimizer = opt)
q = vi(gdemo_default, alg; optimizer=opt)
samples = transpose(rand(q, N))
chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"])

# TODO: uhmm, seems like a large `eps` here...
check_gdemo(chn, atol = 0.5)
check_gdemo(chn; atol=0.5)
end
end

Expand All @@ -52,7 +52,7 @@ using Turing.Essential: TuringDiagMvNormal

# OR: implement `update` and pass a `Distribution`
function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real})
return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end]))
return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[(length(q) + 1):end]))
end

q0 = TuringDiagMvNormal(zeros(2), ones(2))
Expand All @@ -66,7 +66,7 @@ using Turing.Essential: TuringDiagMvNormal
# https://github.com/TuringLang/Turing.jl/issues/2065
@testset "simplex bijector" begin
@model function dirichlet()
x ~ Dirichlet([1.0,1.0])
x ~ Dirichlet([1.0, 1.0])
return x
end

Expand All @@ -82,17 +82,17 @@ using Turing.Essential: TuringDiagMvNormal
# And regression for https://github.com/TuringLang/Turing.jl/issues/2160.
q = vi(m, ADVI(10, 1000))
x = rand(q, 1000)
@test mean(eachcol(x)) [0.5, 0.5] atol=0.1
@test mean(eachcol(x)) [0.5, 0.5] atol = 0.1
end

# Ref: https://github.com/TuringLang/Turing.jl/issues/2205
@testset "with `condition` (issue #2205)" begin
@model function demo_issue2205()
x ~ Normal()
y ~ Normal(x, 1)
return y ~ Normal(x, 1)
end

model = demo_issue2205() | (y = 1.0,)
model = demo_issue2205() | (y=1.0,)
q = vi(model, ADVI(10, 1000))
# True mean.
mean_true = 1 / 2
Expand All @@ -101,8 +101,8 @@ using Turing.Essential: TuringDiagMvNormal
samples = rand(q, 1000)
mean_est = mean(samples)
var_est = var(samples)
@test mean_est mean_true atol=0.2
@test var_est var_true atol=0.2
@test mean_est mean_true atol = 0.2
@test var_est var_true atol = 0.2
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/variational/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ using Turing
function test_opt(ADPack, opt)
θ = randn(10, 10)
θ_fit = randn(10, 10)
loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1))
for t = 1:10^4
loss(x, θ_) = mean(sum(abs2, θ * x - θ_ * x; dims=1))
for t in 1:(10^4)
x = rand(10)
Δ = ADPack.gradient(θ_ -> loss(x, θ_), θ_fit)
Δ = apply!(opt, θ_fit, Δ)
Expand Down

0 comments on commit 798bad7

Please sign in to comment.