Skip to content

InitContext, part 2 - Move hasvalue and getvalue to AbstractPPL; enforce key type of AbstractDict #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: breaking
Choose a base branch
from
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11, 0.12"
AbstractPPL = "0.13"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
AbstractPPL = "0.11, 0.12"
AbstractPPL = "0.13"
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using DocStringExtensions
using Random: Random

# For extending
import AbstractPPL: predict
import AbstractPPL: predict, hasvalue, getvalue

# TODO: Remove these when it's possible.
import Bijectors: link, invlink
Expand Down
12 changes: 8 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
Generate a sample of type `T` from the prior distribution of the `model`.
"""
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict())))
x = last(
evaluate_and_sample!!(
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
),
)
return values_as(x, T)
end

Expand Down Expand Up @@ -1028,7 +1032,7 @@ julia> logjoint(demo_model([1., 2.]), chain);
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
Expand Down Expand Up @@ -1082,7 +1086,7 @@ julia> logprior(demo_model([1., 2.]), chain);
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
Expand Down Expand Up @@ -1136,7 +1140,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain);
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
argvals_dict = OrderedDict{VarName,Any}(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
Expand Down
16 changes: 8 additions & 8 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,19 @@ ERROR: type NamedTuple has no field x
[...]

julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict()));
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));

julia> # (✓) Sort of fast, but only possible at runtime.
vi[@varname(x[1])]
-1.019202452456547

julia> # In addtion, we can only access varnames as they appear in the model!
vi[@varname(x)]
ERROR: KeyError: key x not found
ERROR: x was not found in the dictionary provided
[...]

julia> vi[@varname(x[1:2])]
ERROR: KeyError: key x[1:2] not found
ERROR: x[1:2] was not found in the dictionary provided
[...]
```

Expand Down Expand Up @@ -107,7 +107,7 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
true

julia> # And with `OrderedDict` of course!
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true));
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));

julia> vi[@varname(x)] # (✓) -∞ < x < ∞
0.6225185067787314
Expand Down Expand Up @@ -177,11 +177,11 @@ julia> svi_dict[@varname(m.a[1])]
1.0

julia> svi_dict[@varname(m.a[2])]
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
ERROR: m.a[2] was not found in the dictionary provided
[...]

julia> svi_dict[@varname(m.b)]
ERROR: type NamedTuple has no field b
ERROR: m.b was not found in the dictionary provided
[...]
```
"""
Expand Down Expand Up @@ -212,7 +212,7 @@ end
function SimpleVarInfo(values)
return SimpleVarInfo{LogProbType}(values)
end
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict})
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}})
return if isempty(values)
# Can't infer from values, so we just use default.
SimpleVarInfo{LogProbType}(values)
Expand Down Expand Up @@ -264,7 +264,7 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
end

function untyped_simple_varinfo(model::Model)
varinfo = SimpleVarInfo(OrderedDict())
varinfo = SimpleVarInfo(OrderedDict{VarName,Any}())
return last(evaluate_and_sample!!(model, varinfo))
end

Expand Down
2 changes: 1 addition & 1 deletion src/test_utils/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function setup_varinfos(

# SimpleVarInfo
svi_typed = SimpleVarInfo(example_values)
svi_untyped = SimpleVarInfo(OrderedDict())
svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}())
svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector())

varinfos = map((
Expand Down
193 changes: 0 additions & 193 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,199 +751,6 @@ function unflatten(original::AbstractDict, x::AbstractVector)
return D(zip(keys(original), unflatten(collect(values(original)), x)))
end

# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl.
"""
getvalue(vals, vn::VarName)

Return the value(s) in `vals` represented by `vn`.

Note that this method is different from `getindex`. See examples below.

# Examples

For `NamedTuple`:

```jldoctest
julia> vals = (x = [1.0],);

julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex`
1-element Vector{Float64}:
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex`
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[2]))
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
[...]
```

For `AbstractDict`:

```jldoctest
julia> vals = Dict(@varname(x) => [1.0]);

julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex`
1-element Vector{Float64}:
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex`
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[2]))
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
[...]
```

In the `AbstractDict` case we can also have keys such as `v[1]`:

```jldoctest
julia> vals = Dict(@varname(x[1]) => [1.0,]);

julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex`
1-element Vector{Float64}:
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex`
1.0

julia> DynamicPPL.getvalue(vals, @varname(x[1][2]))
ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
[...]

julia> DynamicPPL.getvalue(vals, @varname(x[2][1]))
ERROR: KeyError: key x[2][1] not found
[...]
```
"""
getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn)
getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn)

