Skip to content

Commit

Permalink
adapt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MUCDK committed Nov 5, 2023
1 parent fea3ac9 commit 7669f7b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions tests/problems/generic/test_neural_dual_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from moscot.backends.ott.nets import MLP_marginal
from moscot.backends.ott.output import NeuralDualOutput
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.output import BaseNeuralOutput
from moscot.base.problems import NeuralOTProblem
from moscot.problems.generic import NeuralProblem # type: ignore[attr-defined]
from tests._utils import ATOL, RTOL
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_solve_balanced_no_baseline(self, adata_time: ad.AnnData): # type: igno
problem = problem.solve(**neuraldual_args_1)

for key, subsol in problem.solutions.items():
assert isinstance(subsol, BaseDiscreteSolverOutput)
assert isinstance(subsol, BaseNeuralOutput)
assert key in expected_keys

def test_solve_unbalanced_with_baseline(self, adata_time: ad.AnnData):
Expand All @@ -53,7 +53,7 @@ def test_solve_unbalanced_with_baseline(self, adata_time: ad.AnnData):
problem = problem.solve(**neuraldual_args_2)

for key, subsol in problem.solutions.items():
assert isinstance(subsol, BaseDiscreteSolverOutput)
assert isinstance(subsol, BaseNeuralOutput)
assert key in expected_keys

def test_reproducibility(self, adata_time: ad.AnnData):
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_learning_rescaling_factors(self, adata_time: ad.AnnData):
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))]
problem = problem.prepare(key="time", joint_attr="X_pca")
problem = problem.solve(mlp_eta=mlp_eta, mlp_xi=mlp_xi, **neuraldual_args_2)
assert isinstance(problem[0, 1].solution, BaseDiscreteSolverOutput)
assert isinstance(problem[0, 1].solution, BaseNeuralOutput)
assert isinstance(problem[0, 1].solution, NeuralDualOutput)

array = adata_time.obsm["X_pca"]
Expand Down
6 changes: 3 additions & 3 deletions tests/problems/time/test_temporal_neural_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import anndata as ad

from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.output import BaseNeuralOutput
from moscot.problems.time import TemporalNeuralProblem
from moscot.problems.time._lineage import BirthDeathProblem
from tests._utils import ATOL, RTOL
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_solve_balanced_no_baseline(self, adata_time: ad.AnnData):
problem = problem.solve(**neuraldual_args_1)

for key, subsol in problem.solutions.items():
assert isinstance(subsol, BaseDiscreteSolverOutput)
assert isinstance(subsol, BaseNeuralOutput)
assert key in expected_keys

def test_solve_unbalanced_with_baseline(self, adata_time: ad.AnnData):
Expand All @@ -54,7 +54,7 @@ def test_solve_unbalanced_with_baseline(self, adata_time: ad.AnnData):
problem = problem.solve(**neuraldual_args_2)

for key, subsol in problem.solutions.items():
assert isinstance(subsol, BaseDiscreteSolverOutput)
assert isinstance(subsol, BaseNeuralOutput)
assert key in expected_keys

def test_reproducibility(self, adata_time: ad.AnnData):
Expand Down

0 comments on commit 7669f7b

Please sign in to comment.