Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend macros with rand to support custom samplers #1210

Merged
merged 10 commits into from
Jan 3, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.8.2"
version = "1.9.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
113 changes: 94 additions & 19 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,65 @@
cat_any(Val(maxdim), Val(catdim), nargs)
end

#=
For example,
* `@SArray rand(2, 3, 4)`
* `@SArray rand(rng, 3, 4)`
will be expanded to the following.
* `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))`
* `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))`
The function `_int2val` is required to avoid the following case.
* `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))`
* `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))`
Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error.
=#
_int2val(x::Int) = Val(x)
_int2val(::Any) = nothing
# @SArray zeros(...)
_zeros_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = zeros(SA{Tuple{n1, ns...}})
_zeros_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = zeros(SA{Tuple{ns...}, T})
# @SArray ones(...)
_ones_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = ones(SA{Tuple{n1, ns...}})
_ones_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = ones(SA{Tuple{ns...}, T})
# @SArray rand(...)
_rand_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = rand(SA{Tuple{n1,n2,ns...}})
_rand_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, T, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_rand_with_Val(::Type{SA}, sampler, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, ns...), SA{Tuple{n1, ns...}, Random.gentype(sampler)})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(rng, Float64, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, T, Size(ns...), SA{Tuple{ns...}, T})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, sampler, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, sampler, Size(ns...), SA{Tuple{ns...}, Random.gentype(sampler)})
# @SArray randn(...)
_randn_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randn(SA{Tuple{n1,n2,ns...}})
_randn_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_randn_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_randn_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randn(rng, Size(ns...), SA{Tuple{ns...}, T})
# @SArray randexp(...)
_randexp_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randexp(SA{Tuple{n1,n2,ns...}})
_randexp_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randexp(rng, Size(ns...), SA{Tuple{ns...}, T})

