From 28fe95255ad8eb00bb9fc582ad4adb56785fa4be Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sat, 14 Dec 2024 00:15:48 -0400 Subject: [PATCH] Fix global linear indexing (`fill!`) --- src/MetalKernels.jl | 4 +++- test/array.jl | 14 ++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/MetalKernels.jl b/src/MetalKernels.jl index fbe8b25e9..cdde3d0f2 100644 --- a/src/MetalKernels.jl +++ b/src/MetalKernels.jl @@ -140,7 +140,9 @@ end end @device_override @inline function KA.__index_Global_Linear(ctx) - return thread_position_in_grid_1d() + I = @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid_1d(), thread_position_in_threadgroup_1d()) + # TODO: This is unfortunate, can we get the linear index cheaper + @inbounds LinearIndices(KA.__ndrange(ctx))[I] end @device_override @inline function KA.__index_Local_Cartesian(ctx) diff --git a/test/array.jl b/test/array.jl index 247b72b8f..4abbb9ea9 100644 --- a/test/array.jl +++ b/test/array.jl @@ -229,15 +229,12 @@ end @testset "fill($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float16, Float32] - broken466a = T ∉ [Int8,UInt8] - broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation) - b = rand(T) # Dims in tuple let A = Metal.fill(b, (10, 10, 10, 1000)) B = fill(b, (10, 10, 10, 1000)) - @test Array(A) == B broken=(broken466a && broken466b) + @test Array(A) == B end let M = Metal.fill(b, (10, 10)) @@ -253,7 +250,7 @@ end #Dims already unpacked let A = Metal.fill(b, 10, 1000, 1000) B = fill(b, 10, 1000, 1000) - @test Array(A) == B broken=broken466a + @test Array(A) == B end let M = Metal.fill(b, 10, 10) @@ -269,15 +266,12 @@ end @testset "fill!($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float16, Float32] - broken466a = T ∉ [Int8,UInt8] - broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation) - b = rand(T) # Dims in tuple let A = MtlArray{T,3}(undef, (10, 1000, 1000)) fill!(A, b) - @test all(Array(A) .== b) broken=broken466a + @test all(Array(A) .== b) end let M = MtlMatrix{T}(undef, (10, 10)) @@ -293,7 +287,7 @@ end # Dims already unpacked let A = MtlArray{T,4}(undef, 10, 10, 10, 1000) fill!(A, b) - @test all(Array(A) .== b) broken=(broken466a && broken466b) + @test all(Array(A) .== b) end let M = MtlMatrix{T}(undef, 10, 10)