Skip to content

Commit 8a4761f

Browse files
committed
initial support for Dask DataFrames in obsm/varm
1 parent b2c7a21 commit 8a4761f

File tree

10 files changed

+108
-9
lines changed

10 files changed

+108
-9
lines changed

docs/release-notes/1880.feature.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support for Dask DataFrames in `.obsm` and `.varm` ({user}`ilia-kats`)

src/anndata/_core/aligned_mapping.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,16 @@ def to_df(self) -> pd.DataFrame:
264264
def _validate_value(self, val: Value, key: str) -> Value:
265265
if isinstance(val, pd.DataFrame):
266266
raise_value_error_if_multiindex_columns(val, f"{self.attrname}[{key!r}]")
267-
if not val.index.equals(self.dim_names):
267+
if (
268+
not val.index.equals(self.dim_names)
269+
and (
270+
val.index.dtype == "string"
271+
and self.dim_names.dtype == "O"
272+
or val.index.dtype == "O"
273+
and self.dim_names.dtype == "string"
274+
)
275+
and (val.index != self.dim_names).any()
276+
):
268277
# Could probably also re-order index if it’s contained
269278
try:
270279
pd.testing.assert_index_equal(val.index, self.dim_names)

src/anndata/_core/file_backing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import h5py
1010

11-
from ..compat import AwkArray, DaskArray, ZarrArray, ZarrGroup
11+
from ..compat import AwkArray, DaskArray, DaskDataFrame, ZarrArray, ZarrGroup
1212
from .sparse_dataset import BaseCompressedSparseDataset
1313

1414
if TYPE_CHECKING:
@@ -143,7 +143,8 @@ def _(x: BaseCompressedSparseDataset, *, copy: bool = False):
143143

144144

145145
@to_memory.register(DaskArray)
146-
def _(x: DaskArray, *, copy: bool = False):
146+
@to_memory.register(DaskDataFrame)
147+
def _(x: DaskArray | DaskDataFrame, *, copy: bool = False):
147148
return x.compute()
148149

149150

src/anndata/_core/storage.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from scipy import sparse
99

10-
from anndata.compat import CSArray
10+
from anndata.compat import CSArray, DaskDataFrame
1111