escall(args) = Iterators.map(esc, args)
function _isnonnegvec(args)
length(args) == 0 && return false
all(isa.(args, Integer)) && return all(args .≥ 0)
return false
end
function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
if !isa(ex, Expr)
error("Bad input for @$SA")
end
head = ex.head
if head === :vect # vector
return :($SA{$Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
return :($SA{Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
elseif head === :ref # typed, vector
return :($SA{$Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
return :($SA{Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat
args = parse_cat_ast(ex)
return :($SA{$Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
return :($SA{Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat
args = parse_cat_ast(ex)
return :($SA{$Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
return :($SA{Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
elseif head === :comprehension
if length(ex.args) != 1
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
Expand All @@ -173,7 +216,7 @@
return quote
let
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
$SA{$Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
$SA{Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
end
end
elseif head === :typed_comprehension
Expand All @@ -192,26 +235,58 @@
return quote
let
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
$SA{$Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
$SA{Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
end
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 1
f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)")
return :($f($SA{$Tuple{},$Float64}))
end
return quote
if isa($(esc(ex.args[2])), DataType)
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
else
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
end
fargs = ex.args[2:end]
if f === :zeros || f === :ones
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 0
# for calls like `zeros()`
return :($f($SA{Tuple{},$Float64}))
elseif _isnonnegvec(fargs)
# for calls like `zeros(dims...)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
else
# for calls like `zeros(type)`
# for calls like `zeros(type, dims...)`
return :($_f_with_Val($SA, $(esc(fargs[1])), Val($(esc(fargs[1]))), Val(tuple($(escall(fargs[2:end])...)))))
end
elseif f === :fill
length(ex.args) == 1 && error("@$SA got bad expression: $(ex)")
return :($f($(esc(ex.args[2])), $SA{$Tuple{$(escall(ex.args[3:end])...)}}))
# for calls like `fill(value, dims...)`
return :($f($(esc(fargs[1])), $SA{Tuple{$(escall(fargs[2:end])...)}}))
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 0
# No support for `@SArray rand()`
error("@$SA got bad expression: $(ex)")

Check warning on line 264 in src/SArray.jl

View check run for this annotation

Codecov / codecov/patch

src/SArray.jl#L264

Added line #L264 was not covered by tests
elseif _isnonnegvec(fargs)
# for calls like `rand(dims...)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
elseif length(fargs) ≥ 2
# for calls like `rand(dim1, dim2, dims...)`
# for calls like `rand(type, dim1, dims...)`
# for calls like `rand(sampler, dim1, dims...)`
# for calls like `rand(rng, dim1, dims...)`
# for calls like `rand(rng, type, dims...)`
# for calls like `rand(rng, sampler, dims...)`
# for calls like `randn(dim1, dim2, dims...)`
# for calls like `randn(type, dim1, dims...)`
# for calls like `randn(rng, dim1, dims...)`
# for calls like `randn(rng, type, dims...)`
# for calls like `randexp(dim1, dim2, dims...)`
# for calls like `randexp(type, dim1, dims...)`
# for calls like `randexp(rng, dim1, dims...)`
# for calls like `randexp(rng, type, dims...)`
return :($_f_with_Val($SA, $(esc(fargs[1])), $(esc(fargs[2])), _int2val($(esc(fargs[1]))), _int2val($(esc(fargs[2]))), Val(tuple($(escall(fargs[3:end])...)))))
elseif length(fargs) == 1
# for calls like `rand(dim)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
else
error("@$SA got bad expression: $(ex)")

Check warning on line 288 in src/SArray.jl

View check run for this annotation

Codecov / codecov/patch

src/SArray.jl#L288

Added line #L288 was not covered by tests
end
else
error("@$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
Expand Down
64 changes: 54 additions & 10 deletions src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
x1, x2
end

# @SMatrix rand(...)
_rand_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2})
_rand_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, T, Size(n1, n2), SM{n1, n2, T})
_rand_with_Val(::Type{SM}, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
_rand_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2, T})
_rand_with_Val(::Type{SM}, rng::AbstractRNG, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(rng, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
# @SMatrix randn(...)
_randn_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2})
_randn_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randn(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
_randn_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2, T})
# @SMatrix randexp(...)
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2})
_randexp_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randexp(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2, T})

function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM}
if !isa(ex, Expr)
error("Bad input for @$SM")
Expand Down Expand Up @@ -69,22 +84,51 @@
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 3
return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base
elseif length(ex.args) == 4
return :($f($SM{$(escall(ex.args[[3,4,2]])...)}))
fargs = ex.args[2:end]
if f === :zeros || f === :ones
if length(fargs) == 2
# for calls like `zeros(dim1, dim2)`
return :($f($SM{$(escall(fargs)...)}))
elseif length(fargs[2:end]) == 2
# for calls like `zeros(type, dim1, dim2)`
return :($f($SM{$(escall(fargs[2:end])...), $(esc(fargs[1]))}))
else
error("@$SM expected a 2-dimensional array expression")
error("@$SM got bad expression: $(ex)")
end
elseif ex.args[1] === :fill
if length(ex.args) == 4
return :($f($(esc(ex.args[2])), $SM{$(escall(ex.args[3:4])...)}))
elseif f === :fill
# for calls like `fill(value, dim1, dim2)`
if length(fargs[2:end]) == 2
return :($f($(esc(fargs[1])), $SM{$(escall(fargs[2:end])...)}))
else
error("@$SM expected a 2-dimensional array expression")
end
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 2
# for calls like `rand(dim1, dim2)`
# for calls like `randn(dim1, dim2)`
# for calls like `randexp(dim1, dim2)`
return :($f($SM{$(escall(fargs)...)}))
elseif length(fargs) == 3
# for calls like `rand(rng, dim1, dim2)`
# for calls like `rand(type, dim1, dim2)`
# for calls like `rand(sampler, dim1, dim2)`
# for calls like `randn(rng, dim1, dim2)`
# for calls like `randn(type, dim1, dim2)`
# for calls like `randexp(rng, dim1, dim2)`
# for calls like `randexp(type, dim1, dim2)`
return :($_f_with_Val($SM, $(esc(fargs[1])), Val($(esc(fargs[2]))), Val($(esc(fargs[3])))))
elseif length(fargs) == 4
# for calls like `rand(rng, type, dim1, dim2)`
# for calls like `rand(rng, sampler, dim1, dim2)`
# for calls like `randn(rng, type, dim1, dim2)`
# for calls like `randexp(rng, type, dim1, dim2)`
return :($_f_with_Val($SM, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3]))), Val($(esc(fargs[4])))))
else
error("@$SM got bad expression: $(ex)")

Check warning on line 128 in src/SMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/SMatrix.jl#L128

Added line #L128 was not covered by tests
end
else
error("@$SM only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
error("@$SM only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
else
error("Bad input for @$SM")
Expand Down
64 changes: 54 additions & 10 deletions src/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ function check_vector_length(x::Tuple, T = :S)
length(x) >= 1 ? x[1] : 1
end

# @SVector rand(...)
_rand_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = rand(rng, SV{n})
_rand_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, T, Size(n), SV{n, T})
_rand_with_Val(::Type{SV}, sampler, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, sampler, Size(n), SV{n, Random.gentype(sampler)})
_rand_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = rand(rng, SV{n, T})
_rand_with_Val(::Type{SV}, rng::AbstractRNG, sampler, ::Val{n}) where {SV, n} = _rand(rng, sampler, Size(n), SV{n, Random.gentype(sampler)})
# @SVector randn(...)
_randn_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randn(rng, SV{n})
_randn_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randn(Random.GLOBAL_RNG, Size(n), SV{n, T})
_randn_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randn(rng, SV{n, T})
# @SVector randexp(...)
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randexp(rng, SV{n})
_randexp_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randexp(Random.GLOBAL_RNG, Size(n), SV{n, T})
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randexp(rng, SV{n, T})

function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV}
if !isa(ex, Expr)
error("Bad input for @$SV")
Expand Down Expand Up @@ -74,22 +89,51 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 2
return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base
elseif length(ex.args) == 3
return :($f($SV{$(escall(ex.args[3:-1:2])...)}))
fargs = ex.args[2:end]
if f === :zeros || f === :ones
if length(fargs) == 1
# for calls like `zeros(dim)`
return :($f($SV{$(esc(fargs[1]))}))
elseif length(fargs) == 2
# for calls like `zeros(type, dim)`
return :($f($SV{$(esc(fargs[2])), $(esc(fargs[1]))}))
else
error("@$SV expected a 1-dimensional array expression")
error("@$SV got bad expression: $(ex)")
end
elseif ex.args[1] === :fill
if length(ex.args) == 3
return :($f($(esc(ex.args[2])), $SV{$(esc(ex.args[3]))}))
elseif f === :fill
# for calls like `fill(value, dim)`
if length(fargs) == 2
return :($f($(esc(fargs[1])), $SV{$(esc(fargs[2]))}))
else
error("@$SV expected a 1-dimensional array expression")
end
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 1
# for calls like `rand(dim)`
# for calls like `randn(dim)`
# for calls like `randexp(dim)`
return :($f($SV{$(escall(fargs)...)}))
elseif length(fargs) == 2
# for calls like `rand(rng, dim)`
# for calls like `rand(type, dim)`
# for calls like `rand(sampler, dim)`
# for calls like `randn(rng, dim)`
# for calls like `randn(type, dim)`
# for calls like `randexp(rng, dim)`
# for calls like `randexp(type, dim)`
return :($_f_with_Val($SV, $(esc(fargs[1])), Val($(esc(fargs[2])))))
elseif length(fargs) == 3
# for calls like `rand(rng, type, dim)`
# for calls like `rand(rng, sampler, dim)`
# for calls like `randn(rng, type, dim)`
# for calls like `randexp(rng, type, dim)`
return :($_f_with_Val($SV, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3])))))
else
error("@$SV got bad expression: $(ex)")
end
else
error("@$SV only supports the zeros(), ones(), rand(), randn() and randexp() functions.")
error("@$SV only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
else
error("Use @$SV [a,b,c], @$SV Type[a,b,c] or a comprehension like @$SV [f(i) for i = i_min:i_max]")
Expand Down
4 changes: 2 additions & 2 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ end

@inline rand(rng::AbstractRNG, range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(rng, range, Size(SA), SA)
@inline rand(range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(Random.GLOBAL_RNG, range, Size(SA), SA)
@generated function _rand(rng::AbstractRNG, range::AbstractArray, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, range)) for i = 1:prod(s)]
@generated function _rand(rng::AbstractRNG, X, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, X)) for i = 1:prod(s)]
return quote
@_inline_meta
$SA(tuple($(v...)))
Expand Down
Loading
Loading