-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perform gradient clipping on global batch when using gradient accumulation #9
base: main
Are you sure you want to change the base?
Changes from 5 commits
f6e6618
f57f77a
b69771e
fd49b3e
57e567c
2303ad8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,9 +185,43 @@ def get_grad_tx( | |
self._hparams.repeat_prefix_sep, | ||
) | ||
|
||
def get_individual_grad_norms( | ||
self, | ||
raw_grads, | ||
optimizer_name): | ||
p = self._hparams | ||
# Compute gradient norm. | ||
|
||
if p.grad_norm_individual_vars: | ||
grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads) | ||
var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms) | ||
|
||
def add_grad_norm_summary(key, value): | ||
base_layer.add_global_summary( | ||
f'per_var_grad_norm/{optimizer_name}{key}', | ||
value, | ||
SummaryType.AGGREGATE_SCALAR, | ||
) | ||
|
||
jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) | ||
|
||
def keep_step( | ||
self, | ||
grad_norm): | ||
p = self._hparams | ||
keep_threshold = p.skip_step_gradient_norm_value | ||
if keep_threshold: | ||
return jnp.logical_and( | ||
jnp.all(jnp.isfinite(grad_norm)), | ||
jnp.all(jnp.less(grad_norm, keep_threshold)), | ||
) | ||
else: | ||
return jnp.all(jnp.isfinite(grad_norm)) | ||
|
||
def scale_gradients( | ||
self, | ||
raw_grads: NestedMap, | ||
raw_grad_norm: JTensor, | ||
optimizer_name: Optional[str] = None, | ||
clip_gradient_norm_to_value: Optional[float] = None, | ||
clip_gradient_single_norm_to_value: Optional[float] = None, | ||
|
@@ -209,57 +243,20 @@ def scale_gradients( | |
have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm) | ||
and should not be skipped. | ||
""" | ||
|
||
p = self._hparams | ||
|
||
if optimizer_name is None: | ||
optimizer_name = '' | ||
else: | ||
optimizer_name = optimizer_name + '/' | ||
|
||
if clip_gradient_norm_to_value is None: | ||
clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value | ||
if clip_gradient_single_norm_to_value is None: | ||
clip_gradient_single_norm_to_value = ( | ||
p.optimizer.clip_gradient_single_norm_to_value | ||
) | ||
# Compute gradient norm. | ||
|
||
if p.grad_norm_individual_vars: | ||
grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads) | ||
var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms) | ||
|
||
def add_grad_norm_summary(key, value): | ||
base_layer.add_global_summary( | ||
f'per_var_grad_norm/{optimizer_name}{key}', | ||
value, | ||
SummaryType.AGGREGATE_SCALAR, | ||
) | ||
|
||
jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) | ||
|
||
if ( | ||
p.grad_norm_summary | ||
or p.check_valid_step | ||
or clip_gradient_norm_to_value | ||
or clip_gradient_single_norm_to_value | ||
): | ||
raw_grad_norm = _compute_grad_norm(raw_grads) | ||
if p.grad_norm_summary: | ||
base_layer.add_global_summary( | ||
'learning/' + optimizer_name + 'raw_grad_norm', | ||
raw_grad_norm, | ||
SummaryType.AGGREGATE_SCALAR, | ||
) | ||
else: | ||
raw_grad_norm = None | ||
|
||
def keep_step(grad_norm): | ||
keep_threshold = p.skip_step_gradient_norm_value | ||
if keep_threshold: | ||
return jnp.logical_and( | ||
jnp.all(jnp.isfinite(grad_norm)), | ||
jnp.all(jnp.less(grad_norm, keep_threshold)), | ||
) | ||
else: | ||
return jnp.all(jnp.isfinite(grad_norm)) | ||
|
||
def clip_grads(grads, grad_norm): | ||
if clip_gradient_norm_to_value: | ||
|
@@ -288,17 +285,6 @@ def scale_gradient(grad, norm): | |
grad_scale = jnp.array(1.0) | ||
return grads, grad_scale | ||
|
||
if p.check_valid_step: | ||
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan | ||
# or Inf, or excessively big gradient norm). | ||
valid_step = keep_step(raw_grad_norm) | ||
base_layer.add_global_summary( | ||
'learning/' + optimizer_name + 'is_valid_step', | ||
valid_step.astype(jnp.float32), | ||
SummaryType.AGGREGATE_SCALAR, | ||
) | ||
else: | ||
valid_step = True | ||
grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) | ||
base_layer.add_global_summary( | ||
'learning/' + optimizer_name + 'grad_scale', | ||
|
@@ -313,7 +299,55 @@ def scale_gradient(grad, norm): | |
clipped_grad_norm, | ||
SummaryType.AGGREGATE_SCALAR, | ||
) | ||
return grads, valid_step # pytype: disable=bad-return-type # jax-ndarray | ||
return grads # pytype: disable=bad-return-type # jax-ndarray | ||
|
||
|
||
def get_grad_norm_valid_step( | ||
self, | ||
raw_grads, | ||
optimizer_name: Optional[str] = None, | ||
clip_gradient_norm_to_value: Optional[float] = None, | ||
clip_gradient_single_norm_to_value: Optional[float] = None | ||
) -> Tuple[JTensor, JTensor]: | ||
|
||
p = self._hparams | ||
|
||
if optimizer_name is None: | ||
optimizer_name = '' | ||
else: | ||
optimizer_name = optimizer_name + '/' | ||
self.get_individual_grad_norms(raw_grads, | ||
optimizer_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: let's not line break here, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually can we move get_individual_grad_norms back inline? it's not used anywhere else, and it seems more consistent with the inlined global grad norm below |
||
|
||
if ( | ||
p.grad_norm_summary | ||
or p.check_valid_step | ||
or clip_gradient_norm_to_value | ||
or clip_gradient_single_norm_to_value | ||
): | ||
raw_grad_norm = _compute_grad_norm(raw_grads) | ||
if p.grad_norm_summary: | ||
base_layer.add_global_summary( | ||
'learning/' + optimizer_name + 'raw_grad_norm', | ||
raw_grad_norm, | ||
SummaryType.AGGREGATE_SCALAR, | ||
) | ||
else: | ||
raw_grad_norm = None | ||
|
||
if p.check_valid_step: | ||
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan | ||
# or Inf, or excessively big gradient norm). | ||
valid_step = self.keep_step(raw_grad_norm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's move the original code is a bit complicated; let's avoid refactoring too much because it might make it harder to spot whether the existing logic still holds |
||
base_layer.add_global_summary( | ||
'learning/' + optimizer_name + 'is_valid_step', | ||
valid_step.astype(jnp.float32), | ||
SummaryType.AGGREGATE_SCALAR, | ||
) | ||
else: | ||
valid_step = True | ||
|
||
return raw_grad_norm, valid_step | ||
|
||
def update_states( | ||
self, | ||
|
@@ -335,7 +369,15 @@ def update_states( | |
""" | ||
p = self._hparams | ||
|
||
grads, valid_step = self.scale_gradients(grads) | ||
grad_norm, valid_step = self.get_grad_norm_valid_step(grads) | ||
|
||
using_ga = hasattr(p.optimizer, 'num_sub_batches') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: let's use most readers might not know what |
||
|
||
# When using gradient accumulation, gradient scaling happens within base | ||
# optimizer update | ||
if not using_ga: | ||
grads = self.scale_gradients(grads, grad_norm) | ||
|
||
transformed_grad, new_states = self.get_grad_tx(var_weight_hparams).update( | ||
grads, states, old_vars | ||
) | ||
|
@@ -357,6 +399,7 @@ def _update(updated, original): | |
new_states = jax.tree_map( | ||
_update, new_states, states, is_leaf=py_utils.is_optax_masked_node | ||
) | ||
|
||
# Final applied grad norm. | ||
if p.grad_norm_summary: | ||
applied_grad_norm = _compute_grad_norm(transformed_grad) | ||
|
@@ -588,8 +631,16 @@ def scale_gradients_by_optimizer( | |
) -> Tuple[NestedMap, JTensor]: | ||
optimizer_mask, default_mask = self.get_masks(var_weight_hparams) | ||
|
||
all_grads, all_valid_step = self.scale_gradients( | ||
jax.tree_map(lambda x, y: x * y, raw_grads, default_mask), | ||
raw_grads = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not reuse i would not overwrite
so that inside auxiliary_optimizers loop, |
||
|
||
grad_norm, all_valid_step = self.get_grad_norm_valid_step( | ||
raw_grads, | ||
optimizer_name='main', | ||
) | ||
|
||
all_grads = self.scale_gradients( | ||
raw_grads, | ||
grad_norm, | ||
optimizer_name='main', | ||
) | ||
|
||
|
@@ -600,8 +651,13 @@ def scale_gradients_by_optimizer( | |
): | ||
assert optimizer.clip_gradient_norm_to_value is not None | ||
assert optimizer.clip_gradient_single_norm_to_value is not None | ||
grads, valid_step = self.scale_gradients( | ||
grad_norm, valid_step = self.get_grad_norm_valid_step( | ||
raw_grads, | ||
optimizer_name=name, | ||
) | ||
grads = self.scale_gradients( | ||
jax.tree_map(lambda x, y: x * y, raw_grads, mask), | ||
grad_norm, | ||
optimizer_name=name, | ||
clip_gradient_norm_to_value=optimizer.clip_gradient_norm_to_value, | ||
clip_gradient_single_norm_to_value=optimizer.clip_gradient_single_norm_to_value, | ||
|
@@ -633,7 +689,8 @@ def update_states( | |
grads, var_weight_hparams | ||
) | ||
else: | ||
grads, valid_step = self.scale_gradients(grads) | ||
grad_norm, valid_step = self.get_grad_norm_valid_step(grads) | ||
grads = self.scale_gradients(grads, grad_norm) | ||
grad_tx = self.get_grad_tx(var_weight_hparams) | ||
transformed_grad, new_states = grad_tx.update(grads, states, old_vars) | ||
if self._hparams.enable_skip_step_on_gradient_anomalies: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think you are missing the following code block from the original scale_gradient?