Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform gradient clipping on global batch when using gradient accumulation #9

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 120 additions & 58 deletions paxml/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,30 @@ def get_grad_tx(
self._hparams.repeat_prefix_sep,
)

def get_individual_grad_norms(
self,
raw_grads,
optimizer_name):
p = self._hparams
# Compute gradient norm.

if p.grad_norm_individual_vars:
grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads)
var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms)

def add_grad_norm_summary(key, value):
base_layer.add_global_summary(
f'per_var_grad_norm/{optimizer_name}{key}',
value,
SummaryType.AGGREGATE_SCALAR,
)

jax.tree_map(add_grad_norm_summary, var_keys, grad_norms)

def scale_gradients(
self,
raw_grads: NestedMap,
raw_grad_norm: JTensor,
optimizer_name: Optional[str] = None,
clip_gradient_norm_to_value: Optional[float] = None,
clip_gradient_single_norm_to_value: Optional[float] = None,
Expand All @@ -209,57 +230,20 @@ def scale_gradients(
have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm)
and should not be skipped.
"""

p = self._hparams

if optimizer_name is None:
optimizer_name = ''
else:
optimizer_name = optimizer_name + '/'

if clip_gradient_norm_to_value is None:
clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value
if clip_gradient_single_norm_to_value is None:
clip_gradient_single_norm_to_value = (
p.optimizer.clip_gradient_single_norm_to_value
)
# Compute gradient norm.

if p.grad_norm_individual_vars:
grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads)
var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms)

def add_grad_norm_summary(key, value):
base_layer.add_global_summary(
f'per_var_grad_norm/{optimizer_name}{key}',
value,
SummaryType.AGGREGATE_SCALAR,
)

jax.tree_map(add_grad_norm_summary, var_keys, grad_norms)

if (
p.grad_norm_summary
or p.check_valid_step
or clip_gradient_norm_to_value
or clip_gradient_single_norm_to_value
):
raw_grad_norm = _compute_grad_norm(raw_grads)
if p.grad_norm_summary:
base_layer.add_global_summary(
'learning/' + optimizer_name + 'raw_grad_norm',
raw_grad_norm,
SummaryType.AGGREGATE_SCALAR,
)
else:
raw_grad_norm = None

def keep_step(grad_norm):
keep_threshold = p.skip_step_gradient_norm_value
if keep_threshold:
return jnp.logical_and(
jnp.all(jnp.isfinite(grad_norm)),
jnp.all(jnp.less(grad_norm, keep_threshold)),
)
else:
return jnp.all(jnp.isfinite(grad_norm))

def clip_grads(grads, grad_norm):
if clip_gradient_norm_to_value:
Expand Down Expand Up @@ -288,17 +272,6 @@ def scale_gradient(grad, norm):
grad_scale = jnp.array(1.0)
return grads, grad_scale

if p.check_valid_step:
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan
# or Inf, or excessively big gradient norm).
valid_step = keep_step(raw_grad_norm)
base_layer.add_global_summary(
'learning/' + optimizer_name + 'is_valid_step',
valid_step.astype(jnp.float32),
SummaryType.AGGREGATE_SCALAR,
)
else:
valid_step = True
grads, grad_scale = clip_grads(raw_grads, raw_grad_norm)
base_layer.add_global_summary(
'learning/' + optimizer_name + 'grad_scale',
Expand All @@ -313,7 +286,71 @@ def scale_gradient(grad, norm):
clipped_grad_norm,
SummaryType.AGGREGATE_SCALAR,
)
return grads, valid_step # pytype: disable=bad-return-type # jax-ndarray
return grads # pytype: disable=bad-return-type # jax-ndarray


def get_grad_norm_valid_step(
self,
raw_grads,
optimizer_name: Optional[str] = None,
clip_gradient_norm_to_value: Optional[float] = None,
clip_gradient_single_norm_to_value: Optional[float] = None
) -> Tuple[JTensor, JTensor]:

p = self._hparams

if optimizer_name is None:
optimizer_name = ''
else:
optimizer_name = optimizer_name + '/'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you are missing the following code block from the original scale_gradient?

    if clip_gradient_norm_to_value is None:
      clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value
    if clip_gradient_single_norm_to_value is None:
      clip_gradient_single_norm_to_value = (
          p.optimizer.clip_gradient_single_norm_to_value
      )

self.get_individual_grad_norms(raw_grads, optimizer_name)

if clip_gradient_norm_to_value is None:
clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value
if clip_gradient_single_norm_to_value is None:
clip_gradient_single_norm_to_value = (
p.optimizer.clip_gradient_single_norm_to_value
)

if (
p.grad_norm_summary
or p.check_valid_step
or clip_gradient_norm_to_value
or clip_gradient_single_norm_to_value
):
raw_grad_norm = _compute_grad_norm(raw_grads)
if p.grad_norm_summary:
base_layer.add_global_summary(
'learning/' + optimizer_name + 'raw_grad_norm',
raw_grad_norm,
SummaryType.AGGREGATE_SCALAR,
)
else:
raw_grad_norm = None

def keep_step(grad_norm):
keep_threshold = p.skip_step_gradient_norm_value
if keep_threshold:
return jnp.logical_and(
jnp.all(jnp.isfinite(grad_norm)),
jnp.all(jnp.less(grad_norm, keep_threshold)),
)
else:
return jnp.all(jnp.isfinite(grad_norm))

if p.check_valid_step:
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan
# or Inf, or excessively big gradient norm).
valid_step = keep_step(raw_grad_norm)
base_layer.add_global_summary(
'learning/' + optimizer_name + 'is_valid_step',
valid_step.astype(jnp.float32),
SummaryType.AGGREGATE_SCALAR,
)
else:
valid_step = True

return raw_grad_norm, valid_step

def update_states(
self,
Expand All @@ -335,7 +372,15 @@ def update_states(
"""
p = self._hparams

grads, valid_step = self.scale_gradients(grads)
grad_norm, valid_step = self.get_grad_norm_valid_step(grads)

using_grad_accum = hasattr(p.optimizer, 'num_sub_batches')

# When using gradient accumulation, gradient scaling happens within base
# optimizer update
if not using_grad_accum:
grads = self.scale_gradients(grads, grad_norm)

transformed_grad, new_states = self.get_grad_tx(var_weight_hparams).update(
grads, states, old_vars
)
Expand All @@ -357,6 +402,7 @@ def _update(updated, original):
new_states = jax.tree_map(
_update, new_states, states, is_leaf=py_utils.is_optax_masked_node
)

# Final applied grad norm.
if p.grad_norm_summary:
applied_grad_norm = _compute_grad_norm(transformed_grad)
Expand Down Expand Up @@ -588,8 +634,16 @@ def scale_gradients_by_optimizer(
) -> Tuple[NestedMap, JTensor]:
optimizer_mask, default_mask = self.get_masks(var_weight_hparams)

all_grads, all_valid_step = self.scale_gradients(
jax.tree_map(lambda x, y: x * y, raw_grads, default_mask),
grads_after_default_mask = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask)

grad_norm, all_valid_step = self.get_grad_norm_valid_step(
grads_after_default_mask,
optimizer_name='main',
)

all_grads = self.scale_gradients(
grads_after_default_mask,
grad_norm,
optimizer_name='main',
)

Expand All @@ -600,8 +654,15 @@ def scale_gradients_by_optimizer(
):
assert optimizer.clip_gradient_norm_to_value is not None
assert optimizer.clip_gradient_single_norm_to_value is not None
grads, valid_step = self.scale_gradients(
jax.tree_map(lambda x, y: x * y, raw_grads, mask),

grads_after_mask = jax.tree_map(lambda x, y: x * y, raw_grads, mask)
grad_norm, valid_step = self.get_grad_norm_valid_step(
grads_after_mask,
optimizer_name=name,
)
grads = self.scale_gradients(
grads_after_mask,
grad_norm,
optimizer_name=name,
clip_gradient_norm_to_value=optimizer.clip_gradient_norm_to_value,
clip_gradient_single_norm_to_value=optimizer.clip_gradient_single_norm_to_value,
Expand Down Expand Up @@ -633,7 +694,8 @@ def update_states(
grads, var_weight_hparams
)
else:
grads, valid_step = self.scale_gradients(grads)
grad_norm, valid_step = self.get_grad_norm_valid_step(grads)
grads = self.scale_gradients(grads, grad_norm)
grad_tx = self.get_grad_tx(var_weight_hparams)
transformed_grad, new_states = grad_tx.update(grads, states, old_vars)
if self._hparams.enable_skip_step_on_gradient_anomalies:
Expand Down
3 changes: 2 additions & 1 deletion paxml/learners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_learner_clip_gradients(self, g1a, g1b, g2, global_clip_norm,
grad2=jnp.array([g2], dtype=jnp.float32))

with base_layer.JaxContext.new_context():
transformed_grads, _ = learner_instance.scale_gradients(grads)
grad_norm, valid_step = learner_instance.get_grad_norm_valid_step(grads)
transformed_grads = learner_instance.scale_gradients(grads, grad_norm)

global_norm = np.linalg.norm([g1a, g1b, g2])
local_norm1 = np.linalg.norm([g1a, g1b])
Expand Down