diff --git a/plugins/extract/_base.py b/plugins/extract/_base.py index 9ae4441407..a8c8f7e932 100644 --- a/plugins/extract/_base.py +++ b/plugins/extract/_base.py @@ -9,7 +9,7 @@ import numpy as np import torch -from keras.backend import device_scope +from keras import device from lib.logger import parse_class_init from lib.multithreading import MultiThread @@ -424,7 +424,7 @@ def get_device_context(cls, cpu: bool) -> T.ContextManager: """ if cpu: logger.debug("CPU mode selected. Returning CPU device context") - return device_scope("cpu") + return device("cpu") # TODO apple_silicon + directml if get_backend() == "apple_silicon": @@ -434,10 +434,10 @@ def get_device_context(cls, cpu: bool) -> T.ContextManager: if torch.cuda.is_available(): logger.debug("Cuda available. Returning Cuda device context") - return device_scope("cuda") + return device("cuda") logger.debug("Cuda not available. Returning CPU device context") - return device_scope("cpu") + return device("cpu") # <<< THREADING METHODS >>> # def start(self) -> None: