Skip to content

Commit

Permalink
Couple typos and is_m4 function (#498)
Browse files Browse the repository at this point in the history
[skip benchmarks]
  • Loading branch information
christiangnrd authored Dec 17, 2024
1 parent 60a9e34 commit d37e9dd
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ function encode!(cmdbuf::MTLCommandBuffer, matmul::MPSMatrixMultiplication, left
end

"""
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
matmul!(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)
A `MPSMatrixMultiplication` kernel thay computes:
`c = alpha * op(a) * beta * op(b) + beta * C`
Expand Down
2 changes: 1 addition & 1 deletion lib/mps/matrixrandom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export MPSMatrixRandomDistributionDescriptor
@autoproperty distributionType::MPSMatrixRandomDistribution
@autoproperty maximum::Float32 setter=setMaximum
@autoproperty mean::Float32 setter=setMean
@autoproperty minimum::Float32 setter=setMimimum
@autoproperty minimum::Float32 setter=setMinimum
@autoproperty standardDeviation::Float32 setter=setStandardDeviation
end

Expand Down
16 changes: 7 additions & 9 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
end

function MPSTemporaryNDArray(cmdbuf::MTLCommandBuffer, descriptor::MPSNDArrayDescriptor)
@objc [MPSNDTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
@objc [MPSTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
descriptor:descriptor::id{MPSNDArrayDescriptor}]::id{MPSTemporaryNDArray}
return obj
end
Expand Down Expand Up @@ -123,7 +123,7 @@ end
return obj
end
else
function MPSNDArray(buffer::MTLBuffer, offset::UInt, descriptor::MPSNDArrayDescriptor)
function MPSNDArray(_::MTLBuffer, _::UInt, _::MPSNDArrayDescriptor)
@assert false "Creating an MPSNDArray that shares data with user-provided MTLBuffer is only supported in macOS v15+"
end
end
Expand All @@ -135,20 +135,18 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode)
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
ndims = Int(ndarr.numberOfDimensions)
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
T = convert(DataType, ndarr.dataType)
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
dev = device(arr)

cmdBuf = MTLCommandBuffer(global_queue(dev))

exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))

commit!(cmdBuf)
wait_completed(cmdBuf)
cmdBuf = MTLCommandBuffer(global_queue(dev)) do cmdBuf
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
end

async || wait_completed(cmdBuf)
return arr
end

Expand Down
7 changes: 5 additions & 2 deletions lib/mtl/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ MTLDevice(i::Integer) = devices()[i]
# family
#

export supports_family, is_m3, is_m2, is_m1
export supports_family, is_m4, is_m3, is_m2, is_m1

@cenum MTLGPUFamily::NSInteger begin
MTLGPUFamilyMetal3 = 5001 # Metal 3 support
Expand Down Expand Up @@ -121,4 +121,7 @@ is_m1(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple7) &&
!supports_family(dev, MTLGPUFamilyApple8)
is_m2(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple8) &&
!supports_family(dev, MTLGPUFamilyApple9)
is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9)
is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) &&
occursin("M3", String(dev.name))
is_m4(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) &&
occursin("M4", String(dev.name))

0 comments on commit d37e9dd

Please sign in to comment.