Skip to content

Commit

Permalink
add remaining handling of padding
Browse files Browse the repository at this point in the history
  • Loading branch information
omlins committed Oct 30, 2024
1 parent b9e71be commit b129021
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ end
function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType, inbounds::Bool, indices::Union{Symbol,Expr}, kernel::Expr)
if (!isa(indices,Symbol) && !isa(indices.head,Symbol)) @ArgumentError("@parallel_indices: argument 'indices' must be a tuple of indices or a single index (e.g. (ix, iy, iz) or (ix, iy) or ix ).") end
indices = extract_tuple(indices)
padding = get_padding(caller)
body = get_body(kernel)
body = remove_return(body)
body = macroexpand(caller, body)
Expand All @@ -183,7 +184,7 @@ function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType,
end
if isgpu(package) kernel = insert_device_types(caller, kernel) end
kernel = adjust_signatures(kernel, package)
body = handle_padding(body, get_padding(caller)) # TODO: padding can later be made configurable per kernel (to enable working with arrays as before).
body = handle_padding(body, padding) # TODO: padding can later be made configurable per kernel (to enable working with arrays as before).
body = handle_indices_and_literals(body, indices, package, numbertype)
if (inbounds) body = add_inbounds(body) end
body = add_return(body)
Expand Down Expand Up @@ -365,13 +366,47 @@ function adjust_signatures(kernel::Expr, package::Symbol)
end

function handle_padding(body::Expr, padding::Bool)
body = substitute_indices_inn(body, padding)
if padding
body = substitute_firstlastindex(body)
body = substitute_view_accesses(body, INDICES)
end
return body
end

function substitute_indices_inn(body::Expr, padding::Bool)
for i=1:length(INDICES_INN)
index_inn = (padding) ? INDICES[i] : :($(INDICES[i]) + 1) # NOTE: expression of ixi with ix, etc.: if padding is not used, they must be shifted by 1.
body = substitute(body, INDICES_INN[i], index_inn)
end
return body
end

function substitute_firstlastindex(body::Expr)
padding = true
return postwalk(body) do ex
if @capture(ex, f_(args__))
if (f == :firstindex) return :(ParallelStencil.ParallelKernel.@firstindex($(args...), $padding))
elseif (f == :lastindex) return :(ParallelStencil.ParallelKernel.@lastindex($(args...), $padding))
else return ex
end
else
return ex
end
end
end

function substitute_view_accesses(expr::Expr, indices::NTuple{N,<:Union{Symbol,Expr}} where N)
return postwalk(expr) do ex
if is_access(ex, indices...)
@capture(ex, A_[indices_expr__]) || @ModuleInternalError("a stencil access could not be pattern matched.")
return :($A.parent[$(indices_expr...)])
else
return ex
end
end
end

function handle_indices_and_literals(body::Expr, indices::Array, package::Symbol, numbertype::DataType)
int_type = kernel_int_type(package)
ranges = [:($RANGES_VARNAME[1]), :($RANGES_VARNAME[2]), :($RANGES_VARNAME[3])]
Expand Down

0 comments on commit b129021

Please sign in to comment.