diff --git a/src/MetalKernels.jl b/src/MetalKernels.jl index 11aaedda3..48e4cf44c 100644 --- a/src/MetalKernels.jl +++ b/src/MetalKernels.jl @@ -111,15 +111,15 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize= ctx = KA.mkcontext(obj, ndrange, iterspace) end - nblocks = length(KA.blocks(iterspace)) + groups = length(KA.blocks(iterspace)) threads = length(KA.workitems(iterspace)) - if nblocks == 0 + if groups == 0 return nothing end # Launch kernel - kernel(ctx, args...; threads=threads, groups=nblocks) + kernel(ctx, args...; threads, groups) return nothing end @@ -143,7 +143,7 @@ end end @device_override @inline function KA.__index_Group_Cartesian(ctx) - @inbounds blocks(KA.__iterspace(ctx))[threadgroup_position_in_grid_1d()] + @inbounds KA.blocks(KA.__iterspace(ctx))[threadgroup_position_in_grid_1d()] end @device_override @inline function KA.__index_Global_Cartesian(ctx) diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index caab02dde..f2caf3349 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -28,7 +28,15 @@ There is one supported keyword argument that influences the behavior of `@metal` """ macro metal(ex...) call = ex[end] - kwargs = ex[1:end-1] + kwargs = map(ex[1:end-1]) do kwarg + if kwarg isa Symbol + :($kwarg = $kwarg) + elseif Meta.isexpr(kwarg, :(=)) + kwarg + else + throw(ArgumentError("Invalid keyword argument '$kwarg'")) + end + end # destructure the kernel call Meta.isexpr(call, :call) || throw(ArgumentError("second argument to @metal should be a function call")) diff --git a/src/gpuarrays.jl b/src/gpuarrays.jl index 3d97335c3..3f688ae1f 100644 --- a/src/gpuarrays.jl +++ b/src/gpuarrays.jl @@ -20,9 +20,9 @@ struct mtlKernelContext <: AbstractKernelContext end return (; threads=Int(threads), blocks=Int(blocks)) end -function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, blocks::Int; +function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int; name::Union{String,Nothing}) - @metal threads=threads groups=blocks name=name f(mtlKernelContext(), args...) + @metal threads groups name f(mtlKernelContext(), args...) end diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 2eaea9e7d..3eb4f9b70 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -225,7 +225,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, # perform the actual reduction if reduce_groups == 1 # we can cover the dimensions to reduce using a single group - @metal threads=threads groups=groups partial_mapreduce_device( + @metal threads groups partial_mapreduce_device( f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R′, A) else @@ -236,7 +236,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, # use broadcasting to extend singleton dimensions partial .= R end - @metal threads=threads groups=groups partial_mapreduce_device( + @metal threads groups partial_mapreduce_device( f, op, init, Val(threads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A) diff --git a/test/execution.jl b/test/execution.jl index a5b3fe11e..e4c8920d4 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -9,10 +9,14 @@ dummy() = return @testset "launch configuration" begin @metal dummy() + threads = 1 + @metal threads dummy() @metal threads=1 dummy() @metal threads=(1,1) dummy() @metal threads=(1,1,1) dummy() + groups = 1 + @metal groups dummy() @metal groups=1 dummy() @metal groups=(1,1) dummy() @metal groups=(1,1,1) dummy()