@@ -358,7 +358,7 @@ function typed_vector_varinfo(
358358end
359359
360360function make_leaf_metadata ((r, dist), optic)
361- md = Metadata (Float64)
361+ md = Metadata (Float64, VarName{ :_ } )
362362 vn = VarName {:_} (optic)
363363 push! (md, vn, r, dist)
364364 return md
@@ -439,13 +439,13 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)
439439
440440Construct an empty type unstable instance of `Metadata`.
441441"""
442- function Metadata (eltype= Real)
442+ function Metadata (eltype= Real, vntype = VarName )
443443 vals = Vector {eltype} ()
444444 is_transformed = BitVector ()
445445
446446 return Metadata (
447- Dict {VarName ,Int} (),
448- Vector {VarName } (),
447+ Dict {vntype ,Int} (),
448+ Vector {vntype } (),
449449 Vector {UnitRange{Int}} (),
450450 vals,
451451 Vector {Distribution} (),
@@ -814,7 +814,7 @@ The values may or may not be transformed to Euclidean space.
814814setval! (vi:: VarInfo , val, vn:: VarName ) = setval! (getmetadata (vi, vn), val, vn)
815815function setval! (vi:: TupleVarInfo , val, vn:: VarName )
816816 main_vn, optic = split_trailing_index (vn)
817- return setval! (getindex (vi. metadata, main_vn), VarName {:_} (optic))
817+ return setval! (getindex (vi. metadata, main_vn), val, VarName {:_} (optic))
818818end
819819function setval! (md:: Metadata , val:: AbstractVector , vn:: VarName )
820820 return md. vals[getrange (md, vn)] = val
@@ -1980,3 +1980,80 @@ end
19801980function from_linked_internal_transform (:: VarNamedVector , :: VarName , dist)
19811981 return from_linked_vec_transform (dist)
19821982end
1983+
1984+ function link (vi:: TupleVarInfo , model:: Model )
1985+ metadata = map (value -> link (value, model), vi. metadata)
1986+ return VarInfo (metadata, vi. accs)
1987+ end
1988+
1989+ function link (metadata:: Metadata , model:: Model )
1990+ vns = metadata. vns
1991+ cumulative_logjac = zero (LogProbType)
1992+
1993+ # Construct the new transformed values, and keep track of their lengths.
1994+ vals_new = map (vns) do vn
1995+ # Return early if we're already in unconstrained space.
1996+ # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
1997+ if is_transformed (metadata, vn)
1998+ return metadata. vals[getrange (metadata, vn)]
1999+ end
2000+
2001+ # Transform to constrained space.
2002+ x = getindex_internal (metadata, vn)
2003+ dist = getdist (metadata, vn)
2004+ f_from_internal = from_internal_transform (metadata, vn, dist)
2005+ f_to_linked_internal = inverse (from_linked_internal_transform (metadata, vn, dist))
2006+ f = f_to_linked_internal ∘ f_from_internal
2007+ y, logjac = with_logabsdet_jacobian (f, x)
2008+ # Vectorize value.
2009+ yvec = tovec (y)
2010+ # Accumulate the log-abs-det jacobian correction.
2011+ cumulative_logjac += logjac
2012+ # Return the vectorized transformed value.
2013+ return yvec
2014+ end
2015+
2016+ # Determine new ranges.
2017+ ranges_new = similar (metadata. ranges)
2018+ offset = 0
2019+ for (i, v) in enumerate (vals_new)
2020+ r_start, r_end = offset + 1 , length (v) + offset
2021+ offset = r_end
2022+ ranges_new[i] = r_start: r_end
2023+ end
2024+
2025+ # Now we just create a new metadata with the new `vals` and `ranges`.
2026+ return Metadata (
2027+ metadata. idcs,
2028+ metadata. vns,
2029+ ranges_new,
2030+ reduce (vcat, vals_new),
2031+ metadata. dists,
2032+ BitVector (fill (true , length (metadata. vns))),
2033+ )
2034+ end
2035+
2036+ function Base. haskey (vi:: TupleVarInfo , vn:: VarName )
2037+ # TODO (mhauru) Fix this to account for the index.
2038+ main_vn, optic = split_trailing_index (vn)
2039+ haskey (vi. metadata, main_vn) || return false
2040+ value = getindex (vi. metadata, main_vn)
2041+ if value isa Metadata
2042+ return haskey (value, VarName {:_} (optic))
2043+ else
2044+ error (" TODO(mhauru) Implement me" )
2045+ end
2046+ end
2047+
2048+ function BangBang. setindex!! (metadata:: Metadata , val, optic)
2049+ return setindex!! (metadata, val, VarName {:_} (optic))
2050+ end
2051+
2052+ function BangBang. setindex!! (metadata:: Metadata , (r, dist), vn:: VarName )
2053+ if haskey (metadata, vn)
2054+ setval! (metadata, r, vn)
2055+ else
2056+ push! (metadata, vn, r, dist)
2057+ end
2058+ return metadata
2059+ end
0 commit comments