Skip to content

Commit

Permalink
Add Tensorflow support for variable scatter_update in optimizers.
Browse files Browse the repository at this point in the history
This should have been added along with `scatter_add` and `scatter_sub` as part of #18692 as it was added in the super class: https://github.com/keras-team/keras/blob/master/keras/optimizers/base_optimizer.py#L222

Fixes #19281
  • Loading branch information
hertschuh committed Mar 14, 2024
1 parent 818c9fa commit bb9551b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
9 changes: 9 additions & 0 deletions keras/backend/tensorflow/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
"(as it is incompatible with tf.distribute)."
)

def assign(self, variable, value):
if isinstance(variable, KerasVariable):
variable = variable.value
value = tf.cast(value, variable.dtype)
if isinstance(value, tf.IndexedSlices):
variable.scatter_update(value)
else:
variable.assign(value)

def assign_add(self, variable, value):
if isinstance(variable, KerasVariable):
variable = variable.value
Expand Down
26 changes: 26 additions & 0 deletions keras/optimizers/optimizer_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
from keras import optimizers
from keras import testing


class ScatterUpdateOptimizer(optimizers.Optimizer):
def __init__(self):
super().__init__(learning_rate=0.001)

def build(self, variables):
if self.built:
return
super().build(variables)
self.momentums = [
self.add_variable_from_reference(v, name="momentum")
for v in variables
]

def update_step(self, grad, variable, learning_rate):
momentum = self.momentums[self._get_variable_index(variable)]
self.assign(momentum, ops.cast(grad, momentum.dtype))
self.assign(variable, ops.cast(grad, variable.dtype))


TEST_CASES = [
{
"testcase_name": "adadelta",
Expand Down Expand Up @@ -97,6 +117,12 @@
"optimizer_class": optimizers.SGD,
"init_kwargs": {"momentum": 0.05, "nesterov": True},
},
{
"testcase_name": "scatter_update",
"optimizer_class": ScatterUpdateOptimizer,
"expect_model_sparse_variable_updates": True,
"expect_optimizer_sparse_variable_updates": True,
},
]


Expand Down

0 comments on commit bb9551b

Please sign in to comment.