Skip to content

Commit

Permalink
Enable CPU mode for compatible extractors
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 26, 2024
1 parent 8c79afd commit ed0fe85
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 9 deletions.
36 changes: 35 additions & 1 deletion plugins/extract/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from __future__ import annotations
import logging
import typing as T

from dataclasses import dataclass, field

import keras
import numpy as np
import torch

from lib.multithreading import MultiThread
from lib.queue_manager import queue_manager
from lib.utils import GetModel
from lib.utils import get_backend
from ._config import Config
from .pipeline import ExtractMedia

Expand Down Expand Up @@ -360,6 +362,38 @@ def get_batch(self, queue: Queue) -> tuple[bool, BatchType]:
"""
raise NotImplementedError

@classmethod
def get_device_context(cls, cpu: bool) -> T.ContextManager:
""" Get a device context manager for running inference on the CPU
Parameters
----------
cpu: bool
``True`` to get a context manager for running on the CPU. ``False`` to get a
context manager for the default device
Returns
-------
ContextManager
The context manager for running ops on the selected device
"""
if cpu:
logger.debug("CPU mode selected. Returning CPU device context")
return keras.device("cpu")

# TODO apple_silicon + directml
if get_backend() == "apple_silicon":
pass
if get_backend() == "directml":
pass

if torch.cuda.is_available():
logger.debug("Cuda available. Returning Cuda device context")
return keras.device("cuda")

logger.debug("Cuda not available. Returning CPU device context")
return keras.device("cpu")

# <<< THREADING METHODS >>> #
def start(self) -> None:
""" Start all threads
Expand Down
10 changes: 8 additions & 2 deletions plugins/extract/detect/mtcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def _validate_kwargs(self) -> dict[str, int | float | list[float]]:
def init_model(self) -> None:
""" Initialize MTCNN Model. """
assert isinstance(self.model_path, list)
self.model = MTCNN(self.model_path, self.batchsize, **self.kwargs)
placeholder_shape = (self.batchsize, self.input_size, self.input_size, 3)
placeholder = np.zeros(placeholder_shape, dtype="float32")

with self.get_device_context(self.config["cpu"]):
self.model = MTCNN(self.model_path, self.batchsize, **self.kwargs)
self.model.detect_faces(placeholder)

def process_input(self, batch: BatchType) -> None:
""" Compile the detection image(s) for prediction
Expand All @@ -85,7 +90,8 @@ def predict(self, feed: np.ndarray) -> np.ndarray:
The batch with the predictions added to the dictionary
"""
assert isinstance(self.model, MTCNN)
prediction, points = self.model.detect_faces(feed)
with self.get_device_context(self.config["cpu"]):
prediction, points = self.model.detect_faces(feed)
logger.trace("prediction: %s, mtcnn_points: %s", # type:ignore
prediction, points)
return prediction
Expand Down
3 changes: 3 additions & 0 deletions plugins/extract/detect/s3fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def init_model(self) -> None:
assert isinstance(self.model_path, str)
confidence = self.config["confidence"] / 100
self.model = S3fd(self.model_path, self.batchsize, confidence)
placeholder_shape = (self.batchsize, self.input_size, self.input_size, 3)
placeholder = np.zeros(placeholder_shape, dtype="float32")
self.model(placeholder)

def process_input(self, batch: BatchType) -> None:
""" Compile the detection image(s) for prediction """
Expand Down
9 changes: 6 additions & 3 deletions plugins/extract/mask/bisenet_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ def init_model(self) -> None:
""" Initialize the BiSeNet Face Parsing model. """
assert isinstance(self.model_path, str)
lbls = 5 if self._is_faceswap else 19
self.model = BiSeNet(self.model_path, self.batchsize, self.input_size, lbls)
placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3),
dtype="float32")
self.model(placeholder)

with self.get_device_context(self.config["cpu"]):
self.model = BiSeNet(self.model_path, self.batchsize, self.input_size, lbls)
self.model(placeholder)

def process_input(self, batch: BatchType) -> None:
""" Compile the detected faces for prediction """
Expand All @@ -124,7 +126,8 @@ def process_input(self, batch: BatchType) -> None:

def predict(self, feed: np.ndarray) -> np.ndarray:
""" Run model to get predictions """
return self.model(feed)[0]
with self.get_device_context(self.config["cpu"]):
return self.model(feed)[0]

def process_output(self, batch: BatchType) -> None:
""" Compile found faces for output """
Expand Down
9 changes: 6 additions & 3 deletions plugins/extract/recognition/vgg_face2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ def __init__(self, *args, **kwargs) -> None: # pylint:disable=unused-argument
def init_model(self) -> None:
""" Initialize VGG Face 2 Model. """
assert isinstance(self.model_path, str)
self.model = VGGFace2(self.input_size, self.model_path, self.batchsize)
placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3),
dtype="float32")
self.model(placeholder)

with self.get_device_context(self.config["cpu"]):
self.model = VGGFace2(self.input_size, self.model_path, self.batchsize)
self.model(placeholder)

def process_input(self, batch: BatchType) -> None:
""" Compile the detected faces for prediction """
Expand All @@ -90,7 +92,8 @@ def predict(self, feed: np.ndarray) -> np.ndarray:
numpy.ndarray
The encodings for the face
"""
retval = self.model(feed)
with self.get_device_context(self.config["cpu"]):
retval = self.model(feed)
assert isinstance(retval, np.ndarray)
return retval

Expand Down

0 comments on commit ed0fe85

Please sign in to comment.