Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix isSigned usage for scalar prints. #201

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ class TritonLLVMConversionTarget : public ConversionTarget {

// TODO: This code is the same as the GPU-backend code. Consider refactoring.
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt) {
std::optional<int> width = std::nullopt,
bool isSigned = false) {
Type type = value.getType();
// If the `value` is a pointer, just return %p.
if (isa<LLVM::LLVMPointerType>(type)) {
return "%p";
}
Expand All @@ -52,23 +54,15 @@ std::string getFormatSubstr(Value value, bool hex = false,
std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
} else if (hex) {
prefix += "0";
prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4);
}

if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return prefix + "f";
} else if (type.isSignedInteger()) {
} else if (type.isInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "lli";
return prefix + (isSigned ? "lli" : "llu");
else
return prefix + "i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "llu";
else
return prefix + "u";
return prefix + (isSigned ? "i" : "u");
}
assert(false && "not supported type");
return "";
Expand Down Expand Up @@ -163,7 +157,8 @@ static StringRef makeNullTerminatedString(StringRef s) {

void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter,
std::array<Value, 3> pid, StringRef prefix,
std::optional<Value> arg, bool hex = false) {
std::optional<Value> arg, bool hex = false,
bool isSigned = false) {
assert(!prefix.empty() && "printf with empty string not supported");
auto loc = UnknownLoc::get(rewriter.getContext());

Expand All @@ -172,7 +167,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter,
os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1])
<< ", " << getFormatSubstr(pid[2]) << ")" << prefix;
if (arg.has_value())
os << getFormatSubstr(arg.value(), hex);
os << getFormatSubstr(arg.value(), hex, std::nullopt, isSigned);

llvm::SmallString<64> formatStrNewline(formatStr);
formatStrNewline.push_back('\n');
Expand Down Expand Up @@ -242,7 +237,8 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::cpu::PrintOp> {
std::nullopt);
} else {
createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(),
adaptor.getOperands()[0], op.getHex());
adaptor.getOperands()[0], op.getHex(),
op.getIsSigned()[0]);
}
rewriter.eraseOp(op);
return success();
Expand Down