|
6 | 6 |
|
7 | 7 | module JLArrays |
8 | 8 |
|
9 | | -export JLArray, JLVector, JLMatrix, jl, JLBackend |
| 9 | +export JLArray, JLVector, JLMatrix, jl, JLBackend, JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR |
10 | 10 |
|
11 | 11 | using GPUArrays |
12 | 12 |
|
13 | 13 | using Adapt |
| 14 | +using SparseArrays, LinearAlgebra |
| 15 | + |
| 16 | +import GPUArrays: _dense_array_type |
14 | 17 |
|
15 | 18 | import KernelAbstractions |
16 | 19 | import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config |
@@ -115,7 +118,90 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} |
115 | 118 | end |
116 | 119 | end |
117 | 120 |
|
| 121 | +mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, Ti} |
| 122 | + iPtr::JLArray{Ti, 1} |
| 123 | + nzVal::JLArray{Tv, 1} |
| 124 | + len::Int |
| 125 | + nnz::Ti |
| 126 | + |
| 127 | + function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1}, |
| 128 | + len::Integer) where {Tv, Ti <: Integer} |
| 129 | + new{Tv, Ti}(iPtr, nzVal, len, length(nzVal)) |
| 130 | + end |
| 131 | +end |
| 132 | +SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal)) |
| 133 | +SparseArrays.nnz(x::JLSparseVector) = x.nnz |
| 134 | +SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr |
| 135 | +SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal |
| 136 | + |
| 137 | +mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti} |
| 138 | + colPtr::JLArray{Ti, 1} |
| 139 | + rowVal::JLArray{Ti, 1} |
| 140 | + nzVal::JLArray{Tv, 1} |
| 141 | + dims::NTuple{2,Int} |
| 142 | + nnz::Ti |
| 143 | + |
| 144 | + function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1}, |
| 145 | + nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} |
| 146 | + new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal)) |
| 147 | + end |
| 148 | +end |
| 149 | +function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} |
| 150 | + return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims) |
| 151 | +end |
| 152 | +SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal)) |
| 153 | + |
| 154 | +JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A |
| 155 | + |
| 156 | +function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti} |
| 157 | + r1 = Int(@inbounds A.colPtr[j]) |
| 158 | + r2 = Int(@inbounds A.colPtr[j+1]-1) |
| 159 | + (r1 > r2) && return zero(Tv) |
| 160 | + r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1 |
| 161 | + ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1] |
| 162 | +end |
| 163 | + |
| 164 | +mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti} |
| 165 | + rowPtr::JLArray{Ti, 1} |
| 166 | + colVal::JLArray{Ti, 1} |
| 167 | + nzVal::JLArray{Tv, 1} |
| 168 | + dims::NTuple{2,Int} |
| 169 | + nnz::Ti |
| 170 | + |
| 171 | + function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1}, |
| 172 | + nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer} |
| 173 | + new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal)) |
| 174 | + end |
| 175 | +end |
| 176 | +function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} |
| 177 | + return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims) |
| 178 | +end |
| 179 | +function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR) |
| 180 | + x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal)) |
| 181 | + return SparseMatrixCSC(transpose(x_transpose)) |
| 182 | +end |
| 183 | + |
| 184 | +JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A |
| 185 | + |
118 | 186 | GPUArrays.storage(a::JLArray) = a.data |
| 187 | +GPUArrays._dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N} |
| 188 | +GPUArrays._dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N} |
| 189 | +GPUArrays._dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1} |
| 190 | +GPUArrays._dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1} |
| 191 | + |
| 192 | +GPUArrays._sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC |
| 193 | +GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC |
| 194 | +GPUArrays._sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR |
| 195 | +GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR |
| 196 | +GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector |
| 197 | +GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector |
| 198 | + |
| 199 | +GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray |
| 200 | +GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray |
| 201 | +GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray |
| 202 | +GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray |
| 203 | +GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray |
| 204 | +GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray |
119 | 205 |
|
120 | 206 | # conversion of untyped data to a typed Array |
121 | 207 | function typed_data(x::JLArray{T}) where {T} |
@@ -217,6 +303,41 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs) |
217 | 303 | (::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x) |
218 | 304 | JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A) |
219 | 305 |
|
| 306 | +function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv} |
| 307 | + iPtr = JLVector{Ti}(undef, length(xs.nzind)) |
| 308 | + nzVal = JLVector{Tv}(undef, length(xs.nzval)) |
| 309 | + copyto!(iPtr, convert(Vector{Ti}, xs.nzind)) |
| 310 | + copyto!(nzVal, convert(Vector{Tv}, xs.nzval)) |
| 311 | + return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),) |
| 312 | +end |
| 313 | +Base.length(x::JLSparseVector) = x.len |
| 314 | +Base.size(x::JLSparseVector) = (x.len,) |
| 315 | + |
| 316 | +function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv} |
| 317 | + colPtr = JLVector{Ti}(undef, length(xs.colptr)) |
| 318 | + rowVal = JLVector{Ti}(undef, length(xs.rowval)) |
| 319 | + nzVal = JLVector{Tv}(undef, length(xs.nzval)) |
| 320 | + copyto!(colPtr, convert(Vector{Ti}, xs.colptr)) |
| 321 | + copyto!(rowVal, convert(Vector{Ti}, xs.rowval)) |
| 322 | + copyto!(nzVal, convert(Vector{Tv}, xs.nzval)) |
| 323 | + return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n)) |
| 324 | +end |
| 325 | +Base.length(x::JLSparseMatrixCSC) = prod(x.dims) |
| 326 | +Base.size(x::JLSparseMatrixCSC) = x.dims |
| 327 | + |
| 328 | +function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv} |
| 329 | + csr_xs = SparseMatrixCSC(transpose(xs)) |
| 330 | + rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr)) |
| 331 | + colVal = JLVector{Ti}(undef, length(csr_xs.rowval)) |
| 332 | + nzVal = JLVector{Tv}(undef, length(csr_xs.nzval)) |
| 333 | + copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr)) |
| 334 | + copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval)) |
| 335 | + copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval)) |
| 336 | + return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n)) |
| 337 | +end |
| 338 | +Base.length(x::JLSparseMatrixCSR) = prod(x.dims) |
| 339 | +Base.size(x::JLSparseMatrixCSR) = x.dims |
| 340 | + |
220 | 341 | # idempotency |
221 | 342 | JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs |
222 | 343 |
|
@@ -358,9 +479,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br |
358 | 479 | R |
359 | 480 | end |
360 | 481 |
|
| 482 | +Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = |
| 483 | +GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz) |
| 484 | +Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} = |
| 485 | +GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz) |
| 486 | +Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} = |
| 487 | +GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz) |
| 488 | + |
361 | 489 | ## KernelAbstractions interface |
362 | 490 |
|
363 | 491 | KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend() |
| 492 | +KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend() |
364 | 493 |
|
365 | 494 | function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic |
366 | 495 | return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace) |
|
0 commit comments