From 05b9f23e2fd187c3ff51b69e556b75bd09c18a54 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Fri, 14 Jul 2023 15:45:44 -0700 Subject: [PATCH] Adapt Type call broadcasting to a function This is a more generic solution to the existing `broadcasted` definition that fixes #1761 (as suggested in https://github.com/JuliaGPU/CUDA.jl/issues/1761#issuecomment-1425496244). --- src/broadcast.jl | 4 ---- src/compiler/execution.jl | 4 ++++ test/base/broadcast.jl | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index dd5f83ddc1..702fcf0ad2 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -13,7 +13,3 @@ Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} = Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}, dims) where {N,T} = CuArray{T}(undef, dims) - -# broadcasting type ctors isn't GPU compatible -Broadcast.broadcasted(::CuArrayStyle{N}, f::Type{T}, args...) where {N, T} = - Broadcasted{CuArrayStyle{N}}((x...) -> T(x...), args, nothing) diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index ed44662bce..5859de2d88 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -140,6 +140,10 @@ struct CuRefType{T} <: Ref{DataType} end Base.getindex(r::CuRefType{T}) where T = T Adapt.adapt_structure(to::Adaptor, r::Base.RefValue{<:Union{DataType,Type}}) = CuRefType{r[]}() +# case where type is the function being broadcasted +Adapt.adapt_structure(to::Adaptor, bc::Base.Broadcast.Broadcasted{Style, <:Any, Type{T}}) where {Style, T} = + Base.Broadcast.Broadcasted{Style}((x...) -> T(x...), adapt(to, bc.args), bc.axes) + Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} = Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs) diff --git a/test/base/broadcast.jl b/test/base/broadcast.jl index 285a474070..96540d1b73 100644 --- a/test/base/broadcast.jl +++ b/test/base/broadcast.jl @@ -39,3 +39,17 @@ end A = CuArray{ComplexF64}(undef, (2,2)) @test eltype(convert.(ComplexF32, A)) == ComplexF32 end + +# https://github.com/JuliaGPU/CUDA.jl/issues/261 +@testset "Broadcast Ref{<:Type}" begin + A = CuArray{ComplexF64}(undef, (2,2)) + @test eltype(convert.(ComplexF32, A)) == ComplexF32 +end + +# https://github.com/JuliaGPU/CUDA.jl/issues/1761 +@testset "Broadcast Type(args)" begin + A = CuArray{ComplexF64}(undef, (2,2)) + @test eltype(ComplexF32.(A)) == ComplexF32 + @test eltype(A .+ ComplexF32.(1)) == ComplexF64 + @test eltype(ComplexF32.(A) .+ ComplexF32.(1)) == ComplexF32 +end