Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting a function returning an anonymous function with a constructor over CUDA arrays fails to compile, "not isbits" #2514

Closed
BioTurboNick opened this issue Oct 3, 2024 · 7 comments · Fixed by JuliaGPU/GPUCompiler.jl#638
Labels
bug Something isn't working

Comments

@BioTurboNick
Copy link

BioTurboNick commented Oct 3, 2024

Describe the bug

When broadcasting a function that returns a function over arguments, the broadcast fails to compile with the following error:

ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#34#36")(::CUDA.CuKernelContext, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}}, var"#3#4"{Type{Bar}}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
  .f is of type var"#3#4"{Type{Bar}} which is not isbits.
    .f is of type Type{Bar} which is not isbits.


Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\validation.jl:92
  [2] macro expansion
    @ C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:123 [inlined]
  [3] macro expansion
    @ C:\Users\nicho\.julia\packages\TimerOutputs\Lw5SP\src\TimerOutput.jl:253 [inlined]
  [4]
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:121
  [5] codegen
    @ C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:110 [inlined]
  [6]
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:106
  [7] compile
    @ C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:98 [inlined]
  [8] #1072
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\compilation.jl:247 [inlined]
  [9] JuliaContext(f::CUDA.var"#1072#1075"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:47
 [10] compile(job::GPUCompiler.CompilerJob)
    @ CUDA C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\compilation.jl:246
 [11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\execution.jl:125
 [12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler C:\Users\nicho\.julia\packages\GPUCompiler\U36Ed\src\execution.jl:103
 [13] macro expansion
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:367 [inlined]
 [14] macro expansion
    @ .\lock.jl:267 [inlined]
 [15] cufunction(f::GPUArrays.var"#34#36", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
    @ CUDA C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:362
 [16] cufunction(f::GPUArrays.var"#34#36", tt::Type{Tuple{…}})
    @ CUDA C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:359
 [17] macro expansion
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\compiler\execution.jl:112 [inlined]
 [18] #launch_heuristic#1122
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\gpuarrays.jl:17 [inlined]
 [19] launch_heuristic
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\gpuarrays.jl:15 [inlined]
 [20] _copyto!
    @ C:\Users\nicho\.julia\packages\GPUArrays\OqrUV\src\host\broadcast.jl:78 [inlined]
 [21] copyto!
    @ C:\Users\nicho\.julia\packages\GPUArrays\OqrUV\src\host\broadcast.jl:44 [inlined]
 [22] copy
    @ C:\Users\nicho\.julia\packages\GPUArrays\OqrUV\src\host\broadcast.jl:29 [inlined]
 [23] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, var"#3#4"{…}, Tuple{…}})
    @ Base.Broadcast .\broadcast.jl:903
 [24] top-level scope
    @ REPL[11]:1
 [25] top-level scope
    @ C:\Users\nicho\.julia\packages\CUDA\htRwP\src\initialization.jl:206
Some type information was truncated. Use `show(err)` to see complete types.

To reproduce

The Minimal Working Example (MWE) for this bug:

using CUDA

struct Bar{T}
a::T
b::T
end

foo(f) = (args...) -> f(args...)

a = cu(zeros(5)); b = cu(ones(5)); c = Bar; d = foo(c)

c.(a, b) # works, produces GPU array

foo(c).(collect(a), collect(b)) # works, produces CPU array

((args...) -> Bar(args...)).(a, b) # works, produces GPU array

foo(c).(a, b) # fails
Manifest.toml

Status `C:\Users\nicho\.julia\environments\v1.10\Manifest.toml`
  [621f4979] AbstractFFTs v1.5.0
  [79e6a3ab] Adapt v4.0.4
  [a9b6321e] Atomix v0.1.0
  [ab4f0b2a] BFloat16s v0.5.0
  [fa961155] CEnum v0.5.0
  [052768ef] CUDA v5.5.2
  [1af6417a] CUDA_Runtime_Discovery v0.3.5
  [da1fd8a2] CodeTracking v1.3.6
  [3da002f7] ColorTypes v0.11.5
  [5ae59095] Colors v0.12.11
  [34da2185] Compat v4.16.0
  [a8cc5b0e] Crayons v4.1.1
  [9a962f9c] DataAPI v1.16.0
  [a93c6f00] DataFrames v1.7.0
  [864edb3b] DataStructures v0.18.20
  [e2d170a0] DataValueInterfaces v1.0.0
  [e2ba6199] ExprTools v0.1.10
  [53c48c17] FixedPointNumbers v0.8.5
  [0c68f7d7] GPUArrays v10.3.1
  [46192b85] GPUArraysCore v0.1.6
  [61eb1bfa] GPUCompiler v0.27.8
  [842dd82b] InlineStrings v1.4.2
  [41ab1584] InvertedIndices v1.3.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.6.0
  [aa1ae85d] JuliaInterpreter v0.9.36
  [63c18a36] KernelAbstractions v0.9.27
  [929cbde3] LLVM v9.1.2
  [8b046642] LLVMLoopInfo v1.0.0
  [b964fa9f] LaTeXStrings v1.3.1
  [6f1432cf] LoweredCodeUtils v3.0.2
  [1914dd2f] MacroTools v0.5.13
  [e1d29d7a] Missings v1.2.0
  [5da4648a] NVTX v0.3.4
  [bac558e1] OrderedCollections v1.6.3
  [2dfb63ee] PooledArrays v1.4.3
  [aea7be01] PrecompileTools v1.2.1
  [21216c6a] Preferences v1.4.3
  [08abe8d2] PrettyTables v2.4.0
  [74087812] Random123 v1.7.0
  [e6cf234a] RandomNumbers v1.6.0
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [295af30f] Revise v3.6.0
  [6c6a2e73] Scratch v1.2.1
  [91c51154] SentinelArrays v1.4.5
  [a2af1166] SortingAlgorithms v1.2.1
  [90137ffa] StaticArrays v1.9.7
  [1e83bf80] StaticArraysCore v1.4.3
  [892a3eda] StringManipulation v0.4.0
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.12.0
  [a759f4b9] TimerOutputs v0.5.24
  [013be700] UnsafeAtomics v0.2.1
  [d80eeb9a] UnsafeAtomicsLLVM v0.2.1
  [4ee394cb] CUDA_Driver_jll v0.10.3+0
  [76a88914] CUDA_Runtime_jll v0.15.3+0
  [9c1d0b0a] JuliaNVTXCallbacks_jll v0.2.1+0
  [dad2f222] LLVMExtra_jll v0.0.34+0
  [e98f9f5b] NVTX_jll v3.1.0+2
  [1e29f10c] demumble_jll v1.3.0+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.4
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.10.0
  [de0858da] Printf
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [6462fe0b] Sockets
  [2f01184e] SparseArrays v1.10.0
  [10745b16] Statistics v1.10.0
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.1.1+0
  [deac9b47] LibCURL_jll v8.4.0+0
  [e37daf67] LibGit2_jll v1.6.4+0
  [29816b5a] LibSSH2_jll v1.11.0+1
  [c8ffd9c3] MbedTLS_jll v2.28.2+1
  [14a3606d] MozillaCACerts_jll v2023.1.10
  [4536629a] OpenBLAS_jll v0.3.23+4
  [bea87d4a] SuiteSparse_jll v7.2.1+1
  [83775a58] Zlib_jll v1.2.13+1
  [8e850b90] libblastrampoline_jll v5.8.0+1
  [8e850ede] nghttp2_jll v1.52.0+1
  [3f19e933] p7zip_jll v17.4.0+2

Expected behavior

c.(a, b) == foo(c).(a, b)

Version info

Details on Julia:

Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900KF
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, goldmont)
Threads: 1 default, 0 interactive, 1 GC (on 32 virtual cores)

Details on CUDA:

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.94.0

CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.94

Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

1 device:
  0: NVIDIA GeForce RTX 4080 (sm_89, 13.307 GiB / 15.992 GiB available)

Additional context

Encountered when Zygote differentiates over a broadcast of the form Bar.(a, b), via broadcast_forward where a and b are GPU arrays.

@BioTurboNick BioTurboNick added the bug Something isn't working label Oct 3, 2024
@BioTurboNick BioTurboNick changed the title Broadcasting an anonymous function containing a constructor over CUDA arrays fails to compile, "not isbitstype" Broadcasting an anonymous function containing a constructor over CUDA arrays fails to compile, "not isbits" Oct 3, 2024
@BioTurboNick BioTurboNick changed the title Broadcasting an anonymous function containing a constructor over CUDA arrays fails to compile, "not isbits" Broadcasting a function returning an anonymous function with a constructor over CUDA arrays fails to compile, "not isbits" Oct 3, 2024
@maleadt
Copy link
Member

maleadt commented Oct 7, 2024

Further reduced MWE:

struct Bar{T}
    a::T
end

function main()
    a = cu(zeros(5))

    capture = Bar
    function closure(arg)
        capture(arg)
    end

    function kernel(f, x)
        f(x[])
        return
    end
    @cuda kernel(closure, a)
end

So the problem is that you're passing a closure (closure in my MWE, the result of foo(c) in yours) which captures a type-unstable variable (capture resp. c, both ::Type{Bar} with unbound type vars). Because that type is contained in a closure, it doesn't pass Core.Compiler.isconstType, so we don't filter it out during validation.

We could make the validation more lenient again by having it actually consider whether the type-unstable argument is unused in the LLVM IR -- something we removed after Core.Compiler.isconstType in JuliaGPU/GPUCompiler.jl#24 -- however that only fixes my MWE here, and not yours, because the Broadcasted argument (which is now also type-unstable) is used. So that approach doesn't cut it.

Maybe we should approach this differently. Julia converted this Broadcast to a { [1 x {}*], [2 x { { i8 addrspace(1)*, i64, [1 x i64], i64 }, [1 x i8], [1 x i64] }], [1 x [1 x i64]] } %1, but it seems hard to check whether that managed pointer in there is the only unused field...

All that said, I'm not sure it's worth the effort, because the generated broadcast kernel is broken anyway: Because of the (inferred) type instability, the broadcast returns Any, which doesn't work on the GPU anyway:

ERROR: ArgumentError: Broadcast operation resulting in Any is not GPU compatible
Stacktrace:
 [1] _copyto!
   @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:86 [inlined]
 [2] copyto!
   @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:44 [inlined]
 [3] copy
   @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:29 [inlined]
 [4] materialize
   @ ./broadcast.jl:903 [inlined]

@maleadt
Copy link
Member

maleadt commented Oct 7, 2024

We could make the validation more lenient again by having it actually consider whether the type-unstable argument is unused in the LLVM IR -- something we removed after Core.Compiler.isconstType in JuliaGPU/GPUCompiler.jl#24

FWIW, that looks like:

diff --git a/src/driver.jl b/src/driver.jl
index 9e05eb6..a4cff8f 100644
--- a/src/driver.jl
+++ b/src/driver.jl
@@ -88,8 +88,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
     end

     @timeit_debug to "Validation" begin
-        check_method(job)   # not optional
-        validate && check_invocation(job)
+        check_method(job)
     end

     prepare_job!(job)
@@ -99,6 +98,10 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool

     ir, ir_meta = emit_llvm(job; libraries, toplevel, optimize, cleanup, only_entry, validate)

+    validate && @timeit_debug to "Validation" begin
+        check_invocation(job, ir_meta.entry)
+    end
+
     if output == :llvm
         if strip
             @timeit_debug to "strip debug info" strip_debuginfo!(ir)
diff --git a/src/validation.jl b/src/validation.jl
index e1a355b..9f1f869 100644
--- a/src/validation.jl
+++ b/src/validation.jl
@@ -66,7 +66,7 @@ function explain_nonisbits(@nospecialize(dt), depth=1; maxdepth=10)
     return msg
 end

-function check_invocation(@nospecialize(job::CompilerJob))
+function check_invocation(@nospecialize(job::CompilerJob), entry::LLVM.Function)
     sig = job.source.specTypes
     ft = sig.parameters[1]
     tt = Tuple{sig.parameters[2:end]...}
@@ -77,6 +77,9 @@ function check_invocation(@nospecialize(job::CompilerJob))
     real_arg_i = 0

     for (arg_i,dt) in enumerate(sig.parameters)
+        println(Core.stdout, arg_i)
+        println(Core.stdout, dt)
+
         isghosttype(dt) && continue
         Core.Compiler.isconstType(dt) && continue
         real_arg_i += 1
@@ -89,9 +92,13 @@ function check_invocation(@nospecialize(job::CompilerJob))
         end

         if !isbitstype(dt)
-            throw(KernelError(job, "passing and using non-bitstype argument",
-                """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
-                    $(explain_nonisbits(dt))"""))
+            param = parameters(entry)[real_arg_i]
+            if !isempty(uses(param))
+                println(Core.stdout, string(entry))
+                throw(KernelError(job, "passing and using non-bitstype argument",
+                      """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
+                         $(explain_nonisbits(dt))"""))
+             end
         end
     end

@BioTurboNick
Copy link
Author

Is there maybe a way to avoid the type instability? Or should closures of this kind be warned against, and e.g. Zygote makes changes accordingly?

@maleadt
Copy link
Member

maleadt commented Oct 7, 2024

Zygote capturing all kinds of things in closures is definitely not great, but it's too late to fix that.

The type instability was a red herring though, even typing c wouldn't resolve this because the type object itself is the problematic one:

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, var"#16#18"{Type{Bar{Float32}}}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
  .f is of type var"#16#18"{Type{Bar{Float32}}} which is not isbits.
    .f is of type Type{Bar{Float32}} which is not isbits.

@BioTurboNick
Copy link
Author

Okay, yeah. So I think on the Zygote side there's a way to work around the specific issue I encountered. Aside from trying to fix anything in CUDA, is there an opportunity to provide a more helpful error message in this case? I would be willing to work on that.

@maleadt
Copy link
Member

maleadt commented Oct 7, 2024

Isn't it already relatively helpful, pointing to the exact argument and field?

In terms of actually fixing this, it would be possible to completely disable validation, as this code turns out to result in relatively compatible LLVM IR. But then we open the door towards accidentally using GPU pointers (from boxed objects), which is what this check was designed to combat...

@BioTurboNick
Copy link
Author

It's helpful if you already understand the internals, perhaps.

As naive user: "Okay great, Type{Bar} is not isbits. What do I do with this information? How do I use the error to correct the code?"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants