Skip to content

Commit

Permalink
Refactor HOC/Python wrapper printing. (#1522)
Browse files Browse the repository at this point in the history
* Split HOC/Py wrapper code printing.

* Move printing return variable.

* Rename HOC/Py wrapper function name.

* Extract HOC/Py signature.
  • Loading branch information
1uc authored Oct 22, 2024
1 parent 0c789ec commit 380207b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 33 deletions.
86 changes: 55 additions & 31 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,43 @@ void CodegenNeuronCppVisitor::print_function_procedure_helper(const ast::Block&
}
}


void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body(
void CodegenNeuronCppVisitor::print_hoc_py_wrapper_call_impl(
const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type) {
if (info.point_process && wrapper_type == InterpreterWrapper::Python) {
return;
}
const auto block_name = function_or_procedure_block->get_node_name();
if (wrapper_type == InterpreterWrapper::HOC) {
printer->fmt_push_block("{}", hoc_function_signature(block_name));

const auto get_func_call_str = [&]() {
const auto& params = function_or_procedure_block->get_parameters();
const auto func_proc_name = block_name + "_" + info.mod_suffix;
auto func_call = fmt::format("{}({}", func_proc_name, internal_method_arguments());
for (int i = 0; i < params.size(); ++i) {
func_call.append(fmt::format(", *getarg({})", i + 1));
}
func_call.append(")");
return func_call;
};

printer->add_line("double _r = 0.0;");
if (function_or_procedure_block->is_function_block()) {
printer->add_indent();
printer->fmt_text("_r = {};", get_func_call_str());
printer->add_newline();
} else {
printer->fmt_push_block("{}", py_function_signature(block_name));
printer->add_line("_r = 1.;");
printer->fmt_line("{};", get_func_call_str());
}
if (info.point_process || wrapper_type != InterpreterWrapper::HOC) {
printer->add_line("return(_r);");
} else if (wrapper_type == InterpreterWrapper::HOC) {
printer->add_line("hoc_retpushx(_r);");
}
}

void CodegenNeuronCppVisitor::print_hoc_py_wrapper_setup(
const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type) {
const auto block_name = function_or_procedure_block->get_node_name();
printer->add_multi_line(R"CODE(
double _r{};
Datum* _ppvar;
Datum* _thread;
NrnThread* nt;
Expand Down Expand Up @@ -369,38 +391,40 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body(
table_update_function_name(block_name),
internal_method_arguments());
}
const auto get_func_call_str = [&]() {
const auto& params = function_or_procedure_block->get_parameters();
const auto func_proc_name = block_name + "_" + info.mod_suffix;
auto func_call = fmt::format("{}({}", func_proc_name, internal_method_arguments());
for (int i = 0; i < params.size(); ++i) {
func_call.append(fmt::format(", *getarg({})", i + 1));
}
func_call.append(")");
return func_call;
};
if (function_or_procedure_block->is_function_block()) {
printer->add_indent();
printer->fmt_text("_r = {};", get_func_call_str());
printer->add_newline();
}


std::string CodegenNeuronCppVisitor::hoc_py_wrapper_signature(
const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type) {
const auto block_name = function_or_procedure_block->get_node_name();
if (wrapper_type == InterpreterWrapper::HOC) {
return hoc_function_signature(block_name);
} else {
printer->add_line("_r = 1.;");
printer->fmt_line("{};", get_func_call_str());
return py_function_signature(block_name);
}
if (info.point_process || wrapper_type != InterpreterWrapper::HOC) {
printer->add_line("return(_r);");
} else if (wrapper_type == InterpreterWrapper::HOC) {
printer->add_line("hoc_retpushx(_r);");
}

void CodegenNeuronCppVisitor::print_hoc_py_wrapper(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type) {
if (info.point_process && wrapper_type == InterpreterWrapper::Python) {
return;
}

printer->push_block(hoc_py_wrapper_signature(function_or_procedure_block, wrapper_type));

print_hoc_py_wrapper_setup(function_or_procedure_block, wrapper_type);
print_hoc_py_wrapper_call_impl(function_or_procedure_block, wrapper_type);

printer->pop_block();
}


void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_definitions() {
auto print_wrappers = [this](const auto& callables) {
for (const auto& callable: callables) {
print_hoc_py_wrapper_function_body(callable, InterpreterWrapper::HOC);
print_hoc_py_wrapper_function_body(callable, InterpreterWrapper::Python);
print_hoc_py_wrapper(callable, InterpreterWrapper::HOC);
print_hoc_py_wrapper(callable, InterpreterWrapper::Python);
}
};

Expand Down
32 changes: 30 additions & 2 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,37 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void print_function_procedure_helper(const ast::Block& node) override;


void print_hoc_py_wrapper_function_body(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);
/** Print the wrapper for calling FUNCION/PROCEDURES from HOC/Py.
*
* Usually the function is made up of the following parts:
* * Print setup code `inst`, etc.
* * Print code to call the function and return.
*/
void print_hoc_py_wrapper(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);

/** Print the setup code for HOC/Py wrapper.
*/
void print_hoc_py_wrapper_setup(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);


/** Print the code that calls the impl from the HOC/Py wrapper.
*/
void print_hoc_py_wrapper_call_impl(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);

/** Return the wrapper signature.
*
* Everything without the `{` or `;`. Roughly, as an example:
* <return_type> <function_name>(<internal_args>, <args>)
*
* were `<internal_args> is the list of arguments required by the
* codegen to be passed along, while <args> are the arguments of
* of the function as they appear in the MOD file.
*/
std::string hoc_py_wrapper_signature(const ast::Block* function_or_procedure_block,
InterpreterWrapper wrapper_type);

void print_hoc_py_wrapper_function_definitions();

Expand Down

0 comments on commit 380207b

Please sign in to comment.