Skip to content

Commit

Permalink
extract: Fix keras.device to call keras.backend.device_scope
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Apr 16, 2024
1 parent 61bd910 commit 2328c5e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions plugins/extract/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import typing as T
from dataclasses import dataclass, field

import keras
import numpy as np
import torch
from keras.backend import device_scope

from lib.logger import parse_class_init
from lib.multithreading import MultiThread
Expand Down Expand Up @@ -424,7 +424,7 @@ def get_device_context(cls, cpu: bool) -> T.ContextManager:
"""
if cpu:
logger.debug("CPU mode selected. Returning CPU device context")
return keras.device("cpu")
return device_scope("cpu")

# TODO apple_silicon + directml
if get_backend() == "apple_silicon":
Expand All @@ -434,10 +434,10 @@ def get_device_context(cls, cpu: bool) -> T.ContextManager:

if torch.cuda.is_available():
logger.debug("Cuda available. Returning Cuda device context")
return keras.device("cuda")
return device_scope("cuda")

logger.debug("Cuda not available. Returning CPU device context")
return keras.device("cpu")
return device_scope("cpu")

# <<< THREADING METHODS >>> #
def start(self) -> None:
Expand Down

0 comments on commit 2328c5e

Please sign in to comment.