diff --git a/src/Turing.jl b/src/Turing.jl index 80fbb0e5b..24f53e7f9 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -112,6 +112,7 @@ mutable struct Sampler{T<:InferenceAlgorithm} <: AbstractSampler alg :: T info :: Dict{Symbol, Any} # sampler infomation end +Sampler(alg, model) = Sampler(alg) # mutable struct HMCState{T<:Real} # epsilon :: T diff --git a/src/core/ad.jl b/src/core/ad.jl index c73546a08..f0c4718e0 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -30,7 +30,7 @@ getADtype(::Type{<:Hamiltonian{AD}}) where {AD} = AD gradient( θ::AbstractVector{<:Real}, vi::VarInfo, - model::Function, + model::Model, sampler::Union{Nothing, Sampler}=nothing, ) @@ -40,7 +40,7 @@ Computes the gradient of the log joint of `θ` for the model specified by @generated function gradient( θ::AbstractVector{<:Real}, vi::VarInfo, - model::Function, + model::Model, sampler::TS=nothing, ) where {TS <: Union{Nothing, Sampler}} if TS == Nothing @@ -68,7 +68,7 @@ end gradient_forward( θ::AbstractVector{<:Real}, vi::VarInfo, - model::Function, + model::Model, spl::Union{Nothing, Sampler}=nothing, chunk_size::Int=CHUNKSIZE[], ) @@ -79,7 +79,7 @@ using forwards-mode AD from ForwardDiff.jl. function gradient_forward( θ::AbstractVector{<:Real}, vi::VarInfo, - model::Function, + model::Model, sampler::Union{Nothing, Sampler}=nothing, ::Val{chunk_size}=Val(CHUNKSIZE[]), ) where chunk_size @@ -111,7 +111,7 @@ end gradient_reverse( θ::AbstractVector{<:Real}, vi::VarInfo, - model::Function, + model::Model, sampler::Union{Nothing, Sampler}=nothing, ) @@ -121,7 +121,7 @@ Computes the gradient of the log joint of `θ` for the model specified by function gradient_reverse( θ::AbstractVector{<:Real}, vi::Turing.VarInfo, - model::Function, + model::Model, sampler::Union{Nothing, Sampler}=nothing, ) vals_old, logp_old = copy(vi.vals), copy(vi.logp) diff --git a/src/core/compiler.jl b/src/core/compiler.jl index 53565b463..31584da82 100644 --- a/src/core/compiler.jl +++ b/src/core/compiler.jl @@ -4,6 +4,33 @@ using Base.Meta: parse # Overload of ~ # ################# +""" + struct Model{pvars, dvars, F, TD} + f::F + data::TD + end + +A `Model` struct with parameter variables `pvars`, data variables `dvars`, inner +function `f` and `data::NamedTuple`. +""" +struct Model{pvars, dvars, F, TD} + f::F + data::TD +end +function Model{pvars, dvars}(f::F, data::TD) where {pvars, dvars, F, TD} + return Model{pvars, dvars, F, TD}(f, data) +end +pvars(m::Model{params}) where {params} = Tuple(params.types) +dvars(m::Model{params, data}) where {params, data} = Tuple(data.types) +@generated function inpvars(::Val{sym}, ::Model{params}) where {sym, params} + return sym in params.types ? :(true) : :(false) +end +@generated function indvars(::Val{sym}, ::Model{params, data}) where {sym, params, data} + return sym in data.types ? :(true) : :(false) +end + +(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...) + # TODO: Replace this macro, see issue #514 """ Usage: @VarName x[1,2][1+5][45][3] @@ -28,179 +55,128 @@ function var_tuple(sym::Symbol, inds::Expr=:(())) end -wrong_dist_errormsg(l) = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line $(l)." +function wrong_dist_errormsg(l) + return "Right-hand side of a ~ must be subtype of Distribution or a vector of" * + "Distributions on line $(l)." +end """ - generate_observe(observation, distribution) + generate_observe(observation, dist, model_info) Generate an observe expression for observation `observation` drawn from -a distribution or a vector of distributions (`distribution`). +a distribution or a vector of distributions (`dist`). """ -function generate_observe(observation, distribution) - return esc( - quote - isdist = if isa($(distribution), AbstractVector) - # Check if the right-hand side is a vector of distributions. - all(d -> isa(d, Distribution), $(distribution)) - else - # Check if the right-hand side is a distribution. - isa($(distribution), Distribution) - end - @assert isdist @error($(wrong_dist_errormsg(@__LINE__))) - - vi.logp += Turing.observe( - sampler, - $(distribution), - $(observation), - vi - ) +function generate_observe(observation, dist, model_info) + main_body_names = model_info[:main_body_names] + vi = main_body_names[:vi] + sampler = main_body_names[:sampler] + return quote + isdist = if isa($dist, AbstractVector) + # Check if the right-hand side is a vector of distributions. + all(d -> isa(d, Distribution), $dist) + else + # Check if the right-hand side is a distribution. + isa($dist, Distribution) end - ) + @assert isdist @error($(wrong_dist_errormsg(@__LINE__))) + $vi.logp += Turing.observe($sampler, $dist, $observation, $vi) + end end """ - generate_assume(variable, distribution, syms) - -Generate an assume expression for parameters `variable` drawn from -a distribution or a vector of distributions (`distribution`). + generate_assume(var, dist, model_info) +Generate an assume expression for parameters `var` drawn from +a distribution or a vector of distributions (`dist`). """ -function generate_assume(variable, distribution, syms) - return esc( - quote - varname = Turing.VarName(vi, $syms, "") - - isdist = if isa($(distribution), AbstractVector) - # Check if the right-hand side is a vector of distributions. - all(d -> isa(d, Distribution), $(distribution)) - else - # Check if the right-hand side is a distribution. - isa($(distribution), Distribution) - end - @assert isdist @error($(wrong_dist_errormsg(@__LINE__))) - - ($(variable), _lp) = if isa($(distribution), AbstractVector) - Turing.assume( - sampler, - $(distribution), - varname, - $(variable), - vi - ) - else - Turing.assume( - sampler, - $(distribution), - varname, - vi - ) - end - vi.logp += _lp +function generate_assume(var::Union{Symbol, Expr}, dist, model_info) + main_body_names = model_info[:main_body_names] + vi = main_body_names[:vi] + sampler = main_body_names[:sampler] + + varname = gensym(:varname) + sym, idcs, csym = gensym(:sym), gensym(:idcs), gensym(:csym) + csym_str, indexing, syms = gensym(:csym_str), gensym(:indexing), gensym(:syms) + + if var isa Symbol + varname_expr = quote + $sym, $idcs, $csym = @VarName $var + $csym = Symbol($(model_info[:name]), $csym) + $syms = Symbol[$csym, $(QuoteNode(var))] + $varname = Turing.VarName($vi, $syms, "") end - ) -end + else + varname_expr = quote + $sym, $idcs, $csym = @VarName $var + $csym_str = string($(model_info[:name]))*string($csym) + $indexing = mapfoldl(string, *, $idcs, init = "") + $varname = Turing.VarName($vi, Symbol($csym_str), $sym, $indexing) + end + end + + lp = gensym(:lp) + return quote + $varname_expr + isdist = if isa($dist, AbstractVector) + # Check if the right-hand side is a vector of distributions. + all(d -> isa(d, Distribution), $dist) + else + # Check if the right-hand side is a distribution. + isa($dist, Distribution) + end + @assert isdist @error($(wrong_dist_errormsg(@__LINE__))) -function generate_assume(variable::Expr, distribution) - return esc( - quote - sym, idcs, csym = @VarName $variable - csym_str = string(Turing._compiler_[:name])*string(csym) - indexing = mapfoldl(string, *, idcs, init = "") - varname = Turing.VarName(vi, Symbol(csym_str), sym, indexing) - - # Sanity check. - isdist = if isa($(distribution), Vector) - all(d -> isa(d, Distribution), $(distribution)) - else - isa($(distribution), Distribution) - end - @assert isdist @error($(wrong_dist_errormsg(@__LINE__))) - - $(variable), _lp = Turing.assume( - sampler, - $(distribution), - varname, - vi - ) - vi.logp += _lp + ($var, $lp) = if isa($dist, AbstractVector) + Turing.assume($sampler, $dist, $varname, $var, $vi) + else + Turing.assume($sampler, $dist, $varname, $vi) end - ) + $vi.logp += $lp + end end """ - macro: @~ var Distribution() + tilde(left, right, model_info) -Tilde notation macro. This macro constructs Turing.observe or -Turing.assume calls depending on the left-hand argument. -Note that the macro is interconnected with the @model macro and -assumes that a `compiler` struct is available. - -Example: -```julia -@~ x Normal() -``` +The `tilde` function generates observation expression for data variables and assumption expressions for parameter variables, updating `model_info` in the process. """ -macro ~(left, right) - return tilde(left, right) +function tilde(left, right, model_info) + return generate_observe(left, right, model_info) end - -function tilde(left, right) - return generate_observe(left, right) +function tilde(left::Union{Symbol, Expr}, right, model_info) + return _tilde(getvsym(left), left, right, model_info) end -function tilde(left::Symbol, right) +function _tilde(vsym, left, dist, model_info) + main_body_names = model_info[:main_body_names] + model_name = main_body_names[:model] - # Check if left-hand side is a observation. - if left in Turing._compiler_[:args] - if !(left in Turing._compiler_[:dvars]) - @debug " Observe - `$(left)` is an observation" - push!(Turing._compiler_[:dvars], left) + if vsym in model_info[:arg_syms] + if !(vsym in model_info[:tent_dvars_list]) + @debug " Observe - `$(vsym)` is an observation" + push!(model_info[:tent_dvars_list], vsym) end - return generate_observe(left, right) - else - # Assume it is a parameter. - if !(left in Turing._compiler_[:pvars]) - msg = " Assume - `$(left)` is a parameter" - if isdefined(Main, left) - msg *= " (ignoring `$(left)` found in global scope)" + return quote + if Turing.indvars($(Val(vsym)), $model_name) + $(generate_observe(left, dist, model_info)) + else + $(generate_assume(left, dist, model_info)) end - - @debug msg - push!(Turing._compiler_[:pvars], left) - end - - sym, idcs, csym = @VarName(left) - csym = Symbol(Turing._compiler_[:name], csym) - syms = Symbol[csym, left] - - return generate_assume(left, right, syms) - end -end - -function tilde(left::Expr, right) - vsym = getvsym(left) - @assert isa(vsym, Symbol) - - if vsym in Turing._compiler_[:args] - if !(vsym in Turing._compiler_[:dvars]) - @debug " Observe - `$(vsym)` is an observation" - push!(Turing._compiler_[:dvars], vsym) end - - return generate_observe(left, right) else - if !(vsym in Turing._compiler_[:pvars]) + # Assume it is a parameter. + if !(vsym in model_info[:tent_pvars_list]) msg = " Assume - `$(vsym)` is a parameter" if isdefined(Main, vsym) msg *= " (ignoring `$(vsym)` found in global scope)" end @debug msg - push!(Turing._compiler_[:pvars], vsym) + push!(model_info[:tent_pvars_list], vsym) end - return generate_assume(left, right) + return generate_assume(left, dist, model_info) end end @@ -209,206 +185,266 @@ end ################# """ - @model(name, fbody) + @model(body) Macro to specify a probabilistic model. Example: +Model definition: + ```julia -@model Gaussian(x) = begin - s ~ InverseGamma(2,3) - m ~ Normal(0,sqrt.(s)) - for i in 1:length(x) - x[i] ~ Normal(m, sqrt.(s)) - end - return (s, m) +@model model_generator(x = default_x, y) = begin + ... end ``` -Compiler design: `sample(fname(x,y), sampler)`. +Expanded model definition + ```julia -fname(x=nothing,y=nothing; compiler=compiler) = begin - ex = quote - # Pour in kwargs for those args where value != nothing. - fname_model(vi::VarInfo, sampler::Sampler; x = x, y = y) = begin - vi.logp = zero(Real) - - # Pour in model definition. - x ~ Normal(0,1) - y ~ Normal(x, 1) - return x, y +# Allows passing arguments as kwargs +model_generator(; x = nothing, y = nothing)) = model_generator(x, y) +function model_generator(x = nothing, y = nothing) + pvars, dvars = Turing.get_vars(Tuple{:x, :y}, (x = x, y = y)) + data = Turing.get_data(dvars, (x = x, y = y)) + + inner_function(sampler::Turing.AnySampler, model) = inner_function(model) + function inner_function(model) + return inner_function(Turing.VarInfo(), Turing.SampleFromPrior(), model) + end + function inner_function(vi::Turing.VarInfo, model) + return inner_function(vi, Turing.SampleFromPrior(), model) + end + # Define the main inner function + function inner_function(vi::Turing.VarInfo, sampler::Turing.AnySampler, model) + local x + if isdefined(model.data, :x) + x = model.data.x + else + x = default_x + end + local y + if isdefined(model.data, :y) + y = model.data.y + else + y = nothing end + + vi.logp = zero(Real) + ... end - return Main.eval(ex) + model = Turing.Model{pvars, dvars}(inner_function, data) + return model end ``` -""" -macro model(fexpr) - - # translate all ~ occurences to macro calls - fexpr = translate(fexpr) - # extract model name (:name), arguments (:args), (:kwargs) and definition (:body) - modeldef = MacroTools.splitdef(fexpr) +Generating a model: `model_generator(x_value)::Model`. +""" +macro model(input_expr) + build_model_info(input_expr) |> translate_tilde! |> update_args! |> build_output +end - # function body of the model is empty - if all(l -> isa(l, LineNumberNode), modeldef[:body].args) - @warn("Model definition seems empty, still continue.") - end +""" + build_model_info(input_expr) - # construct compiler dictionary - compiler = Dict( +Builds the `model_info` dictionary from the model's expression. +""" +function build_model_info(input_expr) + # Extract model name (:name), arguments (:args), (:kwargs) and definition (:body) + modeldef = MacroTools.splitdef(input_expr) + # Function body of the model is empty + warn_empty(modeldef[:body]) + # Construct model_info dictionary + + arg_syms = [(arg isa Symbol) ? arg : arg.args[1] for arg in modeldef[:args]] + model_info = Dict( :name => modeldef[:name], - :closure_name => Symbol(modeldef[:name], :_model), - :args => [], + :input_expr => input_expr, + :main_body => modeldef[:body], + :arg_syms => arg_syms, + :args => modeldef[:args], :kwargs => modeldef[:kwargs], - :dvars => Set{Symbol}(), - :pvars => Set{Symbol}() + :tent_dvars_list => Symbol[], + :tent_pvars_list => Symbol[], + :main_body_names => Dict( + :vi => gensym(:vi), + :sampler => gensym(:sampler), + :model => gensym(:model), + :pvars => gensym(:pvars), + :dvars => gensym(:dvars), + :data => gensym(:data), + :inner_function => gensym(:inner_function) + ) ) - # Manipulate the function arguments. - fargs = deepcopy(vcat(modeldef[:args], modeldef[:kwargs])) + return model_info +end + +""" + translate_tilde!(model_info) + +Translates ~ expressions to observation or assumption expressions, updating `model_info`. +""" +function translate_tilde!(model_info) + ex = model_info[:main_body] + ex = MacroTools.postwalk(x -> @capture(x, L_ ~ R_) ? tilde(L, R, model_info) : x, ex) + model_info[:main_body] = ex + return model_info +end + +""" + update_args!(model_info) + +Extracts default argument values and replaces them with `nothing`. +""" +function update_args!(model_info) + fargs = model_info[:args] + fargs_default_values = Dict() for i in 1:length(fargs) if isa(fargs[i], Symbol) + fargs_default_values[fargs[i]] = :nothing fargs[i] = Expr(:kw, fargs[i], :nothing) + elseif isa(fargs[i], Expr) && fargs[i].head == :kw + fargs_default_values[fargs[i].args[1]] = fargs[i].args[2] + fargs[i] = Expr(:kw, fargs[i].args[1], :nothing) + else + throw("Unsupported argument type $(fargs[i]).") end end + model_info[:args] = fargs + model_info[:arg_defaults] = fargs_default_values - # Construct closure. - closure = MacroTools.combinedef( - Dict( - :name => compiler[:closure_name], - :kwargs => [], - :args => [ - :(vi::Turing.VarInfo), - :(sampler::Turing.AnySampler) - ], - # Initialise logp in VarInfo. - :body => Expr(:block, :(vi.logp = zero(Real)), modeldef[:body].args...) - ) - ) - - # Construct aliases. - alias1 = MacroTools.combinedef( - Dict( - :name => compiler[:closure_name], - :args => [:(vi::Turing.VarInfo)], - :kwargs => [], - :body => :(return $(compiler[:closure_name])(vi, Turing.SampleFromPrior())) - ) - ) - - alias2 = MacroTools.combinedef( - Dict( - :name => compiler[:closure_name], - :args => [:(sampler::Turing.AnySampler)], - :kwargs => [], - :body => :(return $(compiler[:closure_name])(Turing.VarInfo(), Turing.SampleFromPrior())) - ) - ) - - alias3 = MacroTools.combinedef( - Dict( - :name => compiler[:closure_name], - :args => [], - :kwargs => [], - :body => :(return $(compiler[:closure_name])(Turing.VarInfo(), Turing.SampleFromPrior())) - ) - ) - - # Add definitions to the compiler. - compiler[:closure] = closure - compiler[:alias1] = alias1 - compiler[:alias2] = alias2 - compiler[:alias3] = alias3 - - # Construct user function. - modelfun = MacroTools.combinedef( - Dict( - :name => compiler[:name], - :kwargs => [Expr(:kw, :compiler, compiler)], - :args => fargs, - :body => Expr(:block, - quote - Turing.eval(:(_compiler_ = deepcopy($compiler))) - # Copy the expr of function definition and callbacks - closure = Turing._compiler_[:closure] - alias1 = Turing._compiler_[:alias1] - alias2 = Turing._compiler_[:alias2] - alias3 = Turing._compiler_[:alias3] - modelname = Turing._compiler_[:closure_name] - end, - # Insert argument values as kwargs to the closure - map(data_insertion, fargs)..., - # Eval the closure's methods globally and return it - quote - Main.eval(Expr(:(=), modelname, closure)) - Main.eval(alias1) - Main.eval(alias2) - Main.eval(alias3) - return $(compiler[:closure_name]) - end, - ) - ) - ) - - return esc(modelfun) + return model_info end +""" + build_output(model_info) -#################### -# Helper functions # -#################### +Builds the output expression. +""" +function build_output(model_info) + # Construct user-facing function + main_body_names = model_info[:main_body_names] + vi_name = main_body_names[:vi] + model_name = main_body_names[:model] + sampler_name = main_body_names[:sampler] + data_name = main_body_names[:data] + pvars_name = main_body_names[:pvars] + dvars_name = main_body_names[:dvars] + inner_function_name = main_body_names[:inner_function] + + args = model_info[:args] + arg_syms = model_info[:arg_syms] + outer_function_name = model_info[:name] + tent_pvars_list = model_info[:tent_pvars_list] + tent_dvars_list = model_info[:tent_dvars_list] + main_body = model_info[:main_body] + arg_defaults = model_info[:arg_defaults] + + if length(tent_dvars_list) == 0 + tent_dvars_nt = :(NamedTuple()) + else + tent_dvars_nt = :($([:($var = $var) for var in tent_dvars_list]...),) + end -function data_insertion(k) - if isa(k, Symbol) - _k = k - elseif k.head == :kw - _k = k.args[1] + #= Does the following for each of the tentative dvars + local x + if isdefined(model.data, :x) + x = model.data.x else - return :() + x = default_x end + =# + unwrap_data_expr = Expr(:block) + for var in tent_dvars_list + push!(unwrap_data_expr.args, quote + local $var + if isdefined($model_name.data, $(QuoteNode(var))) + $var = $model_name.data.$var + else + $var = $(arg_defaults[var]) + end + end) + end + + return esc(quote + # Allows passing arguments as kwargs + $outer_function_name(;$(args...)) = $outer_function_name($(arg_syms...)) + # Outer function with `nothing` as default values + function $outer_function_name($(args...)) + # Adds variables equal to `nothing` to pvars and the rest to dvars + # `tent_pvars_list` is the tentative list of pvars + # `tent_dvars_nt` is the tentative named tuple of dvars + $pvars_name, $dvars_name = Turing.get_vars($(Tuple{tent_pvars_list...}), $(tent_dvars_nt)) + # Filter out the dvars equal to `nothing` + $data_name = Turing.get_data($dvars_name, $tent_dvars_nt) + + # Define fallback inner functions + function $inner_function_name($sampler_name::Turing.AnySampler, $model_name) + return $inner_function_name($model_name) + end + function $inner_function_name($model_name) + return $inner_function_name(Turing.VarInfo(), Turing.SampleFromPrior(), $model_name) + end + function $inner_function_name($vi_name::Turing.VarInfo, $model_name) + return $inner_function_name($vi_name, Turing.SampleFromPrior(), $model_name) + end - return quote - if $_k == nothing - # Notify the user if an argument is missing. - @warn("Data `"*$(string(_k))*"` not provided, treating as parameter instead.") - else - if $(QuoteNode(_k)) ∉ Turing._compiler_[:args] - push!(Turing._compiler_[:args], $(QuoteNode(_k))) - end - closure = Turing.setkwargs(closure, $(QuoteNode(_k)), $_k) - end + # Define the main inner function + function $inner_function_name( + $vi_name::Turing.VarInfo, + $sampler_name::Turing.AnySampler, + $model_name + ) + + $unwrap_data_expr + $vi_name.logp = zero(Real) + $main_body end + $model_name = Turing.Model{$pvars_name, $dvars_name}($inner_function_name, $data_name) + return $model_name + end + end) end -function setkwargs(fexpr::Expr, kw::Symbol, value) - # Split up the function definition. - funcdef = MacroTools.splitdef(fexpr) +@generated function get_vars(tent_pvars::Type{Tpvars}, tent_dvars_nt::NamedTuple) where {Tpvars <: Tuple} + tent_pvar_syms = [Tpvars.types...] + tent_dvar_syms = [tent_dvars_nt.names...] + dvar_types = [tent_dvars_nt.types...] + append!(tent_pvar_syms, [tent_dvar_syms[i] for i in 1:length(tent_dvar_syms) if dvar_types[i] == Nothing]) + setdiff!(tent_dvar_syms, tent_pvar_syms) + pvars_tuple = Tuple{tent_pvar_syms...} + dvars_tuple = Tuple{tent_dvar_syms...} - # Add the new keyword argument. - push!(funcdef[:kwargs], Expr(:kw, kw, value)) + return :($pvars_tuple, $dvars_tuple) +end - # Recompose the function. - return MacroTools.combinedef(funcdef) +@generated function get_data(::Type{Tdvars}, nt) where Tdvars + dvars = Tdvars.types + args = [] + for var in dvars + push!(args, :($var = nt.$var)) + end + if length(args) == 0 + return :(NamedTuple()) + else + return :($(args...),) + end +end + +function warn_empty(body) + if all(l -> isa(l, LineNumberNode), body.args) + @warn("Model definition seems empty, still continue.") + end + return end +#################### +# Helper functions # +#################### + getvsym(s::Symbol) = s function getvsym(expr::Expr) @assert expr.head == :ref "expr needs to be an indexing expression, e.g. :(x[1])" return getvsym(expr.args[1]) end - -translate!(ex::Any) = ex -function translate!(ex::Expr) - if ex.head === :call && ex.args[1] === :(~) - ex.head = :macrocall - ex.args[1] = Symbol("@~") - insert!(ex.args, 2, LineNumberNode(@__LINE__)) - else - map(translate!, ex.args) - end - return ex -end -translate(ex::Expr) = translate!(deepcopy(ex)) diff --git a/src/core/container.jl b/src/core/container.jl index d1a907a95..1905d9ab9 100644 --- a/src/core/container.jl +++ b/src/core/container.jl @@ -6,8 +6,8 @@ Data structure for particle filters - normalise!(pc::ParticleContainer) - consume(pc::ParticleContainer): return incremental likelihood """ -mutable struct ParticleContainer{T<:Particle} - model :: Function +mutable struct ParticleContainer{T<:Particle, F} + model :: F num_particles :: Int vals :: Array{T} logWs :: Array{Float64} # Log weights (Trace) or incremental likelihoods (ParticleContainer) @@ -17,8 +17,8 @@ mutable struct ParticleContainer{T<:Particle} n_consume :: Int # helpful for rejuvenation steps, e.g. in SMC2 end ParticleContainer{T}(m) where T = ParticleContainer{T}(m, 0) -function ParticleContainer{T}(m::Function,n::Int) where T - ParticleContainer{T}(m, n, Vector{T}(), Vector{Float64}(), 0.0, nothing, 0) +function ParticleContainer{T}(m::F, n::Int) where {T, F} + ParticleContainer{T, F}(m, n, Vector{T}(), Vector{Float64}(), 0.0, nothing, 0) end Base.collect(pc :: ParticleContainer) = pc.vals # prev: Dict, now: Array diff --git a/src/core/trace.jl b/src/core/trace.jl index ef1ea0b68..412fd656f 100644 --- a/src/core/trace.jl +++ b/src/core/trace.jl @@ -8,7 +8,7 @@ mutable struct Trace end # NOTE: this function is called by `forkr` -function Trace(f::Function) +function Trace(f) res = Trace(); # Task(()->f()); res.task = Task( () -> begin res=f(); produce(Val{:done}); res; end ) @@ -19,7 +19,7 @@ function Trace(f::Function) res end -function Trace(f::Function, spl::Sampler, vi :: VarInfo) +function Trace(f, spl::Sampler, vi :: VarInfo) res = Trace(); res.spl = spl # Task(()->f()); diff --git a/src/samplers/dynamichmc.jl b/src/samplers/dynamichmc.jl index aa2713777..1b5e59027 100644 --- a/src/samplers/dynamichmc.jl +++ b/src/samplers/dynamichmc.jl @@ -36,7 +36,11 @@ function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian return Sampler(alg, Dict{Symbol,Any}()) end -function sample(model::Function, alg::DynamicNUTS, chunk_size=CHUNKSIZE[]) where T <: Hamiltonian +function sample(model::Model, + alg::DynamicNUTS, + chunk_size=CHUNKSIZE[] + ) where T <: Hamiltonian + if ADBACKEND[] == :forward_diff default_chunk_size = CHUNKSIZE[] # record global chunk size setchunksize(chunk_size) # set temp chunk size @@ -52,7 +56,7 @@ function sample(model::Function, alg::DynamicNUTS, chunk_size=CHUNKSIZE[]) where end vi = VarInfo() - Base.invokelatest(model, vi, HamiltonianRobustInit()) + model(vi, HamiltonianRobustInit()) if spl.alg.gid == 0 link!(vi, spl) diff --git a/src/samplers/gibbs.jl b/src/samplers/gibbs.jl index 1e28326e8..8aae200e5 100644 --- a/src/samplers/gibbs.jl +++ b/src/samplers/gibbs.jl @@ -35,7 +35,7 @@ Gibbs(alg::Gibbs, new_gid) = Gibbs(alg.n_iters, alg.algs, alg.thin, new_gid) const GibbsComponent = Union{Hamiltonian,MH,PG} -function Sampler(alg::Gibbs) +function Sampler(alg::Gibbs, model::Model) n_samplers = length(alg.algs) samplers = Array{Sampler}(undef, n_samplers) @@ -44,7 +44,7 @@ function Sampler(alg::Gibbs) for i in 1:n_samplers sub_alg = alg.algs[i] if isa(sub_alg, GibbsComponent) - samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i)) + samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i), model) else @error("[Gibbs] unsupport base sampling algorithm $alg") end @@ -52,10 +52,10 @@ function Sampler(alg::Gibbs) end # Sanity check for space - @assert issubset(Turing._compiler_[:pvars], space) "[Gibbs] symbols specified to samplers ($space) doesn't cover the model parameters ($(Turing._compiler_[:pvars]))" + @assert issubset(Set(pvars(model)), space) "[Gibbs] symbols specified to samplers ($space) doesn't cover the model parameters ($(Set(pvars(model))))" - if Turing._compiler_[:pvars] != space - @warn("[Gibbs] extra parameters specified by samplers don't exist in model: $(setdiff(space, Turing._compiler_[:pvars]))") + if Set(pvars(model)) != space + @warn("[Gibbs] extra parameters specified by samplers don't exist in model: $(setdiff(space, Set(pvars(model))))") end info = Dict{Symbol, Any}() @@ -65,7 +65,7 @@ function Sampler(alg::Gibbs) end function sample( - model::Function, + model::Model, alg::Gibbs; save_state=false, # flag for state saving resume_from=nothing, # chain to continue @@ -73,7 +73,7 @@ function sample( ) # Init the (master) Gibbs sampler - spl = reuse_spl_n > 0 ? resume_from.info[:spl] : Sampler(alg) + spl = reuse_spl_n > 0 ? resume_from.info[:spl] : Sampler(alg, model) @assert typeof(spl.alg) == typeof(alg) "[Turing] alg type mismatch; please use resume() to re-use spl" @@ -102,7 +102,7 @@ function sample( # Init parameters varInfo = if resume_from == nothing vi_ = VarInfo() - Base.invokelatest(model, vi_, HamiltonianRobustInit()) + model(vi_, HamiltonianRobustInit()) vi_ else resume_from.info[:vi] diff --git a/src/samplers/hmc.jl b/src/samplers/hmc.jl index 196906612..0349859dd 100644 --- a/src/samplers/hmc.jl +++ b/src/samplers/hmc.jl @@ -100,7 +100,7 @@ Sampler(alg::Hamiltonian, adapt_conf::DEFAULT_ADAPT_CONF_TYPE) = begin Sampler(alg, info) end -function sample(model::Function, alg::Hamiltonian; +function sample(model::Model, alg::Hamiltonian; chunk_size=CHUNKSIZE[], # set temporary chunk size save_state=false, # flag for state saving resume_from=nothing, # chain to continue @@ -137,7 +137,7 @@ function sample(model::Function, alg::Hamiltonian; vi = if resume_from == nothing vi_ = VarInfo() - Base.invokelatest(model, vi_, HamiltonianRobustInit()) + model(vi_, HamiltonianRobustInit()) spl.info[:eval_num] += 1 vi_ else diff --git a/src/samplers/ipmcmc.jl b/src/samplers/ipmcmc.jl index a72e883f0..ad4602abb 100644 --- a/src/samplers/ipmcmc.jl +++ b/src/samplers/ipmcmc.jl @@ -76,7 +76,7 @@ function Sampler(alg::IPMCMC) Sampler(alg, info) end -step(model::Function, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first::Bool) = begin +step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first::Bool) = begin # Initialise array for marginal likelihood estimators log_zs = zeros(spl.alg.n_nodes) @@ -106,7 +106,7 @@ step(model::Function, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first VarInfos[nodes_permutation] end -sample(model::Function, alg::IPMCMC) = begin +function sample(model::Model, alg::IPMCMC) spl = Sampler(alg) diff --git a/src/samplers/is.jl b/src/samplers/is.jl index 68c3af9fd..093cdbbb4 100644 --- a/src/samplers/is.jl +++ b/src/samplers/is.jl @@ -40,7 +40,7 @@ Sampler(alg::IS) = begin Sampler(alg, info) end -sample(model::Function, alg::IS) = begin +function sample(model::Model, alg::IS) spl = Sampler(alg); samples = Array{Sample}(undef, alg.n_particles) diff --git a/src/samplers/mh.jl b/src/samplers/mh.jl index 287927644..6a463eda4 100644 --- a/src/samplers/mh.jl +++ b/src/samplers/mh.jl @@ -49,14 +49,14 @@ function MH(n_iters::Int, space...) end MH{T}(alg::MH, new_gid::Int) where T = MH{T}(alg.n_iters, alg.proposals, alg.space, new_gid) -Sampler(alg::MH) = begin +Sampler(alg::MH, model::Model) = begin alg_str = "MH" # Sanity check for space if alg.gid == 0 && !isempty(alg.space) - @assert issubset(Turing._compiler_[:pvars], alg.space) "[$alg_str] symbols specified to samplers ($alg.space) doesn't cover the model parameters ($(Turing._compiler_[:pvars]))" - if Turing._compiler_[:pvars] != alg.space - warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(alg.space, Turing._compiler_[:pvars]))") + @assert issubset(Set(pvars(model)), alg.space) "[$alg_str] symbols specified to samplers ($alg.space) doesn't cover the model parameters ($(Set(pvars(model))))" + if Set(pvars(model)) != alg.space + warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(alg.space, Set(pvars(model))))") end end @@ -68,18 +68,18 @@ Sampler(alg::MH) = begin Sampler(alg, info) end -propose(model::Function, spl::Sampler{<:MH}, vi::VarInfo) = begin +propose(model, spl::Sampler{<:MH}, vi::VarInfo) = begin spl.info[:proposal_ratio] = 0.0 spl.info[:prior_prob] = 0.0 spl.info[:violating_support] = false runmodel!(model, vi ,spl) end -function step(model::Function, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{true}) +function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{true}) return vi, true end -function step(model::Function, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{false}) +function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{false}) if spl.alg.gid != 0 # Recompute joint in logp runmodel!(model, vi, nothing) end @@ -104,7 +104,7 @@ function step(model::Function, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{fa return vi, is_accept end -function sample(model::Function, alg::MH; +function sample(model::Model, alg::MH; save_state=false, # flag for state saving resume_from=nothing, # chain to continue reuse_spl_n=0, # flag for spl re-using @@ -112,7 +112,7 @@ function sample(model::Function, alg::MH; spl = reuse_spl_n > 0 ? resume_from.info[:spl] : - Sampler(alg) + Sampler(alg, model) alg_str = "MH" # Initialization @@ -128,7 +128,7 @@ function sample(model::Function, alg::MH; vi = if resume_from == nothing vi_ = VarInfo() - Base.invokelatest(model, vi_, HamiltonianRobustInit()) + model(vi_, HamiltonianRobustInit()) vi_ else resume_from.info[:vi] diff --git a/src/samplers/pgibbs.jl b/src/samplers/pgibbs.jl index 7dc96672e..fbe9f29dd 100644 --- a/src/samplers/pgibbs.jl +++ b/src/samplers/pgibbs.jl @@ -50,9 +50,9 @@ Sampler(alg::PG) = begin Sampler(alg, info) end -step(model::Function, spl::Sampler{<:PG}, vi::VarInfo, _) = step(model, spl, vi) +step(model, spl::Sampler{<:PG}, vi::VarInfo, _) = step(model, spl, vi) -step(model::Function, spl::Sampler{<:PG}, vi::VarInfo) = begin +step(model, spl::Sampler{<:PG}, vi::VarInfo) = begin particles = ParticleContainer{Trace}(model) vi.num_produce = 0; # Reset num_produce before new sweep\. @@ -82,7 +82,7 @@ step(model::Function, spl::Sampler{<:PG}, vi::VarInfo) = begin return particles[indx].vi, true end -sample(model::Function, alg::PG; +sample(model::Model, alg::PG; save_state=false, # flag for state saving resume_from=nothing, # chain to continue reuse_spl_n=0 # flag for spl re-using diff --git a/src/samplers/pmmh.jl b/src/samplers/pmmh.jl index 591aeea78..2a63f3ae4 100644 --- a/src/samplers/pmmh.jl +++ b/src/samplers/pmmh.jl @@ -34,7 +34,7 @@ PMMH(alg::PMMH, new_gid) = PMMH(alg.n_iters, alg.algs, alg.space, new_gid) PIMH(n_iters::Int, smc_alg::SMC) = PMMH(n_iters, tuple(smc_alg), Set(), 0) -function Sampler(alg::PMMH) +function Sampler(alg::PMMH, model::Model) alg_str = "PMMH" n_samplers = length(alg.algs) samplers = Array{Sampler}(undef, n_samplers) @@ -44,7 +44,7 @@ function Sampler(alg::PMMH) for i in 1:n_samplers sub_alg = alg.algs[i] if isa(sub_alg, Union{SMC, MH}) - samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i)) + samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i), model) else error("[$alg_str] unsupport base sampling algorithm $alg") end @@ -56,10 +56,10 @@ function Sampler(alg::PMMH) # Sanity check for space if !isempty(space) - @assert issubset(Turing._compiler_[:pvars], space) "[$alg_str] symbols specified to samplers ($space) doesn't cover the model parameters ($(Turing._compiler_[:pvars]))" + @assert issubset(Set(pvars(model)), space) "[$alg_str] symbols specified to samplers ($space) doesn't cover the model parameters ($(Set(pvars(model))))" - if Turing._compiler_[:pvars] != space - warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(space, Turing._compiler_[:pvars]))") + if Set(pvars(model)) != space + warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(space, Set(pvars(model))))") end end @@ -71,7 +71,7 @@ function Sampler(alg::PMMH) Sampler(alg, info) end -step(model::Function, spl::Sampler{<:PMMH}, vi::VarInfo, is_first::Bool) = begin +step(model, spl::Sampler{<:PMMH}, vi::VarInfo, is_first::Bool) = begin violating_support = false proposal_ratio = 0.0 new_prior_prob = 0.0 @@ -113,7 +113,7 @@ step(model::Function, spl::Sampler{<:PMMH}, vi::VarInfo, is_first::Bool) = begin return vi, is_accept end -sample(model::Function, alg::PMMH; +sample(model::Model, alg::PMMH; save_state=false, # flag for state saving resume_from=nothing, # chain to continue reuse_spl_n=0 # flag for spl re-using @@ -136,7 +136,7 @@ sample(model::Function, alg::PMMH; # Init parameters vi = if resume_from == nothing vi_ = VarInfo() - Base.invokelatest(model, vi_, HamiltonianRobustInit()) + model(vi_, HamiltonianRobustInit()) vi_ else resume_from.info[:vi] diff --git a/src/samplers/smc.jl b/src/samplers/smc.jl index 3ae69e3f0..a098de140 100644 --- a/src/samplers/smc.jl +++ b/src/samplers/smc.jl @@ -47,7 +47,7 @@ Sampler(alg::SMC) = begin Sampler(alg, info) end -step(model::Function, spl::Sampler{<:SMC}, vi::VarInfo) = begin +step(model, spl::Sampler{<:SMC}, vi::VarInfo) = begin particles = ParticleContainer{Trace}(model) vi.num_produce = 0; # Reset num_produce before new sweep\. set_retained_vns_del_by_spl!(vi, spl) @@ -71,7 +71,7 @@ step(model::Function, spl::Sampler{<:SMC}, vi::VarInfo) = begin end ## wrapper for smc: run the sampler, collect results. -function sample(model::Function, alg::SMC) +function sample(model::Model, alg::SMC) spl = Sampler(alg); particles = ParticleContainer{Trace}(model) diff --git a/src/samplers/support/hmc_core.jl b/src/samplers/support/hmc_core.jl index 08f0da6cb..ec7d5e670 100644 --- a/src/samplers/support/hmc_core.jl +++ b/src/samplers/support/hmc_core.jl @@ -122,12 +122,12 @@ end ### -function runmodel!(model::Function, vi::VarInfo, spl::Union{Nothing,Sampler}) +function runmodel!(model, vi::VarInfo, spl::Union{Nothing,Sampler}) setlogp!(vi, zero(Real)) if spl != nothing && :eval_num ∈ keys(spl.info) spl.info[:eval_num] += 1 end - Base.invokelatest(model, vi, spl) + model(vi, spl) return vi end @@ -135,7 +135,7 @@ function leapfrog(θ::AbstractVector{<:Real}, p::AbstractVector{<:Real}, τ::Int, ϵ::Real, - model::Function, + model, vi::VarInfo, sampler::Sampler, ) @@ -238,7 +238,7 @@ end # TODO: remove used Turing-wrapper functions # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/base_hmc.hpp -function find_good_eps(model::Function, spl::Sampler{T}, vi::VarInfo) where T +function find_good_eps(model, spl::Sampler{T}, vi::VarInfo) where T logpdf_func_float = gen_lj_func(vi, spl, model) momentum_sampler = gen_momentum_sampler(vi, spl) H_func = gen_H_func() diff --git a/src/utilities/io.jl b/src/utilities/io.jl index bf3322a7c..5e4b50b6a 100644 --- a/src/utilities/io.jl +++ b/src/utilities/io.jl @@ -200,7 +200,7 @@ function Base.vcat(c1::Chain, args::Chain...) Chain(0, value2) end -save!(c::Chain, spl::Sampler, model::Function, vi) = begin +save!(c::Chain, spl::Sampler, model, vi) = begin c.info[:spl] = spl c.info[:model] = model c.info[:vi] = deepcopy(vi) diff --git a/test/compiler.jl/explicit_ret.jl b/test/compiler.jl/explicit_ret.jl index 5f08c6e4f..877a928f7 100644 --- a/test/compiler.jl/explicit_ret.jl +++ b/test/compiler.jl/explicit_ret.jl @@ -14,7 +14,6 @@ mf = test_ex_rt() for alg = [HMC(2000, 0.2, 3), PG(20, 2000), SMC(10000), IS(10000), Gibbs(2000, PG(20, 1, :x), HMC(1, 0.2, 3, :y))] println("[explicit_ret.jl] testing $alg") chn = sample(mf, alg) - @test mean(chn[:x]) ≈ 9.0 atol=0.2 + @test mean(chn[:x]) ≈ 10.0 atol=0.2 @test mean(chn[:y]) ≈ 5.0 atol=0.2 - @test mean(chn[:z]) ≈ 6.0 atol=0.2 end diff --git a/test/compiler.jl/model_macro.jl b/test/compiler.jl/model_macro.jl index 2ff7e51b4..9d272197b 100644 --- a/test/compiler.jl/model_macro.jl +++ b/test/compiler.jl/model_macro.jl @@ -1,11 +1,11 @@ using Turing, Distributions, Test using MacroTools +model_info = Dict(:main_body_names => Dict(:vi => :vi, :sampler => :sampler)) # unit test model macro -expr = Turing.generate_observe(:x, :y) -@test expr.head == :escape -@test expr.args[1].head == :block -@test :(vi.logp += Turing.observe(sampler, y, x, vi)) in expr.args[1].args +expr = Turing.generate_observe(:x, :y, model_info) +@test expr.head == :block +@test :(vi.logp += Turing.observe(sampler, y, x, vi)) in expr.args @model testmodel_comp(x, y) = begin s ~ InverseGamma(2,3) @@ -16,35 +16,7 @@ expr = Turing.generate_observe(:x, :y) return x, y end - testmodel_comp(1.0, 1.2) -c = deepcopy(Turing._compiler_) - -alias1 = Dict( - :name => :testmodel_comp_model, - :args => [:(vi::Turing.VarInfo)], - :kwargs => [], - :body => :(return testmodel_comp_model(vi, Turing.SampleFromPrior())) - ) -@test c[:alias1] == MacroTools.combinedef(alias1) - -alias2 = Dict( - :name => :testmodel_comp_model, - :args => [:(sampler::Turing.AnySampler)], - :kwargs => [], - :body => :(return testmodel_comp_model(Turing.VarInfo(), Turing.SampleFromPrior())) - ) -@test c[:alias2] == MacroTools.combinedef(alias2) - -alias3 = Dict( - :name => :testmodel_comp_model, - :args => [], - :kwargs => [], - :body => :(return testmodel_comp_model(Turing.VarInfo(), Turing.SampleFromPrior())) - ) -@test c[:alias3] == MacroTools.combinedef(alias3) -@test length(c[:closure].args[2].args[2].args) == 6 -@test mapreduce(line -> line.head == :macrocall, +, c[:closure].args[2].args[2].args) == 4 # check if drawing from the prior works @model testmodel0(x) = begin @@ -54,6 +26,15 @@ end f0_mm = testmodel0() @test mean(f0_mm() for _ in 1:1000) ≈ 0. atol=0.1 +# Test #544 +@model testmodel0(x = Vector{Float64}(undef, 2)) = begin + x[1] ~ Normal() + x[2] ~ Normal() + return x +end +f0_mm = testmodel0() +@test all(isapprox.(mean(f0_mm() for _ in 1:1000), 0., atol=0.1)) + @model testmodel01(x) = begin x ~ Bernoulli(0.5) return x diff --git a/test/compiler.jl/tilde.jl b/test/compiler.jl/tilde.jl index 8868a2a84..80bb4aeda 100644 --- a/test/compiler.jl/tilde.jl +++ b/test/compiler.jl/tilde.jl @@ -1,23 +1,20 @@ using Turing -import Turing.translate! +import Turing.translate_tilde! + +model_info = Dict(:name => "model", :main_body_names => Dict(:model => :model, :vi => :vi, :sampler => :sampler), :arg_syms => [], :tent_pvars_list => []) + +ex = :(y ~ Normal(1,1)) +model_info[:main_body] = ex +translate_tilde!(model_info) +res = model_info[:main_body] +Base.@assert res.head == :block ex = quote x = 1 y = rand() y ~ Normal(0,1) end - -res = translate!(:(y~Normal(1,1))) - -Base.@assert res.head == :macrocall -Base.@assert res.args[1] == Symbol("@~") -Base.@assert res.args[3] == :y -Base.@assert res.args[4] == :(Normal(1, 1)) - - -res2 = translate!(ex) - -Base.@assert res2.args[end].head == :macrocall -Base.@assert res2.args[end].args[1] == Symbol("@~") -Base.@assert res2.args[end].args[3] == :y -Base.@assert res2.args[end].args[4] == :(Normal(0, 1)) +model_info[:main_body] = ex +translate_tilde!(model_info) +res = model_info[:main_body] +Base.@assert res.head == :block diff --git a/test/gibbs.jl/gibbs_constructor.jl b/test/gibbs.jl/gibbs_constructor.jl index 72e4a8651..147198b4a 100644 --- a/test/gibbs.jl/gibbs_constructor.jl +++ b/test/gibbs.jl/gibbs_constructor.jl @@ -31,7 +31,7 @@ end @test length(c4[:s]) == N * (3 + 2) # Test gid of each samplers -g = Turing.Sampler(s3) +g = Turing.Sampler(s3, gdemo()) @test g.info[:samplers][1].alg.gid == 1 @test g.info[:samplers][2].alg.gid == 2 diff --git a/test/hmc_core.jl/bayes_lr.jl b/test/hmc_core.jl/bayes_lr.jl index c4e7a055b..7c9ace305 100644 --- a/test/hmc_core.jl/bayes_lr.jl +++ b/test/hmc_core.jl/bayes_lr.jl @@ -81,9 +81,9 @@ lj = lj_func(θ) chn = [] accept_num = 1 - total_num = 2000 for iter = 1:total_num + global θ, chn, lj, lj_func, grad_func, std, accept_num push!(chn, θ) θ, lj, is_accept, τ_valid, α = _hmc_step(θ, lj, lj_func, grad_func, 3, 0.005, std) accept_num += is_accept diff --git a/test/hmc_core.jl/dual_averaging.jl b/test/hmc_core.jl/dual_averaging.jl index 1b4e07963..01c1cb401 100644 --- a/test/hmc_core.jl/dual_averaging.jl +++ b/test/hmc_core.jl/dual_averaging.jl @@ -1,18 +1,12 @@ function _adapt_ϵ(logϵ, Hbar, logϵbar, da_stat, m, M_adapt, δ, μ; γ=0.05, t0=10, κ=0.75) -if m <= M_adapt - -Hbar = (1.0 - 1.0 / (m + t0)) * Hbar + (1 / (m + t0)) * (δ - da_stat) -logϵ = μ - sqrt(m) / γ * Hbar -logϵbar = m^(-κ) * logϵ + (1 - m^(-κ)) * logϵbar - -else - -logϵ = logϵbar - -end - -return logϵ, Hbar, logϵbar - + if m <= M_adapt + Hbar = (1.0 - 1.0 / (m + t0)) * Hbar + (1 / (m + t0)) * (δ - da_stat) + logϵ = μ - sqrt(m) / γ * Hbar + logϵbar = m^(-κ) * logϵ + (1 - m^(-κ)) * logϵbar + else + logϵ = logϵbar + end + return logϵ, Hbar, logϵbar end diff --git a/test/hmc_core.jl/gdemo_hmc.jl b/test/hmc_core.jl/gdemo_hmc.jl index 1a08ce983..93f866bdf 100644 --- a/test/hmc_core.jl/gdemo_hmc.jl +++ b/test/hmc_core.jl/gdemo_hmc.jl @@ -26,11 +26,9 @@ end totla_num = 5000 for iter = 1:totla_num - push!(chn[:θ], θ) θ, lj, is_accept, τ_valid, α = _hmc_step(θ, lj, lj_func, grad_func, 5, 0.05, std) accept_num += is_accept - end @show lj diff --git a/test/hmc_core.jl/gdemo_nuts.jl b/test/hmc_core.jl/gdemo_nuts.jl index a92fc48bd..ecd82de17 100644 --- a/test/hmc_core.jl/gdemo_nuts.jl +++ b/test/hmc_core.jl/gdemo_nuts.jl @@ -44,7 +44,7 @@ for test_id = 1:2 totla_num = 10000 for iter = 1:totla_num - + global logϵ, lj_func, grad_func, M_adapt, δ, μ θ, da_stat = _nuts_step(θ, exp(logϵ), lj_func, grad_func, std) if test_id == 1 logϵ, Hbar, logϵbar = _adapt_ϵ(logϵ, Hbar, logϵbar, da_stat, iter, M_adapt, δ, μ) diff --git a/test/hmc_core.jl/unit_test_helper.jl b/test/hmc_core.jl/unit_test_helper.jl index 6c934258b..44a0d81cf 100644 --- a/test/hmc_core.jl/unit_test_helper.jl +++ b/test/hmc_core.jl/unit_test_helper.jl @@ -2,7 +2,7 @@ using Test function test_grad(turing_model, grad_f; trans=Dict()) model_f = turing_model() - vi = Base.invokelatest(model_f) + vi = model_f() for i in trans vi.flags["trans"][i] = true end diff --git a/test/hmcda.jl/hmcda_geweke.jl b/test/hmcda.jl/hmcda_geweke.jl index 16f7047b0..c1a689ab9 100644 --- a/test/hmcda.jl/hmcda_geweke.jl +++ b/test/hmcda.jl/hmcda_geweke.jl @@ -43,7 +43,8 @@ x = [s[:y][1]...] s_bk = Array{Turing.Chain}(undef, N) simple_logger = Base.CoreLogging.SimpleLogger(stderr, Base.CoreLogging.Debug) -with_logger(simple_logger) do +Base.CoreLogging.with_logger(simple_logger) do + global x, bk, s_bk i = 1 while i <= N s_bk[i] = sample(gdemo_bk(x), bk); diff --git a/test/mh.jl/mh_cons.jl b/test/mh.jl/mh_cons.jl index 6d66c3022..e57b067c8 100644 --- a/test/mh.jl/mh_cons.jl +++ b/test/mh.jl/mh_cons.jl @@ -1,5 +1,7 @@ using Turing using Test +using Random +Random.seed!(0) @model gdemo() = begin s ~ InverseGamma(2,3) diff --git a/test/varinfo.jl/varinfo.jl b/test/varinfo.jl/varinfo.jl index 9d308f858..922b6f849 100644 --- a/test/varinfo.jl/varinfo.jl +++ b/test/varinfo.jl/varinfo.jl @@ -89,7 +89,7 @@ end # Test the update of group IDs g_demo_f = gdemo() -g = Turing.Sampler(Gibbs(1000, PG(10, 2, :x, :y, :z), HMC(1, 0.4, 8, :w, :u))) +g = Turing.Sampler(Gibbs(1000, PG(10, 2, :x, :y, :z), HMC(1, 0.4, 8, :w, :u)), g_demo_f) pg, hmc = g.info[:samplers]