Skip to content

Commit

Permalink
Replace current_device() with device() (#366)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
christiangnrd and maleadt authored Jun 18, 2024
1 parent a8c51d9 commit e3a8ed2
Show file tree
Hide file tree
Showing 31 changed files with 92 additions and 90 deletions.
2 changes: 1 addition & 1 deletion bin/metallib-load
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function main(args)
end
isempty(metallib) && error("Empty input")

dev = current_device()
dev = device()
verbose && println("Using device: ", dev.name)

lib = try
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/essentials.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
```@docs
device!
devices
current_device
device
global_queue
synchronize
device_synchronize
Expand Down
2 changes: 1 addition & 1 deletion docs/src/faq/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ There are varying degrees of user-facing interfaces from Metal.jl. At the lowest
`Metal.MTL.xxx`. This is for low-level functionality close to or at bare Objective-C, or things
that a normal user wouldn't directly be using. `Metal.MPS.xxx` is for Metal Performance Shader
specifics (like `MPSMatrix`).
Next, is `Metal.xxx`. This is for higher-level, usually pure-Julian functionality (like `current_device()`).
Next, is `Metal.xxx`. This is for higher-level, usually pure-Julian functionality (like `device()`).
The only thing beyond this is exporting into the global namespace. That would be useful for uniquely-named
functions/structures/macros with clear and common use-cases (`MtlArray` or `@metal`).
Expand Down
4 changes: 2 additions & 2 deletions lib/mps/copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ export MPSMatrixCopy
@autoproperty destinationsAreTransposed::Bool
end

function MPSMatrixCopy(device, copyRows, copyColumns, sourcesAreTransposed, destinationsAreTransposed)
function MPSMatrixCopy(dev, copyRows, copyColumns, sourcesAreTransposed, destinationsAreTransposed)
kernel = @objc [MPSMatrixCopy alloc]::id{MPSMatrixCopy}
obj = MPSMatrixCopy(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixCopy} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixCopy} initWithDevice:dev::id{MTLDevice}
copyRows:copyRows::NSUInteger
copyColumns:copyColumns::NSUInteger
sourcesAreTransposed:sourcesAreTransposed::Bool
Expand Down
8 changes: 4 additions & 4 deletions lib/mps/decomposition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ export MPSMatrixDecompositionLU

@objcwrapper immutable=false MPSMatrixDecompositionLU <: MPSMatrixUnaryKernel

function MPSMatrixDecompositionLU(device, rows, columns)
function MPSMatrixDecompositionLU(dev, rows, columns)
kernel = @objc [MPSMatrixDecompositionLU alloc]::id{MPSMatrixDecompositionLU}
obj = MPSMatrixDecompositionLU(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixDecompositionLU} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixDecompositionLU} initWithDevice:dev::id{MTLDevice}
rows:rows::NSUInteger
columns:columns::NSUInteger]::id{MPSMatrixDecompositionLU}
return obj
Expand All @@ -37,11 +37,11 @@ export MPSMatrixDecompositionCholesky

@objcwrapper immutable=false MPSMatrixDecompositionCholesky <: MPSMatrixUnaryKernel

function MPSMatrixDecompositionCholesky(device, lower, order)
function MPSMatrixDecompositionCholesky(dev, lower, order)
kernel = @objc [MPSMatrixDecompositionCholesky alloc]::id{MPSMatrixDecompositionCholesky}
obj = MPSMatrixDecompositionCholesky(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixDecompositionCholesky} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixDecompositionCholesky} initWithDevice:dev::id{MTLDevice}
lower:lower::Bool
order:order::NSUInteger]::id{MPSMatrixDecompositionCholesky}
return obj
Expand Down
16 changes: 8 additions & 8 deletions lib/mps/images.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ end

@objcwrapper immutable=false MPSImageGaussianBlur <: MPSUnaryImageKernel

function MPSImageGaussianBlur(device, sigma)
function MPSImageGaussianBlur(dev, sigma)
kernel = @objc [MPSImageGaussianBlur alloc]::id{MPSImageGaussianBlur}
obj = MPSImageGaussianBlur(kernel)
finalizer(release, obj)
@objc [obj::id{MPSImageGaussianBlur} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSImageGaussianBlur} initWithDevice:dev::id{MTLDevice}
sigma:sigma::Float32]::id{MPSImageGaussianBlur}
return obj
end


@objcwrapper immutable=false MPSImageBox <: MPSUnaryImageKernel

function MPSImageBox(device, kernelWidth, kernelHeight)
function MPSImageBox(dev, kernelWidth, kernelHeight)
kernel = @objc [MPSImageBox alloc]::id{MPSImageBox}
obj = MPSImageBox(kernel)
finalizer(release, obj)
@objc [obj::id{MPSImageBox} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSImageBox} initWithDevice:dev::id{MTLDevice}
kernelWidth:kernelWidth::Int
kernelHeight:kernelHeight::Int]::id{MPSImageBox}
return obj
Expand All @@ -69,7 +69,7 @@ function blur(image, kernel; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)

w,h = size(image)

alignment = MTL.minimumLinearTextureAlignmentForPixelFormat(current_device(), pixelFormat)
alignment = MTL.minimumLinearTextureAlignmentForPixelFormat(device(), pixelFormat)
preBytesPerRow = sizeof(eltype(image))*w

rowoffset = alignment - (preBytesPerRow - 1) % alignment - 1
Expand All @@ -83,19 +83,19 @@ function blur(image, kernel; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)
textDesc2.usage = MTL.MTLTextureUsageShaderRead | MTL.MTLTextureUsageShaderWrite
text2 = MTL.MTLTexture(res.data.rc.obj, textDesc2, 0, bytesPerRow)

cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, kernel, text1, text2)
commit!(cmdbuf)

return res
end

function gaussianblur(image; sigma, pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)
kernel = MPSImageGaussianBlur(current_device(), sigma)
kernel = MPSImageGaussianBlur(device(), sigma)
return blur(image, kernel; pixelFormat)
end

function boxblur(image, kernelWidth, kernelHeight; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm)
kernel = MPSImageBox(current_device(), kernelWidth, kernelHeight)
kernel = MPSImageBox(device(), kernelWidth, kernelHeight)
return blur(image, kernel; pixelFormat)
end
10 changes: 5 additions & 5 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
typC = eltype(C)

# If possible, dispatch to performance shaders
if is_supported(current_device()) &&
if is_supported(device()) &&
typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES
matmul!(C, A, B, alpha, beta, transA, transB)
else
Expand Down Expand Up @@ -131,7 +131,7 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
typC = eltype(C)

# If possible, dispatch to performance shaders
if is_supported(current_device()) &&
if is_supported(device()) &&
typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES
matvecmul!(C, A, B, alpha, beta, transA)
else
Expand Down Expand Up @@ -177,7 +177,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
@autoreleasepool function LinearAlgebra.lu(A::MtlMatrix{T};
check::Bool=true) where {T<:MtlFloat}
M,N = size(A)
dev = current_device()
dev = device()
queue = global_queue(dev)

At = MtlMatrix{T,Private}(undef, (N, M))
Expand Down Expand Up @@ -235,7 +235,7 @@ end
check::Bool=true,
allowsingular::Bool=false) where {T<:MtlFloat}
M,N = size(A)
dev = current_device()
dev = device()
queue = global_queue(dev)

At = MtlMatrix{T,Private}(undef, (N, M))
Expand Down Expand Up @@ -278,7 +278,7 @@ end
axes(B,2) == axes(A,1) && axes(B,1) == axes(A,2) || throw(DimensionMismatch("transpose"))

M,N = size(A)
dev = current_device()
dev = device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)

Expand Down
20 changes: 10 additions & 10 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ export MPSMatrixMultiplication, matmul!
@autoproperty batchStart::NSUInteger setter=setBatchStart
end

function MPSMatrixMultiplication(device, transposeLeft, transposeRight, resultRows,
function MPSMatrixMultiplication(dev, transposeLeft, transposeRight, resultRows,
resultColumns, interiorColumns, alpha, beta)
kernel = @objc [MPSMatrixMultiplication alloc]::id{MPSMatrixMultiplication}
obj = MPSMatrixMultiplication(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixMultiplication} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixMultiplication} initWithDevice:dev::id{MTLDevice}
transposeLeft:transposeLeft::Bool
transposeRight:transposeRight::Bool
resultRows:resultRows::NSUInteger
Expand Down Expand Up @@ -225,14 +225,14 @@ function matmul!(c::MtlArray{T1,N}, a::MtlArray{T2,N}, b::MtlArray{T3,N},
mps_b = MPSMatrix(b)
mps_c = MPSMatrix(c)

mat_mul_kernel = MPSMatrixMultiplication(current_device(),
mat_mul_kernel = MPSMatrixMultiplication(device(),
transpose_b, transpose_a,
rows_c, cols_c, cols_a,
alpha, beta)


# Encode and commit matmul kernel
cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, mat_mul_kernel, mps_b, mps_a, mps_c)
commit!(cmdbuf)

Expand All @@ -253,11 +253,11 @@ export MPSMatrixFindTopK, topk, topk!
@autoproperty sourceRows::NSInteger setter=setSourceRows
end

function MPSMatrixFindTopK(device, numberOfTopKValues)
function MPSMatrixFindTopK(dev, numberOfTopKValues)
kernel = @objc [MPSMatrixFindTopK alloc]::id{MPSMatrixFindTopK}
obj = MPSMatrixFindTopK(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixFindTopK} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixFindTopK} initWithDevice:dev::id{MTLDevice}
numberOfTopKValues:numberOfTopKValues::NSUInteger]::id{MPSMatrixFindTopK}
return obj
end
Expand Down Expand Up @@ -299,11 +299,11 @@ end
mps_i = MPSMatrix(I)
mps_v = MPSMatrix(V)

topk_kernel = MPSMatrixFindTopK(current_device(), k)
topk_kernel = MPSMatrixFindTopK(device(), k)
topk_kernel.indexOffset = 1

# Encode and commit topk kernel
cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, topk_kernel, mps_a, mps_i, mps_v)
commit!(cmdbuf)

Expand Down Expand Up @@ -343,11 +343,11 @@ end

for f in (:MPSMatrixSoftMax, :MPSMatrixLogSoftMax)
@eval begin
function $(f)(device)
function $(f)(dev)
kernel = @objc [$(f) alloc]::id{$(f)}
obj = $(f)(kernel)
finalizer(release, obj)
@objc [obj::id{$(f)} initWithDevice:device::id{MTLDevice}]::id{$(f)}
@objc [obj::id{$(f)} initWithDevice:dev::id{MTLDevice}]::id{$(f)}
return obj
end

Expand Down
8 changes: 4 additions & 4 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ export MPSMatrixVectorMultiplication, matvecmul!

@objcwrapper immutable=false MPSMatrixVectorMultiplication <: MPSMatrixBinaryKernel

function MPSMatrixVectorMultiplication(device, transpose, rows, columns, alpha, beta)
function MPSMatrixVectorMultiplication(dev, transpose, rows, columns, alpha, beta)
kernel = @objc [MPSMatrixVectorMultiplication alloc]::id{MPSMatrixVectorMultiplication}
obj = MPSMatrixVectorMultiplication(kernel)
finalizer(release, obj)
@objc [obj::id{MPSMatrixVectorMultiplication} initWithDevice:device::id{MTLDevice}
@objc [obj::id{MPSMatrixVectorMultiplication} initWithDevice:dev::id{MTLDevice}
transpose:transpose::Bool
rows:rows::NSUInteger
columns:columns::NSUInteger
Expand Down Expand Up @@ -129,12 +129,12 @@ function matvecmul!(c::MtlVector, a::MtlMatrix, b::MtlVector, alpha::Number=true
mps_b = MPSVector(b)
mps_c = MPSVector(c)

matvec_mul_kernel = MPSMatrixVectorMultiplication(current_device(), !transpose,
matvec_mul_kernel = MPSMatrixVectorMultiplication(device(), !transpose,
rows_c, cols_a,
alpha, beta)

# Encode and commit matmul kernel
cmdbuf = MTLCommandBuffer(global_queue(current_device()))
cmdbuf = MTLCommandBuffer(global_queue(device()))
encode!(cmdbuf, matvec_mul_kernel, mps_a, mps_b, mps_c)
commit!(cmdbuf)

Expand Down
4 changes: 2 additions & 2 deletions lib/mtl/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ export supports_family, is_m3, is_m2, is_m1
MTLGPUFamilyMac2 = 2002 # Mac family 2 GPU features
end

function supports_family(device::MTLDevice, gpufamily::MTLGPUFamily)
@objc [device::MTLDevice supportsFamily:gpufamily::MTLGPUFamily]::Bool
function supports_family(dev::MTLDevice, gpufamily::MTLGPUFamily)
@objc [dev::MTLDevice supportsFamily:gpufamily::MTLGPUFamily]::Bool
end

is_m1(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple7) &&
Expand Down
14 changes: 7 additions & 7 deletions lib/mtl/library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ export MTLLibrary, MTLLibraryFromFile, MTLLibraryFromData
@autoproperty functionNames::id{NSArray} type=Vector{NSString}
end

function MTLLibrary(device::MTLDevice, src::String,
function MTLLibrary(dev::MTLDevice, src::String,
opts::MTLCompileOptions=MTLCompileOptions())
err = Ref{id{NSError}}(nil)
handle = @objc [device::id{MTLDevice} newLibraryWithSource:src::id{NSString}
handle = @objc [dev::id{MTLDevice} newLibraryWithSource:src::id{NSString}
options:opts::id{MTLCompileOptions}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
err[] == nil || throw(NSError(err[]))
Expand All @@ -21,14 +21,14 @@ function MTLLibrary(device::MTLDevice, src::String,
return obj
end

function MTLLibraryFromFile(device::MTLDevice, path::String)
function MTLLibraryFromFile(dev::MTLDevice, path::String)
err = Ref{id{NSError}}(nil)
handle = if macos_version() >= v"13"
url = NSFileURL(path)
@objc [device::id{MTLDevice} newLibraryWithURL:url::id{NSURL}
@objc [dev::id{MTLDevice} newLibraryWithURL:url::id{NSURL}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
else
@objc [device::id{MTLDevice} newLibraryWithFile:path::id{NSString}
@objc [dev::id{MTLDevice} newLibraryWithFile:path::id{NSString}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
end
err[] == nil || throw(NSError(err[]))
Expand All @@ -38,11 +38,11 @@ function MTLLibraryFromFile(device::MTLDevice, path::String)
return obj
end

function MTLLibraryFromData(device::MTLDevice, input_data)
function MTLLibraryFromData(dev::MTLDevice, input_data)
err = Ref{id{NSError}}(nil)
GC.@preserve input_data begin
data = dispatch_data(pointer(input_data), sizeof(input_data))
handle = @objc [device::id{MTLDevice} newLibraryWithData:data::dispatch_data_t
handle = @objc [dev::id{MTLDevice} newLibraryWithData:data::dispatch_data_t
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
end
err[] == nil || throw(NSError(err[]))
Expand Down
8 changes: 4 additions & 4 deletions lib/mtl/texture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ end
## bitwise operations lose type information, so allow conversions
Base.convert(::Type{MTLPixelFormat}, x::Integer) = MTLPixelFormat(x)

function minimumLinearTextureAlignmentForPixelFormat(device, format)
return @objc [device::MTLDevice minimumLinearTextureAlignmentForPixelFormat:format::MTLPixelFormat]::NSUInteger
function minimumLinearTextureAlignmentForPixelFormat(dev, format)
return @objc [dev::MTLDevice minimumLinearTextureAlignmentForPixelFormat:format::MTLPixelFormat]::NSUInteger
end

@cenum MTLTextureUsage::NSUInteger begin
Expand Down Expand Up @@ -247,8 +247,8 @@ function MTLTexture(buffer, descriptor, offset, bytesPerRow)
return obj
end

function MTLTexture(device, descriptor)
texture = @objc [device::id{MTLDevice} newTextureWithDescriptor:descriptor::id{MTLTextureDescriptor}]::id{MTLTexture}
function MTLTexture(dev, descriptor)
texture = @objc [dev::id{MTLDevice} newTextureWithDescriptor:descriptor::id{MTLTextureDescriptor}]::id{MTLTexture}
obj = MTLTexture(texture)
finalizer(release, obj)

Expand Down
2 changes: 2 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ export MetalBackend
include("../ext/BFloat16sExt.jl")
end

include("deprecated.jl")

include("precompile.jl")

end # module
6 changes: 3 additions & 3 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# host array

export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private, device
export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private

function hasfieldcount(@nospecialize(dt))
try
Expand Down Expand Up @@ -61,7 +61,7 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N}
maxsize
end

dev = current_device()
dev = device()
if bufsize == 0
# Metal doesn't support empty allocations. For simplicity (i.e., the ability to get
# a pointer, query the buffer's properties, etc), we use a 1-byte buffer instead.
Expand Down Expand Up @@ -557,7 +557,7 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPtr{T}, dims; own=false)
end

function Base.unsafe_wrap(A::Type{<:MtlArray{T,N}}, arr::Array, dims=size(arr);
dev=current_device(), kwargs...) where {T,N}
dev=device(), kwargs...) where {T,N}
GC.@preserve arr begin
buf = MTLBuffer(dev, prod(dims) * sizeof(T), pointer(arr); nocopy=true, kwargs...)
return A(buf, Dims(dims))
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ end
@signpost_event log=log_compiler() "Link" "Job=$job"

@signpost_interval log=log_compiler() "Instantiate compute pipeline" begin
dev = current_device()
dev = device()
lib = MTLLibraryFromData(dev, compiled.image)
fun = MTLFunction(lib, compiled.entry)
pipeline_state = try
Expand Down
Loading

0 comments on commit e3a8ed2

Please sign in to comment.