From 017b9b34e2469b746081b141efbb9603ef23e71d Mon Sep 17 00:00:00 2001 From: kliegeois Date: Mon, 25 Sep 2023 15:50:22 -0600 Subject: [PATCH] Add a smart_holder option. --- source/class.cpp | 8 +++++++- source/config.cpp | 24 ++++++++++++++++++++++++ source/config.hpp | 6 +++++- source/context.cpp | 27 +++++++++++++++------------ 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/source/class.cpp b/source/class.cpp index c79391f0..79073f7d 100644 --- a/source/class.cpp +++ b/source/class.cpp @@ -1226,11 +1226,17 @@ void ClassBinder::bind(Context &context) std::string extra_annotation = module_local_annotation + buffer_protocol_annotation; - if( named_class ) + if( named_class ) { + if (Config::get().is_smart_holder_requested(qualified_name_without_template)) { + c += '\t' + + R"(PYBIND11_TYPE_CASTER_BASE_HOLDER({}, {})"_format(qualified_name, maybe_holder_type) + + '\n'; + } c += '\t' + R"(pybind11::class_<{}{}{}{}> cl({}, "{}", "{}"{});)"_format(qualified_name, maybe_holder_type, maybe_trampoline, maybe_base_classes(context), module_variable_name, python_class_name(C), generate_documentation_string_for_declaration(C), extra_annotation) + '\n'; + } // c += "\tpybind11::handle cl_type = cl;\n\n"; // if( C->isAbstract() and callback_structure) c += "\tcl.def(pybind11::init<>());\n"; diff --git a/source/config.cpp b/source/config.cpp index a7a056e9..9cb8e5d2 100644 --- a/source/config.cpp +++ b/source/config.cpp @@ -75,6 +75,8 @@ void Config::read(string const &file_name) string const _custom_shared_{"custom_shared"}; + string const _smart_holder_{"smart_holder"}; + string const _default_static_pointer_return_value_policy_{"default_static_pointer_return_value_policy"}; string const _default_static_lvalue_reference_return_value_policy_{"default_static_lvalue_reference_return_value_policy"}; string const _default_static_rvalue_reference_return_value_policy_{"default_static_rvalue_reference_return_value_policy"}; @@ -202,6 +204,13 @@ void Config::read(string const &file_name) } else if( token == _custom_shared_ ) holder_type_ = name_without_spaces; + else if( token == _smart_holder_ ) { + include_file_ = "pybind11/smart_holder.h"; + if(bind) { + smart_held_classes.push_back(name_without_spaces); + } + } + else if( token == _default_static_pointer_return_value_policy_ ) default_static_pointer_return_value_policy_ = name_without_spaces; else if( token == _default_static_lvalue_reference_return_value_policy_ ) default_static_lvalue_reference_return_value_policy_ = name_without_spaces; else if( token == _default_static_rvalue_reference_return_value_policy_ ) default_static_rvalue_reference_return_value_policy_ = name_without_spaces; @@ -401,6 +410,21 @@ bool Config::is_module_local_requested(string const &namespace_) const return false; } +bool Config::is_smart_holder_requested(string const &class__) const +{ + string class_{class__}; + class_.erase(std::remove(class_.begin(), class_.end(), ' '), class_.end()); + + auto smart_held_class = std::find(smart_held_classes.begin(), smart_held_classes.end(), class_); + + if( smart_held_class != smart_held_classes.end() ) { + // outs() << "Using smart holder for class : " << class_ << "\n"; + return true; + } + + return false; +} + bool Config::is_include_skipping_requested(string const &include) const { for( auto &i : includes_to_skip ) diff --git a/source/config.hpp b/source/config.hpp index 9f7b63d8..00fb033e 100644 --- a/source/config.hpp +++ b/source/config.hpp @@ -52,6 +52,7 @@ class Config string default_member_rvalue_reference_return_value_policy_ = "pybind11::return_value_policy::automatic"; string default_call_guard_ = ""; string holder_type_ = "std::shared_ptr"; + string include_file_ = "pybind11/pybind11.h"; public: static Config &get(); @@ -62,7 +63,7 @@ class Config string root_module; std::vector namespaces_to_bind, classes_to_bind, functions_to_bind, namespaces_to_skip, classes_to_skip, functions_to_skip, includes_to_add, includes_to_skip; - std::vector buffer_protocols, module_local_namespaces_to_add, module_local_namespaces_to_skip; + std::vector buffer_protocols, module_local_namespaces_to_add, module_local_namespaces_to_skip, smart_held_classes; std::map const &binders() const { return binders_; } std::map const &add_on_binders() const { return add_on_binders_; } @@ -85,6 +86,7 @@ class Config string const &default_call_guard() { return default_call_guard_; } string const &holder_type() const { return holder_type_; } + string const &include_file() const { return include_file_; } string prefix; @@ -102,6 +104,8 @@ class Config bool is_class_skipping_requested(string const &class_) const; bool is_buffer_protocol_requested(string const &class_) const; + bool is_smart_holder_requested(string const &class_) const; + bool is_include_skipping_requested(string const &include) const; string is_custom_trampoline_function_requested(string const &function__) const; diff --git a/source/context.cpp b/source/context.cpp index 2a945e13..a27fb0dc 100644 --- a/source/context.cpp +++ b/source/context.cpp @@ -87,14 +87,14 @@ const char *main_module_header = R"_(#include #include #include -#include +{0} typedef std::function< pybind11::module & (std::string const &) > ModuleGetter; -{0} +{1} -PYBIND11_MODULE({1}, root_module) {{ - root_module.doc() = "{1} module"; +PYBIND11_MODULE({2}, root_module) {{ + root_module.doc() = "{2} module"; std::map modules; ModuleGetter M = [&](std::string const &namespace_) -> pybind11::module & {{ @@ -115,25 +115,25 @@ PYBIND11_MODULE({1}, root_module) {{ ); std::vector< std::pair > sub_modules {{ -{2} }}; +{3} }}; for(auto &p : sub_modules ) modules[p.first.size() ? p.first+"::"+p.second : p.second] = modules[p.first].def_submodule( mangle_namespace_name(p.second).c_str(), ("Bindings for " + p.first + "::" + p.second + " namespace").c_str() ); //pybind11::class_>(M(""), "_encapsulated_data_"); -{3} +{4} }} )_"; const char *module_header = R"_( #include -#include +{0} #include -{} +{1} #ifndef BINDER_PYBIND11_TYPE_CASTER #define BINDER_PYBIND11_TYPE_CASTER - {} + {2} PYBIND11_DECLARE_HOLDER_TYPE(T, T*) - {} + {3} #endif )_"; @@ -437,7 +437,8 @@ void Context::generate(Config const &config) string shared_declare = "PYBIND11_DECLARE_HOLDER_TYPE(T, "+holder_type+")"; string shared_make_opaque = "PYBIND11_MAKE_OPAQUE("+holder_type+")"; - code = generate_include_directives(includes) + fmt::format(module_header, config.includes_code(), shared_declare, shared_make_opaque) + prefix_code + "void " + function_name + module_function_suffix + "\n{\n" + code + "}\n"; + string const pybind11_include = "#include <" + Config::get().include_file() + ">"; + code = generate_include_directives(includes) + fmt::format(module_header, pybind11_include, config.includes_code(), shared_declare, shared_make_opaque) + prefix_code + "void " + function_name + module_function_suffix + "\n{\n" + code + "}\n"; if( O_single_file ) root_module_file_handle << "// File: " << file_name << '\n' << code << "\n\n"; else update_source_file(config.prefix, file_name, code); @@ -462,8 +463,10 @@ void Context::generate(Config const &config) binding_function_calls += "\t" + f + "(M);\n"; } + string const pybind11_include = "#include <" + Config::get().include_file() + ">"; + std::stringstream s; - s << fmt::format(main_module_header, binding_function_decls, config.root_module, namespace_pairs, binding_function_calls); + s << fmt::format(main_module_header, pybind11_include, binding_function_decls, config.root_module, namespace_pairs, binding_function_calls); root_module_file_handle << s.str();