Skip to content

Commit

Permalink
Update API using mypy linting (#72)
Browse files Browse the repository at this point in the history
* update API using mypy linting
* minimal python 3.7
  • Loading branch information
liubenyuan authored Nov 27, 2022
1 parent 960aefb commit 416ab6f
Show file tree
Hide file tree
Showing 18 changed files with 247 additions and 230 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pyeit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
steps:
- uses: actions/setup-python@v2
with:
Expand Down
21 changes: 18 additions & 3 deletions examples/fem_forward2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyeit.eit.fem import Forward
from pyeit.mesh.shape import thorax
from pyeit.mesh.wrapper import PyEITAnomaly_Circle
from pyeit.eit.interp2d import sim2pts, pdegrad

""" 0. build mesh """
n_el = 16 # nb of electrodes
Expand All @@ -19,7 +20,7 @@
# Mesh shape is specified with fd parameter in the instantiation, e.g : fd=thorax
mesh_obj = mesh.create(n_el, h0=0.1, fd=thorax)
else:
mesh_obj = mesh.create(n_el, h0=0.1)
mesh_obj = mesh.create(n_el, h0=0.05)
el_pos = mesh_obj.el_pos

# extract node, element, alpha
Expand All @@ -46,8 +47,7 @@
f = np.real(f)

""" 2. plot """
fig = plt.figure()
ax1 = fig.add_subplot(111)
fig, ax1 = plt.subplots(figsize=(9, 6))
# draw equi-potential lines
vf = np.linspace(min(f), max(f), 32)
# vf = np.sort(f[el_pos])
Expand Down Expand Up @@ -78,3 +78,18 @@
fig.set_size_inches(6, 6)
# fig.savefig('demo_bp.png', dpi=96)
plt.show()

ux, uy = pdegrad(pts, tri, f)
uf = ux**2 + uy**2
uf_pts = sim2pts(pts, tri, uf)
uf_logpwr = 10 * np.log10(uf_pts)

fig, ax = plt.subplots(figsize=(9, 6))
# Draw contour lines on an unstructured triangular grid.
ax.tripcolor(x, y, tri, uf_logpwr, cmap=plt.cm.viridis)
ax.tricontour(x, y, tri, uf_logpwr, 10, cmap=plt.cm.hot)
ax.set_aspect("equal")
ax.set_ylim([-1.2, 1.2])
ax.set_xlim([-1.2, 1.2])
ax.set_title("E field (logmag)")
plt.show()
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[mypy]
warn_return_any = True
warn_unused_configs = True
exclude = (build|app|scripts.in|scripts|examples)
ignore_missing_imports = True

18 changes: 9 additions & 9 deletions pyeit/eit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def __init__(
self.fwd = EITForward(mesh=mesh, protocol=protocol)

# initialize other parameters
self.params = None
self.xg = None
self.yg = None
self.mask = None
self.params: dict = {}
self.xg: np.ndarray = np.zeros(mesh.n_elems)
self.yg: np.ndarray = np.zeros(mesh.n_elems)
self.mask: np.ndarray = np.zeros(mesh.n_elems)
# user must run solver.setup() manually to get correct H
self.H = None
self.H: np.ndarray = np.zeros((mesh.n_elems, protocol.n_meas), dtype=mesh.dtype)
self.is_ready = False

@property
Expand All @@ -77,7 +77,7 @@ def setup(self) -> None:
"""

@abstractmethod
def _compute_h(self) -> np.ndarray:
def _compute_h(self):
"""
Compute H matrix for solving inv problem
Expand All @@ -96,7 +96,7 @@ def solve(
v0: np.ndarray,
normalize: bool = False,
log_scale: bool = False,
) -> np.ndarray:
):
"""
Dynamic imaging (conductivities imaging)
Expand Down Expand Up @@ -128,7 +128,7 @@ def solve(
ds = np.exp(ds) - 1.0
return ds

def map(self, dv: np.ndarray) -> np.ndarray:
def map(self, dv: np.ndarray):
"""
(NOT USED, Deprecated?) simple mat using projection matrix
Expand Down Expand Up @@ -167,7 +167,7 @@ def _check_solver_is_ready(self) -> None:
msg = "User must first run {type(self).__name__}.setup() before imaging."
raise SolverNotReadyError(msg)

def _normalize(self, v1: np.ndarray, v0: np.ndarray) -> np.ndarray:
def _normalize(self, v1: np.ndarray, v0: np.ndarray):
"""
Normalize current frame using the amplitude of the reference frame.
Boundary measurements v are complex-valued, we can use the real part of v,
Expand Down
14 changes: 8 additions & 6 deletions pyeit/eit/bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Distributed under the (new) BSD License. See LICENSE.txt for more info.
from __future__ import division, absolute_import, print_function, annotations

from typing import Union
from typing import Union, Optional
import numpy as np
from .base import EitBase

Expand All @@ -14,7 +14,9 @@ class BP(EitBase):
"""A naive inversion of (Euclidean) back projection."""

def setup(
self, weight: str = "none", perm: Union[int, float, np.ndarray] = None
self,
weight: str = "none",
perm: Optional[Union[int, float, complex, np.ndarray]] = None,
) -> None:
"""
Setup BP solver
Expand All @@ -34,7 +36,7 @@ def setup(
self.H = self._compute_h(b_matrix=self.B)
self.is_ready = True

def _compute_h(self, b_matrix: np.ndarray) -> np.ndarray:
def _compute_h(self, b_matrix: np.ndarray) -> np.ndarray: # type: ignore[override]
"""
Compute H matrix for BP solver
Expand All @@ -53,7 +55,7 @@ def _compute_h(self, b_matrix: np.ndarray) -> np.ndarray:
b_matrix = weights * b_matrix
return b_matrix.T

def solve_gs(self, v1: np.ndarray, v0: np.ndarray) -> np.ndarray:
def solve_gs(self, v1: np.ndarray, v0: np.ndarray):
"""
Solving using gram-schmidt orthogonalization
Expand All @@ -79,7 +81,7 @@ def solve_gs(self, v1: np.ndarray, v0: np.ndarray) -> np.ndarray:
vn = -(v1 - a * v0) / np.sign(v0.real)
return np.dot(self.H, vn.transpose())

def _normalize(self, v1: np.ndarray, v0: np.ndarray) -> np.ndarray:
def _normalize(self, v1: np.ndarray, v0: np.ndarray):
"""
redefine normalize for BP (without amplitude normalization) using
only the sign of v0.real. [experimental]
Expand All @@ -101,7 +103,7 @@ def _normalize(self, v1: np.ndarray, v0: np.ndarray) -> np.ndarray:
"""
return (v1 - v0) / np.sign(v0.real)

def _simple_weight(self, num_voltages: int) -> np.ndarray:
def _simple_weight(self, num_voltages: int):
"""
Build weighting matrix : simple, normalize by radius.
Expand Down
59 changes: 29 additions & 30 deletions pyeit/eit/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Distributed under the (new) BSD License. See LICENSE.txt for more info.
from __future__ import division, absolute_import, print_function, annotations

from typing import Tuple, Union
from typing import Tuple, Union, Optional
import numpy as np
import numpy.linalg as la
from scipy import sparse
Expand Down Expand Up @@ -40,7 +40,9 @@ def __init__(self, mesh: PyEITMesh) -> None:
self.se = calculate_ke(self.mesh.node, self.mesh.element)
self.assemble_pde(self.mesh.perm)

def assemble_pde(self, perm: Union[int, float, np.ndarray]) -> None:
def assemble_pde(
self, perm: Optional[Union[int, float, complex, np.ndarray]] = None
) -> None:
"""
assemble PDE
Expand All @@ -53,12 +55,16 @@ def assemble_pde(self, perm: Union[int, float, np.ndarray]) -> None:
"""
if perm is None:
return
perm = self.mesh.get_valid_perm(perm)
perm_array = self.mesh.get_valid_perm_array(perm)
self.kg = assemble(
self.se, self.mesh.element, perm, self.mesh.n_nodes, ref=self.mesh.ref_node
self.se,
self.mesh.element,
perm_array,
self.mesh.n_nodes,
ref=self.mesh.ref_node,
)

def solve(self, ex_line: np.ndarray = None) -> np.ndarray:
def solve(self, ex_line: np.ndarray = np.array([0, 1])):
"""
Calculate and compute the potential distribution (complex-valued)
corresponding to the permittivity distribution `perm ` for a
Expand Down Expand Up @@ -148,7 +154,7 @@ def _check_mesh_protocol_compatibility(

def solve_eit(
self,
perm: Union[int, float, np.ndarray] = None,
perm: Optional[Union[int, float, complex, np.ndarray]] = None,
) -> np.ndarray:
"""
EIT simulation, generate forward v measurements
Expand All @@ -165,9 +171,7 @@ def solve_eit(
simulated boundary voltage measurements; shape(n_exe*n_el,)
"""
self.assemble_pde(perm)
v = np.zeros(
(self.protocol.n_exc, self.protocol.n_meas), dtype=self.mesh.perm.dtype
)
v = np.zeros((self.protocol.n_exc, self.protocol.n_meas), dtype=self.mesh.dtype)
for i, ex_line in enumerate(self.protocol.ex_mat):
f = self.solve(ex_line)
v[i] = subtract_row(f[self.mesh.el_pos], self.protocol.meas_mat[i])
Expand All @@ -176,7 +180,7 @@ def solve_eit(

def compute_jac(
self,
perm: Union[int, float, np.ndarray] = None,
perm: Optional[Union[int, float, complex, np.ndarray]] = None,
normalize: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Expand Down Expand Up @@ -206,11 +210,9 @@ def compute_jac(
# calculate v, jac per excitation pattern (ex_line)
_jac = np.zeros(
(self.protocol.n_exc, self.protocol.n_meas, self.mesh.n_elems),
dtype=self.mesh.perm.dtype,
)
v = np.zeros(
(self.protocol.n_exc, self.protocol.n_meas), dtype=self.mesh.perm.dtype
dtype=self.mesh.dtype,
)
v = np.zeros((self.protocol.n_exc, self.protocol.n_meas), dtype=self.mesh.dtype)
for i, ex_line in enumerate(self.protocol.ex_mat):
f = self.solve(ex_line)
v[i] = subtract_row(f[self.mesh.el_pos], self.protocol.meas_mat[i])
Expand All @@ -221,7 +223,7 @@ def compute_jac(
_jac[i, :, e] = np.dot(np.dot(ri[:, ijk], self.se[e]), f[ijk])

# measurement protocol
jac = np.vstack(_jac)
jac = np.concatenate(_jac)
v0 = v.reshape(-1)

# Jacobian normalization: divide each row of J (J[i]) by abs(v0[i])
Expand All @@ -231,8 +233,8 @@ def compute_jac(

def compute_b_matrix(
self,
perm: Union[int, float, np.ndarray] = None,
) -> np.ndarray:
perm: Optional[Union[int, float, complex, np.ndarray]] = None,
):
"""
Compute back-projection mappings (smear matrix)
Expand All @@ -251,7 +253,8 @@ def compute_b_matrix(
self.assemble_pde(perm)
b_mat = np.zeros((self.protocol.n_exc, self.protocol.n_meas, self.mesh.n_nodes))

for i, ex_line in enumerate(self.protocol.ex_mat):
for i in range(self.protocol.n_exc):
ex_line = self.protocol.ex_mat[i]
f = self.solve(ex_line=ex_line)
f_el = f[self.mesh.el_pos]
# build bp projection matrix
Expand All @@ -260,10 +263,10 @@ def compute_b_matrix(
# 2. or, simply smear at the nodes using f
b_mat[i] = _smear(f, f_el, self.protocol.meas_mat[i])

return np.vstack(b_mat)
return np.concatenate(b_mat)


def _smear(f: np.ndarray, fb: np.ndarray, pairs: np.ndarray) -> np.ndarray:
def _smear(f: np.ndarray, fb: np.ndarray, pairs: np.ndarray):
"""
Build smear matrix B for bp for one exitation
Expand Down Expand Up @@ -329,7 +332,7 @@ def smear_nd(
f_max = np.repeat(f_max[:, :, np.newaxis], n_pts, axis=2)
f0 = np.repeat(f[:, :, np.newaxis], n_meas, axis=2)
f0 = f0.swapaxes(1, 2)
return (f_min < f0) & (f0 <= f_max)
return np.array((f_min < f0) & (f0 <= f_max))
else:
# Replacing the below code by a faster implementation in Numpy
def b_matrix_init(k):
Expand All @@ -338,7 +341,7 @@ def b_matrix_init(k):
return np.array(list(map(b_matrix_init, np.arange(f.shape[0]))))


def subtract_row(v: np.ndarray, meas_pattern: np.ndarray) -> np.ndarray:
def subtract_row(v: np.ndarray, meas_pattern: np.ndarray):
"""
Build the voltage differences on axis=1 using the meas_pattern.
v_diff[k] = v[i, :] - v[j, :]
Expand All @@ -362,7 +365,7 @@ def subtract_row(v: np.ndarray, meas_pattern: np.ndarray) -> np.ndarray:

def assemble(
ke: np.ndarray, tri: np.ndarray, perm: np.ndarray, n_pts: int, ref: int = 0
) -> np.ndarray:
):
"""
Assemble the stiffness matrix (using sparse matrix)
Expand Down Expand Up @@ -458,7 +461,7 @@ def calculate_ke(pts: np.ndarray, tri: np.ndarray) -> np.ndarray:
return ke_array


def _k_triangle(xy: np.ndarray) -> np.ndarray:
def _k_triangle(xy: np.ndarray):
"""
Given a point-matrix of an element, solving for Kij analytically
using barycentric coordinates (simplex coordinates)
Expand All @@ -485,12 +488,12 @@ def _k_triangle(xy: np.ndarray) -> np.ndarray:
return np.dot(s, s.T) / (4.0 * at)


def det2x2(s1: np.ndarray, s2: np.ndarray) -> float:
def det2x2(s1: np.ndarray, s2: np.ndarray):
"""Calculate the determinant of a 2x2 matrix"""
return s1[0] * s2[1] - s1[1] * s2[0]


def _k_tetrahedron(xy: np.ndarray) -> np.ndarray:
def _k_tetrahedron(xy: np.ndarray):
"""
Given a point-matrix of an element, solving for Kij analytically
using barycentric coordinates (simplex coordinates)
Expand Down Expand Up @@ -525,7 +528,3 @@ def _k_tetrahedron(xy: np.ndarray) -> np.ndarray:

# local (e for element) stiffness matrix
return np.dot(a, a.transpose()) / (36.0 * vt)


if __name__ == "__main__":
""""""
Loading

0 comments on commit 416ab6f

Please sign in to comment.