Skip to content

Commit

Permalink
Revert unintended change to vadd example script.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jul 17, 2024
1 parent 3b16be4 commit 116b1bf
Showing 1 changed file with 13 additions and 25 deletions.
38 changes: 13 additions & 25 deletions examples/vadd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,21 @@ using Test
using Metal

function vadd(a, b, c)
i0 = Tuple(thread_position_in_grid_3d())
stride = Tuple(threads_per_grid_3d())
is = i0
while 1 <= is[1] <= size(a, 1) &&
1 <= is[2] <= size(a, 2) &&
1 <= is[3] <= size(a, 3)
I = CartesianIndex(is)
c[I] = a[I] + b[I]
is = (is[1] + stride[1],
is[2] + stride[2],
is[3] + stride[3])
end
i = thread_position_in_grid_1d()
c[i] = a[i] + b[i]
return
end

function main()
dims = (3,4,5)
a = round.(rand(Float32, dims) * 100)
b = round.(rand(Float32, dims) * 100)
c = similar(a)
dims = (3,4)
a = round.(rand(Float32, dims) * 100)
b = round.(rand(Float32, dims) * 100)
c = similar(a)

d_a = MtlArray(a)
d_b = MtlArray(b)
d_c = MtlArray(c)
d_a = MtlArray(a)
d_b = MtlArray(b)
d_c = MtlArray(c)

len = prod(dims)
@metal threads=dims vadd(d_a, d_b, d_c)
c = Array(d_c)
@test a+b c
end
len = prod(dims)
@metal threads=len vadd(d_a, d_b, d_c)
c = Array(d_c)
@test a+b c

0 comments on commit 116b1bf

Please sign in to comment.