From c20eb01d746a9d3f1046b50d874f66bb9f4c27e3 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 24 Oct 2023 08:32:19 +0100 Subject: [PATCH 1/8] Add `BangBang.possible` for general arrays --- src/utils.jl | 6 ++++++ test/utils.jl | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 9ba66e63e..569605e6d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -569,6 +569,12 @@ function BangBang.possible( return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) end +function BangBang.possible( + ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg +) where {C<:AbstractArray,T<:AbstractArray} + return BangBang.implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) +end # HACK(torfjelde): This makes it so it works on iterators, etc. by default. # TODO(torfjelde): Do better. diff --git a/test/utils.jl b/test/utils.jl index 1fcf09ef1..5afb5b947 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,4 +48,24 @@ x = rand(dist) @test vectorize(dist, x) == vec(x.UL) end + + @testset "BangBang.possible" begin + a = zeros(3, 3, 3, 3) + svi = SimpleVarInfo(Dict(@varname(a) => a)) + DynamicPPL.setindex!!(svi, ones(3, 2), @varname(a[1, 1:3, 1, 1:2])) + @test eltype(svi[@varname(a)]) != Any + + DynamicPPL.setindex!!(svi, ones(3), @varname(a[1, 1, :, 1])) + @test eltype(svi[@varname(a)]) != Any + + DynamicPPL.setindex!!(svi, [1, 2], @varname(a[[5, 8]])) + @test eltype(svi[@varname(a)]) != Any + + DynamicPPL.setindex!!( + svi, + [1, 2], + @varname(a[[CartesianIndex(1, 1, 3, 1), CartesianIndex(1, 1, 3, 2)]]) + ) + @test eltype(svi[@varname(a)]) != Any + end end From 88fdc24ec8b96ac437571b6c392a9704e81bf3c1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 24 Oct 2023 13:11:44 +0100 Subject: [PATCH 2/8] Remove redundant BangBang.possible --- src/utils.jl | 29 ----------------------------- test/utils.jl | 9 ++++++++- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 569605e6d..9d6381293 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -540,35 +540,6 @@ end # HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233 # and https://github.com/JuliaFolds/BangBang.jl/pull/238. # HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. -function BangBang.possible( - ::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer -) where {C<:AbstractMatrix,T<:AbstractVector} - return BangBang.implements(setindex!, C) && - promote_type(eltype(C), eltype(T)) <: eltype(C) -end -function BangBang.possible( - ::typeof(BangBang._setindex!), ::C, ::T, ::AbstractPPL.ConcretizedSlice, ::Integer -) where {C<:AbstractMatrix,T<:AbstractVector} - return BangBang.implements(setindex!, C) && - promote_type(eltype(C), eltype(T)) <: eltype(C) -end -# HACK: Makes it possible to use ranges, etc. for setting a vector. -# For example, without this hack, BangBang.jl will consider -# -# x[1:2] = [1, 2] -# -# as NOT supported. This results is calling the immutable -# `BangBang.setindex` instead, which also ends up expanding the -# type of the containing array (`x` in the above scenario) to -# have element type `Any`. -# The below code just, correctly, marks this as possible and -# thus we hit the mutable `setindex!` instead. -function BangBang.possible( - ::typeof(BangBang._setindex!), ::C, ::T, ::AbstractVector{<:Integer} -) where {C<:AbstractVector,T<:AbstractVector} - return BangBang.implements(setindex!, C) && - promote_type(eltype(C), eltype(T)) <: eltype(C) -end function BangBang.possible( ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg ) where {C<:AbstractArray,T<:AbstractArray} diff --git a/test/utils.jl b/test/utils.jl index 5afb5b947..fe72134a3 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -50,7 +50,7 @@ end @testset "BangBang.possible" begin - a = zeros(3, 3, 3, 3) + a = zeros(3, 3, 3, 3) # also allow varname concretization svi = SimpleVarInfo(Dict(@varname(a) => a)) DynamicPPL.setindex!!(svi, ones(3, 2), @varname(a[1, 1:3, 1, 1:2])) @test eltype(svi[@varname(a)]) != Any @@ -67,5 +67,12 @@ @varname(a[[CartesianIndex(1, 1, 3, 1), CartesianIndex(1, 1, 3, 2)]]) ) @test eltype(svi[@varname(a)]) != Any + + svi = SimpleVarInfo(Dict(@varname(b) => [zeros(2), zeros(3)])) + DynamicPPL.setindex!!(svi, ones(2), @varname(b[1])) + @test eltype(svi[@varname(b)][1]) != Any + + DynamicPPL.setindex!!(svi, ones(2), @varname(b[2][1:2])) + @test eltype(svi[@varname(b)][2]) != Any end end From a9b51f1d8f09eab0bbf741bf363e55b26f6f15e6 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 31 Oct 2023 11:16:55 +0000 Subject: [PATCH 3/8] add comments rt JuliaFold2 PR --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 9d6381293..6a5bc7cc2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -538,7 +538,7 @@ function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) whe end # HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233 -# and https://github.com/JuliaFolds/BangBang.jl/pull/238. +# and https://github.com/JuliaFolds/BangBang.jl/pull/238, https://github.com/JuliaFolds2/BangBang.jl/pull/16. # HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. function BangBang.possible( ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg From 471b89928b55c4efa4477da9d8244167a63bd465 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 1 Nov 2023 07:12:25 +0000 Subject: [PATCH 4/8] copy tor's fix to JuliaFolds here --- src/utils.jl | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 6a5bc7cc2..2bc808b04 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -539,12 +539,39 @@ end # HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233 # and https://github.com/JuliaFolds/BangBang.jl/pull/238, https://github.com/JuliaFolds2/BangBang.jl/pull/16. -# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. -function BangBang.possible( - ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg -) where {C<:AbstractArray,T<:AbstractArray} - return BangBang.implements(setindex!, C) && - promote_type(eltype(C), eltype(T)) <: eltype(C) +# This avoids type-instability in `dot_assume` for `SimpleVarInfo`. +# The following code a copy from https://github.com/JuliaFolds2/BangBang.jl/pull/16 authored by torfjelde +# Default implementation for `_setindex!` with `AbstractArray`. +# But this will return `false` even in cases such as +# +# setindex!!([1, 2, 3], [4, 5, 6], :) +# +# because `promote_type(eltype(C), T) <: eltype(C)` is `false`. +# To address this, we specialize on the case where `T<:AbstractArray`. +# In addition, we need to support a wide range of indexing behaviors: +# +# We also need to ensure that the dimensionality of the index is +# valid, i.e. that we're not returning `true` in cases such as +# +# setindex!!([1, 2, 3], [4, 5], 1) +# +# which should return `false`. +_index_dimension(::Any) = 0 +_index_dimension(::Colon) = 1 +_index_dimension(::AbstractVector) = 1 +_index_dimension(indices::Tuple) = sum(map(_index_dimension, indices)) + +function possible( + ::typeof(_setindex!), ::C, ::T, indices::Vararg +) where {M,C<:AbstractArray{<:Real},T<:AbstractArray{<:Real,M}} + return implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) && + # This will still return `false` for scenarios such as + # + # setindex!!([1, 2, 3], [4, 5, 6], :, 1) + # + # which are in fact valid. However, this cases are rare. + (_index_dimension(indices) == M || _index_dimension(indices) == 1) end # HACK(torfjelde): This makes it so it works on iterators, etc. by default. From c4c1105f1d4e0fa0ccea95f3458e3f2419a2695e Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 1 Nov 2023 07:19:01 +0000 Subject: [PATCH 5/8] add BangBang prefix to the functions --- src/utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 2bc808b04..ca068b1dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -561,10 +561,10 @@ _index_dimension(::Colon) = 1 _index_dimension(::AbstractVector) = 1 _index_dimension(indices::Tuple) = sum(map(_index_dimension, indices)) -function possible( - ::typeof(_setindex!), ::C, ::T, indices::Vararg +function BangBang.possible( + ::typeof(BangBang._setindex!), ::C, ::T, indices::Vararg ) where {M,C<:AbstractArray{<:Real},T<:AbstractArray{<:Real,M}} - return implements(setindex!, C) && + return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) && # This will still return `false` for scenarios such as # From 1bb3c49f802aa5acfc851da9b2784d18d5181563 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 1 Nov 2023 07:36:26 +0000 Subject: [PATCH 6/8] remove tests of possible, no longer needed --- test/utils.jl | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index fe72134a3..1fcf09ef1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,31 +48,4 @@ x = rand(dist) @test vectorize(dist, x) == vec(x.UL) end - - @testset "BangBang.possible" begin - a = zeros(3, 3, 3, 3) # also allow varname concretization - svi = SimpleVarInfo(Dict(@varname(a) => a)) - DynamicPPL.setindex!!(svi, ones(3, 2), @varname(a[1, 1:3, 1, 1:2])) - @test eltype(svi[@varname(a)]) != Any - - DynamicPPL.setindex!!(svi, ones(3), @varname(a[1, 1, :, 1])) - @test eltype(svi[@varname(a)]) != Any - - DynamicPPL.setindex!!(svi, [1, 2], @varname(a[[5, 8]])) - @test eltype(svi[@varname(a)]) != Any - - DynamicPPL.setindex!!( - svi, - [1, 2], - @varname(a[[CartesianIndex(1, 1, 3, 1), CartesianIndex(1, 1, 3, 2)]]) - ) - @test eltype(svi[@varname(a)]) != Any - - svi = SimpleVarInfo(Dict(@varname(b) => [zeros(2), zeros(3)])) - DynamicPPL.setindex!!(svi, ones(2), @varname(b[1])) - @test eltype(svi[@varname(b)][1]) != Any - - DynamicPPL.setindex!!(svi, ones(2), @varname(b[2][1:2])) - @test eltype(svi[@varname(b)][2]) != Any - end end From 39f4ff15b7a5a915665d62f6487afb5f1183212c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 1 Nov 2023 17:19:13 +0000 Subject: [PATCH 7/8] Add Tor's tests --- test/utils.jl | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/test/utils.jl b/test/utils.jl index 1fcf09ef1..39bcd78e1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,4 +48,117 @@ x = rand(dist) @test vectorize(dist, x) == vec(x.UL) end + + @testset "BangBang.possible" begin + # Some utility methods for testing `setindex!`. + test_linear_index_only(::Tuple, ::AbstractArray) = false + test_linear_index_only(inds::NTuple{1}, ::AbstractArray) = true + test_linear_index_only(inds::NTuple{1}, ::AbstractVector) = false + + function replace_colon_with_axis(inds::Tuple, x) + ntuple(length(inds)) do i + inds[i] isa Colon ? axes(x, i) : inds[i] + end + end + function replace_colon_with_vector(inds::Tuple, x) + ntuple(length(inds)) do i + inds[i] isa Colon ? collect(axes(x, i)) : inds[i] + end + end + function replace_colon_with_range(inds::Tuple, x) + ntuple(length(inds)) do i + inds[i] isa Colon ? (1:size(x, i)) : inds[i] + end + end + function replace_colon_with_booleans(inds::Tuple, x) + ntuple(length(inds)) do i + inds[i] isa Colon ? trues(size(x, i)) : inds[i] + end + end + + function replace_colon_with_range_linear(inds::NTuple{1}, x::AbstractArray) + return inds[1] isa Colon ? (1:length(x),) : inds + end + + @testset begin + @test setindex!!((1, 2, 3), :two, 2) === (1, :two, 3) + @test setindex!!((a=1, b=2, c=3), :two, :b) === (a=1, b=:two, c=3) + @test setindex!!([1, 2, 3], :two, 2) == [1, :two, 3] + @test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 10, :a) == + Dict(:a => 10, :b => 2) + @test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 3, "c") == + Dict(:a => 1, :b => 2, "c" => 3) + end + + @testset "mutation" begin + @testset "without type expansion" begin + for args in [([1, 2, 3], 20, 2), (Dict(:a => 1, :b => 2), 10, :a)] + @test setindex!!(args...) === args[1] + end + end + + @testset "with type expansion" begin + @test setindex!!([1, 2, 3], [4, 5], 1) == [[4, 5], 2, 3] + @test setindex!!([1, 2, 3], [4, 5, 6], :, 1) == [4, 5, 6] + end + end + + @testset "slices" begin + @testset "$(typeof(x)) with $(src_idx)" for (x, src_idx) in [ + # Vector. + (randn(2), (:,)), + (randn(2), (1:2,)), + # Matrix. + (randn(2, 3), (:,)), + (randn(2, 3), (:, 1)), + (randn(2, 3), (:, 1:3)), + # 3D array. + (randn(2, 3, 4), (:, 1, :)), + (randn(2, 3, 4), (:, 1:3, :)), + (randn(2, 3, 4), (1, 1:3, :)), + ] + # Base case. + @test @inferred(setindex!!(x, x[src_idx...], src_idx...)) === x + + # If we have `Colon` in the index, we replace this with other equivalent indices. + if any(Base.Fix2(isa, Colon), src_idx) + if test_linear_index_only(src_idx, x) + # With range instead of `Colon`. + @test @inferred( + setindex!!( + x, + x[src_idx...], + replace_colon_with_range_linear(src_idx, x)..., + ) + ) === x + else + # With axis instead of `Colon`. + @test @inferred( + setindex!!( + x, x[src_idx...], replace_colon_with_axis(src_idx, x)... + ) + ) === x + # With range instead of `Colon`. + @test @inferred( + setindex!!( + x, x[src_idx...], replace_colon_with_range(src_idx, x)... + ) + ) === x + # With vectors instead of `Colon`. + @test @inferred( + setindex!!( + x, x[src_idx...], replace_colon_with_vector(src_idx, x)... + ) + ) === x + # With boolean index instead of `Colon`. + @test @inferred( + setindex!!( + x, x[src_idx...], replace_colon_with_booleans(src_idx, x)... + ) + ) === x + end + end + end + end + end end From 3ca938c6fe1444723d94fa32e59efb1592999075 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 1 Nov 2023 17:44:27 +0000 Subject: [PATCH 8/8] Import `setindex!!` --- test/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/utils.jl b/test/utils.jl index 39bcd78e1..a2d6f46fb 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -50,6 +50,8 @@ end @testset "BangBang.possible" begin + using DynamicPPL.BangBang: setindex!! + # Some utility methods for testing `setindex!`. test_linear_index_only(::Tuple, ::AbstractArray) = false test_linear_index_only(inds::NTuple{1}, ::AbstractArray) = true