-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcasting a function returning an anonymous function with a constructor over CUDA arrays fails to compile, "not isbits" #2514
Comments
Further reduced MWE: struct Bar{T}
a::T
end
function main()
a = cu(zeros(5))
capture = Bar
function closure(arg)
capture(arg)
end
function kernel(f, x)
f(x[])
return
end
@cuda kernel(closure, a)
end So the problem is that you're passing a closure ( We could make the validation more lenient again by having it actually consider whether the type-unstable argument is unused in the LLVM IR -- something we removed after Maybe we should approach this differently. Julia converted this Broadcast to a All that said, I'm not sure it's worth the effort, because the generated broadcast kernel is broken anyway: Because of the (inferred) type instability, the broadcast returns
|
FWIW, that looks like: diff --git a/src/driver.jl b/src/driver.jl
index 9e05eb6..a4cff8f 100644
--- a/src/driver.jl
+++ b/src/driver.jl
@@ -88,8 +88,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
end
@timeit_debug to "Validation" begin
- check_method(job) # not optional
- validate && check_invocation(job)
+ check_method(job)
end
prepare_job!(job)
@@ -99,6 +98,10 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
ir, ir_meta = emit_llvm(job; libraries, toplevel, optimize, cleanup, only_entry, validate)
+ validate && @timeit_debug to "Validation" begin
+ check_invocation(job, ir_meta.entry)
+ end
+
if output == :llvm
if strip
@timeit_debug to "strip debug info" strip_debuginfo!(ir)
diff --git a/src/validation.jl b/src/validation.jl
index e1a355b..9f1f869 100644
--- a/src/validation.jl
+++ b/src/validation.jl
@@ -66,7 +66,7 @@ function explain_nonisbits(@nospecialize(dt), depth=1; maxdepth=10)
return msg
end
-function check_invocation(@nospecialize(job::CompilerJob))
+function check_invocation(@nospecialize(job::CompilerJob), entry::LLVM.Function)
sig = job.source.specTypes
ft = sig.parameters[1]
tt = Tuple{sig.parameters[2:end]...}
@@ -77,6 +77,9 @@ function check_invocation(@nospecialize(job::CompilerJob))
real_arg_i = 0
for (arg_i,dt) in enumerate(sig.parameters)
+ println(Core.stdout, arg_i)
+ println(Core.stdout, dt)
+
isghosttype(dt) && continue
Core.Compiler.isconstType(dt) && continue
real_arg_i += 1
@@ -89,9 +92,13 @@ function check_invocation(@nospecialize(job::CompilerJob))
end
if !isbitstype(dt)
- throw(KernelError(job, "passing and using non-bitstype argument",
- """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
- $(explain_nonisbits(dt))"""))
+ param = parameters(entry)[real_arg_i]
+ if !isempty(uses(param))
+ println(Core.stdout, string(entry))
+ throw(KernelError(job, "passing and using non-bitstype argument",
+ """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
+ $(explain_nonisbits(dt))"""))
+ end
end
end |
Is there maybe a way to avoid the type instability? Or should closures of this kind be warned against, and e.g. Zygote makes changes accordingly? |
Zygote capturing all kinds of things in closures is definitely not great, but it's too late to fix that. The type instability was a red herring though, even typing
|
Okay, yeah. So I think on the Zygote side there's a way to work around the specific issue I encountered. Aside from trying to fix anything in CUDA, is there an opportunity to provide a more helpful error message in this case? I would be willing to work on that. |
Isn't it already relatively helpful, pointing to the exact argument and field? In terms of actually fixing this, it would be possible to completely disable validation, as this code turns out to result in relatively compatible LLVM IR. But then we open the door towards accidentally using GPU pointers (from boxed objects), which is what this check was designed to combat... |
It's helpful if you already understand the internals, perhaps. As naive user: "Okay great, Type{Bar} is not isbits. What do I do with this information? How do I use the error to correct the code?" |
Describe the bug
When broadcasting a function that returns a function over arguments, the broadcast fails to compile with the following error:
To reproduce
The Minimal Working Example (MWE) for this bug:
Manifest.toml
Expected behavior
Version info
Details on Julia:
Details on CUDA:
Additional context
Encountered when Zygote differentiates over a broadcast of the form
Bar.(a, b)
, viabroadcast_forward
wherea
andb
are GPU arrays.The text was updated successfully, but these errors were encountered: