Skip to content

Commit

Permalink
feat: add power methods for estimating lipschitz constant.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Dec 18, 2023
1 parent 387f1d6 commit a4c7284
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
26 changes: 26 additions & 0 deletions src/mrinufft/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,32 @@ def proper_trajectory(trajectory, normalize="pi"):
return new_traj


def power_method(max_iter, operator, norm_func=None, x0=None):
"""Power method to find the Lipschitz constant of an operator."""

def AHA(x):
return operator.adj_op(operator.op(x))

if norm_func is None:
norm_func = np.linalg.norm
if x0 is None:
x = np.random.random(operator.shape).astype(operator.cpx_dtype)
x_norm = norm_func(x)
x /= x_norm
for i in range(max_iter): # noqa: B007
x_new = AHA(x)
x_new_norm = norm_func(x_new)
x_new /= x_new_norm
if abs(x_norm - x_new_norm) < 1e-6:
break
x_norm = x_new_norm
x = x_new

if i == max_iter - 1:
warnings.warn("Lipschitz constant did not converge")
return x_new_norm


class MethodRegister:
"""
A Decorator to register methods of the same type in dictionnaries.
Expand Down
25 changes: 25 additions & 0 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
import warnings
import numpy as np
from mrinufft._utils import _power_method

from mrinufft.density import get_density

Expand Down Expand Up @@ -181,6 +182,30 @@ def compute_density(self, method=None, **kwargs):

self.density = method(self.samples, self.shape, **kwargs)

def get_lipschitz_cst(self, max_iter=10, **kwargs):
"""Return the Lipschitz constant of the operator.
Parameters
---------
max_iter: int
number of iteration to compute the lipschitz constant.
**kwargs:
Extra arguments givent
Returns
-------
float
Spectral Radius
Notes
-----
This uses the Iterative Power Method to compute the largest singular value of a
minified version of the nufft operator (no coil or B0, but includes any computed density.
"""
tmp_op = self.__class__(
self.samples, self.shape, density=self.density, n_coils=1, **kwargs
)
return _power_method(max_iter, tmp_op)

@property
def uses_sense(self):
"""Return True if the operator uses sensitivity maps."""
Expand Down
38 changes: 37 additions & 1 deletion src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import warnings
import numpy as np
from mrinufft.operators.base import FourierOperatorBase
from mrinufft._utils import proper_trajectory, get_array_module, auto_cast
from mrinufft._utils import (
proper_trajectory,
get_array_module,
auto_cast,
power_method,
)

from .utils import (
CUPY_AVAILABLE,
Expand Down Expand Up @@ -769,3 +774,34 @@ def __repr__(self):
f" eps:{self.raw_op.eps:.0e}\n"
")"
)

def get_lipschitz_cst(self, max_iter, **kwargs):
"""Return the Lipschitz constant of the operator.
Parameters
----------
max_iter: int
Number of iteration to perform to estimate the Lipschitz constant.
kwargs:
Extra kwargs for the cufinufft operator.
Returns
-------
float
Lipschitz constant of the operator.
"""

tmp_op = self.__class__(
self.samples,
self.shape,
density=self.density,
n_coils=1,
smaps=None,
**kwargs,
)
return power_method(
max_iter,
tmp_op,
norm_func=lambda x: cp.linalg.norm(x, ord="fro"),
x0=cp.zeros(self.shape, dtype=self.cpx_dtype),
)

0 comments on commit a4c7284

Please sign in to comment.