1212
from .._warnings import ImplicitModificationWarning
1313
from ..utils import (
@@ -51,7 +51,7 @@ def coerce_array(
5151
if any(is_non_csc_r_array_or_matrix):
5252
msg = f"Only CSR and CSC {'matrices' if isinstance(value, sparse.spmatrix) else 'arrays'} are supported."
5353
raise ValueError(msg)
54-
if isinstance(value, pd.DataFrame):
54+
if isinstance(value, pd.DataFrame | DaskDataFrame):
5555
if allow_df:
5656
raise_value_error_if_multiindex_columns(value, name)
5757
return value if allow_df else ensure_df_homogeneous(value, name)

src/anndata/_io/specs/methods.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CupyCSCMatrix,
3030
CupyCSRMatrix,
3131
DaskArray,
32+
DaskDataFrame,
3233
H5Array,
3334
H5File,
3435
H5Group,
@@ -896,6 +897,19 @@ def write_dataframe(
896897
)
897898

898899

900+
@_REGISTRY.register_write(H5Group, DaskDataFrame, IOSpec("dataframe", "0.2.0"))
901+
@_REGISTRY.register_write(ZarrGroup, DaskDataFrame, IOSpec("dataframe", "0.2.0"))
902+
def write_dask_dataframe(
903+
f: GroupStorageType,
904+
key: str,
905+
df: DaskDataFrame,
906+
*,
907+
_writer: Writer,
908+
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
909+
):
910+
_writer.write_elem(f, key, df.compute(), dataset_kwargs=dataset_kwargs)
911+
912+
899913
@_REGISTRY.register_read(H5Group, IOSpec("dataframe", "0.2.0"))
900914
@_REGISTRY.register_read(ZarrGroup, IOSpec("dataframe", "0.2.0"))
901915
def read_dataframe(elem: GroupStorageType, *, _reader: Reader) -> pd.DataFrame:
@@ -1051,6 +1065,12 @@ def read_partial_categorical(elem, *, items=None, indices=(slice(None),)):
10511065
@_REGISTRY.register_write(
10521066
ZarrGroup, pd.arrays.StringArray, IOSpec("nullable-string-array", "0.1.0")
10531067
)
1068+
@_REGISTRY.register_write(
1069+
H5Group, pd.arrays.ArrowStringArray, IOSpec("nullable-string-array", "0.1.0")
1070+
)
1071+
@_REGISTRY.register_write(
1072+
ZarrGroup, pd.arrays.ArrowStringArray, IOSpec("nullable-string-array", "0.1.0")
1073+
)
10541074
def write_nullable(
10551075
f: GroupStorageType,
10561076
k: str,
@@ -1073,7 +1093,7 @@ def write_nullable(
10731093
g = f.require_group(k)
10741094
values = (
10751095
v.to_numpy(na_value="")
1076-
if isinstance(v, pd.arrays.StringArray)
1096+
if isinstance(v, pd.arrays.StringArray | pd.arrays.ArrowStringArray)
10771097
else v.to_numpy(na_value=0, dtype=v.dtype.numpy_dtype)
10781098
)
10791099
_writer.write_elem(g, "values", values, dataset_kwargs=dataset_kwargs)

src/anndata/compat/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,19 @@ def __repr__():
100100
from dask.array.core import Array as DaskArray
101101
elif find_spec("dask"):
102102
from dask.array import Array as DaskArray
103+
from dask.dataframe import DataFrame as DaskDataFrame
103104
else:
104105

105106
class DaskArray:
106107
@staticmethod
107108
def __repr__():
108109
return "mock dask.array.core.Array"
109110

111+
class DaskDataFrame:
112+
@staticmethod
113+
def __repr__():
114+
return "mock dask.dataframe.dask_expr._collection.DataFrame"
115+
110116

111117
# https://github.com/scverse/anndata/issues/1749
112118
def is_cupy_importable() -> bool:

src/anndata/tests/helpers.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
CupyCSRMatrix,
3333
CupySparseMatrix,
3434
DaskArray,
35+
DaskDataFrame,
3536
ZarrArray,
3637
)
3738
from anndata.utils import asarray
@@ -644,8 +645,13 @@ def assert_equal_h5py_dataset(
644645

645646

646647
@assert_equal.register(DaskArray)
648+
@assert_equal.register(DaskDataFrame)
647649
def assert_equal_dask_array(
648-
a: DaskArray, b: object, *, exact: bool = False, elem_name: str | None = None
650+
a: DaskArray | DaskDataFrame,
651+
b: object,
652+
*,
653+
exact: bool = False,
654+
elem_name: str | None = None,
649655
):
650656
assert_equal(b, a.compute(), exact=exact, elem_name=elem_name)
651657

src/anndata/typing.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CupyArray,
1616
CupySparseMatrix,
1717
DaskArray,
18+
DaskDataFrame,
1819
H5Array,
1920
ZappyArray,
2021
ZarrArray,
@@ -43,6 +44,7 @@
4344
| abc.CSRDataset
4445
| abc.CSCDataset
4546
| DaskArray
47+
| DaskDataFrame
4648
| CupyArray
4749
| CupySparseMatrix
4850
)

src/anndata/utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import anndata
1414

1515
from ._core.sparse_dataset import BaseCompressedSparseDataset
16-
from .compat import CSArray, CupyArray, CupySparseMatrix, DaskArray
16+
from .compat import CSArray, CupyArray, CupySparseMatrix, DaskArray, DaskDataFrame
1717
from .logging import get_logger
1818

1919
if TYPE_CHECKING:
@@ -115,6 +115,11 @@ def axis_len(x, axis: Literal[0, 1]) -> int | None:
115115
return x.shape[axis]
116116

117117

118+
@axis_len.register(DaskDataFrame)
119+
def axis_len_dask_df(df, axis: Literal[0, 1]) -> int | None:
120+
return df.shape[axis].compute()
121+
122+
118123
try:
119124
from .compat import awkward as ak
120125

tests/test_dask.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import anndata as ad
1313
from anndata._core.anndata import AnnData
14-
from anndata.compat import CupyArray, DaskArray
14+
from anndata.compat import CupyArray, DaskArray, DaskDataFrame
1515
from anndata.experimental.merge import as_group
1616
from anndata.tests.helpers import (
1717
GEN_ADATA_DASK_ARGS,
@@ -74,6 +74,7 @@ def test_dask_X_view():
7474

7575
def test_dask_write(adata, tmp_path, diskfmt):
7676
import dask.array as da
77+
import dask.dataframe as ddf
7778
import numpy as np
7879

7980
pth = tmp_path / f"test_write.{diskfmt}"
@@ -84,6 +85,12 @@ def test_dask_write(adata, tmp_path, diskfmt):
8485
adata.obsm["a"] = da.random.random((M, 10))
8586
adata.obsm["b"] = da.random.random((M, 10))
8687
adata.varm["a"] = da.random.random((N, 10))
88+
adata.varm["b"] = ddf.from_pandas(
89+
pd.DataFrame(
90+
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
91+
index=adata.var_names,
92+
)
93+
)
8794

8895
orig = adata
8996
write(orig, pth)
@@ -93,6 +100,7 @@ def test_dask_write(adata, tmp_path, diskfmt):
93100
assert_equal(curr.obsm["a"], curr.obsm["b"])
94101

95102
assert_equal(curr.varm["a"], orig.varm["a"])
103+
assert_equal(orig.varm["b"], orig.varm["b"])
96104
assert_equal(curr.obsm["a"], orig.obsm["a"])
97105

98106
assert isinstance(curr.X, np.ndarray)
@@ -105,6 +113,7 @@ def test_dask_write(adata, tmp_path, diskfmt):
105113

106114
def test_dask_distributed_write(adata, tmp_path, diskfmt):
107115
import dask.array as da
116+
import dask.dataframe as ddf
108117
import dask.distributed as dd
109118
import numpy as np
110119

@@ -119,6 +128,12 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt):
119128
adata.obsm["a"] = da.random.random((M, 10))
120129
adata.obsm["b"] = da.random.random((M, 10))
121130
adata.varm["a"] = da.random.random((N, 10))
131+
adata.varm["b"] = ddf.from_pandas(
132+
pd.DataFrame(
133+
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
134+
index=adata.var_names,
135+
)
136+
)
122137
orig = adata
123138
if diskfmt == "h5ad":
124139
with pytest.raises(ValueError, match=r"Cannot write dask arrays to hdf5"):
@@ -131,6 +146,7 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt):
131146
assert_equal(curr.obsm["a"], curr.obsm["b"])
132147

133148
assert_equal(curr.varm["a"], orig.varm["a"])
149+
assert_equal(orig.varm["b"], curr.varm["a"])
134150
assert_equal(curr.obsm["a"], orig.obsm["a"])
135151

136152
assert isinstance(curr.X, np.ndarray)
@@ -143,6 +159,7 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt):
143159

144160
def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
145161
import dask.array as da
162+
import dask.dataframe as ddf
146163
import numpy as np
147164

148165
pth = tmp_path / f"test_write.{diskfmt}"
@@ -153,6 +170,12 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
153170
adata.obsm["a"] = da.random.random((M, 10))
154171
adata.obsm["b"] = da.random.random((M, 10))
155172
adata.varm["a"] = da.random.random((N, 10))
173+
adata.varm["b"] = ddf.from_pandas(
174+
pd.DataFrame(
175+
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
176+
index=adata.var_names,
177+
)
178+
)
156179

157180
orig = adata
158181
write(orig, pth)
@@ -161,6 +184,7 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
161184
assert isinstance(orig.X, DaskArray)
162185
assert isinstance(orig.obsm["a"], DaskArray)
163186
assert isinstance(orig.varm["a"], DaskArray)
187+
assert isinstance(orig.varm["b"], DaskDataFrame)
164188

165189
mem = orig.to_memory()
166190

@@ -171,20 +195,25 @@ def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
171195
assert_equal(curr.obsm["a"], orig.obsm["a"])
172196
assert_equal(mem.obsm["a"], orig.obsm["a"])
173197
assert_equal(mem.varm["a"], orig.varm["a"])
198+
assert_equal(orig.varm["b"], mem.varm["b"])
174199

175200
assert isinstance(curr.X, np.ndarray)
176201
assert isinstance(curr.obsm["a"], np.ndarray)
177202
assert isinstance(curr.varm["a"], np.ndarray)
203+
assert isinstance(curr.varm["b"], pd.DataFrame)
178204
assert isinstance(mem.X, np.ndarray)
179205
assert isinstance(mem.obsm["a"], np.ndarray)
180206
assert isinstance(mem.varm["a"], np.ndarray)
207+
assert isinstance(mem.varm["b"], pd.DataFrame)
181208
assert isinstance(orig.X, DaskArray)
182209
assert isinstance(orig.obsm["a"], DaskArray)
183210
assert isinstance(orig.varm["a"], DaskArray)
211+
assert isinstance(orig.varm["b"], DaskDataFrame)
184212

185213

186214
def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt):
187215
import dask.array as da
216+
import dask.dataframe as ddf
188217
import numpy as np
189218

