Skip to content

Commit

Permalink
Merge pull request #263 from jsandham/csritilu0_cherry_pick
Browse files Browse the repository at this point in the history
Cherry pick Csritilu0 (#407)
  • Loading branch information
YvanMokwinski authored Oct 16, 2022
2 parents e10b786 + d75c149 commit 1f576a9
Show file tree
Hide file tree
Showing 66 changed files with 12,316 additions and 416 deletions.
1 change: 1 addition & 0 deletions clients/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ set(ROCSPARSE_CLIENTS_TESTINGS
../testings/testing_bsrilu0.cpp
../testings/testing_csric0.cpp
../testings/testing_csrilu0.cpp
../testings/testing_csritilu0.cpp
../testings/testing_gpsv_interleaved_batch.cpp
../testings/testing_gtsv.cpp
../testings/testing_gtsv_no_pivot.cpp
Expand Down
20 changes: 18 additions & 2 deletions clients/benchmarks/rocsparse_arguments_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* ************************************************************************ */

#include "rocsparse_arguments_config.hpp"

#include "rocsparse_enum.hpp"
rocsparse_arguments_config::rocsparse_arguments_config()
{
//
Expand Down Expand Up @@ -73,6 +73,7 @@ rocsparse_arguments_config::rocsparse_arguments_config()
this->order = static_cast<rocsparse_order>(0);
this->format = static_cast<rocsparse_format>(0);

this->itilu0_alg = rocsparse_itilu0_alg_default;
this->sddmm_alg = rocsparse_sddmm_alg_default;
this->spmv_alg = rocsparse_spmv_alg_default;
this->spsv_alg = rocsparse_spsv_alg_default;
Expand Down Expand Up @@ -275,7 +276,7 @@ void rocsparse_arguments_config::set_description(options_description& desc)
" Level2: bsrmv, bsrxmv, bsrsv, coomv, coomv_aos, csrmv, csrmv_managed, csrsv, coosv, ellmv, hybmv, gebsrmv, gemvi\n"
" Level3: bsrmm, bsrsm, gebsrmm, csrmm, csrmm_batched, coomm, coomm_batched, cscmm, cscmm_batched, csrsm, coosm, gemmi, sddmm\n"
" Extra: csrgeam, csrgemm, csrgemm_reuse\n"
" Preconditioner: bsric0, bsrilu0, csric0, csrilu0, gtsv, gtsv_no_pivot, gtsv_no_pivot_strided_batch, gtsv_interleaved_batch, gpsv_interleaved_batch\n"
" Preconditioner: bsric0, bsrilu0, csric0, csrilu0, csritilu0, gtsv, gtsv_no_pivot, gtsv_no_pivot_strided_batch, gtsv_interleaved_batch, gpsv_interleaved_batch\n"
" Conversion: csr2coo, csr2csc, gebsr2gebsc, csr2ell, csr2hyb, csr2bsr, csr2gebsr\n"
" coo2csr, ell2csr, hyb2csr, dense2csr, dense2coo, prune_dense2csr, prune_dense2csr_by_percentage, dense2csc\n"
" csr2dense, csc2dense, coo2dense, bsr2csr, gebsr2csr, gebsr2gebsr, csr2csr_compress, prune_csr2csr, prune_csr2csr_by_percentage\n"
Expand Down Expand Up @@ -349,6 +350,10 @@ void rocsparse_arguments_config::set_description(options_description& desc)
value<rocsparse_int>(&this->b_spmv_alg)->default_value(rocsparse_spmv_alg_default),
"Indicates what algorithm to use when running SpMV. Possibly choices are default: 0, COO: 1, CSR adaptive: 2, CSR stream: 3, ELL: 4, COO atomic: 5 (default:0)")

("itilu0_alg",
value<rocsparse_int>(&this->b_itilu0_alg)->default_value(rocsparse_itilu0_alg_default),
"Indicates what algorithm to use when running Iterative ILU0. see documentation.")

("spmm_alg",
value<rocsparse_int>(&this->b_spmm_alg)->default_value(rocsparse_spmm_alg_default),
"Indicates what algorithm to use when running SpMM. Possibly choices are default: 0, CSR: 1, COO segmented: 2, COO atomic: 3, CSR row split: 4, CSR merge: 5, COO segmented atomic: 6, BELL: 7 (default:0)")
Expand Down Expand Up @@ -390,6 +395,16 @@ int rocsparse_arguments_config::parse(int&argc,char**&argv, options_description&
return -1;
}

if (rocsparse_itilu0_alg_t::is_invalid(this->b_itilu0_alg))
{
std::cerr << "Invalid value '"
<< this->b_itilu0_alg
<< "' for --itilu0_alg, valid values are : (";
rocsparse_itilu0_alg_t::info(std::cerr);
std::cerr << ")" << std::endl;
return -1;
}

if(this->b_spmv_alg != rocsparse_spmv_alg_default
&& this->b_spmv_alg != rocsparse_spmv_alg_coo
&& this->b_spmv_alg != rocsparse_spmv_alg_csr_adaptive
Expand Down Expand Up @@ -472,6 +487,7 @@ int rocsparse_arguments_config::parse(int&argc,char**&argv, options_description&
this->order = (this->b_order == rocsparse_order_row) ? rocsparse_order_row : rocsparse_order_column;
this->format = (rocsparse_format)this->b_format;
this->spmv_alg = (rocsparse_spmv_alg)this->b_spmv_alg;
this->itilu0_alg = (rocsparse_itilu0_alg)this->b_itilu0_alg;
this->spmm_alg = (rocsparse_spmm_alg)this->b_spmm_alg;
this->gtsv_interleaved_alg = (rocsparse_gtsv_interleaved_alg)this->b_gtsv_interleaved_alg;

Expand Down
1 change: 1 addition & 0 deletions clients/benchmarks/rocsparse_arguments_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct rocsparse_arguments_config : Arguments
rocsparse_int b_dir{};
rocsparse_int b_order{};
rocsparse_int b_format{};
rocsparse_int b_itilu0_alg{};
rocsparse_int b_spmv_alg{};
rocsparse_int b_spmm_alg{};
rocsparse_int b_gtsv_interleaved_alg{};
Expand Down
1 change: 1 addition & 0 deletions clients/benchmarks/rocsparse_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void rocsparse_bench::parse(int& argc, char**& argv, rocsparse_arguments_config&
config.betai = 0.0;
config.threshold = 0.0;
config.percentage = 0.0;
config.itilu0_alg = rocsparse_itilu0_alg_default;
config.sddmm_alg = rocsparse_sddmm_alg_default;
config.spmv_alg = rocsparse_spmv_alg_default;
config.spsv_alg = rocsparse_spsv_alg_default;
Expand Down
2 changes: 2 additions & 0 deletions clients/benchmarks/rocsparse_routine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ constexpr const char* rocsparse_routine::to_string() const
#include "testing_bsrilu0.hpp"
#include "testing_csric0.hpp"
#include "testing_csrilu0.hpp"
#include "testing_csritilu0.hpp"
#include "testing_gpsv_interleaved_batch.hpp"
#include "testing_gtsv.hpp"
#include "testing_gtsv_interleaved_batch.hpp"
Expand Down Expand Up @@ -454,6 +455,7 @@ rocsparse_status rocsparse_routine::dispatch_call(const Arguments& arg)
DEFINE_CASE_T(csrcolor);
DEFINE_CASE_T(csric0);
DEFINE_CASE_T(csrilu0);
DEFINE_CASE_T(csritilu0);
DEFINE_CASE_T(csrgeam);
DEFINE_CASE_IJT_X(csrgemm, testing_spgemm_csr);
DEFINE_CASE_T(csrgemm_reuse);
Expand Down
1 change: 1 addition & 0 deletions clients/benchmarks/rocsparse_routine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ ROCSPARSE_DO_ROUTINE(csc2dense) \
ROCSPARSE_DO_ROUTINE(csrcolor) \
ROCSPARSE_DO_ROUTINE(csric0) \
ROCSPARSE_DO_ROUTINE(csrilu0) \
ROCSPARSE_DO_ROUTINE(csritilu0) \
ROCSPARSE_DO_ROUTINE(csrgeam) \
ROCSPARSE_DO_ROUTINE(csrgemm) \
ROCSPARSE_DO_ROUTINE(csrgemm_reuse) \
Expand Down
23 changes: 23 additions & 0 deletions clients/include/auto_testing_bad_arg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ inline void auto_testing_bad_arg_set_invalid(int32_t& p)
p = -1;
}

template <>
inline void auto_testing_bad_arg_set_invalid(size_t& p)
{
}

template <>
inline void auto_testing_bad_arg_set_invalid(int64_t& p)
{
Expand Down Expand Up @@ -132,6 +137,12 @@ inline rocsparse_status auto_testing_bad_arg_get_status(int32_t& p)
return rocsparse_status_invalid_size;
}

template <>
inline rocsparse_status auto_testing_bad_arg_get_status(size_t& p)
{
return rocsparse_status_invalid_size;
}

template <>
inline rocsparse_status auto_testing_bad_arg_get_status(int64_t& p)
{
Expand Down Expand Up @@ -234,6 +245,12 @@ inline rocsparse_status auto_testing_bad_arg_get_status(rocsparse_spgemm_alg& p)
return rocsparse_status_invalid_value;
}

template <>
inline rocsparse_status auto_testing_bad_arg_get_status(rocsparse_itilu0_alg& p)
{
return rocsparse_status_invalid_value;
}

template <>
inline rocsparse_status auto_testing_bad_arg_get_status(rocsparse_indextype& p)
{
Expand Down Expand Up @@ -396,6 +413,12 @@ inline void auto_testing_bad_arg_set_invalid(rocsparse_spgemm_alg& p)
p = (rocsparse_spgemm_alg)-1;
}

template <>
inline void auto_testing_bad_arg_set_invalid(rocsparse_itilu0_alg& p)
{
p = (rocsparse_itilu0_alg)-1;
}

template <>
inline void auto_testing_bad_arg_set_invalid(rocsparse_indextype& p)
{
Expand Down
26 changes: 26 additions & 0 deletions clients/include/rocsparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,32 @@ REAL_COMPLEX_TEMPLATE(csric0,
rocsparse_solve_policy policy,
void* temp_buffer);

// csritilu0_compute
REAL_COMPLEX_TEMPLATE(csritilu0_compute,
rocsparse_handle handle,
rocsparse_itilu0_alg alg,
rocsparse_int options,
rocsparse_int* nsweeps,
floating_data_t<T> tol_correction,
rocsparse_int m,
rocsparse_int nnz,
const rocsparse_int* ptr,
const rocsparse_int* ind,
const T* val,
T* ilu0,
rocsparse_index_base base,
size_t buffer_size_,
void* buffer);

// csritilu0_update
REAL_COMPLEX_TEMPLATE(csritilu0_history,
rocsparse_handle handle,
rocsparse_itilu0_alg alg,
rocsparse_int* niter,
floating_data_t<T>* data,
size_t buffer_size_,
void* buffer);

// csrilu0
REAL_COMPLEX_TEMPLATE(csrilu0_buffer_size,
rocsparse_handle handle,
Expand Down
3 changes: 3 additions & 0 deletions clients/include/rocsparse_arguments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct Arguments
rocsparse_direction direction;
rocsparse_order order;
rocsparse_format format;
rocsparse_itilu0_alg itilu0_alg;
rocsparse_sddmm_alg sddmm_alg;
rocsparse_spmv_alg spmv_alg;
rocsparse_spsv_alg spsv_alg;
Expand Down Expand Up @@ -207,6 +208,7 @@ struct Arguments
ROCSPARSE_FORMAT_CHECK(direction);
ROCSPARSE_FORMAT_CHECK(order);
ROCSPARSE_FORMAT_CHECK(format);
ROCSPARSE_FORMAT_CHECK(itilu0_alg);
ROCSPARSE_FORMAT_CHECK(sddmm_alg);
ROCSPARSE_FORMAT_CHECK(spmv_alg);
ROCSPARSE_FORMAT_CHECK(spsv_alg);
Expand Down Expand Up @@ -405,6 +407,7 @@ struct Arguments
print("direction", rocsparse_direction2string(arg.direction));
print("order", rocsparse_order2string(arg.order));
print("format", rocsparse_format2string(arg.format));
print("itilu0_alg", rocsparse_itilu0alg2string(arg.itilu0_alg));
print("sddmm_alg", rocsparse_sddmmalg2string(arg.sddmm_alg));
print("spmv_alg", rocsparse_spmvalg2string(arg.spmv_alg));
print("spsv_alg", rocsparse_spsvalg2string(arg.spsv_alg));
Expand Down
11 changes: 11 additions & 0 deletions clients/include/rocsparse_common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ Datatypes:
rocsparse_format_csr: 2
rocsparse_format_csc: 3
rocsparse_format_ell: 4
- rocsparse_itilu0_alg:
bases: [c_int ]
attr:
rocsparse_itilu0_alg_default: 0
rocsparse_itilu0_alg_async_inplace: 1
rocsparse_itilu0_alg_async_split: 2
rocsparse_itilu0_alg_sync_split: 3
rocsparse_itilu0_alg_sync_split_fusion: 4
- rocsparse_sddmm_alg:
bases: [c_int ]
attr:
Expand Down Expand Up @@ -297,6 +305,7 @@ Arguments:
- direction: rocsparse_direction
- order: rocsparse_order
- format: rocsparse_format
- itilu0_alg: rocsparse_itilu0_alg
- sddmm_alg: rocsparse_sddmm_alg
- spmv_alg: rocsparse_spmv_alg
- spsv_alg: rocsparse_spsv_alg
Expand Down Expand Up @@ -394,6 +403,7 @@ Defaults:
betai: 0.0
threshold: 1.0
percentage: 0.0
tol: 0.0
transA: rocsparse_operation_none
transB: rocsparse_operation_none
baseA: rocsparse_index_base_zero
Expand All @@ -411,6 +421,7 @@ Defaults:
direction: rocsparse_direction_row
order: rocsparse_order_column
format: rocsparse_format_coo
itilu0_alg: rocsparse_itilu0_alg_default
sddmm_alg: rocsparse_sddmm_alg_default
spmv_alg: rocsparse_spmv_alg_default
spsv_alg: rocsparse_spsv_alg_default
Expand Down
25 changes: 25 additions & 0 deletions clients/include/rocsparse_datatype2string.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,31 @@ constexpr auto rocsparse_sddmmalg2string(rocsparse_sddmm_alg alg)
return "invalid";
}

constexpr auto rocsparse_itilu0alg2string(rocsparse_itilu0_alg alg)
{
switch(alg)
{
case rocsparse_itilu0_alg_default:
case rocsparse_itilu0_alg_async_inplace:
{
return "async_inplace";
}
case rocsparse_itilu0_alg_async_split:
{
return "async_split";
}
case rocsparse_itilu0_alg_sync_split:
{
return "sync_split";
}
case rocsparse_itilu0_alg_sync_split_fusion:
{
return "sync_split_fusion";
}
}
return "invalid";
}

constexpr auto rocsparse_spmvalg2string(rocsparse_spmv_alg alg)
{
switch(alg)
Expand Down
92 changes: 92 additions & 0 deletions clients/include/rocsparse_enum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,68 @@
#include <hip/hip_runtime_api.h>
#include <vector>

struct rocsparse_itilu0_alg_t
{
using value_t = rocsparse_itilu0_alg;
static constexpr unsigned int nvalues = 5;
static constexpr value_t values[nvalues] = {rocsparse_itilu0_alg_default,
rocsparse_itilu0_alg_async_inplace,
rocsparse_itilu0_alg_async_split,
rocsparse_itilu0_alg_sync_split,
rocsparse_itilu0_alg_sync_split_fusion};

static void info(std::ostream& out_)
{
for(int i = 0; i < nvalues; ++i)
{
if(i > 0)
out_ << ", ";
const value_t v = values[i];
switch(v)
{
#define LOCAL_CASE(TOKEN) \
case TOKEN: \
{ \
out_ << TOKEN << " : " << #TOKEN; \
break; \
}

case rocsparse_itilu0_alg_default:
{
out_ << rocsparse_itilu0_alg_default << " : "
<< "rocsparse_itilu0_alg_default";
}
LOCAL_CASE(rocsparse_itilu0_alg_async_inplace);
LOCAL_CASE(rocsparse_itilu0_alg_async_split);
LOCAL_CASE(rocsparse_itilu0_alg_sync_split);
LOCAL_CASE(rocsparse_itilu0_alg_sync_split_fusion);
#undef LOCAL_CASE
}
}
};

static constexpr bool is_invalid(rocsparse_int value_)
{
return is_invalid((value_t)value_);
};

static constexpr bool is_invalid(value_t value_)
{
switch(value_)
{
case rocsparse_itilu0_alg_default:
case rocsparse_itilu0_alg_async_inplace:
case rocsparse_itilu0_alg_async_split:
case rocsparse_itilu0_alg_sync_split:
case rocsparse_itilu0_alg_sync_split_fusion:
{
return false;
}
}
return true;
}
};

struct rocsparse_matrix_type_t
{
using value_t = rocsparse_matrix_type;
Expand Down Expand Up @@ -64,4 +126,34 @@ struct rocsparse_storage_mode_t
std::ostream& operator<<(std::ostream& out, const rocsparse_operation& v);
std::ostream& operator<<(std::ostream& out, const rocsparse_direction& v);

struct rocsparse_datatype_t
{
using value_t = rocsparse_datatype;
template <typename T>
static inline rocsparse_datatype get();
};

template <>
inline rocsparse_datatype rocsparse_datatype_t::get<float>()
{
return rocsparse_datatype_f32_r;
}
template <>
inline rocsparse_datatype rocsparse_datatype_t::get<double>()
{
return rocsparse_datatype_f64_r;
}

template <>
inline rocsparse_datatype rocsparse_datatype_t::get<rocsparse_float_complex>()
{
return rocsparse_datatype_f32_c;
}

template <>
inline rocsparse_datatype rocsparse_datatype_t::get<rocsparse_double_complex>()
{
return rocsparse_datatype_f64_c;
}

#endif // ROCSPARSE_ENUM_HPP
Loading

0 comments on commit 1f576a9

Please sign in to comment.