Skip to content

Commit

Permalink
extract: Fix device selection for plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Apr 19, 2024
1 parent d5c0dd8 commit ac3cda3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions plugins/extract/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
import torch
from keras.backend import device_scope
from keras import device

from lib.logger import parse_class_init
from lib.multithreading import MultiThread
Expand Down Expand Up @@ -424,7 +424,7 @@ def get_device_context(cls, cpu: bool) -> T.ContextManager:
"""
if cpu:
logger.debug("CPU mode selected. Returning CPU device context")
return device_scope("cpu")
return device("cpu")

# TODO apple_silicon + directml
if get_backend() == "apple_silicon":
Expand All @@ -434,10 +434,10 @@ def get_device_context(cls, cpu: bool) -> T.ContextManager:

if torch.cuda.is_available():
logger.debug("Cuda available. Returning Cuda device context")
return device_scope("cuda")
return device("cuda")

logger.debug("Cuda not available. Returning CPU device context")
return device_scope("cpu")
return device("cpu")

# <<< THREADING METHODS >>> #
def start(self) -> None:
Expand Down

0 comments on commit ac3cda3

Please sign in to comment.