diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 4291c405..5f662781 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -98,6 +98,32 @@ def proper_trajectory(trajectory, normalize="pi"): return new_traj +def power_method(max_iter, operator, norm_func=None, x0=None): + """Power method to find the Lipschitz constant of an operator.""" + + def AHA(x): + return operator.adj_op(operator.op(x)) + + if norm_func is None: + norm_func = np.linalg.norm + if x0 is None: + x = np.random.random(operator.shape).astype(operator.cpx_dtype) + x_norm = norm_func(x) + x /= x_norm + for i in range(max_iter): # noqa: B007 + x_new = AHA(x) + x_new_norm = norm_func(x_new) + x_new /= x_new_norm + if abs(x_norm - x_new_norm) < 1e-6: + break + x_norm = x_new_norm + x = x_new + + if i == max_iter - 1: + warnings.warn("Lipschitz constant did not converge") + return x_new_norm + + class MethodRegister: """ A Decorator to register methods of the same type in dictionnaries. diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 9650bea2..c3c4699b 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -9,6 +9,7 @@ from functools import partial import warnings import numpy as np +from mrinufft._utils import _power_method from mrinufft.density import get_density @@ -181,6 +182,30 @@ def compute_density(self, method=None, **kwargs): self.density = method(self.samples, self.shape, **kwargs) + def get_lipschitz_cst(self, max_iter=10, **kwargs): + """Return the Lipschitz constant of the operator. + + Parameters + --------- + max_iter: int + number of iteration to compute the lipschitz constant. + **kwargs: + Extra arguments givent + Returns + ------- + float + Spectral Radius + + Notes + ----- + This uses the Iterative Power Method to compute the largest singular value of a + minified version of the nufft operator (no coil or B0, but includes any computed density. + """ + tmp_op = self.__class__( + self.samples, self.shape, density=self.density, n_coils=1, **kwargs + ) + return _power_method(max_iter, tmp_op) + @property def uses_sense(self): """Return True if the operator uses sensitivity maps.""" diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 5826ede0..81606314 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -3,7 +3,12 @@ import warnings import numpy as np from mrinufft.operators.base import FourierOperatorBase -from mrinufft._utils import proper_trajectory, get_array_module, auto_cast +from mrinufft._utils import ( + proper_trajectory, + get_array_module, + auto_cast, + power_method, +) from .utils import ( CUPY_AVAILABLE, @@ -769,3 +774,34 @@ def __repr__(self): f" eps:{self.raw_op.eps:.0e}\n" ")" ) + + def get_lipschitz_cst(self, max_iter, **kwargs): + """Return the Lipschitz constant of the operator. + + Parameters + ---------- + max_iter: int + Number of iteration to perform to estimate the Lipschitz constant. + kwargs: + Extra kwargs for the cufinufft operator. + + Returns + ------- + float + Lipschitz constant of the operator. + """ + + tmp_op = self.__class__( + self.samples, + self.shape, + density=self.density, + n_coils=1, + smaps=None, + **kwargs, + ) + return power_method( + max_iter, + tmp_op, + norm_func=lambda x: cp.linalg.norm(x, ord="fro"), + x0=cp.zeros(self.shape, dtype=self.cpx_dtype), + )