diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 5ff3e1e464..63d4e16134 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -36,8 +36,8 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits #### Fixed -- Fix bug where `n_hidden` was not being passed into {class}`scvi.nn.Encoder` in {class}`scvi.model.AmortizedLDA` - {pr}`2229` +- Fix bug where `n_hidden` was not being passed into {class}`scvi.nn.Encoder` + in {class}`scvi.model.AmortizedLDA` {pr}`2229` #### Changed @@ -47,7 +47,8 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits method {pr}`2173`. `metrics["diff_exp"]` is now a dictionary where the `summary` stores the summary dataframe, and the `lfc_per_model_per_group` key stores the per-group LFC. -- `n_hidden` is taking effect in AmortizedLDA. +- Revalidate `devices` when automatically switching from MPS to CPU + accelerator in {func}`scvi.model._utils.parse_device_args` {pr}`2247`. #### Removed diff --git a/scvi/model/_utils.py b/scvi/model/_utils.py index 7c5de93a61..89bd71c345 100644 --- a/scvi/model/_utils.py +++ b/scvi/model/_utils.py @@ -110,12 +110,12 @@ def parse_device_args( UserWarning, stacklevel=settings.warnings_stacklevel, ) - - # auto accelerator should not default to mps - if accelerator == "auto" and _accelerator == "mps": - _accelerator = "cpu" - - if _accelerator == "mps": + elif _accelerator == "mps" and accelerator == "auto": + # auto accelerator should not default to mps + connector = _AcceleratorConnector(accelerator="cpu", devices=devices) + _accelerator = connector._accelerator_flag + _devices = connector._devices_flag + elif _accelerator == "mps" and accelerator != "auto": warnings.warn( "`accelerator` has been set to `mps`. Please note that not all PyTorch " "operations are supported with this backend. Refer to " @@ -132,8 +132,8 @@ def parse_device_args( else: device_idx = _devices - # auto device should not use multiple devices for non-cpu accelerators if devices == "auto" and _accelerator != "cpu": + # auto device should not use multiple devices for non-cpu accelerators _devices = [device_idx] if return_device == "torch":