Skip to content

Commit 7e35f64

Browse files
committed
Various fixes
1 parent 5551b73 commit 7e35f64

File tree

3 files changed

+92
-14
lines changed

3 files changed

+92
-14
lines changed

src/contexts/transformation.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ function tilde_assume!!(
2121
# vi[vn, right] always provides the value in unlinked space.
2222
x = vi[vn, right]
2323

24-
if is_transformed(vi, vn)
25-
isinverse || @warn "Trying to link an already transformed variable ($vn)"
26-
else
27-
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
28-
end
24+
# TODO(mhauru) Warnings disabled for benchmarking purposes
25+
# if is_transformed(vi, vn)
26+
# isinverse || @warn "Trying to link an already transformed variable ($vn)"
27+
# else
28+
# isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
29+
# end
2930

3031
transform = isinverse ? identity : link_transform(right)
3132
y, logjac = with_logabsdet_jacobian(transform, x)

src/varinfo.jl

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ function typed_vector_varinfo(
358358
end
359359

360360
function make_leaf_metadata((r, dist), optic)
361-
md = Metadata(Float64)
361+
md = Metadata(Float64, VarName{:_})
362362
vn = VarName{:_}(optic)
363363
push!(md, vn, r, dist)
364364
return md
@@ -439,13 +439,13 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)
439439
440440
Construct an empty type unstable instance of `Metadata`.
441441
"""
442-
function Metadata(eltype=Real)
442+
function Metadata(eltype=Real, vntype=VarName)
443443
vals = Vector{eltype}()
444444
is_transformed = BitVector()
445445

446446
return Metadata(
447-
Dict{VarName,Int}(),
448-
Vector{VarName}(),
447+
Dict{vntype,Int}(),
448+
Vector{vntype}(),
449449
Vector{UnitRange{Int}}(),
450450
vals,
451451
Vector{Distribution}(),
@@ -814,7 +814,7 @@ The values may or may not be transformed to Euclidean space.
814814
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
815815
function setval!(vi::TupleVarInfo, val, vn::VarName)
816816
main_vn, optic = split_trailing_index(vn)
817-
return setval!(getindex(vi.metadata, main_vn), VarName{:_}(optic))
817+
return setval!(getindex(vi.metadata, main_vn), val, VarName{:_}(optic))
818818
end
819819
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
820820
return md.vals[getrange(md, vn)] = val
@@ -1980,3 +1980,80 @@ end
19801980
function from_linked_internal_transform(::VarNamedVector, ::VarName, dist)
19811981
return from_linked_vec_transform(dist)
19821982
end
1983+
1984+
function link(vi::TupleVarInfo, model::Model)
1985+
metadata = map(value -> link(value, model), vi.metadata)
1986+
return VarInfo(metadata, vi.accs)
1987+
end
1988+
1989+
function link(metadata::Metadata, model::Model)
1990+
vns = metadata.vns
1991+
cumulative_logjac = zero(LogProbType)
1992+
1993+
# Construct the new transformed values, and keep track of their lengths.
1994+
vals_new = map(vns) do vn
1995+
# Return early if we're already in unconstrained space.
1996+
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
1997+
if is_transformed(metadata, vn)
1998+
return metadata.vals[getrange(metadata, vn)]
1999+
end
2000+
2001+
# Transform to constrained space.
2002+
x = getindex_internal(metadata, vn)
2003+
dist = getdist(metadata, vn)
2004+
f_from_internal = from_internal_transform(metadata, vn, dist)
2005+
f_to_linked_internal = inverse(from_linked_internal_transform(metadata, vn, dist))
2006+
f = f_to_linked_internal f_from_internal
2007+
y, logjac = with_logabsdet_jacobian(f, x)
2008+
# Vectorize value.
2009+
yvec = tovec(y)
2010+
# Accumulate the log-abs-det jacobian correction.
2011+
cumulative_logjac += logjac
2012+
# Return the vectorized transformed value.
2013+
return yvec
2014+
end
2015+
2016+
# Determine new ranges.
2017+
ranges_new = similar(metadata.ranges)
2018+
offset = 0
2019+
for (i, v) in enumerate(vals_new)
2020+
r_start, r_end = offset + 1, length(v) + offset
2021+
offset = r_end
2022+
ranges_new[i] = r_start:r_end
2023+
end
2024+
2025+
# Now we just create a new metadata with the new `vals` and `ranges`.
2026+
return Metadata(
2027+
metadata.idcs,
2028+
metadata.vns,
2029+
ranges_new,
2030+
reduce(vcat, vals_new),
2031+
metadata.dists,
2032+
BitVector(fill(true, length(metadata.vns))),
2033+
)
2034+
end
2035+
2036+
function Base.haskey(vi::TupleVarInfo, vn::VarName)
2037+
# TODO(mhauru) Fix this to account for the index.
2038+
main_vn, optic = split_trailing_index(vn)
2039+
haskey(vi.metadata, main_vn) || return false
2040+
value = getindex(vi.metadata, main_vn)
2041+
if value isa Metadata
2042+
return haskey(value, VarName{:_}(optic))
2043+
else
2044+
error("TODO(mhauru) Implement me")
2045+
end
2046+
end
2047+
2048+
function BangBang.setindex!!(metadata::Metadata, val, optic)
2049+
return setindex!!(metadata, val, VarName{:_}(optic))
2050+
end
2051+
2052+
function BangBang.setindex!!(metadata::Metadata, (r, dist), vn::VarName)
2053+
if haskey(metadata, vn)
2054+
setval!(metadata, r, vn)
2055+
else
2056+
push!(metadata, vn, r, dist)
2057+
end
2058+
return metadata
2059+
end

src/varname.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
4646
return if Optic === typeof(identity)
4747
vn
48-
elseif Optic isa IndexLens
48+
elseif Optic <: Accessors.IndexLens
4949
VarName{sym}()
5050
else
5151
prefix(remove_trailing_index(unprefix(vn, sym)), sym)
@@ -55,10 +55,10 @@ end
5555
function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
5656
return if Optic === typeof(identity)
5757
(vn, identity)
58-
elseif Optic isa IndexLens
59-
(VarName{sym}(), Optic.index)
58+
elseif Optic <: Accessors.IndexLens
59+
(VarName{sym}(), getoptic(vn))
6060
else
61-
(prefix, index) = split_trailing_index(unprefix(vn, sym))
61+
(prefix, index) = split_trailing_index(unprefix(vn, VarName{sym}()))
6262
(prefix(prefix, sym), index)
6363
end
6464
end

0 commit comments

Comments
 (0)