Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform gradient clipping on global batch when using gradient accumulation #9

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 114 additions & 57 deletions paxml/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,43 @@ def get_grad_tx(
self._hparams.repeat_prefix_sep,
)

def get_individual_grad_norms(
self,
raw_grads,
optimizer_name):
p = self._hparams
# Compute gradient norm.

if p.grad_norm_individual_vars:
grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads)
var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms)

def add_grad_norm_summary(key, value):
base_layer.add_global_summary(
f'per_var_grad_norm/{optimizer_name}{key}',
value,
SummaryType.AGGREGATE_SCALAR,
)

jax.tree_map(add_grad_norm_summary, var_keys, grad_norms)

def keep_step(
self,
grad_norm):
p = self._hparams
keep_threshold = p.skip_step_gradient_norm_value
if keep_threshold:
return jnp.logical_and(
jnp.all(jnp.isfinite(grad_norm)),
jnp.all(jnp.less(grad_norm, keep_threshold)),
)
else:
return jnp.all(jnp.isfinite(grad_norm))

def scale_gradients(
self,
raw_grads: NestedMap,
raw_grad_norm: JTensor,
optimizer_name: Optional[str] = None,
clip_gradient_norm_to_value: Optional[float] = None,
clip_gradient_single_norm_to_value: Optional[float] = None,
Expand All @@ -209,57 +243,20 @@ def scale_gradients(
have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm)
and should not be skipped.
"""

p = self._hparams

if optimizer_name is None:
optimizer_name = ''
else:
optimizer_name = optimizer_name + '/'

if clip_gradient_norm_to_value is None:
clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value
if clip_gradient_single_norm_to_value is None:
clip_gradient_single_norm_to_value = (
p.optimizer.clip_gradient_single_norm_to_value
)
# Compute gradient norm.

if p.grad_norm_individual_vars:
grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads)
var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms)

def add_grad_norm_summary(key, value):
base_layer.add_global_summary(
f'per_var_grad_norm/{optimizer_name}{key}',
value,
SummaryType.AGGREGATE_SCALAR,
)

jax.tree_map(add_grad_norm_summary, var_keys, grad_norms)

if (
p.grad_norm_summary
or p.check_valid_step
or clip_gradient_norm_to_value
or clip_gradient_single_norm_to_value
):
raw_grad_norm = _compute_grad_norm(raw_grads)
if p.grad_norm_summary:
base_layer.add_global_summary(
'learning/' + optimizer_name + 'raw_grad_norm',
raw_grad_norm,
SummaryType.AGGREGATE_SCALAR,
)
else:
raw_grad_norm = None

def keep_step(grad_norm):
keep_threshold = p.skip_step_gradient_norm_value
if keep_threshold:
return jnp.logical_and(
jnp.all(jnp.isfinite(grad_norm)),
jnp.all(jnp.less(grad_norm, keep_threshold)),
)
else:
return jnp.all(jnp.isfinite(grad_norm))

def clip_grads(grads, grad_norm):
if clip_gradient_norm_to_value:
Expand Down Expand Up @@ -288,17 +285,6 @@ def scale_gradient(grad, norm):
grad_scale = jnp.array(1.0)
return grads, grad_scale

if p.check_valid_step:
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan
# or Inf, or excessively big gradient norm).
valid_step = keep_step(raw_grad_norm)
base_layer.add_global_summary(
'learning/' + optimizer_name + 'is_valid_step',
valid_step.astype(jnp.float32),
SummaryType.AGGREGATE_SCALAR,
)
else:
valid_step = True
grads, grad_scale = clip_grads(raw_grads, raw_grad_norm)
base_layer.add_global_summary(
'learning/' + optimizer_name + 'grad_scale',
Expand All @@ -313,7 +299,55 @@ def scale_gradient(grad, norm):
clipped_grad_norm,
SummaryType.AGGREGATE_SCALAR,
)
return grads, valid_step # pytype: disable=bad-return-type # jax-ndarray
return grads # pytype: disable=bad-return-type # jax-ndarray


def get_grad_norm_valid_step(
self,
raw_grads,
optimizer_name: Optional[str] = None,
clip_gradient_norm_to_value: Optional[float] = None,
clip_gradient_single_norm_to_value: Optional[float] = None
) -> Tuple[JTensor, JTensor]:

p = self._hparams

if optimizer_name is None:
optimizer_name = ''
else:
optimizer_name = optimizer_name + '/'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you are missing the following code block from the original scale_gradient?

    if clip_gradient_norm_to_value is None:
      clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value
    if clip_gradient_single_norm_to_value is None:
      clip_gradient_single_norm_to_value = (
          p.optimizer.clip_gradient_single_norm_to_value
      )

self.get_individual_grad_norms(raw_grads,
optimizer_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's not line break here, optimizer_name can be on previous line

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually can we move get_individual_grad_norms back inline? it's not used anywhere else, and it seems more consistent with the inlined global grad norm below


if (
p.grad_norm_summary
or p.check_valid_step
or clip_gradient_norm_to_value
or clip_gradient_single_norm_to_value
):
raw_grad_norm = _compute_grad_norm(raw_grads)
if p.grad_norm_summary:
base_layer.add_global_summary(
'learning/' + optimizer_name + 'raw_grad_norm',
raw_grad_norm,
SummaryType.AGGREGATE_SCALAR,
)
else:
raw_grad_norm = None

if p.check_valid_step:
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan
# or Inf, or excessively big gradient norm).
valid_step = self.keep_step(raw_grad_norm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move keep_step back as a free function inside get_grad_norm_valid_step rather than a new instance method?

the original code is a bit complicated; let's avoid refactoring too much because it might make it harder to spot whether the existing logic still holds

base_layer.add_global_summary(
'learning/' + optimizer_name + 'is_valid_step',
valid_step.astype(jnp.float32),
SummaryType.AGGREGATE_SCALAR,
)
else:
valid_step = True

return raw_grad_norm, valid_step

def update_states(
self,
Expand All @@ -335,7 +369,15 @@ def update_states(
"""
p = self._hparams

grads, valid_step = self.scale_gradients(grads)
grad_norm, valid_step = self.get_grad_norm_valid_step(grads)

using_ga = hasattr(p.optimizer, 'num_sub_batches')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's use using_grad_accum

most readers might not know what ga means


# When using gradient accumulation, gradient scaling happens within base
# optimizer update
if not using_ga:
grads = self.scale_gradients(grads, grad_norm)

transformed_grad, new_states = self.get_grad_tx(var_weight_hparams).update(
grads, states, old_vars
)
Expand All @@ -357,6 +399,7 @@ def _update(updated, original):
new_states = jax.tree_map(
_update, new_states, states, is_leaf=py_utils.is_optax_masked_node
)

# Final applied grad norm.
if p.grad_norm_summary:
applied_grad_norm = _compute_grad_norm(transformed_grad)
Expand Down Expand Up @@ -588,8 +631,16 @@ def scale_gradients_by_optimizer(
) -> Tuple[NestedMap, JTensor]:
optimizer_mask, default_mask = self.get_masks(var_weight_hparams)

all_grads, all_valid_step = self.scale_gradients(
jax.tree_map(lambda x, y: x * y, raw_grads, default_mask),
raw_grads = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not reuse raw_grads, let's call this grads_after_mask because you've introduced a subtle bug here if you look at line line 659 inside the auxiliary_optimizers loop, you are now combining this outer mask with inner mask

i would not overwrite raw_grads variable, just

grads_after_mask = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask)
grad_norm, all_valid_step = self.get_grad_norm_valid_step(
        grads_after_mask,
        optimizer_name='main',
    )

so that inside auxiliary_optimizers loop, raw_grads is only added to each auxiliary optimizer mask


grad_norm, all_valid_step = self.get_grad_norm_valid_step(
raw_grads,
optimizer_name='main',
)

all_grads = self.scale_gradients(
raw_grads,
grad_norm,
optimizer_name='main',
)

Expand All @@ -600,8 +651,13 @@ def scale_gradients_by_optimizer(
):
assert optimizer.clip_gradient_norm_to_value is not None
assert optimizer.clip_gradient_single_norm_to_value is not None
grads, valid_step = self.scale_gradients(
grad_norm, valid_step = self.get_grad_norm_valid_step(
raw_grads,
optimizer_name=name,
)
grads = self.scale_gradients(
jax.tree_map(lambda x, y: x * y, raw_grads, mask),
grad_norm,
optimizer_name=name,
clip_gradient_norm_to_value=optimizer.clip_gradient_norm_to_value,
clip_gradient_single_norm_to_value=optimizer.clip_gradient_single_norm_to_value,
Expand Down Expand Up @@ -633,7 +689,8 @@ def update_states(
grads, var_weight_hparams
)
else:
grads, valid_step = self.scale_gradients(grads)
grad_norm, valid_step = self.get_grad_norm_valid_step(grads)
grads = self.scale_gradients(grads, grad_norm)
grad_tx = self.get_grad_tx(var_weight_hparams)
transformed_grad, new_states = grad_tx.update(grads, states, old_vars)
if self._hparams.enable_skip_step_on_gradient_anomalies:
Expand Down
3 changes: 2 additions & 1 deletion paxml/learners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_learner_clip_gradients(self, g1a, g1b, g2, global_clip_norm,
grad2=jnp.array([g2], dtype=jnp.float32))

with base_layer.JaxContext.new_context():
transformed_grads, _ = learner_instance.scale_gradients(grads)
grad_norm, valid_step = learner_instance.get_grad_norm_valid_step(grads)
transformed_grads = learner_instance.scale_gradients(grads, grad_norm)

global_norm = np.linalg.norm([g1a, g1b, g2])
local_norm1 = np.linalg.norm([g1a, g1b])
Expand Down