From b7beb12558d4e4413e81f3bb13ce1b6d7fabdcc8 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 26 Oct 2023 18:23:12 +0200 Subject: [PATCH 1/9] Extend macros with rand to support custom samplers --- Project.toml | 2 +- src/SMatrix.jl | 6 ++++++ src/SVector.jl | 7 ++++++- src/arraymath.jl | 4 ++-- test/arraymath.jl | 12 ++++++++++++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index a7b345278..726345081 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.5" +version = "1.6.6" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/SMatrix.jl b/src/SMatrix.jl index 8c50c6f54..31c564285 100644 --- a/src/SMatrix.jl +++ b/src/SMatrix.jl @@ -73,6 +73,12 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM if length(ex.args) == 3 return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base elseif length(ex.args) == 4 + if f === :rand + # supports calls like rand(Type, n, m) and rand(sampler, n, m)) + return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3])), $(esc(ex.args[4]))), $SM{$(esc(ex.args[3])),$(esc(ex.args[4]))})) + else + return :($f($SV{$(escall(ex.args[3,4,2])...)})) + end return :($f($SM{$(escall(ex.args[[3,4,2]])...)})) else error("@$SM expected a 2-dimensional array expression") diff --git a/src/SVector.jl b/src/SVector.jl index 455201283..47968e341 100644 --- a/src/SVector.jl +++ b/src/SVector.jl @@ -78,7 +78,12 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV if length(ex.args) == 2 return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base elseif length(ex.args) == 3 - return :($f($SV{$(escall(ex.args[3:-1:2])...)})) + if f === :rand + # supports calls like rand(Type, n) and rand(sampler, n)) + return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3]))), $SV{$(esc(ex.args[3]))})) + else + return :($f($SV{$(escall(ex.args[3:-1:2])...)})) + end else error("@$SV expected a 1-dimensional array expression") end diff --git a/src/arraymath.jl b/src/arraymath.jl index 8d9d75a30..fffcaed56 100644 --- a/src/arraymath.jl +++ b/src/arraymath.jl @@ -80,8 +80,8 @@ end @inline rand(rng::AbstractRNG, range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(rng, range, Size(SA), SA) @inline rand(range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(Random.GLOBAL_RNG, range, Size(SA), SA) -@generated function _rand(rng::AbstractRNG, range::AbstractArray, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray} - v = [:(rand(rng, range)) for i = 1:prod(s)] +@generated function _rand(rng::AbstractRNG, X, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray} + v = [:(rand(rng, X)) for i = 1:prod(s)] return quote @_inline_meta $SA(tuple($(v...))) diff --git a/test/arraymath.jl b/test/arraymath.jl index 682c6e23d..376aa8a78 100644 --- a/test/arraymath.jl +++ b/test/arraymath.jl @@ -1,6 +1,13 @@ using StaticArrays, Test import StaticArrays.arithmetic_closure +struct TestDie + nsides::Int +end +Random.rand(rng::AbstractRNG, ::Random.SamplerType{TestDie}) = TestDie(rand(rng, 4:20)) +Random.rand(rng::AbstractRNG, d::Random.SamplerTrivial{TestDie}) = rand(rng, 1:d[].nsides) +Base.eltype(::Type{TestDie}) = Int + @testset "Array math" begin @testset "zeros() and ones()" begin @test @inferred(zeros(SVector{3,Float64})) === @SVector [0.0, 0.0, 0.0] @@ -179,6 +186,11 @@ import StaticArrays.arithmetic_closure @test v4 isa SA{0, Float32} @test all(0 .< v4 .< 1) end + @test (@SVector rand(TestDie(6), 3)) isa SVector{3,Int} + @test (@MVector rand(TestDie(6), 3)) isa MVector{3,Int} + + @test (@SMatrix rand(TestDie(6), 3, 4)) isa SMatrix{3,4,Int} + @test (@MMatrix rand(TestDie(6), 3, 4)) isa MMatrix{3,4,Int} end @testset "rand!()" begin From 6c1511d6336907ac11db0e1606b2362079376aef Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 26 Oct 2023 18:50:52 +0200 Subject: [PATCH 2/9] Fix tests --- src/SMatrix.jl | 9 ++++----- src/SVector.jl | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/SMatrix.jl b/src/SMatrix.jl index 31c564285..8a44164f7 100644 --- a/src/SMatrix.jl +++ b/src/SMatrix.jl @@ -73,13 +73,12 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM if length(ex.args) == 3 return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base elseif length(ex.args) == 4 - if f === :rand - # supports calls like rand(Type, n, m) and rand(sampler, n, m)) - return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3])), $(esc(ex.args[4]))), $SM{$(esc(ex.args[3])),$(esc(ex.args[4]))})) + if f === :rand && ex.args[3] isa Int && ex.args[3] > 0 && ex.args[4] isa Int && ex.args[4] > 0 + # supports calls like rand(Type, n, m) and rand(sampler, n, m)), but only if n, m > 0 + return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3])), $(esc(ex.args[4]))), $SM{$(esc(ex.args[3])), $(esc(ex.args[4]))})) else - return :($f($SV{$(escall(ex.args[3,4,2])...)})) + return :($f($SM{$(escall(ex.args[[3,4,2]])...)})) end - return :($f($SM{$(escall(ex.args[[3,4,2]])...)})) else error("@$SM expected a 2-dimensional array expression") end diff --git a/src/SVector.jl b/src/SVector.jl index 47968e341..6d3ea5528 100644 --- a/src/SVector.jl +++ b/src/SVector.jl @@ -78,8 +78,8 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV if length(ex.args) == 2 return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base elseif length(ex.args) == 3 - if f === :rand - # supports calls like rand(Type, n) and rand(sampler, n)) + if f === :rand && ex.args[3] isa Int && ex.args[3] > 0 + # supports calls like rand(Type, n) and rand(sampler, n)), but only if n > 0 return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3]))), $SV{$(esc(ex.args[3]))})) else return :($f($SV{$(escall(ex.args[3:-1:2])...)})) From bfd0671e0dfc47e2b456db120117c2619591398f Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 27 Oct 2023 16:39:46 +0200 Subject: [PATCH 3/9] Update Project.toml Co-authored-by: Yuto Horikawa --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 726345081..7800f6642 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.6" +version = "1.7.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From eb485cd4c0b7f81da8e1b5492672cb1fcbb5ef1c Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 27 Oct 2023 20:47:48 +0200 Subject: [PATCH 4/9] Update test/arraymath.jl Co-authored-by: Yuto Horikawa --- test/arraymath.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/arraymath.jl b/test/arraymath.jl index 376aa8a78..5f6e13f30 100644 --- a/test/arraymath.jl +++ b/test/arraymath.jl @@ -4,7 +4,6 @@ import StaticArrays.arithmetic_closure struct TestDie nsides::Int end -Random.rand(rng::AbstractRNG, ::Random.SamplerType{TestDie}) = TestDie(rand(rng, 4:20)) Random.rand(rng::AbstractRNG, d::Random.SamplerTrivial{TestDie}) = rand(rng, 1:d[].nsides) Base.eltype(::Type{TestDie}) = Int From ce161c125d5a05f6579205edb374464375b631e3 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 27 Oct 2023 20:53:17 +0200 Subject: [PATCH 5/9] Update src/SVector.jl Co-authored-by: Yuto Horikawa --- src/SVector.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/SVector.jl b/src/SVector.jl index 6d3ea5528..00b6e9758 100644 --- a/src/SVector.jl +++ b/src/SVector.jl @@ -78,9 +78,9 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV if length(ex.args) == 2 return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base elseif length(ex.args) == 3 - if f === :rand && ex.args[3] isa Int && ex.args[3] > 0 - # supports calls like rand(Type, n) and rand(sampler, n)), but only if n > 0 - return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3]))), $SV{$(esc(ex.args[3]))})) + if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0 + # supports calls like rand(Type, n) and rand(sampler, n)) + return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3]))), $SV{$(esc(ex.args[3])), Random.gentype($(esc(ex.args[2])))})) else return :($f($SV{$(escall(ex.args[3:-1:2])...)})) end From 127dcd371f36d8d95772b56c855d5be085e81a6c Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sat, 28 Oct 2023 11:17:00 +0200 Subject: [PATCH 6/9] Update src/SMatrix.jl Co-authored-by: Yuto Horikawa --- src/SMatrix.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/SMatrix.jl b/src/SMatrix.jl index 8a44164f7..06d51ff34 100644 --- a/src/SMatrix.jl +++ b/src/SMatrix.jl @@ -73,9 +73,9 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM if length(ex.args) == 3 return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base elseif length(ex.args) == 4 - if f === :rand && ex.args[3] isa Int && ex.args[3] > 0 && ex.args[4] isa Int && ex.args[4] > 0 - # supports calls like rand(Type, n, m) and rand(sampler, n, m)), but only if n, m > 0 - return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3])), $(esc(ex.args[4]))), $SM{$(esc(ex.args[3])), $(esc(ex.args[4]))})) + if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0 && ex.args[4] isa Int && ex.args[4] ≥ 0 + # supports calls like rand(Type, n, m) and rand(sampler, n, m)) + return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3])), $(esc(ex.args[4]))), $SM{$(esc(ex.args[3])), $(esc(ex.args[4])), Random.gentype($(esc(ex.args[2])))})) else return :($f($SM{$(escall(ex.args[[3,4,2]])...)})) end From 4ed4c7c891d566cae3be4a5b2fee829ebe080d90 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sat, 28 Oct 2023 11:42:23 +0200 Subject: [PATCH 7/9] Update `SArray` macro --- src/SArray.jl | 8 +++++++- test/arraymath.jl | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/SArray.jl b/src/SArray.jl index 2094733c3..bf572bb8a 100644 --- a/src/SArray.jl +++ b/src/SArray.jl @@ -205,8 +205,14 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA} return quote if isa($(esc(ex.args[2])), DataType) $f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))}) - else + elseif isa($(esc(ex.args[2])), Integer) $f($SA{$Tuple{$(escall(ex.args[2:end])...)}}) + elseif isa($(esc(ex.args[2])), Random.AbstractRNG) + # for calls like rand(rng::AbstractRNG, sampler, dims::Integer...) + StaticArrays._rand($(esc(ex.args[2])), $(esc(ex.args[3])), Size($(escall(ex.args[4:end])...)), $SA{Tuple{$(escall(ex.args[4:end])...)}, Random.gentype($(esc(ex.args[3])))}) + else + # for calls like rand(sampler, dims::Integer...) + StaticArrays._rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(escall(ex.args[3:end])...)), $SA{Tuple{$(escall(ex.args[3:end])...)}, Random.gentype($(esc(ex.args[2])))}) end end elseif f === :fill diff --git a/test/arraymath.jl b/test/arraymath.jl index 5f6e13f30..553d7a639 100644 --- a/test/arraymath.jl +++ b/test/arraymath.jl @@ -185,11 +185,22 @@ Base.eltype(::Type{TestDie}) = Int @test v4 isa SA{0, Float32} @test all(0 .< v4 .< 1) end + rng = MersenneTwister(123) @test (@SVector rand(TestDie(6), 3)) isa SVector{3,Int} + @test (@SVector rand(TestDie(6), 0)) isa SVector{0,Int} @test (@MVector rand(TestDie(6), 3)) isa MVector{3,Int} @test (@SMatrix rand(TestDie(6), 3, 4)) isa SMatrix{3,4,Int} + @test (@SMatrix rand(TestDie(6), 0, 4)) isa SMatrix{0,4,Int} @test (@MMatrix rand(TestDie(6), 3, 4)) isa MMatrix{3,4,Int} + + @test (@SArray rand(TestDie(6), 3, 4, 5)) isa SArray{Tuple{3,4,5},Int} + @test (@SArray rand(rng, TestDie(6), 3, 4, 5)) isa SArray{Tuple{3,4,5},Int} + @test (@SArray rand(TestDie(6), 0, 4, 5)) isa SArray{Tuple{0,4,5},Int} + @test (@SArray rand(rng, TestDie(6), 0, 4, 5)) isa SArray{Tuple{0,4,5},Int} + # test if rng generator is actually respected + @test (@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5)) === (@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5)) + @test (@MArray rand(TestDie(6), 3, 4, 5)) isa MArray{Tuple{3,4,5},Int} end @testset "rand!()" begin From ebbe90ff9f9d28a3b7fbd4898ba31c99d05317d2 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sat, 28 Oct 2023 18:30:46 +0200 Subject: [PATCH 8/9] fix reported issue; support rng in SArray and SMatrix --- src/SArray.jl | 49 +++++++++++++++++++++++++++++++++++------------ src/SMatrix.jl | 21 ++++++++++++++++++-- src/SVector.jl | 21 ++++++++++++++++++-- test/arraymath.jl | 24 +++++++++++++++++++++-- 4 files changed, 97 insertions(+), 18 deletions(-) diff --git a/src/SArray.jl b/src/SArray.jl index bf572bb8a..4a9a6f050 100644 --- a/src/SArray.jl +++ b/src/SArray.jl @@ -201,18 +201,43 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA} if length(ex.args) == 1 f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)") return :($f($SA{$Tuple{},$Float64})) - end - return quote - if isa($(esc(ex.args[2])), DataType) - $f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))}) - elseif isa($(esc(ex.args[2])), Integer) - $f($SA{$Tuple{$(escall(ex.args[2:end])...)}}) - elseif isa($(esc(ex.args[2])), Random.AbstractRNG) - # for calls like rand(rng::AbstractRNG, sampler, dims::Integer...) - StaticArrays._rand($(esc(ex.args[2])), $(esc(ex.args[3])), Size($(escall(ex.args[4:end])...)), $SA{Tuple{$(escall(ex.args[4:end])...)}, Random.gentype($(esc(ex.args[3])))}) - else - # for calls like rand(sampler, dims::Integer...) - StaticArrays._rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(escall(ex.args[3:end])...)), $SA{Tuple{$(escall(ex.args[3:end])...)}, Random.gentype($(esc(ex.args[2])))}) + elseif f !== :rand || length(ex.args) == 2 + return quote + if isa($(esc(ex.args[2])), DataType) + $f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))}) + else + $f($SA{$Tuple{$(escall(ex.args[2:end])...)}}) + end + end + else + return quote + if isa($(esc(ex.args[2])), DataType) + $f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))}) + elseif isa($(esc(ex.args[2])), Integer) + $f($SA{$Tuple{$(escall(ex.args[2:end])...)}}) + elseif isa($(esc(ex.args[2])), Random.AbstractRNG) + # for calls like rand(rng::AbstractRNG, sampler, dims::Integer...) + StaticArrays._rand( + $(esc(ex.args[2])), + $(esc(ex.args[3])), + Size($(escall(ex.args[4:end])...)), + $SA{ + Tuple{$(escall(ex.args[4:end])...)}, + Random.gentype($(esc(ex.args[3]))), + }, + ) + else + # for calls like rand(sampler, dims::Integer...) + StaticArrays._rand( + Random.GLOBAL_RNG, + $(esc(ex.args[2])), + Size($(escall(ex.args[3:end])...)), + $SA{ + Tuple{$(escall(ex.args[3:end])...)}, + Random.gentype($(esc(ex.args[2]))), + }, + ) + end end end elseif f === :fill diff --git a/src/SMatrix.jl b/src/SMatrix.jl index 06d51ff34..1e92c4ebf 100644 --- a/src/SMatrix.jl +++ b/src/SMatrix.jl @@ -74,11 +74,28 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base elseif length(ex.args) == 4 if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0 && ex.args[4] isa Int && ex.args[4] ≥ 0 - # supports calls like rand(Type, n, m) and rand(sampler, n, m)) - return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3])), $(esc(ex.args[4]))), $SM{$(esc(ex.args[3])), $(esc(ex.args[4])), Random.gentype($(esc(ex.args[2])))})) + # for calls like rand(sampler, n, m) or rand(type, n, m) + return quote + StaticArrays._rand( + Random.GLOBAL_RNG, + $(esc(ex.args[2])), + Size($(esc(ex.args[3])), $(esc(ex.args[4]))), + $SM{$(esc(ex.args[3])), $(esc(ex.args[4])), Random.gentype($(esc(ex.args[2])))}, + ) + end else return :($f($SM{$(escall(ex.args[[3,4,2]])...)})) end + elseif length(ex.args) == 5 && f === :rand && ex.args[4] isa Int && ex.args[4] ≥ 0 && ex.args[5] isa Int && ex.args[5] ≥ 0 + # for calls like rand(rng::AbstractRNG, sampler, n, m) or rand(rng::AbstractRNG, type, n, m) + return quote + StaticArrays._rand( + $(esc(ex.args[2])), + $(esc(ex.args[3])), + Size($(esc(ex.args[4])), $(esc(ex.args[5]))), + $SM{$(esc(ex.args[4])), $(esc(ex.args[5])), Random.gentype($(esc(ex.args[3])))}, + ) + end else error("@$SM expected a 2-dimensional array expression") end diff --git a/src/SVector.jl b/src/SVector.jl index 00b6e9758..a26412493 100644 --- a/src/SVector.jl +++ b/src/SVector.jl @@ -79,11 +79,28 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base elseif length(ex.args) == 3 if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0 - # supports calls like rand(Type, n) and rand(sampler, n)) - return :(_rand(Random.GLOBAL_RNG, $(esc(ex.args[2])), Size($(esc(ex.args[3]))), $SV{$(esc(ex.args[3])), Random.gentype($(esc(ex.args[2])))})) + # for calls like rand(sampler, n) or rand(type, n) + return quote + StaticArrays._rand( + Random.GLOBAL_RNG, + $(esc(ex.args[2])), + Size($(esc(ex.args[3]))), + $SV{$(esc(ex.args[3])), Random.gentype($(esc(ex.args[2])))}, + ) + end else return :($f($SV{$(escall(ex.args[3:-1:2])...)})) end + elseif length(ex.args) == 4 && f === :rand && ex.args[4] isa Int && ex.args[4] ≥ 0 + # for calls like rand(rng::AbstractRNG, sampler, n) or rand(rng::AbstractRNG, type, n) + return quote + StaticArrays._rand( + $(esc(ex.args[2])), + $(esc(ex.args[3])), + Size($(esc(ex.args[4]))), + $SV{$(esc(ex.args[4])), Random.gentype($(esc(ex.args[3])))}, + ) + end else error("@$SV expected a 1-dimensional array expression") end diff --git a/test/arraymath.jl b/test/arraymath.jl index 553d7a639..be6a19ff7 100644 --- a/test/arraymath.jl +++ b/test/arraymath.jl @@ -186,21 +186,41 @@ Base.eltype(::Type{TestDie}) = Int @test all(0 .< v4 .< 1) end rng = MersenneTwister(123) + @test (@SVector rand(3)) isa SVector{3,Float64} + @test (@SMatrix rand(3, 4)) isa SMatrix{3,4,Float64} + @test (@SArray rand(3, 4, 5)) isa SArray{Tuple{3,4,5},Float64} + + @test (@MVector rand(3)) isa MVector{3,Float64} + @test (@MMatrix rand(3, 4)) isa MMatrix{3,4,Float64} + @test (@MArray rand(3, 4, 5)) isa MArray{Tuple{3,4,5},Float64} + @test (@SVector rand(TestDie(6), 3)) isa SVector{3,Int} + @test (@SVector rand(rng, TestDie(6), 3)) isa SVector{3,Int} @test (@SVector rand(TestDie(6), 0)) isa SVector{0,Int} + @test (@SVector rand(rng, TestDie(6), 0)) isa SVector{0,Int} @test (@MVector rand(TestDie(6), 3)) isa MVector{3,Int} + @test (@MVector rand(rng, TestDie(6), 3)) isa MVector{3,Int} @test (@SMatrix rand(TestDie(6), 3, 4)) isa SMatrix{3,4,Int} + @test (@SMatrix rand(rng, TestDie(6), 3, 4)) isa SMatrix{3,4,Int} @test (@SMatrix rand(TestDie(6), 0, 4)) isa SMatrix{0,4,Int} + @test (@SMatrix rand(rng, TestDie(6), 0, 4)) isa SMatrix{0,4,Int} @test (@MMatrix rand(TestDie(6), 3, 4)) isa MMatrix{3,4,Int} + @test (@MMatrix rand(rng, TestDie(6), 3, 4)) isa MMatrix{3,4,Int} @test (@SArray rand(TestDie(6), 3, 4, 5)) isa SArray{Tuple{3,4,5},Int} @test (@SArray rand(rng, TestDie(6), 3, 4, 5)) isa SArray{Tuple{3,4,5},Int} @test (@SArray rand(TestDie(6), 0, 4, 5)) isa SArray{Tuple{0,4,5},Int} @test (@SArray rand(rng, TestDie(6), 0, 4, 5)) isa SArray{Tuple{0,4,5},Int} - # test if rng generator is actually respected - @test (@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5)) === (@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5)) @test (@MArray rand(TestDie(6), 3, 4, 5)) isa MArray{Tuple{3,4,5},Int} + + # test if rng generator is actually respected + @test (@SVector rand(MersenneTwister(123), TestDie(6), 3)) === + (@SVector rand(MersenneTwister(123), TestDie(6), 3)) + @test (@SMatrix rand(MersenneTwister(123), TestDie(6), 3, 4)) === + (@SMatrix rand(MersenneTwister(123), TestDie(6), 3, 4)) + @test (@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5)) === + (@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5)) end @testset "rand!()" begin From 80f51368a4d33975c938dd6da69edaa1e2678c16 Mon Sep 17 00:00:00 2001 From: Yuto Horikawa Date: Wed, 3 Jan 2024 04:49:26 +0900 Subject: [PATCH 9/9] Code suggestions for #1210 (#1213) * move `ex.args[2] isa Integer` * split `if` block * simplify :zeros and :ones * refactor :rand * refactor :randn and :randexp * update comments * add _isnonnegvec * update with `_isnonnegvec` * add `_isnonnegvec(args, n)` method to check the size of `args` * fix `@SArray` for `@SArray rand(rng,T,dim)` etc. * update comments * update `@SVector` macro * update `@SMatrix` * update `@SVector` * update `@SArray` * introduce `fargs` variable * avoid `_isnonnegvec` in `static_matrix_gen` * avoid `_isnonnegvec` in `static_vector_gen` * remove unnecessary `_isnonnegvec` * add `_rng()` function * update tests on `@SVector` macro * update tests on `@MVector` macro * organize test/MMatrix.jl and test/SMatrix.jl * organize test/MMatrix.jl and test/SMatrix.jl * update with broken tests * organize test/MMatrix.jl and test/SMatrix.jl for `rand*` functions * fix around `broken` key for `@test` macro * fix zero-length tests * update `test/SArray.jl` to match `test/MArray.jl` * update tests for `@SArray ones` etc. * add supports for `@SArray ones(3-1,2)` etc. * move block for `fill` * update macro `@SArray rand(rng,2,3)` to use ordinary dispatches * update around `@SArray randn` etc. * remove unnecessary dollars * simplify `@SArray fill` * add `@testset "expand_error"` * update tests for `@SArray rand(...)` etc. * fix bug in `rand*_with_Val` * cleanup tests * update macro `@SMatrix rand(rng,2,3)` to use ordinary dispatches * update macro `@SVector rand(rng,3)` to use ordinary dispatches * move block for `fill` * simplify `_randexp_with_Val` --- src/SArray.jl | 142 ++++++++++++++++++++++++------------- src/SMatrix.jl | 86 +++++++++++++--------- src/SVector.jl | 86 +++++++++++++--------- test/MArray.jl | 180 +++++++++++++++++++++++++++++++++++++++-------- test/MMatrix.jl | 106 +++++++++++++++++++++++----- test/MVector.jl | 87 ++++++++++++++++++----- test/SArray.jl | 170 ++++++++++++++++++++++++++++++++++++-------- test/SMatrix.jl | 100 +++++++++++++++++++------- test/SVector.jl | 71 ++++++++++++++----- test/runtests.jl | 2 + 10 files changed, 780 insertions(+), 250 deletions(-) diff --git a/src/SArray.jl b/src/SArray.jl index 4a9a6f050..137a762fd 100644 --- a/src/SArray.jl +++ b/src/SArray.jl @@ -142,22 +142,65 @@ function parse_cat_ast(ex::Expr) cat_any(Val(maxdim), Val(catdim), nargs) end +#= +For example, +* `@SArray rand(2, 3, 4)` +* `@SArray rand(rng, 3, 4)` +will be expanded to the following. +* `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))` +* `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))` +The function `_int2val` is required to avoid the following case. +* `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))` +* `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))` +Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error. +=# +_int2val(x::Int) = Val(x) +_int2val(::Any) = nothing +# @SArray zeros(...) +_zeros_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = zeros(SA{Tuple{n1, ns...}}) +_zeros_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = zeros(SA{Tuple{ns...}, T}) +# @SArray ones(...) +_ones_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = ones(SA{Tuple{n1, ns...}}) +_ones_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = ones(SA{Tuple{ns...}, T}) +# @SArray rand(...) +_rand_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = rand(SA{Tuple{n1,n2,ns...}}) +_rand_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, T, Size(n1, ns...), SA{Tuple{n1, ns...}, T}) +_rand_with_Val(::Type{SA}, sampler, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, ns...), SA{Tuple{n1, ns...}, Random.gentype(sampler)}) +_rand_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(rng, Float64, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64}) +_rand_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, T, Size(ns...), SA{Tuple{ns...}, T}) +_rand_with_Val(::Type{SA}, rng::AbstractRNG, sampler, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, sampler, Size(ns...), SA{Tuple{ns...}, Random.gentype(sampler)}) +# @SArray randn(...) +_randn_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randn(SA{Tuple{n1,n2,ns...}}) +_randn_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T}) +_randn_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64}) +_randn_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randn(rng, Size(ns...), SA{Tuple{ns...}, T}) +# @SArray randexp(...) +_randexp_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randexp(SA{Tuple{n1,n2,ns...}}) +_randexp_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T}) +_randexp_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64}) +_randexp_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randexp(rng, Size(ns...), SA{Tuple{ns...}, T}) + escall(args) = Iterators.map(esc, args) +function _isnonnegvec(args) + length(args) == 0 && return false + all(isa.(args, Integer)) && return all(args .≥ 0) + return false +end function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA} if !isa(ex, Expr) error("Bad input for @$SA") end head = ex.head if head === :vect # vector - return :($SA{$Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...)))) + return :($SA{Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...)))) elseif head === :ref # typed, vector - return :($SA{$Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...)))) + return :($SA{Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...)))) elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat args = parse_cat_ast(ex) - return :($SA{$Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...)))) + return :($SA{Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...)))) elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat args = parse_cat_ast(ex) - return :($SA{$Tuple{$(size(args)...)}}($tuple($(escall(args)...)))) + return :($SA{Tuple{$(size(args)...)}}($tuple($(escall(args)...)))) elseif head === :comprehension if length(ex.args) != 1 error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]") @@ -173,7 +216,7 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA} return quote let f($(escall(rng_args)...)) = $(esc(ex.args[1])) - $SA{$Tuple{$(size(exprs)...)}}($tuple($(exprs...))) + $SA{Tuple{$(size(exprs)...)}}($tuple($(exprs...))) end end elseif head === :typed_comprehension @@ -192,57 +235,58 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA} return quote let f($(escall(rng_args)...)) = $(esc(ex.args[1])) - $SA{$Tuple{$(size(exprs)...)},$T}($tuple($(exprs...))) + $SA{Tuple{$(size(exprs)...)},$T}($tuple($(exprs...))) end end elseif head === :call f = ex.args[1] - if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp - if length(ex.args) == 1 - f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)") - return :($f($SA{$Tuple{},$Float64})) - elseif f !== :rand || length(ex.args) == 2 - return quote - if isa($(esc(ex.args[2])), DataType) - $f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))}) - else - $f($SA{$Tuple{$(escall(ex.args[2:end])...)}}) - end - end + fargs = ex.args[2:end] + if f === :zeros || f === :ones + _f_with_Val = Symbol(:_, f, :_with_Val) + if length(fargs) == 0 + # for calls like `zeros()` + return :($f($SA{Tuple{},$Float64})) + elseif _isnonnegvec(fargs) + # for calls like `zeros(dims...)` + return :($f($SA{Tuple{$(escall(fargs)...)}})) else - return quote - if isa($(esc(ex.args[2])), DataType) - $f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))}) - elseif isa($(esc(ex.args[2])), Integer) - $f($SA{$Tuple{$(escall(ex.args[2:end])...)}}) - elseif isa($(esc(ex.args[2])), Random.AbstractRNG) - # for calls like rand(rng::AbstractRNG, sampler, dims::Integer...) - StaticArrays._rand( - $(esc(ex.args[2])), - $(esc(ex.args[3])), - Size($(escall(ex.args[4:end])...)), - $SA{ - Tuple{$(escall(ex.args[4:end])...)}, - Random.gentype($(esc(ex.args[3]))), - }, - ) - else - # for calls like rand(sampler, dims::Integer...) - StaticArrays._rand( - Random.GLOBAL_RNG, - $(esc(ex.args[2])), - Size($(escall(ex.args[3:end])...)), - $SA{ - Tuple{$(escall(ex.args[3:end])...)}, - Random.gentype($(esc(ex.args[2]))), - }, - ) - end - end + # for calls like `zeros(type)` + # for calls like `zeros(type, dims...)` + return :($_f_with_Val($SA, $(esc(fargs[1])), Val($(esc(fargs[1]))), Val(tuple($(escall(fargs[2:end])...))))) end elseif f === :fill - length(ex.args) == 1 && error("@$SA got bad expression: $(ex)") - return :($f($(esc(ex.args[2])), $SA{$Tuple{$(escall(ex.args[3:end])...)}})) + # for calls like `fill(value, dims...)` + return :($f($(esc(fargs[1])), $SA{Tuple{$(escall(fargs[2:end])...)}})) + elseif f === :rand || f === :randn || f === :randexp + _f_with_Val = Symbol(:_, f, :_with_Val) + if length(fargs) == 0 + # No support for `@SArray rand()` + error("@$SA got bad expression: $(ex)") + elseif _isnonnegvec(fargs) + # for calls like `rand(dims...)` + return :($f($SA{Tuple{$(escall(fargs)...)}})) + elseif length(fargs) ≥ 2 + # for calls like `rand(dim1, dim2, dims...)` + # for calls like `rand(type, dim1, dims...)` + # for calls like `rand(sampler, dim1, dims...)` + # for calls like `rand(rng, dim1, dims...)` + # for calls like `rand(rng, type, dims...)` + # for calls like `rand(rng, sampler, dims...)` + # for calls like `randn(dim1, dim2, dims...)` + # for calls like `randn(type, dim1, dims...)` + # for calls like `randn(rng, dim1, dims...)` + # for calls like `randn(rng, type, dims...)` + # for calls like `randexp(dim1, dim2, dims...)` + # for calls like `randexp(type, dim1, dims...)` + # for calls like `randexp(rng, dim1, dims...)` + # for calls like `randexp(rng, type, dims...)` + return :($_f_with_Val($SA, $(esc(fargs[1])), $(esc(fargs[2])), _int2val($(esc(fargs[1]))), _int2val($(esc(fargs[2]))), Val(tuple($(escall(fargs[3:end])...))))) + elseif length(fargs) == 1 + # for calls like `rand(dim)` + return :($f($SA{Tuple{$(escall(fargs)...)}})) + else + error("@$SA got bad expression: $(ex)") + end else error("@$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.") end diff --git a/src/SMatrix.jl b/src/SMatrix.jl index 1e92c4ebf..e588be491 100644 --- a/src/SMatrix.jl +++ b/src/SMatrix.jl @@ -15,6 +15,21 @@ function check_matrix_size(x::Tuple, T = :S) x1, x2 end +# @SMatrix rand(...) +_rand_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2}) +_rand_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, T, Size(n1, n2), SM{n1, n2, T}) +_rand_with_Val(::Type{SM}, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)}) +_rand_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2, T}) +_rand_with_Val(::Type{SM}, rng::AbstractRNG, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(rng, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)}) +# @SMatrix randn(...) +_randn_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2}) +_randn_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randn(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T}) +_randn_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2, T}) +# @SMatrix randexp(...) +_randexp_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2}) +_randexp_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randexp(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T}) +_randexp_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2, T}) + function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM} if !isa(ex, Expr) error("Bad input for @$SM") @@ -69,44 +84,51 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM end elseif head === :call f = ex.args[1] - if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp - if length(ex.args) == 3 - return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base - elseif length(ex.args) == 4 - if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0 && ex.args[4] isa Int && ex.args[4] ≥ 0 - # for calls like rand(sampler, n, m) or rand(type, n, m) - return quote - StaticArrays._rand( - Random.GLOBAL_RNG, - $(esc(ex.args[2])), - Size($(esc(ex.args[3])), $(esc(ex.args[4]))), - $SM{$(esc(ex.args[3])), $(esc(ex.args[4])), Random.gentype($(esc(ex.args[2])))}, - ) - end - else - return :($f($SM{$(escall(ex.args[[3,4,2]])...)})) - end - elseif length(ex.args) == 5 && f === :rand && ex.args[4] isa Int && ex.args[4] ≥ 0 && ex.args[5] isa Int && ex.args[5] ≥ 0 - # for calls like rand(rng::AbstractRNG, sampler, n, m) or rand(rng::AbstractRNG, type, n, m) - return quote - StaticArrays._rand( - $(esc(ex.args[2])), - $(esc(ex.args[3])), - Size($(esc(ex.args[4])), $(esc(ex.args[5]))), - $SM{$(esc(ex.args[4])), $(esc(ex.args[5])), Random.gentype($(esc(ex.args[3])))}, - ) - end + fargs = ex.args[2:end] + if f === :zeros || f === :ones + if length(fargs) == 2 + # for calls like `zeros(dim1, dim2)` + return :($f($SM{$(escall(fargs)...)})) + elseif length(fargs[2:end]) == 2 + # for calls like `zeros(type, dim1, dim2)` + return :($f($SM{$(escall(fargs[2:end])...), $(esc(fargs[1]))})) else - error("@$SM expected a 2-dimensional array expression") + error("@$SM got bad expression: $(ex)") end - elseif ex.args[1] === :fill - if length(ex.args) == 4 - return :($f($(esc(ex.args[2])), $SM{$(escall(ex.args[3:4])...)})) + elseif f === :fill + # for calls like `fill(value, dim1, dim2)` + if length(fargs[2:end]) == 2 + return :($f($(esc(fargs[1])), $SM{$(escall(fargs[2:end])...)})) else error("@$SM expected a 2-dimensional array expression") end + elseif f === :rand || f === :randn || f === :randexp + _f_with_Val = Symbol(:_, f, :_with_Val) + if length(fargs) == 2 + # for calls like `rand(dim1, dim2)` + # for calls like `randn(dim1, dim2)` + # for calls like `randexp(dim1, dim2)` + return :($f($SM{$(escall(fargs)...)})) + elseif length(fargs) == 3 + # for calls like `rand(rng, dim1, dim2)` + # for calls like `rand(type, dim1, dim2)` + # for calls like `rand(sampler, dim1, dim2)` + # for calls like `randn(rng, dim1, dim2)` + # for calls like `randn(type, dim1, dim2)` + # for calls like `randexp(rng, dim1, dim2)` + # for calls like `randexp(type, dim1, dim2)` + return :($_f_with_Val($SM, $(esc(fargs[1])), Val($(esc(fargs[2]))), Val($(esc(fargs[3]))))) + elseif length(fargs) == 4 + # for calls like `rand(rng, type, dim1, dim2)` + # for calls like `rand(rng, sampler, dim1, dim2)` + # for calls like `randn(rng, type, dim1, dim2)` + # for calls like `randexp(rng, type, dim1, dim2)` + return :($_f_with_Val($SM, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3]))), Val($(esc(fargs[4]))))) + else + error("@$SM got bad expression: $(ex)") + end else - error("@$SM only supports the zeros(), ones(), rand(), randn(), and randexp() functions.") + error("@$SM only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.") end else error("Bad input for @$SM") diff --git a/src/SVector.jl b/src/SVector.jl index a26412493..9e0175aba 100644 --- a/src/SVector.jl +++ b/src/SVector.jl @@ -16,6 +16,21 @@ function check_vector_length(x::Tuple, T = :S) length(x) >= 1 ? x[1] : 1 end +# @SVector rand(...) +_rand_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = rand(rng, SV{n}) +_rand_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, T, Size(n), SV{n, T}) +_rand_with_Val(::Type{SV}, sampler, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, sampler, Size(n), SV{n, Random.gentype(sampler)}) +_rand_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = rand(rng, SV{n, T}) +_rand_with_Val(::Type{SV}, rng::AbstractRNG, sampler, ::Val{n}) where {SV, n} = _rand(rng, sampler, Size(n), SV{n, Random.gentype(sampler)}) +# @SVector randn(...) +_randn_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randn(rng, SV{n}) +_randn_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randn(Random.GLOBAL_RNG, Size(n), SV{n, T}) +_randn_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randn(rng, SV{n, T}) +# @SVector randexp(...) +_randexp_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randexp(rng, SV{n}) +_randexp_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randexp(Random.GLOBAL_RNG, Size(n), SV{n, T}) +_randexp_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randexp(rng, SV{n, T}) + function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV} if !isa(ex, Expr) error("Bad input for @$SV") @@ -74,44 +89,51 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV end elseif head === :call f = ex.args[1] - if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp - if length(ex.args) == 2 - return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base - elseif length(ex.args) == 3 - if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0 - # for calls like rand(sampler, n) or rand(type, n) - return quote - StaticArrays._rand( - Random.GLOBAL_RNG, - $(esc(ex.args[2])), - Size($(esc(ex.args[3]))), - $SV{$(esc(ex.args[3])), Random.gentype($(esc(ex.args[2])))}, - ) - end - else - return :($f($SV{$(escall(ex.args[3:-1:2])...)})) - end - elseif length(ex.args) == 4 && f === :rand && ex.args[4] isa Int && ex.args[4] ≥ 0 - # for calls like rand(rng::AbstractRNG, sampler, n) or rand(rng::AbstractRNG, type, n) - return quote - StaticArrays._rand( - $(esc(ex.args[2])), - $(esc(ex.args[3])), - Size($(esc(ex.args[4]))), - $SV{$(esc(ex.args[4])), Random.gentype($(esc(ex.args[3])))}, - ) - end + fargs = ex.args[2:end] + if f === :zeros || f === :ones + if length(fargs) == 1 + # for calls like `zeros(dim)` + return :($f($SV{$(esc(fargs[1]))})) + elseif length(fargs) == 2 + # for calls like `zeros(type, dim)` + return :($f($SV{$(esc(fargs[2])), $(esc(fargs[1]))})) else - error("@$SV expected a 1-dimensional array expression") + error("@$SV got bad expression: $(ex)") end - elseif ex.args[1] === :fill - if length(ex.args) == 3 - return :($f($(esc(ex.args[2])), $SV{$(esc(ex.args[3]))})) + elseif f === :fill + # for calls like `fill(value, dim)` + if length(fargs) == 2 + return :($f($(esc(fargs[1])), $SV{$(esc(fargs[2]))})) else error("@$SV expected a 1-dimensional array expression") end + elseif f === :rand || f === :randn || f === :randexp + _f_with_Val = Symbol(:_, f, :_with_Val) + if length(fargs) == 1 + # for calls like `rand(dim)` + # for calls like `randn(dim)` + # for calls like `randexp(dim)` + return :($f($SV{$(escall(fargs)...)})) + elseif length(fargs) == 2 + # for calls like `rand(rng, dim)` + # for calls like `rand(type, dim)` + # for calls like `rand(sampler, dim)` + # for calls like `randn(rng, dim)` + # for calls like `randn(type, dim)` + # for calls like `randexp(rng, dim)` + # for calls like `randexp(type, dim)` + return :($_f_with_Val($SV, $(esc(fargs[1])), Val($(esc(fargs[2]))))) + elseif length(fargs) == 3 + # for calls like `rand(rng, type, dim)` + # for calls like `rand(rng, sampler, dim)` + # for calls like `randn(rng, type, dim)` + # for calls like `randexp(rng, type, dim)` + return :($_f_with_Val($SV, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3]))))) + else + error("@$SV got bad expression: $(ex)") + end else - error("@$SV only supports the zeros(), ones(), rand(), randn() and randexp() functions.") + error("@$SV only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.") end else error("Use @$SV [a,b,c], @$SV Type[a,b,c] or a comprehension like @$SV [f(i) for i = i_min:i_max]") diff --git a/test/MArray.jl b/test/MArray.jl index 79022dd98..9a5b4c474 100644 --- a/test/MArray.jl +++ b/test/MArray.jl @@ -89,42 +89,154 @@ @test ((@MArray Float64[1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2, p = 1:2])::MArray{Tuple{2,2,2,2,2,2,2,2}}).data === ntuple(i->1.0, 256) @test ((@MArray Float64[1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2, p = 1:2, q = 1:2])::MArray{Tuple{2,2,2,2,2,2,2,2,2}}).data === ntuple(i->1.0, 512) - test_expand_error(:(@MArray [1 2; 3])) - test_expand_error(:(@MArray Float64[1 2; 3])) - test_expand_error(:(@MArray fill)) - test_expand_error(:(@MArray ones)) - test_expand_error(:(@MArray sin(1:5))) - test_expand_error(:(@MArray fill())) - test_expand_error(:(@MArray [1; 2; 3; 4]...)) - - @test ((@MArray fill(1))::MArray{Tuple{},Int}).data === (1,) - @test ((@MArray ones())::MArray{Tuple{},Float64}).data === (1.,) - - @test ((@MArray fill(3.,2,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (3.0, 3.0, 3.0, 3.0) - @test ((@MArray zeros(2,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (0.0, 0.0, 0.0, 0.0) - @test ((@MArray ones(2,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (1.0, 1.0, 1.0, 1.0) - @test isa(@MArray(rand(2,2,1)), MArray{Tuple{2,2,1}, Float64}) - @test isa(@MArray(randn(2,2,1)), MArray{Tuple{2,2,1}, Float64}) - @test isa(@MArray(randexp(2,2,1)), MArray{Tuple{2,2,1}, Float64}) - @test isa(@MArray(rand(2,2,0)), MArray{Tuple{2,2,0}, Float64}) - @test isa(@MArray(randn(2,2,0)), MArray{Tuple{2,2,0}, Float64}) - @test isa(@MArray(randexp(2,2,0)), MArray{Tuple{2,2,0}, Float64}) + @testset "expand error" begin + test_expand_error(:(@MArray [1 2; 3])) + test_expand_error(:(@MArray Float64[1 2; 3])) + test_expand_error(:(@MArray fill)) + test_expand_error(:(@MArray ones)) + test_expand_error(:(@MArray sin(1:5))) + test_expand_error(:(@MArray fill())) + test_expand_error(:(@MArray [1; 2; 3; 4]...)) + + # (typed-)comprehension LoadError for `ex.args[1].head != :generator` + test_expand_error(:(@MArray [i+j for i in 1:2 for j in 1:2])) + test_expand_error(:(@MArray Int[i+j for i in 1:2 for j in 1:2])) + end - @test isa(randn!(@MArray zeros(2,2,1)), MArray{Tuple{2,2,1}, Float64}) - @test isa(randexp!(@MArray zeros(2,2,1)), MArray{Tuple{2,2,1}, Float64}) + @testset "@MArray rand*" begin + @testset "Same test as @MVector rand*" begin + n = 4 + @test (@MArray rand(n)) isa MVector{n, Float64} + @test (@MArray randn(n)) isa MVector{n, Float64} + @test (@MArray randexp(n)) isa MVector{n, Float64} + @test (@MArray rand(4)) isa MVector{4, Float64} + @test (@MArray randn(4)) isa MVector{4, Float64} + @test (@MArray randexp(4)) isa MVector{4, Float64} + @test (@MArray rand(_rng(), n)) isa MVector{n, Float64} + @test (@MArray rand(_rng(), n)) == rand(_rng(), n) + @test (@MArray randn(_rng(), n)) isa MVector{n, Float64} + @test (@MArray randn(_rng(), n)) == randn(_rng(), n) + @test (@MArray randexp(_rng(), n)) isa MVector{n, Float64} + @test (@MArray randexp(_rng(), n)) == randexp(_rng(), n) + @test (@MArray rand(_rng(), 4)) isa MVector{4, Float64} + @test (@MArray rand(_rng(), 4)) == rand(_rng(), 4) + @test (@MArray randn(_rng(), 4)) isa MVector{4, Float64} + @test (@MArray randn(_rng(), 4)) == randn(_rng(), 4) + @test (@MArray randexp(_rng(), 4)) isa MVector{4, Float64} + @test (@MArray randexp(_rng(), 4)) == randexp(_rng(), 4) + + for T in (Float32, Float64) + @test (@MArray rand(T, n)) isa MVector{n, T} + @test (@MArray randn(T, n)) isa MVector{n, T} + @test (@MArray randexp(T, n)) isa MVector{n, T} + @test (@MArray rand(T, 4)) isa MVector{4, T} + @test (@MArray randn(T, 4)) isa MVector{4, T} + @test (@MArray randexp(T, 4)) isa MVector{4, T} + @test (@MArray rand(_rng(), T, n)) isa MVector{n, T} + VERSION≥v"1.7" && @test (@MArray rand(_rng(), T, n)) == rand(_rng(), T, n) broken=(T===Float32) + @test (@MArray randn(_rng(), T, n)) isa MVector{n, T} + @test (@MArray randn(_rng(), T, n)) == randn(_rng(), T, n) + @test (@MArray randexp(_rng(), T, n)) isa MVector{n, T} + @test (@MArray randexp(_rng(), T, n)) == randexp(_rng(), T, n) + @test (@MArray rand(_rng(), T, 4)) isa MVector{4, T} + VERSION≥v"1.7" && @test (@MArray rand(_rng(), T, 4)) == rand(_rng(), T, 4) broken=(T===Float32) + @test (@MArray randn(_rng(), T, 4)) isa MVector{4, T} + @test (@MArray randn(_rng(), T, 4)) == randn(_rng(), T, 4) + @test (@MArray randexp(_rng(), T, 4)) isa MVector{4, T} + @test (@MArray randexp(_rng(), T, 4)) == randexp(_rng(), T, 4) + end + end - @test ((@MArray zeros(Float32, 2, 2, 1))::MArray{Tuple{2,2,1},Float32}).data === (0.0f0, 0.0f0, 0.0f0, 0.0f0) - @test ((@MArray ones(Float32, 2, 2, 1))::MArray{Tuple{2,2,1},Float32}).data === (1.0f0, 1.0f0, 1.0f0, 1.0f0) - @test isa(@MArray(rand(Float32, 2, 2, 1)), MArray{Tuple{2,2,1}, Float32}) - @test isa(@MArray(randn(Float32, 2, 2, 1)), MArray{Tuple{2,2,1}, Float32}) - @test isa(@MArray(randexp(Float32, 2, 2, 1)), MArray{Tuple{2,2,1}, Float32}) - @test isa(@MArray(rand(Float32, 2, 2, 0)), MArray{Tuple{2,2,0}, Float32}) - @test isa(@MArray(randn(Float32, 2, 2, 0)), MArray{Tuple{2,2,0}, Float32}) - @test isa(@MArray(randexp(Float32, 2, 2, 0)), MArray{Tuple{2,2,0}, Float32}) + @testset "Same tests as @MMatrix rand*" begin + n = 4 + @testset "zero-length" begin + @test (@MArray rand(0, 0)) isa MMatrix{0, 0, Float64} + @test (@MArray rand(0, n)) isa MMatrix{0, n, Float64} + @test (@MArray rand(n, 0)) isa MMatrix{n, 0, Float64} + @test (@MArray rand(Float32, 0, 0)) isa MMatrix{0, 0, Float32} + @test (@MArray rand(Float32, 0, n)) isa MMatrix{0, n, Float32} + @test (@MArray rand(Float32, n, 0)) isa MMatrix{n, 0, Float32} + @test (@MArray rand(_rng(), Float32, 0, 0)) isa MMatrix{0, 0, Float32} + @test (@MArray rand(_rng(), Float32, 0, n)) isa MMatrix{0, n, Float32} + @test (@MArray rand(_rng(), Float32, n, 0)) isa MMatrix{n, 0, Float32} + end + + @test (@MArray rand(n, n)) isa MMatrix{n, n, Float64} + @test (@MArray randn(n, n)) isa MMatrix{n, n, Float64} + @test (@MArray randexp(n, n)) isa MMatrix{n, n, Float64} + @test (@MArray rand(4, 4)) isa MMatrix{4, 4, Float64} + @test (@MArray randn(4, 4)) isa MMatrix{4, 4, Float64} + @test (@MArray randexp(4, 4)) isa MMatrix{4, 4, Float64} + @test (@MArray rand(_rng(), n, n)) isa MMatrix{n, n, Float64} + @test (@MArray rand(_rng(), n, n)) == rand(_rng(), n, n) + @test (@MArray randn(_rng(), n, n)) isa MMatrix{n, n, Float64} + @test (@MArray randn(_rng(), n, n)) == randn(_rng(), n, n) + @test (@MArray randexp(_rng(), n, n)) isa MMatrix{n, n, Float64} + @test (@MArray randexp(_rng(), n, n)) == randexp(_rng(), n, n) + @test (@MArray rand(_rng(), 4, 4)) isa MMatrix{4, 4, Float64} + @test (@MArray rand(_rng(), 4, 4)) == rand(_rng(), 4, 4) + @test (@MArray randn(_rng(), 4, 4)) isa MMatrix{4, 4, Float64} + @test (@MArray randn(_rng(), 4, 4)) == randn(_rng(), 4, 4) + @test (@MArray randexp(_rng(), 4, 4)) isa MMatrix{4, 4, Float64} + @test (@MArray randexp(_rng(), 4, 4)) == randexp(_rng(), 4, 4) + + for T in (Float32, Float64) + @test (@MArray rand(T, n, n)) isa MMatrix{n, n, T} + @test (@MArray randn(T, n, n)) isa MMatrix{n, n, T} + @test (@MArray randexp(T, n, n)) isa MMatrix{n, n, T} + @test (@MArray rand(T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MArray randn(T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MArray randexp(T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MArray rand(_rng(), T, n, n)) isa MMatrix{n, n, T} + VERSION≥v"1.7" && @test (@MArray rand(_rng(), T, n, n)) == rand(_rng(), T, n, n) broken=(T===Float32) + @test (@MArray randn(_rng(), T, n, n)) isa MMatrix{n, n, T} + @test (@MArray randn(_rng(), T, n, n)) == randn(_rng(), T, n, n) + @test (@MArray randexp(_rng(), T, n, n)) isa MMatrix{n, n, T} + @test (@MArray randexp(_rng(), T, n, n)) == randexp(_rng(), T, n, n) + @test (@MArray rand(_rng(), T, 4, 4)) isa MMatrix{4, 4, T} + VERSION≥v"1.7" && @test (@MArray rand(_rng(), T, 4, 4)) == rand(_rng(), T, 4, 4) broken=(T===Float32) + @test (@MArray randn(_rng(), T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MArray randn(_rng(), T, 4, 4)) == randn(_rng(), T, 4, 4) + @test (@MArray randexp(_rng(), T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MArray randexp(_rng(), T, 4, 4)) == randexp(_rng(), T, 4, 4) + end + end + + @test (@MArray rand(2,2,1)) isa MArray{Tuple{2,2,1}, Float64} + @test (@MArray rand(2,2,0)) isa MArray{Tuple{2,2,0}, Float64} + @test (@MArray randn(2,2,1)) isa MArray{Tuple{2,2,1}, Float64} + @test (@MArray randn(2,2,0)) isa MArray{Tuple{2,2,0}, Float64} + @test (@MArray randexp(2,2,1)) isa MArray{Tuple{2,2,1}, Float64} + @test (@MArray randexp(2,2,0)) isa MArray{Tuple{2,2,0}, Float64} + @test (@MArray rand(Float32,2,2,1)) isa MArray{Tuple{2,2,1}, Float32} + @test (@MArray rand(Float32,2,2,0)) isa MArray{Tuple{2,2,0}, Float32} + @test (@MArray randn(Float32,2,2,1)) isa MArray{Tuple{2,2,1}, Float32} + @test (@MArray randn(Float32,2,2,0)) isa MArray{Tuple{2,2,0}, Float32} + @test (@MArray randexp(Float32,2,2,1)) isa MArray{Tuple{2,2,1}, Float32} + @test (@MArray randexp(Float32,2,2,0)) isa MArray{Tuple{2,2,0}, Float32} + end + + @testset "fill, zeros, ones" begin + @test ((@MArray fill(1))::MArray{Tuple{},Int}).data === (1,) + @test ((@MArray zeros())::MArray{Tuple{},Float64}).data === (0.,) + @test ((@MArray ones())::MArray{Tuple{},Float64}).data === (1.,) + @test ((@MArray fill(3.,2,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (3.0, 3.0, 3.0, 3.0) + @test ((@MArray zeros(2,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (0.0, 0.0, 0.0, 0.0) + @test ((@MArray ones(2,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (1.0, 1.0, 1.0, 1.0) + @test ((@MArray zeros(3-1,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (0.0, 0.0, 0.0, 0.0) + @test ((@MArray ones(3-1,2,1))::MArray{Tuple{2,2,1}, Float64}).data === (1.0, 1.0, 1.0, 1.0) + @test ((@MArray zeros(Float32,2,2,1))::MArray{Tuple{2,2,1}, Float32}).data === (0.f0, 0.f0, 0.f0, 0.f0) + @test ((@MArray ones(Float32,2,2,1))::MArray{Tuple{2,2,1}, Float32}).data === (1.f0, 1.f0, 1.f0, 1.f0) + @test ((@MArray zeros(Float32,3-1,2,1))::MArray{Tuple{2,2,1}, Float32}).data === (0.f0, 0.f0, 0.f0, 0.f0) + @test ((@MArray ones(Float32,3-1,2,1))::MArray{Tuple{2,2,1}, Float32}).data === (1.f0, 1.f0, 1.f0, 1.f0) + end m = [1 2; 3 4] @test MArray{Tuple{2,2}}(m) == @MArray [1 2; 3 4] + # Non-square comprehensions built from SVectors - see #76 + @test @MArray([1 for x = SVector(1,2), y = SVector(1,2,3)]) == ones(2,3) + # Nested cat @test ((@MArray [[1;2] [3;4]])::MMatrix{2,2}).data === (1,2,3,4) @test ((@MArray Float64[[1;2] [3;4]])::MMatrix{2,2}).data === (1.,2.,3.,4.) @@ -133,6 +245,8 @@ test_expand_error(:(@MArray [[1;2] [3]])) test_expand_error(:(@MArray [[1 2]; [3]])) + @test (@MArray [[[1,2],1]; 2; 3]) == [[[1,2],1]; 2; 3] + if VERSION >= v"1.7.0" function test_ex(ex) a = eval(:(@MArray $ex)) @@ -213,6 +327,12 @@ @test_throws ErrorException setindex!(m, "a", 1, 1, 1) end + @testset "rand! randn! randexp!" begin + @test isa(rand!(@MArray zeros(2,2,1)), MArray{Tuple{2,2,1}, Float64}) + @test isa(randn!(@MArray zeros(2,2,1)), MArray{Tuple{2,2,1}, Float64}) + @test isa(randexp!(@MArray zeros(2,2,1)), MArray{Tuple{2,2,1}, Float64}) + end + @testset "promotion" begin @test @inferred(promote_type(MVector{1,Float64}, MVector{1,BigFloat})) == MVector{1,BigFloat} @test @inferred(promote_type(MVector{2,Int}, MVector{2,Float64})) === MVector{2,Float64} diff --git a/test/MMatrix.jl b/test/MMatrix.jl index 48fb6351b..59e1cfa31 100644 --- a/test/MMatrix.jl +++ b/test/MMatrix.jl @@ -18,6 +18,7 @@ end @testset "Outer constructors and macro" begin + # The following tests are much similar in `test/SMatrix.jl` @test_throws Exception MMatrix(1,2,3,4) # unknown constructor @test MMatrix{1,1,Int}((1,)).data === (1,) @@ -27,10 +28,27 @@ @test MMatrix{2,2,Int}((1,2,3,4)).data === (1,2,3,4) @test MMatrix{2,2}((1,2,3,4)).data === (1,2,3,4) @test MMatrix{2}((1,2,3,4)).data === (1,2,3,4) + @test_throws DimensionMismatch MMatrix{2}((1,2,3,4,5)) # test for #557-like issues @test (@inferred MMatrix(SMatrix{0,0,Float64}()))::MMatrix{0,0,Float64} == MMatrix{0,0,Float64}() + @test (MMatrix{2,3}(i+10j for i in 1:2, j in 1:3)::MMatrix{2,3}).data === + (11,12,21,22,31,32) + @test (MMatrix{2,3}(float(i+10j) for i in 1:2, j in 1:3)::MMatrix{2,3}).data === + (11.0,12.0,21.0,22.0,31.0,32.0) + @test (MMatrix{0,0,Int}()::MMatrix{0,0}).data === () + @test (MMatrix{0,3,Int}()::MMatrix{0,3}).data === () + @test (MMatrix{2,0,Int}()::MMatrix{2,0}).data === () + @test (MMatrix{2,3,Int}(i+10j for i in 1:2, j in 1:3)::MMatrix{2,3}).data === + (11,12,21,22,31,32) + @test (MMatrix{2,3,Float64}(i+10j for i in 1:2, j in 1:3)::MMatrix{2,3}).data === + (11.0,12.0,21.0,22.0,31.0,32.0) + @test_throws Exception MMatrix{2,3}(i+10j for i in 1:1, j in 1:3) + @test_throws Exception MMatrix{2,3}(i+10j for i in 1:3, j in 1:3) + @test_throws Exception MMatrix{2,3,Int}(i+10j for i in 1:1, j in 1:3) + @test_throws Exception MMatrix{2,3,Int}(i+10j for i in 1:3, j in 1:3) + @test ((@MMatrix [1.0])::MMatrix{1,1}).data === (1.0,) @test ((@MMatrix [1 2])::MMatrix{1,2}).data === (1, 2) @test ((@MMatrix [1 ; 2])::MMatrix{2,1}).data === (1, 2) @@ -44,38 +62,88 @@ @test ((@MMatrix [i*j for i = 1:2, j=2:3])::MMatrix{2,2}).data === (2, 4, 3, 6) @test ((@MMatrix Float64[i*j for i = 1:2, j=2:3])::MMatrix{2,2}).data === (2.0, 4.0, 3.0, 6.0) - test_expand_error(:(@MMatrix [1 2; 3])) - test_expand_error(:(@MMatrix Float32[1 2; 3])) - test_expand_error(:(@MMatrix [i*j*k for i = 1:2, j=2:3, k=3:4])) - test_expand_error(:(@MMatrix Float32[i*j*k for i = 1:2, j=2:3, k=3:4])) - test_expand_error(:(@MMatrix fill(2.3, 4, 5, 6))) - test_expand_error(:(@MMatrix ones(4, 5, 6, 7))) - test_expand_error(:(@MMatrix ones)) - test_expand_error(:(@MMatrix sin(1:5))) - test_expand_error(:(@MMatrix [1; 2; 3; 4]...)) - test_expand_error(:(@MMatrix a)) + @testset "expand error" begin + test_expand_error(:(@MMatrix [1 2; 3])) + test_expand_error(:(@MMatrix Float32[1 2; 3])) + test_expand_error(:(@MMatrix [i*j*k for i = 1:2, j=2:3, k=3:4])) + test_expand_error(:(@MMatrix Float32[i*j*k for i = 1:2, j=2:3, k=3:4])) + test_expand_error(:(@MMatrix fill(2.3, 4, 5, 6))) + test_expand_error(:(@MMatrix ones(4, 5, 6, 7))) + test_expand_error(:(@MMatrix ones)) + test_expand_error(:(@MMatrix sin(1:5))) + test_expand_error(:(@MMatrix [1; 2; 3; 4]...)) + test_expand_error(:(@MMatrix a)) + end @test ((@MMatrix [1 2.;3 4])::MMatrix{2, 2, Float64}).data === (1., 3., 2., 4.) #issue #911 - @test ((@MMatrix zeros(2,2))::MMatrix{2, 2, Float64}).data === (0.0, 0.0, 0.0, 0.0) @test ((@MMatrix fill(3.4, 2,2))::MMatrix{2, 2, Float64}).data === (3.4, 3.4, 3.4, 3.4) + @test ((@MMatrix zeros(2,2))::MMatrix{2, 2, Float64}).data === (0.0, 0.0, 0.0, 0.0) @test ((@MMatrix ones(2,2))::MMatrix{2, 2, Float64}).data === (1.0, 1.0, 1.0, 1.0) - @test isa(@MMatrix(rand(2,2)), MMatrix{2, 2, Float64}) - @test isa(@MMatrix(randn(2,2)), MMatrix{2, 2, Float64}) - @test isa(@MMatrix(randexp(2,2)), MMatrix{2, 2, Float64}) - @test ((@MMatrix zeros(Float32, 2, 2))::MMatrix{2,2,Float32}).data === (0.0f0, 0.0f0, 0.0f0, 0.0f0) @test ((@MMatrix ones(Float32, 2, 2))::MMatrix{2,2,Float32}).data === (1.0f0, 1.0f0, 1.0f0, 1.0f0) - @test isa(@MMatrix(rand(Float32, 2, 2)), MMatrix{2, 2, Float32}) - @test isa(@MMatrix(randn(Float32, 2, 2)), MMatrix{2, 2, Float32}) - @test isa(@MMatrix(randexp(Float32, 2, 2)), MMatrix{2, 2, Float32}) + @testset "@MMatrix rand*" begin + n = 4 + @testset "zero-length" begin + @test (@MMatrix rand(0, 0)) isa MMatrix{0, 0, Float64} + @test (@MMatrix rand(0, n)) isa MMatrix{0, n, Float64} + @test (@MMatrix rand(n, 0)) isa MMatrix{n, 0, Float64} + @test (@MMatrix rand(Float32, 0, 0)) isa MMatrix{0, 0, Float32} + @test (@MMatrix rand(Float32, 0, n)) isa MMatrix{0, n, Float32} + @test (@MMatrix rand(Float32, n, 0)) isa MMatrix{n, 0, Float32} + @test (@MMatrix rand(_rng(), Float32, 0, 0)) isa MMatrix{0, 0, Float32} + @test (@MMatrix rand(_rng(), Float32, 0, n)) isa MMatrix{0, n, Float32} + @test (@MMatrix rand(_rng(), Float32, n, 0)) isa MMatrix{n, 0, Float32} + end + + @test (@MMatrix rand(n, n)) isa MMatrix{n, n, Float64} + @test (@MMatrix randn(n, n)) isa MMatrix{n, n, Float64} + @test (@MMatrix randexp(n, n)) isa MMatrix{n, n, Float64} + @test (@MMatrix rand(4, 4)) isa MMatrix{4, 4, Float64} + @test (@MMatrix randn(4, 4)) isa MMatrix{4, 4, Float64} + @test (@MMatrix randexp(4, 4)) isa MMatrix{4, 4, Float64} + @test (@MMatrix rand(_rng(), n, n)) isa MMatrix{n, n, Float64} + @test (@MMatrix rand(_rng(), n, n)) == rand(_rng(), n, n) + @test (@MMatrix randn(_rng(), n, n)) isa MMatrix{n, n, Float64} + @test (@MMatrix randn(_rng(), n, n)) == randn(_rng(), n, n) + @test (@MMatrix randexp(_rng(), n, n)) isa MMatrix{n, n, Float64} + @test (@MMatrix randexp(_rng(), n, n)) == randexp(_rng(), n, n) + @test (@MMatrix rand(_rng(), 4, 4)) isa MMatrix{4, 4, Float64} + @test (@MMatrix rand(_rng(), 4, 4)) == rand(_rng(), 4, 4) + @test (@MMatrix randn(_rng(), 4, 4)) isa MMatrix{4, 4, Float64} + @test (@MMatrix randn(_rng(), 4, 4)) == randn(_rng(), 4, 4) + @test (@MMatrix randexp(_rng(), 4, 4)) isa MMatrix{4, 4, Float64} + @test (@MMatrix randexp(_rng(), 4, 4)) == randexp(_rng(), 4, 4) + + for T in (Float32, Float64) + @test (@MMatrix rand(T, n, n)) isa MMatrix{n, n, T} + @test (@MMatrix randn(T, n, n)) isa MMatrix{n, n, T} + @test (@MMatrix randexp(T, n, n)) isa MMatrix{n, n, T} + @test (@MMatrix rand(T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MMatrix randn(T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MMatrix randexp(T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MMatrix rand(_rng(), T, n, n)) isa MMatrix{n, n, T} + VERSION≥v"1.7" && @test (@MMatrix rand(_rng(), T, n, n)) == rand(_rng(), T, n, n) broken=(T===Float32) + @test (@MMatrix randn(_rng(), T, n, n)) isa MMatrix{n, n, T} + @test (@MMatrix randn(_rng(), T, n, n)) == randn(_rng(), T, n, n) + @test (@MMatrix randexp(_rng(), T, n, n)) isa MMatrix{n, n, T} + @test (@MMatrix randexp(_rng(), T, n, n)) == randexp(_rng(), T, n, n) + @test (@MMatrix rand(_rng(), T, 4, 4)) isa MMatrix{4, 4, T} + VERSION≥v"1.7" && @test (@MMatrix rand(_rng(), T, 4, 4)) == rand(_rng(), T, 4, 4) broken=(T===Float32) + @test (@MMatrix randn(_rng(), T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MMatrix randn(_rng(), T, 4, 4)) == randn(_rng(), T, 4, 4) + @test (@MMatrix randexp(_rng(), T, 4, 4)) isa MMatrix{4, 4, T} + @test (@MMatrix randexp(_rng(), T, 4, 4)) == randexp(_rng(), T, 4, 4) + end + end + + @inferred MMatrix(rand(MMatrix{3, 3})) # issue 356 @test MMatrix(SMatrix{1,1,Int,1}((1,))).data == (1,) @test_throws DimensionMismatch MMatrix{3}((1,2,3,4)) if VERSION >= v"1.7.0" @test ((@MMatrix Float64[1;2;3;;;])::MMatrix{3,1}).data === (1.0, 2.0, 3.0) @test ((@MMatrix [1;2;3;;;])::MMatrix{3,1}).data === (1, 2, 3) - @test ((@MMatrix [1;2;3;;;])::MMatrix{3,1}).data === (1, 2, 3) test_expand_error(:(@MMatrix [1;2;;;1;2])) end end diff --git a/test/MVector.jl b/test/MVector.jl index 4c270f8d3..25d510c65 100644 --- a/test/MVector.jl +++ b/test/MVector.jl @@ -23,6 +23,20 @@ # test for #557-like issues @test (@inferred MVector(SVector{0,Float64}()))::MVector{0,Float64} == MVector{0,Float64}() + @test MVector{3}(i for i in 1:3).data === (1,2,3) + @test MVector{3}(float(i) for i in 1:3).data === (1.0,2.0,3.0) + @test MVector{0,Int}().data === () + @test MVector{3,Int}(i for i in 1:3).data === (1,2,3) + @test MVector{3,Float64}(i for i in 1:3).data === (1.0,2.0,3.0) + @test MVector{1}(MVector(MVector(1.0), MVector(2.0))[j] for j in 1:1) == MVector((MVector(1.0),)) + @test_throws Exception MVector{3}(i for i in 1:2) + @test_throws Exception MVector{3}(i for i in 1:4) + @test_throws Exception MVector{3,Int}(i for i in 1:2) + @test_throws Exception MVector{3,Int}(i for i in 1:4) + + @test MVector(1).data === (1,) + @test MVector(1,1.0).data === (1.0,1.0) + @test ((@MVector [1.0])::MVector{1}).data === (1.0,) @test ((@MVector [1, 2, 3])::MVector{3}).data === (1, 2, 3) @test ((@MVector Float64[1,2,3])::MVector{3}).data === (1.0, 2.0, 3.0) @@ -31,25 +45,64 @@ @test ((@MVector zeros(2))::MVector{2, Float64}).data === (0.0, 0.0) @test ((@MVector ones(2))::MVector{2, Float64}).data === (1.0, 1.0) - @test ((@MVector fill(2.5, 2))::MVector{2, Float64}).data === (2.5, 2.5) - @test isa(@MVector(rand(2)), MVector{2, Float64}) - @test isa(@MVector(randn(2)), MVector{2, Float64}) - @test isa(@MVector(randexp(2)), MVector{2, Float64}) - + @test ((@MVector fill(2.5, 2))::MVector{2,Float64}).data === (2.5, 2.5) @test ((@MVector zeros(Float32, 2))::MVector{2,Float32}).data === (0.0f0, 0.0f0) @test ((@MVector ones(Float32, 2))::MVector{2,Float32}).data === (1.0f0, 1.0f0) - @test isa(@MVector(rand(Float32, 2)), MVector{2, Float32}) - @test isa(@MVector(randn(Float32, 2)), MVector{2, Float32}) - @test isa(@MVector(randexp(Float32, 2)), MVector{2, Float32}) - - test_expand_error(:(@MVector fill(1.5, 2, 3))) - test_expand_error(:(@MVector ones(2, 3, 4))) - test_expand_error(:(@MVector sin(1:5))) - test_expand_error(:(@MVector [i*j for i in 1:2, j in 2:3])) - test_expand_error(:(@MVector Float32[i*j for i in 1:2, j in 2:3])) - test_expand_error(:(@MVector [1; 2; 3]...)) - test_expand_error(:(@MVector a)) - test_expand_error(:(@MVector [[1 2];[3 4]])) + + @testset "@MVector rand*" begin + n = 4 + @test (@MVector rand(n)) isa MVector{n, Float64} + @test (@MVector randn(n)) isa MVector{n, Float64} + @test (@MVector randexp(n)) isa MVector{n, Float64} + @test (@MVector rand(4)) isa MVector{4, Float64} + @test (@MVector randn(4)) isa MVector{4, Float64} + @test (@MVector randexp(4)) isa MVector{4, Float64} + @test (@MVector rand(_rng(), n)) isa MVector{n, Float64} + @test (@MVector rand(_rng(), n)) == rand(_rng(), n) + @test (@MVector randn(_rng(), n)) isa MVector{n, Float64} + @test (@MVector randn(_rng(), n)) == randn(_rng(), n) + @test (@MVector randexp(_rng(), n)) isa MVector{n, Float64} + @test (@MVector randexp(_rng(), n)) == randexp(_rng(), n) + @test (@MVector rand(_rng(), 4)) isa MVector{4, Float64} + @test (@MVector rand(_rng(), 4)) == rand(_rng(), 4) + @test (@MVector randn(_rng(), 4)) isa MVector{4, Float64} + @test (@MVector randn(_rng(), 4)) == randn(_rng(), 4) + @test (@MVector randexp(_rng(), 4)) isa MVector{4, Float64} + @test (@MVector randexp(_rng(), 4)) == randexp(_rng(), 4) + + for T in (Float32, Float64) + @test (@MVector rand(T, n)) isa MVector{n, T} + @test (@MVector randn(T, n)) isa MVector{n, T} + @test (@MVector randexp(T, n)) isa MVector{n, T} + @test (@MVector rand(T, 4)) isa MVector{4, T} + @test (@MVector randn(T, 4)) isa MVector{4, T} + @test (@MVector randexp(T, 4)) isa MVector{4, T} + @test (@MVector rand(_rng(), T, n)) isa MVector{n, T} + VERSION≥v"1.7" && @test (@MVector rand(_rng(), T, n)) == rand(_rng(), T, n) broken=(T===Float32) + @test (@MVector randn(_rng(), T, n)) isa MVector{n, T} + @test (@MVector randn(_rng(), T, n)) == randn(_rng(), T, n) + @test (@MVector randexp(_rng(), T, n)) isa MVector{n, T} + @test (@MVector randexp(_rng(), T, n)) == randexp(_rng(), T, n) + @test (@MVector rand(_rng(), T, 4)) isa MVector{4, T} + VERSION≥v"1.7" && @test (@MVector rand(_rng(), T, 4)) == rand(_rng(), T, 4) broken=(T===Float32) + @test (@MVector randn(_rng(), T, 4)) isa MVector{4, T} + @test (@MVector randn(_rng(), T, 4)) == randn(_rng(), T, 4) + @test (@MVector randexp(_rng(), T, 4)) isa MVector{4, T} + @test (@MVector randexp(_rng(), T, 4)) == randexp(_rng(), T, 4) + end + end + + @testset "expand error" begin + test_expand_error(:(@MVector fill(1.5, 2, 3))) + test_expand_error(:(@MVector ones(2, 3, 4))) + test_expand_error(:(@MVector rand(Float64, 2, 3, 4))) + test_expand_error(:(@MVector sin(1:5))) + test_expand_error(:(@MVector [i*j for i in 1:2, j in 2:3])) + test_expand_error(:(@MVector Float32[i*j for i in 1:2, j in 2:3])) + test_expand_error(:(@MVector [1; 2; 3]...)) + test_expand_error(:(@MVector a)) + test_expand_error(:(@MVector [[1 2];[3 4]])) + end if VERSION >= v"1.7.0" @test ((@MVector Float64[1;2;3;;;])::MVector{3}).data === (1.0, 2.0, 3.0) diff --git a/test/SArray.jl b/test/SArray.jl index 74a0421de..d530efd1f 100644 --- a/test/SArray.jl +++ b/test/SArray.jl @@ -83,35 +83,147 @@ @test ((@SArray Float64[1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2, p = 1:2])::SArray{Tuple{2,2,2,2,2,2,2,2}}).data === ntuple(i->1.0, 256) @test ((@SArray Float64[1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2, p = 1:2, q = 1:2])::SArray{Tuple{2,2,2,2,2,2,2,2,2}}).data === ntuple(i->1.0, 512) - test_expand_error(:(@SArray [1 2; 3])) - test_expand_error(:(@SArray Float64[1 2; 3])) - test_expand_error(:(@SArray ones)) - test_expand_error(:(@SArray fill)) - test_expand_error(:(@SArray sin(1:5))) - test_expand_error(:(@SArray fill())) - test_expand_error(:(@SArray [1; 2; 3; 4]...)) - - # (typed-)comprehension LoadError for `ex.args[1].head != :generator` - test_expand_error(:(@SArray [i+j for i in 1:2 for j in 1:2])) - test_expand_error(:(@SArray Int[i+j for i in 1:2 for j in 1:2])) - - @test ((@SArray fill(1))::SArray{Tuple{},Int}).data === (1,) - @test ((@SArray ones())::SArray{Tuple{},Float64}).data === (1.,) - - @test ((@SArray fill(3.,2,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (3.0, 3.0, 3.0, 3.0) - @test ((@SArray zeros(2,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (0.0, 0.0, 0.0, 0.0) - @test ((@SArray ones(2,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (1.0, 1.0, 1.0, 1.0) - @test isa(@SArray(rand(2,2,0)), SArray{Tuple{2,2,0}, Float64}) - @test isa(@SArray(rand(2,2,1)), SArray{Tuple{2,2,1}, Float64}) - @test isa(@SArray(randn(2,2,1)), SArray{Tuple{2,2,1}, Float64}) - @test isa(@SArray(randexp(2,2,1)), SArray{Tuple{2,2,1}, Float64}) - - @test ((@SArray zeros(Float32, 2, 2, 1))::SArray{Tuple{2,2,1},Float32}).data === (0.0f0, 0.0f0, 0.0f0, 0.0f0) - @test ((@SArray ones(Float32, 2, 2, 1))::SArray{Tuple{2,2,1},Float32}).data === (1.0f0, 1.0f0, 1.0f0, 1.0f0) - @test isa(@SArray(rand(Float32, 2, 2, 0)), SArray{Tuple{2,2,0}, Float32}) - @test isa(@SArray(rand(Float32, 2, 2, 1)), SArray{Tuple{2,2,1}, Float32}) - @test isa(@SArray(randn(Float32, 2, 2, 1)), SArray{Tuple{2,2,1}, Float32}) - @test isa(@SArray(randexp(Float32, 2, 2, 1)), SArray{Tuple{2,2,1}, Float32}) + @testset "expand error" begin + test_expand_error(:(@SArray [1 2; 3])) + test_expand_error(:(@SArray Float64[1 2; 3])) + test_expand_error(:(@SArray ones)) + test_expand_error(:(@SArray fill)) + test_expand_error(:(@SArray sin(1:5))) + test_expand_error(:(@SArray fill())) + test_expand_error(:(@SArray [1; 2; 3; 4]...)) + + # (typed-)comprehension LoadError for `ex.args[1].head != :generator` + test_expand_error(:(@SArray [i+j for i in 1:2 for j in 1:2])) + test_expand_error(:(@SArray Int[i+j for i in 1:2 for j in 1:2])) + end + + @testset "@SArray rand*" begin + @testset "Same test as @SVector rand*" begin + n = 4 + @test (@SArray rand(n)) isa SVector{n, Float64} + @test (@SArray randn(n)) isa SVector{n, Float64} + @test (@SArray randexp(n)) isa SVector{n, Float64} + @test (@SArray rand(4)) isa SVector{4, Float64} + @test (@SArray randn(4)) isa SVector{4, Float64} + @test (@SArray randexp(4)) isa SVector{4, Float64} + @test (@SArray rand(_rng(), n)) isa SVector{n, Float64} + @test (@SArray rand(_rng(), n)) == rand(_rng(), n) + @test (@SArray randn(_rng(), n)) isa SVector{n, Float64} + @test (@SArray randn(_rng(), n)) == randn(_rng(), n) + @test (@SArray randexp(_rng(), n)) isa SVector{n, Float64} + @test (@SArray randexp(_rng(), n)) == randexp(_rng(), n) + @test (@SArray rand(_rng(), 4)) isa SVector{4, Float64} + @test (@SArray rand(_rng(), 4)) == rand(_rng(), 4) + @test (@SArray randn(_rng(), 4)) isa SVector{4, Float64} + @test (@SArray randn(_rng(), 4)) == randn(_rng(), 4) + @test (@SArray randexp(_rng(), 4)) isa SVector{4, Float64} + @test (@SArray randexp(_rng(), 4)) == randexp(_rng(), 4) + + for T in (Float32, Float64) + @test (@SArray rand(T, n)) isa SVector{n, T} + @test (@SArray randn(T, n)) isa SVector{n, T} + @test (@SArray randexp(T, n)) isa SVector{n, T} + @test (@SArray rand(T, 4)) isa SVector{4, T} + @test (@SArray randn(T, 4)) isa SVector{4, T} + @test (@SArray randexp(T, 4)) isa SVector{4, T} + @test (@SArray rand(_rng(), T, n)) isa SVector{n, T} + VERSION≥v"1.7" && @test (@SArray rand(_rng(), T, n)) == rand(_rng(), T, n) broken=(T===Float32) + @test (@SArray randn(_rng(), T, n)) isa SVector{n, T} + @test (@SArray randn(_rng(), T, n)) == randn(_rng(), T, n) + @test (@SArray randexp(_rng(), T, n)) isa SVector{n, T} + @test (@SArray randexp(_rng(), T, n)) == randexp(_rng(), T, n) + @test (@SArray rand(_rng(), T, 4)) isa SVector{4, T} + VERSION≥v"1.7" && @test (@SArray rand(_rng(), T, 4)) == rand(_rng(), T, 4) broken=(T===Float32) + @test (@SArray randn(_rng(), T, 4)) isa SVector{4, T} + @test (@SArray randn(_rng(), T, 4)) == randn(_rng(), T, 4) + @test (@SArray randexp(_rng(), T, 4)) isa SVector{4, T} + @test (@SArray randexp(_rng(), T, 4)) == randexp(_rng(), T, 4) + end + end + + @testset "Same tests as @SMatrix rand*" begin + n = 4 + @testset "zero-length" begin + @test (@SArray rand(0, 0)) isa SMatrix{0, 0, Float64} + @test (@SArray rand(0, n)) isa SMatrix{0, n, Float64} + @test (@SArray rand(n, 0)) isa SMatrix{n, 0, Float64} + @test (@SArray rand(Float32, 0, 0)) isa SMatrix{0, 0, Float32} + @test (@SArray rand(Float32, 0, n)) isa SMatrix{0, n, Float32} + @test (@SArray rand(Float32, n, 0)) isa SMatrix{n, 0, Float32} + @test (@SArray rand(_rng(), Float32, 0, 0)) isa SMatrix{0, 0, Float32} + @test (@SArray rand(_rng(), Float32, 0, n)) isa SMatrix{0, n, Float32} + @test (@SArray rand(_rng(), Float32, n, 0)) isa SMatrix{n, 0, Float32} + end + + @test (@SArray rand(n, n)) isa SMatrix{n, n, Float64} + @test (@SArray randn(n, n)) isa SMatrix{n, n, Float64} + @test (@SArray randexp(n, n)) isa SMatrix{n, n, Float64} + @test (@SArray rand(4, 4)) isa SMatrix{4, 4, Float64} + @test (@SArray randn(4, 4)) isa SMatrix{4, 4, Float64} + @test (@SArray randexp(4, 4)) isa SMatrix{4, 4, Float64} + @test (@SArray rand(_rng(), n, n)) isa SMatrix{n, n, Float64} + @test (@SArray rand(_rng(), n, n)) == rand(_rng(), n, n) + @test (@SArray randn(_rng(), n, n)) isa SMatrix{n, n, Float64} + @test (@SArray randn(_rng(), n, n)) == randn(_rng(), n, n) + @test (@SArray randexp(_rng(), n, n)) isa SMatrix{n, n, Float64} + @test (@SArray randexp(_rng(), n, n)) == randexp(_rng(), n, n) + @test (@SArray rand(_rng(), 4, 4)) isa SMatrix{4, 4, Float64} + @test (@SArray rand(_rng(), 4, 4)) == rand(_rng(), 4, 4) + @test (@SArray randn(_rng(), 4, 4)) isa SMatrix{4, 4, Float64} + @test (@SArray randn(_rng(), 4, 4)) == randn(_rng(), 4, 4) + @test (@SArray randexp(_rng(), 4, 4)) isa SMatrix{4, 4, Float64} + @test (@SArray randexp(_rng(), 4, 4)) == randexp(_rng(), 4, 4) + + for T in (Float32, Float64) + @test (@SArray rand(T, n, n)) isa SMatrix{n, n, T} + @test (@SArray randn(T, n, n)) isa SMatrix{n, n, T} + @test (@SArray randexp(T, n, n)) isa SMatrix{n, n, T} + @test (@SArray rand(T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SArray randn(T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SArray randexp(T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SArray rand(_rng(), T, n, n)) isa SMatrix{n, n, T} + VERSION≥v"1.7" && @test (@SArray rand(_rng(), T, n, n)) == rand(_rng(), T, n, n) broken=(T===Float32) + @test (@SArray randn(_rng(), T, n, n)) isa SMatrix{n, n, T} + @test (@SArray randn(_rng(), T, n, n)) == randn(_rng(), T, n, n) + @test (@SArray randexp(_rng(), T, n, n)) isa SMatrix{n, n, T} + @test (@SArray randexp(_rng(), T, n, n)) == randexp(_rng(), T, n, n) + @test (@SArray rand(_rng(), T, 4, 4)) isa SMatrix{4, 4, T} + VERSION≥v"1.7" && @test (@SArray rand(_rng(), T, 4, 4)) == rand(_rng(), T, 4, 4) broken=(T===Float32) + @test (@SArray randn(_rng(), T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SArray randn(_rng(), T, 4, 4)) == randn(_rng(), T, 4, 4) + @test (@SArray randexp(_rng(), T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SArray randexp(_rng(), T, 4, 4)) == randexp(_rng(), T, 4, 4) + end + end + + @test (@SArray rand(2,2,1)) isa SArray{Tuple{2,2,1}, Float64} + @test (@SArray rand(2,2,0)) isa SArray{Tuple{2,2,0}, Float64} + @test (@SArray randn(2,2,1)) isa SArray{Tuple{2,2,1}, Float64} + @test (@SArray randn(2,2,0)) isa SArray{Tuple{2,2,0}, Float64} + @test (@SArray randexp(2,2,1)) isa SArray{Tuple{2,2,1}, Float64} + @test (@SArray randexp(2,2,0)) isa SArray{Tuple{2,2,0}, Float64} + @test (@SArray rand(Float32,2,2,1)) isa SArray{Tuple{2,2,1}, Float32} + @test (@SArray rand(Float32,2,2,0)) isa SArray{Tuple{2,2,0}, Float32} + @test (@SArray randn(Float32,2,2,1)) isa SArray{Tuple{2,2,1}, Float32} + @test (@SArray randn(Float32,2,2,0)) isa SArray{Tuple{2,2,0}, Float32} + @test (@SArray randexp(Float32,2,2,1)) isa SArray{Tuple{2,2,1}, Float32} + @test (@SArray randexp(Float32,2,2,0)) isa SArray{Tuple{2,2,0}, Float32} + end + + @testset "fill, zeros, ones" begin + @test ((@SArray fill(1))::SArray{Tuple{},Int}).data === (1,) + @test ((@SArray zeros())::SArray{Tuple{},Float64}).data === (0.,) + @test ((@SArray ones())::SArray{Tuple{},Float64}).data === (1.,) + @test ((@SArray fill(3.,2,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (3.0, 3.0, 3.0, 3.0) + @test ((@SArray zeros(2,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (0.0, 0.0, 0.0, 0.0) + @test ((@SArray ones(2,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (1.0, 1.0, 1.0, 1.0) + @test ((@SArray zeros(3-1,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (0.0, 0.0, 0.0, 0.0) + @test ((@SArray ones(3-1,2,1))::SArray{Tuple{2,2,1}, Float64}).data === (1.0, 1.0, 1.0, 1.0) + @test ((@SArray zeros(Float32,2,2,1))::SArray{Tuple{2,2,1}, Float32}).data === (0.f0, 0.f0, 0.f0, 0.f0) + @test ((@SArray ones(Float32,2,2,1))::SArray{Tuple{2,2,1}, Float32}).data === (1.f0, 1.f0, 1.f0, 1.f0) + @test ((@SArray zeros(Float32,3-1,2,1))::SArray{Tuple{2,2,1}, Float32}).data === (0.f0, 0.f0, 0.f0, 0.f0) + @test ((@SArray ones(Float32,3-1,2,1))::SArray{Tuple{2,2,1}, Float32}).data === (1.f0, 1.f0, 1.f0, 1.f0) + end m = [1 2; 3 4] @test SArray{Tuple{2,2}}(m) === @SArray [1 2; 3 4] diff --git a/test/SMatrix.jl b/test/SMatrix.jl index ae1120723..3d86441f7 100644 --- a/test/SMatrix.jl +++ b/test/SMatrix.jl @@ -17,19 +17,21 @@ end @testset "Outer constructors and macro" begin + # The following tests are much similar in `test/MMatrix.jl` @test_throws Exception SMatrix(1,2,3,4) # unknown constructor @test SMatrix{1,1,Int}((1,)).data === (1,) @test SMatrix{1,1}((1,)).data === (1,) @test SMatrix{1}((1,)).data === (1,) - @test (@inferred SMatrix(MMatrix{0,0,Float64}()))::SMatrix{0,0,Float64} == SMatrix{0,0,Float64}() - @test SMatrix{2,2,Int}((1,2,3,4)).data === (1,2,3,4) @test SMatrix{2,2}((1,2,3,4)).data === (1,2,3,4) @test SMatrix{2}((1,2,3,4)).data === (1,2,3,4) @test_throws DimensionMismatch SMatrix{2}((1,2,3,4,5)) + # test for #557-like issues + @test (@inferred SMatrix(MMatrix{0,0,Float64}()))::SMatrix{0,0,Float64} == SMatrix{0,0,Float64}() + @test (SMatrix{2,3}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data === (11,12,21,22,31,32) @test (SMatrix{2,3}(float(i+10j) for i in 1:2, j in 1:3)::SMatrix{2,3}).data === @@ -58,39 +60,85 @@ @test ((@SMatrix [i*j for i = 1:2, j=2:3])::SMatrix{2,2}).data === (2, 4, 3, 6) @test ((@SMatrix Float64[i*j for i = 1:2, j=2:3])::SMatrix{2,2}).data === (2.0, 4.0, 3.0, 6.0) - test_expand_error(:(@SMatrix [1 2; 3])) - test_expand_error(:(@SMatrix Float64[1 2; 3])) - test_expand_error(:(@SMatrix [i*j*k for i = 1:2, j=2:3, k=3:4])) - test_expand_error(:(@SMatrix Float64[i*j*k for i = 1:2, j=2:3, k=3:4])) - test_expand_error(:(@SMatrix fill(1.5, 2, 3, 4))) - test_expand_error(:(@SMatrix ones(2, 3, 4, 5))) - test_expand_error(:(@SMatrix ones)) - test_expand_error(:(@SMatrix sin(1:5))) - test_expand_error(:(@SMatrix [1; 2; 3; 4]...)) - test_expand_error(:(@SMatrix a)) + @testset "expand error" begin + test_expand_error(:(@SMatrix [1 2; 3])) + test_expand_error(:(@SMatrix Float64[1 2; 3])) + test_expand_error(:(@SMatrix [i*j*k for i = 1:2, j=2:3, k=3:4])) + test_expand_error(:(@SMatrix Float64[i*j*k for i = 1:2, j=2:3, k=3:4])) + test_expand_error(:(@SMatrix fill(1.5, 2, 3, 4))) + test_expand_error(:(@SMatrix ones(2, 3, 4, 5))) + test_expand_error(:(@SMatrix ones)) + test_expand_error(:(@SMatrix sin(1:5))) + test_expand_error(:(@SMatrix [1; 2; 3; 4]...)) + test_expand_error(:(@SMatrix a)) + end + + @test ((@SMatrix [1 2.;3 4])::SMatrix{2, 2, Float64}).data === (1., 3., 2., 4.) #issue #911 @test ((@SMatrix fill(1.3, 2,2))::SMatrix{2, 2, Float64}).data === (1.3, 1.3, 1.3, 1.3) @test ((@SMatrix zeros(2,2))::SMatrix{2, 2, Float64}).data === (0.0, 0.0, 0.0, 0.0) @test ((@SMatrix ones(2,2))::SMatrix{2, 2, Float64}).data === (1.0, 1.0, 1.0, 1.0) - @test isa(@SMatrix(rand(2,2)), SMatrix{2, 2, Float64}) - @test isa(@SMatrix(randn(2,2)), SMatrix{2, 2, Float64}) - @test isa(@SMatrix(randexp(2,2)), SMatrix{2, 2, Float64}) - @test isa(@SMatrix(rand(2, 0)), SMatrix{2, 0, Float64}) - @test isa(@SMatrix(randn(2, 0)), SMatrix{2, 0, Float64}) - @test isa(@SMatrix(randexp(2, 0)), SMatrix{2, 0, Float64}) - @test ((@SMatrix zeros(Float32, 2, 2))::SMatrix{2,2,Float32}).data === (0.0f0, 0.0f0, 0.0f0, 0.0f0) @test ((@SMatrix ones(Float32, 2, 2))::SMatrix{2,2,Float32}).data === (1.0f0, 1.0f0, 1.0f0, 1.0f0) - @test isa(@SMatrix(rand(Float32, 2, 2)), SMatrix{2, 2, Float32}) - @test isa(@SMatrix(randn(Float32, 2, 2)), SMatrix{2, 2, Float32}) - @test isa(@SMatrix(randexp(Float32, 2, 2)), SMatrix{2, 2, Float32}) - @test isa(@SMatrix(rand(Float32, 2, 0)), SMatrix{2, 0, Float32}) - @test isa(@SMatrix(randn(Float32, 2, 0)), SMatrix{2, 0, Float32}) - @test isa(@SMatrix(randexp(Float32, 2, 0)), SMatrix{2, 0, Float32}) - @test isa(SMatrix(@SMatrix zeros(4,4)), SMatrix{4, 4, Float64}) + @testset "@SMatrix rand*" begin + n = 4 + @testset "zero-length" begin + @test (@SMatrix rand(0, 0)) isa SMatrix{0, 0, Float64} + @test (@SMatrix rand(0, n)) isa SMatrix{0, n, Float64} + @test (@SMatrix rand(n, 0)) isa SMatrix{n, 0, Float64} + @test (@SMatrix rand(Float32, 0, 0)) isa SMatrix{0, 0, Float32} + @test (@SMatrix rand(Float32, 0, n)) isa SMatrix{0, n, Float32} + @test (@SMatrix rand(Float32, n, 0)) isa SMatrix{n, 0, Float32} + @test (@SMatrix rand(_rng(), Float32, 0, 0)) isa SMatrix{0, 0, Float32} + @test (@SMatrix rand(_rng(), Float32, 0, n)) isa SMatrix{0, n, Float32} + @test (@SMatrix rand(_rng(), Float32, n, 0)) isa SMatrix{n, 0, Float32} + end + + @test (@SMatrix rand(n, n)) isa SMatrix{n, n, Float64} + @test (@SMatrix randn(n, n)) isa SMatrix{n, n, Float64} + @test (@SMatrix randexp(n, n)) isa SMatrix{n, n, Float64} + @test (@SMatrix rand(4, 4)) isa SMatrix{4, 4, Float64} + @test (@SMatrix randn(4, 4)) isa SMatrix{4, 4, Float64} + @test (@SMatrix randexp(4, 4)) isa SMatrix{4, 4, Float64} + @test (@SMatrix rand(_rng(), n, n)) isa SMatrix{n, n, Float64} + @test (@SMatrix rand(_rng(), n, n)) == rand(_rng(), n, n) + @test (@SMatrix randn(_rng(), n, n)) isa SMatrix{n, n, Float64} + @test (@SMatrix randn(_rng(), n, n)) == randn(_rng(), n, n) + @test (@SMatrix randexp(_rng(), n, n)) isa SMatrix{n, n, Float64} + @test (@SMatrix randexp(_rng(), n, n)) == randexp(_rng(), n, n) + @test (@SMatrix rand(_rng(), 4, 4)) isa SMatrix{4, 4, Float64} + @test (@SMatrix rand(_rng(), 4, 4)) == rand(_rng(), 4, 4) + @test (@SMatrix randn(_rng(), 4, 4)) isa SMatrix{4, 4, Float64} + @test (@SMatrix randn(_rng(), 4, 4)) == randn(_rng(), 4, 4) + @test (@SMatrix randexp(_rng(), 4, 4)) isa SMatrix{4, 4, Float64} + @test (@SMatrix randexp(_rng(), 4, 4)) == randexp(_rng(), 4, 4) + + for T in (Float32, Float64) + @test (@SMatrix rand(T, n, n)) isa SMatrix{n, n, T} + @test (@SMatrix randn(T, n, n)) isa SMatrix{n, n, T} + @test (@SMatrix randexp(T, n, n)) isa SMatrix{n, n, T} + @test (@SMatrix rand(T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SMatrix randn(T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SMatrix randexp(T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SMatrix rand(_rng(), T, n, n)) isa SMatrix{n, n, T} + VERSION≥v"1.7" && @test (@SMatrix rand(_rng(), T, n, n)) == rand(_rng(), T, n, n) broken=(T===Float32) + @test (@SMatrix randn(_rng(), T, n, n)) isa SMatrix{n, n, T} + @test (@SMatrix randn(_rng(), T, n, n)) == randn(_rng(), T, n, n) + @test (@SMatrix randexp(_rng(), T, n, n)) isa SMatrix{n, n, T} + @test (@SMatrix randexp(_rng(), T, n, n)) == randexp(_rng(), T, n, n) + @test (@SMatrix rand(_rng(), T, 4, 4)) isa SMatrix{4, 4, T} + VERSION≥v"1.7" && @test (@SMatrix rand(_rng(), T, 4, 4)) == rand(_rng(), T, 4, 4) broken=(T===Float32) + @test (@SMatrix randn(_rng(), T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SMatrix randn(_rng(), T, 4, 4)) == randn(_rng(), T, 4, 4) + @test (@SMatrix randexp(_rng(), T, 4, 4)) isa SMatrix{4, 4, T} + @test (@SMatrix randexp(_rng(), T, 4, 4)) == randexp(_rng(), T, 4, 4) + end + end @inferred SMatrix(rand(SMatrix{3, 3})) # issue 356 + @test SMatrix(MMatrix{1,1,Int,1}((1,))).data == (1,) + @test_throws DimensionMismatch SMatrix{3}((1,2,3,4)) if VERSION >= v"1.7.0" @test ((@SMatrix Float64[1;2;3;;;])::SMatrix{3,1}).data === (1.0, 2.0, 3.0) diff --git a/test/SVector.jl b/test/SVector.jl index d9231e9a1..2e62f4790 100644 --- a/test/SVector.jl +++ b/test/SVector.jl @@ -42,24 +42,63 @@ @test ((@SVector zeros(2))::SVector{2, Float64}).data === (0.0, 0.0) @test ((@SVector ones(2))::SVector{2, Float64}).data === (1.0, 1.0) @test ((@SVector fill(2.5, 2))::SVector{2,Float64}).data === (2.5, 2.5) - @test isa(@SVector(rand(2)), SVector{2, Float64}) - @test isa(@SVector(randn(2)), SVector{2, Float64}) - @test isa(@SVector(randexp(2)), SVector{2, Float64}) - @test ((@SVector zeros(Float32, 2))::SVector{2,Float32}).data === (0.0f0, 0.0f0) @test ((@SVector ones(Float32, 2))::SVector{2,Float32}).data === (1.0f0, 1.0f0) - @test isa(@SVector(rand(Float32, 2)), SVector{2, Float32}) - @test isa(@SVector(randn(Float32, 2)), SVector{2, Float32}) - @test isa(@SVector(randexp(Float32, 2)), SVector{2, Float32}) - - test_expand_error(:(@SVector fill(1.5, 2, 3))) - test_expand_error(:(@SVector ones(2, 3, 4))) - test_expand_error(:(@SVector sin(1:5))) - test_expand_error(:(@SVector [i*j for i in 1:2, j in 2:3])) - test_expand_error(:(@SVector Float32[i*j for i in 1:2, j in 2:3])) - test_expand_error(:(@SVector [1; 2; 3]...)) - test_expand_error(:(@SVector a)) - test_expand_error(:(@SVector [[1 2];[3 4]])) + + @testset "@SVector rand*" begin + n = 4 + @test (@SVector rand(n)) isa SVector{n, Float64} + @test (@SVector randn(n)) isa SVector{n, Float64} + @test (@SVector randexp(n)) isa SVector{n, Float64} + @test (@SVector rand(4)) isa SVector{4, Float64} + @test (@SVector randn(4)) isa SVector{4, Float64} + @test (@SVector randexp(4)) isa SVector{4, Float64} + @test (@SVector rand(_rng(), n)) isa SVector{n, Float64} + @test (@SVector rand(_rng(), n)) == rand(_rng(), n) + @test (@SVector randn(_rng(), n)) isa SVector{n, Float64} + @test (@SVector randn(_rng(), n)) == randn(_rng(), n) + @test (@SVector randexp(_rng(), n)) isa SVector{n, Float64} + @test (@SVector randexp(_rng(), n)) == randexp(_rng(), n) + @test (@SVector rand(_rng(), 4)) isa SVector{4, Float64} + @test (@SVector rand(_rng(), 4)) == rand(_rng(), 4) + @test (@SVector randn(_rng(), 4)) isa SVector{4, Float64} + @test (@SVector randn(_rng(), 4)) == randn(_rng(), 4) + @test (@SVector randexp(_rng(), 4)) isa SVector{4, Float64} + @test (@SVector randexp(_rng(), 4)) == randexp(_rng(), 4) + + for T in (Float32, Float64) + @test (@SVector rand(T, n)) isa SVector{n, T} + @test (@SVector randn(T, n)) isa SVector{n, T} + @test (@SVector randexp(T, n)) isa SVector{n, T} + @test (@SVector rand(T, 4)) isa SVector{4, T} + @test (@SVector randn(T, 4)) isa SVector{4, T} + @test (@SVector randexp(T, 4)) isa SVector{4, T} + @test (@SVector rand(_rng(), T, n)) isa SVector{n, T} + VERSION≥v"1.7" && @test (@SVector rand(_rng(), T, n)) == rand(_rng(), T, n) broken=(T===Float32) + @test (@SVector randn(_rng(), T, n)) isa SVector{n, T} + @test (@SVector randn(_rng(), T, n)) == randn(_rng(), T, n) + @test (@SVector randexp(_rng(), T, n)) isa SVector{n, T} + @test (@SVector randexp(_rng(), T, n)) == randexp(_rng(), T, n) + @test (@SVector rand(_rng(), T, 4)) isa SVector{4, T} + VERSION≥v"1.7" && @test (@SVector rand(_rng(), T, 4)) == rand(_rng(), T, 4) broken=(T===Float32) + @test (@SVector randn(_rng(), T, 4)) isa SVector{4, T} + @test (@SVector randn(_rng(), T, 4)) == randn(_rng(), T, 4) + @test (@SVector randexp(_rng(), T, 4)) isa SVector{4, T} + @test (@SVector randexp(_rng(), T, 4)) == randexp(_rng(), T, 4) + end + end + + @testset "expand error" begin + test_expand_error(:(@SVector fill(1.5, 2, 3))) + test_expand_error(:(@SVector ones(2, 3, 4))) + test_expand_error(:(@SVector rand(Float64, 2, 3, 4))) + test_expand_error(:(@SVector sin(1:5))) + test_expand_error(:(@SVector [i*j for i in 1:2, j in 2:3])) + test_expand_error(:(@SVector Float32[i*j for i in 1:2, j in 2:3])) + test_expand_error(:(@SVector [1; 2; 3]...)) + test_expand_error(:(@SVector a)) + test_expand_error(:(@SVector [[1 2];[3 4]])) + end if VERSION >= v"1.7.0" @test ((@SVector Float64[1;2;3;;;])::SVector{3}).data === (1.0, 2.0, 3.0) diff --git a/test/runtests.jl b/test/runtests.jl index 68aac8d68..c7eef00ab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,8 @@ using InteractiveUtils # deterministic. Therefore seed the RNG here (and further down, to avoid test # file order dependence) Random.seed!(42) +# Useful function to regenerate rng +_rng() = Random.MersenneTwister(42) include("testutil.jl") # Hook into Pkg.test so that tests from a single file can be run. For example,