Skip to content

Commit

Permalink
Add a smart_holder option.
Browse files Browse the repository at this point in the history
  • Loading branch information
kliegeois committed Sep 25, 2023
1 parent e9b5598 commit 017b9b3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 14 deletions.
8 changes: 7 additions & 1 deletion source/class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1226,11 +1226,17 @@ void ClassBinder::bind(Context &context)

std::string extra_annotation = module_local_annotation + buffer_protocol_annotation;

if( named_class )
if( named_class ) {
if (Config::get().is_smart_holder_requested(qualified_name_without_template)) {
c += '\t' +
R"(PYBIND11_TYPE_CASTER_BASE_HOLDER({}, {})"_format(qualified_name, maybe_holder_type) +
'\n';
}
c += '\t' +
R"(pybind11::class_<{}{}{}{}> cl({}, "{}", "{}"{});)"_format(qualified_name, maybe_holder_type, maybe_trampoline, maybe_base_classes(context), module_variable_name, python_class_name(C),
generate_documentation_string_for_declaration(C), extra_annotation) +
'\n';
}
// c += "\tpybind11::handle cl_type = cl;\n\n";

// if( C->isAbstract() and callback_structure) c += "\tcl.def(pybind11::init<>());\n";
Expand Down
24 changes: 24 additions & 0 deletions source/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ void Config::read(string const &file_name)

string const _custom_shared_{"custom_shared"};

string const _smart_holder_{"smart_holder"};

string const _default_static_pointer_return_value_policy_{"default_static_pointer_return_value_policy"};
string const _default_static_lvalue_reference_return_value_policy_{"default_static_lvalue_reference_return_value_policy"};
string const _default_static_rvalue_reference_return_value_policy_{"default_static_rvalue_reference_return_value_policy"};
Expand Down Expand Up @@ -202,6 +204,13 @@ void Config::read(string const &file_name)
}
else if( token == _custom_shared_ ) holder_type_ = name_without_spaces;

else if( token == _smart_holder_ ) {
include_file_ = "pybind11/smart_holder.h";
if(bind) {
smart_held_classes.push_back(name_without_spaces);
}
}

else if( token == _default_static_pointer_return_value_policy_ ) default_static_pointer_return_value_policy_ = name_without_spaces;
else if( token == _default_static_lvalue_reference_return_value_policy_ ) default_static_lvalue_reference_return_value_policy_ = name_without_spaces;
else if( token == _default_static_rvalue_reference_return_value_policy_ ) default_static_rvalue_reference_return_value_policy_ = name_without_spaces;
Expand Down Expand Up @@ -401,6 +410,21 @@ bool Config::is_module_local_requested(string const &namespace_) const
return false;
}

bool Config::is_smart_holder_requested(string const &class__) const
{
string class_{class__};
class_.erase(std::remove(class_.begin(), class_.end(), ' '), class_.end());

auto smart_held_class = std::find(smart_held_classes.begin(), smart_held_classes.end(), class_);

if( smart_held_class != smart_held_classes.end() ) {
// outs() << "Using smart holder for class : " << class_ << "\n";
return true;
}

return false;
}

