From 798bad716723a28644b2fe487e84efb43c492c4e Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 10 Jun 2024 17:53:57 +0100 Subject: [PATCH] Formatting VI files --- .JuliaFormatter.toml | 1 + src/variational/VariationalInference.jl | 28 ++++++--------------- src/variational/advi.jl | 33 +++++++++++-------------- test/variational/advi.jl | 18 +++++++------- test/variational/optimisers.jl | 4 +-- 5 files changed, 33 insertions(+), 51 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 04dcd5680..5e865d93a 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,5 +1,6 @@ style="blue" format_markdown = true +import_to_using = false # TODO # We ignore these files because when formatting was first put in place they were being worked on. # These ignores should be removed once the relevant PRs are merged/closed. diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index d601ae406..189d3f700 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -12,16 +12,9 @@ using Random: Random import AdvancedVI import Bijectors - # Reexports using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad -export - vi, - ADVI, - ELBO, - elbo, - TruncatedADAGrad, - DecayedADAGrad +export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad """ make_logjoint(model::Model; weight = 1.0) @@ -31,17 +24,10 @@ use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_o ## Notes - For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. """ -function make_logjoint(model::DynamicPPL.Model; weight = 1.0) +function make_logjoint(model::DynamicPPL.Model; weight=1.0) # setup - ctx = DynamicPPL.MiniBatchContext( - DynamicPPL.DefaultContext(), - weight - ) - f = DynamicPPL.LogDensityFunction( - model, - DynamicPPL.VarInfo(model), - ctx - ) + ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) + f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) return Base.Fix1(LogDensityProblems.logdensity, f) end @@ -52,10 +38,10 @@ function (elbo::ELBO)( q, model::DynamicPPL.Model, num_samples; - weight = 1.0, - kwargs... + weight=1.0, + kwargs..., ) - return elbo(rng, alg, q, make_logjoint(model; weight = weight), num_samples; kwargs...) + return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...) end # VI algorithms diff --git a/src/variational/advi.jl b/src/variational/advi.jl index cf2d4034a..eacf2d01b 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -14,7 +14,6 @@ function wrap_in_vec_reshape(f, in_size) return reshape_outer ∘ f ∘ reshape_inner end - """ bijector(model::Model[, sym2ranges = Val(false)]) @@ -22,26 +21,25 @@ Returns a `Stacked <: Bijector` which maps from the support of the posterior to denoting the dimensionality of the latent variables. """ function Bijectors.bijector( - model::DynamicPPL.Model, - ::Val{sym2ranges} = Val(false); - varinfo = DynamicPPL.VarInfo(model) + model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) ) where {sym2ranges} - num_params = sum([size(varinfo.metadata[sym].vals, 1) - for sym ∈ keys(varinfo.metadata)]) + num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) +]) - dists = vcat([varinfo.metadata[sym].dists for sym ∈ keys(varinfo.metadata)]...) + dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - num_ranges = sum([length(varinfo.metadata[sym].ranges) - for sym ∈ keys(varinfo.metadata)]) + num_ranges = sum([ + length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) + ]) ranges = Vector{UnitRange{Int}}(undef, num_ranges) idx = 0 range_idx = 1 # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol, Vector{UnitRange{Int}}}() - for sym ∈ keys(varinfo.metadata) + sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() + for sym in keys(varinfo.metadata) sym_lookup[sym] = Vector{UnitRange{Int}}() - for r ∈ varinfo.metadata[sym].ranges + for r in varinfo.metadata[sym].ranges ranges[range_idx] = idx .+ r push!(sym_lookup[sym], ranges[range_idx]) range_idx += 1 @@ -117,27 +115,24 @@ function AdvancedVI.update( end function AdvancedVI.vi( - model::DynamicPPL.Model, - alg::AdvancedVI.ADVI; - optimizer = AdvancedVI.TruncatedADAGrad(), + model::DynamicPPL.Model, alg::AdvancedVI.ADVI; optimizer=AdvancedVI.TruncatedADAGrad() ) q = meanfield(model) - return AdvancedVI.vi(model, alg, q; optimizer = optimizer) + return AdvancedVI.vi(model, alg, q; optimizer=optimizer) end - function AdvancedVI.vi( model::DynamicPPL.Model, alg::AdvancedVI.ADVI, q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; - optimizer = AdvancedVI.TruncatedADAGrad(), + optimizer=AdvancedVI.TruncatedADAGrad(), ) # Initial parameters for mean-field approx μ, σs = StatsBase.params(q) θ = vcat(μ, StatsFuns.invsoftplus.(σs)) # Optimize - AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer = optimizer) + AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer=optimizer) # Return updated `Distribution` return AdvancedVI.update(q, θ) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index 8f12562a5..639df018c 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -27,12 +27,12 @@ using Turing.Essential: TuringDiagMvNormal N = 500 alg = ADVI(10, 5000) - q = vi(gdemo_default, alg; optimizer = opt) + q = vi(gdemo_default, alg; optimizer=opt) samples = transpose(rand(q, N)) chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) # TODO: uhmm, seems like a large `eps` here... - check_gdemo(chn, atol = 0.5) + check_gdemo(chn; atol=0.5) end end @@ -52,7 +52,7 @@ using Turing.Essential: TuringDiagMvNormal # OR: implement `update` and pass a `Distribution` function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end])) + return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[(length(q) + 1):end])) end q0 = TuringDiagMvNormal(zeros(2), ones(2)) @@ -66,7 +66,7 @@ using Turing.Essential: TuringDiagMvNormal # https://github.com/TuringLang/Turing.jl/issues/2065 @testset "simplex bijector" begin @model function dirichlet() - x ~ Dirichlet([1.0,1.0]) + x ~ Dirichlet([1.0, 1.0]) return x end @@ -82,17 +82,17 @@ using Turing.Essential: TuringDiagMvNormal # And regression for https://github.com/TuringLang/Turing.jl/issues/2160. q = vi(m, ADVI(10, 1000)) x = rand(q, 1000) - @test mean(eachcol(x)) ≈ [0.5, 0.5] atol=0.1 + @test mean(eachcol(x)) ≈ [0.5, 0.5] atol = 0.1 end # Ref: https://github.com/TuringLang/Turing.jl/issues/2205 @testset "with `condition` (issue #2205)" begin @model function demo_issue2205() x ~ Normal() - y ~ Normal(x, 1) + return y ~ Normal(x, 1) end - model = demo_issue2205() | (y = 1.0,) + model = demo_issue2205() | (y=1.0,) q = vi(model, ADVI(10, 1000)) # True mean. mean_true = 1 / 2 @@ -101,8 +101,8 @@ using Turing.Essential: TuringDiagMvNormal samples = rand(q, 1000) mean_est = mean(samples) var_est = var(samples) - @test mean_est ≈ mean_true atol=0.2 - @test var_est ≈ var_true atol=0.2 + @test mean_est ≈ mean_true atol = 0.2 + @test var_est ≈ var_true atol = 0.2 end end diff --git a/test/variational/optimisers.jl b/test/variational/optimisers.jl index 8063cdf2e..6f64d5fb1 100644 --- a/test/variational/optimisers.jl +++ b/test/variational/optimisers.jl @@ -9,8 +9,8 @@ using Turing function test_opt(ADPack, opt) θ = randn(10, 10) θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1)) - for t = 1:10^4 + loss(x, θ_) = mean(sum(abs2, θ * x - θ_ * x; dims=1)) + for t in 1:(10^4) x = rand(10) Δ = ADPack.gradient(θ_ -> loss(x, θ_), θ_fit) Δ = apply!(opt, θ_fit, Δ)