Skip to content

Commit

Permalink
Revalidate devices when auto switching mps to cpu (#2247)
Browse files Browse the repository at this point in the history
* Passthrough auto device for cpu

* Reconfigure mps to cpu

* Add release note
  • Loading branch information
martinkim0 authored Aug 18, 2023
1 parent de59581 commit 682a7f7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
7 changes: 4 additions & 3 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits

#### Fixed

- Fix bug where `n_hidden` was not being passed into {class}`scvi.nn.Encoder` in {class}`scvi.model.AmortizedLDA`
{pr}`2229`
- Fix bug where `n_hidden` was not being passed into {class}`scvi.nn.Encoder`
in {class}`scvi.model.AmortizedLDA` {pr}`2229`

#### Changed

Expand All @@ -47,7 +47,8 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary`
stores the summary dataframe, and the `lfc_per_model_per_group` key stores the
per-group LFC.
- `n_hidden` is taking effect in AmortizedLDA.
- Revalidate `devices` when automatically switching from MPS to CPU
accelerator in {func}`scvi.model._utils.parse_device_args` {pr}`2247`.

#### Removed

Expand Down
14 changes: 7 additions & 7 deletions scvi/model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ def parse_device_args(
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

# auto accelerator should not default to mps
if accelerator == "auto" and _accelerator == "mps":
_accelerator = "cpu"

if _accelerator == "mps":
elif _accelerator == "mps" and accelerator == "auto":
# auto accelerator should not default to mps
connector = _AcceleratorConnector(accelerator="cpu", devices=devices)
_accelerator = connector._accelerator_flag
_devices = connector._devices_flag
elif _accelerator == "mps" and accelerator != "auto":
warnings.warn(
"`accelerator` has been set to `mps`. Please note that not all PyTorch "
"operations are supported with this backend. Refer to "
Expand All @@ -132,8 +132,8 @@ def parse_device_args(
else:
device_idx = _devices

# auto device should not use multiple devices for non-cpu accelerators
if devices == "auto" and _accelerator != "cpu":
# auto device should not use multiple devices for non-cpu accelerators
_devices = [device_idx]

if return_device == "torch":
Expand Down

0 comments on commit 682a7f7

Please sign in to comment.