Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace current_device() with device() #366

Merged
merged 6 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/metallib-load
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function main(args)
end
isempty(metallib) && error("Empty input")

dev = current_device()
dev = device()
verbose && println("Using device: ", dev.name)

lib = try
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/essentials.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
```@docs
device!
devices
current_device
device
global_queue
synchronize
device_synchronize
Expand Down
2 changes: 1 addition & 1 deletion docs/src/faq/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ There are varying degrees of user-facing interfaces from Metal.jl. At the lowest
`Metal.MTL.xxx`. This is for low-level functionality close to or at bare Objective-C, or things
that a normal user wouldn't directly be using. `Metal.MPS.xxx` is for Metal Performance Shader
specifics (like `MPSMatrix`).
Next, is `Metal.xxx`. This is for higher-level, usually pure-Julian functionality (like `current_device()`).
Next, is `Metal.xxx`. This is for higher-level, usually pure-Julian functionality (like `device()`).
The only thing beyond this is exporting into the global namespace. That would be useful for uniquely-named
functions/structures/macros with clear and common use-cases (`MtlArray` or `@metal`).
Expand Down
4 changes: 2 additions & 2 deletions lib/mps/copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ export MPSMatrixCopy
@autoproperty destinationsAreTransposed::Bool
end

function MPSMatrixCopy(device, copyRows, copyColumns, sourcesAreTransposed, destinationsAreTransposed)
function MPSMatrixCopy(dev, copyRows, copyColumns, sourcesAreTransposed, destinationsAreTransposed)
kernel = @objc [MPSMatrixCopy alloc]::id{MPSMatrixCopy}
obj = MPSMatrixCopy(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixCopy} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixCopy} initWithDevice:dev::id{MTLDevice}
copyRows:copyRows::NSUInteger
copyColumns:copyColumns::NSUInteger
sourcesAreTransposed:sourcesAreTransposed::Bool
Expand Down
8 changes: 4 additions & 4 deletions lib/mps/decomposition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ export MPSMatrixDecompositionLU

@objcwrapper immutable=false MPSMatrixDecompositionLU <: MPSMatrixUnaryKernel

function MPSMatrixDecompositionLU(device, rows, columns)
function MPSMatrixDecompositionLU(dev, rows, columns)
kernel = @objc [MPSMatrixDecompositionLU alloc]::id{MPSMatrixDecompositionLU}
obj = MPSMatrixDecompositionLU(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixDecompositionLU} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixDecompositionLU} initWithDevice:dev::id{MTLDevice}
rows:rows::NSUInteger
columns:columns::NSUInteger]::id{MPSMatrixDecompositionLU}
return obj
Expand All @@ -37,11 +37,11 @@ export MPSMatrixDecompositionCholesky

@objcwrapper immutable=false MPSMatrixDecompositionCholesky <: MPSMatrixUnaryKernel

function MPSMatrixDecompositionCholesky(device, lower, order)
function MPSMatrixDecompositionCholesky(dev, lower, order)
kernel = @objc [MPSMatrixDecompositionCholesky alloc]::id{MPSMatrixDecompositionCholesky}
obj = MPSMatrixDecompositionCholesky(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixDecompositionCholesky} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixDecompositionCholesky} initWithDevice:dev::id{MTLDevice}
lower:lower::Bool
order:order::NSUInteger]::id{MPSMatrixDecompositionCholesky}
return obj
Expand Down
16 changes: 8 additions & 8 deletions lib/mps/images.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ end

@objcwrapper immutable=false MPSImageGaussianBlur <: MPSUnaryImageKernel

function MPSImageGaussianBlur(device, sigma)
function MPSImageGaussianBlur(dev, sigma)
kernel = @objc [MPSImageGaussianBlur alloc]::id{MPSImageGaussianBlur}
obj = MPSImageGaussianBlur(kernel)
finalizer(release, obj)
@objc [obj::id{MPSImageGaussianBlur} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSImageGaussianBlur} initWithDevice:dev::id{MTLDevice}
sigma:sigma::Float32]::id{MPSImageGaussianBlur}
return obj
end


@objcwrapper immutable=false MPSImageBox <: MPSUnaryImageKernel

function MPSImageBox(device, kernelWidth, kernelHeight)
function MPSImageBox(dev, kernelWidth, kernelHeight)
kernel = @objc [MPSImageBox alloc]::id{MPSImageBox}
obj = MPSImageBox(kernel)
finalizer(release, obj)
@objc [obj::id{MPSImageBox} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSImageBox} initWithDevice:dev::id{MTLDevice}
kernelWidth:kernelWidth::Int
kernelHeight:kernelHeight::Int]::id{MPSImageBox}
return obj
Expand All @@ -69,7 +69,7 @@ function blur(image, kernel; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)

w,h = size(image)

alignment = MTL.minimumLinearTextureAlignmentForPixelFormat(current_device(), pixelFormat)
alignment = MTL.minimumLinearTextureAlignmentForPixelFormat(device(), pixelFormat)
preBytesPerRow = sizeof(eltype(image))*w

rowoffset = alignment - (preBytesPerRow - 1) % alignment - 1
Expand All @@ -83,19 +83,19 @@ function blur(image, kernel; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)
textDesc2.usage = MTL.MTLTextureUsageShaderRead | MTL.MTLTextureUsageShaderWrite
text2 = MTL.MTLTexture(res.data.rc.obj, textDesc2, 0, bytesPerRow)

cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, kernel, text1, text2)
commit!(cmdbuf)

return res
end

function gaussianblur(image; sigma, pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)
kernel = MPSImageGaussianBlur(current_device(), sigma)
kernel = MPSImageGaussianBlur(device(), sigma)
return blur(image, kernel; pixelFormat)
end

function boxblur(image, kernelWidth, kernelHeight; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)
kernel = MPSImageBox(current_device(), kernelWidth, kernelHeight)
kernel = MPSImageBox(device(), kernelWidth, kernelHeight)
return blur(image, kernel; pixelFormat)
end
10 changes: 5 additions & 5 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
typC = eltype(C)

# If possible, dispatch to performance shaders
if is_supported(current_device()) &&
if is_supported(device()) &&
typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
matmul!(C, A, B, alpha, beta, transA, transB)
else
Expand Down Expand Up @@ -131,7 +131,7 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
typC = eltype(C)

# If possible, dispatch to performance shaders
if is_supported(current_device()) &&
if is_supported(device()) &&
typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
matvecmul!(C, A, B, alpha, beta, transA)
else
Expand Down Expand Up @@ -177,7 +177,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
@autoreleasepool function LinearAlgebra.lu(A::MtlMatrix{T};
check::Bool=true) where {T<:MtlFloat}
M,N = size(A)
dev = current_device()
dev = device()
queue = global_queue(dev)

At = MtlMatrix{T,Private}(undef, (N, M))
Expand Down Expand Up @@ -235,7 +235,7 @@ end
check::Bool=true,
allowsingular::Bool=false) where {T<:MtlFloat}
M,N = size(A)
dev = current_device()
dev = device()
queue = global_queue(dev)

At = MtlMatrix{T,Private}(undef, (N, M))
Expand Down Expand Up @@ -278,7 +278,7 @@ end
axes(B,2) == axes(A,1) && axes(B,1) == axes(A,2) || throw(DimensionMismatch("transpose"))

M,N = size(A)
dev = current_device()
dev = device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)

Expand Down
20 changes: 10 additions & 10 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ export MPSMatrixMultiplication, matmul!
@autoproperty batchStart::NSUInteger setter=setBatchStart
end

function MPSMatrixMultiplication(device, transposeLeft, transposeRight, resultRows,
function MPSMatrixMultiplication(dev, transposeLeft, transposeRight, resultRows,
resultColumns, interiorColumns, alpha, beta)
kernel = @objc [MPSMatrixMultiplication alloc]::id{MPSMatrixMultiplication}
obj = MPSMatrixMultiplication(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixMultiplication} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixMultiplication} initWithDevice:dev::id{MTLDevice}
transposeLeft:transposeLeft::Bool
transposeRight:transposeRight::Bool
resultRows:resultRows::NSUInteger
Expand Down Expand Up @@ -225,14 +225,14 @@ function matmul!(c::MtlArray{T1,N}, a::MtlArray{T2,N}, b::MtlArray{T3,N},
mps_b = MPSMatrix(b)
mps_c = MPSMatrix(c)

mat_mul_kernel = MPSMatrixMultiplication(current_device(),
mat_mul_kernel = MPSMatrixMultiplication(device(),
transpose_b, transpose_a,
rows_c, cols_c, cols_a,
alpha, beta)


# Encode and commit matmul kernel
cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, mat_mul_kernel, mps_b, mps_a, mps_c)
commit!(cmdbuf)

Expand All @@ -253,11 +253,11 @@ export MPSMatrixFindTopK, topk, topk!
@autoproperty sourceRows::NSInteger setter=setSourceRows
end

function MPSMatrixFindTopK(device, numberOfTopKValues)
function MPSMatrixFindTopK(dev, numberOfTopKValues)
kernel = @objc [MPSMatrixFindTopK alloc]::id{MPSMatrixFindTopK}
obj = MPSMatrixFindTopK(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixFindTopK} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixFindTopK} initWithDevice:dev::id{MTLDevice}
numberOfTopKValues:numberOfTopKValues::NSUInteger]::id{MPSMatrixFindTopK}
return obj
end
Expand Down Expand Up @@ -299,11 +299,11 @@ end
mps_i = MPSMatrix(I)
mps_v = MPSMatrix(V)

topk_kernel = MPSMatrixFindTopK(current_device(), k)
topk_kernel = MPSMatrixFindTopK(device(), k)
topk_kernel.indexOffset = 1

# Encode and commit topk kernel
cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, topk_kernel, mps_a, mps_i, mps_v)
commit!(cmdbuf)

Expand Down Expand Up @@ -343,11 +343,11 @@ end

for f in (:MPSMatrixSoftMax, :MPSMatrixLogSoftMax)
@eval begin
function $(f)(device)
function $(f)(dev)
kernel = @objc [$(f) alloc]::id{$(f)}
obj = $(f)(kernel)
finalizer(release, obj)
@objc [obj::id{$(f)} initWithDevice:device::id{MTLDevice}]::id{$(f)}
@objc [obj::id{$(f)} initWithDevice:dev::id{MTLDevice}]::id{$(f)}
return obj
end

Expand Down
8 changes: 4 additions & 4 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ export MPSMatrixVectorMultiplication, matvecmul!

@objcwrapper immutable=false MPSMatrixVectorMultiplication <: MPSMatrixBinaryKernel

function MPSMatrixVectorMultiplication(device, transpose, rows, columns, alpha, beta)
function MPSMatrixVectorMultiplication(dev, transpose, rows, columns, alpha, beta)
kernel = @objc [MPSMatrixVectorMultiplication alloc]::id{MPSMatrixVectorMultiplication}
obj = MPSMatrixVectorMultiplication(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixVectorMultiplication} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixVectorMultiplication} initWithDevice:dev::id{MTLDevice}
transpose:transpose::Bool
rows:rows::NSUInteger
columns:columns::NSUInteger
Expand Down Expand Up @@ -129,12 +129,12 @@ function matvecmul!(c::MtlVector, a::MtlMatrix, b::MtlVector, alpha::Number=true
mps_b = MPSVector(b)
mps_c = MPSVector(c)

matvec_mul_kernel = MPSMatrixVectorMultiplication(current_device(), !transpose,
matvec_mul_kernel = MPSMatrixVectorMultiplication(device(), !transpose,
rows_c, cols_a,
alpha, beta)

# Encode and commit matmul kernel
cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, matvec_mul_kernel, mps_a, mps_b, mps_c)
commit!(cmdbuf)

Expand Down
4 changes: 2 additions & 2 deletions lib/mtl/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ export supports_family, is_m3, is_m2, is_m1
MTLGPUFamilyMac2 = 2002 # Mac family 2 GPU features
end

function supports_family(device::MTLDevice, gpufamily::MTLGPUFamily)
@objc [device::MTLDevice supportsFamily:gpufamily::MTLGPUFamily]::Bool
function supports_family(dev::MTLDevice, gpufamily::MTLGPUFamily)
@objc [dev::MTLDevice supportsFamily:gpufamily::MTLGPUFamily]::Bool
end

is_m1(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple7) &&
Expand Down
14 changes: 7 additions & 7 deletions lib/mtl/library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ export MTLLibrary, MTLLibraryFromFile, MTLLibraryFromData
@autoproperty functionNames::id{NSArray} type=Vector{NSString}
end

function MTLLibrary(device::MTLDevice, src::String,
function MTLLibrary(dev::MTLDevice, src::String,
opts::MTLCompileOptions=MTLCompileOptions())
err = Ref{id{NSError}}(nil)
handle = @objc [device::id{MTLDevice} newLibraryWithSource:src::id{NSString}
handle = @objc [dev::id{MTLDevice} newLibraryWithSource:src::id{NSString}
options:opts::id{MTLCompileOptions}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
err[] == nil || throw(NSError(err[]))
Expand All @@ -21,14 +21,14 @@ function MTLLibrary(device::MTLDevice, src::String,
return obj
end

function MTLLibraryFromFile(device::MTLDevice, path::String)
function MTLLibraryFromFile(dev::MTLDevice, path::String)
err = Ref{id{NSError}}(nil)
handle = if macos_version() >= v"13"
url = NSFileURL(path)
@objc [device::id{MTLDevice} newLibraryWithURL:url::id{NSURL}
@objc [dev::id{MTLDevice} newLibraryWithURL:url::id{NSURL}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
else
@objc [device::id{MTLDevice} newLibraryWithFile:path::id{NSString}
@objc [dev::id{MTLDevice} newLibraryWithFile:path::id{NSString}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
end
err[] == nil || throw(NSError(err[]))
Expand All @@ -38,11 +38,11 @@ function MTLLibraryFromFile(device::MTLDevice, path::String)
return obj
end

function MTLLibraryFromData(device::MTLDevice, input_data)
function MTLLibraryFromData(dev::MTLDevice, input_data)
err = Ref{id{NSError}}(nil)
GC.@preserve input_data begin
data = dispatch_data(pointer(input_data), sizeof(input_data))
handle = @objc [device::id{MTLDevice} newLibraryWithData:data::dispatch_data_t
handle = @objc [dev::id{MTLDevice} newLibraryWithData:data::dispatch_data_t
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
end
err[] == nil || throw(NSError(err[]))
Expand Down
8 changes: 4 additions & 4 deletions lib/mtl/texture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ end
## bitwise operations lose type information, so allow conversions
Base.convert(::Type{MTLPixelFormat}, x::Integer) = MTLPixelFormat(x)

function minimumLinearTextureAlignmentForPixelFormat(device, format)
return @objc [device::MTLDevice minimumLinearTextureAlignmentForPixelFormat:format::MTLPixelFormat]::NSUInteger
function minimumLinearTextureAlignmentForPixelFormat(dev, format)
return @objc [dev::MTLDevice minimumLinearTextureAlignmentForPixelFormat:format::MTLPixelFormat]::NSUInteger
end

@cenum MTLTextureUsage::NSUInteger begin
Expand Down Expand Up @@ -247,8 +247,8 @@ function MTLTexture(buffer, descriptor, offset, bytesPerRow)
return obj
end

function MTLTexture(device, descriptor)
texture = @objc [device::id{MTLDevice} newTextureWithDescriptor:descriptor::id{MTLTextureDescriptor}]::id{MTLTexture}
function MTLTexture(dev, descriptor)
texture = @objc [dev::id{MTLDevice} newTextureWithDescriptor:descriptor::id{MTLTextureDescriptor}]::id{MTLTexture}
obj = MTLTexture(texture)
finalizer(release, obj)

Expand Down
2 changes: 2 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ export MetalBackend
include("../ext/BFloat16sExt.jl")
end

include("deprecated.jl")

include("precompile.jl")

end # module
6 changes: 3 additions & 3 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# host array

export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private, device
export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private

function hasfieldcount(@nospecialize(dt))
try
Expand Down Expand Up @@ -61,7 +61,7 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N}
maxsize
end

dev = current_device()
dev = device()
if bufsize == 0
# Metal doesn't support empty allocations. For simplicity (i.e., the ability to get
# a pointer, query the buffer's properties, etc), we use a 1-byte buffer instead.
Expand Down Expand Up @@ -557,7 +557,7 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPtr{T}, dims; own=false)
end

function Base.unsafe_wrap(A::Type{<:MtlArray{T,N}}, arr::Array, dims=size(arr);
dev=current_device(), kwargs...) where {T,N}
dev=device(), kwargs...) where {T,N}
GC.@preserve arr begin
buf = MTLBuffer(dev, prod(dims) * sizeof(T), pointer(arr); nocopy=true, kwargs...)
return A(buf, Dims(dims))
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ end
@signpost_event log=log_compiler() "Link" "Job=$job"

@signpost_interval log=log_compiler() "Instantiate compute pipeline" begin
dev = current_device()
dev = device()
lib = MTLLibraryFromData(dev, compiled.image)
fun = MTLFunction(lib, compiled.entry)
pipeline_state = try
Expand Down
Loading