From 14e81fe3ada24fcfaaf46f0dc85a50295635b042 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 5 Dec 2024 14:51:44 +0000 Subject: [PATCH 1/3] Add varname tests from DPPL cf. https://github.com/TuringLang/DynamicPPL.jl/issues/737 --- test/varname.jl | 81 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/test/varname.jl b/test/varname.jl index 7a92d1e..4f9d329 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -14,7 +14,7 @@ macro test_strict_subsumption(x, y) end end -function test_equal(o1::VarName{sym1}, o2::VarName{sym2}) where {sym1, sym2} +function test_equal(o1::VarName{sym1}, o2::VarName{sym2}) where {sym1,sym2} return sym1 === sym2 && test_equal(o1.optic, o2.optic) end function test_equal(o1::ComposedFunction, o2::ComposedFunction) @@ -28,26 +28,53 @@ function test_equal(o1, o2) end @testset "varnames" begin + @testset "string and symbol conversion" begin + vn1 = @varname x[1][2] + @test string(vn1) == "x[1][2]" + @test Symbol(vn1) == Symbol("x[1][2]") + end + + @testset "equality and hashing" begin + vn1 = @varname x[1][2] + vn2 = @varname x[1][2] + @test vn2 == vn1 + @test hash(vn2) == hash(vn1) + end + + @testset "inspace" begin + space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) + @test inspace(@varname(x), space) + @test inspace(@varname(y), space) + @test inspace(@varname(x[1]), space) + @test inspace(@varname(z[1][1]), space) + @test inspace(@varname(z[1][:]), space) + @test inspace(@varname(z[1][2:3:10]), space) + @test inspace(@varname(M[[2, 3], 1]), space) + @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) + @test inspace(@varname(M[1, [2, 4, 6]]), space) + @test !inspace(@varname(z[2]), space) + @test !inspace(@varname(z), space) + end + @testset "construction & concretization" begin i = 1:10 j = 2:2:5 @test @varname(A[1].b[i]) == @varname(A[1].b[1:10]) @test @varname(A[j]) == @varname(A[2:2:5]) - + @test @varname(A[:, 1][1+1]) == @varname(A[:, 1][2]) - @test(@varname(A[:, 1][2]) == - VarName{:A}(@o(_[:, 1]) ⨟ @o(_[2]))) + @test(@varname(A[:, 1][2]) == VarName{:A}(@o(_[:, 1]) ⨟ @o(_[2]))) # concretization y = zeros(10, 10) - x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0],); + x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0],) @test @varname(y[begin, i], true) == @varname(y[1, 1:10]) @test test_equal(@varname(y[:], true), @varname(y[1:100])) @test test_equal(@varname(y[:, begin], true), @varname(y[1:10, 1])) - @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === - AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) - @test test_equal(@varname(x.a[1:end, end][:], true), @varname(x.a[1:3,2][1:3])) + @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === + AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) + @test test_equal(@varname(x.a[1:end, end][:], true), @varname(x.a[1:3, 2][1:3])) end @testset "compose and opcompose" begin @@ -63,13 +90,13 @@ end end @testset "get & set" begin - x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 1.0); + x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 1.0) @test get(x, @varname(a[1, 2])) == 2.0 @test get(x, @varname(b)) == 1.0 @test set(x, @varname(a[1, 2]), 10) == (a = [1.0 10.0; 3.0 4.0; 5.0 6.0], b = 1.0) @test set(x, @varname(b), 10) == (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 10.0) end - + @testset "subsumption with standard indexing" begin # x ⊑ x @test @varname(x) ⊑ @varname(x) @@ -97,11 +124,11 @@ end @test_strict_subsumption x[1] x[1:10] @test_strict_subsumption x[1:5] x[1:10] @test_strict_subsumption x[4:6] x[1:10] - - @test_strict_subsumption x[[2,3,5]] x[[7,6,5,4,3,2,1]] + + @test_strict_subsumption x[[2, 3, 5]] x[[7, 6, 5, 4, 3, 2, 1]] @test_strict_subsumption x[:a][1] x[:a] - + # boolean indexing works as long as it is concretized A = rand(10, 10) @test @varname(A[iseven.(1:10), 1], true) ⊑ @varname(A[1:10, 1]) @@ -116,8 +143,11 @@ end @testset "non-standard indexing" begin A = rand(10, 10) - @test test_equal(@varname(A[1, Not(3)], true), @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]])) - + @test test_equal( + @varname(A[1, Not(3)], true), + @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]) + ) + B = OffsetArray(A, -5, -5) # indices -4:5×-4:5 @test test_equal(@varname(B[1, :], true), @varname(B[1, -4:5])) @@ -129,11 +159,11 @@ end @inferred VarName{:a}(PropertyLens(:b)) @inferred VarName{:a}(Accessors.opcompose(IndexLens(1), PropertyLens(:b))) - b = (a=[1, 2, 3],) + b = (a = [1, 2, 3],) @inferred get(b, @varname(a[1])) @inferred Accessors.set(b, @varname(a[1]), 10) - c = (b=(a=[1, 2, 3],),) + c = (b = (a = [1, 2, 3],),) @inferred get(c, @varname(b.a[1])) @inferred Accessors.set(c, @varname(b.a[1]), 10) end @@ -166,10 +196,10 @@ end @varname(z[:], true), @varname(z[:][:], false), @varname(z[:][:], true), - @varname(z[:,:], false), - @varname(z[:,:], true), - @varname(z[2:5,:], false), - @varname(z[2:5,:], true), + @varname(z[:, :], false), + @varname(z[:, :], true), + @varname(z[2:5, :], false), + @varname(z[2:5, :], true), ] for vn in vns @test string_to_varname(varname_to_string(vn)) == vn @@ -194,8 +224,13 @@ end @test_throws MethodError varname_to_string(vn) # Now define the relevant methods - AbstractPPL.index_to_dict(o::OffsetArrays.IdOffsetRange{I, R}) where {I,R} = Dict("type" => "OffsetArrays.OffsetArray", "parent" => AbstractPPL.index_to_dict(o.parent), "offset" => o.offset) - AbstractPPL.dict_to_index(::Val{Symbol("OffsetArrays.OffsetArray")}, d) = OffsetArrays.IdOffsetRange(AbstractPPL.dict_to_index(d["parent"]), d["offset"]) + AbstractPPL.index_to_dict(o::OffsetArrays.IdOffsetRange{I,R}) where {I,R} = Dict( + "type" => "OffsetArrays.OffsetArray", + "parent" => AbstractPPL.index_to_dict(o.parent), + "offset" => o.offset, + ) + AbstractPPL.dict_to_index(::Val{Symbol("OffsetArrays.OffsetArray")}, d) = + OffsetArrays.IdOffsetRange(AbstractPPL.dict_to_index(d["parent"]), d["offset"]) # Serialisation should now work @test string_to_varname(varname_to_string(vn)) == vn From 36549d74e2d32b32e6aa8e518600b5ca73baba15 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 5 Dec 2024 15:12:51 +0000 Subject: [PATCH 2/3] Format --- docs/make.jl | 12 ++--- src/AbstractPPL.jl | 5 +- src/abstractprobprog.jl | 5 -- src/varname.jl | 115 +++++++++++++++++++++++++++------------- test/runtests.jl | 5 +- test/varname.jl | 20 ++++--- 6 files changed, 95 insertions(+), 67 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index cf031b4..33bf21b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,12 +2,12 @@ using Documenter using AbstractPPL # Doctest setup -DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive = true) +DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true) makedocs(; - sitename = "AbstractPPL", - modules = [AbstractPPL], - pages = ["Home" => "index.md", "API" => "api.md"], - checkdocs = :exports, - doctest = false, + sitename="AbstractPPL", + modules=[AbstractPPL], + pages=["Home" => "index.md", "API" => "api.md"], + checkdocs=:exports, + doctest=false, ) diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index f121b97..86015a6 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -16,14 +16,13 @@ export VarName, varname_to_string, string_to_varname - # Abstract model functions -export AbstractProbabilisticProgram, condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!! +export AbstractProbabilisticProgram, + condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!! # Abstract traces export AbstractModelTrace - include("varname.jl") include("abstractmodeltrace.jl") include("abstractprobprog.jl") diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index 9051d3e..07e5546 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -2,7 +2,6 @@ using AbstractMCMC using DensityInterface using Random - """ AbstractProbabilisticProgram @@ -12,7 +11,6 @@ abstract type AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel end DensityInterface.DensityKind(::AbstractProbabilisticProgram) = HasDensity() - """ logdensityof(model, trace) @@ -26,7 +24,6 @@ probability theory. """ DensityInterface.logdensityof(::AbstractProbabilisticProgram, ::AbstractModelTrace) - """ decondition(conditioned_model) @@ -43,7 +40,6 @@ should hold for models `m` with conditioned variables `obs`. """ function decondition end - """ condition(model, observations) @@ -84,7 +80,6 @@ should hold for any model `m` and parameters `params`. """ function fix end - """ unfix(model) diff --git a/src/varname.jl b/src/varname.jl index 48d7c5c..0d0d62d 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -36,7 +36,11 @@ struct VarName{sym,T} function VarName{sym}(optic=identity) where {sym} if !is_static_optic(typeof(optic)) - throw(ArgumentError("attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))")) + throw( + ArgumentError( + "attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))", + ), + ) end return new{sym,typeof(optic)}(optic) end @@ -168,7 +172,7 @@ end function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T} print(io, getsym(vn)) - _show_optic(io, getoptic(vn)) + return _show_optic(io, getoptic(vn)) end # modified from https://github.com/JuliaObjects/Accessors.jl/blob/01528a81fdf17c07436e1f3d99119d3f635e4c26/src/sugar.jl#L502 @@ -181,7 +185,7 @@ function _show_optic(io::IO, optic) print(io, " ∘ ") end shortstr = reduce(_shortstring, inner; init="") - print(io, shortstr) + return print(io, shortstr) end _shortstring(prev, o::IndexLens) = "$prev[$(join(map(prettify_index, o.indices), ", "))]" @@ -207,7 +211,6 @@ Symbol("x[1][:]") """ Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol - """ inspace(vn::Union{VarName, Symbol}, space::Tuple) @@ -244,7 +247,6 @@ inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space) _in(vn::VarName, s::Symbol) = getsym(vn) == s _in(vn::VarName, s::VarName) = subsumes(s, vn) - """ subsumes(u::VarName, v::VarName) @@ -297,8 +299,9 @@ subsumes(::typeof(identity), ::typeof(identity)) = true subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false -subsumes(t::ComposedOptic, u::ComposedOptic) = - subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner) +function subsumes(t::ComposedOptic, u::ComposedOptic) + return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner) +end # If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a # leaf of the "lens-tree". @@ -317,11 +320,12 @@ subsumes(t::PropertyLens, u::PropertyLens) = false # FIXME: Does not support `DynamicIndexLens`. # FIXME: Does not correctly handle cases such as `subsumes(x, x[:])` # (but neither did old implementation). -subsumes( +function subsumes( t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}, - u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}} -) = subsumes_indices(t, u) - + u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}, +) + return subsumes_indices(t, u) +end """ subsumedby(t, u) @@ -444,7 +448,6 @@ subsumes_index(i::Colon, j) = true subsumes_index(i::AbstractVector, j) = issubset(j, i) subsumes_index(i, j) = i == j - """ ConcretizedSlice(::Base.Slice) @@ -455,10 +458,13 @@ struct ConcretizedSlice{T,R} <: AbstractVector{T} range::R end -ConcretizedSlice(s::Base.Slice{R}) where {R} = ConcretizedSlice{eltype(s.indices),R}(s.indices) +function ConcretizedSlice(s::Base.Slice{R}) where {R} + return ConcretizedSlice{eltype(s.indices),R}(s.indices) +end Base.show(io::IO, s::ConcretizedSlice) = print(io, ":") -Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) = - print(io, "ConcretizedSlice(", s.range, ")") +function Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) + return print(io, "ConcretizedSlice(", s.range, ")") +end Base.size(s::ConcretizedSlice) = size(s.range) Base.iterate(s::ConcretizedSlice, state...) = Base.iterate(s.range, state...) Base.collect(s::ConcretizedSlice) = collect(s.range) @@ -480,8 +486,9 @@ The only purpose of this are special cases like `:`, which we want to avoid beco `ConcretizedSlice` based on the `lowered_index`, just what you'd get with an explicit `begin:end` """ reconcretize_index(original_index, lowered_index) = lowered_index -reconcretize_index(original_index::Colon, lowered_index::Base.Slice) = - ConcretizedSlice(lowered_index) +function reconcretize_index(original_index::Colon, lowered_index::Base.Slice) + return ConcretizedSlice(lowered_index) +end """ concretize(l, x) @@ -495,7 +502,9 @@ the result close to the original indexing. """ concretize(I::ALLOWED_OPTICS, x) = I concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x) -concretize(I::IndexLens, x) = IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices))) +function concretize(I::IndexLens, x) + return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices))) +end function concretize(I::ComposedOptic, x) x_inner = I.inner(x) # TODO: get view here return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x)) @@ -646,11 +655,9 @@ function varname(expr::Expr, concretize=Accessors.need_dynamic_optic(expr)) end if concretize - return :( - $(AbstractPPL.VarName){$sym}( + return :($(AbstractPPL.VarName){$sym}( $(AbstractPPL.concretize)($optics, $sym_escaped) - ) - ) + )) elseif Accessors.need_dynamic_optic(expr) error("Variable name `$(expr)` is dynamic and requires concretization!") else @@ -672,7 +679,7 @@ end function _parse_obj_optic(ex) obj, optics = _parse_obj_optics(ex) optic = Expr(:call, Accessors.opticcompose, optics...) - obj, optic + return obj, optic end # Accessors doesn't have the same support for interpolation @@ -688,7 +695,8 @@ function _parse_obj_optics(ex) indices = Accessors.replace_underscore.(indices, collection) dims = length(indices) == 1 ? nothing : 1:length(indices) lindices = esc.(Accessors.lower_index.(collection, indices, dims)) - optics = :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),))) + optics = + :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),))) else index = esc(Expr(:tuple, indices...)) optics = :($(Accessors.IndexLens)($index)) @@ -702,16 +710,20 @@ function _parse_obj_optics(ex) elseif Meta.isexpr(property, :$, 1) optics = :($(Accessors.PropertyLens){$(esc(property.args[1]))}()) else - throw(ArgumentError( - string("Error while parsing :($ex). Second argument to `getproperty` can only be", - "a `Symbol` or `String` literal, received `$property` instead.") - )) + throw( + ArgumentError( + string( + "Error while parsing :($ex). Second argument to `getproperty` can only be", + "a `Symbol` or `String` literal, received `$property` instead.", + ), + ), + ) end else obj = esc(ex) return obj, () end - obj, tuple(frontoptics..., optics) + return obj, tuple(frontoptics..., optics) end """ @@ -778,12 +790,27 @@ Convert an index `i` to a dictionary representation. """ index_to_dict(i::Integer) = Dict("type" => _BASE_INTEGER_TYPE, "value" => i) index_to_dict(v::Vector{Int}) = Dict("type" => _BASE_VECTOR_TYPE, "values" => v) -index_to_dict(r::UnitRange) = Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop) -index_to_dict(r::StepRange) = Dict("type" => _BASE_STEPRANGE_TYPE, "start" => r.start, "stop" => r.stop, "step" => r.step) -index_to_dict(r::Base.OneTo{I}) where {I} = Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop) +function index_to_dict(r::UnitRange) + return Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop) +end +function index_to_dict(r::StepRange) + return Dict( + "type" => _BASE_STEPRANGE_TYPE, + "start" => r.start, + "stop" => r.stop, + "step" => r.step, + ) +end +function index_to_dict(r::Base.OneTo{I}) where {I} + return Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop) +end index_to_dict(::Colon) = Dict("type" => _BASE_COLON_TYPE) -index_to_dict(s::ConcretizedSlice{T,R}) where {T,R} = Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range)) -index_to_dict(t::Tuple) = Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t)) +function index_to_dict(s::ConcretizedSlice{T,R}) where {T,R} + return Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range)) +end +function index_to_dict(t::Tuple) + return Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t)) +end """ dict_to_index(dict) @@ -839,9 +866,17 @@ function dict_to_index(dict) end optic_to_dict(::typeof(identity)) = Dict("type" => "identity") -optic_to_dict(::PropertyLens{sym}) where {sym} = Dict("type" => "property", "field" => String(sym)) +function optic_to_dict(::PropertyLens{sym}) where {sym} + return Dict("type" => "property", "field" => String(sym)) +end optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices)) -optic_to_dict(c::ComposedOptic) = Dict("type" => "composed", "outer" => optic_to_dict(c.outer), "inner" => optic_to_dict(c.inner)) +function optic_to_dict(c::ComposedOptic) + return Dict( + "type" => "composed", + "outer" => optic_to_dict(c.outer), + "inner" => optic_to_dict(c.inner), + ) +end function dict_to_optic(dict) if dict["type"] == "identity" @@ -857,9 +892,13 @@ function dict_to_optic(dict) end end -varname_to_dict(vn::VarName) = Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn))) +function varname_to_dict(vn::VarName) + return Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn))) +end -dict_to_varname(dict::Dict{<:AbstractString, Any}) = VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"])) +function dict_to_varname(dict::Dict{<:AbstractString,Any}) + return VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"])) +end """ varname_to_string(vn::VarName) diff --git a/test/runtests.jl b/test/runtests.jl index eb97c36..71ef2cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,10 +16,7 @@ using Test include("abstractprobprog.jl") @testset "doctests" begin DocMeta.setdocmeta!( - AbstractPPL, - :DocTestSetup, - :(using AbstractPPL); - recursive=true, + AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true ) doctest(AbstractPPL; manual=false) end diff --git a/test/varname.jl b/test/varname.jl index 4f9d329..32ac1c1 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -62,18 +62,18 @@ end @test @varname(A[1].b[i]) == @varname(A[1].b[1:10]) @test @varname(A[j]) == @varname(A[2:2:5]) - @test @varname(A[:, 1][1+1]) == @varname(A[:, 1][2]) + @test @varname(A[:, 1][1 + 1]) == @varname(A[:, 1][2]) @test(@varname(A[:, 1][2]) == VarName{:A}(@o(_[:, 1]) ⨟ @o(_[2]))) # concretization y = zeros(10, 10) - x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0],) + x = (a=[1.0 2.0; 3.0 4.0; 5.0 6.0],) @test @varname(y[begin, i], true) == @varname(y[1, 1:10]) @test test_equal(@varname(y[:], true), @varname(y[1:100])) @test test_equal(@varname(y[:, begin], true), @varname(y[1:10, 1])) @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === - AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) + AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) @test test_equal(@varname(x.a[1:end, end][:], true), @varname(x.a[1:3, 2][1:3])) end @@ -90,11 +90,11 @@ end end @testset "get & set" begin - x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 1.0) + x = (a=[1.0 2.0; 3.0 4.0; 5.0 6.0], b=1.0) @test get(x, @varname(a[1, 2])) == 2.0 @test get(x, @varname(b)) == 1.0 - @test set(x, @varname(a[1, 2]), 10) == (a = [1.0 10.0; 3.0 4.0; 5.0 6.0], b = 1.0) - @test set(x, @varname(b), 10) == (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 10.0) + @test set(x, @varname(a[1, 2]), 10) == (a=[1.0 10.0; 3.0 4.0; 5.0 6.0], b=1.0) + @test set(x, @varname(b), 10) == (a=[1.0 2.0; 3.0 4.0; 5.0 6.0], b=10.0) end @testset "subsumption with standard indexing" begin @@ -144,13 +144,11 @@ end @testset "non-standard indexing" begin A = rand(10, 10) @test test_equal( - @varname(A[1, Not(3)], true), - @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]) + @varname(A[1, Not(3)], true), @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]]) ) B = OffsetArray(A, -5, -5) # indices -4:5×-4:5 @test test_equal(@varname(B[1, :], true), @varname(B[1, -4:5])) - end @testset "type stability" begin @inferred VarName{:a}() @@ -159,11 +157,11 @@ end @inferred VarName{:a}(PropertyLens(:b)) @inferred VarName{:a}(Accessors.opcompose(IndexLens(1), PropertyLens(:b))) - b = (a = [1, 2, 3],) + b = (a=[1, 2, 3],) @inferred get(b, @varname(a[1])) @inferred Accessors.set(b, @varname(a[1]), 10) - c = (b = (a = [1, 2, 3],),) + c = (b=(a=[1, 2, 3],),) @inferred get(c, @varname(b.a[1])) @inferred Accessors.set(c, @varname(b.a[1]), 10) end From 19e3ea4740a7b7c064f494642461cbe752400563 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 5 Dec 2024 15:16:28 +0000 Subject: [PATCH 3/3] Format readme --- README.md | 210 +++++++++++++++++++++++++----------------------------- 1 file changed, 99 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index f95b508..f92ecab 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,12 @@ A light-weight package to factor out interfaces and associated APIs for modelling languages for probabilistic programming. High level goals are: -- Definition of an interface of few abstract types and a small set of functions that should be -supported by all [probabilistic programs](./src/abstractprobprog.jl) and [trace -types](./src/abstractmodeltrace.jl). -- Provision of some commonly used functionality and data structures, e.g., for managing [variable names](./src/varname.jl) and - traces. - + - Definition of an interface of few abstract types and a small set of functions that should be + supported by all [probabilistic programs](./src/abstractprobprog.jl) and [trace + types](./src/abstractmodeltrace.jl). + - Provision of some commonly used functionality and data structures, e.g., for managing [variable names](./src/varname.jl) and + traces. + This should facilitate reuse of functions in modelling languages, to allow end users to handle models in a consistent way, and to simplify interaction between different languages and sampler implementations, from very rich, dynamic languages like Turing.jl to highly constrained or @@ -22,41 +22,38 @@ simplified models such as GPs, GLMs, or plain log-density problems. A more short term goal is to start a process of cleanly refactoring and justifying parts of DynamicPPL.jl’s design, and hopefully to get on closer terms with Soss.jl. - ## `AbstractProbabilisticProgram` interface (still somewhat drafty) There are at least two somewhat incompatible conventions used for the term “model”. None of this is particularly exact, but: -- In Turing.jl, if you write down a `@model` function and call it on arguments, you get a model - object paired with (a possibly empty set of) observations. This can be treated as instantiated - “conditioned” object with fixed values for parameters and observations. -- In Soss.jl, “model” is used for a symbolic “generative” object from which concrete functions, such as - densities and sampling functions, can be derived, _and_ which you can later condition on (and in - turn get a conditional density etc.). + - In Turing.jl, if you write down a `@model` function and call it on arguments, you get a model + object paired with (a possibly empty set of) observations. This can be treated as instantiated + “conditioned” object with fixed values for parameters and observations. + - In Soss.jl, “model” is used for a symbolic “generative” object from which concrete functions, such as + densities and sampling functions, can be derived, _and_ which you can later condition on (and in + turn get a conditional density etc.). Relevant discussions: [1](https://julialang.zulipchat.com/#narrow/stream/234072-probprog/topic/Naming.20the.20.22likelihood.22.20thingy), [2](https://github.com/TuringLang/AbstractPPL.jl/discussions/10). - ### TL/DR: There are three interrelating aspects that this interface intends to standardize: -- Density calculation -- Sampling -- “Conversions” between different conditionings of models + - Density calculation + - Sampling + - “Conversions” between different conditionings of models Therefore, the interface consists of an `AbstractProbabilisticProgram` supertype, together with functions -- `condition(::Model, ::Trace) -> ConditionedModel` -- `decondition(::ConditionedModel) -> GenerativeModel` -- `sample(::Model, ::Sampler = Exact(), [Int])` (from `AbstractMCMC.sample`) -- `logdensityof(::Model, ::Trace)` and `densityof(::Model, ::Trace)` (from - [DensityInterface.jl](https://github.com/JuliaMath/DensityInterface.jl)) - + - `condition(::Model, ::Trace) -> ConditionedModel` + - `decondition(::ConditionedModel) -> GenerativeModel` + - `sample(::Model, ::Sampler = Exact(), [Int])` (from `AbstractMCMC.sample`) + - `logdensityof(::Model, ::Trace)` and `densityof(::Model, ::Trace)` (from + [DensityInterface.jl](https://github.com/JuliaMath/DensityInterface.jl)) ### Traces & probability expressions @@ -79,7 +76,6 @@ just choose some arbitrary macro-like syntax like the following: Some more ideas for this kind of object can be found at the end. - ### “Conversions” The purpose of this part is to provide common names for how we want a model instance to be @@ -95,7 +91,7 @@ Let’s start from a generative model with parameter `μ`: @generative_model function foo_gen(μ) X ~ Normal(0, μ) Y[1] ~ Normal(X) - Y[2] ~ Normal(X + 1) + return Y[2] ~ Normal(X + 1) end ``` @@ -103,7 +99,7 @@ Applying the “constructor” `foo_gen` now means to fix the parameter, and sho object of the generative type: ```julia -g = foo_gen(μ=…)::SomeGenerativeModel +g = foo_gen(; μ=…)::SomeGenerativeModel ``` With this kind of object, we should be able to sample and calculate joint log-densities from, i.e., @@ -131,10 +127,10 @@ we have a situation like this, with the observations `Y` fixed in the instantiat @model function foo(Y, μ) X ~ Normal(0, μ) Y[1] ~ Normal(X) - Y[2] ~ Normal(X + 1) + return Y[2] ~ Normal(X + 1) end -m = foo(Y=…, μ=…)::SomeConditionedModel +m = foo(; Y=…, μ=…)::SomeConditionedModel ``` From this we can, if supported, go back to the generative form via `decondition`, and back via @@ -170,7 +166,6 @@ rather easy, since it is only a marginal of the generative distribution, while t more structural information. Perhaps both can be generalized under the `query` function I discuss at the end. - ### Sampling Sampling in this case refers to producing values from the distribution specified in a model @@ -192,15 +187,15 @@ a (posterior) conditioned model with no known sampling procedure, we just have w `AbstractMCMC`: ```julia -sample([rng], m, N, sampler; [args…]) # chain of length N using `sampler` +sample([rng], m, N, sampler; [args]) # chain of length N using `sampler` ``` In the case of a generative model, or a posterior model with exact solution, we can have some more methods without the need to specify a sampler: ```julia -sample([rng], m; [args…]) # one random sample -sample([rng], m, N; [args…]) # N iid samples; equivalent to `rand` in certain cases +sample([rng], m; [args]) # one random sample +sample([rng], m, N; [args]) # N iid samples; equivalent to `rand` in certain cases ``` It should be possible to implement this by a special sampler, say, `Exact` (name still to be @@ -223,7 +218,6 @@ Not all variants need to be supported – for example, a posterior model might n `rand` is then just a special case when “trivial” exact sampling works for a model, e.g. a joint model. - ### Density Evaluation Since the different “versions” of how a model is to be understood as generative or conditioned are @@ -234,7 +228,7 @@ therefore adapt the interface of `logdensityof` should suffice for variants, with the distinction being made by the capabilities of the concrete model instance. - DensityInterface.jl also requires the trait function `DensityKind`, which is set to `HasDensity()` +DensityInterface.jl also requires the trait function `DensityKind`, which is set to `HasDensity()` for the `AbstractProbabilisticProgram` type. Additional functions ``` @@ -243,7 +237,7 @@ DensityInterface.logdensityof(d) = Base.Fix1(logdensityof, d) DensityInterface.densityof(d) = Base.Fix1(densityof, d) ``` -are provided automatically (repeated here for clarity). +are provided automatically (repeated here for clarity). Note that `logdensityof` strictly generalizes `logpdf`, since the posterior density will of course in general be unnormalized and hence not a probability density. @@ -265,8 +259,7 @@ logdensityof(m, @T(X = …)) Densities need (and usually, will) not be normalized. - -#### Implementation notes +#### Implementation notes It should be able to make this fall back on the internal method with the right definition and implementation of `maketrace`: @@ -286,7 +279,6 @@ logdensityof(g, @T(X = …, Y = …, Z = …); normalized=Val{true}) Although there is proably a better way through traits; maybe like for arrays, with `NormalizationStyle(g, t) = IsNormalized()`? - ## More on probability expressions Note that this needs to be a macro, if written this way, since the keys may themselves be more @@ -311,14 +303,13 @@ and probability expression combination. Possible extensions of this idea: -- Pearl-style do-notation: `@T(Y = y | do(X = x))` -- Allowing free variables, to specify model transformations: `query(m, @T(X | Y))` -- “Graph queries”: `@T(X | Parents(X))`, `@T(Y | Not(X))` (a nice way to express Gibbs conditionals!) -- Predicate style for “measure queries”: `@T(X < Y + Z)` + - Pearl-style do-notation: `@T(Y = y | do(X = x))` + - Allowing free variables, to specify model transformations: `query(m, @T(X | Y))` + - “Graph queries”: `@T(X | Parents(X))`, `@T(Y | Not(X))` (a nice way to express Gibbs conditionals!) + - Predicate style for “measure queries”: `@T(X < Y + Z)` The latter applications are the reason I originally liked the idea of the macro being called `@P` -(or even `@𝓅` or `@ℙ`), since then it would look like a “Bayesian probability expression”: `@P(X < -Y + Z)`. But this would not be so meaningful in the case of representing a trace instance. +(or even `@𝓅` or `@ℙ`), since then it would look like a “Bayesian probability expression”: `@P(X < Y + Z)`. But this would not be so meaningful in the case of representing a trace instance. Perhaps both `@T` and `@P` can coexist, and both produce different kinds of `ProbabilityExpression` objects? @@ -326,8 +317,6 @@ objects? NB: the exact details of this kind of “schema application”, and what results from it, will need to be specified in the interface of `AbstractModelTrace`, aka “the new `VarInfo`”. - - # `AbstractModelTrace`/`VarInfo` interface draft **This part is even draftier than the above – we’ll try out things in DynamicPPL.jl first** @@ -336,7 +325,7 @@ be specified in the interface of `AbstractModelTrace`, aka “the new `VarInfo` ### Why do we do this? -As I have said before: +As I have said before: > There are many aspects that make VarInfo a very complex data structure. @@ -363,8 +352,8 @@ for a dictionary-like structure. Related previous discussions: -- [Discussion about `VarName`](https://github.com/TuringLang/AbstractPPL.jl/discussions/7) -- [`AbstractVarInfo` representation](https://github.com/TuringLang/AbstractPPL.jl/discussions/5) + - [Discussion about `VarName`](https://github.com/TuringLang/AbstractPPL.jl/discussions/7) + - [`AbstractVarInfo` representation](https://github.com/TuringLang/AbstractPPL.jl/discussions/5) Additionally (but closely related), the second part tries to formalize the “subsumption” mechanism of `VarName`s, and its interaction with using `VarName`s as keys/indices. @@ -378,57 +367,56 @@ ParetoSmoothing.jl. ### What is going to change? -- For the end user of Turing.jl: nothing. You usually don’t use `VarInfo`, or the raw evaluator -interface, anyways. (Although if the newer data structures are more user-friendly, they might occur -in more places in the future?) -- For people having a look into code using `VarInfo`, or starting to hack on Turing.jl/DPPL.jl: a -huge reduction in cognitive complexity. `VarInfo` implementations should be readable on their own, -and the implemented functions layed out somewhere. Its usages should look like for any other nice, -normal data structure. -- For core DPPL.jl implementors: same as the previous, plus: a standard against which to improve and -test `VarInfo`, and a clearly defined design space for new data structures. -- For AbstractPPL.jl clients/PPL implementors: an interface to program against (as with the rest of -APPL), and an existing set of well-specified, flexible trace data types with different -characteristics. + - For the end user of Turing.jl: nothing. You usually don’t use `VarInfo`, or the raw evaluator + interface, anyways. (Although if the newer data structures are more user-friendly, they might occur + in more places in the future?) + - For people having a look into code using `VarInfo`, or starting to hack on Turing.jl/DPPL.jl: a + huge reduction in cognitive complexity. `VarInfo` implementations should be readable on their own, + and the implemented functions layed out somewhere. Its usages should look like for any other nice, + normal data structure. + - For core DPPL.jl implementors: same as the previous, plus: a standard against which to improve and + test `VarInfo`, and a clearly defined design space for new data structures. + - For AbstractPPL.jl clients/PPL implementors: an interface to program against (as with the rest of + APPL), and an existing set of well-specified, flexible trace data types with different + characteristics. And in terms of implementation work in DPPL.jl: once the interface is fixed (or even during fixing it), varinfo.jl will undergo a heavy refactoring – which should make it _simpler_! (No three different getter functions with slightly different semantics, etc…). - ## Property interface The basic idea is for all `VarInfo`s to behave like ordered dictionaries with `VarName` keys – all common operations should just work. There are two things that make them more special, though: -1. “Fancy indexing”: since `VarName`s are structured themselves, the `VarInfo` should be have a bit - like a trie, in the sense that all prefixes of stored keys should be retrievable. Also, - subsumption of `VarName`s should be respected (see end of this document): - + 1. “Fancy indexing”: since `VarName`s are structured themselves, the `VarInfo` should be have a bit + like a trie, in the sense that all prefixes of stored keys should be retrievable. Also, + subsumption of `VarName`s should be respected (see end of this document): + ```julia - vi[@varname(x.a)] = [1,2,3] - vi[@varname(x.b)] = [4,5,6] + vi[@varname(x.a)] = [1, 2, 3] + vi[@varname(x.b)] = [4, 5, 6] vi[@varname(x.a[2])] == 2 - vi[@varname(x)] == (; a = [1,2,3], b = [4,5,6]) + vi[@varname(x)] == (; a=[1, 2, 3], b=[4, 5, 6]) ``` Generalizations that go beyond simple cases (those that you can imagine by storing individual `setfield!`s in a tree) need not be implemented in the beginning; e.g., - + ```julia vi[@varname(x[1])] = 1 vi[@varname(x[2])] = 2 keys(vi) == [x[1], x[2]] - vi[@varname(x)] = [1,2] + vi[@varname(x)] = [1, 2] keys(vi) == [x] ``` - -2. (_This has to be discussed further._) Information other than the sampled values, such as flags, - metadata, pointwise likelihoods, etc., can in principle be stored in multiple of these “`VarInfo` - dicts” with parallel structure. For efficiency, it is thinkable to devise a design such that - multiple fields can be stored under the same indexing structure. + 2. (_This has to be discussed further._) Information other than the sampled values, such as flags, + metadata, pointwise likelihoods, etc., can in principle be stored in multiple of these “`VarInfo` + dicts” with parallel structure. For efficiency, it is thinkable to devise a design such that + multiple fields can be stored under the same indexing structure. + ```julia vi[@varname(x[1])] == 1 vi[@varname(x[1])].meta["bla"] == false @@ -446,23 +434,26 @@ common operations should just work. There are two things that make them more sp The important question here is: should the “joint data structure” behave like a dictionary of `NamedTuple`s (`eltype(vi) == @NamedTuple{value::T, ℓ::Float64, meta}`), or like a struct of dicts with shared keys (`eltype(vi.value) <: T`, `eltype(vi.ℓ) <: Float64`, …)? - + The required dictionary functions are about the following: -- Pure functions: - - `iterate`, yielding pairs of `VarName` and the stored value - - `IteratorEltype == HasEltype()`, `IteratorSize = HasLength()` - - `keys`, `values`, `pairs`, `length` consistent with `iterate` - - `eltype`, `keytype`, `valuetype` - - `get`, `getindex`, `haskey` for indexing by `VarName` - - `merge` to join two `VarInfo`s -- Mutating functions: - - `insert!!`, `set!!` - - `merge!!` to add and join elements (TODO: think about `merge`) - - `setindex!!` - - `empty!!`, `delete!!`, `unset!!` (_Are these really used anywhere? Not having them makes persistent - implementations much easier!_) + - Pure functions: + + `iterate`, yielding pairs of `VarName` and the stored value + + `IteratorEltype == HasEltype()`, `IteratorSize = HasLength()` + + `keys`, `values`, `pairs`, `length` consistent with `iterate` + + `eltype`, `keytype`, `valuetype` + + `get`, `getindex`, `haskey` for indexing by `VarName` + + `merge` to join two `VarInfo`s + + - Mutating functions: + + + `insert!!`, `set!!` + + `merge!!` to add and join elements (TODO: think about `merge`) + + `setindex!!` + + `empty!!`, `delete!!`, `unset!!` (_Are these really used anywhere? Not having them makes persistent + implementations much easier!_) + I believe that adopting the interface of [Dictionaries.jl](https://github.com/andyferris/Dictionaries.jl), not `Base.AbstractDict`, would be ideal, since their approach make key sharing and certain operations naturally easy (particularly @@ -471,14 +462,13 @@ ideal, since their approach make key sharing and certain operations naturally ea Other `Base` functions, like `enumerate`, should follow from the above. `length` might appear weird – but it should definitely be consistent with the iterator. - + It would be really cool if `merge` supported the combination of distinct types of implementations, e.g., a dynamic and a tuple-based part. To support both mutable and immutable/persistent implementations, let’s require consistent BangBang.jl style mutators throughout. - ## Transformations/Bijectors Transformations should ideally be handled explicitely and from outside: automatically by the @@ -490,20 +480,18 @@ Implementation-wise, they can probably be expressed as folds? map(v -> link(v.dist, v.value), vi) ``` - ## Linearization There are multiple possible approaches to handle this: -1. As a special case of conversion: `Vector(vi)` -2. `copy!(vals_array, vi)`. -3. As a fold: `mapreduce(v -> vec(v.value), append!, vi, init=Float64[])` + 1. As a special case of conversion: `Vector(vi)` + 2. `copy!(vals_array, vi)`. + 3. As a fold: `mapreduce(v -> vec(v.value), append!, vi, init=Float64[])` Also here, I think that the best implementation would be through a fold. Variants (1) or (2) might additionally be provided as syntactic sugar. - ---- +* * * # `VarName`-based axioms @@ -518,7 +506,7 @@ Now, `VarName`s have a compositional structure: they can be built by composing a more and more lenses (`VarName{v}()` starts off with an `IdentityLens`): ```julia -julia> vn = VarName{:x}() ∘ Setfield.IndexLens((1:10, 1) ∘ Setfield.IndexLens((2, ))) +julia> vn = VarName{:x}() ∘ Setfield.IndexLens((1:10, 1) ∘ Setfield.IndexLens((2,))) x[1:10,1][2] ``` @@ -535,21 +523,21 @@ subsumes(@varname(x.a), @varname(x.a[1])) Thus, we have the following axioms for `VarName`s (“variables” are `VarName{n}()`): -1. `x ⊑ x` for all variables `x` -2. `x ≍ y` for `x ≠ y` (i.e., distinct variables are incomparable; `x ⋢ y` and `y ⋢ x`) (`≍` is `\asymp`) -3. `x ∘ ℓ ⊑ x` for all variables `x` and lenses `ℓ` -4. `x ∘ ℓ₁ ⊑ x ∘ ℓ₂ ⇔ ℓ₁ ⊑ ℓ₂` + 1. `x ⊑ x` for all variables `x` + 2. `x ≍ y` for `x ≠ y` (i.e., distinct variables are incomparable; `x ⋢ y` and `y ⋢ x`) (`≍` is `\asymp`) + 3. `x ∘ ℓ ⊑ x` for all variables `x` and lenses `ℓ` + 4. `x ∘ ℓ₁ ⊑ x ∘ ℓ₂ ⇔ ℓ₁ ⊑ ℓ₂` For the last axiom to work, we also have to define subsumption of individual, non-composed lenses: -1. `PropertyLens(a) == PropertyLens(b) ⇔ a == b`, for all symbols `a`, `b` -2. `FunctionLens(f) == FunctionLens(g) ⇔ f == g` (under extensional equality; I’m only mentioning - this in case we ever generalize to Bijector-ed variables like `@varname(log(x))`) -3. `IndexLens(ι₁) ⊑ IndexLens(ι₂)` if the index tuple `ι₂` covers all indices in `ι₁`; for example, - `_[1, 2:10] ⊑ _[1:10, 1:20]`. (_This is a bit fuzzy and not all corner cases have been - considered yet!_) -4. `IdentityLens() == IdentityLens()` -4. `ℓ₁ ≍ ℓ₂`, otherwise + 1. `PropertyLens(a) == PropertyLens(b) ⇔ a == b`, for all symbols `a`, `b` + 2. `FunctionLens(f) == FunctionLens(g) ⇔ f == g` (under extensional equality; I’m only mentioning + this in case we ever generalize to Bijector-ed variables like `@varname(log(x))`) + 3. `IndexLens(ι₁) ⊑ IndexLens(ι₂)` if the index tuple `ι₂` covers all indices in `ι₁`; for example, + `_[1, 2:10] ⊑ _[1:10, 1:20]`. (_This is a bit fuzzy and not all corner cases have been + considered yet!_) + 4. `IdentityLens() == IdentityLens()` + 5. `ℓ₁ ≍ ℓ₂`, otherwise Together, this should make `VarName`s under subsumption a reflexive poset.