Skip to content

Commit

Permalink
JaxSCVI (#1367)
Browse files Browse the repository at this point in the history
* add dependencies

* dummy collate

* initial flax

* latent rep

* history and jit get_ funcs

* updated weight init

* add dropout

* add rngs in jit

* cleaner rngs

* module fixes

* add batchnorm

* more context, fixes

* dense init

* bound module for inference

* fix imports

* update batchnorm

* no extras for jax depend

* add poisson, not implemented

* docs

* updates

* no dropout decoder

* eps

* use other NB

* back to other

* jax > numba

* better cpu support

* release notes

* Update v0.15.0.rst

* speed up with random choice

* speed up with random choice

* better outputs from vae

* jax config

* keyboard interrupt

* add level to rich handler

* Update scvi/data/_utils.py

* Update scvi/data/_utils.py

* address comments
  • Loading branch information
adamgayoso authored Feb 28, 2022
1 parent 5ca65cc commit 7ba36b1
Show file tree
Hide file tree
Showing 12 changed files with 570 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/api/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Model
model.TOTALVI
model.MULTIVI
model.AmortizedLDA
model.JaxSCVI



Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
pytorch_lightning=("https://pytorch-lightning.readthedocs.io/en/stable/", None),
pyro=("http://docs.pyro.ai/en/stable/", None),
pymde=("https://pymde.org/", None),
flax=("https://flax.readthedocs.io/en/latest/", None),
jax=("https://jax.readthedocs.io/en/latest/", None),
)


Expand Down
5 changes: 5 additions & 0 deletions docs/release_notes/v0.15.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ for scvi-tools and stores necessary information, rather than adding additional f
:alt: Schematic of data handling strategy with AnnDataManager

Schematic of data handling strategy with :class:`~scvi.data.AnnDataManager`

We also have an exciting new experimental Jax-based scVI implementation via :class:`~scvi.model.JaxSCVI`. While this implementation has limited functionality, we have found it to be substantially faster than the PyTorch-based implementation. For example, on a 10-core Intel CPU, Jax on only a CPU can be as fast as PyTorch with a GPU (RTX3090). We will be planning further Jax integrations in the next releases.

Changes
~~~~~~~
Expand All @@ -32,6 +34,7 @@ Changes
- Fix for :class:`~scvi.external.SOLO` when :class:`~scvi.model.SCVI` was setup with a `labels_key` (`#1354`_)
- Updates to tutorials (`#1369`_, `#1371`_)
- Furo docs theme (`#1290`_)
- Add :class:`scvi.model.JaxSCVI` and :class:`scvi.module.JaxVAE`, drop Numba dependency for checking if data is count data (`#1367`_).

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -75,3 +78,5 @@ Contributors
.. _`#1369`: https://github.com/YosefLab/scvi-tools/pull/1369
.. _`#1371`: https://github.com/YosefLab/scvi-tools/pull/1371
.. _`#1290`: https://github.com/YosefLab/scvi-tools/pull/1290
.. _`#1367`: https://github.com/YosefLab/scvi-tools/pull/1367

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,25 @@ black = {version = ">=22.1", optional = true}
codecov = {version = ">=2.0.8", optional = true}
docrep = ">=0.3.2"
flake8 = {version = ">=3.7.7", optional = true}
flax = "*"
furo = {version = ">=2022.2.14.1", optional = true}
h5py = ">=2.9.0"
importlib-metadata = {version = "^1.0", python = "<3.8"}
ipython = {version = ">=7.20", optional = true, python = ">=3.7"}
ipywidgets = "*"
isort = {version = ">=5.7", optional = true}
jax = ">=0.3"
jupyter = {version = ">=1.0", optional = true}
leidenalg = {version = "*", optional = true}
loompy = {version = ">=3.0.6", optional = true}
nbconvert = {version = ">=5.4.0", optional = true}
nbformat = {version = ">=4.4.0", optional = true}
nbsphinx = {version = "*", optional = true}
nbsphinx-link = {version = "*", optional = true}
numba = ">=0.41.0"
numpy = ">=1.17.0"
numpyro = "*"
openpyxl = ">=3.0"
optax = "*"
pandas = ">=1.0"
pre-commit = {version = ">=2.7.1", optional = true}
pymde = {version = "*", optional = true}
Expand Down
40 changes: 37 additions & 3 deletions scvi/_settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -38,6 +39,10 @@ class ScviConfig:
To set the number of threads PyTorch will use
>>> scvi.settings.num_threads = 2
To prevent Jax from preallocating GPU memory on start (default)
>>> scvi.settings.jax_preallocate_gpu_memory = False
"""

def __init__(
Expand All @@ -49,9 +54,9 @@ def __init__(
logging_dir: str = "./scvi_log/",
dl_num_workers: int = 0,
dl_pin_memory_gpu_training: bool = True,
jax_preallocate_gpu_memory: bool = False,
):

self.verbosity = verbosity
self.seed = seed
self.batch_size = batch_size
if progress_bar_style not in ["rich", "tqdm"]:
Expand All @@ -61,6 +66,8 @@ def __init__(
self.dl_num_workers = dl_num_workers
self.dl_pin_memory_gpu_training = dl_pin_memory_gpu_training
self._num_threads = None
self.jax_preallocate_gpu_memory = jax_preallocate_gpu_memory
self.verbosity = verbosity

@property
def batch_size(self) -> int:
Expand Down Expand Up @@ -171,7 +178,9 @@ def verbosity(self, level: Union[str, int]):
console = Console(force_terminal=True)
if console.is_jupyter is True:
console.is_jupyter = False
ch = RichHandler(show_path=False, console=console, show_time=False)
ch = RichHandler(
level=level, show_path=False, console=console, show_time=False
)
formatter = logging.Formatter("%(message)s")
ch.setFormatter(formatter)
scvi_logger.addHandler(ch)
Expand All @@ -185,10 +194,35 @@ def reset_logging_handler(self):
This is useful if piping outputs to a file.
"""
scvi_logger.removeHandler(scvi_logger.handlers[0])
ch = RichHandler(show_path=False, show_time=False)
ch = RichHandler(level=self._verbosity, show_path=False, show_time=False)
formatter = logging.Formatter("%(message)s")
ch.setFormatter(formatter)
scvi_logger.addHandler(ch)

