From e3a8ed29e0a665cbc1126fb441b2a1bfd388e01b Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:31:43 -0300 Subject: [PATCH] Replace `current_device()` with `device()` (#366) Co-authored-by: Tim Besard --- bin/metallib-load | 2 +- docs/src/api/essentials.md | 2 +- docs/src/faq/contributing.md | 2 +- lib/mps/copy.jl | 4 ++-- lib/mps/decomposition.jl | 8 ++++---- lib/mps/images.jl | 16 ++++++++-------- lib/mps/linalg.jl | 10 +++++----- lib/mps/matrix.jl | 20 ++++++++++---------- lib/mps/vector.jl | 8 ++++---- lib/mtl/device.jl | 4 ++-- lib/mtl/library.jl | 14 +++++++------- lib/mtl/texture.jl | 8 ++++---- src/Metal.jl | 2 ++ src/array.jl | 6 +++--- src/compiler/compilation.jl | 2 +- src/compiler/execution.jl | 4 ++-- src/compiler/reflection.jl | 8 ++++---- src/deprecated.jl | 3 +++ src/gpuarrays.jl | 5 +---- src/mapreduce.jl | 2 +- src/state.jl | 8 ++++---- src/utilities.jl | 2 +- test/array.jl | 2 +- test/capturing.jl | 10 +++++----- test/execution.jl | 2 +- test/metal.jl | 4 ++-- test/mps/copy.jl | 6 +++--- test/mps/linalg.jl | 6 +++--- test/mps/matrix.jl | 6 +++--- test/mps/vector.jl | 4 ++-- test/profiling.jl | 2 +- 31 files changed, 92 insertions(+), 90 deletions(-) create mode 100644 src/deprecated.jl diff --git a/bin/metallib-load b/bin/metallib-load index 3e72ba399..c7b6b83de 100755 --- a/bin/metallib-load +++ b/bin/metallib-load @@ -47,7 +47,7 @@ function main(args) end isempty(metallib) && error("Empty input") - dev = current_device() + dev = device() verbose && println("Using device: ", dev.name) lib = try diff --git a/docs/src/api/essentials.md b/docs/src/api/essentials.md index 87eef7b96..d1fc0da6f 100644 --- a/docs/src/api/essentials.md +++ b/docs/src/api/essentials.md @@ -6,7 +6,7 @@ ```@docs device! devices -current_device +device global_queue synchronize device_synchronize diff --git a/docs/src/faq/contributing.md b/docs/src/faq/contributing.md index 29c45fe85..2f17b21a9 100644 --- a/docs/src/faq/contributing.md +++ b/docs/src/faq/contributing.md @@ -95,7 +95,7 @@ There are varying degrees of user-facing interfaces from Metal.jl. At the lowest `Metal.MTL.xxx`. This is for low-level functionality close to or at bare Objective-C, or things that a normal user wouldn't directly be using. `Metal.MPS.xxx` is for Metal Performance Shader specifics (like `MPSMatrix`). -Next, is `Metal.xxx`. This is for higher-level, usually pure-Julian functionality (like `current_device()`). +Next, is `Metal.xxx`. This is for higher-level, usually pure-Julian functionality (like `device()`). The only thing beyond this is exporting into the global namespace. That would be useful for uniquely-named functions/structures/macros with clear and common use-cases (`MtlArray` or `@metal`). diff --git a/lib/mps/copy.jl b/lib/mps/copy.jl index fd13323fc..1a6b1893d 100644 --- a/lib/mps/copy.jl +++ b/lib/mps/copy.jl @@ -32,11 +32,11 @@ export MPSMatrixCopy @autoproperty destinationsAreTransposed::Bool end -function MPSMatrixCopy(device, copyRows, copyColumns, sourcesAreTransposed, destinationsAreTransposed) +function MPSMatrixCopy(dev, copyRows, copyColumns, sourcesAreTransposed, destinationsAreTransposed) kernel = @objc [MPSMatrixCopy alloc]::id{MPSMatrixCopy} obj = MPSMatrixCopy(kernel) finalizer(release, obj) - @objc [obj::id{MPSMatrixCopy} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSMatrixCopy} initWithDevice:dev::id{MTLDevice} copyRows:copyRows::NSUInteger copyColumns:copyColumns::NSUInteger sourcesAreTransposed:sourcesAreTransposed::Bool diff --git a/lib/mps/decomposition.jl b/lib/mps/decomposition.jl index 8796367ff..8bf4f67ce 100644 --- a/lib/mps/decomposition.jl +++ b/lib/mps/decomposition.jl @@ -12,11 +12,11 @@ export MPSMatrixDecompositionLU @objcwrapper immutable=false MPSMatrixDecompositionLU <: MPSMatrixUnaryKernel -function MPSMatrixDecompositionLU(device, rows, columns) +function MPSMatrixDecompositionLU(dev, rows, columns) kernel = @objc [MPSMatrixDecompositionLU alloc]::id{MPSMatrixDecompositionLU} obj = MPSMatrixDecompositionLU(kernel) finalizer(release, obj) - @objc [obj::id{MPSMatrixDecompositionLU} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSMatrixDecompositionLU} initWithDevice:dev::id{MTLDevice} rows:rows::NSUInteger columns:columns::NSUInteger]::id{MPSMatrixDecompositionLU} return obj @@ -37,11 +37,11 @@ export MPSMatrixDecompositionCholesky @objcwrapper immutable=false MPSMatrixDecompositionCholesky <: MPSMatrixUnaryKernel -function MPSMatrixDecompositionCholesky(device, lower, order) +function MPSMatrixDecompositionCholesky(dev, lower, order) kernel = @objc [MPSMatrixDecompositionCholesky alloc]::id{MPSMatrixDecompositionCholesky} obj = MPSMatrixDecompositionCholesky(kernel) finalizer(release, obj) - @objc [obj::id{MPSMatrixDecompositionCholesky} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSMatrixDecompositionCholesky} initWithDevice:dev::id{MTLDevice} lower:lower::Bool order:order::NSUInteger]::id{MPSMatrixDecompositionCholesky} return obj diff --git a/lib/mps/images.jl b/lib/mps/images.jl index 5e5df3dd1..c1b6d6f17 100644 --- a/lib/mps/images.jl +++ b/lib/mps/images.jl @@ -39,11 +39,11 @@ end @objcwrapper immutable=false MPSImageGaussianBlur <: MPSUnaryImageKernel -function MPSImageGaussianBlur(device, sigma) +function MPSImageGaussianBlur(dev, sigma) kernel = @objc [MPSImageGaussianBlur alloc]::id{MPSImageGaussianBlur} obj = MPSImageGaussianBlur(kernel) finalizer(release, obj) - @objc [obj::id{MPSImageGaussianBlur} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSImageGaussianBlur} initWithDevice:dev::id{MTLDevice} sigma:sigma::Float32]::id{MPSImageGaussianBlur} return obj end @@ -51,11 +51,11 @@ end @objcwrapper immutable=false MPSImageBox <: MPSUnaryImageKernel -function MPSImageBox(device, kernelWidth, kernelHeight) +function MPSImageBox(dev, kernelWidth, kernelHeight) kernel = @objc [MPSImageBox alloc]::id{MPSImageBox} obj = MPSImageBox(kernel) finalizer(release, obj) - @objc [obj::id{MPSImageBox} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSImageBox} initWithDevice:dev::id{MTLDevice} kernelWidth:kernelWidth::Int kernelHeight:kernelHeight::Int]::id{MPSImageBox} return obj @@ -69,7 +69,7 @@ function blur(image, kernel; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm) w,h = size(image) - alignment = MTL.minimumLinearTextureAlignmentForPixelFormat(current_device(), pixelFormat) + alignment = MTL.minimumLinearTextureAlignmentForPixelFormat(device(), pixelFormat) preBytesPerRow = sizeof(eltype(image))*w rowoffset = alignment - (preBytesPerRow - 1) % alignment - 1 @@ -83,7 +83,7 @@ function blur(image, kernel; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm) textDesc2.usage = MTL.MTLTextureUsageShaderRead | MTL.MTLTextureUsageShaderWrite text2 = MTL.MTLTexture(res.data.rc.obj, textDesc2, 0, bytesPerRow) - cmdbuf = MTLCommandBuffer(global_queue(current_device())) + cmdbuf = MTLCommandBuffer(global_queue(device())) encode!(cmdbuf, kernel, text1, text2) commit!(cmdbuf) @@ -91,11 +91,11 @@ function blur(image, kernel; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm) end function gaussianblur(image; sigma, pixelFormat=MTL.MTLPixelFormatRGBA8Unorm) - kernel = MPSImageGaussianBlur(current_device(), sigma) + kernel = MPSImageGaussianBlur(device(), sigma) return blur(image, kernel; pixelFormat) end function boxblur(image, kernelWidth, kernelHeight; pixelFormat=MTL.MTLPixelFormatRGBA8Unorm) - kernel = MPSImageBox(current_device(), kernelWidth, kernelHeight) + kernel = MPSImageBox(device(), kernelWidth, kernelHeight) return blur(image, kernel; pixelFormat) end diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index 03b801b2b..30d7d5de5 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -61,7 +61,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri typC = eltype(C) # If possible, dispatch to performance shaders - if is_supported(current_device()) && + if is_supported(device()) && typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES matmul!(C, A, B, alpha, beta, transA, transB) else @@ -131,7 +131,7 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B typC = eltype(C) # If possible, dispatch to performance shaders - if is_supported(current_device()) && + if is_supported(device()) && typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES matvecmul!(C, A, B, alpha, beta, transA) else @@ -177,7 +177,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T = @autoreleasepool function LinearAlgebra.lu(A::MtlMatrix{T}; check::Bool=true) where {T<:MtlFloat} M,N = size(A) - dev = current_device() + dev = device() queue = global_queue(dev) At = MtlMatrix{T,Private}(undef, (N, M)) @@ -235,7 +235,7 @@ end check::Bool=true, allowsingular::Bool=false) where {T<:MtlFloat} M,N = size(A) - dev = current_device() + dev = device() queue = global_queue(dev) At = MtlMatrix{T,Private}(undef, (N, M)) @@ -278,7 +278,7 @@ end axes(B,2) == axes(A,1) && axes(B,1) == axes(A,2) || throw(DimensionMismatch("transpose")) M,N = size(A) - dev = current_device() + dev = device() queue = global_queue(dev) cmdbuf = MTLCommandBuffer(queue) diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index 725817702..7e1604595 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -179,12 +179,12 @@ export MPSMatrixMultiplication, matmul! @autoproperty batchStart::NSUInteger setter=setBatchStart end -function MPSMatrixMultiplication(device, transposeLeft, transposeRight, resultRows, +function MPSMatrixMultiplication(dev, transposeLeft, transposeRight, resultRows, resultColumns, interiorColumns, alpha, beta) kernel = @objc [MPSMatrixMultiplication alloc]::id{MPSMatrixMultiplication} obj = MPSMatrixMultiplication(kernel) finalizer(release, obj) - @objc [obj::id{MPSMatrixMultiplication} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSMatrixMultiplication} initWithDevice:dev::id{MTLDevice} transposeLeft:transposeLeft::Bool transposeRight:transposeRight::Bool resultRows:resultRows::NSUInteger @@ -225,14 +225,14 @@ function matmul!(c::MtlArray{T1,N}, a::MtlArray{T2,N}, b::MtlArray{T3,N}, mps_b = MPSMatrix(b) mps_c = MPSMatrix(c) - mat_mul_kernel = MPSMatrixMultiplication(current_device(), + mat_mul_kernel = MPSMatrixMultiplication(device(), transpose_b, transpose_a, rows_c, cols_c, cols_a, alpha, beta) # Encode and commit matmul kernel - cmdbuf = MTLCommandBuffer(global_queue(current_device())) + cmdbuf = MTLCommandBuffer(global_queue(device())) encode!(cmdbuf, mat_mul_kernel, mps_b, mps_a, mps_c) commit!(cmdbuf) @@ -253,11 +253,11 @@ export MPSMatrixFindTopK, topk, topk! @autoproperty sourceRows::NSInteger setter=setSourceRows end -function MPSMatrixFindTopK(device, numberOfTopKValues) +function MPSMatrixFindTopK(dev, numberOfTopKValues) kernel = @objc [MPSMatrixFindTopK alloc]::id{MPSMatrixFindTopK} obj = MPSMatrixFindTopK(kernel) finalizer(release, obj) - @objc [obj::id{MPSMatrixFindTopK} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSMatrixFindTopK} initWithDevice:dev::id{MTLDevice} numberOfTopKValues:numberOfTopKValues::NSUInteger]::id{MPSMatrixFindTopK} return obj end @@ -299,11 +299,11 @@ end mps_i = MPSMatrix(I) mps_v = MPSMatrix(V) - topk_kernel = MPSMatrixFindTopK(current_device(), k) + topk_kernel = MPSMatrixFindTopK(device(), k) topk_kernel.indexOffset = 1 # Encode and commit topk kernel - cmdbuf = MTLCommandBuffer(global_queue(current_device())) + cmdbuf = MTLCommandBuffer(global_queue(device())) encode!(cmdbuf, topk_kernel, mps_a, mps_i, mps_v) commit!(cmdbuf) @@ -343,11 +343,11 @@ end for f in (:MPSMatrixSoftMax, :MPSMatrixLogSoftMax) @eval begin - function $(f)(device) + function $(f)(dev) kernel = @objc [$(f) alloc]::id{$(f)} obj = $(f)(kernel) finalizer(release, obj) - @objc [obj::id{$(f)} initWithDevice:device::id{MTLDevice}]::id{$(f)} + @objc [obj::id{$(f)} initWithDevice:dev::id{MTLDevice}]::id{$(f)} return obj end diff --git a/lib/mps/vector.jl b/lib/mps/vector.jl index 66693e1d3..4de641dae 100644 --- a/lib/mps/vector.jl +++ b/lib/mps/vector.jl @@ -89,11 +89,11 @@ export MPSMatrixVectorMultiplication, matvecmul! @objcwrapper immutable=false MPSMatrixVectorMultiplication <: MPSMatrixBinaryKernel -function MPSMatrixVectorMultiplication(device, transpose, rows, columns, alpha, beta) +function MPSMatrixVectorMultiplication(dev, transpose, rows, columns, alpha, beta) kernel = @objc [MPSMatrixVectorMultiplication alloc]::id{MPSMatrixVectorMultiplication} obj = MPSMatrixVectorMultiplication(kernel) finalizer(release, obj) - @objc [obj::id{MPSMatrixVectorMultiplication} initWithDevice:device::id{MTLDevice} + @objc [obj::id{MPSMatrixVectorMultiplication} initWithDevice:dev::id{MTLDevice} transpose:transpose::Bool rows:rows::NSUInteger columns:columns::NSUInteger @@ -129,12 +129,12 @@ function matvecmul!(c::MtlVector, a::MtlMatrix, b::MtlVector, alpha::Number=true mps_b = MPSVector(b) mps_c = MPSVector(c) - matvec_mul_kernel = MPSMatrixVectorMultiplication(current_device(), !transpose, + matvec_mul_kernel = MPSMatrixVectorMultiplication(device(), !transpose, rows_c, cols_a, alpha, beta) # Encode and commit matmul kernel - cmdbuf = MTLCommandBuffer(global_queue(current_device())) + cmdbuf = MTLCommandBuffer(global_queue(device())) encode!(cmdbuf, matvec_mul_kernel, mps_a, mps_b, mps_c) commit!(cmdbuf) diff --git a/lib/mtl/device.jl b/lib/mtl/device.jl index a6b2b6d55..cbe29f3e4 100644 --- a/lib/mtl/device.jl +++ b/lib/mtl/device.jl @@ -113,8 +113,8 @@ export supports_family, is_m3, is_m2, is_m1 MTLGPUFamilyMac2 = 2002 # Mac family 2 GPU features end -function supports_family(device::MTLDevice, gpufamily::MTLGPUFamily) - @objc [device::MTLDevice supportsFamily:gpufamily::MTLGPUFamily]::Bool +function supports_family(dev::MTLDevice, gpufamily::MTLGPUFamily) + @objc [dev::MTLDevice supportsFamily:gpufamily::MTLGPUFamily]::Bool end is_m1(dev::MTLDevice) = supports_family(dev, MTLGPUFamilyApple7) && diff --git a/lib/mtl/library.jl b/lib/mtl/library.jl index e83feaaa8..f9037b535 100644 --- a/lib/mtl/library.jl +++ b/lib/mtl/library.jl @@ -8,10 +8,10 @@ export MTLLibrary, MTLLibraryFromFile, MTLLibraryFromData @autoproperty functionNames::id{NSArray} type=Vector{NSString} end -function MTLLibrary(device::MTLDevice, src::String, +function MTLLibrary(dev::MTLDevice, src::String, opts::MTLCompileOptions=MTLCompileOptions()) err = Ref{id{NSError}}(nil) - handle = @objc [device::id{MTLDevice} newLibraryWithSource:src::id{NSString} + handle = @objc [dev::id{MTLDevice} newLibraryWithSource:src::id{NSString} options:opts::id{MTLCompileOptions} error:err::Ptr{id{NSError}}]::id{MTLLibrary} err[] == nil || throw(NSError(err[])) @@ -21,14 +21,14 @@ function MTLLibrary(device::MTLDevice, src::String, return obj end -function MTLLibraryFromFile(device::MTLDevice, path::String) +function MTLLibraryFromFile(dev::MTLDevice, path::String) err = Ref{id{NSError}}(nil) handle = if macos_version() >= v"13" url = NSFileURL(path) - @objc [device::id{MTLDevice} newLibraryWithURL:url::id{NSURL} + @objc [dev::id{MTLDevice} newLibraryWithURL:url::id{NSURL} error:err::Ptr{id{NSError}}]::id{MTLLibrary} else - @objc [device::id{MTLDevice} newLibraryWithFile:path::id{NSString} + @objc [dev::id{MTLDevice} newLibraryWithFile:path::id{NSString} error:err::Ptr{id{NSError}}]::id{MTLLibrary} end err[] == nil || throw(NSError(err[])) @@ -38,11 +38,11 @@ function MTLLibraryFromFile(device::MTLDevice, path::String) return obj end -function MTLLibraryFromData(device::MTLDevice, input_data) +function MTLLibraryFromData(dev::MTLDevice, input_data) err = Ref{id{NSError}}(nil) GC.@preserve input_data begin data = dispatch_data(pointer(input_data), sizeof(input_data)) - handle = @objc [device::id{MTLDevice} newLibraryWithData:data::dispatch_data_t + handle = @objc [dev::id{MTLDevice} newLibraryWithData:data::dispatch_data_t error:err::Ptr{id{NSError}}]::id{MTLLibrary} end err[] == nil || throw(NSError(err[])) diff --git a/lib/mtl/texture.jl b/lib/mtl/texture.jl index a51aa342d..de2e35726 100644 --- a/lib/mtl/texture.jl +++ b/lib/mtl/texture.jl @@ -197,8 +197,8 @@ end ## bitwise operations lose type information, so allow conversions Base.convert(::Type{MTLPixelFormat}, x::Integer) = MTLPixelFormat(x) -function minimumLinearTextureAlignmentForPixelFormat(device, format) - return @objc [device::MTLDevice minimumLinearTextureAlignmentForPixelFormat:format::MTLPixelFormat]::NSUInteger +function minimumLinearTextureAlignmentForPixelFormat(dev, format) + return @objc [dev::MTLDevice minimumLinearTextureAlignmentForPixelFormat:format::MTLPixelFormat]::NSUInteger end @cenum MTLTextureUsage::NSUInteger begin @@ -247,8 +247,8 @@ function MTLTexture(buffer, descriptor, offset, bytesPerRow) return obj end -function MTLTexture(device, descriptor) - texture = @objc [device::id{MTLDevice} newTextureWithDescriptor:descriptor::id{MTLTextureDescriptor}]::id{MTLTexture} +function MTLTexture(dev, descriptor) + texture = @objc [dev::id{MTLDevice} newTextureWithDescriptor:descriptor::id{MTLTextureDescriptor}]::id{MTLTexture} obj = MTLTexture(texture) finalizer(release, obj) diff --git a/src/Metal.jl b/src/Metal.jl index 0796d29e3..f05765810 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -71,6 +71,8 @@ export MetalBackend include("../ext/BFloat16sExt.jl") end +include("deprecated.jl") + include("precompile.jl") end # module diff --git a/src/array.jl b/src/array.jl index 6b1bb2c8c..fe36340de 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1,6 +1,6 @@ # host array -export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private, device +export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private function hasfieldcount(@nospecialize(dt)) try @@ -61,7 +61,7 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N} maxsize end - dev = current_device() + dev = device() if bufsize == 0 # Metal doesn't support empty allocations. For simplicity (i.e., the ability to get # a pointer, query the buffer's properties, etc), we use a 1-byte buffer instead. @@ -557,7 +557,7 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPtr{T}, dims; own=false) end function Base.unsafe_wrap(A::Type{<:MtlArray{T,N}}, arr::Array, dims=size(arr); - dev=current_device(), kwargs...) where {T,N} + dev=device(), kwargs...) where {T,N} GC.@preserve arr begin buf = MTLBuffer(dev, prod(dims) * sizeof(T), pointer(arr); nocopy=true, kwargs...) return A(buf, Dims(dims)) diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index b8d0e1e02..5afd4867d 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -169,7 +169,7 @@ end @signpost_event log=log_compiler() "Link" "Job=$job" @signpost_interval log=log_compiler() "Instantiate compute pipeline" begin - dev = current_device() + dev = device() lib = MTLLibraryFromData(dev, compiled.image) fun = MTLFunction(lib, compiled.entry) pipeline_state = try diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 37f89c2f2..4808c6224 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -176,7 +176,7 @@ in a hot path without degrading performance. New code will be generated automati the function changes, or when different types or keyword arguments are provided. """ function mtlfunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT} - dev = current_device() + dev = device() Base.@lock mtlfunction_lock begin # compile the function cache = compiler_cache(dev) @@ -260,7 +260,7 @@ end end @autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1, - queue=global_queue(current_device())) + queue=global_queue(device())) groups = MTLSize(groups) threads = MTLSize(threads) (groups.width>0 && groups.height>0 && groups.depth>0) || diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 65ec12601..e4b6f299b 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -31,7 +31,7 @@ function code_agx(io::IO, @nospecialize(func), @nospecialize(types), kernel::Bool=true; kwargs...) compiler_kwargs, kwargs = split_kwargs_runtime(kwargs, COMPILER_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types)) - config = compiler_config(current_device(); kernel, compiler_kwargs...) + config = compiler_config(device(); kernel, compiler_kwargs...) job = CompilerJob(source, config) code_agx(io, job) end @@ -55,7 +55,7 @@ end # create a binary archive bin_desc = MTLBinaryArchiveDescriptor() - bin = MTLBinaryArchive(current_device(), bin_desc) + bin = MTLBinaryArchive(device(), bin_desc) add_functions!(bin, pipeline_desc) mktempdir() do dir @@ -172,7 +172,7 @@ for method in (:code_typed, :code_warntype, :code_llvm) kernel::Bool=false, kwargs...) compiler_kwargs, kwargs = split_kwargs_runtime(kwargs, COMPILER_KWARGS) source = methodinstance(typeof(func), Base.to_tuple_type(types)) - config = compiler_config(current_device(); kernel, compiler_kwargs...) + config = compiler_config(device(); kernel, compiler_kwargs...) job = CompilerJob(source, config) GPUCompiler.$method($(args...); kwargs...) end @@ -226,7 +226,7 @@ Return a type `r` such that `f(args...)::r` where `args::tt`. """ function return_type(@nospecialize(func), @nospecialize(tt)) source = methodinstance(typeof(func), tt) - config = compiler_config(current_device()) + config = compiler_config(device()) job = CompilerJob(source, config) interp = GPUCompiler.get_interpreter(job) sig = Base.signature_type(func, tt) diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 000000000..4c07dc090 --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,3 @@ +export current_device + +@deprecate current_device() device() diff --git a/src/gpuarrays.jl b/src/gpuarrays.jl index 3f688ae1f..7ae6d0b78 100644 --- a/src/gpuarrays.jl +++ b/src/gpuarrays.jl @@ -1,8 +1,5 @@ ## GPUArrays interfaces -GPUArrays.device(x::MtlArray) = x.dev - - ## execution struct mtlArrayBackend <: AbstractGPUBackend end @@ -66,7 +63,7 @@ GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend() const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}() function GPUArrays.default_rng(::Type{<:MtlArray}) - dev = current_device() + dev = device() get!(GLOBAL_RNGs, dev) do N = dev.maxThreadsPerThreadgroup.width state = MtlArray{NTuple{4, UInt32}}(undef, N) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 85db94442..f5b823b45 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -179,7 +179,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, grain = contiguous ? prevpow(2, cld(16, sizeof(T))) : 1 # the maximum number of threads is limited by the hardware - dev = current_device() + dev = device() maxthreads = min(Int(dev.maxThreadsPerThreadgroup.width), Int(dev.maxThreadgroupMemoryLength) รท sizeof(T)) diff --git a/src/state.jl b/src/state.jl index edf88c18d..2208f739e 100644 --- a/src/state.jl +++ b/src/state.jl @@ -1,4 +1,4 @@ -export current_device, device!, global_queue, synchronize, device_synchronize +export device, device!, global_queue, synchronize, device_synchronize log_compiler() = OSLog("org.juliagpu.metal", "Compiler") log_compiler(args...) = log_compiler()(args...) @@ -6,14 +6,14 @@ log_array() = OSLog("org.juliagpu.metal", "Array") log_array(args...) = log_array()(args...) """ - current_device()::MTLDevice + device()::MTLDevice Return the Metal GPU device associated with the current Julia task. Since all M-series systems currently only externally show a single GPU, this function effectively returns the only system GPU. """ -function current_device() +function device() get!(task_local_storage(), :MTLDevice) do dev = MTLDevice(1) if !supports_family(dev, MTL.MTLGPUFamilyApple7) @@ -65,7 +65,7 @@ Create a new MTLCommandBuffer from the global command queue, commit it to the qu and simply wait for it to be completed. Since command buffers *should* execute in a First-In-First-Out manner, this synchronizes the GPU. """ -@autoreleasepool function synchronize(queue::MTLCommandQueue=global_queue(current_device())) +@autoreleasepool function synchronize(queue::MTLCommandQueue=global_queue(device())) cmdbuf = MTLCommandBuffer(queue) commit!(cmdbuf) wait_completed(cmdbuf) diff --git a/src/utilities.jl b/src/utilities.jl index fdf6c1c81..09de8e1ca 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -90,7 +90,7 @@ function capture_dir() end function captured(f; dest=MTL.MTLCaptureDestinationGPUTraceDocument, - object=global_queue(current_device())) + object=global_queue(device())) if !haskey(ENV, "METAL_CAPTURE_ENABLED") || ENV["METAL_CAPTURE_ENABLED"] != "1" @warn """Environment variable 'METAL_CAPTURE_ENABLED' is not set. In most cases, this will need to be set to 1 before launching Julia to enable GPU frame capture.""" diff --git a/test/array.jl b/test/array.jl index 5cd5a3bcb..e0ffbdf23 100644 --- a/test/array.jl +++ b/test/array.jl @@ -14,7 +14,7 @@ end @testset "constructors" begin xs = MtlArray{Int8}(undef, 2, 3) - @test device(xs) == current_device() + @test device(xs) == device() @test Base.elsize(xs) == sizeof(Int8) @test xs.data[].length == 6 xs2 = MtlArray{Int8, 2}(xs) diff --git a/test/capturing.jl b/test/capturing.jl index b3c657221..91df77673 100644 --- a/test/capturing.jl +++ b/test/capturing.jl @@ -20,10 +20,10 @@ manager = MTLCaptureManager() desc = MTLCaptureDescriptor() # Capture Object @test desc.captureObject == nothing -cmdq = global_queue(current_device()) +cmdq = global_queue(device()) desc.captureObject = cmdq @test desc.captureObject == cmdq -dev = current_device() +dev = device() desc.captureObject = dev @test desc.captureObject == dev @@ -39,12 +39,12 @@ desc.outputURL = NSFileURL(path) @test desc.outputURL == NSFileURL(path) # Capture Scope -queue = MTLCommandQueue(current_device()) +queue = MTLCommandQueue(device()) default_scope = manager.defaultCaptureScope @test default_scope == nothing new_scope = MTLCaptureScope(@objc [manager::id{MTLCaptureManager} newCaptureScopeWithCommandQueue:queue::id{MTLCommandQueue}]::id{MTLCaptureScope}) @test new_scope.commandQueue == queue -@test new_scope.device == current_device() +@test new_scope.device == device() @test new_scope.label == nothing new_label = "Metal.jl capturing test" new_scope.label = new_label @@ -72,7 +72,7 @@ release(new_scope) @testset "macro" begin Metal.@capture @metal threads=4 tester(bufferA) @test isdir("julia_1.gputrace") - Metal.@capture object=current_device() @metal threads=4 tester(bufferA) + Metal.@capture object=device() @metal threads=4 tester(bufferA) @test isdir("julia_2.gputrace") end diff --git a/test/execution.jl b/test/execution.jl index cb537881b..75cd75755 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -140,7 +140,7 @@ end @test all(vecA == Int.([5, 5, 5, 5, 5, 5, 0, 0])) vecA .= 0 - dev = current_device() + dev = device() queue = MTLCommandQueue(dev) @metal threads=(3) queue=queue tester(bufferA) synchronize(queue) diff --git a/test/metal.jl b/test/metal.jl index a9894614e..64614686a 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -331,7 +331,7 @@ desc.retainedReferences = false desc.errorOptions = MTL.MTLCommandBufferErrorOptionEncoderExecutionStatus @test desc.errorOptions == MTL.MTLCommandBufferErrorOptionEncoderExecutionStatus -cmq = MTLCommandQueue(current_device()) +cmq = MTLCommandQueue(device()) cmdbuf = MTLCommandBuffer(cmq, desc) if !runtime_validation # when the debug layer is activated, Metal seems to retain all resources? @@ -454,7 +454,7 @@ end arr = Metal.zeros(T, 4) buf = Base.unsafe_convert(MTL.MTLBuffer, arr) - Metal.unsafe_fill!(current_device(), Metal.MtlPtr{T}(buf, 0), T(val), 4) + Metal.unsafe_fill!(device(), Metal.MtlPtr{T}(buf, 0), T(val), 4) @test all(Array(arr) .== val) end diff --git a/test/mps/copy.jl b/test/mps/copy.jl index 57cce7ad5..fd2456cd8 100644 --- a/test/mps/copy.jl +++ b/test/mps/copy.jl @@ -4,8 +4,8 @@ using Metal, Test const IGNORE_UNION = Union{Complex, Int64, UInt64} function copytest(src, srctrans, dsttrans) - device = current_device() - queue = global_queue(device) + dev = device() + queue = global_queue(dev) dst = if srctrans == dsttrans similar(src) else @@ -23,7 +23,7 @@ function copytest(src, srctrans, dsttrans) dstMPS = MPS.MPSMatrix(dst) copydesc = MPS.MPSMatrixCopyDescriptor(srcMPS, dstMPS) - copykern = MPS.MPSMatrixCopy(device, cprows, cpcols, srctrans, dsttrans) + copykern = MPS.MPSMatrixCopy(dev, cprows, cpcols, srctrans, dsttrans) MPS.encode!(cbuf, copykern, copydesc) end wait_completed(cmdbuf) diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl index 344ac7de1..d0f982489 100644 --- a/test/mps/linalg.jl +++ b/test/mps/linalg.jl @@ -1,6 +1,6 @@ using LinearAlgebra -if MPS.is_supported(current_device()) +if MPS.is_supported(device()) @testset "mixed-precision matrix matrix multiplication" begin N = 10 @@ -218,7 +218,7 @@ using .MPS: MPSMatrixSoftMax, MPSMatrixLogSoftMax cols = rand(UInt) rows = rand(UInt) - skern = MPSMatrixSoftMax(current_device()) + skern = MPSMatrixSoftMax(device()) skern.sourceColumns = cols skern.sourceRows = rows @@ -226,7 +226,7 @@ using .MPS: MPSMatrixSoftMax, MPSMatrixLogSoftMax @test skern.sourceColumns == cols @test skern.sourceRows == rows - lkern = MPSMatrixLogSoftMax(current_device()) + lkern = MPSMatrixLogSoftMax(device()) lkern.sourceColumns = cols lkern.sourceRows = rows diff --git a/test/mps/matrix.jl b/test/mps/matrix.jl index a94e876f6..37da8d0b2 100644 --- a/test/mps/matrix.jl +++ b/test/mps/matrix.jl @@ -37,7 +37,7 @@ end using .MPS: MPSMatrix @testset "MPSMatrix" begin - dev = current_device() + dev = device() T = Float32 DT = convert(MPSDataType, T) rows = 2 @@ -136,7 +136,7 @@ using .MPS: MPSMatrixMultiplication alpha = 1 beta = 0 - mat_mul = MPSMatrixMultiplication(current_device(), + mat_mul = MPSMatrixMultiplication(device(), transpose_b, transpose_a, rows_c, cols_c, cols_a, alpha, beta) @@ -159,7 +159,7 @@ using .MPS: MPSMatrixFindTopK rows = 2 cols = 3 - topk = MPSMatrixFindTopK(current_device(), k) + topk = MPSMatrixFindTopK(device(), k) topk.indexOffset = off topk.sourceColumns = cols topk.sourceRows = rows diff --git a/test/mps/vector.jl b/test/mps/vector.jl index 9293bd1b2..3a8a6aa74 100644 --- a/test/mps/vector.jl +++ b/test/mps/vector.jl @@ -32,7 +32,7 @@ end using .MPS: MPSVector @testset "MPSVector" begin - dev = current_device() + dev = device() T = Float32 DT = convert(MPSDataType, T) len = 4 @@ -90,7 +90,7 @@ using .MPS: MPSMatrixVectorMultiplication alpha = 1 beta = 0 - matvec_mul = MPSMatrixVectorMultiplication(current_device(), trans, + matvec_mul = MPSMatrixVectorMultiplication(device(), trans, rows_c, cols_a, alpha, beta) diff --git a/test/profiling.jl b/test/profiling.jl index 9d2b53b15..f280c2302 100644 --- a/test/profiling.jl +++ b/test/profiling.jl @@ -13,7 +13,7 @@ else error("Could not parse xctrace version output:\n$version_output") else xcode_version = VersionNumber(parse(Int, m.captures[1]), parse(Int, m.captures[2])) - if MTL.is_m1(current_device()) && macos_version() >= v"14.4" && xcode_version < v"15.3" + if MTL.is_m1(device()) && macos_version() >= v"14.4" && xcode_version < v"15.3" @warn "Skipping profiling tests because of an M1-related bug on macOS 14.4 and Xcode < 15.3; please upgrade Xcode first" else run_tests = true