Skip to content

Commit

Permalink
added testing of update! with smaller sizes and fixed bug related t…
Browse files Browse the repository at this point in the history
…o this
  • Loading branch information
torfjelde committed Nov 14, 2023
1 parent 0900c57 commit 1f7e633
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 19 deletions.
13 changes: 12 additions & 1 deletion src/varnamevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,18 @@ end
function nextrange(vnv::VarNameVector, x)
# NOTE: Need to treat `isempty(vnv.ranges)` separately because `maximum`
# will error if `vnv.ranges` is empty.
offset = isempty(vnv.ranges) ? 0 : maximum(last, vnv.ranges)
max_active_range = isempty(vnv.ranges) ? 0 : maximum(last, vnv.ranges)
# Also need to consider inactive ranges, since we can have scenarios such as
#
# vnv = VarNameVector(@varname(x) => 1, @varname(y) => [2, 3])
# update!(vnv, @varname(y), [4]) # => `ranges = [1:1, 2:2], inactive_ranges = [3:3]`
#
# Here `nextrange(vnv, [5])` should return `4:4`, _not_ `3:3`.
# NOTE: We could of course attempt to make use of unused space, e.g. if we have an inactive
# range which can hold `x`, then we could just use that. Buuut the complexity of this is
# probably not worth it (at least at the moment).
max_inactive_range = isempty(vnv.inactive_ranges) ? 0 : maximum(last, vnv.inactive_ranges)
offset = max(max_active_range, max_inactive_range)
return (offset + 1):(offset + length(x))
end

Expand Down
81 changes: 63 additions & 18 deletions test/varnamevector.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
replace_sym(vn::VarName, sym_new::Symbol) = VarName{sym_new}(vn.lens)

change_size_for_test(x::Real) = [x]
change_size_for_test(x::AbstractArray) = repeat(x, 2)
increase_size_for_test(x::Real) = [x]
increase_size_for_test(x::AbstractArray) = repeat(x, 2)

decrease_size_for_test(x::Real) = x
decrease_size_for_test(x::AbstractVector) = first(x)
decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1))

function need_varnames_relaxation(vnv::VarNameVector, vn::VarName, val)
if isconcretetype(eltype(vnv.varnames))
Expand All @@ -16,6 +20,9 @@ function need_varnames_relaxation(vnv::VarNameVector, vn::VarName, val)

return false
end
function need_varnames_relaxation(vnv::VarNameVector, vns, vals)
return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals))
end

function need_values_relaxation(vnv::VarNameVector, vn::VarName, val)
if isconcretetype(eltype(vnv.vals))
Expand All @@ -24,6 +31,9 @@ function need_values_relaxation(vnv::VarNameVector, vn::VarName, val)

return false
end
function need_values_relaxation(vnv::VarNameVector, vns, vals)
return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals))
end

function need_transforms_relaxation(vnv::VarNameVector, vn::VarName, val)
if isconcretetype(eltype(vnv.transforms))
Expand All @@ -36,28 +46,50 @@ function need_transforms_relaxation(vnv::VarNameVector, vn::VarName, val)

return false
end
function need_transforms_relaxation(vnv::VarNameVector, vns, vals)
return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals))
end

"""
relax_container_types(vnv::VarNameVector, vn::VarName, val)
relax_container_types(vnv::VarNameVector, vns, val)
Relax the container types of `vnv` if necessary to accommodate `vn` and `val`.
This attempts to avoid unnecessary container type relaxations by checking whether
the container types of `vnv` are already compatible with `vn` and `val`.
# Notes
For example, if `vn` is not compatible with the current keys in `vnv`, then
the underlying types will be changed to `VarName` to accommodate `vn`.
Similarly:
- If `val` is not compatible with the current values in `vnv`, then
the underlying value type will be changed to `Real`.
- If `val` requires a transformation that is not compatible with the current
transformations type in `vnv`, then the underlying transformation type will
be changed to `Any`.
"""
function relax_container_types(vnv::VarNameVector, vn::VarName, val)
return relax_container_types(vnv, [vn], [val])
end
function relax_container_types(vnv::VarNameVector, vns, vals)
if any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals))
if need_varnames_relaxation(vnv, vns, vals)
varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index)
varnames_new = convert(Vector{VarName}, vnv.varnames)
else
varname_to_index_new = vnv.varname_to_index
varnames_new = vnv.varnames
end

transforms_new =
if any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals))
convert(Vector{Any}, vnv.transforms)
else
vnv.transforms
end
transforms_new = if need_transforms_relaxation(vnv, vns, vals)
convert(Vector{Any}, vnv.transforms)
else
vnv.transforms
end

vals_new = if any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals))
convert(Vector{Any}, vnv.vals)
vals_new = if need_values_relaxation(vnv, vns, vals)
convert(Vector{Real}, vnv.vals)
else
vnv.vals
end
Expand Down Expand Up @@ -102,7 +134,7 @@ end
@varname(z[3]) => rand(1:10, 2, 3),
)
test_vns = collect(keys(test_pairs))
test_vals = collect(test_vals)
test_vals = collect(values(test_pairs))

@testset "constructor: no args" begin
# Empty.
Expand Down Expand Up @@ -237,14 +269,10 @@ end
@test !DynamicPPL.has_inactive_ranges(vnv)
end

# Need to recompute valid varnames for the changing of the sizes; before
# we required either a) the underlying `transforms` to be non-concrete,
# or b) the sizes of the values to match. But now the sizes of the values
# will change, so we can only test the former.
vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals)
@testset "$vn (different size)" for vn in test_vns
@testset "$vn (increased size)" for vn in test_vns
val_original = test_pairs[vn]
val = change_size_for_test(val_original)
val = increase_size_for_test(val_original)
vn_already_present = haskey(vnv, vn)
expected_length = if vn_already_present
# If it's already present, the resulting length will be altered.
Expand All @@ -258,6 +286,23 @@ end
@test length(vnv) == expected_length
@test length(vnv[:]) == length(vnv)
end

vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals)
@testset "$vn (decreased size)" for vn in test_vns
val_original = test_pairs[vn]
val = decrease_size_for_test(val_original)
vn_already_present = haskey(vnv, vn)
expected_length = if vn_already_present
# If it's already present, the resulting length will be altered.
length(vnv) + length(val) - length(val_original)
else
length(vnv) + length(val)
end
DynamicPPL.update!(vnv, vn, val .+ 1)
@test vnv[vn] == val .+ 1
@test length(vnv) == expected_length
@test length(vnv[:]) == length(vnv)
end
end
end
end

0 comments on commit 1f7e633

Please sign in to comment.