Skip to content

Commit

Permalink
make alpha mandatory
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Dec 10, 2024
1 parent 11aa435 commit 64ab534
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def _prepare(
self,
a: jnp.ndarray,
b: jnp.ndarray,
alpha: float,
xy: Optional[TaggedArray] = None,
x: Optional[TaggedArray] = None,
y: Optional[TaggedArray] = None,
Expand All @@ -430,7 +431,6 @@ def _prepare(
cost_matrix_rank: Optional[int] = None,
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
# problem
alpha: Optional[float] = None,
**kwargs: Any,
) -> quadratic_problem.QuadraticProblem:
self._a = a
Expand Down
13 changes: 8 additions & 5 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, eps: Optional[float], jit: bool
x=x,
y=y,
tags={"x": "point_cloud", "y": "point_cloud"},
alpha=1.0,
)

assert solver.is_fused is False
Expand All @@ -141,6 +142,7 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f
x=x_cost,
y=y_cost,
tags={"x": Tag.COST_MATRIX, "y": Tag.COST_MATRIX},
alpha=1.0,
)

assert solver.is_fused is False
Expand Down Expand Up @@ -171,6 +173,7 @@ def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None:
x=x,
y=y,
tags={"x": "point_cloud", "y": "point_cloud"},
alpha=1.0,
)

assert solver.is_fused is False
Expand Down Expand Up @@ -347,8 +350,8 @@ def test_pull(
b, ndim = (b, b.shape[1]) if batched else (b[:, 0], None)
xx, yy = xy
solver = solver_t()

out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy))
additional_kwargs = {"alpha": 1.0} if xy is not None else {}
out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy), **additional_kwargs)
p = out.pull(b, scale_by_marginals=False)

assert isinstance(out, BaseDiscreteSolverOutput)
Expand Down Expand Up @@ -389,17 +392,17 @@ def test_to_device(self, x: Geom_t, device: Optional[Device_t]) -> None:

class TestOutputPlotting(PlotTester, metaclass=PlotTesterMeta):
def test_plot_costs(self, x: Geom_t, y: Geom_t):
out = GWSolver()(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y)
out = GWSolver()(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0)
out.plot_costs()

def test_plot_costs_last(self, x: Geom_t, y: Geom_t):
out = GWSolver(rank=2)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y)
out = GWSolver(rank=2)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0)
out.plot_costs(last=3)

def test_plot_errors_sink(self, x: Geom_t, y: Geom_t):
out = SinkhornSolver()(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), xy=(x, y))
out.plot_errors()

def test_plot_errors_gw(self, x: Geom_t, y: Geom_t):
out = GWSolver(store_inner_errors=True)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y)
out = GWSolver(store_inner_errors=True)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0)
out.plot_errors()

0 comments on commit 64ab534

Please sign in to comment.