Skip to content

Commit

Permalink
[FRONTEND] use local bindings in triton.cc (triton-lang#1932)
Browse files Browse the repository at this point in the history
Another follow up with the relative imports this time dealing with the
bindings.
  • Loading branch information
IzzyPutterman authored Jul 12, 2023
1 parent 4795820 commit c615ce9
Showing 1 changed file with 41 additions and 28 deletions.
69 changes: 41 additions & 28 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ enum backend_t {

void init_triton_runtime(py::module &&m) {
// wrap backend_t
py::enum_<backend_t>(m, "backend")
py::enum_<backend_t>(m, "backend", py::module_local())
.value("HOST", HOST)
.value("CUDA", CUDA)
.value("ROCM", ROCM)
Expand Down Expand Up @@ -164,12 +164,14 @@ void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy;
using namespace pybind11::literals;

py::enum_<mlir::triton::PaddingOption>(m, "PADDING_OPTION")
py::enum_<mlir::triton::PaddingOption>(m, "PADDING_OPTION",
py::module_local())
.value("PAD_ZERO", mlir::triton::PaddingOption::PAD_ZERO)
.value("PAD_NAN", mlir::triton::PaddingOption::PAD_NAN)
.export_values();

py::enum_<mlir::triton::CacheModifier>(m, "CACHE_MODIFIER")
py::enum_<mlir::triton::CacheModifier>(m, "CACHE_MODIFIER",
py::module_local())
.value("NONE", mlir::triton::CacheModifier::NONE)
.value("CA", mlir::triton::CacheModifier::CA)
.value("CG", mlir::triton::CacheModifier::CG)
Expand All @@ -178,20 +180,21 @@ void init_triton_ir(py::module &&m) {
.value("WT", mlir::triton::CacheModifier::WT)
.export_values();

py::enum_<mlir::triton::MemSemantic>(m, "MEM_SEMANTIC")
py::enum_<mlir::triton::MemSemantic>(m, "MEM_SEMANTIC", py::module_local())
.value("ACQUIRE_RELEASE", mlir::triton::MemSemantic::ACQUIRE_RELEASE)
.value("ACQUIRE", mlir::triton::MemSemantic::ACQUIRE)
.value("RELEASE", mlir::triton::MemSemantic::RELEASE)
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
.export_values();

py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY")
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY",
py::module_local())
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
.value("EVICT_FIRST", mlir::triton::EvictionPolicy::EVICT_FIRST)
.value("EVICT_LAST", mlir::triton::EvictionPolicy::EVICT_LAST)
.export_values();

py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP")
py::enum_<mlir::triton::RMWOp>(m, "ATOMIC_OP", py::module_local())
.value("ADD", mlir::triton::RMWOp::ADD)
.value("FADD", mlir::triton::RMWOp::FADD)
.value("AND", mlir::triton::RMWOp::AND)
Expand All @@ -203,7 +206,7 @@ void init_triton_ir(py::module &&m) {
.value("UMIN", mlir::triton::RMWOp::UMIN)
.value("UMAX", mlir::triton::RMWOp::UMAX);

py::class_<mlir::MLIRContext>(m, "context")
py::class_<mlir::MLIRContext>(m, "context", py::module_local())
.def(py::init<>())
.def("load_triton", [](mlir::MLIRContext &self) {
self.getOrLoadDialect<mlir::triton::TritonDialect>();
Expand Down Expand Up @@ -259,7 +262,7 @@ void init_triton_ir(py::module &&m) {
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
// // .def("get", &ir::undef_value::get, ret::reference);

py::class_<mlir::Type>(m, "type")
py::class_<mlir::Type>(m, "type", py::module_local())
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16)
.def("__str__", [](mlir::Type &self) {
Expand All @@ -269,21 +272,21 @@ void init_triton_ir(py::module &&m) {
return os.str();
});

py::class_<mlir::FunctionType>(m, "function_type")
py::class_<mlir::FunctionType>(m, "function_type", py::module_local())
.def("param_types", [](mlir::FunctionType &self) {
return std::vector<mlir::Type>(self.getInputs().begin(),
self.getInputs().end());
});

py::class_<mlir::Location>(m, "location")
py::class_<mlir::Location>(m, "location", py::module_local())
.def("__str__", [](mlir::Location &self) {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return os.str();
});

py::class_<mlir::Value>(m, "value")
py::class_<mlir::Value>(m, "value", py::module_local())
.def("set_attr",
[](mlir::Value &self, std::string &name,
mlir::Attribute &attr) -> void {
Expand All @@ -307,14 +310,15 @@ void init_triton_ir(py::module &&m) {
})
.def("get_type", &mlir::Value::getType);

py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument");
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument",
py::module_local());

py::class_<mlir::Region>(m, "region")
py::class_<mlir::Region>(m, "region", py::module_local())
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
.def("size", [](mlir::Region &self) { return self.getBlocks().size(); })
.def("empty", &mlir::Region::empty);

py::class_<mlir::Block>(m, "block")
py::class_<mlir::Block>(m, "block", py::module_local())
.def("arg",
[](mlir::Block &self, int index) -> mlir::BlockArgument {
return self.getArgument(index);
Expand Down Expand Up @@ -383,12 +387,14 @@ void init_triton_ir(py::module &&m) {
// .value("retune", eattr::retune)
// .value("not_implemented", eattr::not_implemented);

py::class_<mlir::Attribute>(m, "attribute");
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "integer_attr");
py::class_<mlir::BoolAttr, mlir::Attribute>(m, "bool_attr");
py::class_<mlir::Attribute>(m, "attribute", py::module_local());
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "integer_attr",
py::module_local());
py::class_<mlir::BoolAttr, mlir::Attribute>(m, "bool_attr",
py::module_local());

// Ops
py::class_<mlir::OpState>(m, "OpState")
py::class_<mlir::OpState>(m, "OpState", py::module_local())
.def("set_attr",
[](mlir::OpState &self, std::string &name,
mlir::Attribute &attr) -> void { self->setAttr(name, attr); })
Expand Down Expand Up @@ -427,23 +433,27 @@ void init_triton_ir(py::module &&m) {
return mlir::succeeded(mlir::verify(self.getOperation()));
});
// scf Ops
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp")
py::class_<mlir::scf::ForOp, mlir::OpState>(m, "ForOp", py::module_local())
.def("get_induction_var", &mlir::scf::ForOp::getInductionVar);

py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp")
py::class_<mlir::scf::IfOp, mlir::OpState>(m, "IfOp", py::module_local())
.def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference)
.def("get_then_yield", &mlir::scf::IfOp::thenYield)
.def("get_else_yield", &mlir::scf::IfOp::elseYield);
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp");
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
py::class_<mlir::scf::YieldOp, mlir::OpState>(m, "YieldOp",
py::module_local());
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp",
py::module_local())
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp");
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp",
py::module_local());

// dynamic_attr is used to transfer ownership of the MLIR context to the
// module
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::dynamic_attr())
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::module_local(),
py::dynamic_attr())
.def("dump", &mlir::ModuleOp::dump)
.def("str",
[](mlir::ModuleOp &self) -> std::string {
Expand Down Expand Up @@ -523,7 +533,8 @@ void init_triton_ir(py::module &&m) {
},
ret::take_ownership);

py::class_<mlir::triton::FuncOp, mlir::OpState>(m, "function")
py::class_<mlir::triton::FuncOp, mlir::OpState>(m, "function",
py::module_local())
// .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr);
.def("args",
Expand Down Expand Up @@ -571,9 +582,11 @@ void init_triton_ir(py::module &&m) {
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
.def("reset_type", &mlir::triton::FuncOp::setType);

py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint",
py::module_local());

py::class_<TritonOpBuilder>(m, "builder", py::dynamic_attr())
py::class_<TritonOpBuilder>(m, "builder", py::module_local(),
py::dynamic_attr())
.def(py::init<mlir::MLIRContext *>())
// getters
.def("create_module",
Expand Down Expand Up @@ -1507,7 +1520,7 @@ void init_triton_ir(py::module &&m) {
offsets);
});

py::class_<mlir::PassManager>(m, "pass_manager")
py::class_<mlir::PassManager>(m, "pass_manager", py::module_local())
.def(py::init<mlir::MLIRContext *>())
.def("enable_debug",
[](mlir::PassManager &self) {
Expand Down

0 comments on commit c615ce9

Please sign in to comment.