From 4b2a5cf4043ecdb3bcb84c2bdfdfb06fd9d5c332 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Fri, 15 Dec 2023 16:07:51 +0100 Subject: [PATCH] code-gen: Simplify `switch` Improves model code generation by collapsing cases with identical statements. I.e. ``` switch(a) case b: case c: statements; break; ``` instead of ``` switch(a) case b: statements; break; case c: statements; break; ``` For my current model of interest, containing many events, this significantly reduces the generated code: E.g.: ``` 16K my_model/deltasx.cpp 6,6M my_model_old/deltasx.cpp ``` Overall, for this model, I got from 204201 LOC down to 7936 LOC (i.e. -96%). --- python/sdist/amici/cxxcodeprinter.py | 43 ++++++++++++++++++++-------- python/sdist/amici/de_export.py | 2 +- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py index e6e377b331..b57f03a2c1 100644 --- a/python/sdist/amici/cxxcodeprinter.py +++ b/python/sdist/amici/cxxcodeprinter.py @@ -303,7 +303,9 @@ def get_switch_statement( indentation_step: Optional[str] = " " * 4, ): """ - Generate code for switch statement + Generate code for a C++ switch statement. + + Generate code for a C++ switch statements with a ``break`` after each case. :param condition: Condition for switch @@ -321,26 +323,43 @@ def get_switch_statement( :return: Code for switch expression as list of strings """ - lines = [] - if not cases: - return lines + return [] indent0 = indentation_level * indentation_step indent1 = (indentation_level + 1) * indentation_step indent2 = (indentation_level + 2) * indentation_step + + # try to find redundant statements and collapse those cases + # map statements to case expressions + cases_map: dict[tuple[str, ...], list[str]] = {} for expression, statements in cases.items(): if statements: - lines.extend( + statement_code = tuple( [ - f"{indent1}case {expression}:", *(f"{indent2}{statement}" for statement in statements), f"{indent2}break;", ] ) - - if lines: - lines.insert(0, f"{indent0}switch({condition}) {{") - lines.append(indent0 + "}") - - return lines + case_code = f"{indent1}case {expression}:" + + try: + # there is already a case with the same statement, append + cases_map[statement_code].append(case_code) + except KeyError: + # add new case + statement + cases_map[statement_code] = [case_code] + + if not cases_map: + return [] + + def get_lines(): + for statements, case_code in cases_map.items(): + yield from case_code + yield from statements + + return [ + f"{indent0}switch({condition}) {{", + *(get_lines()), + indent0 + "}", + ] diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 2694f753ad..e3ebbdcaea 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -2674,7 +2674,7 @@ def _get_unique_root( return None for root in roots: - if sp.simplify(root_found - root.get_val()) == 0: + if (root_found - root.get_val()).is_zero: return root.get_id() # create an event for a new root function