Skip to content

Commit

Permalink
MPSNDArray improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Dec 18, 2024
1 parent 0981389 commit 2a6d162
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
3 changes: 3 additions & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import GPUArrays

const MtlFloat = Union{Float32, Float16}

const MPSShape = NSArray#{NSNumber}
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))

is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)

include("size.jl")
Expand Down
28 changes: 24 additions & 4 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ else
end
end

function Base.size(ndarr::MPSNDArray)
ndims = Int(ndarr.numberOfDimensions)
Tuple([Int(lengthOfDimension(ndarr,i)) for i in 0:ndims-1])
end

@objcwrapper immutable=false MPSTemporaryNDArray <: MPSNDArray

@objcproperties MPSTemporaryNDArray begin
Expand Down Expand Up @@ -136,14 +141,17 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
ndims = Int(ndarr.numberOfDimensions)
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
arrsize = size(ndarr)
T = convert(DataType, ndarr.dataType)
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
arr = MtlArray{T,length(arrsize),storage}(undef, (arrsize)...)
return exportToMtlArray!(arr, ndarr; async)
end

function exportToMtlArray!(arr::MtlArray{T}, ndarr::MPSNDArray; async=false) where T
dev = device(arr)

cmdBuf = MTLCommandBuffer(global_queue(dev)) do cmdBuf
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, arr.offset)
end

async || wait_completed(cmdBuf)
Expand All @@ -157,6 +165,12 @@ exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffe
destinationDataType:destinationDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
exportDataWithCommandBuffer(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, toBuffer, destinationDataType, offset) =
@objc [ndarr::MPSNDArray exportDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
toBuffer:toBuffer::id{MTLBuffer}
destinationDataType:destinationDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:nil::id{ObjectiveC.Object}]::Nothing

# rowStrides in Bytes
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset, rowStrides) =
Expand All @@ -165,6 +179,12 @@ importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBu
sourceDataType:sourceDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:pointer(rowStrides)::Ptr{NSInteger}]::Nothing
importDataWithCommandBuffer!(ndarr::MPSNDArray, cmdbuf::MTLCommandBuffer, fromBuffer, sourceDataType, offset) =
@objc [ndarr::MPSNDArray importDataWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
fromBuffer:fromBuffer::id{MTLBuffer}
sourceDataType:sourceDataType::MPSDataType
offset:offset::NSUInteger
rowStrides:nil::id{ObjectiveC.Object}]::Nothing

# TODO
# exportDataWithCommandBuffer(toImages, offset)
Expand Down

0 comments on commit 2a6d162

Please sign in to comment.