Skip to content

Commit

Permalink
Merge branch 'main' into chex
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 authored Jul 29, 2023
2 parents 1eb0c97 + ac1743b commit 8d7c1d9
Show file tree
Hide file tree
Showing 20 changed files with 369 additions and 298 deletions.
9 changes: 0 additions & 9 deletions .devcontainer/Dockerfile

This file was deleted.

26 changes: 0 additions & 26 deletions .devcontainer/devcontainer.json

This file was deleted.

5 changes: 4 additions & 1 deletion docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
- Expose {meth}`torch.save` keyword arguments in {class}`scvi.model.base.BaseModelClass.save`
and {class}`scvi.external.GIMVI.save` {pr}`2200`.
- Add `model_kwargs` and `train_kwargs` arguments to {meth}`scvi.autotune.ModelTuner.fit` {pr}`2203`.
- Add `datasplitter_kwargs` to model `train` methods {pr}`2204`.

#### Changed

Expand All @@ -47,7 +48,7 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits

## Version 1.0

### 1.0.3 (2023-MM-DD)
### 1.0.3 (2023-07-DD)

### Changed

Expand All @@ -59,6 +60,8 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
argument is ignored {pr}`2162`.
- Fix missing docstring for `unlabeled_category` in
{class}`scvi.model.SCANVI.setup_anndata` and reorder arguments {pr}`2189`.
- Fix Pandas 2.0 unpickling error in {meth}`scvi.model.base.BaseModelClas.convert_legacy_save`
by switching to {func}`pandas.read_pickle` for the setup dictionary {pr}`2212`.

### 1.0.2 (2023-07-05)

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["hatchling"]

[project]
name = "scvi-tools"
version = "1.0.2"
version = "1.0.3"
description = "Deep probabilistic analysis of single-cell omics data."
readme = "README.md"
requires-python = ">=3.9"
Expand Down Expand Up @@ -109,11 +109,17 @@ optional = [
] # all optional user functionality

tutorials = [
"cell2location",
"leidenalg",
"muon",
"plotnine",
"pooch",
"pynndescent",
"igraph",
"scikit-misc",
"scrublet",
"scvi-tools[optional]",
"squidpy",
] # dependencies for all tutorials

