Skip to content

Commit

Permalink
[FRONTEND][BACKEND] Materialize line info for triton kernels (triton-…
Browse files Browse the repository at this point in the history
…lang#1902)

`export TRITON_DISABLE_LINE_INFO=1` to disable the feature.
  • Loading branch information
Jokeren authored Jul 7, 2023
1 parent cb0321f commit cc5a7ed
Show file tree
Hide file tree
Showing 17 changed files with 875 additions and 533 deletions.
1 change: 1 addition & 0 deletions include/triton/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Target)
1 change: 1 addition & 0 deletions include/triton/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(LLVMIR)
3 changes: 3 additions & 0 deletions include/triton/Target/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR)
add_public_tablegen_target(LLVMIRIncGen)
4 changes: 2 additions & 2 deletions include/triton/Target/LLVMIR/LLVMIRTranslation.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
#define TRITON_TARGET_LLVMIRTRANSLATION_H
#ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
Expand Down
17 changes: 17 additions & 0 deletions include/triton/Target/LLVMIR/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRITON_TARGET_LLVM_IR_PASSES_H
#define TRITON_TARGET_LLVM_IR_PASSES_H

#include "mlir/Pass/Pass.h"

namespace mlir {

/// Create a pass to add DIScope
std::unique_ptr<Pass> createLLVMDIScopePass();

/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
#include "triton/Target/LLVMIR/Passes.h.inc"

} // namespace mlir

#endif // TRITON_TARGET_LLVM_IR_PASSES_H
15 changes: 15 additions & 0 deletions include/triton/Target/LLVMIR/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_TARGET_LLVMIR_PASSES
#define TRITON_TARGET_LLVMIR_PASSES

include "mlir/Pass/PassBase.td"

def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> {
let summary = "Materialize LLVM line info";
let description = [{
This pass materializes line mapping information for LLVM IR dialect operations.
}];

let constructor = "mlir::createLLVMDIScopePass()";
}

