Skip to content

Commit

Permalink
Allow more kwargs syntax with kernel launches (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Dec 6, 2023
1 parent 699211b commit 6441d1c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize=
ctx = KA.mkcontext(obj, ndrange, iterspace)
end

nblocks = length(KA.blocks(iterspace))
groups = length(KA.blocks(iterspace))
threads = length(KA.workitems(iterspace))

if nblocks == 0
if groups == 0
return nothing
end

# Launch kernel
kernel(ctx, args...; threads=threads, groups=nblocks)
kernel(ctx, args...; threads, groups)
return nothing
end

Expand All @@ -143,7 +143,7 @@ end
end

@device_override @inline function KA.__index_Group_Cartesian(ctx)
@inbounds blocks(KA.__iterspace(ctx))[threadgroup_position_in_grid_1d()]
@inbounds KA.blocks(KA.__iterspace(ctx))[threadgroup_position_in_grid_1d()]
end

@device_override @inline function KA.__index_Global_Cartesian(ctx)
Expand Down
10 changes: 9 additions & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ There is one supported keyword argument that influences the behavior of `@metal`
"""
macro metal(ex...)
call = ex[end]
kwargs = ex[1:end-1]
kwargs = map(ex[1:end-1]) do kwarg
if kwarg isa Symbol
:($kwarg = $kwarg)
elseif Meta.isexpr(kwarg, :(=))
kwarg
else
throw(ArgumentError("Invalid keyword argument '$kwarg'"))
end
end

# destructure the kernel call
Meta.isexpr(call, :call) || throw(ArgumentError("second argument to @metal should be a function call"))
Expand Down
4 changes: 2 additions & 2 deletions src/gpuarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ struct mtlKernelContext <: AbstractKernelContext end
return (; threads=Int(threads), blocks=Int(blocks))
end

function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, blocks::Int;
function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int;
name::Union{String,Nothing})
@metal threads=threads groups=blocks name=name f(mtlKernelContext(), args...)
@metal threads groups name f(mtlKernelContext(), args...)
end


Expand Down
4 changes: 2 additions & 2 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# perform the actual reduction
if reduce_groups == 1
# we can cover the dimensions to reduce using a single group
@metal threads=threads groups=groups partial_mapreduce_device(
@metal threads groups partial_mapreduce_device(
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R′, A)
else
Expand All @@ -236,7 +236,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# use broadcasting to extend singleton dimensions
partial .= R
end
@metal threads=threads groups=groups partial_mapreduce_device(
@metal threads groups partial_mapreduce_device(
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)

Expand Down
4 changes: 4 additions & 0 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ dummy() = return
@testset "launch configuration" begin
@metal dummy()

threads = 1
@metal threads dummy()
@metal threads=1 dummy()
@metal threads=(1,1) dummy()
@metal threads=(1,1,1) dummy()

groups = 1
@metal groups dummy()
@metal groups=1 dummy()
@metal groups=(1,1) dummy()
@metal groups=(1,1,1) dummy()
Expand Down

0 comments on commit 6441d1c

Please sign in to comment.