From 25f386f823c710b5cb069b12098793aa6f90ba98 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 2 Oct 2024 16:19:11 +0200 Subject: [PATCH] fix after rebase --- core/config/type_descriptor_helper.hpp | 3 +++ core/log/solver_progress.cpp | 8 ++++++++ cuda/solver/common_trs_kernels.cuh | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/core/config/type_descriptor_helper.hpp b/core/config/type_descriptor_helper.hpp index 0edc4376f1a..63a953e3a1e 100644 --- a/core/config/type_descriptor_helper.hpp +++ b/core/config/type_descriptor_helper.hpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -38,8 +39,10 @@ struct type_string {}; TYPE_STRING_OVERLOAD(void, "void"); TYPE_STRING_OVERLOAD(double, "float64"); TYPE_STRING_OVERLOAD(float, "float32"); +TYPE_STRING_OVERLOAD(half, "float16"); TYPE_STRING_OVERLOAD(std::complex, "complex"); TYPE_STRING_OVERLOAD(std::complex, "complex"); +TYPE_STRING_OVERLOAD(std::complex, "complex"); TYPE_STRING_OVERLOAD(int32, "int32"); TYPE_STRING_OVERLOAD(int64, "int64"); diff --git a/core/log/solver_progress.cpp b/core/log/solver_progress.cpp index effa0279bba..4d1566e159f 100644 --- a/core/log/solver_progress.cpp +++ b/core/log/solver_progress.cpp @@ -247,6 +247,14 @@ class SolverProgressStore : public SolverProgress { run, gko::matrix::Dense, gko::matrix::Dense>, gko::matrix::Dense>, +#if GINKGO_ENABLE_HALF + gko::matrix::Dense, + gko::matrix::Dense>, + gko::WritableToMatrixData, + gko::WritableToMatrixData, int32>, + gko::WritableToMatrixData, + gko::WritableToMatrixData, int64>, +#endif // fallback for other matrix types gko::WritableToMatrixData, gko::WritableToMatrixData, diff --git a/cuda/solver/common_trs_kernels.cuh b/cuda/solver/common_trs_kernels.cuh index 362d22a653c..3dea9bd457c 100644 --- a/cuda/solver/common_trs_kernels.cuh +++ b/cuda/solver/common_trs_kernels.cuh @@ -359,6 +359,10 @@ struct float_to_unsigned_impl { using type = uint32; }; +template <> +struct float_to_unsigned_impl<__half> { + using type = uint16; +}; /** * Checks if a floating point number representation matches the representation