Skip to content

Commit

Permalink
Fix isSigned usage for scalar prints. (#201)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Dec 23, 2024
1 parent 4e4e6b8 commit daa7eb0
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ class TritonLLVMConversionTarget : public ConversionTarget {

// TODO: This code is the same as the GPU-backend code. Consider refactoring.
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt) {
std::optional<int> width = std::nullopt,
bool isSigned = false) {
Type type = value.getType();
// If the `value` is a pointer, just return %p.
if (isa<LLVM::LLVMPointerType>(type)) {
return "%p";
}
Expand All @@ -52,23 +54,15 @@ std::string getFormatSubstr(Value value, bool hex = false,
std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
} else if (hex) {
prefix += "0";
prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4);
}

if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return prefix + "f";
} else if (type.isSignedInteger()) {
} else if (type.isInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "lli";
return prefix + (isSigned ? "lli" : "llu");
else
return prefix + "i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "llu";
else
return prefix + "u";
return prefix + (isSigned ? "i" : "u");
}
assert(false && "not supported type");
return "";
Expand Down Expand Up @@ -163,7 +157,8 @@ static StringRef makeNullTerminatedString(StringRef s) {

void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter,
std::array<Value, 3> pid, StringRef prefix,
std::optional<Value> arg, bool hex = false) {
std::optional<Value> arg, bool hex = false,
bool isSigned = false) {
assert(!prefix.empty() && "printf with empty string not supported");
auto loc = UnknownLoc::get(rewriter.getContext());

Expand All @@ -172,7 +167,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter,
os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1])
<< ", " << getFormatSubstr(pid[2]) << ")" << prefix;
if (arg.has_value())
os << getFormatSubstr(arg.value(), hex);
os << getFormatSubstr(arg.value(), hex, std::nullopt, isSigned);

llvm::SmallString<64> formatStrNewline(formatStr);
formatStrNewline.push_back('\n');
Expand Down Expand Up @@ -242,7 +237,8 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::cpu::PrintOp> {
std::nullopt);
} else {
createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(),
adaptor.getOperands()[0], op.getHex());
adaptor.getOperands()[0], op.getHex(),
op.getIsSigned()[0]);
}
rewriter.eraseOp(op);
return success();
Expand Down

0 comments on commit daa7eb0

Please sign in to comment.