Skip to content

Commit

Permalink
Adapt Type call broadcasting to a function
Browse files Browse the repository at this point in the history
This is a more generic solution to the existing `broadcasted` definition that fixes JuliaGPU#1761 (as suggested in JuliaGPU#1761 (comment)).
  • Loading branch information
simonbyrne committed Jul 14, 2023
1 parent 8af74bf commit e573f74
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
4 changes: 0 additions & 4 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,3 @@ Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} =

Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}, dims) where {N,T} =
CuArray{T}(undef, dims)

# broadcasting type ctors isn't GPU compatible
Broadcast.broadcasted(::CuArrayStyle{N}, f::Type{T}, args...) where {N, T} =
Broadcasted{CuArrayStyle{N}}((x...) -> T(x...), args, nothing)
4 changes: 4 additions & 0 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ struct CuRefType{T} <: Ref{DataType} end
Base.getindex(r::CuRefType{T}) where T = T
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue{<:Union{DataType,Type}}) = CuRefType{r[]}()

Adapt.adapt_structure(to::Adaptor, bc::Base.Broadcast.Broadcasted{Style, <:Any, Type{T}}) where {Style, T} =
Base.Broadcast.Broadcasted{Style}((x...) -> T(x...), adapt(to, bc.args), bc.axes)


Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)

Expand Down
14 changes: 14 additions & 0 deletions test/base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,17 @@ end
A = CuArray{ComplexF64}(undef, (2,2))
@test eltype(convert.(ComplexF32, A)) == ComplexF32
end

# https://github.com/JuliaGPU/CUDA.jl/issues/261
@testset "Broadcast Ref{<:Type}" begin
A = CuArray{ComplexF64}(undef, (2,2))
@test eltype(convert.(ComplexF32, A)) == ComplexF32
end

# https://github.com/JuliaGPU/CUDA.jl/issues/1761
@testset "Broadcast Type(args)" begin
A = CuArray{ComplexF64}(undef, (2,2))
@test eltype(ComplexF32.(A)) == ComplexF32
@test eltype(A .+ ComplexF32.(1)) == ComplexF64
@test eltype(ComplexF32.(A) .+ ComplexF32.(1)) == ComplexF32
end

0 comments on commit e573f74

Please sign in to comment.