Skip to content

Commit

Permalink
Add wrappers for MPSMatrixRandom (#321)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
christiangnrd and maleadt authored Aug 27, 2024
1 parent 28576b3 commit 1dde978
Show file tree
Hide file tree
Showing 7 changed files with 606 additions and 46 deletions.
45 changes: 45 additions & 0 deletions docs/src/usage/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
```@meta
DocTestSetup = quote
using Metal
using GPUArrays
import Random
Random.seed!(1)
Metal.seed!(1)
end
```

Expand Down Expand Up @@ -106,3 +112,42 @@ julia> Base.mapreducedim!(identity, +, b, a)
1×1 MtlMatrix{Float32, Metal.PrivateStorage}:
6.0
```

## Random numbers

Base's convenience functions for generating random numbers are available in Metal as well:

```jldoctest
julia> Metal.rand(2)
2-element MtlVector{Float32, Metal.PrivateStorage}:
0.89025915
0.8946847
julia> Metal.randn(Float32, 2, 1)
2×1 MtlMatrix{Float32, Metal.PrivateStorage}:
1.2279074
1.2518331
```

Behind the scenes, these random numbers come from two different generators: one backed by
[Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixrandom?language=objc),
another by using the GPUArrays.jl random methods. Operations on these generators are implemented using methods from the Random
standard library:

```jldoctest
julia> using Random, GPUArrays
julia> a = Random.rand(MPS.default_rng(), Float32, 1)
1-element MtlVector{Float32, Metal.PrivateStorage}:
0.89025915
julia> a = Random.rand!(GPUArrays.default_rng(MtlArray), a)
1-element MtlVector{Float32, Metal.PrivateStorage}:
0.0705002
```

!!! note
`MPSMatrixRandom` functionality requires Metal.jl >= v1.4

!!! warning
`Random.rand!(::MPS.RNG, args...)` and `Random.randn!(::MPS.RNG, args...)` have a framework limitation that requires the byte offset and byte size of the destination array to be a multiple of 4.
2 changes: 2 additions & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ include("kernel.jl")
include("images.jl")
include("matrix.jl")
include("vector.jl")
include("matrixrandom.jl")
include("decomposition.jl")
include("copy.jl")

# integrations
include("random.jl")
include("linalg.jl")

end
145 changes: 145 additions & 0 deletions lib/mps/matrixrandom.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
@cenum MPSMatrixRandomDistribution::UInt begin
MPSMatrixRandomDistributionDefault = 1
MPSMatrixRandomDistributionUniform = 2
MPSMatrixRandomDistributionNormal = 3
end

#
# matrix random descriptor
#

export MPSMatrixRandomDistributionDescriptor

@objcwrapper immutable=false MPSMatrixRandomDistributionDescriptor <: NSObject

@objcproperties MPSMatrixRandomDistributionDescriptor begin
@autoproperty distributionType::MPSMatrixRandomDistribution
@autoproperty maximum::Float32 setter=setMaximum
@autoproperty mean::Float32 setter=setMean
@autoproperty minimum::Float32 setter=setMimimum
@autoproperty standardDeviation::Float32 setter=setStandardDeviation
end


function MPSMatrixRandomDefaultDistributionDescriptor()
desc = @objc [MPSMatrixRandomDistributionDescriptor defaultDistributionDescriptor]::id{MPSMatrixRandomDistributionDescriptor}
obj = MPSMatrixRandomDistributionDescriptor(desc)
return obj
end

# Default constructor
MPSMatrixRandomDistributionDescriptor() = MPSMatrixRandomDefaultDistributionDescriptor()

function MPSMatrixRandomNormalDistributionDescriptor(mean, standardDeviation)
desc = @objc [MPSMatrixRandomDistributionDescriptor normalDistributionDescriptorWithMean:mean::Float32
standardDeviation:standardDeviation::Float32]::id{MPSMatrixRandomDistributionDescriptor}
obj = MPSMatrixRandomDistributionDescriptor(desc)
return obj
end

function MPSMatrixRandomNormalDistributionDescriptor(mean, standardDeviation, minimum, maximum)
desc = @objc [MPSMatrixRandomDistributionDescriptor normalDistributionDescriptorWithMean:mean::Float32
standardDeviation:standardDeviation::Float32
minimum:minimum::Float32
maximum:maximum::Float32]::id{MPSMatrixRandomDistributionDescriptor}
obj = MPSMatrixRandomDistributionDescriptor(desc)
return obj
end

function MPSMatrixRandomUniformDistributionDescriptor(minimum, maximum)
desc = @objc [MPSMatrixRandomDistributionDescriptor uniformDistributionDescriptorWithMinimum:minimum::Float32
maximum:maximum::Float32]::id{MPSMatrixRandomDistributionDescriptor}
obj = MPSMatrixRandomDistributionDescriptor(desc)
return obj
end


@objcwrapper immutable=false MPSMatrixRandom <: MPSKernel

@objcproperties MPSMatrixRandom begin
@autoproperty batchSize::NSUInteger
@autoproperty batchStart::NSUInteger
@autoproperty destinationDataType::id{MPSDataType}
@autoproperty distributionType::id{MPSMatrixRandomDistributionDescriptor}
end

function encode!(cmdbuf::MTLCommandBuffer, kernel::K, destinationMatrix::MPSMatrix) where {K<:MPSMatrixRandom}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
destinationMatrix:destinationMatrix::id{MPSMatrix}]::Nothing
end
function encode!(cmdbuf::MTLCommandBuffer, kernel::K, destinationVector::MPSVector) where {K<:MPSMatrixRandom}
@objc [kernel::id{K} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
destinationVector:destinationVector::id{MPSVector}]::Nothing
end

@objcwrapper immutable=false MPSMatrixRandomMTGP32 <: MPSMatrixRandom
@objcwrapper immutable=false MPSMatrixRandomPhilox <: MPSMatrixRandom

for R in [:MPSMatrixRandomMTGP32, :MPSMatrixRandomPhilox]
@eval begin
function $R(device)
kernel = @objc [$R alloc]::id{$R}
obj = $R(kernel)
finalizer(release, obj)
@objc [obj::id{$R} initWithDevice:device::id{MTLDevice}]::id{$R}
return obj
end
function $R(device, destinationDataType, seed)
kernel = @objc [$R alloc]::id{$R}
obj = $R(kernel)
finalizer(release, obj)
@objc [obj::id{$R} initWithDevice:device::id{MTLDevice}
destinationDataType:destinationDataType::MPSDataType
seed:seed::NSUInteger]::id{$R}
return obj
end
function $R(device, destinationDataType, seed, distributionDescriptor)
kernel = @objc [$R alloc]::id{$R}
obj = $R(kernel)
finalizer(release, obj)
@objc [obj::id{$R} initWithDevice:device::id{MTLDevice}
destinationDataType:destinationDataType::MPSDataType
seed:seed::NSUInteger
distributionDescriptor:distributionDescriptor::id{MPSMatrixRandomDistributionDescriptor}]::id{$R}
return obj
end
end
end

synchronize_state(kern::MPSMatrixRandomMTGP32, cmdbuf::MTLCommandBuffer) =
@objc [obj::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing


@inline function _mpsmat_rand!(randkern::MPSMatrixRandom, dest::MtlArray{T}, ::Type{T2};
queue::MTLCommandQueue = global_queue(device()),
async::Bool=false) where {T,T2}
byteoffset = dest.offset * sizeof(T)
bytesize = sizeof(dest)

# Even though `append_copy`` seems to work with any size or offset values, the documentation at
# https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc
# mentions that both must be multiples of 4 bytes in MacOS so error when they are not
(bytesize % 4 == 0) || error(lazy"Destination buffer bytesize ($(bytesize)) must be a multiple of 4.")
(byteoffset % 4 == 0) || error(lazy"Destination buffer offset ($(byteoffset)) must be a multiple of 4.")

cmdbuf = if bytesize % 16 == 0 && dest.offset == 0
MTLCommandBuffer(queue) do cmdbuf
vecDesc = MPSVectorDescriptor(bytesize ÷ sizeof(T2), T2)
mpsdest = MPSVector(dest, vecDesc)
encode!(cmdbuf, randkern, mpsdest)
end
else
MTLCommandBuffer(queue) do cmdbuf
len = UInt(ceil(bytesize / sizeof(T2)) * 4)
vecDesc = MPSVectorDescriptor(len, T2)
tempVec = MPSTemporaryVector(cmdbuf, vecDesc)
encode!(cmdbuf, randkern, tempVec)
MTLBlitCommandEncoder(cmdbuf) do enc
MTL.append_copy!(enc, dest.data[], byteoffset, tempVec.data, tempVec.offset, bytesize)
end
end
end

async || wait_completed(cmdbuf)
return
end
109 changes: 109 additions & 0 deletions lib/mps/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using Random
using Metal: DefaultStorageMode

"""
MPS.RNG()
A random number generator using `rand()` in a device kernel.
"""
mutable struct RNG <: AbstractRNG
device::MTLDevice
uniformInteger::MPSMatrixRandomPhilox
uniformFloat32::MPSMatrixRandomPhilox
normalFloat32::MPSMatrixRandomPhilox
end


make_seed() = Base.rand(RandomDevice(), UInt)

function RNG(device::MTLDevice, seed::Integer)
seed = seed%UInt
RNG(device,
MPSMatrixRandomPhilox(device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()),
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)),
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)),)
end
@autoreleasepool RNG(seed::Integer) = RNG(device(), seed)
RNG(device::MTLDevice) = RNG(device, make_seed())

