diff --git a/.github/workflows/test_linux_cuda.yml b/.github/workflows/test_linux_cuda.yml index 492fbffeba..0eca4d65c8 100644 --- a/.github/workflows/test_linux_cuda.yml +++ b/.github/workflows/test_linux_cuda.yml @@ -1,6 +1,8 @@ name: test (cuda) on: + push: + branches: [main, "[0-9]+.[0-9]+.x"] #this is new pull_request: branches: [main, "[0-9]+.[0-9]+.x"] types: [labeled, synchronize, opened] @@ -31,6 +33,7 @@ jobs: container: image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base + #image: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ env.BRANCH_NAME }}-base options: --user root --gpus all name: integration @@ -40,11 +43,24 @@ jobs: PYTHON: ${{ matrix.python }} steps: + #- name: Get the current branch name + # id: vars + # run: echo "BRANCH_NAME=$(echo $GITHUB_REF | awk -F'/' '{print $3}')" >> $GITHUB_ENV + - uses: actions/checkout@v4 - - run: | + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: "pip" + cache-dependency-path: "**/pyproject.toml" + + - name: Install dependencies + run: | python -m pip install --upgrade pip wheel uv python -m uv pip install --system "scvi-tools[tests] @ ." + python -m pip install jax[cuda] + python -m pip install nvidia-nccl-cu12 - name: Run pytest env: diff --git a/src/scvi/external/poissonvi/_model.py b/src/scvi/external/poissonvi/_model.py index a27d3b0c60..c7b7bc61c0 100644 --- a/src/scvi/external/poissonvi/_model.py +++ b/src/scvi/external/poissonvi/_model.py @@ -235,8 +235,11 @@ def get_accessibility_estimates( @torch.inference_mode() def get_region_factors(self): - """Return region-specific factors.""" - region_factors = self.module.decoder.px_scale_decoder[-2].bias.numpy() + """Return region-specific factors. CPU/GPU dependent""" + if self.device.type == "cpu": + region_factors = self.module.decoder.px_scale_decoder[-2].bias.numpy() + else: + region_factors = self.module.decoder.px_scale_decoder[-2].bias.cpu().numpy() # gpu if region_factors is None: raise RuntimeError("region factors were not included in this model") return region_factors diff --git a/tests/model/test_pyro.py b/tests/model/test_pyro.py index 1c0e4a946c..bc647bb665 100644 --- a/tests/model/test_pyro.py +++ b/tests/model/test_pyro.py @@ -6,6 +6,7 @@ import numpy as np import pyro import pyro.distributions as dist +import pytest import torch from pyro import clear_param_store from pyro.infer.autoguide import AutoNormal, init_to_mean @@ -215,6 +216,7 @@ def test_pyro_bayesian_regression_low_level( ] +@pytest.mark.optional def test_pyro_bayesian_regression(accelerator: str, devices: list | str | int, save_path: str): adata = synthetic_iid() adata_manager = _create_indices_adata_manager(adata) @@ -277,6 +279,7 @@ def test_pyro_bayesian_regression(accelerator: str, devices: list | str | int, s np.testing.assert_array_equal(linear_median_new, linear_median) +@pytest.mark.optional def test_pyro_bayesian_regression_jit( accelerator: str, devices: list | str | int,