Skip to content

Commit

Permalink
code-gen: Simplify switch
Browse files Browse the repository at this point in the history
Improves model code generation by collapsing cases with identical statements.

I.e.
```
switch(a)
 case b:
 case c:
  statements;
  break;
```
instead of

```
switch(a)
 case b:
  statements;
  break;
 case c:
  statements;
  break;
```

For my current model of interest, containing many events, this significantly reduces the generated code:

E.g.:
```
 16K my_model/deltasx.cpp
6,6M my_model_old/deltasx.cpp
```

Overall, for this model, I got from 204201 LOC down to 7936 LOC (i.e. -96%).
  • Loading branch information
dweindl committed Dec 15, 2023
1 parent 594b07e commit 4b2a5cf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
43 changes: 31 additions & 12 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def get_switch_statement(
indentation_step: Optional[str] = " " * 4,
):
"""
Generate code for switch statement
Generate code for a C++ switch statement.
Generate code for a C++ switch statements with a ``break`` after each case.
:param condition:
Condition for switch
Expand All @@ -321,26 +323,43 @@ def get_switch_statement(
:return:
Code for switch expression as list of strings
"""
lines = []

if not cases:
return lines
return []

indent0 = indentation_level * indentation_step
indent1 = (indentation_level + 1) * indentation_step
indent2 = (indentation_level + 2) * indentation_step

# try to find redundant statements and collapse those cases
# map statements to case expressions
cases_map: dict[tuple[str, ...], list[str]] = {}
for expression, statements in cases.items():
if statements:
lines.extend(
statement_code = tuple(
[
f"{indent1}case {expression}:",
*(f"{indent2}{statement}" for statement in statements),
f"{indent2}break;",
]
)

if lines:
lines.insert(0, f"{indent0}switch({condition}) {{")
lines.append(indent0 + "}")

return lines
case_code = f"{indent1}case {expression}:"

try:
# there is already a case with the same statement, append
cases_map[statement_code].append(case_code)
except KeyError:
# add new case + statement
cases_map[statement_code] = [case_code]

if not cases_map:
return []

def get_lines():
for statements, case_code in cases_map.items():
yield from case_code
yield from statements

return [
f"{indent0}switch({condition}) {{",
*(get_lines()),
indent0 + "}",
]
2 changes: 1 addition & 1 deletion python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2674,7 +2674,7 @@ def _get_unique_root(
return None

for root in roots:
if sp.simplify(root_found - root.get_val()) == 0:
if (root_found - root.get_val()).is_zero:
return root.get_id()

# create an event for a new root function
Expand Down

0 comments on commit 4b2a5cf

Please sign in to comment.