From 0c8be0f51a43cf344bab8bf13798ca25aaf01dc4 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 27 Nov 2024 16:25:07 +0100 Subject: [PATCH] attempt to fix dtype bug --- pyproject.toml | 2 +- src/moscot/backends/ott/_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9fdbc3fa3..2fd66c9a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dependencies = [ "scanpy>=1.9.3", "wrapt>=1.13.2", "docrep>=0.3.2", - "ott-jax[neural]>=0.4.6,<=0.4.8", + "ott-jax[neural]>=0.4.6", "cloudpickle>=2.2.0", "rich>=13.5", "docstring_inheritance>=2.0.0", diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index 2cac53b30..c4b1f6d3f 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -132,7 +132,7 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: return jnp.reshape(arr, (-1, 1)) if arr.ndim != 2: raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.") - return arr + return arr.astype(jnp.float64) def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO: