You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I tried coding an argmax using KernelAbstraction in need for particles simulation. Sadly, the results from Metal and CPU differ.
Basically I have a field::Array{Float32, 4} and I want to compute in parallel argmax(field[x1,x2,x3,:])cfor many (basically Nnmc ) vectors (x1,x2,x3) in parallel. In the code below, this vector is fixed x1,x2,x3 = (1, 1, 1).
I found that the argmax differ whether the code is run on CPU or on Metal and only if field is large enough. This is the bulk of the issue.
using Revise, LinearAlgebra
using Metal
using KernelAbstractions
function_sample_gpu(field;
Nnmc =1000,
TA = Array
)
result =TA(zeros(Float32, 2, Nnmc))
npb =size(field, 4)
# launch gpu kernel
backend =get_backend(result)
nth = backend isa KernelAbstractions.GPU ?256:8
kernel! =_sample_mtl!(backend, nth)
kernel!(result,
TA(field),
npb,
ndrange = Nnmc
)
result
end@kernelfunction_sample_mtl!(result,
@Const(field),
nd,
)
nₙₘ =@index(Global)
voxel₁ = voxel₂ = voxel₃ =1# compute argmax of field[voxel₁, voxel₂, voxel₃, :]
_val_max::Float32=0f0
ind_u =0for ii inaxes(field, 4)
val = field[voxel₁, voxel₂, voxel₃, ii]
if val > _val_max
_val_max = val
ind_u = ii
endend
result[1, nₙₘ] = nₙₘ
# save argmax
result[2, nₙₘ] = ind_u
end
all_od =Float32.(rand(Float32,100,108,100, 1000));
res_a =_sample_gpu(all_od,
)
res_g =_sample_gpu(all_od,
TA = MtlArray,
) |> Array
norm(res_g[2,:] - res_a[2,:], Inf)
# returns 232.0f0
If the field is smaller the discrepancy seems to disappear:
Hi,
I tried coding an argmax using
KernelAbstraction
in need for particles simulation. Sadly, the results from Metal and CPU differ.Basically I have a
field::Array{Float32, 4}
and I want to compute in parallelargmax(field[x1,x2,x3,:])
cfor many (basicallyNnmc
) vectors(x1,x2,x3)
in parallel. In the code below, this vector is fixedx1,x2,x3 = (1, 1, 1)
.I found that the argmax differ whether the code is run on CPU or on Metal and only if
field
is large enough. This is the bulk of the issue.If the
field
is smaller the discrepancy seems to disappear:The text was updated successfully, but these errors were encountered: