Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Couple typos and is_m4 function #498

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ function encode!(cmdbuf::MTLCommandBuffer, matmul::MPSMatrixMultiplication, left
end

"""
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
matmul!(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)
A `MPSMatrixMultiplication` kernel thay computes:
`c = alpha * op(a) * beta * op(b) + beta * C`
Expand Down
2 changes: 1 addition & 1 deletion lib/mps/matrixrandom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export MPSMatrixRandomDistributionDescriptor
@autoproperty distributionType::MPSMatrixRandomDistribution
@autoproperty maximum::Float32 setter=setMaximum
@autoproperty mean::Float32 setter=setMean
@autoproperty minimum::Float32 setter=setMimimum
@autoproperty minimum::Float32 setter=setMinimum
@autoproperty standardDeviation::Float32 setter=setStandardDeviation
end

Expand Down
16 changes: 7 additions & 9 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
end

function MPSTemporaryNDArray(cmdbuf::MTLCommandBuffer, descriptor::MPSNDArrayDescriptor)
@objc [MPSNDTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
@objc [MPSTemporaryNDArray temporaryNDArrayWithCommandBuffer:cmdbuf::id{MTLCommandBuffer}
descriptor:descriptor::id{MPSNDArrayDescriptor}]::id{MPSTemporaryNDArray}
return obj
end
Expand Down Expand Up @@ -123,7 +123,7 @@ end
return obj
end
else
function MPSNDArray(buffer::MTLBuffer, offset::UInt, descriptor::MPSNDArrayDescriptor)
function MPSNDArray(_::MTLBuffer, _::UInt, _::MPSNDArrayDescriptor)
@assert false "Creating an MPSNDArray that shares data with user-provided MTLBuffer is only supported in macOS v15+"
end
end
Expand All @@ -135,20 +135,18 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode)
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
ndims = Int(ndarr.numberOfDimensions)
arrsize = [lengthOfDimension(ndarr,i) for i in 0:ndims-1]
T = convert(DataType, ndarr.dataType)
arr = MtlArray{T,ndims,storage}(undef, reverse(arrsize)...)
dev = device(arr)

cmdBuf = MTLCommandBuffer(global_queue(dev))

exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))

commit!(cmdBuf)
wait_completed(cmdBuf)
cmdBuf = MTLCommandBuffer(global_queue(dev)) do cmdBuf
exportDataWithCommandBuffer(ndarr, cmdBuf, arr.data[], T, 0, collect(sizeof(T) .* reverse(strides(arr))))
end

async || wait_completed(cmdBuf)
return arr
end

Expand Down
7 changes: 5 additions & 2 deletions lib/mtl/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ MTLDevice(i::Integer) = devices()[i]
# family
#

export supports_family, is_m3, is_m2, is_m1
export supports_family, is_m4, is_m3, is_m2, is_m1

@cenum MTLGPUFamily::NSInteger begin
MTLGPUFamilyMetal3 = 5001 # Metal 3 support
Expand Down Expand Up @@ -121,4 +121,7 @@ is_m1(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple7) &&
!supports_family(dev, MTLGPUFamilyApple8)
is_m2(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple8) &&
!supports_family(dev, MTLGPUFamilyApple9)
is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9)
is_m3(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) &&
occursin("M3", String(dev.name))
is_m4(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple9) &&
occursin("M4", String(dev.name))
Loading