Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): _first_pass_qc single dispatch refactor #180

Open
wants to merge 4 commits into
base: dask_mg_support
Choose a base branch
from

Conversation

ilan-gold
Copy link

I added a separate file here containing the refactor. It drops the number of lines by 50, and that's only first pass (second would be more).

To reproduce my benchmark, I ran (on our current cluster setup):

CUDA_VISIBLE_DEVICES="0,1,2,3" srun --gres=gpu:4 --partition=dc-gpu --account=training2406 --reservation=gpuhack24-2024-04-25 --time=480  --pty /bin/bash -i

and then the following script in a file called rsc_example.py via python rsc_example.py refactored or python rsc_example.py current. This script prints out the time taken and the AnnData object (so one can see the QC is calculated). For me, I get about 800-900 milliseconds on both implementations.

import time
import sys

import anndata

import dask
from dask_cuda import LocalCUDACluster
from dask.distributed import Client

import cudf
from cuml.dask.common.part_utils import _extract_partitions
import cupy as cp
from cupyx.scipy import sparse

import h5py
import rapids_singlecell as rsc
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator



def set_mem():
    rmm.reinitialize(managed_memory=True)
    cp.cuda.set_allocator(rmm_cupy_allocator)

def read_with_filter(client,
                     sample_file, batch_size = 50000):
    """
    Reads an h5ad file and applies cell and geans count filter. Dask Array is
    used allow partitioning the input file. This function supports multi-GPUs.
    """

    # Path in h5 file
    _data = '/X/data'
    _index = '/X/indices'
    _indprt = '/X/indptr'
    # _genes = '/var/ensembl_ids'
    #_genes = '/var/ensembl_id'
    _genes = '/var/_index'
    #_genes = '/var/feature_id'
    _barcodes = '/obs/_index'

    @dask.delayed
    def _read_partition_to_sparse_matrix(sample_file,
                                         total_cols, batch_start, batch_end,
                                         ):
        with h5py.File(sample_file, 'r') as h5f:
            indptrs = h5f[_indprt]
            start_ptr = indptrs[batch_start]
            end_ptr = indptrs[batch_end]

            # Read all things data and index
            sub_data = cp.array(h5f[_data][start_ptr:end_ptr])
            sub_indices = cp.array(h5f[_index][start_ptr:end_ptr])

            # recompute the row pointer for the partial dataset
            sub_indptrs  = cp.array(indptrs[batch_start:(batch_end + 1)])
            sub_indptrs = sub_indptrs - sub_indptrs[0]

        # Reconstruct partial sparse array
        partial_sparse_array = cp.sparse.csr_matrix(
            (sub_data, sub_indices, sub_indptrs),
            shape=(batch_end - batch_start, total_cols))
            
        return partial_sparse_array


    with h5py.File(sample_file, 'r') as h5f:
        # Compute the number of cells to read
        indptr = h5f[_indprt]
        vars= h5f["/var/"]
        print(vars.keys())
        genes = cudf.Series(h5f[_genes], dtype=cp.dtype('object'))

        total_cols = genes.shape[0]
        max_cells = indptr.shape[0] - 1

    dls = []
    for batch_start in range(0, max_cells, batch_size):
        actual_batch_size = min(batch_size, max_cells - batch_start)
        dls.append(dask.array.from_delayed(
                   (_read_partition_to_sparse_matrix)
                   (sample_file,
                    total_cols,
                    batch_start,
                    batch_start + actual_batch_size),
                   dtype=cp.float32,
                   meta=sparse.csr_matrix(cp.array((1.,))),
                   shape=(actual_batch_size, total_cols)))

    dask_sparse_arr =  dask.array.concatenate(dls)
    dask_sparse_arr = dask_sparse_arr.persist()
    return dask_sparse_arr, genes

if __name__ == '__main__':
    preprocessing_gpus="0, 1, 2, 3"
    cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=preprocessing_gpus)
    client = Client(cluster)
    set_mem()
    client.run(set_mem)
        
    dask_sparse_arr, genes = read_with_filter(
        client, 
        "/p/scratch/training2406/team_scverse/scverse_data/1M_brain_cells_10X.sparse.h5ad", 
        batch_size=50000
    )
    dask_sparse_arr = dask_sparse_arr.persist()

    dask_sparse_arr.compute_chunk_sizes()
    adata = anndata.AnnData(dask_sparse_arr)
    rsc.pp.flag_gene_family(adata, gene_family_name="MT", gene_family_prefix="mt-")

    start = time.time()
    funcs = {
        "refactored": rsc.pp.calculate_qc_metrics_refactored,
        "current": rsc.pp.calculate_qc_metrics
    }
    
    funcs[sys.argv[1]](adata, qc_vars = "MT",client=client)
    print('TIME TAKEN:', time.time() - start)
    print('QCed ANNDATA:', adata)
    client.retire_workers()
    client.shutdown()

@Intron7
Copy link
Member

Intron7 commented May 4, 2024

@ilan-gold singledispatch and and docs dont work together with gpu arrays. I have no Idea why that is. To get docs to run I had to refactor normalize_total not to use it. If you figure out how docs will run with this we can talk about it some more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants