From daa7eb01e1f3fe60d0e0b9a643886f32d3c3ffe7 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 23 Dec 2024 13:05:22 -0600 Subject: [PATCH] Fix isSigned usage for scalar prints. (#201) Signed-off-by: Ilya Enkovich --- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index 21b4756b506e..33e1753e31b2 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -32,8 +32,10 @@ class TritonLLVMConversionTarget : public ConversionTarget { // TODO: This code is the same as the GPU-backend code. Consider refactoring. std::string getFormatSubstr(Value value, bool hex = false, - std::optional width = std::nullopt) { + std::optional width = std::nullopt, + bool isSigned = false) { Type type = value.getType(); + // If the `value` is a pointer, just return %p. if (isa(type)) { return "%p"; } @@ -52,23 +54,15 @@ std::string getFormatSubstr(Value value, bool hex = false, std::string prefix = "%"; if (width.has_value()) { prefix += std::to_string(*width); - } else if (hex) { - prefix += "0"; - prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); } if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { return prefix + "f"; - } else if (type.isSignedInteger()) { + } else if (type.isInteger()) { if (type.getIntOrFloatBitWidth() == 64) - return prefix + "lli"; + return prefix + (isSigned ? "lli" : "llu"); else - return prefix + "i"; - } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "llu"; - else - return prefix + "u"; + return prefix + (isSigned ? "i" : "u"); } assert(false && "not supported type"); return ""; @@ -163,7 +157,8 @@ static StringRef makeNullTerminatedString(StringRef s) { void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, std::array pid, StringRef prefix, - std::optional arg, bool hex = false) { + std::optional arg, bool hex = false, + bool isSigned = false) { assert(!prefix.empty() && "printf with empty string not supported"); auto loc = UnknownLoc::get(rewriter.getContext()); @@ -172,7 +167,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" << prefix; if (arg.has_value()) - os << getFormatSubstr(arg.value(), hex); + os << getFormatSubstr(arg.value(), hex, std::nullopt, isSigned); llvm::SmallString<64> formatStrNewline(formatStr); formatStrNewline.push_back('\n'); @@ -242,7 +237,8 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { std::nullopt); } else { createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(), - adaptor.getOperands()[0], op.getHex()); + adaptor.getOperands()[0], op.getHex(), + op.getIsSigned()[0]); } rewriter.eraseOp(op); return success();