@autoreleasepool RNG() = RNG(device(), make_seed())

Base.copy(rng::RNG) = RNG(copy(rng.device), copy(rng.uniformInteger), copy(rng.uniformFloat32), copy(rng.normalFloat32))

@autoreleasepool function Random.seed!(rng::RNG, seed::Integer)
rng.uniformInteger = MPSMatrixRandomPhilox(rng.device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor())
rng.uniformFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1))
rng.normalFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1))
return rng
end

Random.seed!(rng::RNG) = Random.seed!(rng, make_seed())

const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}()
@autoreleasepool function default_rng()
dev = device()
get!(GLOBAL_RNGs, dev) do
RNG(dev)
end
end

const UniformTypes = [Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64]
const UniformType = Union{[Type{T} for T in UniformTypes]...}
const UniformArray = MtlArray{<:Union{Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
isempty(A) && return A
_mpsmat_rand!(rng.uniformInteger, A, UInt32)
return A
end

@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{Float32})
isempty(A) && return A
_mpsmat_rand!(rng.uniformFloat32, A, Float32)
return A
end

const NormalType = Type{Float32}
const NormalArray = MtlArray{<:Float32}
@autoreleasepool function Random.randn!(rng::RNG, A::MtlArray{Float32})
isempty(A) && return A
_mpsmat_rand!(rng.normalFloat32, A, Float32)
return A
end

# CPU arrays
function Random.rand!(rng::RNG, A::AbstractArray{T,N}) where {T <: Union{UniformTypes...}, N}
isempty(A) && return A
B = MtlArray{T,N,SharedStorage}(undef, size(A))
rand!(rng, B)
copyto!(A, unsafe_wrap(Array{T},B))
return A
end
function Random.randn!(rng::RNG, A::AbstractArray{T,N}) where {T <: Float32, N}
isempty(A) && return A
B = MtlArray{T,N,SharedStorage}(undef, size(A))
randn!(rng, B)
copyto!(A, unsafe_wrap(Array{T},B))
return A
end

# Out of place
Random.rand(rng::RNG, T::UniformType, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
Random.randn(rng::RNG, T::NormalType, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...))

# support all dimension specifications
Random.rand(rng::RNG, T::UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
Random.randn(rng::RNG, T::NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))

# untyped out-of-place
Random.rand(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# scalars
Random.rand(rng::RNG, T::UniformType=Float32; storage=SharedStorage) = rand(rng, T, 4; storage)[1]
Random.randn(rng::RNG, T::NormalType=Float32; storage=SharedStorage) = randn(rng, T, 4; storage)[1]
Loading

0 comments on commit 1dde978

Please sign in to comment.