Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix mypy errors
Browse files Browse the repository at this point in the history
alexfikl committed May 29, 2024
1 parent fbaa18f commit 8813bce
Showing 2 changed files with 24 additions and 31 deletions.
51 changes: 22 additions & 29 deletions pytential/linalg/hmatrix.py
Original file line number Diff line number Diff line change
@@ -21,10 +21,11 @@
"""

from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union

import numpy as np
import numpy.linalg as la
from scipy.sparse.linalg import LinearOperator

from arraycontext import PyOpenCLArrayContext, ArrayOrContainerT, flatten, unflatten
from meshmode.dof_array import DOFArray
@@ -35,13 +36,6 @@
from pytential.linalg.skeletonization import (
SkeletonizationWrangler, SkeletonizationResult)

try:
from scipy.sparse.linalg import LinearOperator
except ImportError:
# NOTE: scipy should be available (for interp_decomp), but just in case
class LinearOperator:
pass

import logging
logger = logging.getLogger(__name__)

@@ -124,7 +118,7 @@ def _update_skeleton_diagonal(
targets, sources = parent.skel_tgt_src_index

# FIXME: nicer way to do this?
mat = np.empty(skeleton.nclusters, dtype=object)
mat: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
for k in range(skeleton.nclusters):
D = skeleton.D[k].copy()

@@ -146,9 +140,9 @@ def _update_skeleton_diagonal(

def _update_skeletons_diagonal(
wrangler: "ProxyHierarchicalMatrixWrangler",
func: Callable[[SkeletonizationResult], np.ndarray],
func: Callable[[SkeletonizationResult], Optional[np.ndarray]],
) -> np.ndarray:
skeletons = np.empty(wrangler.skeletons.shape, dtype=object)
skeletons: np.ndarray = np.empty(wrangler.skeletons.shape, dtype=object)
skeletons[0] = wrangler.skeletons[0]

for i in range(1, wrangler.ctree.nlevels):
@@ -263,11 +257,14 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
else:
raise TypeError(f"unsupported input type: {type(x)}")

assert actx is None or isinstance(actx, PyOpenCLArrayContext)
result = apply_skeleton_forward_matvec(self, ary)

if isinstance(x, DOFArray):
assert actx is not None
result = unflatten(x, actx.from_numpy(result), actx)

return result
return result # type: ignore[return-value]


def apply_skeleton_forward_matvec(
@@ -276,7 +273,7 @@ def apply_skeleton_forward_matvec(
) -> ArrayOrContainerT:
from pytential.linalg.cluster import split_array
targets, sources = hmat.skeletons[0].tgt_src_index
x = split_array(ary, sources)
x = split_array(ary, sources) # type: ignore[arg-type]

# NOTE: this computes a telescoping product of the form
#
@@ -297,7 +294,7 @@ def apply_skeleton_forward_matvec(
#
# which gives back the desired product when we reach the leaf level again.

d_dot_x = np.empty(hmat.nlevels, dtype=object)
d_dot_x: np.ndarray = np.empty(hmat.nlevels, dtype=object)

# {{{ recurse down

@@ -307,8 +304,8 @@ def apply_skeleton_forward_matvec(
assert x.shape == (skeleton.nclusters,)
assert skeleton.tgt_src_index.shape[1] == sum([xi.size for xi in x])

d_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
r_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
d_dot_x_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
r_dot_x_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)

for i in range(skeleton.nclusters):
r_dot_x_k[i] = skeleton.R[i] @ x[i]
@@ -366,23 +363,26 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
else:
raise TypeError(f"unsupported input type: {type(x)}")

assert actx is None or isinstance(actx, PyOpenCLArrayContext)
result = apply_skeleton_backward_matvec(actx, self, ary)

if isinstance(x, DOFArray):
assert actx is not None
result = unflatten(x, actx.from_numpy(result), actx)

return result
return result # type: ignore[return-value]


def apply_skeleton_backward_matvec(
actx: PyOpenCLArrayContext,
actx: Optional[PyOpenCLArrayContext],
hmat: ProxyHierarchicalMatrix,
ary: ArrayOrContainerT,
) -> ArrayOrContainerT:
from pytential.linalg.cluster import split_array
targets, sources = hmat.skeletons[0].tgt_src_index

b = split_array(ary, targets)
r_dot_b = np.empty(hmat.nlevels, dtype=object)
b = split_array(ary, targets) # type: ignore[arg-type]
r_dot_b: np.ndarray = np.empty(hmat.nlevels, dtype=object)

# {{{ recurse down

@@ -412,7 +412,7 @@ def apply_skeleton_backward_matvec(
assert b.shape == (skeleton.nclusters,)
assert skeleton.tgt_src_index.shape[0] == sum([bi.size for bi in b])

dhat_dot_b_k = np.empty(skeleton.nclusters, dtype=object)
dhat_dot_b_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
for i in range(skeleton.nclusters):
dhat_dot_b_k[i] = (
skeleton.Dhat[i] @ (skeleton.R[i] @ (skeleton.invD[i] @ b[i]))
@@ -467,7 +467,7 @@ def build_hmatrix_by_proxy(
exprs: Union[sym.Expression, Iterable[sym.Expression]],
input_exprs: Union[sym.Expression, Iterable[sym.Expression]], *,
auto_where: Optional[sym.DOFDescriptorLike] = None,
domains: Optional[Iterable[sym.DOFDescriptorLike]] = None,
domains: Optional[Sequence[sym.DOFDescriptorLike]] = None,
context: Optional[Dict[str, Any]] = None,
id_eps: float = 1.0e-8,

@@ -483,13 +483,6 @@ def build_hmatrix_by_proxy(
_approx_nproxy: Optional[int] = None,
_proxy_radius_factor: Optional[float] = None,
) -> ProxyHierarchicalMatrixWrangler:
try:
import scipy # noqa: F401
except ImportError:
raise ImportError(
"The direct solver requires 'scipy' for the interpolative "
"decomposition used in skeletonization")

from pytential.symbolic.matrix import P2PClusterMatrixBuilder
from pytential.linalg.skeletonization import make_skeletonization_wrangler

4 changes: 2 additions & 2 deletions pytential/linalg/utils.py
Original file line number Diff line number Diff line change
@@ -443,8 +443,8 @@ def mnorm(x: np.ndarray, y: np.ndarray) -> "np.floating[Any]":
def skeletonization_matrix(
mat: np.ndarray, skeleton: "SkeletonizationResult",
) -> Tuple[np.ndarray, np.ndarray]:
D = np.empty(skeleton.nclusters, dtype=object)
S = np.empty((skeleton.nclusters, skeleton.nclusters), dtype=object)
D: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
S: np.ndarray = np.empty((skeleton.nclusters, skeleton.nclusters), dtype=object)

from itertools import product
for i, j in product(range(skeleton.nclusters), repeat=2):

0 comments on commit 8813bce

Please sign in to comment.