Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU support with dask #179

Open
wants to merge 100 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
17df571
add first functions
Apr 25, 2024
40167ca
add hvg part1
Apr 25, 2024
f4db387
Merge branch 'main' into dask_mg_support
Intron7 Apr 30, 2024
6526b42
Merge branch 'main' into dask_mg_support
Intron7 May 2, 2024
0cdb85d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 2, 2024
48b68f6
reset to main for hvg
Intron7 May 2, 2024
886cafa
add support for hvg
Intron7 May 2, 2024
d7bf01e
first pass pca
Intron7 May 2, 2024
b216890
pca update
Intron7 May 2, 2024
cdffd33
fix bug with csc matrix
Intron7 May 3, 2024
177afa1
add dask to docs
Intron7 May 3, 2024
dd1377c
add tests
Intron7 May 3, 2024
e254800
update names
Intron7 May 3, 2024
77b3c34
get docs to work
Intron7 May 4, 2024
36bebf9
remove client from sparse calc
Intron7 May 4, 2024
82cc22c
need dask for docs
Intron7 May 4, 2024
7ddde9b
Merge branch 'main' into dask_mg_support
Intron7 May 6, 2024
e33821f
add scale
Intron7 May 7, 2024
e1e6c19
int64 updates
Intron7 May 8, 2024
7da41e0
For main branch
Intron7 May 8, 2024
e676dbe
Merge branch 'main' into dask_mg_support
Intron7 May 8, 2024
b6f436f
test docs
Intron7 May 8, 2024
ef00052
Merge branch 'main' into dask_mg_support
Intron7 May 13, 2024
4b22562
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 13, 2024
5ed8e68
fix import
Intron7 May 13, 2024
b879ea4
fix rebase
Intron7 May 13, 2024
e1d4e8b
Merge branch 'main' into dask_mg_support
Intron7 May 16, 2024
472ebb6
Merge branch 'main' into dask_mg_support
Intron7 Jun 11, 2024
4f45aef
Merge branch 'main' into dask_mg_support
Intron7 Jun 25, 2024
e0e22ad
Merge branch 'main' into dask_mg_support
Intron7 Jun 26, 2024
f002b9c
(fix): use `to_delayed` and `from_delayed` to submit gram matrix jobs…
ilan-gold Jul 16, 2024
a8b30d3
(fix): use `map_blocks` for job submission in `_get_target_sum_dask` …
ilan-gold Jul 16, 2024
b228345
(fix): remove `extract_partitions` from `mean`/`var` calculation (#221)
ilan-gold Jul 16, 2024
bbf71a5
Merge branch 'main' into dask_mg_support
Intron7 Jul 16, 2024
d7c34ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
7390dd1
remove client from hvg
Intron7 Jul 16, 2024
7fedc07
remove client
Intron7 Jul 16, 2024
36a081e
remove client from scale
Intron7 Jul 16, 2024
6be8d9b
update to fast transform
Intron7 Jul 16, 2024
35da552
(fix): `normalize_total` -> `log1p` -> `pca` with sparse (#217)
ilan-gold Jul 16, 2024
30ccfd3
fix taskgraph
Intron7 Jul 16, 2024
8cd60b4
(feat): use `map_blocks` in gram matrix calculation and and `mean_var…
ilan-gold Jul 18, 2024
edfdde0
Merge branch 'main' into dask_mg_support
Intron7 Jul 18, 2024
4bfd5a1
update pca
Intron7 Jul 18, 2024
8ddba49
use lambda
Intron7 Jul 19, 2024
dea8d3d
remove unused kernel
Intron7 Jul 19, 2024
a71116d
test removed kernel
Intron7 Jul 19, 2024
4a78e73
Merge branch 'main' into dask_mg_support
Intron7 Jul 22, 2024
523c862
Merge branch 'main' into dask_mg_support
Intron7 Aug 5, 2024
3680a68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
f9f5661
Merge branch 'main' into dask_mg_support
Intron7 Aug 9, 2024
3a038e9
Merge branch 'main' into dask_mg_support
Intron7 Aug 13, 2024
9c0e09a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
845054d
Merge branch 'main' into dask_mg_support
Intron7 Aug 13, 2024
383bafb
add outside compute (#245)
Intron7 Aug 13, 2024
7ef0218
Merge branch 'main' into dask_mg_support
Intron7 Aug 19, 2024
98904fe
Merge branch 'main' into dask_mg_support
Intron7 Sep 10, 2024
1d192a5
Merge branch 'main' into dask_mg_support
Intron7 Sep 26, 2024
4f88abc
update utils for lazy compute
Intron7 Sep 26, 2024
727e15b
Merge branch 'dask_mg_support' of https://github.com/scverse/rapids_s…
Intron7 Sep 26, 2024
602f845
update utils
Intron7 Sep 26, 2024
f0d5124
Merge branch 'main' into dask_mg_support
Intron7 Sep 26, 2024
98441ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
118a37a
move test helpers
flying-sheep Sep 26, 2024
b6c2689
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
1b1023a
Merge branch 'main' into dask_mg_support
Intron7 Oct 1, 2024
9235c3e
Merge branch 'main' into dask_mg_support
Intron7 Oct 2, 2024
7315d99
Merge branch 'main' into dask_mg_support
Intron7 Oct 2, 2024
6a4394d
Merge branch 'main' into dask_mg_support
Intron7 Oct 8, 2024
7945775
Merge branch 'main' into dask_mg_support
Intron7 Oct 10, 2024
0201658
Merge branch 'main' into dask_mg_support
Intron7 Oct 14, 2024
ea57084
update typing
Intron7 Oct 14, 2024
13760b7
update normalize
Intron7 Oct 15, 2024
b9e4931
go back to lambda
Intron7 Oct 15, 2024
9308e21
slim down tests
Intron7 Oct 15, 2024
f8d6269
run tests on rapids-24.08
Intron7 Oct 15, 2024
65f941a
compress hvg tests
Intron7 Oct 16, 2024
17ca2ef
remove .todelayed
Intron7 Oct 16, 2024
06ce8e5
remove dask.delayed
Intron7 Oct 16, 2024
6d94835
update qc
Intron7 Oct 17, 2024
b733ab3
Update src/rapids_singlecell/preprocessing/_pca.py
Intron7 Oct 22, 2024
75bbbb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
42ac1f9
Merge branch 'main' into dask_mg_support
Intron7 Oct 22, 2024
c02b0e0
Merge branch 'main' into dask_mg_support
Intron7 Nov 12, 2024
d2d2e45
Update src/rapids_singlecell/preprocessing/_scale.py
Intron7 Nov 12, 2024
d3421ca
add error
Intron7 Nov 12, 2024
8b32156
update tree pca
Intron7 Nov 12, 2024
bb10cda
Update src/rapids_singlecell/preprocessing/_scale.py
Intron7 Nov 13, 2024
1ae74d7
add note
Intron7 Nov 13, 2024
38e4ad0
dask import
Intron7 Nov 13, 2024
9755641
update qc names
Intron7 Nov 13, 2024
fe2aa20
Merge branch 'main' into dask_mg_support
Intron7 Nov 13, 2024
5fd5a97
update
Intron7 Nov 14, 2024
af1faf5
update _check_gpu_X
Intron7 Nov 14, 2024
7366200
update docs
Intron7 Nov 14, 2024
d1a6344
docs update
Intron7 Nov 15, 2024
fb8c825
Merge branch 'main' into dask_mg_support
Intron7 Nov 21, 2024
c65585d
make sure dtype is correct PCA
Intron7 Nov 25, 2024
b7974f9
Merge branch 'main' into dask_mg_support
Intron7 Nov 26, 2024
e7a1118
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

autosummary_generate = True
autodoc_member_order = "bysource"
autodoc_mock_imports = ["cudf", "cuml", "cugraph", "cupy", "cupyx", "pylibraft"]
autodoc_mock_imports = ["cudf", "cuml", "cugraph", "cupy", "cupyx", "pylibraft", "dask"]
default_role = "literal"
napoleon_google_docstring = False
napoleon_numpy_docstring = True
Expand Down Expand Up @@ -108,6 +108,7 @@
"rmm": ("https://docs.rapids.ai/api/rmm/stable/", None),
"statsmodels": ("https://www.statsmodels.org/stable/", None),
"omnipath": ("https://omnipath.readthedocs.io/en/latest/", None),
"dask": ("https://docs.dask.org/en/stable/", None),
}

# List of patterns, relative to source directory, that match files and
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ doc = [
"scanpydoc[typehints,theme]>=0.9.4",
"readthedocs-sphinx-ext",
"sphinx_copybutton",
"dask",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should dask[array,distributed] and dask-cuda be added to the pyproject-toml in dask section?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for building the docs. I think cuml installs all the dask dependencies

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's generally best practice still to install them directly if you use them.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be added to the dependencies.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. I dont want to run into issue with rapids and the conda installation.

]
test = [
"pytest",
Expand Down
35 changes: 35 additions & 0 deletions src/rapids_singlecell/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

import cupy as cp
from cupyx.scipy.sparse import csr_matrix

try:
from dask.array import Array as DaskArray
except ImportError:

class DaskArray:
pass


try:
from dask.distributed import Client as DaskClient
except ImportError:

class DaskClient:
pass
Intron7 marked this conversation as resolved.
Show resolved Hide resolved


def _get_dask_client(client=None):
from dask.distributed import default_client

if client is None or not isinstance(client, DaskClient):
return default_client()
return client
Intron7 marked this conversation as resolved.
Show resolved Hide resolved


def _meta_dense(dtype):
return cp.zeros([0], dtype=dtype)


def _meta_sparse(dtype):
return csr_matrix(cp.array((1.0,), dtype=dtype))
2 changes: 1 addition & 1 deletion src/rapids_singlecell/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from ._neighbors import neighbors
from ._normalize import log1p, normalize_pearson_residuals, normalize_total
from ._pca import pca
from ._qc import calculate_qc_metrics
from ._regress_out import regress_out
from ._scale import scale
from ._scrublet import scrublet, scrublet_simulate_doublets
from ._simple import (
calculate_qc_metrics,
filter_cells,
filter_genes,
filter_highly_variable,
Expand Down
54 changes: 39 additions & 15 deletions src/rapids_singlecell/preprocessing/_hvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import cupy as cp
import numpy as np
import pandas as pd
from cupyx.scipy.sparse import issparse, isspmatrix_csc
from cupyx.scipy.sparse import csr_matrix, issparse, isspmatrix_csc
from scanpy.get import _get_obs_rep

from ._simple import calculate_qc_metrics
from rapids_singlecell._compat import DaskArray, DaskClient, _meta_dense, _meta_sparse

from ._qc import calculate_qc_metrics
from ._utils import _check_gpu_X, _check_nonnegative_integers, _get_mean_var

if TYPE_CHECKING:
Expand Down Expand Up @@ -47,6 +49,7 @@ def highly_variable_genes(
chunksize: int = 1000,
n_samples: int = 10000,
batch_key: str = None,
client: DaskClient | None = None,
) -> None:
"""\
Annotate highly variable genes.
Expand Down Expand Up @@ -116,6 +119,8 @@ def highly_variable_genes(
of enrichment of zeros for each gene (only for `flavor='poisson_gene_selection'`).
batch_key
If specified, highly-variable genes are selected within each batch separately and merged.
client
Dask client to use for computation. If `None`, the default client is used. Only used if `X` is a Dask array.

Returns
-------
Expand Down Expand Up @@ -188,7 +193,12 @@ def highly_variable_genes(

if batch_key is None:
df = _highly_variable_genes_single_batch(
adata, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor
adata,
layer=layer,
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
client=client,
)
else:
df = _highly_variable_genes_batched(
Expand All @@ -198,6 +208,7 @@ def highly_variable_genes(
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
client=client,
)

adata.uns["hvg"] = {"flavor": flavor}
Expand Down Expand Up @@ -267,6 +278,7 @@ def _highly_variable_genes_single_batch(
cutoff: _Cutoffs | int,
n_bins: int = 20,
flavor: Literal["seurat", "cell_ranger"] = "seurat",
client: DaskClient | None = None,
) -> pd.DataFrame:
"""\
See `highly_variable_genes`.
Expand All @@ -277,18 +289,24 @@ def _highly_variable_genes_single_batch(
`highly_variable`, `means`, `dispersions`, and `dispersions_norm`.
"""
X = _get_obs_rep(adata, layer=layer)

_check_gpu_X(X, allow_dask=True)
if hasattr(X, "_view_args"): # AnnData array view
# For compatibility with anndata<0.9
X = X.copy() # Doesn't actually copy memory, just removes View class wrapper
X = X.copy()

if flavor == "seurat":
X = X.copy()
if issparse(X):
X = X.expm1()
if isinstance(X, DaskArray):
if isinstance(X._meta, cp.ndarray):
X = X.map_blocks(lambda X: cp.expm1(X), meta=_meta_dense(X.dtype))
elif isinstance(X._meta, csr_matrix):
X = X.map_blocks(lambda X: X.expm1(), meta=_meta_sparse(X.dtype))
else:
X = cp.expm1(X)
mean, var = _get_mean_var(X, axis=0)
X = X.copy()
if issparse(X):
X = X.expm1()
else:
X = cp.expm1(X)

mean, var = _get_mean_var(X, axis=0, client=client)
mean[mean == 0] = 1e-12
disp = var / mean
if flavor == "seurat": # logarithmized mean as in Seurat
Expand Down Expand Up @@ -407,20 +425,26 @@ def _highly_variable_genes_batched(
n_bins: int,
flavor: Literal["seurat", "cell_ranger"],
cutoff: _Cutoffs | int,
client: DaskClient | None = None,
) -> pd.DataFrame:
adata._sanitize()
batches = adata.obs[batch_key].cat.categories
dfs = []
gene_list = adata.var_names
for batch in batches:
adata_subset = adata[adata.obs[batch_key] == batch]
adata_subset = adata[adata.obs[batch_key] == batch].copy()

calculate_qc_metrics(adata_subset, layer=layer)
calculate_qc_metrics(adata_subset, layer=layer, client=client)
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0
adata_subset = adata_subset[:, filt]
adata_subset = adata_subset[:, filt].copy()

hvg = _highly_variable_genes_single_batch(
adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor
adata_subset,
layer=layer,
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
client=client,
)
hvg.reset_index(drop=False, inplace=True, names=["gene"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

_csr_scale_diff_kernel = r"""
(const int *indptr, const int *indices, {0} *data, const double * std, const int *mask, {0} clipper,int nrows) {
(const int *indptr, const int *indices, {0} *data, const {0} * std, const int *mask, {0} clipper,int nrows) {
int row = blockIdx.x;

if(row >= nrows){
Expand Down
Loading