#endif
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace SharedToDotOperandMMAv1 {
using CoordTy = SmallVector<Value>;
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;

SmallVector<CoordTy> getMNCoords(Value thread,
SmallVector<CoordTy> getMNCoords(Value thread, Location loc,
ConversionPatternRewriter &rewriter,
ArrayRef<unsigned int> wpt,
const MmaEncodingAttr &mmaLayout,
Expand Down Expand Up @@ -187,8 +187,8 @@ struct ConvertLayoutOpConversion
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
auto coords = SharedToDotOperandMMAv1::getMNCoords(
threadId, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
isARow, isBRow, isAVec4, isBVec4);
threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout,
shape, isARow, isBRow, isAVec4, isBVec4);
return coords[elemId];
} else {
llvm_unreachable("Unexpected MMALayout version");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ namespace SharedToDotOperandMMAv1 {
using CoordTy = SmallVector<Value>;
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;

SmallVector<CoordTy> getMNCoords(Value thread,
SmallVector<CoordTy> getMNCoords(Value thread, Location loc,
ConversionPatternRewriter &rewriter,
ArrayRef<unsigned int> wpt,
const MmaEncodingAttr &mmaLayout,
Expand All @@ -348,7 +348,6 @@ SmallVector<CoordTy> getMNCoords(Value thread,
static constexpr std::array<int, 3> fpw{{2, 2, 1}};

auto *ctx = thread.getContext();
auto loc = UnknownLoc::get(ctx);
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ bool isBroadcastConstantCombinable(Attribute value) {

DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
Value bcast_res) {

auto resType = bcast_res.getType().cast<ShapedType>();
DenseElementsAttr res;
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ class MoveConvertOutOfLoop : public mlir::RewritePattern {
// replace
SmallVector<Value, 4> newResults = newForOp->getResults();
newResults[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
rewriter.getUnknownLoc(), origType, newForOp->getResult(i));
newForOp.getLoc(), origType, newForOp->getResult(i));
newResults[i].getDefiningOp()->moveAfter(newForOp);
return newResults;
}
Expand Down
4 changes: 4 additions & 0 deletions lib/Target/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
add_mlir_translation_library(TritonLLVMIR
LLVMIRTranslation.cpp
LLVMDIScope.cpp

LINK_COMPONENTS
Core

DEPENDS
LLVMIRIncGen

LINK_LIBS
${CMAKE_DL_LIBS}
PUBLIC
Expand Down
148 changes: 148 additions & 0 deletions lib/Target/LLVMIR/LLVMDIScope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#include "triton/Target/LLVMIR/Passes.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "llvm/BinaryFormat/Dwarf.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Path.h"

//===----------------------------------------------------------------------===//
// This file implements a pass to add debug info scope to LLVM operations, and
// is inspired by the DIScopeForLLVMFuncOpPass in LLVM/MLIR. Different from the
// DIScopeForLLVMFuncOpPass, this pass also handles inlined functions.
//===----------------------------------------------------------------------===//

using namespace mlir;

#define GEN_PASS_CLASSES
#include "triton/Target/LLVMIR/Passes.h.inc"

namespace {

/// Attempt to extract a filename for the given loc.
FileLineColLoc extractFileLoc(Location loc) {
if (auto fileLoc = dyn_cast<FileLineColLoc>(loc))
return fileLoc;
if (auto nameLoc = dyn_cast<NameLoc>(loc))
return extractFileLoc(nameLoc.getChildLoc());
if (auto opaqueLoc = dyn_cast<OpaqueLoc>(loc))
return extractFileLoc(opaqueLoc.getFallbackLocation());
if (auto fusedLoc = dyn_cast<FusedLoc>(loc))
return extractFileLoc(fusedLoc.getLocations().front());
if (auto callerLoc = dyn_cast<CallSiteLoc>(loc))
return extractFileLoc(callerLoc.getCaller());
StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), "<unknown>");
return mlir::FileLineColLoc::get(unknownFile, 0, 0);
}

/// Add a debug info scope to LLVMFuncOp that are missing it.
struct LLVMDIScopePass : public LLVMDIScopeBase<LLVMDIScopePass> {
LLVMDIScopePass() = default;

void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) {
Location loc = funcOp.getLoc();
if (loc->findInstanceOf<mlir::FusedLocWith<LLVM::DISubprogramAttr>>())
return;

MLIRContext *context = &getContext();

// To find a DICompileUnitAttr attached to a parent (the module for
// example), otherwise create a default one.
LLVM::DICompileUnitAttr compileUnitAttr;
if (ModuleOp module = funcOp->getParentOfType<ModuleOp>()) {
auto fusedCompileUnitAttr =
module->getLoc()
->findInstanceOf<mlir::FusedLocWith<LLVM::DICompileUnitAttr>>();
if (fusedCompileUnitAttr)
compileUnitAttr = fusedCompileUnitAttr.getMetadata();
}

// Filename, line and colmun to associate to the function.
LLVM::DIFileAttr fileAttr;
int64_t line = 1, col = 1;
FileLineColLoc fileLoc = extractFileLoc(loc);
if (!fileLoc && compileUnitAttr) {
fileAttr = compileUnitAttr.getFile();
} else if (!fileLoc) {
fileAttr = LLVM::DIFileAttr::get(context, "<unknown>", "");
} else {
line = fileLoc.getLine();
col = fileLoc.getColumn();
StringRef inputFilePath = fileLoc.getFilename().getValue();
fileAttr = LLVM::DIFileAttr::get(
context, llvm::sys::path::filename(inputFilePath),
llvm::sys::path::parent_path(inputFilePath));
}
if (!compileUnitAttr) {
compileUnitAttr = LLVM::DICompileUnitAttr::get(
context, llvm::dwarf::DW_LANG_C, fileAttr,
StringAttr::get(context, "triton"), /*isOptimized=*/true,
LLVM::DIEmissionKind::LineTablesOnly);
}
auto subroutineTypeAttr =
LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {});

StringAttr funcNameAttr = funcOp.getNameAttr();
// Note that scopeline is set differently from LLVM's
// DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be
// the column offset
auto subprogramAttr =
LLVM::DISubprogramAttr::get(context, compileUnitAttr, fileAttr,
funcNameAttr, funcNameAttr, fileAttr,
/*line=*/line,
/*scopeline=*/line,
LLVM::DISubprogramFlags::Definition |
LLVM::DISubprogramFlags::Optimized,
subroutineTypeAttr);
funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr));
}

// Get a nested loc for inlined functions
Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr,
Location calleeLoc) {
auto calleeFileName = extractFileLoc(calleeLoc).getFilename();
auto context = op->getContext();
LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get(
context, llvm::sys::path::filename(calleeFileName),
llvm::sys::path::parent_path(calleeFileName));
auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get(
context, scopeAttr, calleeFileAttr, /*discriminator=*/0);
Location loc = op->getLoc();
if (calleeLoc.isa<CallSiteLoc>()) {
auto nestedLoc = calleeLoc.cast<CallSiteLoc>().getCallee();
loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc);
}
return FusedLoc::get(context, {loc}, lexicalBlockFileAttr);
}

void setLexicalBlockFileAttr(Operation *op) {
auto opLoc = op->getLoc();
if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) {
auto callerLoc = callSiteLoc.getCaller();
auto calleeLoc = callSiteLoc.getCallee();
LLVM::DIScopeAttr scopeAttr;
// We assemble the full inline stack so the parent of this loc must be a
// function
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
auto funcOpLoc = funcOp.getLoc().cast<FusedLoc>();
scopeAttr = funcOpLoc.getMetadata().cast<LLVM::DISubprogramAttr>();
auto loc = getNestedLoc(op, scopeAttr, calleeLoc);
op->setLoc(loc);
}
}

void runOnOperation() override {
getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) -> void {
if (isa<LLVM::LLVMFuncOp>(op))
setSubprogramAttr(cast<LLVM::LLVMFuncOp>(op));
else
setLexicalBlockFileAttr(op);
});
}
};

} // end anonymous namespace

std::unique_ptr<Pass> mlir::createLLVMDIScopePass() {
return std::make_unique<LLVMDIScopePass>();
}
5 changes: 5 additions & 0 deletions lib/Target/LLVMIR/LLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/Dialect.h"
Expand All @@ -15,6 +16,7 @@
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Target/LLVMIR/Passes.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "triton/Tools/Sys/GetPlatform.hpp"
#include "llvm/ADT/APInt.h"
Expand Down Expand Up @@ -329,6 +331,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
}
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
pm.enableIRPrinting(
/*shouldPrintBeforePass=*/nullptr,
/*shouldPrintAfterPass=*/
Expand All @@ -347,6 +350,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
// Simplify the IR
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
if (!::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"))
pm.addPass(mlir::createLLVMDIScopePass());

if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";
Expand Down
Loading

0 comments on commit cc5a7ed

Please sign in to comment.