Skip to content

Commit

Permalink
Merge pull request #748 from theislab/refactor/handle-solve-args
Browse files Browse the repository at this point in the history
Refactor/handle solve args
  • Loading branch information
selmanozleyen authored Dec 12, 2024
2 parents 4080a94 + 792d913 commit 29764d4
Show file tree
Hide file tree
Showing 24 changed files with 318 additions and 170 deletions.
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
("py:class", "None. Update D from dict/iterable E and F."),
("py:class", "an object providing a view on D's values"),
("py:class", "a shallow copy of D"),
# ignore these classes until ott-jax adds them to their docs
("py:class", "ott.initializers.quadratic.initializers.BaseQuadraticInitializer"),
("py:class", "ott.initializers.linear.initializers.SinkhornInitializer"),
]
# TODO(michalk8): remove once typing has been cleaned-up
nitpick_ignore_regex = [
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dependencies = [
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"docrep>=0.3.2",
"ott-jax[neural]>=0.4.6,<=0.4.8",
"ott-jax[neural]>=0.5.0",
"cloudpickle>=2.2.0",
"rich>=13.5",
"docstring_inheritance>=2.0.0",
Expand Down
14 changes: 9 additions & 5 deletions src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Any, Literal, Mapping, Optional, Sequence, Union

import numpy as np
from ott.initializers.linear.initializers import SinkhornInitializer
from ott.initializers.linear.initializers_lr import LRInitializer
from ott.initializers.quadratic.initializers import BaseQuadraticInitializer

# TODO(michalk8): polish

Expand All @@ -17,13 +20,14 @@
Numeric_t = Union[int, float] # type of `time_key` arguments
Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata
Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition`
SinkFullRankInit = Literal["default", "gaussian", "sorting"]
LRInitializer_t = Literal["random", "rank2", "k-means", "generalized-k-means"]
SinkhornInitializerTag_t = Literal["default", "gaussian", "sorting"]
LRInitializerTag_t = Literal["random", "rank2", "k-means", "generalized-k-means"]

SinkhornInitializer_t = Optional[Union[SinkFullRankInit, LRInitializer_t]]
QuadInitializer_t = Optional[LRInitializer_t]
LRInitializer_t = Optional[Union[LRInitializer, LRInitializerTag_t]]
SinkhornInitializer_t = Optional[Union[SinkhornInitializer, SinkhornInitializerTag_t]]
QuadInitializer_t = Optional[Union[BaseQuadraticInitializer]]

Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t]
Initializer_t = Union[SinkhornInitializer_t, QuadInitializer_t, LRInitializer_t]
ProblemStage_t = Literal["prepared", "solved"]
Device_t = Union[Literal["cpu", "gpu", "tpu"], str]

Expand Down
95 changes: 92 additions & 3 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
from ott.initializers.linear import initializers as init_lib
from ott.initializers.linear import initializers_lr as lr_init_lib
from ott.neural import datasets
from ott.solvers import utils as solver_utils
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
Expand All @@ -21,6 +23,90 @@
__all__ = ["sinkhorn_divergence"]


class InitializerResolver:
"""Class for creating various OT solver initializers.
This class provides static methods to create and manage different types of
initializers used in optimal transport solvers, including low-rank, k-means,
and standard Sinkhorn initializers.
"""

@staticmethod
def lr_from_str(
initializer: str,
rank: int,
**kwargs: Any,
) -> lr_init_lib.LRInitializer:
"""Create a low-rank initializer from a string specification.
Parameters
----------
initializer : str
Either existing initializer instance or string specifier.
rank : int
Rank for the initialization.
**kwargs : Any
Additional keyword arguments for initializer creation.
Returns
-------
LRInitializer
Configured low-rank initializer.
Raises
------
NotImplementedError
If requested initializer type is not implemented.
"""
if isinstance(initializer, lr_init_lib.LRInitializer):
return initializer
if initializer == "k-means":
return lr_init_lib.KMeansInitializer(rank=rank, **kwargs)
if initializer == "generalized-k-means":
return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs)
if initializer == "random":
return lr_init_lib.RandomInitializer(rank=rank, **kwargs)
if initializer == "rank2":
return lr_init_lib.Rank2Initializer(rank=rank, **kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not implemented.")

@staticmethod
def from_str(
initializer: str,
**kwargs: Any,
) -> init_lib.SinkhornInitializer:
"""Create a Sinkhorn initializer from a string specification.
Parameters
----------
initializer : str
String specifier for initializer type.
**kwargs : Any
Additional keyword arguments for initializer creation.
Returns
-------
SinkhornInitializer
Configured Sinkhorn initializer.
Raises
------
NotImplementedError
If requested initializer type is not implemented.
"""
if isinstance(initializer, init_lib.SinkhornInitializer):
return initializer
if initializer == "default":
return init_lib.DefaultInitializer(**kwargs)
if initializer == "gaussian":
return init_lib.GaussianInitializer(**kwargs)
if initializer == "sorting":
return init_lib.SortingInitializer(**kwargs)
if initializer == "subsample":
return init_lib.SubsampleInitializer(**kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")


def sinkhorn_divergence(
point_cloud_1: ArrayLike,
point_cloud_2: ArrayLike,
Expand All @@ -45,11 +131,14 @@ def sinkhorn_divergence(
batch_size=batch_size,
a=a,
b=b,
sinkhorn_kwargs={"tau_a": tau_a, "tau_b": tau_b},
scale_cost=scale_cost,
epsilon=epsilon,
solve_kwargs={
"tau_a": tau_a,
"tau_b": tau_b,
},
**kwargs,
)
)[1]
xy_conv, xx_conv, *yy_conv = output.converged

if not xy_conv:
Expand Down Expand Up @@ -132,7 +221,7 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
return jnp.reshape(arr, (-1, 1))
if arr.ndim != 2:
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
return arr
return arr.astype(jnp.float64)


def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
Expand Down
50 changes: 30 additions & 20 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
from moscot._logging import logger
from moscot._types import (
ArrayLike,
LRInitializer_t,
ProblemKind_t,
QuadInitializer_t,
SinkhornInitializer_t,
)
from moscot.backends.ott._utils import (
InitializerResolver,
Loader,
MultiLoader,
_instantiate_geodesic_cost,
Expand Down Expand Up @@ -88,16 +90,20 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
----------
jit
Whether to :func:`~jax.jit` the :attr:`solver`.
initializer_kwargs
Keyword arguments for the initializer.
"""

def __init__(self, jit: bool = True):
def __init__(self, jit: bool = True, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({})):
super().__init__()
self._solver: Optional[OTTSolver_t] = None
self._problem: Optional[OTTProblem_t] = None
self._jit = jit
self._a: Optional[jnp.ndarray] = None
self._b: Optional[jnp.ndarray] = None

self.initializer_kwargs = initializer_kwargs

def _create_geometry(
self,
x: TaggedArray,
Expand Down Expand Up @@ -170,7 +176,7 @@ def _solve( # type: ignore[override]
**kwargs: Any,
) -> Union[OTTOutput, GraphOTTOutput]:
solver = jax.jit(self.solver) if self._jit else self.solver
out = solver(prob, **kwargs)
out = solver(prob, **self.initializer_kwargs, **kwargs)
if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic):
return GraphOTTOutput(out, shape=(len(self._a), len(self._b))) # type: ignore[arg-type]
return OTTOutput(out)
Expand Down Expand Up @@ -275,20 +281,20 @@ def __init__(
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
):
super().__init__(jit=jit)
super().__init__(jit=jit, initializer_kwargs=initializer_kwargs)
if rank > -1:
kwargs.setdefault("gamma", 500)
kwargs.setdefault("gamma_rescale", True)
eps = kwargs.get("epsilon")
if eps is not None and eps > 0.0:
logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.")
initializer = "rank2" if initializer is None else initializer
self._solver = sinkhorn_lr.LRSinkhorn(
rank=rank, epsilon=epsilon, initializer=initializer, kwargs_init=initializer_kwargs, **kwargs
)
if isinstance(initializer, str):
initializer = InitializerResolver.lr_from_str(initializer, rank=rank)
self._solver = sinkhorn_lr.LRSinkhorn(rank=rank, epsilon=epsilon, initializer=initializer, **kwargs)
else:
initializer = "default" if initializer is None else initializer
self._solver = sinkhorn.Sinkhorn(initializer=initializer, kwargs_init=initializer_kwargs, **kwargs)
if isinstance(initializer, str):
initializer = InitializerResolver.from_str(initializer)
self._solver = sinkhorn.Sinkhorn(initializer=initializer, **kwargs)

def _prepare(
self,
Expand Down Expand Up @@ -389,40 +395,40 @@ def __init__(
self,
jit: bool = True,
rank: int = -1,
initializer: QuadInitializer_t = None,
initializer: QuadInitializer_t | LRInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
):
super().__init__(jit=jit)
super().__init__(jit=jit, initializer_kwargs=initializer_kwargs)
if rank > -1:
kwargs.setdefault("gamma", 10)
kwargs.setdefault("gamma_rescale", True)
eps = kwargs.get("epsilon")
if eps is not None and eps > 0.0:
logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.")
initializer = "rank2" if initializer is None else initializer
if isinstance(initializer, str):
initializer = InitializerResolver.lr_from_str(initializer, rank=rank)
self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
else:
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
initializer = None
linear_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
if isinstance(initializer, str):
raise ValueError("Expected `initializer` to be `None` or `ott.initializers.quadratic.initializers`.")
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
linear_solver=linear_solver,
initializer=initializer,
**kwargs,
)

def _prepare(
self,
a: jnp.ndarray,
b: jnp.ndarray,
alpha: float,
xy: Optional[TaggedArray] = None,
x: Optional[TaggedArray] = None,
y: Optional[TaggedArray] = None,
Expand All @@ -435,7 +441,6 @@ def _prepare(
cost_matrix_rank: Optional[int] = None,
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
# problem
alpha: float = 0.5,
**kwargs: Any,
) -> quadratic_problem.QuadraticProblem:
self._a = a
Expand All @@ -456,6 +461,11 @@ def _prepare(
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
if alpha <= 0.0:
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None):
raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.")

if alpha == 1.0 or xy is None: # GW
# arbitrary fused penalty; must be positive
geom_xy, fused_penalty = None, 1.0
Expand Down
18 changes: 1 addition & 17 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,24 +518,8 @@ def solve(
solver_class = backends.get_solver(
self.problem_kind, solver_name=solver_name, backend=backend, return_class=True
)
init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
# if linear problem, then alpha is 0.0 by default
# if quadratic problem, then alpha is 1.0 by default
alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0)
if alpha < 0.0 or alpha > 1.0:
raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.")
if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)):
raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.")
if self.problem_kind == "quadratic":
if self.x is None or self.y is None:
raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.")
if alpha != 1.0 and self.xy is None: # means FGW case
raise ValueError(
"`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class."
)
if alpha == 1.0 and self.xy is not None:
raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.")

init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
self._solver = solver_class(**init_kwargs)

# note that the solver call consists of solver._prepare and solver._solve
Expand Down
9 changes: 3 additions & 6 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def solve(
jit: bool = True,
threshold: float = 1e-3,
lse_mode: bool = True,
inner_iterations: int = 10,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down Expand Up @@ -233,9 +232,7 @@ def solve(
lse_mode
Whether to use `log-sum-exp (LSE)
<https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations>`_
computations for numerical stability.
inner_iterations
Compute the convergence criterion every ``inner_iterations``.
computations for numerical stability. Valid only for the :term:`linear problem`.
min_iterations
Minimum number of :term:`Sinkhorn` iterations.
max_iterations
Expand All @@ -253,6 +250,8 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
if self.problem_kind == "linear":
kwargs["lse_mode"] = lse_mode
return super().solve( # type:ignore[return-value]
epsilon=epsilon,
tau_a=tau_a,
Expand All @@ -265,8 +264,6 @@ def solve(
initializer_kwargs=initializer_kwargs,
jit=jit,
threshold=threshold,
lse_mode=lse_mode,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
device=device,
Expand Down
Loading

0 comments on commit 29764d4

Please sign in to comment.