Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform gradient clipping on global batch when using gradient accumulation #9

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ashors1
Copy link
Contributor

@ashors1 ashors1 commented Feb 14, 2023

Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using ShardedStaticAccumulator. Note that this refactor allows us to maintain support for enable_skip_step_on_gradient_anomalies and requires x+1 grad norm calculations per global batch when using ShardedStaticAccumulator with x subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.

This PR should be taken together with the corresponding Praxis PR.

@zhangqiaorjc zhangqiaorjc self-assigned this Mar 3, 2023
@zhangqiaorjc zhangqiaorjc added the pull ready Used to import PR as CL label Mar 5, 2023
@zhangqiaorjc
Copy link
Member

@ashors1 sorry for the late review, could rebase to head? i want to import it and run some internal CI, thanks!

@zhangqiaorjc
Copy link
Member

There's quite a few redundant whitespaces. Could you run some python linter to remove those?

@zhangqiaorjc zhangqiaorjc added pull ready Used to import PR as CL and removed pull ready Used to import PR as CL labels Mar 8, 2023
if optimizer_name is None:
optimizer_name = ''
else:
optimizer_name = optimizer_name + '/'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you are missing the following code block from the original scale_gradient?

    if clip_gradient_norm_to_value is None:
      clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value
    if clip_gradient_single_norm_to_value is None:
      clip_gradient_single_norm_to_value = (
          p.optimizer.clip_gradient_single_norm_to_value
      )

else:
optimizer_name = optimizer_name + '/'
self.get_individual_grad_norms(raw_grads,
optimizer_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's not line break here, optimizer_name can be on previous line

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually can we move get_individual_grad_norms back inline? it's not used anywhere else, and it seems more consistent with the inlined global grad norm below

if p.check_valid_step:
# Mark the step as invalid if any gradient anomaly is detected (e.g. Nan
# or Inf, or excessively big gradient norm).
valid_step = self.keep_step(raw_grad_norm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move keep_step back as a free function inside get_grad_norm_valid_step rather than a new instance method?

the original code is a bit complicated; let's avoid refactoring too much because it might make it harder to spot whether the existing logic still holds

grads, valid_step = self.scale_gradients(grads)
grad_norm, valid_step = self.get_grad_norm_valid_step(grads)

using_ga = hasattr(p.optimizer, 'num_sub_batches')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's use using_grad_accum

most readers might not know what ga means

@@ -588,8 +631,16 @@ def scale_gradients_by_optimizer(
) -> Tuple[NestedMap, JTensor]:
optimizer_mask, default_mask = self.get_masks(var_weight_hparams)

all_grads, all_valid_step = self.scale_gradients(
jax.tree_map(lambda x, y: x * y, raw_grads, default_mask),
raw_grads = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not reuse raw_grads, let's call this grads_after_mask because you've introduced a subtle bug here if you look at line line 659 inside the auxiliary_optimizers loop, you are now combining this outer mask with inner mask

i would not overwrite raw_grads variable, just

grads_after_mask = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask)
grad_norm, all_valid_step = self.get_grad_norm_valid_step(
        grads_after_mask,
        optimizer_name='main',
    )

so that inside auxiliary_optimizers loop, raw_grads is only added to each auxiliary optimizer mask

@nluehr
Copy link

nluehr commented Jun 30, 2023

@zhangqiaorjc is there a reason this has been approved by not merged yet?

ashors1 pushed a commit to ashors1/paxml that referenced this pull request Jul 18, 2023
…kage/tensorflow-2.11.1

PiperOrigin-RevId: 524892551
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Used to import PR as CL
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants