From c615ce944ce17e45ca2186f49a13998808e4170b Mon Sep 17 00:00:00 2001 From: Izzy Putterman Date: Tue, 11 Jul 2023 19:19:48 -0700 Subject: [PATCH] [FRONTEND] use local bindings in triton.cc (#1932) Another follow up with the relative imports this time dealing with the bindings. --- python/src/triton.cc | 69 ++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 46593ef103fb..406412a12316 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -65,7 +65,7 @@ enum backend_t { void init_triton_runtime(py::module &&m) { // wrap backend_t - py::enum_(m, "backend") + py::enum_(m, "backend", py::module_local()) .value("HOST", HOST) .value("CUDA", CUDA) .value("ROCM", ROCM) @@ -164,12 +164,14 @@ void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; - py::enum_(m, "PADDING_OPTION") + py::enum_(m, "PADDING_OPTION", + py::module_local()) .value("PAD_ZERO", mlir::triton::PaddingOption::PAD_ZERO) .value("PAD_NAN", mlir::triton::PaddingOption::PAD_NAN) .export_values(); - py::enum_(m, "CACHE_MODIFIER") + py::enum_(m, "CACHE_MODIFIER", + py::module_local()) .value("NONE", mlir::triton::CacheModifier::NONE) .value("CA", mlir::triton::CacheModifier::CA) .value("CG", mlir::triton::CacheModifier::CG) @@ -178,20 +180,21 @@ void init_triton_ir(py::module &&m) { .value("WT", mlir::triton::CacheModifier::WT) .export_values(); - py::enum_(m, "MEM_SEMANTIC") + py::enum_(m, "MEM_SEMANTIC", py::module_local()) .value("ACQUIRE_RELEASE", mlir::triton::MemSemantic::ACQUIRE_RELEASE) .value("ACQUIRE", mlir::triton::MemSemantic::ACQUIRE) .value("RELEASE", mlir::triton::MemSemantic::RELEASE) .value("RELAXED", mlir::triton::MemSemantic::RELAXED) .export_values(); - py::enum_(m, "EVICTION_POLICY") + py::enum_(m, "EVICTION_POLICY", + py::module_local()) .value("NORMAL", mlir::triton::EvictionPolicy::NORMAL) .value("EVICT_FIRST", mlir::triton::EvictionPolicy::EVICT_FIRST) .value("EVICT_LAST", mlir::triton::EvictionPolicy::EVICT_LAST) .export_values(); - py::enum_(m, "ATOMIC_OP") + py::enum_(m, "ATOMIC_OP", py::module_local()) .value("ADD", mlir::triton::RMWOp::ADD) .value("FADD", mlir::triton::RMWOp::FADD) .value("AND", mlir::triton::RMWOp::AND) @@ -203,7 +206,7 @@ void init_triton_ir(py::module &&m) { .value("UMIN", mlir::triton::RMWOp::UMIN) .value("UMAX", mlir::triton::RMWOp::UMAX); - py::class_(m, "context") + py::class_(m, "context", py::module_local()) .def(py::init<>()) .def("load_triton", [](mlir::MLIRContext &self) { self.getOrLoadDialect(); @@ -259,7 +262,7 @@ void init_triton_ir(py::module &&m) { // // py::class_(m, "undef") // // .def("get", &ir::undef_value::get, ret::reference); - py::class_(m, "type") + py::class_(m, "type", py::module_local()) .def("is_integer", &mlir::Type::isInteger) .def("is_fp16", &mlir::Type::isF16) .def("__str__", [](mlir::Type &self) { @@ -269,13 +272,13 @@ void init_triton_ir(py::module &&m) { return os.str(); }); - py::class_(m, "function_type") + py::class_(m, "function_type", py::module_local()) .def("param_types", [](mlir::FunctionType &self) { return std::vector(self.getInputs().begin(), self.getInputs().end()); }); - py::class_(m, "location") + py::class_(m, "location", py::module_local()) .def("__str__", [](mlir::Location &self) { std::string str; llvm::raw_string_ostream os(str); @@ -283,7 +286,7 @@ void init_triton_ir(py::module &&m) { return os.str(); }); - py::class_(m, "value") + py::class_(m, "value", py::module_local()) .def("set_attr", [](mlir::Value &self, std::string &name, mlir::Attribute &attr) -> void { @@ -307,14 +310,15 @@ void init_triton_ir(py::module &&m) { }) .def("get_type", &mlir::Value::getType); - py::class_(m, "block_argument"); + py::class_(m, "block_argument", + py::module_local()); - py::class_(m, "region") + py::class_(m, "region", py::module_local()) .def("get_parent_region", &mlir::Region::getParentRegion, ret::reference) .def("size", [](mlir::Region &self) { return self.getBlocks().size(); }) .def("empty", &mlir::Region::empty); - py::class_(m, "block") + py::class_(m, "block", py::module_local()) .def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument { return self.getArgument(index); @@ -383,12 +387,14 @@ void init_triton_ir(py::module &&m) { // .value("retune", eattr::retune) // .value("not_implemented", eattr::not_implemented); - py::class_(m, "attribute"); - py::class_(m, "integer_attr"); - py::class_(m, "bool_attr"); + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", + py::module_local()); + py::class_(m, "bool_attr", + py::module_local()); // Ops - py::class_(m, "OpState") + py::class_(m, "OpState", py::module_local()) .def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void { self->setAttr(name, attr); }) @@ -427,23 +433,27 @@ void init_triton_ir(py::module &&m) { return mlir::succeeded(mlir::verify(self.getOperation())); }); // scf Ops - py::class_(m, "ForOp") + py::class_(m, "ForOp", py::module_local()) .def("get_induction_var", &mlir::scf::ForOp::getInductionVar); - py::class_(m, "IfOp") + py::class_(m, "IfOp", py::module_local()) .def("get_then_block", &mlir::scf::IfOp::thenBlock, ret::reference) .def("get_else_block", &mlir::scf::IfOp::elseBlock, ret::reference) .def("get_then_yield", &mlir::scf::IfOp::thenYield) .def("get_else_yield", &mlir::scf::IfOp::elseYield); - py::class_(m, "YieldOp"); - py::class_(m, "WhileOp") + py::class_(m, "YieldOp", + py::module_local()); + py::class_(m, "WhileOp", + py::module_local()) .def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference) .def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference); - py::class_(m, "ConditionOp"); + py::class_(m, "ConditionOp", + py::module_local()); // dynamic_attr is used to transfer ownership of the MLIR context to the // module - py::class_(m, "module", py::dynamic_attr()) + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) .def("dump", &mlir::ModuleOp::dump) .def("str", [](mlir::ModuleOp &self) -> std::string { @@ -523,7 +533,8 @@ void init_triton_ir(py::module &&m) { }, ret::take_ownership); - py::class_(m, "function") + py::class_(m, "function", + py::module_local()) // .def_property_readonly("attrs", &ir::function::attrs) // .def("add_attr", &ir::function::add_attr); .def("args", @@ -571,9 +582,11 @@ void init_triton_ir(py::module &&m) { .def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType) .def("reset_type", &mlir::triton::FuncOp::setType); - py::class_(m, "InsertPoint"); + py::class_(m, "InsertPoint", + py::module_local()); - py::class_(m, "builder", py::dynamic_attr()) + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) .def(py::init()) // getters .def("create_module", @@ -1507,7 +1520,7 @@ void init_triton_ir(py::module &&m) { offsets); }); - py::class_(m, "pass_manager") + py::class_(m, "pass_manager", py::module_local()) .def(py::init()) .def("enable_debug", [](mlir::PassManager &self) {