Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt at implementation of VarNamedVector (Metadata alternative) #555

Merged
merged 222 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
222 commits
Select commit Hold shift + click to select a range
5af1afa
initial implementation of VarNameVector
torfjelde Oct 31, 2023
8ce53f7
added some hacky getval and getdist get things to work for VarInfo
torfjelde Nov 7, 2023
fc6a051
Apply suggestions from code review
torfjelde Nov 7, 2023
7cd599d
added arbitrary metadata field as discussed
torfjelde Nov 12, 2023
ed0a757
renamed idcs to varname_to_index
torfjelde Nov 12, 2023
4ebd252
renamed vns to varnames for VarNameVector
torfjelde Nov 12, 2023
9f12c9a
added keys impl for Metadata
torfjelde Nov 12, 2023
5a15121
added push! and update! for VarNameVector
torfjelde Nov 13, 2023
edde2c1
added getindex_raw! and setindex_raw! for VarNameVector
torfjelde Nov 13, 2023
ed46002
added `iterate` and `convert` (for `AbstractDict) impls for `VarNameV…
torfjelde Nov 13, 2023
5b00059
make the key and eltype part of the `VarNameVector` type
torfjelde Nov 13, 2023
bef7e0a
added more tests for VarNameVector
torfjelde Nov 13, 2023
006ee8d
formatting
torfjelde Nov 13, 2023
9802811
more testing for VarNameVector
torfjelde Nov 13, 2023
88b1721
minor changes to some comments
torfjelde Nov 13, 2023
ca7b173
added a bunch more tests for VarNameVector + several bugfixes in the …
torfjelde Nov 13, 2023
fb01b94
formatting
torfjelde Nov 13, 2023
9634839
added `similar` implementation for `VarNameVector`
torfjelde Nov 13, 2023
5179f6f
formatting
torfjelde Nov 13, 2023
9f632bb
removed debug statement
torfjelde Nov 13, 2023
3c210f7
made VarInfo slighly more generic wrt. underlying metadata
torfjelde Nov 13, 2023
8bf6589
Merge branch 'master' into torfjelde/varnamevector
torfjelde Nov 14, 2023
8b2720f
fixed incorrect behavior in `keys` for `Metadata`
torfjelde Nov 14, 2023
9fa6446
minor style changes to VarNameVector tests
torfjelde Nov 14, 2023
0900c57
style
torfjelde Nov 14, 2023
1f7e633
added testing of `update!` with smaller sizes and fixed bug related t…
torfjelde Nov 14, 2023
8d05586
formatting
torfjelde Nov 14, 2023
7801fe1
move functionality related to `push!` for `VarNameVector` into `push!`
torfjelde Nov 14, 2023
cdc2373
Update src/varnamevector.jl
torfjelde Nov 16, 2023
d2d776d
Merge branch 'master' into torfjelde/varnamevector
torfjelde Nov 20, 2023
ae4bcb7
several fixes to make sampling with VarNameVector + initiall tests for
torfjelde Dec 30, 2023
97e1bcc
VarInfo + VarNameVector tests for all demo models
torfjelde Dec 30, 2023
be3c1b4
Merge remote-tracking branch 'origin/torfjelde/varnamevector' into to…
torfjelde Dec 30, 2023
ad343f3
Apply suggestions from code review
torfjelde Dec 30, 2023
f707b25
added docs on the design of `VarNameVector`
torfjelde Dec 31, 2023
4e7af1d
Merge branch 'master' into torfjelde/varnamevector
torfjelde Dec 31, 2023
f1faf18
Apply suggestions from code review
torfjelde Dec 31, 2023
87d3d01
added note on `update!`
torfjelde Dec 31, 2023
9c3b265
further elaboration of the design of `VarInfo` and `VarNameVector`
torfjelde Jan 1, 2024
958c66b
more writing improvements
torfjelde Jan 1, 2024
74c6efd
added docstring to `has_inactive_ranges` and `inactive_ranges_sweep!`
torfjelde Jan 1, 2024
d9ea878
moved docs on `VarInfo` design to a separate internals section
torfjelde Jan 1, 2024
5acce98
writing improvements for internal docs
torfjelde Jan 1, 2024
6f95cdd
further motivation of the design choices made in `VarNameVector`
torfjelde Jan 1, 2024
38a4b08
improved writing
torfjelde Jan 1, 2024
60edd10
VarNameVector is now grown as much as needed
torfjelde Jan 1, 2024
3f9d34f
updated `delete!`
torfjelde Jan 2, 2024
fb822b5
Significant changes to implementation of `VarNameVector`:
torfjelde Jan 2, 2024
66bc090
added `copy` when constructing `VectorVarInfo` from `VarInfo`
torfjelde Jan 2, 2024
ccd86f2
added missing `isempty` impl
torfjelde Jan 2, 2024
1d4a000
remove impl of `iterate` and instead implemented `pairs` and `values`…
torfjelde Jan 2, 2024
9a16dd1
added missing `empty!` for `num_inactive`
torfjelde Jan 2, 2024
e49b762
removed redundant `shift_left!` methd
torfjelde Jan 2, 2024
2b445c9
fixed `delete!` for `VarNameVector`
torfjelde Jan 2, 2024
e3c2633
added `is_contiguous` as an alterantive to `!has_inactive`
torfjelde Jan 2, 2024
19a829c
updates to internal docs
torfjelde Jan 2, 2024
a358bc4
renamed `sweep_inactive_ranges!` to `contiguify!`
torfjelde Jan 2, 2024
46be8d5
improvements to internal docs
torfjelde Jan 2, 2024
57d688e
more improvements to internal docs
torfjelde Jan 2, 2024
0968a07
moved additional methods description in internals to earlier in the doc
torfjelde Jan 2, 2024
0d008a4
moved internals docs to a separate directory and split into files
torfjelde Jan 2, 2024
ccd0d64
more improvements to internals doc
torfjelde Jan 2, 2024
7c45e67
formatting
torfjelde Jan 2, 2024
373215b
added tests for `delete!` and fixed reference to old method
torfjelde Jan 2, 2024
0cdafbf
addition to `delete!` test
torfjelde Jan 2, 2024
51c041f
added `values_as` impls for `VarNameVector`
torfjelde Jan 2, 2024
20b3742
added docs for `replace_valus` and `values_as` for `VarNameVector`
torfjelde Jan 2, 2024
ef6c618
fixed doctest
torfjelde Jan 2, 2024
8a1209c
formatting
torfjelde Jan 2, 2024
adeadf0
temporarily disable doctests so we can build docs
torfjelde Jan 2, 2024
7ff179d
added missing compat entry for ForwardDiff in docs
torfjelde Jan 2, 2024
c7ec08a
moved some shared code into methods to make things a bit cleaner
torfjelde Jan 3, 2024
c5a5e58
added impl of `merge` for `VarNameVector`
torfjelde Jan 3, 2024
c376d95
renamed a few variables in `merge` impl for `VarNameVector`
torfjelde Jan 3, 2024
f71baa5
forgot to include some changes in previous commit
torfjelde Jan 3, 2024
af25f3c
added impl of `subset` for `VarNameVector`
torfjelde Jan 3, 2024
c28f076
fixed `pairs` impl for `VarNameVector`
torfjelde Jan 3, 2024
f5d2c63
added missing impl of `subset` for `VectorVarInfo`
torfjelde Jan 3, 2024
3eb6c7f
added missing impl of `merge_metadata` for `VarNameVector`
torfjelde Jan 3, 2024
9ba8144
added a bunch of `from_vec_transform` and `tovec` impls to make
torfjelde Jan 3, 2024
acd6951
make default args use `from_vec_transform` rather than `FromVec`
torfjelde Jan 3, 2024
790f743
fixed `values_as` fro `VarInfo` with `VarNameVector` as `metadata`
torfjelde Jan 3, 2024
c474bb0
fixed impl of `getindex_raw` when using integer index for `VarNameVec…
torfjelde Jan 4, 2024
8251463
added tests for `getindex` with `Int` index for `VarNameVector`
torfjelde Jan 4, 2024
5df7031
fix for `setindex!` and `setindex_raw!` for `VarNameVector`
torfjelde Jan 4, 2024
683b776
introduction of `from_vec_transform` and `tovec` and its usage in `Va…
torfjelde Jan 19, 2024
4dae00d
moved definition of `is_splat_symbol` to the file where it's used
torfjelde Jan 19, 2024
e3b52a4
added `VarInfo` constructor with vector input for `VectorVarInfo`
torfjelde Jan 19, 2024
9626be1
make `extract_priors` take the `rng` as an argument
torfjelde Jan 19, 2024
e731fd6
added `replace_values` for `Metadata`
torfjelde Jan 19, 2024
0785abf
make link and invlink act on the `metadata` field for `VarInfo` +
torfjelde Jan 19, 2024
b3e0955
added temporary defs of `with_logabsdet_jacobian` and `inverse` for
torfjelde Jan 19, 2024
ff963ce
added invlink_with_logpdf overload for `ThreadSafeVarInfo`
torfjelde Jan 19, 2024
03f2b2b
added `is_transformed` field to `VarNameVector`
torfjelde Jan 19, 2024
949b33a
removed unnecessary defintions of `with_logabsdet_jacobian` and
torfjelde Jan 19, 2024
cc5ecc4
fixed issue where we were storing the wrong transformations in `VarNa…
torfjelde Jan 19, 2024
1aae1b4
make sure `extract_priors` doesn't mutate the `varinfo`
torfjelde Jan 19, 2024
8e0853d
updated `similar` for `VarNameVector` and fixed `invlink` for `VarNam…
torfjelde Jan 19, 2024
229b168
added handling of `is_transformed` in `merge` for `VarNameVector`
torfjelde Jan 19, 2024
c581dcf
removed unnecesasry `deepcopy` from outer `link`
torfjelde Jan 19, 2024
b4d3f55
updated `push!` to also `push!` on `is_transformed`
torfjelde Jan 19, 2024
ed1d006
skip tests for mutating linking when using VarNameVector
torfjelde Jan 19, 2024
f132209
use same projection for `Cholesky` in `VarNameVector` as in `VarInfo`
torfjelde Jan 19, 2024
49454de
fixed `settrans!!` for `VarInfo` with `VarNameVector`
torfjelde Jan 19, 2024
01ff2dd
fixed bug in `set_flag!`
torfjelde Jan 19, 2024
20adedf
fixed another typo
torfjelde Jan 19, 2024
8f9566a
fixed return values of `settrans!!`
torfjelde Jan 19, 2024
5532046
updated static transformation tests
torfjelde Jan 20, 2024
3c5d2ac
Update test/simple_varinfo.jl
torfjelde Jan 20, 2024
317d969
Merge branch 'master' into torfjelde/varnamevector
torfjelde Jan 20, 2024
f8441ea
Merge remote-tracking branch 'origin/torfjelde/varnamevector' into to…
torfjelde Jan 25, 2024
ab16323
Merge branch 'master' into torfjelde/varnamevector
torfjelde Jan 25, 2024
a9be219
removed unnecessary impl of `extract_priors`
torfjelde Jan 25, 2024
53c8d33
make `short_varinfo_name` of `TypedVarInfo` a bit more informative
torfjelde Jan 25, 2024
61d85ad
moved impl of `has_varnamevector` for `ThreadSafeVarInfo`
torfjelde Jan 25, 2024
9ace554
added back `extract_priors` impl as we do need it
torfjelde Jan 25, 2024
67c9821
forgot to include tests for `VarNameVector` in `runtests.jl`
torfjelde Jan 25, 2024
32a2d31
fix for `relax_container_types` in `test/varnamevector.jl`
torfjelde Jan 25, 2024
b3bb42d
fixed `need_transforms_relaxation`
torfjelde Jan 26, 2024
25ff2b1
updated some tests to not refer directly to `FromVec`
torfjelde Jan 28, 2024
004f038
introduce `from_internal_transform` and its siblings
torfjelde Jan 28, 2024
38c89bd
remove `with_logabsdet_jacobian_and_reconstruct` in favour of
torfjelde Jan 28, 2024
218dc23
added `internal_to_linked_internal_transform` + fixed a few bugs in
torfjelde Jan 28, 2024
1df4293
added `linked_internal_to_internal_transform` as a complement to `int…
torfjelde Jan 28, 2024
f8df896
fixed bugs in `invlink` for `VarInfo` using `linked_internal_to_inter…
torfjelde Jan 28, 2024
d62f26a
more work on removing calls to `reconstruct`
torfjelde Jan 28, 2024
b4517d6
removed redundant comment
torfjelde Jan 28, 2024
b7d4754
added `from_linked_vec_transform` specialization for `LKJ`
torfjelde Jan 28, 2024
0244dd9
more work on removing references to `reconstruct`
torfjelde Jan 28, 2024
e886d07
added `copy` in `values_from_metadata` to preserve behavior and avoid
torfjelde Jan 28, 2024
2af6605
remove `reconstruct_and_link` and `invlink_and_reconstruct`
torfjelde Jan 28, 2024
a0664d7
replaced references to `link_and_reconstruct` and `invlink_and_recons…
torfjelde Jan 28, 2024
f2d59b2
introduced `recombine` and replaced calls to `reconstruct` with `n` s…
torfjelde Jan 28, 2024
e3bfa76
completely removed `reconstruct`
torfjelde Jan 28, 2024
c0aef81
renamed `maybe_reconstruct_and_link` to `to_maybe_linked_internal` and
torfjelde Jan 28, 2024
f7c0853
added impls of `from_*_internal_transform` for `ThreadSafeVarInfo`
torfjelde Jan 30, 2024
77b835e
removed `reconstruct` from docs and from exports
torfjelde Jan 30, 2024
b83c262
renamed `getval` to `getindex_internal` and made `dist` an optional
torfjelde Jan 31, 2024
c4faf3e
updated docs + added description of how internals of transforms work
torfjelde Jan 31, 2024
c8d9695
added a bunch of illustrations for the transforms docs + dot files us…
torfjelde Jan 31, 2024
95dc8e3
temporarily removed `VarNameVector` completely
torfjelde Jan 31, 2024
8930f9c
formatting
torfjelde Jan 31, 2024
e45b668
Update docs/src/internals/transformations.md
torfjelde Jan 31, 2024
0e71092
Update docs/src/internals/transformations.md
torfjelde Jan 31, 2024
2de9ac9
removed refs to VectorVarInfo
torfjelde Jan 31, 2024
9b71428
added impls of `from_internal_transform` for `ThreadSafeVarInfo`
torfjelde Feb 1, 2024
786e9bf
reverted accidental removal of old `VarInfo` constructor
torfjelde Feb 1, 2024
f1fe42c
fixed incorrect `recombine` call
torfjelde Feb 1, 2024
2273954
removed undefined refs to `VarNameVector` stuff in `setup_varinfos`
torfjelde Feb 1, 2024
ab7c189
bump minior version because Turing breaks
torfjelde Feb 1, 2024
3a86601
fix: was using `from_linked_internal_transform` in
torfjelde Feb 1, 2024
28c7d85
removed `getindex_raw`
torfjelde Feb 1, 2024
59514d6
removed redundant docstrings
torfjelde Feb 1, 2024
cdc882b
fixed tests
torfjelde Feb 1, 2024
57ba7c0
fixed comparisons in tests
torfjelde Feb 1, 2024
902e59c
try relative references for images in transformation docs
torfjelde Feb 1, 2024
d7aba55
another attempt at fixing asset-references
torfjelde Feb 3, 2024
b0dd2f8
Merge branch 'master' into torfjelde/transformations
torfjelde Feb 3, 2024
1f51203
fixed getindex diagrams in docs
torfjelde Feb 3, 2024
0eb79b1
minor changes to comments
torfjelde Feb 3, 2024
071bebf
remove Combinatorics as a test dep, as it's not needed for this PR
torfjelde Feb 3, 2024
bbdc060
reverted unnecessary change
torfjelde Feb 3, 2024
e2f4d18
disable type-stability tests for models on older Julia versions
torfjelde Feb 3, 2024
3d823ac
removed seemingly completely unused impl of `setval!`
torfjelde Feb 3, 2024
54792f4
Revert "temporarily removed `VarNameVector` completely"
torfjelde Feb 3, 2024
ff68206
Revert "remove Combinatorics as a test dep, as it's not needed for th…
torfjelde Feb 6, 2024
9b1014d
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Aug 22, 2024
19978ec
More work on `VarNameVector` (#637)
mhauru Sep 3, 2024
95668eb
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Sep 3, 2024
1e4efe6
Bump Bijectors dependecy
mhauru Sep 3, 2024
3ee9832
Remove dead TODO note
mhauru Sep 3, 2024
26753e9
Remove old TODOs, improve VNV invlinking
mhauru Sep 3, 2024
ea18e1f
Fix from_vec_transform for 0-dim arrays
mhauru Sep 4, 2024
ffbf2ad
Fix unflatten for VarInfo
mhauru Sep 4, 2024
f077f4a
Fix some VarInfo index getters
mhauru Sep 4, 2024
e27af80
Change how VNV handles transformations, and other VNV stuff
mhauru Sep 4, 2024
b5677b4
Small docs fixes
mhauru Sep 4, 2024
9d1c8d3
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Sep 4, 2024
b778082
Small fixes all over for VNV
mhauru Sep 5, 2024
9750e60
Add comments
mhauru Sep 5, 2024
9ecc506
Fix some tests
mhauru Sep 5, 2024
3f1b9a2
Change long string formatting to support Julia 1.6
mhauru Sep 5, 2024
9145965
Small changes to ReshapeTransformation
mhauru Sep 5, 2024
937956d
Revert unrelated changes to ReverseDiff extension
mhauru Sep 5, 2024
4fbe5d2
Improve VarNamedVector VarInfo testing
mhauru Sep 5, 2024
9f11e7b
Fix some depwarns
mhauru Sep 5, 2024
86d97ae
Improvements to test/simple_varinfo.jl
mhauru Sep 5, 2024
2535517
Fix for unset_flag!, better docstring
mhauru Sep 5, 2024
93ef3ee
Add a comment about hasvalue/getvalue
mhauru Sep 5, 2024
f35eca6
Add @non_differentiable calls to work around Zygote limitations
mhauru Sep 9, 2024
d55fc00
Fix docs, workaround Zygote issue
mhauru Sep 9, 2024
5bbba91
Remove outdated workaround
mhauru Sep 17, 2024
851630f
Move has_varnamedvector(varinfo::VarInfo) to abstract_varinfo.jl
mhauru Sep 17, 2024
45c89c4
Make copies of logp and num_produce in subset
mhauru Sep 17, 2024
30252dc
Rename getindex_raw to getindex_internal
mhauru Sep 17, 2024
be77c36
Add push!(::VarNamedVector, ::Pair)
mhauru Sep 17, 2024
8b5bd47
Improve VarNamedVector docs
mhauru Sep 24, 2024
466acb2
Simplify VarNamedVector constructors
mhauru Sep 24, 2024
40909dd
Change how VNV setindex! et al work
mhauru Oct 2, 2024
a9a6ce2
More improvements to VNV setters and their tests
mhauru Oct 2, 2024
3cdd1d1
Fix style issues in VNV
mhauru Oct 2, 2024
ad06300
Update VNV docs. Add haskey to VarInfo
mhauru Oct 2, 2024
380ef42
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Oct 2, 2024
a203ab7
Fix VarInfo docs
mhauru Oct 2, 2024
775284d
Disable a test that only works for VectorVarInfo
mhauru Oct 2, 2024
79c3a81
Fix bug in isempty(::TypedVarInfo)
mhauru Oct 2, 2024
1f45e76
Make some doctests platform independent
mhauru Oct 3, 2024
2ffbca4
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
mhauru Oct 3, 2024
7fcc7df
Better implementation of haskey(::VarInfo, ::VarName)
mhauru Oct 3, 2024
567c4bc
Improve haskey for VarInfo
mhauru Oct 3, 2024
172d128
Make a VNV doctest more robust
mhauru Oct 3, 2024
186a846
Remote IndexStyle for VNV
mhauru Oct 3, 2024
5508370
Clean up an old comment
mhauru Oct 3, 2024
25b85b4
Fix haskey(::VarInfo, ::VarName)
mhauru Oct 3, 2024
eb5577b
Clarify a TODO note in varinfo.jl
mhauru Oct 4, 2024
b3f92c2
Reintroduce Int indexing to VNV
mhauru Oct 4, 2024
65c94ca
Stop exporting any VNV stuff
mhauru Oct 4, 2024
47428c7
Fix docs
mhauru Oct 4, 2024
c966984
Revert default VarInfo metadata type to Metadata
mhauru Oct 7, 2024
e013926
Fix a few trivial issues with Metadata
mhauru Oct 7, 2024
ed747a4
Docs bug fix
mhauru Oct 7, 2024
b23d4e2
Fix type constraint
mhauru Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export AbstractVarInfo,
UntypedVarInfo,
TypedVarInfo,
SimpleVarInfo,
VarNameVector,
push!!,
empty!!,
subset,
Expand Down Expand Up @@ -164,6 +165,7 @@ include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varnamevector.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
Expand Down
31 changes: 27 additions & 4 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
logp::Base.RefValue{Tlogp}
num_produce::Base.RefValue{Int}
end
const VectorVarInfo = VarInfo{<:VarNameVector}
const UntypedVarInfo = VarInfo{<:Metadata}
const TypedVarInfo = VarInfo{<:NamedTuple}
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
Expand Down Expand Up @@ -520,6 +521,8 @@
"""
getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn)
getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)]
# HACK: we shouldn't need this
getdist(::VarNameVector, ::VarName) = nothing

Check warning on line 525 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L525

Added line #L525 was not covered by tests

"""
getval(vi::VarInfo, vn::VarName)
Expand All @@ -530,6 +533,8 @@
"""
getval(vi::VarInfo, vn::VarName) = getval(getmetadata(vi, vn), vn)
getval(md::Metadata, vn::VarName) = view(md.vals, getrange(md, vn))
# HACK: We shouldn't need this
getval(vnv::VarNameVector, vn::VarName) = view(vnv.vals, getrange(vnv, vn))

Check warning on line 537 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L537

Added line #L537 was not covered by tests

"""
setval!(vi::VarInfo, val, vn::VarName)
Expand Down Expand Up @@ -562,13 +567,14 @@

The values may or may not be transformed to Euclidean space.
"""
getall(vi::UntypedVarInfo) = getall(vi.metadata)
getall(vi::VarInfo) = getall(vi.metadata)
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference.
# See for example https://github.com/JuliaLang/julia/pull/46381.
getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata))
function getall(md::Metadata)
return mapreduce(Base.Fix1(getval, md), vcat, md.vns; init=similar(md.vals, 0))
end
getall(vnv::VarNameVector) = vnv.vals

Check warning on line 577 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L577

Added line #L577 was not covered by tests

"""
setall!(vi::VarInfo, val)
Expand Down Expand Up @@ -743,7 +749,7 @@
@inline function _getranges(vi::VarInfo, s::Selector, space)
return _getranges(vi, _getidcs(vi, s, space))
end
@inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int})
@inline function _getranges(vi::VarInfo, idcs::Vector{Int})

Check warning on line 752 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L752

Added line #L752 was not covered by tests
return mapreduce(i -> vi.metadata.ranges[i], vcat, idcs; init=Int[])
end
@inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs)
Expand Down Expand Up @@ -848,6 +854,12 @@
return VarInfo(nt, Ref(logp), Ref(num_produce))
end
TypedVarInfo(vi::TypedVarInfo) = vi
function TypedVarInfo(vi::VectorVarInfo)
logp = getlogp(vi)
num_produce = get_num_produce(vi)
nt = group_by_symbol(vi.metadata)
return VarInfo(nt, Ref(logp), Ref(num_produce))

Check warning on line 861 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L857-L861

Added lines #L857 - L861 were not covered by tests
end

function BangBang.empty!!(vi::VarInfo)
_empty!(vi.metadata)
Expand All @@ -867,6 +879,8 @@
# Functions defined only for UntypedVarInfo
Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)

Base.keys(vi::VectorVarInfo) = keys(vi.metadata)

Check warning on line 882 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L882

Added line #L882 was not covered by tests

# HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly
# on other methods in the codebase which requires `Vector{<:VarName}`.
Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[]
Expand All @@ -890,7 +904,10 @@
return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid)
end

istrans(vi::VarInfo, vn::VarName) = is_flagged(vi, vn, "trans")
istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn)
istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans")
istrans(vnv::VarNameVector, vn::VarName) = !(gettransform(vnv, vn) isa FromVec)

Check warning on line 909 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L909

Added line #L909 was not covered by tests


torfjelde marked this conversation as resolved.
Show resolved Hide resolved
getlogp(vi::VarInfo) = vi.logp[]

Expand Down Expand Up @@ -1406,6 +1423,12 @@
val = getval(vi, vn)
return maybe_invlink_and_reconstruct(vi, vn, dist, val)
end
function getindex(vi::VectorVarInfo, vn::VarName, ::Nothing)
if !haskey(vi, vn)
throw(KeyError(vn))

Check warning on line 1428 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1426-L1428

Added lines #L1426 - L1428 were not covered by tests
end
return getmetadata(vi, vn)[vn]

Check warning on line 1430 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1430

Added line #L1430 was not covered by tests
end
function getindex(vi::VarInfo, vns::Vector{<:VarName})
# FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases
# such as `x .~ [Normal(), Exponential()]`.
Expand Down Expand Up @@ -1440,7 +1463,7 @@

The value(s) may or may not be transformed to Euclidean space.
"""
getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl)))
getindex(vi::VarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl)))

Check warning on line 1466 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1466

Added line #L1466 was not covered by tests
function getindex(vi::TypedVarInfo, spl::Sampler)
# Gets the ranges as a NamedTuple
ranges = _getranges(vi, spl)
Expand Down
154 changes: 154 additions & 0 deletions src/varnamevector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Similar to `Metadata` but representing a `Vector` and simpler interface.
# TODO: Should we subtype `AbstractVector`?
struct VarNameVector{
TIdcs<:OrderedDict{<:VarName,Int},
TVN<:AbstractVector{<:VarName},
TVal<:AbstractVector,
TTrans<:AbstractVector
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
}
"mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists`"
idcs::TIdcs # Dict{<:VarName,Int}

"vector of identifiers for the random variables, where `vns[idcs[vn]] == vn`"
vns::TVN # AbstractVector{<:VarName}

"vector of index ranges in `vals` corresponding to `vns`; each `VarName` `vn` has a single index or a set of contiguous indices in `vals`"
ranges::Vector{UnitRange{Int}}

"vector of values of all the univariate, multivariate and matrix variables; the value(s) of `vn` is/are `vals[ranges[idcs[vn]]]`"
vals::TVal # AbstractVector{<:Real}

"vector of transformations whose inverse takes us back to the original space"
transforms::TTrans
end

# Useful transformation going from the flattened representation.
struct FromVec{Sz}
sz::Sz
end

FromVec(x::Union{Real,AbstractArray}) = FromVec(size(x))

Check warning on line 30 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L30

Added line #L30 was not covered by tests

# TODO: Should we materialize the `reshape`?
(f::FromVec)(x) = reshape(x, f.sz)
(f::FromVec{Tuple{}})(x) = only(x)

Check warning on line 34 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L33-L34

Added lines #L33 - L34 were not covered by tests

Bijectors.with_logabsdet_jacobian(f::FromVec, x) = (f(x), 0)

Check warning on line 36 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L36

Added line #L36 was not covered by tests

tovec(x::Real) = [x]
tovec(x::AbstractArray) = vec(x)

Check warning on line 39 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L38-L39

Added lines #L38 - L39 were not covered by tests

Bijectors.inverse(f::FromVec) = tovec
Bijectors.inverse(f::FromVec{Tuple{}}) = tovec

Check warning on line 42 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L41-L42

Added lines #L41 - L42 were not covered by tests

VarNameVector(x::AbstractDict) = VarNameVector(keys(x), values(x))
VarNameVector(vns, vals) = VarNameVector(collect(vns), collect(vals))
function VarNameVector(

Check warning on line 46 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L44-L46

Added lines #L44 - L46 were not covered by tests
vns::AbstractVector,
vals::AbstractVector,
transforms = map(FromVec, vals)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
)
# TODO: Check uniqueness of `vns`?

# Convert `vals` into a vector of vectors.
vals_vecs = map(tovec, vals)

Check warning on line 54 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L54

Added line #L54 was not covered by tests

# TODO: Is this really the way to do this?
if !(eltype(vns) <: VarName)
vns = convert(Vector{VarName}, vns)

Check warning on line 58 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L57-L58

Added lines #L57 - L58 were not covered by tests
end
idcs = OrderedDict{eltype(vns),Int}()
ranges = Vector{UnitRange{Int}}()
offset = 0
for (i, (vn, x)) in enumerate(zip(vns, vals_vecs))

Check warning on line 63 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L60-L63

Added lines #L60 - L63 were not covered by tests
# Add the varname index.
push!(idcs, vn => length(idcs) + 1)

Check warning on line 65 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L65

Added line #L65 was not covered by tests
# Add the range.
r = (offset + 1):(offset + length(x))
push!(ranges, r)

Check warning on line 68 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
# Update the offset.
offset = r[end]
end

Check warning on line 71 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L70-L71

Added lines #L70 - L71 were not covered by tests

return VarNameVector(idcs, vns, ranges, reduce(vcat, vals_vecs), transforms)

Check warning on line 73 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L73

Added line #L73 was not covered by tests
end

# Basic array interface.
Base.eltype(vnv::VarNameVector) = eltype(vnv.vals)
Base.length(vnv::VarNameVector) = length(vnv.vals)
Base.size(vnv::VarNameVector) = size(vnv.vals)

Check warning on line 79 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L77-L79

Added lines #L77 - L79 were not covered by tests

Base.IndexStyle(::Type{<:VarNameVector}) = IndexLinear()

Check warning on line 81 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L81

Added line #L81 was not covered by tests

# Dictionary interface.
Base.keys(vnv::VarNameVector) = vnv.vns

Check warning on line 84 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L84

Added line #L84 was not covered by tests

Base.haskey(vnv::VarNameVector, vn::VarName) = haskey(vnv.idcs, vn)

Check warning on line 86 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L86

Added line #L86 was not covered by tests

# `getindex` & `setindex!`
getidx(vnv::VarNameVector, vn::VarName) = vnv.idcs[vn]

Check warning on line 89 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L89

Added line #L89 was not covered by tests

getrange(vnv::VarNameVector, i::Int) = vnv.ranges[i]
getrange(vnv::VarNameVector, vn::VarName) = getrange(vnv, getidx(vnv, vn))

Check warning on line 92 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L91-L92

Added lines #L91 - L92 were not covered by tests

gettransform(vnv::VarNameVector, vn::VarName) = vnv.transforms[getidx(vnv, vn)]

Check warning on line 94 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L94

Added line #L94 was not covered by tests

Base.getindex(vnv::VarNameVector, ::Colon) = vnv.vals
Base.getindex(vnv::VarNameVector, i::Int) = vnv.vals[i]
function Base.getindex(vnv::VarNameVector, vn::VarName)
x = vnv.vals[getrange(vnv, vn)]
f = gettransform(vnv, vn)
return f(x)

Check warning on line 101 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L96-L101

Added lines #L96 - L101 were not covered by tests
end

# HACK: remove this as soon as possible.
Base.getindex(vnv::VarNameVector, spl::AbstractSampler) = vnv[:]

Check warning on line 105 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L105

Added line #L105 was not covered by tests

Base.setindex!(vnv::VarNameVector, val, i::Int) = vnv.vals[i] = val
function Base.setindex!(vnv::VarNameVector, val, vn::VarName)
f = inverse(gettransform(vnv, vn))
vnv.vals[getrange(vnv, vn)] = f(val)

Check warning on line 110 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L107-L110

Added lines #L107 - L110 were not covered by tests
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

function Base.empty!(vnv::VarNameVector)

Check warning on line 113 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L113

Added line #L113 was not covered by tests
# TODO: Or should the semantics be different, e.g. keeping `vns`?
empty!(vnv.idcs)
empty!(vnv.vns)
empty!(vnv.ranges)
empty!(vnv.vals)
empty!(vnv.transforms)
return nothing

Check warning on line 120 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L115-L120

Added lines #L115 - L120 were not covered by tests
end
BangBang.empty!!(vnv::VarNameVector) = empty!(vnv)

Check warning on line 122 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L122

Added line #L122 was not covered by tests

# TODO: Re-use some of the show functionality from Base?
function Base.show(io::IO, vnv::VarNameVector)
print(io, "[")
for (i, vn) in enumerate(vnv.vns)
if i > 1
print(io, ", ")

Check warning on line 129 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L125-L129

Added lines #L125 - L129 were not covered by tests
end
print(io, vn, " = ", vnv[vn])
end

Check warning on line 132 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L131-L132

Added lines #L131 - L132 were not covered by tests
end

# Typed version.
function group_by_symbol(vnv::VarNameVector)

Check warning on line 136 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L136

Added line #L136 was not covered by tests
# Group varnames in `vnv` by the symbol.
d = OrderedDict{Symbol,Vector{VarName}}()
for vn in vnv.vns
push!(get!(d, getsym(vn), Vector{VarName}()), vn)
end

Check warning on line 141 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L138-L141

Added lines #L138 - L141 were not covered by tests

# Create a `NamedTuple` from the grouped varnames.
nt_vals = map(values(d)) do vns

Check warning on line 144 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L144

Added line #L144 was not covered by tests
# TODO: Do we need to specialize the inputs here?
VarNameVector(

Check warning on line 146 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L146

Added line #L146 was not covered by tests
map(identity, vns),
map(Base.Fix1(getindex, vnv), vns),
map(Base.Fix1(gettransform, vnv), vns)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
)
end

return NamedTuple{Tuple(keys(d))}(nt_vals)

Check warning on line 153 in src/varnamevector.jl

View check run for this annotation

Codecov / codecov/patch

src/varnamevector.jl#L153

Added line #L153 was not covered by tests
end
27 changes: 27 additions & 0 deletions test/varnamevector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testset "VarNameVector" begin
vns = [
@varname(x[1]),
@varname(x[2]),
@varname(x[3]),
]
vals = [
1,
2:3,
reshape(4:9, 2, 3),
]
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
vnv = VarNameVector(vns, vals)

# `getindex`
for (vn, val) in zip(vns, vals)
@test vnv[vn] == val
end

# `setindex!`
for (vn, val) in zip(vns, vals)
vnv[vn] = val .+ 100
end

for (vn, val) in zip(vns, vals)
@test vnv[vn] == val .+ 100
end
end
Loading