Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 15, 2024
1 parent 741b608 commit 8484a80
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
19 changes: 9 additions & 10 deletions dev/model_user_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
"import os\n",
"import tempfile\n",
"from collections.abc import Sequence\n",
"from typing import Optional\n",
"\n",
"import numpy as np\n",
"import scvi\n",
Expand Down Expand Up @@ -294,10 +293,10 @@
" def setup_anndata(\n",
" cls,\n",
" adata: AnnData,\n",
" batch_key: Optional[str] = None,\n",
" layer: Optional[str] = None,\n",
" batch_key: str | None = None,\n",
" layer: str | None = None,\n",
" **kwargs,\n",
" ) -> Optional[AnnData]:\n",
" ) -> AnnData | None:\n",
" setup_method_args = cls._get_setup_method_args(**locals())\n",
" anndata_fields = [\n",
" LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),\n",
Expand Down Expand Up @@ -1257,9 +1256,9 @@
"@torch.inference_mode()\n",
"def get_latent_representation(\n",
" self,\n",
" adata: Optional[AnnData] = None,\n",
" indices: Optional[Sequence[int]] = None,\n",
" batch_size: Optional[int] = None,\n",
" adata: AnnData | None = None,\n",
" indices: Sequence[int] | None = None,\n",
" batch_size: int | None = None,\n",
") -> np.ndarray:\n",
" r\"\"\"Return the latent representation for each cell.\n",
"\n",
Expand Down Expand Up @@ -1466,10 +1465,10 @@
" def setup_anndata(\n",
" cls,\n",
" adata: AnnData,\n",
" batch_key: Optional[str] = None,\n",
" layer: Optional[str] = None,\n",
" batch_key: str | None = None,\n",
" layer: str | None = None,\n",
" **kwargs,\n",
" ) -> Optional[AnnData]:\n",
" ) -> AnnData | None:\n",
" setup_method_args = cls._get_setup_method_args(**locals())\n",
" anndata_fields = [\n",
" LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),\n",
Expand Down
4 changes: 2 additions & 2 deletions scrna/AutoZI_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10590,7 +10590,7 @@
"ct = adata.obs.str_labels.astype(\"category\")\n",
"codes = np.unique(ct.cat.codes)\n",
"cats = ct.cat.categories\n",
"for ind_cell_type, cell_type in zip(codes, cats):\n",
"for ind_cell_type, cell_type in zip(codes, cats, strict=False):\n",
" is_zi_pred_genelabel_here = is_zi_pred_genelabel[:, ind_cell_type]\n",
" print(\n",
" f\"Fraction of predicted ZI genes for cell type {cell_type} :\",\n",
Expand Down Expand Up @@ -10654,7 +10654,7 @@
],
"source": [
"# With avg expressions > 1\n",
"for ind_cell_type, cell_type in zip(codes, cats):\n",
"for ind_cell_type, cell_type in zip(codes, cats, strict=False):\n",
" mask_sufficient_expression = (\n",
" np.array(adata.X[adata.obs.str_labels.values.reshape(-1) == cell_type, :].mean(axis=0))\n",
" > 1.0\n",
Expand Down
8 changes: 4 additions & 4 deletions scrna/scanvi_fix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15232,8 +15232,8 @@
" models = [model_no_fix, model_fix, model_fix_linear]\n",
" model_names = [\"No fix\", \"Fix\", \"Fix linear\"]\n",
"\n",
" for i, (metric, ylim) in enumerate(zip(metrics, ylims)):\n",
" for j, (model, model_name) in enumerate(zip(models, model_names)):\n",
" for i, (metric, ylim) in enumerate(zip(metrics, ylims, strict=False)):\n",
" for j, (model, model_name) in enumerate(zip(models, model_names, strict=False)):\n",
" plot_metric(axes[i, j], metric, model, model_name, ylim=ylim)\n",
"\n",
" fig.text(-0.01, 0.8, \"Classification loss\", va=\"center\", rotation=\"vertical\")\n",
Expand Down Expand Up @@ -15343,7 +15343,7 @@
" models = [model_no_fix, model_fix, model_fix_linear]\n",
" model_names = [\"No fix\", \"Fix\", \"Fix linear\"]\n",
"\n",
" for model, model_name, ax in zip(models, model_names, axes):\n",
" for model, model_name, ax in zip(models, model_names, axes, strict=False):\n",
" plot_confusion_matrix(ax, model, model_name, subset)\n",
"\n",
" fig.text(0.0, 0.5, \"Observed\", va=\"center\", rotation=\"vertical\")\n",
Expand Down Expand Up @@ -15492,7 +15492,7 @@
" model_names = [\"No fix\", \"Fix\", \"Fix linear\"]\n",
" legend_loc = [\"none\", \"none\", \"right margin\"]\n",
"\n",
" for model, model_name, ax, leg_loc in zip(models, model_names, axes, legend_loc):\n",
" for model, model_name, ax, leg_loc in zip(models, model_names, axes, legend_loc, strict=False):\n",
" plot_latent_mde(ax, model, model_name, subset, leg_loc)\n",
"\n",
" fig.text(0.0, 0.5, \"MDE_2\", va=\"center\", rotation=\"vertical\")\n",
Expand Down

0 comments on commit 8484a80

Please sign in to comment.