diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index a4157ece6..52235ea84 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -23,6 +23,7 @@ include("vector.jl") include("matrixrandom.jl") # integrations +include("random.jl") include("linalg.jl") # decompositions diff --git a/lib/mps/random.jl b/lib/mps/random.jl new file mode 100644 index 000000000..3fd4e1ec2 --- /dev/null +++ b/lib/mps/random.jl @@ -0,0 +1,69 @@ +using Random + +""" + MPS.RNG() + +A random number generator using `rand()` in a device kernel. +""" +mutable struct RNG <: AbstractRNG + seed::UInt + counter::UInt32 + + function RNG(seed::Integer) + new(seed%UInt, 0) + end + RNG(seed::UInt, counter::UInt32) = new(seed, counter) +end + +make_seed() = Base.rand(RandomDevice(), UInt) + +RNG() = RNG(make_seed()) + +Base.copy(rng::RNG) = RNG(rng.seed, rng.counter) +Base.hash(rng::RNG, h::UInt) = hash(rng.seed, hash(rng.counter, h)) +Base.:(==)(a::RNG, b::RNG) = (a.seed == b.seed) && (a.counter == b.counter) + +function Random.seed!(rng::RNG, seed::Integer) + rng.seed = seed % UInt + rng.counter = 0 +end + +Random.seed!(rng::RNG) = Random.seed!(rng, make_seed()) + +@inline function update_state!(rng::RNG, len) + new_counter = Int64(rng.counter) + len + overflow, remainder = fldmod(new_counter, typemax(UInt32)) + rng.seed += overflow # XXX: is this OK? + rng.counter = remainder + return rng +end + +const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}() +function default_rng() + dev = current_device() + get!(GLOBAL_RNGs, dev) do + RNG() + end +end + +function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} + mpsvecormat = MPSVector(A, UInt32) + _mpsmat_rand!(mpsvecormat, seed = rng.seed + rng.counter) + + update_state!(rng,length(A)) + return A +end +function Random.rand!(rng::RNG, A::MtlArray{Float32}) + mpsvecormat = MPSVector(A, Float32) + _mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomUniformDistributionDescriptor(0, 1), seed = rng.seed + rng.counter) + + update_state!(rng,length(A)) + return A +end +function Random.randn!(rng::RNG, A::MtlArray{Float32}) + mpsvecormat = MPSVector(A, Float32) + _mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomNormalDistributionDescriptor(0, 1), seed = rng.seed + rng.counter) + + update_state!(rng,length(A)) + return A +end diff --git a/lib/mps/vector.jl b/lib/mps/vector.jl index 2d4bf9bf3..55dcee4f3 100644 --- a/lib/mps/vector.jl +++ b/lib/mps/vector.jl @@ -28,6 +28,10 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType) return obj end +function vectorBytesForLength(length, dataType) + @objc [MPSVectorDescriptor vectorBytesForLength:length::NSUInteger + dataType:dataType::MPSDataType]::NSUInteger +end export MPSVector @@ -48,9 +52,16 @@ end Metal vector representation used in Performance Shaders. """ -function MPSVector(arr::MtlVector{T}) where T - len = length(arr) - desc = MPSVectorDescriptor(len, T) +MPSVector(arr::MtlVector{T}) where T = mpsvector(arr, T, length(arr)) + +# For rand! and randn! +function MPSVector(arr::MtlArray{T}, ::Type{T2}) where {T,T2} + len = UInt(ceil(length(arr) * sizeof(T) / sizeof(T2) / 4) * 4) + return mpsvector(arr, T2, len) +end + +@inline function mpsvector(arr::MtlArray{T}, ::Type{T2}, len) where {T,T2} + desc = MPSVectorDescriptor(len, T2) vec = @objc [MPSVector alloc]::id{MPSVector} obj = MPSVector(vec) offset = arr.offset * sizeof(T) @@ -61,6 +72,12 @@ function MPSVector(arr::MtlVector{T}) where T return obj end +resourceSize(vecormat::M) where {M<:Union{MPSVector,MPSMatrix}} = + @objc [vecormat::id{M} resourceSize]::NSUInteger + +synchronizeOnCommandBuffer(vecormat::M, cmdBuf::MTLCommandBuffer) where {M<:Union{MPSVector,MPSMatrix}} = + @objc [vecormat::id{M} synchronizeOnCommandBuffer:cmdBuf::id{MTLCommandBuffer}]::Nothing + # # matrix vector multiplication # diff --git a/lib/mtl/buffer.jl b/lib/mtl/buffer.jl index cb3435b33..622776e33 100644 --- a/lib/mtl/buffer.jl +++ b/lib/mtl/buffer.jl @@ -21,14 +21,17 @@ end ## allocation +const BUFFER_ALIGNMENT_FOR_RAND::Int = 16 +@inline bufferbytesize(bytesize::T) where {T <: Integer} = ceil(T, bytesize / BUFFER_ALIGNMENT_FOR_RAND) * T(BUFFER_ALIGNMENT_FOR_RAND) function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer; storage=Private, hazard_tracking=DefaultTracking, cache_mode=DefaultCPUCache) opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode - @assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap - ptr = alloc_buffer(dev, bytesize, opts) + realbytesize = bufferbytesize(bytesize) + @assert 0 < realbytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap + ptr = alloc_buffer(dev, realbytesize, opts) return MTLBuffer(ptr) end @@ -39,8 +42,9 @@ function MTLBuffer(dev::MTLDevice, bytesize::Integer, ptr::Ptr; storage == Private && error("Can't create a Private copy-allocated buffer.") opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode - @assert 0 < bytesize <= dev.maxBufferLength - ptr = alloc_buffer(dev, bytesize, opts, ptr) + realbytesize = bufferbytesize(bytesize) + @assert 0 < realbytesize <= dev.maxBufferLength + ptr = alloc_buffer(dev, realbytesize, opts, ptr) return MTLBuffer(ptr) end diff --git a/src/Metal.jl b/src/Metal.jl index 7c5a4e29b..7d4b286b5 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -51,6 +51,10 @@ include("compiler/compilation.jl") include("compiler/execution.jl") include("compiler/reflection.jl") +# libraries +include("../lib/mps/MPS.jl") +export MPS + # array implementation include("utilities.jl") include("broadcast.jl") @@ -58,10 +62,6 @@ include("mapreduce.jl") include("random.jl") include("gpuarrays.jl") -# libraries -include("../lib/mps/MPS.jl") -export MPS - # KernelAbstractions include("MetalKernels.jl") import .MetalKernels: MetalBackend diff --git a/src/random.jl b/src/random.jl index 81cc48c00..f2711437b 100644 --- a/src/random.jl +++ b/src/random.jl @@ -1,24 +1,74 @@ using Random +using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor, + MPSMatrixRandomNormalDistributionDescriptor gpuarrays_rng() = GPUArrays.default_rng(MtlArray) +mpsrand_rng() = MPS.default_rng() # GPUArrays in-place Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A) Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A) +@inline function usempsrandom(A::MtlArray{T}) where {T} + return (A.offset == 0 && + (length(A) * sizeof(T) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0)) +end + +# Use MPS random functionality where possible +function Random.rand!(A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} + if usempsrandom(A) + @inline Random.rand!(gpuarrays_rng(), A) + else + @inline Random.rand!(gpuarrays_rng(), A) + end + return A +end +function Random.rand!(A::MtlArray{Float32}) + if usempsrandom(A) + @inline Random.rand!(mpsrand_rng(), A) + else + @inline Random.rand!(gpuarrays_rng(), A) + end + return A +end +function Random.randn!(A::MtlArray{Float32}) + if usempsrandom(A) + @inline Random.randn!(mpsrand_rng(), A) + else + @inline Random.randn!(gpuarrays_rng(), A) + end + return A +end + # GPUArrays out-of-place -rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...)) -randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...) +rand(::Type{T}, dims::Dims; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} = + Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) +randn(::Type{Float32}, dims::Dims; storage=DefaultStorageMode) = + Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims),storage}(undef, dims...)) +rand(T::Type, dims::Dims; storage=DefaultStorageMode) = + Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) +randn(T::Type, dims::Dims; storage=DefaultStorageMode) = + Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) # support all dimension specifications +rand(::Type{T}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} = + Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) +randn(::Type{Float32}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) + rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = - Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...)) -randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = - Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) + Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) +randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) # untyped out-of-place -rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...)) -randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) +rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) +randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = + Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) # seeding -seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed) +function seed!(seed=Base.rand(UInt64)) + Random.seed!(gpuarrays_rng(), seed) + Random.seed!(mpsrand_rng(), seed) +end diff --git a/test/metal.jl b/test/metal.jl index a9894614e..71605dcab 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -235,8 +235,8 @@ dev = first(devices()) buf = MTLBuffer(dev, 8; storage=Shared) -@test buf.length == 8 -@test sizeof(buf) == 8 +@test buf.length == 16 +@test sizeof(buf) == 16 # MTLResource properties @test buf.device == dev diff --git a/test/mps.jl b/test/mps.jl index d3e12b5f8..1baa921c6 100644 --- a/test/mps.jl +++ b/test/mps.jl @@ -60,7 +60,7 @@ end buf_a = MtlArray{input_jl_type}(arr_a) buf_b = MtlArray{input_jl_type}(arr_b) buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size)) - + truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size)) for i in 1:batch_size @views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i]) diff --git a/test/random.jl b/test/random.jl index 89c771bca..d1aec343d 100644 --- a/test/random.jl +++ b/test/random.jl @@ -1,39 +1,102 @@ using Random +const RAND_TYPES = [Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, + UInt64] +const RANDN_TYPES = [Float16, Float32] +const INPLACE_TUPLES = [[(rand!, T) for T in RAND_TYPES]; + [(randn!, T) for T in RANDN_TYPES]] +const OOPLACE_TUPLES = [[(Metal.rand, T) for T in RAND_TYPES]; + [(Metal.randn, T) for T in RANDN_TYPES]; + [(rand, T) for T in RAND_TYPES]; + [(randn, T) for T in RANDN_TYPES]] + @testset "rand" begin + # in-place + @testset "in-place" begin + @testset "$f with $T" for (f, T) in INPLACE_TUPLES + @testset "$d" for d in (1, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16)) + A = MtlArray{T}(undef, d) + fill!(A, T(0)) + f(A) + @test Metal.usempsrandom(A) == + ((prod(d) * sizeof(T)) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0) + @test !iszero(collect(A)) + end + end + end + + # in-place contiguous views + @testset "in-place for views" begin + @testset "$f with $T" for (f, T) in INPLACE_TUPLES + alen = 100 + A = MtlArray{T}(undef, alen) + function test_view!(X::MtlArray{T}, idx; shouldusemps) where {T} + fill!(X, T(0)) + view_X = @view X[idx] + f(view_X) + cpuX = collect(X) + @test Metal.usempsrandom(view_X) == shouldusemps + @test !iszero(cpuX[idx]) + @test iszero(cpuX[1:alen .∉ Ref(idx)]) + return + end + + # Test when view offset is 0 and buffer size not multiple of 16 + @testset "Off == 0, buf % 16 != 0" begin + test_view!(A, 1:51; shouldusemps=false) + end + + # Test when view offset is 0 and buffer size is multiple of 16 + @testset "Off == 0, buf % 16 == 0" begin + test_view!(A, 1:32; shouldusemps=true) + end + + # Test when view offset is not 0 nor multiple of 16 and buffer size not multiple of 16 + @testset "Off != 0, buf % 16 != 0" begin + test_view!(A, 3:51; shouldusemps=false) + end + + # Test when view offset is multiple of 16 and buffer size not multiple of 16 + @testset "Off % 16 == 0, buf % 16 != 0" begin + test_view!(A, 17:51; shouldusemps=false) + end -# in-place -for (f,T) in ((rand!,Float16), - (rand!,Float32), - (randn!,Float16), - (randn!,Float32)), - d in (2, (2,2), (2,2,2), 3, (3,3), (3,3,3)) - A = MtlArray{T}(undef, d) - fill!(A, T(0)) - f(A) - @test !iszero(collect(A)) -end - -# out-of-place, with implicit type -for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32)), - args in ((2,), (2, 2), (3,), (3, 3)) - A = f(args...) - @test eltype(A) == T -end - -# out-of-place, with type specified -for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32), - (rand,Float32), (randn,Float32)), - args in ((T, 2), (T, 2, 2), (T, (2, 2)), (T, 3), (T, 3, 3), (T, (3, 3))) - A = f(args...) - @test eltype(A) == T -end - -## seeding -Metal.seed!(1) -a = Metal.rand(Int32, 1) -Metal.seed!(1) -b = Metal.rand(Int32, 1) -@test iszero(collect(a) - collect(b)) + # Test when view offset is multiple of 16 and buffer size multiple of 16 + @testset "Off % 16 == 0, buf % 16 == 0" begin + test_view!(A, 17:32; shouldusemps=false) + end + end + end + # out-of-place, with implicit type + @testset "out-of-place" begin + @testset "$f with implicit type" for (f, T) in + ((Metal.rand, Float32), (Metal.randn, Float32)) + @testset "args" for args in ((1,), (3,), (3, 3), (16,), (16, 16)) + A = f(args...) + @test eltype(A) == T + end + end + # out-of-place, with type specified + @testset "$f with $T" for (f, T) in OOPLACE_TUPLES + @testset "$args" for args in ((T, 1), + (T, 3), + (T, 3, 3), + (T, (3, 3)), + (T, 16), + (T, 16, 16), + (T, (16, 16))) + A = f(args...) + @test eltype(A) == T + end + end + end + ## seeding + @testset "Seeding" begin + Metal.seed!(1) + a = Metal.rand(Int32, 1) + Metal.seed!(1) + b = Metal.rand(Int32, 1) + @test iszero(collect(a) - collect(b)) + end end # testset