-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support
rand!
and rand
using MPS where appropriate
- Loading branch information
1 parent
298ded2
commit 40e640e
Showing
9 changed files
with
259 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
using Random | ||
|
||
""" | ||
MPS.RNG() | ||
A random number generator using `rand()` in a device kernel. | ||
""" | ||
mutable struct RNG <: AbstractRNG | ||
seed::UInt | ||
counter::UInt32 | ||
|
||
function RNG(seed::Integer) | ||
new(seed%UInt, 0) | ||
end | ||
RNG(seed::UInt, counter::UInt32) = new(seed, counter) | ||
end | ||
|
||
make_seed() = Base.rand(RandomDevice(), UInt) | ||
|
||
RNG() = RNG(make_seed()) | ||
|
||
Base.copy(rng::RNG) = RNG(rng.seed, rng.counter) | ||
Base.hash(rng::RNG, h::UInt) = hash(rng.seed, hash(rng.counter, h)) | ||
Base.:(==)(a::RNG, b::RNG) = (a.seed == b.seed) && (a.counter == b.counter) | ||
|
||
function Random.seed!(rng::RNG, seed::Integer) | ||
rng.seed = seed % UInt | ||
rng.counter = 0 | ||
end | ||
|
||
Random.seed!(rng::RNG) = Random.seed!(rng, make_seed()) | ||
|
||
@inline function update_state!(rng::RNG, len) | ||
new_counter = Int64(rng.counter) + len | ||
overflow, remainder = fldmod(new_counter, typemax(UInt32)) | ||
rng.seed += overflow # XXX: is this OK? | ||
rng.counter = remainder | ||
return rng | ||
end | ||
|
||
const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}() | ||
function default_rng() | ||
dev = current_device() | ||
get!(GLOBAL_RNGs, dev) do | ||
RNG() | ||
end | ||
end | ||
|
||
function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} | ||
mpsvecormat = MPSVector(A, UInt32) | ||
_mpsmat_rand!(mpsvecormat, seed = rng.seed + rng.counter) | ||
|
||
update_state!(rng,length(A)) | ||
return A | ||
end | ||
function Random.rand!(rng::RNG, A::MtlArray{Float32}) | ||
mpsvecormat = MPSVector(A, Float32) | ||
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomUniformDistributionDescriptor(0, 1), seed = rng.seed + rng.counter) | ||
|
||
update_state!(rng,length(A)) | ||
return A | ||
end | ||
function Random.randn!(rng::RNG, A::MtlArray{Float32}) | ||
mpsvecormat = MPSVector(A, Float32) | ||
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomNormalDistributionDescriptor(0, 1), seed = rng.seed + rng.counter) | ||
|
||
update_state!(rng,length(A)) | ||
return A | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,74 @@ | ||
using Random | ||
using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor, | ||
MPSMatrixRandomNormalDistributionDescriptor | ||
|
||
gpuarrays_rng() = GPUArrays.default_rng(MtlArray) | ||
mpsrand_rng() = MPS.default_rng() | ||
|
||
# GPUArrays in-place | ||
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A) | ||
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A) | ||
|
||
@inline function usempsrandom(A::MtlArray{T}) where {T} | ||
return (A.offset == 0 && | ||
(length(A) * sizeof(T) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0)) | ||
end | ||
|
||
# Use MPS random functionality where possible | ||
function Random.rand!(A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} | ||
if usempsrandom(A) | ||
@inline Random.rand!(gpuarrays_rng(), A) | ||
else | ||
@inline Random.rand!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
function Random.rand!(A::MtlArray{Float32}) | ||
if usempsrandom(A) | ||
@inline Random.rand!(mpsrand_rng(), A) | ||
else | ||
@inline Random.rand!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
function Random.randn!(A::MtlArray{Float32}) | ||
if usempsrandom(A) | ||
@inline Random.randn!(mpsrand_rng(), A) | ||
else | ||
@inline Random.randn!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
|
||
# GPUArrays out-of-place | ||
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...) | ||
rand(::Type{T}, dims::Dims; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} = | ||
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(::Type{Float32}, dims::Dims; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims),storage}(undef, dims...)) | ||
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = | ||
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(T::Type, dims::Dims; storage=DefaultStorageMode) = | ||
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
|
||
# support all dimension specifications | ||
rand(::Type{T}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} = | ||
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(::Type{Float32}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...)) | ||
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = | ||
Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) | ||
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# untyped out-of-place | ||
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...)) | ||
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) | ||
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# seeding | ||
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed) | ||
function seed!(seed=Base.rand(UInt64)) | ||
Random.seed!(gpuarrays_rng(), seed) | ||
Random.seed!(mpsrand_rng(), seed) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.