From 232c21d46754e194ae13bef30e32081ae4e7928e Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Tue, 2 Apr 2024 18:14:25 +0100 Subject: [PATCH] Port AdaBelief to Keras 3 --- lib/model/optimizers.py | 401 +++++++++++++++++----------------------- 1 file changed, 167 insertions(+), 234 deletions(-) diff --git a/lib/model/optimizers.py b/lib/model/optimizers.py index bf877433e8..107bbb2394 100644 --- a/lib/model/optimizers.py +++ b/lib/model/optimizers.py @@ -1,15 +1,25 @@ #!/usr/bin/env python3 """ Custom Optimizers for TensorFlow 2.x/keras """ - +from __future__ import annotations import inspect +import logging import sys +import typing as T -import keras +from keras import ops from keras.saving import get_custom_objects -from keras.optimizers import Adam, Nadam, RMSprop # noqa:E501,F401 pylint:disable=unused-import +from keras.optimizers import Adam, Optimizer, Nadam, RMSprop # noqa:E501,F401 pylint:disable=unused-import + +from lib.logger import parse_class_init + +if T.TYPE_CHECKING: + from keras import KerasTensor + from keras.src.backend.common import KerasVariable + +logger = logging.getLogger(__name__) -class AdaBelief(keras.optimizers.Optimizer): +class AdaBelief(Optimizer): """ Implementation of the AdaBelief Optimizer Inherits from: keras.optimizers.Optimizer. @@ -32,13 +42,11 @@ class AdaBelief(keras.optimizers.Optimizer): The exponential decay rate for the 2nd moment estimates. epsilon: float A small constant for numerical stability. - weight_decay: `Tensor`, float or :class: `keras.optimizers.schedules.LearningRateSchedule` - Weight decay for each parameter. - rectify: bool - Whether to enable rectification as in RectifiedAdam amsgrad: bool Whether to apply AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and beyond". + rectify: bool + Whether to enable rectification as in RectifiedAdam sma_threshold. float The threshold for simple mean average. total_steps: int @@ -50,11 +58,9 @@ class AdaBelief(keras.optimizers.Optimizer): name: str, optional Name for the operations created when applying gradients. Default: ``"AdaBeliefOptimizer"``. **kwargs: dict - Standard Keras Optimizer keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, - `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip gradients by value, - `decay` is included for backward compatibility to allow time inverse decay of learning - rate. `lr` is included for backward compatibility, recommended to use `learning_rate` - instead. + Standard Keras Optimizer keyword arguments. Allowed to be (`weight_decay`, `clipnorm`, + `clipvalue`, `global_clipnorm`, `use_ema`, `ema_momentum`, `ema_overwrite_frequency`, + `loss_scale_factor`, `gradient_accumulation_steps`) Examples -------- @@ -123,275 +129,202 @@ class AdaBelief(keras.optimizers.Optimizer): OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ - def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-14, - weight_decay=0.0, rectify=True, amsgrad=False, sma_threshold=5.0, total_steps=0, - warmup_proportion=0.1, min_lr=0.0, name="AdaBeliefOptimizer", **kwargs): - # pylint:disable=too-many-arguments - super().__init__(name, **kwargs) - self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) - self._set_hyper("beta_1", beta_1) - self._set_hyper("beta_2", beta_2) - self._set_hyper("decay", self._initial_decay) - self._set_hyper("weight_decay", weight_decay) - self._set_hyper("sma_threshold", sma_threshold) - self._set_hyper("total_steps", int(total_steps)) - self._set_hyper("warmup_proportion", warmup_proportion) - self._set_hyper("min_lr", min_lr) - self.epsilon = epsilon or keras.backend.epsilon() + def __init__(self, # pylint:disable=too-many-arguments + learning_rate: float = 0.001, + beta_1: float = 0.9, + beta_2: float = 0.999, + epsilon: float = 1e-14, + amsgrad: bool = False, + rectify: bool = True, + sma_threshold: float = 5.0, + total_steps: int = 0, + warmup_proportion: float = 0.1, + min_learning_rate: float = 0.0, + name="AdaBeliefOptimizer", + **kwargs): + logger.debug(parse_class_init(locals())) + super().__init__(learning_rate=learning_rate, name=name, **kwargs) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon self.amsgrad = amsgrad self.rectify = rectify - self._has_weight_decay = weight_decay != 0.0 - self._initial_total_steps = total_steps + self.sma_threshold = sma_threshold + # TODO change the following 2 to "warm_up_steps" + # TODO Make learning rate warm up a global option + # Or these params can be calculated from a user "warm_up_steps" parameter + self.total_steps = total_steps + self.warmup_proportion = warmup_proportion + self.min_learning_rate = min_learning_rate + logger.debug("Initialized %s", self.__class__.__name__) - def _create_slots(self, var_list): - """ Create slots for the first and second moments + self._momentums: list[KerasVariable] = [] + self._velocities: list[KerasVariable] = [] + self._velocity_hats: list[KerasVariable] = [] # Amsgrad only - Parameters - ---------- - var_list: list - List of tensorflow variables to create slots for - """ - for var in var_list: - self.add_slot(var, "m") - self.add_slot(var, "v") - if self.amsgrad: - self.add_slot(var, "vhat") + def build(self, variables: list[KerasVariable]) -> None: + """Initialize optimizer variables. - def set_weights(self, weights): - """ Set the weights of the optimizer. - - The weights of an optimizer are its state (IE, variables). This function takes the weight - values associated with this optimizer as a list of Numpy arrays. The first value is always - the iterations count of the optimizer, followed by the optimizers state variables in the - order they are created. The passed values are used to set the new state of the optimizer. + AdaBelief optimizer has 3 types of variables: momentums, velocities and + velocity_hat (only set when amsgrad is applied), Parameters ---------- - weights: list - weight values as a list of numpy arrays. + variables: list[:class:`keras.src.backend.common.KerasVariable`] + list of model variables to build AdaBelief variables on. """ - params = self.weights - num_vars = int((len(params) - 1) / 2) - if len(weights) == 3 * num_vars + 1: - weights = weights[: len(params)] - super().set_weights(weights) + if self.built: + return + logger.debug("Building AdaBelief. var_list: %s", variables) + super().build(variables) + + for var in variables: + self._momentums.append(self.add_variable_from_reference( + reference_variable=var, name="momentum")) + self._velocities.append(self.add_variable_from_reference( + reference_variable=var, name="velocity")) + if self.amsgrad: + self._velocity_hats.append(self.add_variable_from_reference( + reference_variable=var, name="velocity_hat")) + logger.debug("Built AdaBelief. momentums: %s, velocities: %s, velocity_hats: %s", + len(self._momentums), len(self._velocities), len(self._velocity_hats)) - def _decayed_wd(self, var_dtype): - """ Set the weight decay + def _maybe_warmup(self, learning_rate: KerasTensor, local_step: KerasTensor) -> KerasTensor: + """ Do learning rate warm up if requested Parameters ---------- - var_dtype: str - The data type to to set up weight decay for + learning_rate: :class:`keras.KerasTensor` + The learning rate + local_step: :class:`keras.KerasTensor` + The current training step Returns ------- - Tensor - The weight decay variable + :class:`keras.KerasTensor` + Either the original learning rate or adjusted learning rate if warmup is requested """ - wd_t = self._get_hyper("weight_decay", var_dtype) - if isinstance(wd_t, keras.optimizers.schedules.LearningRateSchedule): - wd_t = tf.cast(wd_t(self.iterations), var_dtype) - return wd_t - - def _resource_apply_dense(self, grad, handle, apply_state=None): - # pylint:disable=too-many-locals,unused-argument - """ Add ops to apply dense gradients to the variable handle. + if self.total_steps <= 0: + return learning_rate + + total_steps = ops.cast(self.total_steps, learning_rate.dtype) + warmup_steps = total_steps * ops.cast(self.warmup_proportion, learning_rate.dtype) + min_lr = ops.cast(self.min_learning_rate, learning_rate.dtype) + decay_steps = ops.maximum(total_steps - warmup_steps, 1) + decay_rate = ops.divide(min_lr - learning_rate, decay_steps) + return ops.where(local_step <= warmup_steps, + ops.multiply(learning_rate, (ops.divide(local_step, warmup_steps))), + ops.multiply(learning_rate + decay_rate, + ops.minimum(local_step - warmup_steps, decay_steps))) + + def _maybe_rectify(self, + momentum: KerasTensor, + velocity: KerasTensor, + local_step: KerasTensor, + beta_2_power: KerasTensor) -> KerasTensor: + """ Apply rectification, if requested Parameters ---------- - grad: Tensor - A tensor representing the gradient. - handle: Tensor - a Tensor of dtype resource which points to the variable to be updated. - apply_state: dict - A dict which is used across multiple apply calls. + momentum: :class:`keras.Tensor` + The momentum update + velocity: :class:`keras.Tensor` + The velocity update + local_step: :class:`keras.KerasTensor` + The current training step + beta_2_power + Adjusted exponential decay rate for the 2nd moment estimates. Returns ------- - An Operation which updates the value of the variable. + :class:`keras.KerasTensor` + The standard or rectified update (if rectification enabled) """ - var_dtype = handle.dtype.base_dtype - lr_t = self._decayed_lr(var_dtype) - wd_t = self._decayed_wd(var_dtype) - var_m = self.get_slot(handle, "m") - var_v = self.get_slot(handle, "v") - beta_1_t = self._get_hyper("beta_1", var_dtype) - beta_2_t = self._get_hyper("beta_2", var_dtype) - epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) - local_step = tf.cast(self.iterations + 1, var_dtype) - beta_1_power = tf.math.pow(beta_1_t, local_step) - beta_2_power = tf.math.pow(beta_2_t, local_step) - - if self._initial_total_steps > 0: - total_steps = self._get_hyper("total_steps", var_dtype) - warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype) - min_lr = self._get_hyper("min_lr", var_dtype) - decay_steps = tf.maximum(total_steps - warmup_steps, 1) - decay_rate = (min_lr - lr_t) / decay_steps - lr_t = tf.where(local_step <= warmup_steps, - lr_t * (local_step / warmup_steps), - lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps)) - - m_t = var_m.assign(beta_1_t * var_m + (1.0 - beta_1_t) * grad, - use_locking=self._use_locking) - m_corr_t = m_t / (1.0 - beta_1_power) - - v_t = var_v.assign( - beta_2_t * var_v + (1.0 - beta_2_t) * tf.math.square(grad - m_t) + epsilon_t, - use_locking=self._use_locking) - - if self.amsgrad: - vhat = self.get_slot(handle, "vhat") - vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) - v_corr_t = tf.math.sqrt(vhat_t / (1.0 - beta_2_power)) - else: - vhat_t = None - v_corr_t = tf.math.sqrt(v_t / (1.0 - beta_2_power)) - - if self.rectify: - sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 - sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) - r_t = tf.math.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * - (sma_t - 2.0) / (sma_inf - 2.0) * - sma_inf / sma_t) - sma_threshold = self._get_hyper("sma_threshold", var_dtype) - var_t = tf.where(sma_t >= sma_threshold, - r_t * m_corr_t / (v_corr_t + epsilon_t), - m_corr_t) - else: - var_t = m_corr_t / (v_corr_t + epsilon_t) - - if self._has_weight_decay: - var_t += wd_t * handle - - var_update = handle.assign_sub(lr_t * var_t, use_locking=self._use_locking) - updates = [var_update, m_t, v_t] - - if self.amsgrad: - updates.append(vhat_t) - return tf.group(*updates) - - def _resource_apply_sparse(self, grad, handle, indices, apply_state=None): - # pylint:disable=too-many-locals, unused-argument - """ Add ops to apply sparse gradients to the variable handle. - - Similar to _apply_sparse, the indices argument to this method has been de-duplicated. - Optimizers which deal correctly with non-unique indices may instead override - :func:`_resource_apply_sparse_duplicate_indices` to avoid this overhead. + if not self.rectify: + return ops.divide(momentum, ops.add(velocity, self.epsilon)) + + sma_inf = 2 / (1 - self.beta_2) - 1 + sma_t = sma_inf - 2 * local_step * beta_2_power / (1 - beta_2_power) + rect = ops.sqrt((sma_t - 4) / (sma_inf - 4) * + (sma_t - 2) / (sma_inf - 2) * + sma_inf / sma_t) + return ops.where(sma_t >= self.sma_threshold, + ops.divide( + ops.multiply(rect, momentum), + (ops.add(velocity, self.epsilon))), + momentum) + + def update_step(self, + gradient: KerasTensor, + variable: KerasVariable, + learning_rate: KerasVariable) -> None: + """Update step given gradient and the associated model variable for AdaBelief. Parameters ---------- - grad: Tensor - a Tensor representing the gradient for the affected indices. - handle: Tensor - a Tensor of dtype resource which points to the variable to be updated. - indices: Tensor - a Tensor of integral type representing the indices for which the gradient is nonzero. - Indices are unique. - apply_state: dict - A dict which is used across multiple apply calls. - - Returns - ------- - An Operation which updates the value of the variable. + gradient :class:`keras.KerasTensor` + The gradient to update + variable: :class:`keras.src.backend.common.KerasVariable` + The variable to update + learning_rate: :class:`keras.src.backend.common.KerasVariable` + The learning rate """ - var_dtype = handle.dtype.base_dtype - lr_t = self._decayed_lr(var_dtype) - wd_t = self._decayed_wd(var_dtype) - beta_1_t = self._get_hyper("beta_1", var_dtype) - beta_2_t = self._get_hyper("beta_2", var_dtype) - epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) - local_step = tf.cast(self.iterations + 1, var_dtype) - beta_1_power = tf.math.pow(beta_1_t, local_step) - beta_2_power = tf.math.pow(beta_2_t, local_step) - - if self._initial_total_steps > 0: - total_steps = self._get_hyper("total_steps", var_dtype) - warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype) - min_lr = self._get_hyper("min_lr", var_dtype) - decay_steps = tf.maximum(total_steps - warmup_steps, 1) - decay_rate = (min_lr - lr_t) / decay_steps - lr_t = tf.where(local_step <= warmup_steps, - lr_t * (local_step / warmup_steps), - lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps)) - - var_m = self.get_slot(handle, "m") - m_scaled_g_values = grad * (1 - beta_1_t) - m_t = var_m.assign(var_m * beta_1_t, use_locking=self._use_locking) - m_t = self._resource_scatter_add(var_m, indices, m_scaled_g_values) - m_corr_t = m_t / (1.0 - beta_1_power) - - var_v = self.get_slot(handle, "v") - m_t_indices = tf.gather(m_t, indices) # pylint:disable=no-value-for-parameter - v_scaled_g_values = tf.math.square(grad - m_t_indices) * (1 - beta_2_t) - v_t = var_v.assign(var_v * beta_2_t + epsilon_t, use_locking=self._use_locking) - v_t = self._resource_scatter_add(var_v, indices, v_scaled_g_values) + local_step = ops.cast(self.iterations + 1, variable.dtype) + learning_rate = self._maybe_warmup(ops.cast(learning_rate, variable.dtype), local_step) + gradient = ops.cast(gradient, variable.dtype) + beta_1_power = ops.power(ops.cast(self.beta_1, variable.dtype), local_step) + beta_2_power = ops.power(ops.cast(self.beta_2, variable.dtype), local_step) + + # m_t = b1 * m + (1 - b1) * g + # => m_t = m + (g - m) * (1 - b1) + momentum = self._momentums[self._get_variable_index(variable)] + self.assign_add(momentum, ops.multiply(ops.subtract(gradient, momentum), 1 - self.beta_1)) + momentum_corr = ops.divide(momentum, (1 - beta_1_power)) + + # v_t = b2 * v + (1 - b2) * (g - m_t)^2 + e + # => v_t = v + ((g - m_t)^2 - v) * (1 - b2) + e + velocity = self._velocities[self._get_variable_index(variable)] + self.assign_add(velocity, + ops.multiply( + ops.subtract(ops.square(gradient - momentum), velocity), + 1 - self.beta_2) + + self.epsilon) if self.amsgrad: - vhat = self.get_slot(handle, "vhat") - vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) - v_corr_t = tf.math.sqrt(vhat_t / (1.0 - beta_2_power)) + velocity_hat = self._velocity_hats[self._get_variable_index(variable)] + self.assign(velocity_hat, ops.maximum(velocity, velocity_hat)) + velocity_corr = ops.sqrt(ops.divide(velocity_hat, (1 - beta_2_power))) else: - vhat_t = None - v_corr_t = tf.math.sqrt(v_t / (1.0 - beta_2_power)) - - if self.rectify: - sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0 - sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power) - r_t = tf.math.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * - (sma_t - 2.0) / (sma_inf - 2.0) * - sma_inf / sma_t) - sma_threshold = self._get_hyper("sma_threshold", var_dtype) - var_t = tf.where(sma_t >= sma_threshold, - r_t * m_corr_t / (v_corr_t + epsilon_t), - m_corr_t) - else: - var_t = m_corr_t / (v_corr_t + epsilon_t) - - if self._has_weight_decay: - var_t += wd_t * handle + velocity_corr = ops.sqrt(ops.divide(velocity, (1 - beta_2_power))) - var_update = self._resource_scatter_add(handle, - indices, - tf.gather( # pylint:disable=no-value-for-parameter - tf.math.negative(lr_t) * var_t, - indices)) + var_t = self._maybe_rectify(momentum_corr, velocity_corr, local_step, beta_2_power) - updates = [var_update, m_t, v_t] - if self.amsgrad: - updates.append(vhat_t) - return tf.group(*updates) + self.assign_sub(variable, ops.multiply(learning_rate, var_t)) - def get_config(self): + def get_config(self) -> dict[str, T.Any]: """ Returns the config of the optimizer. - An optimizer config is a Python dictionary (serializable) containing the configuration of - an optimizer. The same optimizer can be re-instantiated later (without any saved state) - from this configuration. + Optimizer configuration for AdaBelief. Returns ------- - dict + dict[str, Any] The optimizer configuration. """ config = super().get_config() - config.update({"learning_rate": self._serialize_hyperparameter("learning_rate"), - "beta_1": self._serialize_hyperparameter("beta_1"), - "beta_2": self._serialize_hyperparameter("beta_2"), - "decay": self._serialize_hyperparameter("decay"), - "weight_decay": self._serialize_hyperparameter("weight_decay"), - "sma_threshold": self._serialize_hyperparameter("sma_threshold"), + config.update({"beta_1": self.beta_1, + "beta_2": self.beta_2, "epsilon": self.epsilon, "amsgrad": self.amsgrad, "rectify": self.rectify, - "total_steps": self._serialize_hyperparameter("total_steps"), - "warmup_proportion": self._serialize_hyperparameter("warmup_proportion"), - "min_lr": self._serialize_hyperparameter("min_lr")}) + "sma_threshold": self.sma_threshold, + "total_steps": self.total_steps, + "warmup_proportion": self.warmup_proportion, + "min_learning_rate": self.min_learning_rate}) return config -# Update layers into Keras custom objects +# Update Optimizers into Keras custom objects for _name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and obj.__module__ == __name__: get_custom_objects().update({_name: obj})