Skip to content

Commit

Permalink
Switch to DynamicPPL (#1042)
Browse files Browse the repository at this point in the history
* use DynamicPPL

* add DynamicPPL to Project.toml

* resolve deprecation

* fix adaptor with n_adapts=0

* fix adaptor with n_adapts=0 for stan interface and add fixme

* raise compat of AdvancedHMC

Co-authored-by: Kai Xu <[email protected]>
  • Loading branch information
2 people authored and yebai committed Jan 5, 2020
1 parent 4afcf2d commit 7b81df4
Show file tree
Hide file tree
Showing 20 changed files with 42 additions and 3,470 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -28,11 +29,12 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[compat]
AbstractMCMC = "~0.1"
AdvancedHMC = "0.2.4"
AdvancedHMC = "0.2.17"
Bijectors = "0.4.0, 0.5"
BinaryProvider = "0.5.6"
Distributions = "0.21.11"
DistributionsAD = "0.1.2"
DynamicPPL = "0.1.0"
FiniteDifferences = "0.9"
ForwardDiff = "0.10.3"
Libtask = "0.3.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/using-turing/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ mf(vi, sampler, ctx, model) = begin
end

# Instantiate a Model object.
model = Turing.Model(mf, data, Turing.Core.ModelGen{()}(nothing, nothing))
model = DynamicPPL.Model(mf, data, DyanamicPPL.ModelGen{()}(nothing, nothing))

# Sample the model.
chain = sample(model, HMC(0.1, 5), 1000)
Expand Down
153 changes: 2 additions & 151 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,169 +18,19 @@ using Tracker: Tracker

import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
import MCMCChains: AbstractChains, Chains
import DynamicPPL: getspace, runmodel!

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
@info("[Turing]: global PROGRESS is set as $switch")
PROGRESS[] = switch
end

# Constants for caching
const CACHERESET = 0b00
const CACHEIDCS = 0b10
const CACHERANGES = 0b01

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_TURING", "0")))

# Random probability measures.
include("stdlib/distributions.jl")
include("stdlib/RandomMeasures.jl")

"""
struct Model{F, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val}
f::F
args::Targs
modelgen::Tmodelgen
missings::Tmissings
end
A `Model` struct with arguments `args`, inner function `f`, model generator `modelgen` and
missing data `missings`. `missings` is a `Val` instance, e.g. `Val{(:a, :b)}()`. An
argument in `args` with a value `missing` will be in `missings` by default. However, in
non-traditional use-cases `missings` can be defined differently. All variables in
`missings` are treated as random variables rather than observations.
"""
struct Model{F, Targs <: NamedTuple, Tmodelgen, Tmissings <: Val} <: AbstractModel
f::F
args::Targs
modelgen::Tmodelgen
missings::Tmissings
end
Model(f, args::NamedTuple, modelgen) = Model(f, args, modelgen, getmissing(args))
(model::Model)(vi) = model(vi, SampleFromPrior())
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)

getmissing(model::Model) = model.missings
@generated function getmissing(args::NamedTuple{names, ttuple}) where {names, ttuple}
length(names) == 0 && return :(Val{()}())
minds = filter(1:length(names)) do i
ttuple.types[i] == Missing
end
mnames = names[minds]
return :(Val{$mnames}())
end

function runmodel! end
function getspace end

struct Selector
gid :: UInt64
tag :: Symbol # :default, :invalid, :Gibbs, :HMC, etc.
rerun :: Bool
end
function Selector(tag::Symbol = :default, rerun = tag != :default)
return Selector(time_ns(), tag, rerun)
end
function Selector(gid::Integer, tag::Symbol = :default)
return Selector(gid, tag, tag != :default)
end
hash(s::Selector) = hash(s.gid)
==(s1::Selector, s2::Selector) = s1.gid == s2.gid

"""
Robust initialization method for model parameters in Hamiltonian samplers.
"""
struct SampleFromUniform <: AbstractSampler end
struct SampleFromPrior <: AbstractSampler end

getspace(::Union{SampleFromPrior, SampleFromUniform}) = ()

"""
An abstract type that mutable sampler state structs inherit from.
"""
abstract type AbstractSamplerState end

"""
Sampler{T}
Generic interface for implementing inference algorithms.
An implementation of an algorithm should include the following:
1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
2. A method of `sample` function that produces results of inference, which is where actual inference happens.
Turing translates models to chunks that call the modelling functions at specified points.
The dispatch is based on the value of a `sampler` variable.
To include a new inference algorithm implements the requirements mentioned above in a separate file,
then include that file at the end of this one.
"""
mutable struct Sampler{T, S<:AbstractSamplerState} <: AbstractSampler
alg :: T
info :: Dict{Symbol, Any} # sampler infomation
selector :: Selector
state :: S
end
Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)

abstract type AbstractContext end

"""
struct DefaultContext <: AbstractContext end
The `DefaultContext` is used by default to compute log the joint probability of the data
and parameters when running the model.
"""
struct DefaultContext <: AbstractContext end

"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)

"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
The `LikelihoodContext` enables the computation of the log likelihood of the data when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)

"""
struct MiniBatchContext{Tctx, T} <: AbstractContext
ctx::Tctx
loglike_scalar::T
end
The `MiniBatchContext` enables the computation of
`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
`loglike_scalar` field, typically equal to `the number of data points / batch size`.
This is useful in batch-based stochastic gradient descent algorithms to be optimizing
`log(prior) + log(likelihood of all the data points)` in the expectation.
"""
struct MiniBatchContext{Tctx, T} <: AbstractContext
ctx::Tctx
loglike_scalar::T
end
function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints)
return MiniBatchContext(ctx, npoints/batch_size)
end
include("utilities/Utilities.jl")
using .Utilities
include("core/Core.jl")
Expand Down Expand Up @@ -217,6 +67,7 @@ export @model, # modelling
@varinfo,
@logpdf,
@sampler,
DynamicPPL,

MH, # classic sampling
ESS,
Expand Down
9 changes: 3 additions & 6 deletions src/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@ using MacroTools, Libtask, ForwardDiff, Random
using Distributions, LinearAlgebra
using ..Utilities, Reexport
using Tracker: Tracker
using ..Turing: Turing, Model, runmodel!,
using ..Turing: Turing
using DynamicPPL: Model, runmodel!,
AbstractSampler, Sampler, SampleFromPrior
using LinearAlgebra: copytri!
using Bijectors: PDMatDistribution
import Bijectors: link, invlink
using DistributionsAD
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL

include("RandomVariables.jl")
@reexport using .RandomVariables

include("compiler.jl")
include("container.jl")
include("ad.jl")
include("prob_macro.jl")

export @model,
@varname,
Expand Down
Loading

0 comments on commit 7b81df4

Please sign in to comment.