Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/subset-and-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 13, 2023
2 parents cf02816 + 927799f commit d02cb61
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 27 deletions.
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.18"
version = "0.23.19"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -21,13 +22,20 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
LogDensityProblems = "2"
Expand All @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
[compat]
DataStructures = "0.18"
Distributions = "0.25"
Documenter = "0.27"
Documenter = "1"
FillArrays = "0.13, 1"
LogDensityProblems = "2"
MCMCChains = "5, 6"
Expand Down
1 change: 0 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ makedocs(;
"API" => "api.md",
"Tutorials" => ["tutorials/prob-interface.md"],
],
strict=true,
checkdocs=:exports,
)

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module DynamicPPL
using AbstractMCMC: AbstractSampler, AbstractChains
using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict

Expand Down
77 changes: 60 additions & 17 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1116,33 +1116,59 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f)
return vi
end

# HACK: We need `SampleFromPrior` to result in ALL values which are in need
# of a transformation to be transformed. `_getvns` will by default return
# an empty iterable for `SampleFromPrior`, so we need to override it here.
# This is quite hacky, but seems safer than changing the behavior of `_getvns`.
_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl)
_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing
function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior)
return map(Returns(nothing), varinfo.metadata)
end

function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model)
return _link(varinfo)
return _link(varinfo, spl)
end

function _link(varinfo::UntypedVarInfo)
function _link(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_link_metadata!(varinfo, varinfo.metadata),
_link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _link(varinfo::TypedVarInfo)
function _link(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata)
# TODO: Update logp, etc.
md = _link_metadata_namedtuple!(
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
)
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

function _link_metadata!(varinfo::VarInfo, metadata::Metadata)
@generated function _link_metadata_namedtuple!(
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f)))
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in unconstrained space.
if istrans(varinfo, vn)
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
if istrans(varinfo, vn) || (target_vns !== nothing && vn target_vns)
return metadata.vals[getrange(metadata, vn)]
end

Expand Down Expand Up @@ -1186,32 +1212,49 @@ end
function invlink(
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
)
return _invlink(varinfo)
return _invlink(varinfo, spl)
end

function _invlink(varinfo::UntypedVarInfo)
function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
return VarInfo(
_invlink_metadata!(varinfo, varinfo.metadata),
_invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _invlink(varinfo::TypedVarInfo)
function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata)
# TODO: Update logp, etc.
md = _invlink_metadata_namedtuple!(
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
)
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
end

function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata)
@generated function _invlink_metadata_namedtuple!(
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
) where {names,space}
vals = Expr(:tuple)
for f in names
if inspace(f, space) || length(space) == 0
push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)))
else
push!(vals.args, :(metadata.$f))
end
end

return :(NamedTuple{$names}($vals))
end
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
vns = metadata.vns

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn
# Return early if we're already in constrained space.
if !istrans(varinfo, vn)
# Return early if we're already in constrained space OR if we're not
# supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler.
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
if !istrans(varinfo, vn) || (target_vns !== nothing && vn target_vns)
return metadata.vals[getrange(metadata, vn)]
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27"
Documenter = "0.26.1, 0.27, 1"
ForwardDiff = "0.10.12"
LogDensityProblems = "2"
MCMCChains = "4.0.4, 5, 6"
Expand Down
42 changes: 42 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ function check_varinfo_keys(varinfo, vns)
end
end

# A simple "algorithm" which only has `s` variables in its space.
struct MySAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)

@testset "varinfo.jl" begin
@testset "TypedVarInfo" begin
@model gdemo(x, y) = begin
Expand Down Expand Up @@ -539,4 +543,42 @@ end
end
end
end

@testset "VarInfo with selectors" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
varinfo = VarInfo(model)
selector = DynamicPPL.Selector()
spl = Sampler(MySAlg(), model, selector)

vns = DynamicPPL.TestUtils.varnames(model)
vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns)
vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns)
for vn in vns_s
DynamicPPL.updategid!(varinfo, vn, spl)
end

# Should only get the variables subsumed by `@varname(s)`.
@test varinfo[spl] ==
mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s)

# `link`
varinfo_linked = DynamicPPL.link(varinfo, spl, model)
# `s` variables should be linked
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
# `m` variables should NOT be linked
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
# And `varinfo` should be unchanged
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns)

# `invlink`
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model)
# `s` variables should no longer be linked
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s)
# `m` variables should still not be linked
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m)
# And `varinfo_linked` should be unchanged
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
end
end
end

0 comments on commit d02cb61

Please sign in to comment.