"""
hasvalue(vals, vn::VarName)

Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref).

# Examples
With `x` as a `NamedTuple`:

```jldoctest
julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x))
true

julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1]))
false

julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x))
true

julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1]))
true

julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2]))
false
```

With `x` as a `AbstractDict`:

```jldoctest
julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x))
true

julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1]))
false

julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x))
true

julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1]))
true

julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2]))
false
```

In the `AbstractDict` case we can also have keys such as `v[1]`:

```jldoctest
julia> vals = Dict(@varname(x[1]) => [1.0,]);

julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey`
true

julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey`
true

julia> DynamicPPL.hasvalue(vals, @varname(x[1][2]))
false

julia> DynamicPPL.hasvalue(vals, @varname(x[2][1]))
false
```
"""
function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym}
# LHS: Ensure that `nt` indeed has the property we want.
# RHS: Ensure that the optic can view into `nt`.
return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym))
end

# For `dictlike` we need to check wether `vn` is "immediately" present, or
# if some ancestor of `vn` is present in `dictlike`.
function hasvalue(vals::AbstractDict, vn::VarName)
# First we check if `vn` is present as is.
haskey(vals, vn) && return true

# If `vn` is not present, we check any parent-varnames by attempting
# to split the optic into the key / `parent` and the extraction optic / `child`.
# If `issuccess` is `true`, we found such a split, and hence `vn` is present.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(vals, VarName{getsym(vn)}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent

# Return early if no such split could be found.
issuccess || return false

# At this point we just need to check that we `canview` the value.
value = vals[VarName{getsym(vn)}(keyoptic)]

return canview(child, value)
end

"""
nested_getindex(values::AbstractDict, vn::VarName)

Return value corresponding to `vn` in `values` by also looking
in the the actual values of the dict.
"""
function nested_getindex(values::AbstractDict, vn::VarName)
maybeval = get(values, vn, nothing)
if maybeval !== nothing
return maybeval
end

# Split the optic into the key / `parent` and the extraction optic / `child`.
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
o = optic === nothing ? identity : optic
haskey(values, VarName{getsym(vn)}(o))
end
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
keyoptic = parent === nothing ? identity : parent

# If we found a valid split, then we can extract the value.
if !issuccess
# At this point we just throw an error since the key could not be found.
throw(KeyError(vn))
end

# TODO: Should we also check that we `canview` the extracted `value`
# rather than just let it fail upon `get` call?
value = values[VarName{getsym(vn)}(keyoptic)]
return child(value)
end

"""
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)

Expand Down
4 changes: 2 additions & 2 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ $(TYPEDFIELDS)
"""
struct ValuesAsInModelAccumulator <: AbstractAccumulator
"values that are extracted from the model"
values::OrderedDict
values::OrderedDict{<:VarName}
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
end
function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
end

function Base.copy(acc::ValuesAsInModelAccumulator)
Expand Down
2 changes: 1 addition & 1 deletion src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict}
end

# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how
# they differ from `haskey` and `getindex`. They can be found in src/utils.jl.
# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl.

# TODO(mhauru) This is tricky to implement in the general case, and the below implementation
# only covers some simple cases. It's probably sufficient in most situations though.
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.11, 0.12"
AbstractPPL = "0.13"
Accessors = "0.1"
Aqua = "0.8"
Bijectors = "0.15.1"
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ using LinearAlgebra # Diagonal

using JET: JET

# need to call this to get the AbstractPPL I think
Pkg.update()

using Combinatorics: combinations
using OrderedCollections: OrderedSet

Expand Down
2 changes: 1 addition & 1 deletion test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
DynamicPPL.TestUtils.DEMO_MODELS
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
@testset "$(typeof(vi))" for vi in (
SimpleVarInfo(Dict()),
SimpleVarInfo(Dict{VarName,Any}()),
SimpleVarInfo(values_constrained),
SimpleVarInfo(DynamicPPL.VarNamedVector()),
DynamicPPL.typed_varinfo(model),
Expand Down
Loading
Loading