From 1d8d50794e022311c0fa321092cf7766ff15d60a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 20:20:57 +0100 Subject: [PATCH] formatting --- src/varinfo.jl | 36 ++++++++++++++++-------------------- test/varinfo.jl | 13 +++++-------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 35ca1ef72..8b95277df 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -302,8 +302,11 @@ end Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. """ -Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =_merge(varinfo_left, varinfo_right) -Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) =_merge(varinfo_left, varinfo_right) +Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) = + _merge(varinfo_left, varinfo_right) +function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) + return _merge(varinfo_left, varinfo_right) +end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) @@ -314,9 +317,8 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) end function merge_metadata( - metadata_left::NamedTuple{names_left}, - metadata_right::NamedTuple{names_right} -) where {names_left, names_right} + metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} +) where {names_left,names_right} # TODO: Improve this. Maybe make `@generated`? metadata = map(names_left) do sym if sym in names_right @@ -332,7 +334,9 @@ function merge_metadata( end end - return NamedTuple{(names_left..., names_right_only...)}(tuple(metadata..., metadata_right_only...)) + return NamedTuple{(names_left..., names_right_only...)}( + tuple(metadata..., metadata_right_only...) + ) end function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) @@ -361,13 +365,13 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # Initialize required fields for `metadata`. vns = VarName[] - idcs = Dict{VarName, Int}() + idcs = Dict{VarName,Int}() ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` orders = Int[] - flags = Dict{String, BitVector}() + flags = Dict{String,BitVector}() # Initialize the `flags`. for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) flags[k] = BitVector() @@ -442,16 +446,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) end end - return Metadata( - idcs, - vns, - ranges, - vals, - dists, - gids, - orders, - flags, - ) + return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -1601,7 +1596,6 @@ run before sampling `vn`. getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] - ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1614,7 +1608,9 @@ Check whether `vn` has a true value for `flag` in `vi`. function is_flagged(vi::VarInfo, vn::VarName, flag::String) return is_flagged(getmetadata(vi, vn), vn, flag) end -is_flagged(metadata::Metadata, vn::VarName, flag::String) = metadata.flags[flag][getidx(metadata, vn)] +function is_flagged(metadata::Metadata, vn::VarName, flag::String) + return metadata.flags[flag][getidx(metadata, vn)] +end """ unset_flag!(vi::VarInfo, vn::VarName, flag::String) diff --git a/test/varinfo.jl b/test/varinfo.jl index a77bb6a13..20e9b9823 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -481,12 +481,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) ] # All variables. - @test isempty( - setdiff( - keys(varinfo), - vns, - ), - ) + @test isempty(setdiff(keys(varinfo), vns)) @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [ [@varname(s)], @@ -526,7 +521,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS @testset "$(short_varinfo_name(varinfo))" for varinfo in [ VarInfo(model), - last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())), ] vns = DynamicPPL.TestUtils.varnames(model) @testset "with itself" begin @@ -551,7 +546,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "with different value" begin x = DynamicPPL.TestUtils.rand(model) - varinfo_changed = DynamicPPL.TestUtils.update_values!!(deepcopy(varinfo), x, vns) + varinfo_changed = DynamicPPL.TestUtils.update_values!!( + deepcopy(varinfo), x, vns + ) # After `merge`, we should have the same values as `x`. varinfo_merged = merge(varinfo, varinfo_changed) DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns)