Skip to content

Commit

Permalink
Remove duplicate env spec
Browse files Browse the repository at this point in the history
Improve threading in _set_interpolated_values_device

Try debugging

Fixes + use length(out)

Test dist remapping only

Fixes

Add pipeline back in
  • Loading branch information
charleskawczynski committed Jan 24, 2025
1 parent bbad885 commit 8c1c3f6
Showing 1 changed file with 104 additions and 114 deletions.
218 changes: 104 additions & 114 deletions ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,27 @@ function _set_interpolated_values_device!(
# FIXME: Avoid allocation of tuple
field_values = tuple(map(f -> Fields.field_values(f), fields)...)

purely_vertical_space = isnothing(interpolation_matrix)
num_horizontal_points =
purely_vertical_space ? 1 : prod(size(local_horiz_indices))
num_points = num_horizontal_points * length(vert_interpolation_weights)
max_threads = 256
nthreads = min(num_points, max_threads)
nblocks = cld(num_points, nthreads)
purely_vert_space = isnothing(interpolation_matrix)
nitems = length(out)

# For purely vertical spaces, `Nq` is not used, so we pass in -1 here.
_, Nq = purely_vert_space ? (-1, -1) : size(interpolation_matrix[1])
args = (
out,
interpolation_matrix,
local_horiz_indices,
vert_interpolation_weights,
vert_bounding_indices,
field_values,
Val(Nq),
)
threads = threads_via_occupancy(set_interpolated_values_kernel!, args)
p = linear_partition(nitems, threads)
auto_launch!(
set_interpolated_values_kernel!,
args;
threads_s = (nthreads),
blocks_s = (nblocks),
threads_s = (p.threads),
blocks_s = (p.blocks),
)
call_post_op_callback() && post_op_callback(
out,
Expand All @@ -60,41 +61,36 @@ function set_interpolated_values_kernel!(
vert_interpolation_weights,
vert_bounding_indices,
field_values,
)
# TODO: Check the memory access pattern. This was not optimized and likely inefficient!
::Val{Nq},
) where {Nq}
num_horiz = length(local_horiz_indices)
num_vert = length(vert_bounding_indices)
num_fields = length(field_values)

hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z
@inbounds begin
i_thread =
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x +
CUDA.threadIdx().x
inds = (num_vert, num_horiz, num_fields)

1 i_thread prod(inds) || return nothing

totalThreadsX = gridDim().x * blockDim().x
totalThreadsY = gridDim().y * blockDim().y
totalThreadsZ = gridDim().z * blockDim().z
(j, i, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I

_, Nq = size(I1)
CI = CartesianIndex
for i in hindex:totalThreadsX:num_horiz
CI = CartesianIndex
h = local_horiz_indices[i]
for j in vindex:totalThreadsY:num_vert
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && j num_vert && k num_fields
out[i, j, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, j, k] +=
I1[i, t] *
I2[i, s] *
(
A * field_values[k][CI(t, s, 1, v_lo, h)] +
B * field_values[k][CI(t, s, 1, v_hi, h)]
)
end
end
end
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
fvals = field_values[k]
out[i, j, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, j, k] +=
I1[i, t] *
I2[i, s] *
(
A * fvals[CI(t, s, 1, v_lo, h)] +
B * fvals[CI(t, s, 1, v_hi, h)]
)
end
end
return nothing
Expand All @@ -108,41 +104,36 @@ function set_interpolated_values_kernel!(
vert_interpolation_weights,
vert_bounding_indices,
field_values,
)
::Val{Nq},
) where {Nq}
# TODO: Check the memory access pattern. This was not optimized and likely inefficient!
num_horiz = length(local_horiz_indices)
num_vert = length(vert_bounding_indices)
num_fields = length(field_values)

hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z
@inbounds begin
i_thread =
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x +
CUDA.threadIdx().x
inds = (num_vert, num_horiz, num_fields)

1 i_thread prod(inds) || return nothing

totalThreadsX = gridDim().x * blockDim().x
totalThreadsY = gridDim().y * blockDim().y
totalThreadsZ = gridDim().z * blockDim().z
(j, i, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I

_, Nq = size(I)
CI = CartesianIndex
for i in hindex:totalThreadsX:num_horiz
CI = CartesianIndex
h = local_horiz_indices[i]
for j in vindex:totalThreadsY:num_vert
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && j num_vert && k num_fields
out[i, j, k] = 0
for t in 1:Nq
out[i, j, k] +=
I[i, t] *
I[i, s] *
(
A * field_values[k][CI(t, 1, 1, v_lo, h)] +
B * field_values[k][CI(t, 1, 1, v_hi, h)]
)
end
end
end
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
out[i, j, k] = 0
for t in 1:Nq
out[i, j, k] +=
I[i, t] *
I[i, s] *
(
A * field_values[k][CI(t, 1, 1, v_lo, h)] +
B * field_values[k][CI(t, 1, 1, v_hi, h)]
)
end
end
return nothing
Expand All @@ -156,29 +147,29 @@ function set_interpolated_values_kernel!(
vert_interpolation_weights,
vert_bounding_indices,
field_values,
)
::Val{Nq},
) where {Nq}
# TODO: Check the memory access pattern. This was not optimized and likely inefficient!
num_fields = length(field_values)
num_vert = length(vert_bounding_indices)

vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z
@inbounds begin
i_thread =
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x +
CUDA.threadIdx().x
inds = (num_vert, num_fields)

1 i_thread prod(inds) || return nothing

totalThreadsY = gridDim().y * blockDim().y
totalThreadsZ = gridDim().z * blockDim().z
(j, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I

CI = CartesianIndex
for j in vindex:totalThreadsY:num_vert
CI = CartesianIndex
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
if j num_vert && k num_fields
out[j, k] = (
A * field_values[k][CI(1, 1, 1, v_lo, 1)] +
B * field_values[k][CI(1, 1, 1, v_hi, 1)]
)
end
end
out[j, k] = (
A * field_values[k][CI(1, 1, 1, v_lo, 1)] +
B * field_values[k][CI(1, 1, 1, v_hi, 1)]
)
end
return nothing
end
Expand All @@ -205,18 +196,21 @@ function _set_interpolated_values_device!(
# FIXME: Avoid allocation of tuple
field_values = tuple(map(f -> Fields.field_values(f), fields)...)
nitems = length(out)
nthreads, nblocks = _configure_threadblock(nitems)

args = (
out,
local_horiz_interpolation_weights,
local_horiz_indices,
field_values,
Val(Nq),
)
threads = threads_via_occupancy(set_interpolated_values_kernel!, args)
p = linear_partition(nitems, threads)
auto_launch!(
set_interpolated_values_kernel!,
args;
threads_s = (nthreads),
blocks_s = (nblocks),
threads_s = p.threads,
blocks_s = p.blocks,
)
end

Expand All @@ -225,31 +219,29 @@ function set_interpolated_values_kernel!(
(I1, I2)::NTuple{2},
local_horiz_indices,
field_values,
)
::Val{Nq},
) where {Nq}
# TODO: Check the memory access pattern. This was not optimized and likely inefficient!
num_horiz = length(local_horiz_indices)
num_fields = length(field_values)

hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z
@inbounds begin
i_thread =
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x +
CUDA.threadIdx().x
inds = (num_horiz, num_fields)

totalThreadsX = gridDim().x * blockDim().x
totalThreadsZ = gridDim().z * blockDim().z
1 i_thread prod(inds) || return nothing

_, Nq = size(I1)
(i, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I

for i in hindex:totalThreadsX:num_horiz
h = local_horiz_indices[i]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && k num_fields
out[i, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, k] +=
I1[i, t] *
I2[i, s] *
field_values[k][CartesianIndex(t, s, 1, 1, h)]
end
end
out[i, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, k] +=
I1[i, t] *
I2[i, s] *
field_values[k][CartesianIndex(t, s, 1, 1, h)]
end
end
return nothing
Expand All @@ -260,29 +252,27 @@ function set_interpolated_values_kernel!(
(I,)::NTuple{1},
local_horiz_indices,
field_values,
)
::Val{Nq},
) where {Nq}
# TODO: Check the memory access pattern. This was not optimized and likely inefficient!
num_horiz = length(local_horiz_indices)
num_fields = length(field_values)

hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z
@inbounds begin
i_thread =
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x +
CUDA.threadIdx().x
inds = (num_horiz, num_fields)

totalThreadsX = gridDim().x * blockDim().x
totalThreadsZ = gridDim().z * blockDim().z
1 i_thread prod(inds) || return nothing

_, Nq = size(I)
(i, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I

for i in hindex:totalThreadsX:num_horiz
h = local_horiz_indices[i]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && k num_fields
out[i, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, k] +=
I[i, i] * field_values[k][CartesianIndex(t, 1, 1, 1, h)]
end
end
out[i, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, k] +=
I[i, i] * field_values[k][CartesianIndex(t, 1, 1, 1, h)]
end
end
return nothing
Expand Down

0 comments on commit 8c1c3f6

Please sign in to comment.