@@ -8,13 +8,9 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCO
88 CuSparseVecOrMat
99
1010using LinearAlgebra: BlasFloat
11- using SparseArrays: nonzeroinds, dimlub
12-
13- abstract type AbstractCuSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
14- const AbstractCuSparseVector{Tv, Ti} = AbstractCuSparseArray{Tv, Ti, 1 }
15- const AbstractCuSparseMatrix{Tv, Ti} = AbstractCuSparseArray{Tv, Ti, 2 }
16-
17- Base. convert (T:: Type{<:AbstractCuSparseArray} , m:: AbstractArray ) = m isa T ? m : T (m)
11+ using SparseArrays
12+ abstract type AbstractCuSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 1} end
13+ abstract type AbstractCuSparseMatrix{Tv, Ti} <: GPUArrays.AbstractGPUSparseArray{Tv, Ti, 2} end
1814
1915mutable struct CuSparseVector{Tv, Ti} <: AbstractCuSparseVector{Tv, Ti}
2016 iPtr:: CuVector{Ti}
@@ -34,7 +30,7 @@ function CUDA.unsafe_free!(xs::CuSparseVector)
3430 return
3531end
3632
37- mutable struct CuSparseMatrixCSC{Tv, Ti} <: AbstractCuSparseMatrix {Tv, Ti}
33+ mutable struct CuSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC {Tv, Ti}
3834 colPtr:: CuVector{Ti}
3935 rowVal:: CuVector{Ti}
4036 nzVal:: CuVector{Tv}
@@ -47,6 +43,11 @@ mutable struct CuSparseMatrixCSC{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti}
4743 end
4844end
4945
46+ SparseArrays. rowvals (g:: T ) where {T<: CuSparseVector } = nonzeroinds (g)
47+
48+ SparseArrays. rowvals (g:: CuSparseMatrixCSC ) = g. rowVal
49+ SparseArrays. getcolptr (S:: CuSparseMatrixCSC ) = S. colPtr
50+
5051CuSparseMatrixCSC (A:: CuSparseMatrixCSC ) = A
5152
5253function CUDA. unsafe_free! (xs:: CuSparseMatrixCSC )
6970!!! compat "CUDA 11"
7071 Support of indices type rather than `Cint` (`Int32`) requires at least CUDA 11.
7172"""
72- mutable struct CuSparseMatrixCSR{Tv, Ti} <: AbstractCuSparseMatrix {Tv, Ti}
73+ mutable struct CuSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR {Tv, Ti}
7374 rowPtr:: CuVector{Ti}
7475 colVal:: CuVector{Ti}
7576 nzVal:: CuVector{Tv}
@@ -91,6 +92,22 @@ function CUDA.unsafe_free!(xs::CuSparseMatrixCSR)
9192 return
9293end
9394
95+ GPUArrays. _sparse_array_type (sa:: CuSparseMatrixCSC ) = CuSparseMatrixCSC
96+ GPUArrays. _sparse_array_type (:: Type{<:CuSparseMatrixCSC} ) = CuSparseMatrixCSC
97+ GPUArrays. _sparse_array_type (sa:: CuSparseMatrixCSR ) = CuSparseMatrixCSR
98+ GPUArrays. _sparse_array_type (:: Type{<:CuSparseMatrixCSR} ) = CuSparseMatrixCSR
99+ GPUArrays. _sparse_array_type (sa:: CuSparseVector ) = CuSparseVector
100+ GPUArrays. _sparse_array_type (:: Type{<:CuSparseVector} ) = CuSparseVector
101+ GPUArrays. _dense_array_type (sa:: CuSparseVector ) = CuArray
102+ GPUArrays. _dense_array_type (:: Type{<:CuSparseVector} ) = CuArray
103+ GPUArrays. _dense_array_type (sa:: CuSparseMatrixCSC ) = CuArray
104+ GPUArrays. _dense_array_type (:: Type{<:CuSparseMatrixCSC} ) = CuArray
105+ GPUArrays. _dense_array_type (sa:: CuSparseMatrixCSR ) = CuArray
106+ GPUArrays. _dense_array_type (:: Type{<:CuSparseMatrixCSR} ) = CuArray
107+
108+ GPUArrays. _csc_type (sa:: CuSparseMatrixCSR ) = CuSparseMatrixCSC
109+ GPUArrays. _csr_type (sa:: CuSparseMatrixCSC ) = CuSparseMatrixCSR
110+
94111"""
95112Container to hold sparse matrices in block compressed sparse row (BSR) format on
96113the GPU. BSR format is also used in Intel MKL, and is suited to matrices that are
142159
143160CuSparseMatrixCOO (A:: CuSparseMatrixCOO ) = A
144161
145- mutable struct CuSparseArrayCSR{Tv, Ti, N} <: AbstractCuSparseArray {Tv, Ti, N}
162+ mutable struct CuSparseArrayCSR{Tv, Ti, N} <: GPUArrays.AbstractGPUSparseArray {Tv, Ti, N}
146163 rowPtr:: CuArray{Ti}
147164 colVal:: CuArray{Ti}
148165 nzVal:: CuArray{Tv}
308325
309326# # sparse array interface
310327
311- SparseArrays. nnz (g:: AbstractCuSparseArray ) = g. nnz
312- SparseArrays. nonzeros (g:: AbstractCuSparseArray ) = g. nzVal
313-
314- SparseArrays. nonzeroinds (g:: AbstractCuSparseVector ) = g. iPtr
315- SparseArrays. rowvals (g:: AbstractCuSparseVector ) = nonzeroinds (g)
316-
317- SparseArrays. rowvals (g:: CuSparseMatrixCSC ) = g. rowVal
318- SparseArrays. getcolptr (S:: CuSparseMatrixCSC ) = S. colPtr
319-
320328function SparseArrays. findnz (S:: MT ) where {MT <: AbstractCuSparseMatrix }
321329 S2 = CuSparseMatrixCOO (S)
322330 I = S2. rowInd
@@ -570,8 +578,8 @@ CuSparseMatrixCSC(x::Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuS
570578CuSparseMatrixCOO (x:: Adjoint{T,<:Union{CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO}} ) where {T} = CuSparseMatrixCOO (_spadjoint (parent (x)))
571579
572580# gpu to cpu
573- SparseArrays. SparseVector (x:: CuSparseVector ) = SparseVector (length (x), Array (nonzeroinds (x)), Array (nonzeros (x)))
574- SparseArrays. SparseMatrixCSC (x:: CuSparseMatrixCSC ) = SparseMatrixCSC (size (x)... , Array (x. colPtr), Array (rowvals (x)), Array (nonzeros (x)))
581+ SparseArrays. SparseVector (x:: CuSparseVector ) = SparseVector (length (x), Array (SparseArrays . nonzeroinds (x)), Array (SparseArrays . nonzeros (x)))
582+ SparseArrays. SparseMatrixCSC (x:: CuSparseMatrixCSC ) = SparseMatrixCSC (size (x)... , Array (x. colPtr), Array (SparseArrays . rowvals (x)), Array (SparseArrays . nonzeros (x)))
575583SparseArrays. SparseMatrixCSC (x:: CuSparseMatrixCSR ) = SparseMatrixCSC (CuSparseMatrixCSC (x)) # no direct conversion (gpu_CSR -> gpu_CSC -> cpu_CSC)
576584SparseArrays. SparseMatrixCSC (x:: CuSparseMatrixBSR ) = SparseMatrixCSC (CuSparseMatrixCSR (x)) # no direct conversion (gpu_BSR -> gpu_CSR -> gpu_CSC -> cpu_CSC)
577585SparseArrays. SparseMatrixCSC (x:: CuSparseMatrixCOO ) = SparseMatrixCSC (CuSparseMatrixCSC (x)) # no direct conversion (gpu_COO -> gpu_CSC -> cpu_CSC)
@@ -729,25 +737,48 @@ end
729737
730738# interop with device arrays
731739
740+ function GPUArrays. GPUSparseDeviceVector (iPtr:: CuDeviceVector{Ti, A} ,
741+ nzVal:: CuDeviceVector{Tv, A} ,
742+ len:: Int ,
743+ nnz:: Ti ) where {Ti, Tv, A}
744+ GPUArrays. GPUSparseDeviceVector {Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A} (iPtr, nzVal, len, nnz)
745+ end
746+
732747function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseVector )
733- return CuSparseDeviceVector (
748+ return GPUArrays . GPUSparseDeviceVector (
734749 adapt (to, x. iPtr),
735750 adapt (to, x. nzVal),
736751 length (x), x. nnz
737752 )
738753end
739754
755+ function GPUArrays. GPUSparseDeviceMatrixCSR (rowPtr:: CuDeviceVector{Ti, A} ,
756+ colVal:: CuDeviceVector{Ti, A} ,
757+ nzVal:: CuDeviceVector{Tv, A} ,
758+ dims:: NTuple{2, Int} ,
759+ nnz:: Ti ) where {Ti, Tv, A}
760+ GPUArrays. GPUSparseDeviceMatrixCSR {Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A} (rowPtr, colVal, nzVal, dims, nnz)
761+ end
762+
740763function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseMatrixCSR )
741- return CuSparseDeviceMatrixCSR (
764+ return GPUArrays . GPUSparseDeviceMatrixCSR (
742765 adapt (to, x. rowPtr),
743766 adapt (to, x. colVal),
744767 adapt (to, x. nzVal),
745768 size (x), x. nnz
746769 )
747770end
748771
772+ function GPUArrays. GPUSparseDeviceMatrixCSC (colPtr:: CuDeviceVector{Ti, A} ,
773+ rowVal:: CuDeviceVector{Ti, A} ,
774+ nzVal:: CuDeviceVector{Tv, A} ,
775+ dims:: NTuple{2, Int} ,
776+ nnz:: Ti ) where {Ti, Tv, A}
777+ GPUArrays. GPUSparseDeviceMatrixCSC {Tv, Ti, CuDeviceVector{Ti, A}, CuDeviceVector{Tv, A}, A} (colPtr, rowVal, nzVal, dims, nnz)
778+ end
779+
749780function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseMatrixCSC )
750- return CuSparseDeviceMatrixCSC (
781+ return GPUArrays . GPUSparseDeviceMatrixCSC (
751782 adapt (to, x. colPtr),
752783 adapt (to, x. rowVal),
753784 adapt (to, x. nzVal),
@@ -756,7 +787,7 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCSC)
756787end
757788
758789function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseMatrixBSR )
759- return CuSparseDeviceMatrixBSR (
790+ return GPUArrays . GPUSparseDeviceMatrixBSR (
760791 adapt (to, x. rowPtr),
761792 adapt (to, x. colVal),
762793 adapt (to, x. nzVal),
@@ -766,7 +797,7 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixBSR)
766797end
767798
768799function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseMatrixCOO )
769- return CuSparseDeviceMatrixCOO (
800+ return GPUArrays . GPUSparseDeviceMatrixCOO (
770801 adapt (to, x. rowInd),
771802 adapt (to, x. colInd),
772803 adapt (to, x. nzVal),
@@ -775,7 +806,7 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO)
775806end
776807
777808function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseArrayCSR )
778- return CuSparseDeviceArrayCSR (
809+ return GPUArrays . GPUSparseDeviceArrayCSR (
779810 adapt (to, x. rowPtr),
780811 adapt (to, x. colVal),
781812 adapt (to, x. nzVal),
0 commit comments