Skip to content

Commit

Permalink
plugins.model - remove exclude_gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 25, 2024
1 parent fb4ed01 commit be2e8c5
Showing 1 changed file with 4 additions and 31 deletions.
35 changes: 4 additions & 31 deletions plugins/train/model/_base/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch
import keras
from keras import backend as K, losses as k_losses
from keras import losses as k_losses
from keras.config import set_dtype_policy
from keras.dtype_policies import DTypePolicy
from keras.optimizers import LossScaleOptimizer
Expand Down Expand Up @@ -53,7 +53,8 @@ class LossClass:
kwargs: dict
Any keyword arguments to supply to the loss function at initialization.
"""
function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | T.Any = k_losses.MeanSquaredError
function: Callable[[torch.Tensor, torch.Tensor],
torch.Tensor] | T.Any = k_losses.MeanSquaredError
init: bool = True
kwargs: dict[str, T.Any] = field(default_factory=dict)

Expand Down Expand Up @@ -122,7 +123,7 @@ def _mask_shapes(self) -> list[tuple] | None:
``None`` if there is no mask input. """
if self._mask_inputs is None:
return None
return [K.int_shape(mask_input) for mask_input in self._mask_inputs]
return [mask_input.shape for mask_input in self._mask_inputs]

def configure(self, model: keras.models.Model) -> None:
""" Configure the loss functions for the given inputs and outputs.
Expand Down Expand Up @@ -366,8 +367,6 @@ def __init__(self,
is_predict: bool) -> None:
logger.debug("Initializing %s: (arguments: %s, mixed_precision: %s, is_predict: %s)",
self.__class__.__name__, arguments, mixed_precision, is_predict)
self._set_tf_settings(arguments.exclude_gpus)

use_mixed_precision = not is_predict and mixed_precision
self._use_mixed_precision = use_mixed_precision
if use_mixed_precision:
Expand Down Expand Up @@ -408,32 +407,6 @@ def loss_scale_optimizer(
"""
return LossScaleOptimizer(optimizer)

@classmethod
def _set_tf_settings(cls, exclude_devices: list[int]) -> None:
""" Specify Devices to place operations on and Allow TensorFlow to manage VRAM growth.
Parameters
----------
exclude_devices: list or ``None``
List of GPU device indices that should not be made available to Tensorflow. Pass
``None`` if all devices should be made available
"""
backend = get_backend()
if backend == "cpu":
logger.verbose("Hiding GPUs from Tensorflow") # type:ignore[attr-defined]
tf.config.set_visible_devices([], "GPU")
return

if not exclude_devices:
logger.debug("Not setting any specific Tensorflow settings")
return

gpus = tf.config.list_physical_devices('GPU')
if exclude_devices:
gpus = [gpu for idx, gpu in enumerate(gpus) if idx not in exclude_devices]
logger.debug("Filtering devices to: %s", gpus)
tf.config.set_visible_devices(gpus, "GPU")

@classmethod
def _set_keras_mixed_precision(cls, enable: bool) -> None:
""" Enable or disable Keras Mixed Precision.
Expand Down

0 comments on commit be2e8c5

Please sign in to comment.