From b203523390f278c8aba408d0c922162cde118ff4 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:19:05 -0300 Subject: [PATCH] Make buffers always be a multiple of 16 bytes to support MPS rand functionality --- lib/mtl/buffer.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/mtl/buffer.jl b/lib/mtl/buffer.jl index cb3435b33..43d890506 100644 --- a/lib/mtl/buffer.jl +++ b/lib/mtl/buffer.jl @@ -21,14 +21,16 @@ end ## allocation +bufferbytesize(bytesize::T) where T<:Integer = ceil(T, bytesize / 16) * T(16) function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer; storage=Private, hazard_tracking=DefaultTracking, cache_mode=DefaultCPUCache) opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode - @assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap - ptr = alloc_buffer(dev, bytesize, opts) + realbytesize = bufferbytesize(bytesize) + @assert 0 < realbytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap + ptr = alloc_buffer(dev, realbytesize, opts) return MTLBuffer(ptr) end @@ -39,8 +41,9 @@ function MTLBuffer(dev::MTLDevice, bytesize::Integer, ptr::Ptr; storage == Private && error("Can't create a Private copy-allocated buffer.") opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode - @assert 0 < bytesize <= dev.maxBufferLength - ptr = alloc_buffer(dev, bytesize, opts, ptr) + realbytesize = bufferbytesize(bytesize) + @assert 0 < realbytesize <= dev.maxBufferLength + ptr = alloc_buffer(dev, realbytesize, opts, ptr) return MTLBuffer(ptr) end