Skip to content

Commit

Permalink
Passthrough auto device for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Kim committed Aug 18, 2023
1 parent de59581 commit 58bfd20
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions scvi/model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,10 @@ def parse_device_args(
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

# auto accelerator should not default to mps
if accelerator == "auto" and _accelerator == "mps":
elif _accelerator == "mps" and accelerator == "auto":
# auto accelerator should not default to mps
_accelerator = "cpu"

if _accelerator == "mps":
elif _accelerator == "mps" and accelerator != "auto":
warnings.warn(
"`accelerator` has been set to `mps`. Please note that not all PyTorch "
"operations are supported with this backend. Refer to "
Expand All @@ -132,8 +130,11 @@ def parse_device_args(
else:
device_idx = _devices

# auto device should not use multiple devices for non-cpu accelerators
if devices == "auto" and _accelerator != "cpu":
if devices == "auto" and _accelerator == "cpu":
# passthrough auto device for cpu
_devices = devices
elif devices == "auto" and _accelerator != "cpu":
# auto device should not use multiple devices for non-cpu accelerators
_devices = [device_idx]

if return_device == "torch":
Expand Down

0 comments on commit 58bfd20

Please sign in to comment.