Skip to content

Commit

Permalink
Remove tonamedtuple (#547)
Browse files Browse the repository at this point in the history
* Remove dependencies to `tonamedtuple`

* Remove `tonamedtuple`s

* Minor version bump

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
sunxd3 and yebai authored Oct 26, 2023
1 parent 2e8adf4 commit 04b03cd
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 89 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.21"
version = "0.24.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ DynamicPPL.reconstruct
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```
Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ export AbstractVarInfo,
invlink,
invlink!,
invlink!!,
tonamedtuple,
values_as,
# VarName (reexport from AbstractPPL)
VarName,
Expand Down
15 changes: 0 additions & 15 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,21 +738,6 @@ function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::Abstrac
return unflatten(varinfo, sampler, θ)
end

"""
tonamedtuple(vi::AbstractVarInfo)
Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and
indexing string of the variable.
For example, a model that had a vector of vector-valued
variables `x` would return
```julia
(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), )
```
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
Expand Down
38 changes: 0 additions & 38 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,44 +532,6 @@ function dot_assume(
return value, lp, vi
end

# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl.
# TODO: Move away from using these `tonamedtuple` methods.
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names}
nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end

return NamedTuple{names}(nt_vals)
end

function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}()
for vn in keys(vi)
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
sym = only(unique(map(getsym, vns)))

# Initialize entry if not yet initialized.
if !haskey(syms_to_result, sym)
syms_to_result[sym] = (Real[], String[])
end

# Combine with old result.
old_vals, old_string_vns = syms_to_result[sym]
syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns)))
end

# Construct `NamedTuple`.
return NamedTuple(pairs(syms_to_result))
end

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
function settrans!!(vi::SimpleVarInfo, trans)
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
Expand Down
2 changes: 0 additions & 2 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return is_flagged(vi.varinfo, vn, flag)
end

tonamedtuple(vi::ThreadSafeVarInfo) = tonamedtuple(vi.varinfo)

# Transformations.
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
Expand Down
16 changes: 0 additions & 16 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1506,22 +1506,6 @@ end
return expr
end

# TODO: Remove this completely.
tonamedtuple(varinfo::VarInfo) = tonamedtuple(varinfo.metadata, varinfo)
function tonamedtuple(metadata::NamedTuple{names}, varinfo::VarInfo) where {names}
length(names) === 0 && return NamedTuple()

vals_tuple = map(values(metadata)) do x
# NOTE: `tonamedtuple` is really only used in Turing.jl to convert to
# a "transition". This means that we really don't mutations of the values
# in `varinfo` to propoagate the previous samples. Hence we `copy.`
vals = map(copy Base.Fix1(getindex, varinfo), x.vns)
return vals, map(string, x.vns)
end

return NamedTuple{names}(vals_tuple)
end

@inline function findvns(vi, f_vns)
if length(f_vns) == 0
throw("Unidentified error, please report this error in an issue.")
Expand Down
30 changes: 15 additions & 15 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,22 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1)
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
nt = DynamicPPL.tonamedtuple(var_info)
for (k, (vals, names)) in pairs(nt)
for (n, v) in zip(names, vals)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

Expand Down

0 comments on commit 04b03cd

Please sign in to comment.