Skip to content

Commit

Permalink
Compiler refactor 2.0 (#613)
Browse files Browse the repository at this point in the history
* compiler refactor proof of concept

* rename compiler to model_info

* improve compiler readibility and make it more compact

* invokelatest -> calling model

* respond to Martin's comments

* fix ambiguity

* fix missing data case and make more compact

* fix broken compiler.jl/explicit_ret.jl test

* fix sampling from prior

* make compiler more compact

* fix update_vars!

* fix tests

* fix mh.jl tests

* fix Sampler API

* uncomment commented tests

* fix seed in test/mh.jl/mh_cons.jl

* make pvars and dvars type parameters in CallableModel

* make data a field in CallableModel

* fix #544 - allows default value of data vars when treated as params

* add support for passing data as kwargs to outer function

* model::Function -> model

* CallableModel -> Model

* compiler hygiene and organization

* docstrings and some renaming

* respond to Will's comments

* some more style issues

* add test for #544

* comment

* model -> model::Model
  • Loading branch information
mohamed82008 authored and yebai committed Dec 28, 2018
1 parent bf3494a commit 0e439dd
Show file tree
Hide file tree
Showing 28 changed files with 434 additions and 421 deletions.
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ mutable struct Sampler{T<:InferenceAlgorithm} <: AbstractSampler
alg :: T
info :: Dict{Symbol, Any} # sampler infomation
end
Sampler(alg, model) = Sampler(alg)

# mutable struct HMCState{T<:Real}
# epsilon :: T
Expand Down
12 changes: 6 additions & 6 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ getADtype(::Type{<:Hamiltonian{AD}}) where {AD} = AD
gradient(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Function,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
)
Expand All @@ -40,7 +40,7 @@ Computes the gradient of the log joint of `θ` for the model specified by
@generated function gradient(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Function,
model::Model,
sampler::TS=nothing,
) where {TS <: Union{Nothing, Sampler}}
if TS == Nothing
Expand Down Expand Up @@ -68,7 +68,7 @@ end
gradient_forward(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Function,
model::Model,
spl::Union{Nothing, Sampler}=nothing,
chunk_size::Int=CHUNKSIZE[],
)
Expand All @@ -79,7 +79,7 @@ using forwards-mode AD from ForwardDiff.jl.
function gradient_forward(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Function,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
::Val{chunk_size}=Val(CHUNKSIZE[]),
) where chunk_size
Expand Down Expand Up @@ -111,7 +111,7 @@ end
gradient_reverse(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Function,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
)
Expand All @@ -121,7 +121,7 @@ Computes the gradient of the log joint of `θ` for the model specified by
function gradient_reverse(
θ::AbstractVector{<:Real},
vi::Turing.VarInfo,
model::Function,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
)
vals_old, logp_old = copy(vi.vals), copy(vi.logp)
Expand Down
Loading

0 comments on commit 0e439dd

Please sign in to comment.