Skip to content

Commit

Permalink
Removing protocol from conergence criterion class
Browse files Browse the repository at this point in the history
  • Loading branch information
peekxc committed Nov 22, 2024
1 parent 7320734 commit 53420ac
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
5 changes: 2 additions & 3 deletions src/primate/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ def converged(self) -> bool:
return self.margin_of_error <= self.atol or rel_error <= self.rtol


@runtime_checkable
class ConvergenceCriterion(Protocol):
"""Protocol for generic stopping criteria for sequences."""
class ConvergenceCriterion(Callable):
"""Generic stopping criteria for sequences."""

def __init__(self, operation: Callable):
assert callable(operation)
Expand Down
17 changes: 9 additions & 8 deletions src/primate/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,17 @@ def xtrace(
"""Estimates the trace of `A` using the XTrace trace estimator.
Parameters:
A: all isotropic random vectors sampled thus far.
batch: the image A @ Q.
Q: orthogonal component of qr(A @ W)
R: upper-triangular component of qr(A @ W)
R_inv: inverse matrix of R.
pdf: the distribution with which `W` was sampled from.
A: real symmetric matrix or linear operator.
batch: Number of random vectors to sample at a time for batched matrix multiplication.
pdf: Choice of zero-centered distribution to sample random vectors from.
converge: Convergence criterion to test for estimator convergence. See details.
seed: Seed to initialize the `rng` entropy source. Set `seed` > -1 for reproducibility.
full: Whether to return additional information about the computation.
callback: Optional callable to execute after each batch of samples.
**kwargs: Additional keyword arguments to parameterize the convergence criterion.
Returns:
tuple (t, est, err) representing the average trace estimate.
Estimate the trace of $f(A)$. If `info = True`, additional information about the computation is also returned.
"""

from scipy.linalg import qr_insert
Expand Down Expand Up @@ -312,7 +314,6 @@ def xtrace(
y = A @ eta.T
Q, R = qr_insert(Q, R, u=y, k=Q.shape[1], which="col") # rcond=FLOAT_MIN
R_inv = update_trinv(R_inv, R[:, -1])

W = np.c_[W, N.T]
Z = np.c_[Z, A @ Q[:, -ns:]]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def test_xdiag():
errors = []
budget = np.linspace(2, A.shape[0], 10).astype(int)
for m in budget:
d = xdiag(A, m, pdf="signs", seed=rng)
d = xdiag(A, m, pdf="sphere", seed=rng)
errors.append(np.linalg.norm(np.diag(A) - d))
print(f"Error: {np.linalg.norm(np.diag(A) - d)}")
# print(f"Error: {np.linalg.norm(np.diag(A) - d)}")

y = np.array(errors)
B = np.c_[budget, np.ones(len(budget))]
Expand Down

0 comments on commit 53420ac

Please sign in to comment.