Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
revert some previous changes

update

update

update

update

update

update
  • Loading branch information
sfzhu93 committed Jan 24, 2025
1 parent 6aa2df9 commit 36ddd87
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 95 deletions.
134 changes: 91 additions & 43 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,42 @@ class TritonOpBuilder {
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
};

// Run the pass manager under a source manager diagnostic handler, which
// enables emitted MLIR diagnostics to directly reference Python source
// code. This diagnostic handler supports filtering diagnostic info by
// severity levels.
struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler {
TritonSourceMgrDiagnosticHandler(MLIRContext *ctx,
DiagnosticSeverity minSeverity)
: SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) {
setHandler([this, minSeverity](Diagnostic &diag) {
auto severity = diag.getSeverity();
switch (severity) {
case DiagnosticSeverity::Error:
break;
case DiagnosticSeverity::Warning:
if (minSeverity == DiagnosticSeverity::Error)
return success();
break;
case DiagnosticSeverity::Remark:
if (minSeverity == DiagnosticSeverity::Error ||
minSeverity == DiagnosticSeverity::Warning)
return success();
break;
case DiagnosticSeverity::Note:
// notes are handled somewhere else.
return failure();
default:
llvm_unreachable("Unknown diagnostic severity");
}
emitDiagnostic(diag);
return success();
});
}

llvm::SourceMgr sourceMgr;
};

std::string locationToString(Location loc) {
std::string str;
llvm::raw_string_ostream os(str);
Expand All @@ -148,6 +184,23 @@ std::string locationToString(Location loc) {
return str;
}

// Function to parse a comma-separated string into a vector of C-style strings
llvm::SmallVector<const char *, 3>
parseCommaSeparatedValues(const std::string &input,
llvm::SmallVector<std::string, 3> &storage) {
llvm::SmallVector<StringRef, 3> split;
llvm::SmallVector<const char *, 3> result;
StringRef(input.c_str()).split(split, ',');
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
// StringRefs are not always null-terminated.
// The purpose for this storage pattern is to
// produce a collection of C-strings that are.
storage.push_back(str.str());
return storage.back().c_str();
});
return result;
}

