diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index bd2804830172..826b3d94e0d9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1381,21 +1381,37 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) } // Type code is kFloat switch (op->dtype.bits()) { - case 64: + case 64: { + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << "CUDART_INF"; + p->need_math_constants_h_ = true; + } else if (std::isnan(op->value)) { + temp << "CUDART_NAN"; + p->need_math_constants_h_ = true; + } else { + temp << std::fixed << std::setprecision(15) << op->value; + } + p->MarkConst(temp.str()); + os << temp.str(); + break; + } case 32: { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { temp << "-"; } - temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); + temp << "CUDART_INF_F"; p->need_math_constants_h_ = true; } else if (std::isnan(op->value)) { - temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); + temp << "CUDART_NAN_F"; p->need_math_constants_h_ = true; } else { - temp << std::scientific << op->value; - if (op->dtype.bits() == 32) temp << 'f'; + temp << std::scientific << op->value << 'f'; } p->MarkConst(temp.str()); os << temp.str();