all = ["scvi-tools[dev,docs,tutorials]"] # all dependencies
Expand Down
23 changes: 15 additions & 8 deletions scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -129,12 +130,13 @@ def train(
max_epochs: int = 400,
lr: float = 3e-3,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
validation_size: Optional[float] = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 1024,
plan_kwargs: Optional[dict] = None,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
early_stopping: bool = True,
early_stopping_patience: int = 15,
early_stopping_min_delta: float = 0.0,
Expand All @@ -160,6 +162,8 @@ def train(
sequential order of the data according to `validation_size` and `train_size` percentages.
batch_size
Minibatch size to use during training.
datasplitter_kwargs
Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`.
plan_kwargs
Keyword args for :class:`~scvi.train.TrainingPlan`.
early_stopping
Expand All @@ -178,6 +182,8 @@ def train(
else:
plan_kwargs = update_dict

datasplitter_kwargs = datasplitter_kwargs or {}

if "callbacks" in kwargs:
kwargs["callbacks"] += [ClampCallback()]
else:
Expand Down Expand Up @@ -209,6 +215,7 @@ def train(
validation_size=validation_size,
batch_size=batch_size,
shuffle_set_split=shuffle_set_split,
**datasplitter_kwargs,
)
training_plan = TrainingPlan(self.module, **plan_kwargs)
runner = TrainRunner(
Expand All @@ -228,10 +235,10 @@ def setup_anndata(
cls,
adata: AnnData,
size_factor_key: str,
batch_key: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
layer: Optional[str] = None,
batch_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
layer: str | None = None,
**kwargs,
):
"""%(summary)s.
Expand Down
5 changes: 5 additions & 0 deletions scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def train(
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
**kwargs,
):
Expand All @@ -193,6 +194,8 @@ def train(
sequential order of the data according to `validation_size` and `train_size` percentages.
batch_size
Minibatch size to use during training.
datasplitter_kwargs
Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`.
plan_kwargs
Keyword args for model-specific Pytorch Lightning task. Keyword arguments passed
to `train()` will overwrite values present in `plan_kwargs`, when appropriate.
Expand All @@ -204,6 +207,7 @@ def train(
devices=devices,
return_device="torch",
)
datasplitter_kwargs = datasplitter_kwargs or {}

self.trainer = Trainer(
max_epochs=max_epochs,
Expand All @@ -220,6 +224,7 @@ def train(
validation_size=validation_size,
batch_size=batch_size,
shuffle_set_split=shuffle_set_split,
**datasplitter_kwargs,
)
ds.setup()
train_dls.append(ds.train_dataloader())
Expand Down
28 changes: 18 additions & 10 deletions scvi/external/scbasset/_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union
from typing import Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -93,16 +95,17 @@ def train(
max_epochs: int = 1000,
lr: float = 0.01,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
validation_size: Optional[float] = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
early_stopping: bool = True,
early_stopping_monitor: str = "auroc_train",
early_stopping_mode: Literal["min", "max"] = "max",
early_stopping_min_delta: float = 1e-6,
plan_kwargs: Optional[dict] = None,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
**trainer_kwargs,
):
"""Train the model.
Expand Down Expand Up @@ -137,6 +140,8 @@ def train(
early_stopping_min_delta
Minimum change in the monitored quantity to qualify as an improvement,
i.e. an absolute change of less than min_delta, will count as no improvement.
datasplitter_kwargs
Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`.
plan_kwargs
Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
Expand All @@ -152,6 +157,8 @@ def train(
if plan_kwargs is not None:
custom_plan_kwargs.update(plan_kwargs)

datasplitter_kwargs = datasplitter_kwargs or {}

data_splitter = DataSplitter(
self.adata_manager,
train_size=train_size,
Expand All @@ -163,6 +170,7 @@ def train(
REGISTRY_KEYS.X_KEY: np.float32,
REGISTRY_KEYS.DNA_CODE_KEY: np.int64,
},
**datasplitter_kwargs,
)
training_plan = TrainingPlan(self.module, **custom_plan_kwargs)

Expand Down Expand Up @@ -238,8 +246,8 @@ def rename_members(tarball):

@dependencies("Bio")
def _get_motif_library(
self, tf: str, genome: str = "human", motif_dir: Optional[str] = None
) -> Tuple[List[str], List[str]]:
self, tf: str, genome: str = "human", motif_dir: str | None = None
) -> tuple[list[str], list[str]]:
"""Load sequences with a TF motif injected from a pre-computed library
Parameters
Expand Down Expand Up @@ -292,8 +300,8 @@ def get_tf_activity(
self,
tf: str,
genome: str = "human",
motif_dir: Optional[str] = None,
lib_size_norm: Optional[bool] = True,
motif_dir: str | None = None,
lib_size_norm: bool | None = True,
batch_size: int = 256,
) -> np.ndarray:
"""Infer transcription factor activity using a motif injection procedure.
Expand Down Expand Up @@ -398,8 +406,8 @@ def setup_anndata(
cls,
adata: AnnData,
dna_code_key: str,
layer: Optional[str] = None,
batch_key: Optional[str] = None,
layer: str | None = None,
batch_key: str | None = None,
**kwargs,
):
"""%(summary)s.
Expand Down
26 changes: 17 additions & 9 deletions scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import io
import logging
import warnings
from contextlib import redirect_stdout
from typing import List, Optional, Sequence, Union
from typing import Sequence

import anndata
import numpy as np
Expand Down Expand Up @@ -98,8 +100,8 @@ def __init__(
def from_scvi_model(
cls,
scvi_model: SCVI,
adata: Optional[AnnData] = None,
restrict_to_batch: Optional[str] = None,
adata: AnnData | None = None,
restrict_to_batch: str | None = None,
doublet_ratio: int = 2,
**classifier_kwargs,
):
Expand Down Expand Up @@ -243,7 +245,7 @@ def create_doublets(
cls,
adata_manager: AnnDataManager,
doublet_ratio: int,
indices: Optional[Sequence[int]] = None,
indices: Sequence[int] | None = None,
seed: int = 1,
) -> AnnData:
"""Simulate doublets.
Expand Down Expand Up @@ -290,12 +292,13 @@ def train(
max_epochs: int = 400,
lr: float = 1e-3,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
validation_size: Optional[float] = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
plan_kwargs: Optional[dict] = None,
datasplitter_kwargs: dict | None = None,
plan_kwargs: dict | None = None,
early_stopping: bool = True,
early_stopping_patience: int = 30,
early_stopping_min_delta: float = 0.0,
Expand All @@ -321,6 +324,8 @@ def train(
sequential order of the data according to `validation_size` and `train_size` percentages.
batch_size
Minibatch size to use during training.
datasplitter_kwargs
Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`.
plan_kwargs
Keyword args for :class:`~scvi.train.ClassifierTrainingPlan`. Keyword arguments passed to
early_stopping
Expand All @@ -341,6 +346,8 @@ def train(
else:
plan_kwargs = update_dict

datasplitter_kwargs = datasplitter_kwargs or {}

if early_stopping:
early_stopping_callback = [
LoudEarlyStopping(
Expand All @@ -367,6 +374,7 @@ def train(
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
batch_size=batch_size,
**datasplitter_kwargs,
)
training_plan = ClassifierTrainingPlan(
self.module, self.n_labels, **plan_kwargs
Expand Down Expand Up @@ -437,8 +445,8 @@ def auto_forward(module, x):
def setup_anndata(
cls,
adata: AnnData,
labels_key: Optional[str] = None,
layer: Optional[str] = None,
labels_key: str | None = None,
layer: str | None = None,
**kwargs,
):
"""%(summary)s.
Expand Down
Loading

0 comments on commit 8d7c1d9

Please sign in to comment.