bool Config::is_include_skipping_requested(string const &include) const
{
for( auto &i : includes_to_skip )
Expand Down
6 changes: 5 additions & 1 deletion source/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Config
string default_member_rvalue_reference_return_value_policy_ = "pybind11::return_value_policy::automatic";
string default_call_guard_ = "";
string holder_type_ = "std::shared_ptr";
string include_file_ = "pybind11/pybind11.h";

public:
static Config &get();
Expand All @@ -62,7 +63,7 @@ class Config
string root_module;

std::vector<string> namespaces_to_bind, classes_to_bind, functions_to_bind, namespaces_to_skip, classes_to_skip, functions_to_skip, includes_to_add, includes_to_skip;
std::vector<string> buffer_protocols, module_local_namespaces_to_add, module_local_namespaces_to_skip;
std::vector<string> buffer_protocols, module_local_namespaces_to_add, module_local_namespaces_to_skip, smart_held_classes;

std::map<string, string> const &binders() const { return binders_; }
std::map<string, string> const &add_on_binders() const { return add_on_binders_; }
Expand All @@ -85,6 +86,7 @@ class Config
string const &default_call_guard() { return default_call_guard_; }

string const &holder_type() const { return holder_type_; }
string const &include_file() const { return include_file_; }

string prefix;

Expand All @@ -102,6 +104,8 @@ class Config
bool is_class_skipping_requested(string const &class_) const;
bool is_buffer_protocol_requested(string const &class_) const;

bool is_smart_holder_requested(string const &class_) const;

bool is_include_skipping_requested(string const &include) const;

string is_custom_trampoline_function_requested(string const &function__) const;
Expand Down
27 changes: 15 additions & 12 deletions source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ const char *main_module_header = R"_(#include <map>
#include <stdexcept>
#include <string>
#include <pybind11/pybind11.h>
{0}
typedef std::function< pybind11::module & (std::string const &) > ModuleGetter;
{0}
{1}
PYBIND11_MODULE({1}, root_module) {{
root_module.doc() = "{1} module";
PYBIND11_MODULE({2}, root_module) {{
root_module.doc() = "{2} module";
std::map <std::string, pybind11::module> modules;
ModuleGetter M = [&](std::string const &namespace_) -> pybind11::module & {{
Expand All @@ -115,25 +115,25 @@ PYBIND11_MODULE({1}, root_module) {{
);
std::vector< std::pair<std::string, std::string> > sub_modules {{
{2} }};
{3} }};
for(auto &p : sub_modules ) modules[p.first.size() ? p.first+"::"+p.second : p.second] = modules[p.first].def_submodule( mangle_namespace_name(p.second).c_str(), ("Bindings for " + p.first + "::" + p.second + " namespace").c_str() );
//pybind11::class_<std::shared_ptr<void>>(M(""), "_encapsulated_data_");
{3}
{4}
}}
)_";

const char *module_header = R"_(
#include <functional>
#include <pybind11/pybind11.h>
{0}
#include <string>
{}
{1}
#ifndef BINDER_PYBIND11_TYPE_CASTER
#define BINDER_PYBIND11_TYPE_CASTER
{}
{2}
PYBIND11_DECLARE_HOLDER_TYPE(T, T*)
{}
{3}
#endif
)_";
Expand Down Expand Up @@ -437,7 +437,8 @@ void Context::generate(Config const &config)
string shared_declare = "PYBIND11_DECLARE_HOLDER_TYPE(T, "+holder_type+"<T>)";
string shared_make_opaque = "PYBIND11_MAKE_OPAQUE("+holder_type+"<void>)";

code = generate_include_directives(includes) + fmt::format(module_header, config.includes_code(), shared_declare, shared_make_opaque) + prefix_code + "void " + function_name + module_function_suffix + "\n{\n" + code + "}\n";
string const pybind11_include = "#include <" + Config::get().include_file() + ">";
code = generate_include_directives(includes) + fmt::format(module_header, pybind11_include, config.includes_code(), shared_declare, shared_make_opaque) + prefix_code + "void " + function_name + module_function_suffix + "\n{\n" + code + "}\n";

if( O_single_file ) root_module_file_handle << "// File: " << file_name << '\n' << code << "\n\n";
else update_source_file(config.prefix, file_name, code);
Expand All @@ -462,8 +463,10 @@ void Context::generate(Config const &config)
binding_function_calls += "\t" + f + "(M);\n";
}

string const pybind11_include = "#include <" + Config::get().include_file() + ">";

std::stringstream s;
s << fmt::format(main_module_header, binding_function_decls, config.root_module, namespace_pairs, binding_function_calls);
s << fmt::format(main_module_header, pybind11_include, binding_function_decls, config.root_module, namespace_pairs, binding_function_calls);

root_module_file_handle << s.str();

Expand Down

0 comments on commit 017b9b3

Please sign in to comment.