Skip to content

Commit

Permalink
Enable formatting Python files. (#1207)
Browse files Browse the repository at this point in the history
* Format Python files with `black`.
* Enable `black` formatter.
  • Loading branch information
1uc authored Mar 11, 2024
1 parent 9292c18 commit 40aa0bf
Show file tree
Hide file tree
Showing 21 changed files with 379 additions and 171 deletions.
6 changes: 6 additions & 0 deletions .bbp-project.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ tools:
match:
- ext/.*
- src/language/templates/*
Black:
enable: True
version: ~=24.2.0
include:
match:
- .*\.py$
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
(master_doc, "nmodl.tex", "nmodl Documentation", "BlueBrain HPC team", "manual")
]

imgmath_image_format = 'svg'
imgmath_image_format = "svg"
imgmath_embed = True
imgmath_font_size = 14

Expand Down
1 change: 1 addition & 0 deletions python/nmodl/ast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module for vizualization of NMODL abstract syntax trees (ASTs).
"""

import getpass
import json
import os
Expand Down
79 changes: 61 additions & 18 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# accessed through regular imports
major, minor = (int(v) for v in sp.__version__.split(".")[:2])
if major >= 1 and minor >= 7:
known_functions = import_module('sympy.printing.c').known_functions_C99
known_functions = import_module("sympy.printing.c").known_functions_C99
else:
known_functions = import_module('sympy.printing.ccode').known_functions_C99
known_functions.pop('Abs')
known_functions['abs'] = 'fabs'
known_functions = import_module("sympy.printing.ccode").known_functions_C99
known_functions.pop("Abs")
known_functions["abs"] = "fabs"


if not ((major >= 1) and (minor >= 2)):
Expand All @@ -29,7 +29,18 @@
# Some functions are protected inside sympy, if user has declared such a function, it will fail
# because sympy will try to use its own internal one.
# Rename it before and after to a single name
forbidden_var = ["beta", "gamma", "uppergamma", "lowergamma", "polygamma", "loggamma", "digamma", "trigamma"]
forbidden_var = [
"beta",
"gamma",
"uppergamma",
"lowergamma",
"polygamma",
"loggamma",
"digamma",
"trigamma",
]


def search_and_replace_protected_functions_to_sympy(eqs, function_calls):
for c in function_calls:
if c in forbidden_var:
Expand All @@ -38,13 +49,15 @@ def search_and_replace_protected_functions_to_sympy(eqs, function_calls):
eqs = [re.sub(r, f, x) for x in eqs]
return eqs


def search_and_replace_protected_functions_from_sympy(eqs, function_calls):
for c in function_calls:
if c in forbidden_var:
r = f"_sympy_{c}_fun"
eqs = [re.sub(r, f"{c}", x) for x in eqs]
return eqs


def _get_custom_functions(fcts):
custom_functions = {}
for f in fcts:
Expand Down Expand Up @@ -143,13 +156,16 @@ def _sympify_eqs(eq_strings, state_vars, vars):
for state_var in state_vars:
sympy_state_vars.append(sp.sympify(state_var, locals=sympy_vars))
eqs = [
(sp.sympify(eq.split("=", 1)[1], locals=sympy_vars)
- sp.sympify(eq.split("=", 1)[0], locals=sympy_vars)).expand()
(
sp.sympify(eq.split("=", 1)[1], locals=sympy_vars)
- sp.sympify(eq.split("=", 1)[0], locals=sympy_vars)
).expand()
for eq in eq_strings
]

return eqs, sympy_state_vars, sympy_vars


def _interweave_eqs(F, J):
"""Interweave F and J equations so that they are printed in code
rowwise from the equation J x = F. For example:
Expand Down Expand Up @@ -199,13 +215,21 @@ def _interweave_eqs(F, J):
n = len(F)
for i, expr in enumerate(F):
code.append(expr)
for j in range(i * n, (i+1) * n):
for j in range(i * n, (i + 1) * n):
code.append(J[j])

return code


def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_prefix, small_system=False, do_cse=False):
def solve_lin_system(
eq_strings,
vars,
constants,
function_calls,
tmp_unique_prefix,
small_system=False,
do_cse=False,
):
"""Solve linear system of equations, return solution as C code.
If system is small (small_system=True, typically N<=3):
Expand Down Expand Up @@ -233,7 +257,9 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
vars: list of strings containing new local variables
"""

eq_strings = search_and_replace_protected_functions_to_sympy(eq_strings, function_calls)
eq_strings = search_and_replace_protected_functions_to_sympy(
eq_strings, function_calls
)

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)
custom_fcts = _get_custom_functions(function_calls)
Expand All @@ -246,7 +272,9 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
solution_vector = sp.linsolve(eqs, state_vars).args[0]
if do_cse:
# generate prefix for new local vars that avoids clashes
my_symbols = sp.utilities.iterables.numbered_symbols(prefix=tmp_unique_prefix + '_')
my_symbols = sp.utilities.iterables.numbered_symbols(
prefix=tmp_unique_prefix + "_"
)
sub_exprs, simplified_solution_vector = sp.cse(
solution_vector,
symbols=my_symbols,
Expand All @@ -255,10 +283,14 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
)
for var, expr in sub_exprs:
new_local_vars.append(sp.ccode(var))
code.append(f"{var} = {sp.ccode(expr.evalf(), user_functions=custom_fcts)}")
code.append(
f"{var} = {sp.ccode(expr.evalf(), user_functions=custom_fcts)}"
)
solution_vector = simplified_solution_vector[0]
for var, expr in zip(state_vars, solution_vector):
code.append(f"{sp.ccode(var)} = {sp.ccode(expr.evalf(), contract=False, user_functions=custom_fcts)}")
code.append(
f"{sp.ccode(var)} = {sp.ccode(expr.evalf(), contract=False, user_functions=custom_fcts)}"
)
else:
# large linear system: construct and return matrix J, vector F such that
# J X = F is the linear system to be solved for X by e.g. LU factorization
Expand All @@ -267,13 +299,17 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
# construct vector F
vecFcode = []
for i, expr in enumerate(vecF):
vecFcode.append(f"F[{i}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}")
vecFcode.append(
f"F[{i}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}"
)
# construct matrix J
vecJcode = []
for i, expr in enumerate(matJ):
# todo: fix indexing to be ascending order
flat_index = matJ.rows * (i % matJ.rows) + (i // matJ.rows)
vecJcode.append(f"J[{flat_index}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}")
vecJcode.append(
f"J[{flat_index}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}"
)
# interweave
code = _interweave_eqs(vecFcode, vecJcode)

Expand All @@ -299,7 +335,9 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
List of strings containing assignment statements
"""

eq_strings = search_and_replace_protected_functions_to_sympy(eq_strings, function_calls)
eq_strings = search_and_replace_protected_functions_to_sympy(
eq_strings, function_calls
)

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)
custom_fcts = _get_custom_functions(function_calls)
Expand All @@ -310,13 +348,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):

vecFcode = []
for i, eq in enumerate(eqs):
vecFcode.append(f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}")
vecFcode.append(
f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}"
)

vecJcode = []
for i, j in itertools.product(range(jacobian.rows), range(jacobian.cols)):
flat_index = i + jacobian.rows * j

rhs = sp.ccode(jacobian[i,j].simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)
rhs = sp.ccode(
jacobian[i, j].simplify().subs(X_vec_map).evalf(),
user_functions=custom_fcts,
)
vecJcode.append(f"J[{flat_index}] = {rhs}")

# interweave
Expand Down
1 change: 1 addition & 0 deletions src/language/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# SPDX-License-Identifier: Apache-2.0


class Argument:
"""Utility class for holding all arguments for node classes"""

Expand Down
22 changes: 16 additions & 6 deletions src/language/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def jinja_template(self, path):
return self.jinja_env.get_template(name)

def _cmake_deps_task(self, tasks):
""""Construct the JinjaTask generating the CMake file exporting all dependencies
""" "Construct the JinjaTask generating the CMake file exporting all dependencies
Args:
tasks: list of JinjaTask objects
Expand Down Expand Up @@ -196,12 +196,18 @@ def workload(self):
task = JinjaTask(
app=self,
input=filepath,
output=self.base_dir / sub_dir / "pynode_{}.cpp".format(chunk_k),
output=self.base_dir
/ sub_dir
/ "pynode_{}.cpp".format(chunk_k),
context=dict(
nodes=self.nodes[
chunk_k * chunk_length : (chunk_k + 1) * chunk_length
chunk_k
* chunk_length : (chunk_k + 1)
* chunk_length
],
setup_pybind_method="init_pybind_classes_{}".format(chunk_k),
setup_pybind_method="init_pybind_classes_{}".format(
chunk_k
),
),
extradeps=extradeps[filepath],
)
Expand All @@ -212,7 +218,11 @@ def workload(self):
app=self,
input=filepath,
output=self.base_dir / sub_dir / filepath.name,
context=dict(nodes=self.nodes, node_info=node_info, **extracontext[filepath]),
context=dict(
nodes=self.nodes,
node_info=node_info,
**extracontext[filepath],
),
extradeps=extradeps[filepath],
)
tasks.append(task)
Expand All @@ -235,7 +245,7 @@ class JinjaTask(
"""

def execute(self):
""""Perform the Jinja task
""" "Perform the Jinja task
Execute Jinja renderer if the output file is out-of-date.
Expand Down
Loading

0 comments on commit 40aa0bf

Please sign in to comment.