Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark dask #3319

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@
"pooch": [""],
"scikit-image": [""],
// "scikit-misc": [""],
"scikit-learn": [""],
"pip+asv_runner": [""],
"dask": [""]
},

// Combinations of libraries/python versions can be excluded/included
Expand Down
38 changes: 38 additions & 0 deletions benchmarks/benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,42 @@ def lung93k() -> AnnData:
return _lung93k().copy()


@cache
def _large_synthetic_dataset(n_obs: int = 500_000, n_vars: int = 5_000, density: float = 0.01) -> AnnData:
"""
Generate a synthetic dataset suitable for Dask testing.

Parameters:
n_obs: int
Number of observations (rows, typically cells).
n_vars: int
Number of variables (columns, typically genes).
density: float
Fraction of non-zero entries in the sparse matrix.

Returns:
AnnData
The synthetic dataset.
"""

X = sparse.random(n_obs, n_vars, density=density, format="csr", dtype=np.float32, random_state=42)
obs = {"obs_names": [f"cell_{i}" for i in range(n_obs)]}
var = {"var_names": [f"gene_{j}" for j in range(n_vars)]}
adata = anndata.AnnData(X=X, obs=obs, var=var)
adata.layers["counts"] = X.copy()
sc.pp.log1p(adata)
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
)

return adata


def large_synthetic_dataset(n_obs: int = 500_000, n_vars: int = 5_000, density: float = 0.01) -> AnnData:
return _large_synthetic_dataset(n_obs, n_vars, density).copy()


def to_off_axis(x: np.ndarray | sparse.csr_matrix) -> np.ndarray | sparse.csc_matrix:
if isinstance(x, sparse.csr_matrix):
return x.tocsc()
Expand All @@ -138,6 +174,8 @@ def _get_dataset_raw(dataset: Dataset) -> tuple[AnnData, str | None]:
adata, batch_key = bmmc(400), "sample"
case "lung93k":
adata, batch_key = lung93k(), "PatientNumber"
case "large_synthetic":
adata, batch_key = large_synthetic_dataset(), None
case _:
msg = f"Unknown dataset {dataset}"
raise AssertionError(msg)
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/benchmarks/preprocessing_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def setup(dataset: Dataset, layer: KeyCount, *_):
# ASV suite

params: tuple[list[Dataset], list[KeyCount]] = (
["pbmc68k_reduced", "pbmc3k"],
["pbmc3k"],
# ["pbmc3k", "pbmc68k_reduced", "bmmc", "lung93k", "large_synthetic"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]
Expand Down Expand Up @@ -78,7 +79,8 @@ class FastSuite:
"""Suite for fast preprocessing operations."""

params: tuple[list[Dataset], list[KeyCount]] = (
["pbmc3k", "pbmc68k_reduced", "bmmc", "lung93k"],
# ["pbmc3k", "pbmc68k_reduced", "bmmc", "lung93k", "large_synthetic"],
["pbmc3k"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]
Expand Down
143 changes: 143 additions & 0 deletions benchmarks/benchmarks/preprocessing_counts_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import dask.array as dd
from dask.distributed import Client, LocalCluster
import scanpy as sc
from scipy import sparse

from ._utils import get_count_dataset

if TYPE_CHECKING:
from anndata import AnnData
from ._utils import Dataset, KeyCount

# Setup global variables
adata: AnnData
batch_key: str | None

def setup(dataset: Dataset, layer: KeyCount, *_):
"""Setup global variables before each benchmark."""
global adata, batch_key
adata, batch_key = get_count_dataset(dataset, layer=layer)
assert "log1p" not in adata.uns


def setup_dask_cluster():
"""Set up a local Dask cluster for benchmarking."""
cluster = LocalCluster(n_workers=4, threads_per_worker=2)
client = Client(cluster)
return client


# ASV suite
params: tuple[list[Dataset], list[KeyCount]] = (
["pbmc68k_reduced"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]

### Dask-Based Benchmarks ###

def time_filter_cells_dask(*_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 10, adata.X.shape[1] // 10))
adata.X = adata.X.map_blocks(sparse.csr_matrix) # Ensure sparse chunks
sc.pp.filter_cells(adata, min_genes=100)
finally:
client.close()


def peakmem_filter_cells_dask(*_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 50, adata.X.shape[1] // 50))
sc.pp.filter_cells(adata, min_genes=100)
finally:
client.close()


def time_filter_genes_dask(*_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 10, adata.X.shape[1] // 10))
sc.pp.filter_genes(adata, min_cells=3)
finally:
client.close()


def peakmem_filter_genes_dask(*_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 10, adata.X.shape[1] // 10))
sc.pp.filter_genes(adata, min_cells=3)
finally:
client.close()


### General Dask and Non-Dask Preprocessing Benchmarks ###

class FastSuite:
"""Suite for benchmarking preprocessing operations with Dask."""

params: tuple[list[Dataset], list[KeyCount]] = (
["pbmc68k_reduced"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]

def time_calculate_qc_metrics_dask(self, *_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 10, adata.X.shape[1] // 10))
adata.X = adata.X.map_blocks(sparse.csr_matrix)
sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
)
finally:
client.close()

def peakmem_calculate_qc_metrics_dask(self, *_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 10, adata.X.shape[1] // 10))
sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
)
finally:
client.close()

def time_normalize_total_dask(self, *_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 50, adata.X.shape[1] // 50))
sc.pp.normalize_total(adata, target_sum=1e4)
finally:
client.close()

def peakmem_normalize_total_dask(self, *_):
client = setup_dask_cluster()
try:
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0], adata.X.shape[1]))
sc.pp.normalize_total(adata, target_sum=1e4)
finally:
client.close()

def time_log1p_dask(self, *_):
client = setup_dask_cluster()
try:
adata.uns.pop("log1p", None)
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 50, adata.X.shape[1] // 50))
sc.pp.log1p(adata)
finally:
client.close()

def peakmem_log1p_dask(self, *_):
client = setup_dask_cluster()
try:
adata.uns.pop("log1p", None)
adata.X = dd.from_array(adata.X, chunks=(adata.X.shape[0] // 100, adata.X.shape[1] // 100))
sc.pp.log1p(adata)
finally:
client.close()
Loading