void outputWarning(Location loc, const std::string &msg) {
std::string locStr = locationToString(loc);

Expand Down Expand Up @@ -1689,27 +1742,15 @@ void init_triton_ir(py::module &&m) {
.def("enable_debug",
[](PassManager &self) {
auto *context = self.getContext();
bool haveDiagnostics =
::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS");
bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
std::string funcToDump;
if (!haveDump) {
funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP");
if (!funcToDump.empty())
haveDump = true;
}
if (haveDiagnostics || haveDump) {
context->disableMultithreading();
}
if (haveDiagnostics) {
context->printOpOnDiagnostic(true);
context->printStackTraceOnDiagnostic(true);
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
llvm::outs() << diag << "\n";
return success();
});
}
if (haveDump) {
context->disableMultithreading();
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
Expand Down Expand Up @@ -1739,6 +1780,8 @@ void init_triton_ir(py::module &&m) {
// TODO: maybe dump module to file and print error for better
// diagnostics

auto *context = mod.getContext();

auto reproducerPath =
triton::tools::getStrEnv("TRITON_REPRODUCER_PATH");
if (!reproducerPath.empty()) {
Expand All @@ -1750,7 +1793,7 @@ void init_triton_ir(py::module &&m) {
makeReproducer(anchorName, passes, op, reproducerPath);
// But if the pass manager crashes, attempt to generate a local
// reproducer instead.
mod.getContext()->disableMultithreading();
context->disableMultithreading();
self.enableCrashReproducerGeneration(reproducerPath,
/*genLocalReproducer=*/true);
}
Expand All @@ -1761,20 +1804,9 @@ void init_triton_ir(py::module &&m) {

if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY");
!debugOnly.empty()) {
llvm::SmallVector<StringRef, 3> split;
llvm::SmallVector<std::string, 3> storage;
llvm::SmallVector<const char *, 3> debugTypes;

StringRef(debugOnly.c_str()).split(split, ',');
llvm::transform(split, std::back_inserter(debugTypes),
[&storage](StringRef str) {
// StringRefs are not always null-terminated.
// The purpose for this storage pattern is to
// produce a collection of C-strings that are.
storage.push_back(str.str());
return storage.back().c_str();
});

llvm::SmallVector<const char *, 3> debugTypes =
parseCommaSeparatedValues(debugOnly, storage);
::llvm::DebugFlag = true;
using namespace llvm;
setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
Expand All @@ -1785,25 +1817,41 @@ void init_triton_ir(py::module &&m) {
self.enableTiming();
}

// Run the pass manager under a source manager diagnostic handler, which
// enables emitted MLIR diagnostics to directly reference Python source
// code. This diagnostic handler will only filter for errors.
struct SourceMgrErrorDiagnosticHandler
: public SourceMgrDiagnosticHandler {
SourceMgrErrorDiagnosticHandler(MLIRContext *ctx)
: SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) {
setHandler([this](Diagnostic &diag) {
if (diag.getSeverity() != DiagnosticSeverity::Error)
return failure();
emitDiagnostic(diag);
return success();
});
// setting up diagnostics
bool showOperations = false, showStacktraces = false,
showRemarks = false, showWarnings = false;

if (auto enableDiagnostics =
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
!enableDiagnostics.empty()) {
llvm::SmallVector<std::string, 3> storage;
parseCommaSeparatedValues(enableDiagnostics, storage);
for (auto &str : storage) {
if (str == "warnings") {
showWarnings = true;
} else if (str == "remarks") {
showRemarks = true;
} else if (str == "stacktraces") {
showStacktraces = true;
} else if (str == "operations") {
showOperations = true;
}
// we show errors by default, so no need to set it
}
}

llvm::SourceMgr sourceMgr;
};
SourceMgrErrorDiagnosticHandler diagHandler(mod.getContext());
DiagnosticSeverity minSeverity = showWarnings
? DiagnosticSeverity::Warning
: DiagnosticSeverity::Error;
minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity;

TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity);

context->printOpOnDiagnostic(showOperations);
context->printStackTraceOnDiagnostic(showStacktraces);
if (showStacktraces) {
context->disableMultithreading();
}
if (failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");
});
Expand Down
102 changes: 61 additions & 41 deletions python/test/unit/test_perf_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@


@contextmanager
def enable_remark_context():
def enable_diagnostics_context(value):
try:
os.environ["MLIR_ENABLE_REMARK"] = "1"
os.environ["MLIR_ENABLE_DIAGNOSTICS"] = value
yield
finally:
os.environ["MLIR_ENABLE_REMARK"] = "0"


def is_perf_warning_enabled():
return os.environ.get("MLIR_ENABLE_REMARK", "0") == "1"
os.environ["MLIR_ENABLE_DIAGNOSTICS"] = ""


def is_cuda():
Expand Down Expand Up @@ -74,29 +70,39 @@ def matmul_kernel(
c = tl.dot(a, b)
tl.store(c_block_ptr, c)

with enable_remark_context():
triton.compile(
triton.compiler.ASTSource(
fn=matmul_kernel,
signature={
"a_ptr": "*fp32",
"b_ptr": "*fp32",
"c_ptr": "*fp32",
"M": "i32",
"N": "i32",
"K": "i32",
"stride_am": "i32",
"stride_ak": "i32",
"stride_bk": "i32",
"stride_bn": "i32",
"stride_cm": "i32",
"stride_cn": "i32",
},
constexprs={},
))
signature = {
"a_ptr": "*fp32",
"b_ptr": "*fp32",
"c_ptr": "*fp32",
"M": "i32",
"N": "i32",
"K": "i32",
"stride_am": "i32",
"stride_ak": "i32",
"stride_bk": "i32",
"stride_bn": "i32",
"stride_cm": "i32",
"stride_cn": "i32",
}
with enable_diagnostics_context('remarks'):
triton.compile(triton.compiler.ASTSource(
fn=matmul_kernel,
signature=signature,
constexprs={},
))
captured = capfd.readouterr()

assert ("remark: Warning: can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark"
assert ("can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark"
assert "note: see current operation:" not in captured.err

with enable_diagnostics_context('remarks,operations,stacktraces'):
triton.compile(triton.compiler.ASTSource(
fn=matmul_kernel,
signature=signature,
constexprs={},
))
captured = capfd.readouterr()
assert "note: diagnostic emitted with trace:" in captured.err
assert "note: see current operation:" in captured.err


Expand Down Expand Up @@ -126,25 +132,39 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr)
tl.store(out_ptr0 + (x4), tmp22, None)

XBLOCK = 1024
with enable_remark_context():

astsource_args = {
"fn": ldst_vec,
"signature": {
"in_ptr0": "*i64",
"in_ptr1": "*i64",
"in_ptr2": "*fp16",
"in_ptr3": "*fp32",
"out_ptr0": "*fp16",
"XBLOCK": "constexpr",
},
"constexprs": {"XBLOCK": XBLOCK},
}

with enable_diagnostics_context('remarks'):
triton.compile(
triton.compiler.ASTSource(
fn=ldst_vec,
signature={
"in_ptr0": "*i64",
"in_ptr1": "*i64",
"in_ptr2": "*fp16",
"in_ptr3": "*fp32",
"out_ptr0": "*fp16",
"XBLOCK": "constexpr",
},
constexprs={"XBLOCK": XBLOCK},
),
triton.compiler.ASTSource(**astsource_args),
options={"num_warps": 1},
)

_, err = capfd.readouterr()
assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark"
assert "note: see current operation:" not in err

with enable_diagnostics_context('remarks,operations,stacktraces'):
triton.compile(
triton.compiler.ASTSource(**astsource_args),
options={"num_warps": 1},
)

_, err = capfd.readouterr()
assert "note: see current operation:" in err
assert "note: diagnostic emitted with trace:" in err


def test_remark_swp_op_before_operands(capfd, fresh_triton_cache):
Expand Down
11 changes: 0 additions & 11 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,6 @@ def make_ttgir(mod, metadata, opt, capability):
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
# Set up Diagnostic
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
srcMgr = llvm.source_mgr()
_ = ir.source_mgr_diag(srcMgr, mod.context)
mod.context.printOpOnDiagnostic(True)
# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
Expand Down Expand Up @@ -283,11 +277,6 @@ def make_llir(src, metadata, options, capability):
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
# Set up Diagnostic
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
srcMgr = llvm.source_mgr()
_ = ir.source_mgr_diag(srcMgr, mod.context)
mod.context.printOpOnDiagnostic(True)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.convert.add_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)
Expand Down

0 comments on commit 36ddd87

Please sign in to comment.