190219
pth = tmp_path / f"test_write.{diskfmt}"
@@ -195,6 +224,12 @@ def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt):
195224
adata.obsm["a"] = da.random.random((M, 10))
196225
adata.obsm["b"] = da.random.random((M, 10))
197226
adata.varm["a"] = da.random.random((N, 10))
227+
adata.varm["b"] = ddf.from_pandas(
228+
pd.DataFrame(
229+
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
230+
index=adata.var_names,
231+
)
232+
)
198233

199234
orig = adata
200235
write(orig, pth)
@@ -209,25 +244,36 @@ def test_dask_to_memory_copy_check_array_types(adata, tmp_path, diskfmt):
209244
assert_equal(curr.obsm["a"], orig.obsm["a"])
210245
assert_equal(mem.obsm["a"], orig.obsm["a"])
211246
assert_equal(mem.varm["a"], orig.varm["a"])
247+
assert_equal(orig.varm["b"], mem.varm["b"])
212248

213249
assert isinstance(curr.X, np.ndarray)
214250
assert isinstance(curr.obsm["a"], np.ndarray)
215251
assert isinstance(curr.varm["a"], np.ndarray)
252+
assert isinstance(curr.varm["b"], pd.DataFrame)
216253
assert isinstance(mem.X, np.ndarray)
217254
assert isinstance(mem.obsm["a"], np.ndarray)
218255
assert isinstance(mem.varm["a"], np.ndarray)
256+
assert isinstance(mem.varm["b"], pd.DataFrame)
219257
assert isinstance(orig.X, DaskArray)
220258
assert isinstance(orig.obsm["a"], DaskArray)
221259
assert isinstance(orig.varm["a"], DaskArray)
260+
assert isinstance(orig.varm["b"], DaskDataFrame)
222261

223262

224263
def test_dask_copy_check_array_types(adata):
225264
import dask.array as da
265+
import dask.dataframe as ddf
226266

227267
M, N = adata.X.shape
228268
adata.obsm["a"] = da.random.random((M, 10))
229269
adata.obsm["b"] = da.random.random((M, 10))
230270
adata.varm["a"] = da.random.random((N, 10))
271+
adata.varm["b"] = ddf.from_pandas(
272+
pd.DataFrame(
273+
{"A": np.arange(N), "B": np.random.randint(1e6, size=N)},
274+
index=adata.var_names,
275+
)
276+
)
231277

232278
orig = adata
233279
curr = adata.copy()
@@ -236,14 +282,17 @@ def test_dask_copy_check_array_types(adata):
236282
assert_equal(curr.obsm["a"], curr.obsm["b"])
237283

238284
assert_equal(curr.varm["a"], orig.varm["a"])
285+
assert_equal(orig.varm["b"], curr.varm["b"])
239286
assert_equal(curr.obsm["a"], orig.obsm["a"])
240287

241288
assert isinstance(curr.X, DaskArray)
242289
assert isinstance(curr.obsm["a"], DaskArray)
243290
assert isinstance(curr.varm["a"], DaskArray)
291+
assert isinstance(curr.varm["b"], DaskDataFrame)
244292
assert isinstance(orig.X, DaskArray)
245293
assert isinstance(orig.obsm["a"], DaskArray)
246294
assert isinstance(orig.varm["a"], DaskArray)
295+
assert isinstance(orig.varm["b"], DaskDataFrame)
247296

248297

249298
def test_assign_X(adata):

0 commit comments

Comments
 (0)