@property
def jax_preallocate_gpu_memory(self):
"""
Jax GPU memory allocation settings.
If False, Jax will ony preallocate GPU memory it needs.
If float in (0, 1), Jax will preallocate GPU memory to that
fraction of the GPU memory.
"""
return self._jax_gpu

@jax_preallocate_gpu_memory.setter
def jax_preallocate_gpu_memory(self, value: Union[float, bool]):
# see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation
if value is False:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
elif isinstance(value, float):
if value >= 1 or value <= 0:
raise ValueError("Need to use a value between 0 and 1")
# format is ".XX"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(value)[1:4]
else:
raise ValueError("value not understood, need bool or float in (0, 1)")
self._jax_gpu = value


settings = ScviConfig()
36 changes: 16 additions & 20 deletions scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import anndata
import h5py
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import scipy.sparse as sp_sparse
from anndata._core.sparse_dataset import SparseDataset
from numba import boolean, float32, float64, int32, int64, vectorize

from . import _constants

Expand Down Expand Up @@ -177,7 +178,8 @@ def _assign_adata_uuid(adata: anndata.AnnData, overwrite: bool = False) -> None:


def _check_nonnegative_integers(
data: Union[pd.DataFrame, np.ndarray, sp_sparse.spmatrix, h5py.Dataset]
data: Union[pd.DataFrame, np.ndarray, sp_sparse.spmatrix, h5py.Dataset],
n_to_check: int = 20,
):
"""Approximately checks values of data to ensure it is count data."""

Expand All @@ -194,24 +196,18 @@ def _check_nonnegative_integers(
else:
raise TypeError("data type not understood")

n = len(data)
inds = np.random.permutation(n)[:20]
check = data.flat[inds]
return ~np.any(_is_not_count(check))


@vectorize(
[
boolean(int32),
boolean(int64),
boolean(float32),
boolean(float64),
],
target="parallel",
cache=True,
)
def _is_not_count(d):
return d < 0 or d % 1 != 0
inds = np.random.choice(len(data), size=(n_to_check,))
check = jax.device_put(data.flat[inds], device=jax.devices("cpu")[0])
negative, non_integer = _is_not_count_val(check)
return not (negative or non_integer)


@jax.jit
def _is_not_count_val(data: jnp.ndarray):
negative = jnp.any(data < 0)
non_integer = jnp.any(data % 1 != 0)

return negative, non_integer


def _get_batch_mask_protein_data(
Expand Down
11 changes: 11 additions & 0 deletions scvi/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class AnnDataLoader(DataLoader):
If ``None``, defaults to all registered data.
data_loader_kwargs
Keyword arguments for :class:`~torch.utils.data.DataLoader`
iter_ndarray
Whether to iterate over numpy arrays instead of torch tensors
"""

def __init__(
Expand All @@ -116,6 +118,7 @@ def __init__(
batch_size=128,
data_and_attributes: Optional[dict] = None,
drop_last: Union[bool, int] = False,
iter_ndarray: bool = False,
**data_loader_kwargs,
):

Expand Down Expand Up @@ -158,4 +161,12 @@ def __init__(
# do not touch batch size here, sampler gives batched indices
self.data_loader_kwargs.update({"sampler": sampler, "batch_size": None})

if iter_ndarray:
self.data_loader_kwargs.update({"collate_fn": _dummy_collate})

super().__init__(self.dataset, **self.data_loader_kwargs)


def _dummy_collate(b):
"""Dummy collate to have dataloader return numpy ndarrays."""
return b
2 changes: 2 additions & 0 deletions scvi/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ._autozi import AUTOZI
from ._condscvi import CondSCVI
from ._destvi import DestVI
from ._jaxscvi import JaxSCVI
from ._linear_scvi import LinearSCVI
from ._multivi import MULTIVI
from ._peakvi import PEAKVI
Expand All @@ -22,4 +23,5 @@
"MULTIVI",
"AmortizedLDA",
"utils",
"JaxSCVI",
]
Loading

0 comments on commit 7ba36b1

Please sign in to comment.