Skip to content

Commit

Permalink
Updating test cases to massively improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
peekxc committed Nov 22, 2024
1 parent 53420ac commit 5fa1ffb
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 180 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![build_windows](https://img.shields.io/github/actions/workflow/status/peekxc/primate/build_windows.yml?logo=windows&logoColor=white)](https://github.com/peekxc/primate/actions/workflows/wheels.yml)
[![build_linux](https://img.shields.io/github/actions/workflow/status/peekxc/primate/build_linux.yml?logo=linux&logoColor=white)](https://github.com/peekxc/primate/actions/workflows/wheels.yml)

[![Tests](https://badgen.net/github/checks/peekxc/primate/pythran_overhaul?label=tests)]
[![Tests](https://badgen.net/github/checks/peekxc/primate/pythran_overhaul?label=tests)](https://cirrus-ci.com/github/peekxc/primate/pythran_overhaul)
[![Python versions](https://badgen.net/badge/python/3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13/blue)](https://github.com/peekxc/primate/actions)
[![PyPI Version](https://badgen.net/pypi/v/scikit-primate)](https://pypi.org/project/scikit-primate/)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
Expand Down
13 changes: 12 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ license = {file = "LICENSE"}
[project.optional-dependencies]
test = ["pytest", "pytest-cov", "pytest-benchmark"] # "bokeh"
doc = ["quartodoc"]
dev = ['meson-python', 'wheel', 'ninja', 'pybind11', 'numpy']
dev = ['meson-python', 'pybind11', 'numpy']

[tool.meson-python.args]
setup = ['--python.bytecompile=2']
Expand All @@ -48,6 +48,17 @@ addopts = "-rA"
testpaths = ["tests"]
norecursedirs = ["docs", "*.egg-info", ".git", "build", "dist"]

[tool.coverage.report]
omit = [
"src/**/__init__.py",
"src/**/fttr.py",
"src/**/tqli.py",
"src/**/plotting.py",
]
exclude_also = [
"def __repr__",
"def _.*"
]
## NOTE: https://github.com/pypi/warehouse/blob/8060bfa3cb00c7f68bb4b10021b5361e92a04017/warehouse/forklift/legacy.py#L70-L72
## PyPI limits file sizes to 100 MB and project sizes to 10 GB
[tool.cibuildwheel]
Expand Down
13 changes: 4 additions & 9 deletions src/primate/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

def diag(
A: Union[sp.sparse.linalg.LinearOperator, np.ndarray],
maxiter: int = 200,
pdf: Union[str, Callable] = "rademacher",
converge: Union[str, ConvergenceCriterion] = "tolerance",
seed: Union[int, np.random.Generator, None] = None,
Expand All @@ -21,7 +20,7 @@ def diag(
) -> Union[float, tuple]:
r"""Estimates the diagonal of a symmetric `A` via the Girard-Hutchinson estimator.
This function uses up to `maxiter` random vectors to estimate of the diagonal of $A$ via the approximation:
This function random vectors to estimate of the diagonal of $A$ via the approximation:
$$ \mathrm{diag}(A) = \sum_{i=1}^n e_i^T A e_i \approx n^{-1}\sum_{i=1}^n v^T A v $$
When $v$ are isotropic, this approximation forms an unbiased estimator of the diagonal of $A$.
Expand All @@ -32,7 +31,6 @@ def diag(
Parameters:
A: real symmetric matrix or linear operator.
maxiter: Maximum number of random vectors to sample for the trace estimate.
pdf: Choice of zero-centered distribution to sample random vectors from.
estimator: Type of estimator to use for convergence testing. See details.
seed: Seed to initialize the `rng` entropy source. Set `seed` > -1 for reproducibility.
Expand Down Expand Up @@ -68,30 +66,27 @@ def diag(
return 0.0 if not full else (0.0, EstimatorResult(0.0, False, "", 0, []))

## Commence the Monte-Carlo iterations
converged = False
if full or callback is not None:
numer, denom = np.zeros(N, dtype=f_dtype), np.zeros(N, dtype=f_dtype)
result = EstimatorResult(numer, False, "", 0, [])
while not converged:
result = EstimatorResult(numer, False, converge, 0, {})
while not converge(estimator):
v = pdf(size=(N, 1), seed=rng).astype(f_dtype)
u = (A @ v).ravel()
numer += u * v.ravel()
denom += np.square(v.ravel())
estimator.update(np.atleast_2d(numer / denom))
converged = estimator.converged() or len(estimator) >= maxiter
result.update(estimator)
if callback is not None:
callback(result)
return (estimator.estimate, result)
else:
numer, denom = np.zeros(N, dtype=f_dtype), np.zeros(N, dtype=f_dtype)
while not converged:
while not converge(estimator):
v = pdf(size=(N, 1), seed=rng).astype(f_dtype)
u = (A @ v).ravel()
numer += u * v.ravel()
denom += np.square(v.ravel())
estimator.update(np.atleast_2d(numer / denom))
converged = estimator.converged() or len(estimator) >= maxiter
return estimator.estimate


Expand Down
44 changes: 22 additions & 22 deletions src/primate/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,30 @@ def update(self, x: Union[float, np.ndarray], **kwargs: dict): ...
def estimate(self) -> Union[float, np.ndarray]: ...


class ConvergenceCriterion(Callable):
"""Generic stopping criteria for sequences."""

def __init__(self, operation: Callable):
assert callable(operation)
self._operation = operation

def __or__(self, other: "ConvergenceCriterion"):
# other_op = other._operation if isinstance(other, ConvergenceCriterion) else (lambda: other)
return ConvergenceCriterion(lambda est: or_(self(est), other(est)))

def __and__(self, other: "ConvergenceCriterion"):
# other_op = other._operation if isinstance(other, ConvergenceCriterion) else (lambda: other)
return ConvergenceCriterion(lambda est: and_(self(est), other(est)))

def __call__(self, est: Estimator) -> bool:
return self._operation(est)


@dataclass
class EstimatorResult:
estimate: Union[float, np.ndarray]
estimator: Estimator
converged: bool = False
criterion: Union[ConvergenceCriterion, str, None] = None
status: str = ""
nit: int = 0
info: dict = field(default_factory=dict)
Expand Down Expand Up @@ -129,25 +148,6 @@ def converged(self) -> bool:
return self.margin_of_error <= self.atol or rel_error <= self.rtol


class ConvergenceCriterion(Callable):
"""Generic stopping criteria for sequences."""

def __init__(self, operation: Callable):
assert callable(operation)
self._operation = operation

def __or__(self, other: "ConvergenceCriterion"):
# other_op = other._operation if isinstance(other, ConvergenceCriterion) else (lambda: other)
return ConvergenceCriterion(lambda est: or_(self(est), other(est)))

def __and__(self, other: "ConvergenceCriterion"):
# other_op = other._operation if isinstance(other, ConvergenceCriterion) else (lambda: other)
return ConvergenceCriterion(lambda est: and_(self(est), other(est)))

def __call__(self, est: Estimator) -> bool:
return self._operation(est)


class CountCriterion(ConvergenceCriterion):
"""Convergence criterion that returns TRUE when above a given count."""

Expand Down Expand Up @@ -270,9 +270,9 @@ def convergence_criterion(criterion: Union[str, ConvergenceCriterion], **kwargs)
return criterion
criterion = criterion.lower()
if criterion == "count":
cc = CountCriterion(**{k: v for k, v in kwargs.items() if k in {"ord", "atol", "rtol"}})
cc = CountCriterion(**{k: v for k, v in kwargs.items() if k in {"count"}})
elif criterion == "tolerance":
cc = ToleranceCriterion(**{k: v for k, v in kwargs.items() if k in {"count"}})
cc = ToleranceCriterion(**{k: v for k, v in kwargs.items() if k in {"ord", "atol", "rtol"}})
elif criterion == "confidence":
cc = ConfidenceCriterion(**{k: v for k, v in kwargs.items() if k in {"confidence", "atol", "rtol"}})
elif criterion == "knee":
Expand Down
6 changes: 3 additions & 3 deletions src/primate/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def rayleigh_ritz(A, deg: Optional[int] = None, return_eigenvectors: bool = Fals
## Also: Chen's Krylov-aware method


def _lanczos_py(A: np.ndarray, v0: np.ndarray, k: int, tol: float) -> int:
def _lanczos_py(A: np.ndarray, v0: np.ndarray, k: int, tol: float) -> int: # pragma: no cover
"""Base lanczos algorithm, for establishing a baseline"""
n = A.shape[0]
assert k <= n, "Can perform at most k = n iterations"
Expand All @@ -165,7 +165,7 @@ def _lanczos_py(A: np.ndarray, v0: np.ndarray, k: int, tol: float) -> int:
return alpha, beta


def _orth_vector(v, U, start_idx, p, reverse=False):
def _orth_vector(v, U, start_idx, p, reverse=False): # pragma: no cover
n = U.shape[0]
m = U.shape[1]
tol = 2 * np.finfo(U.dtype).eps * np.sqrt(n)
Expand All @@ -179,7 +179,7 @@ def _orth_vector(v, U, start_idx, p, reverse=False):
v -= (s_proj / u_norm) * U[:, i]


def _lanczos_recurrence(A, q, deg, rtol, orth, V, ncv):
def _lanczos_recurrence(A, q, deg, rtol, orth, V, ncv): # pragma: no cover
n, m = A.shape
residual_tol = np.sqrt(n) * rtol

Expand Down
113 changes: 0 additions & 113 deletions src/primate/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,116 +75,3 @@ def lanczos_quadrature(
np.copyto(nodes, theta)
np.copyto(weights, tau)
return theta, tau


def spectral_density(
A: Union[LinearOperator, np.ndarray],
bins: Union[int, np.ndarray] = 100,
bw: Union[float, str] = "scott",
deg: int = 20,
rtol: float = 0.01,
verbose: bool = False,
info: bool = False,
plot: bool = False,
**kwargs,
):
"""Estimates the spectral density of an operator via stochastic Lanczos quadrature.
Parameters:
A: real symmetric matrix or linear operator.
bins: number of domain points to accumulate density
bw: bandwidth value or rule
deg: degree of each quadrature approximation
rtol: relative stopping tolerance
verbose: whether to report various statistics
Return:
(density, bins) = Estimate of the spectral density at domain points 'bins'
"""
## First probe info about the spectrum via a single adaptive Krylov expansion
n = A.shape[0]
# spec_radius = eigsh(A, k=1, which="LM")
# spec_radius, info = spectral_radius(A, full_output=True)
# min_rw = np.min(info["ritz_values"])
# fun = "identity" if fun is None or (isinstance(fun, str) and fun == "identity") else fun
# fun = param_callable(fun, kwargs) if isinstance(fun, str) else fun
# assert isinstance(fun(1.0), Number), "Function must return a real number."

## Parameterize the kernel
## Automatic bandwidth determination for "bimodal or multi-modal distributions tend to be oversmoothed."
N = deg * n
if bw == "scott":
h = N ** (-1 / 5)
h **= 2 # to prevent over-smoothing
elif bw == "silverman":
h = (N * 3 / 4) ** (-1 / 7)
h **= 2 # to prevent over-smoothing
else:
assert isinstance(bw, Number), f"Invalid bandwidth estimator '{bw}'; must be 'scott', 'silverman', or float."
h = bw
K = lambda u: np.exp(-0.5 * u**2)

## Prepare the bins for the estimate
# bins = np.linspace(min_rw, spec_radius, int(bins)) if isinstance(bins, Number) else np.asarray(bins)
bins = np.linspace(0, 1, int(bins), endpoint=True)
n_bins = len(bins)
spectral_density = np.zeros(n_bins) # accumulate density estimate
density_residual = np.zeros(n_bins) # difference in density per iteration
min_bins = np.inf * np.ones(n_bins) # min value encountered per bin

## Begin sampling stochastic quadrature estimates
rel_change, jj = np.inf, 0
trace_samples = array("f")
while rel_change > rtol and jj < A.shape[0]:
## Acquire a quadrature estimate
## TODO: The inner sum can likely be vectorized with an einsum or something
alpha, beta = lanczos(A, deg=deg, **kwargs)
nodes, weights = lanczos_quadrature(alpha, beta)
density_residual.fill(0)
for i, t in enumerate(nodes):
density_residual += weights[i] * K((bins - t) / h) # weights[i] * c # Note constant 'c' can be dropped

# density_residual = np.sum(weights * K((bins[:,np.newaxis] - nodes) / h), axis=1)
# np.sum(weights * K((bins[:,np.newaxis] - nodes) / h), axis=1)
# np.sum(weights * (bins[:,np.newaxis] - nodes), axis=1)
# np.einsum('i,ji,j->j', weights, bins[:, np.newaxis] - nodes, np.ones_like(bins))

## Maintain a minimum ritz estimate per bin to estimate spectral gap
bin_ind = np.clip(np.digitize(nodes, bins), 0, n_bins - 1)
min_bins[bin_ind] = np.minimum(min_bins[bin_ind], nodes)

## Accumulate the spectral density
spectral_density += density_residual
jj += 1
if jj > 2:
w1 = (spectral_density - density_residual) / np.sum(spectral_density - density_residual)
w2 = spectral_density / np.sum(spectral_density)
rel_change = np.mean(np.abs((w1 - w2) / np.where(w1 > 0, w1, 1.0)))

## Keep trace of the spectral sum each iteration
trace_samples.append(np.sum(weights * nodes * n))

## Normalize such it density approx. integrates to 1
spectral_density /= np.sum(spectral_density) * np.diff(bins[:2])

## Plot if requested
if plot:
from bokeh.plotting import figure, show

p = figure(width=700, height=300, title=f"Estimated spectral density (bw = {h:.4f}, n_samples = {jj})")
p.scatter(bins, spectral_density)
p.line(bins, spectral_density)
# y_lb = np.min(spectral_density) - np.ptp(spectral_density) * 0.025
# p.scatter(ew, , marker='plus', color='red', fill_alpha=0.25, line_width=0, size=6)
show(p)

if info:
info_dict = {
"trace": np.mean(trace_samples),
"rtol": rtol,
"quad_est": trace_samples,
"bandwidth": h,
"n_samples": jj,
}
return (spectral_density, bins), info_dict
return (spectral_density, bins)
24 changes: 8 additions & 16 deletions src/primate/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@
from scipy.sparse import sparray
from scipy.sparse.linalg import LinearOperator

from .estimators import (
ConfidenceCriterion,
ConvergenceCriterion,
EstimatorResult,
MeanEstimator,
ToleranceCriterion,
convergence_criterion,
)
from .estimators import ConvergenceCriterion, EstimatorResult, MeanEstimator, convergence_criterion
from .linalg import update_trinv
from .operators import _operator_checks
from .random import isotropic
Expand Down Expand Up @@ -132,11 +125,11 @@ def hutch(

## Catch degenerate case
if np.prod(A.shape) == 0:
return 0.0 if not full else (0.0, EstimatorResult(0.0, False, "", 0, []))
return 0.0 if not full else (0.0, EstimatorResult(0.0, False, converge, 0, {}))

## Commence the Monte-Carlo iterations
if full or callback is not None:
result = EstimatorResult(0.0, False, "", 0, [])
result = EstimatorResult(0.0, False, converge, 0, {})
callback = lambda x: x if callback is None else callback
while not converge(estimator):
v = pdf(size=(N, batch)).astype(f_dtype)
Expand Down Expand Up @@ -165,7 +158,6 @@ def hutchpp(
A: Matrix or LinearOperator to estimate the trace of.
m: number of matvecs to use. If not given, defaults to `n // 3`.
batch: currently unused.
mode:
"""
f_dtype = _operator_checks(A)
N: int = A.shape[0]
Expand Down Expand Up @@ -207,14 +199,14 @@ def hutchpp(
if not full:
return tr_rng + tr_defl
else:
result = EstimatorResult(0.0, False, "", 0, [])
result = EstimatorResult(0.0, False, None, 0, {})
result.estimate = tr_rng + tr_defl
result.nit = 2 * nb
result.samples = np.concatenate([rng_ests, defl_ests])
return result.estimate, result


def __xtrace(W: np.ndarray, Z: np.ndarray, Q: np.ndarray, R: np.ndarray, R_inv: np.ndarray, pdf: str):
def _xtrace(W: np.ndarray, Z: np.ndarray, Q: np.ndarray, R: np.ndarray, R_inv: np.ndarray, pdf: str):
"""Helper for xtrace function.
Parameters:
Expand Down Expand Up @@ -280,7 +272,7 @@ def xtrace(
**kwargs: Additional keyword arguments to parameterize the convergence criterion.
Returns:
Estimate the trace of $f(A)$. If `info = True`, additional information about the computation is also returned.
Estimate the trace of `A`. If `info = True`, additional information about the computation is also returned.
"""

from scipy.linalg import qr_insert
Expand All @@ -300,7 +292,7 @@ def xtrace(
## Commence the batched-iteration
estimate = np.inf
it = 0
result = EstimatorResult(0.0, False, "", 0, [])
result = EstimatorResult(0.0, False, None, 0, {})
rng = np.random.default_rng(seed)
while (it * batch) < A.shape[1]: # err >= (error_atol + error_rtol * abs(t)):
## Determine number of new sample vectors to generate
Expand All @@ -318,7 +310,7 @@ def xtrace(
Z = np.c_[Z, A @ Q[:, -ns:]]

## Expand the subspace
estimate, t_samples, err = __xtrace(W, Z, Q, R, R_inv, pdf)
estimate, t_samples, err = _xtrace(W, Z, Q, R, R_inv, pdf)
it += 1

if full:
Expand Down
Loading

0 comments on commit 5fa1ffb

Please sign in to comment.