Skip to content

Commit

Permalink
Modifying DPLossFastGradientClipping to add support for generative ta…
Browse files Browse the repository at this point in the history
…sks with ghost clipping (#716)

Summary:

Generative tasks for NLP output predictions of shape (B,T,C) i.e., (batch_size, sequence_length, vocab_size). To compute the cross-entropy loss in this case, usually the predictions are reshaped to (BxT, C) and targets to (BxT). This creates an issue with Ghost Clipping per sample loss computation as BxT is seen as the batch_size. In particular, the current implementation of Ghost Clipping results in loss_per_sample, coeff variables to have a shape of BxT and B respectively. This causes a shape mismatch error. This diff fixes that error by collapsing the loss_per_sample variable to shape B i.e., the loss across the sequence_length dim is averaged/summed.

Differential Revision: D68047256
  • Loading branch information
Aparna Aketi authored and facebook-github-bot committed Jan 25, 2025
1 parent 9741fe2 commit 3c32510
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions opacus/utils/fast_gradient_clipping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def backward(self):
reduced_loss.backward(retain_graph=True)
self.optimizer.zero_grad()
coeff = self.module.get_clipping_coef()
second_loss_per_sample = coeff * self.loss_per_sample
second_loss_per_sample = (
coeff.to(self.loss_per_sample.device) * self.loss_per_sample
)
second_loss = torch.sum(second_loss_per_sample)
self.module.disable_hooks()
second_loss.backward()
Expand Down Expand Up @@ -104,15 +106,29 @@ def __init__(
self.loss_reduction = loss_reduction
self.criterion.reduction = "none"

def __call__(self, input, target) -> DPTensorFastGradientClipping:
def __call__(
self, input, target, shape=None
) -> DPTensorFastGradientClipping:
"""
Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping
"""

loss_per_sample = self.criterion(
input,
target,
)
loss_per_sample = self.criterion(input, target)

if shape is not None and loss_per_sample.shape[0] == shape[0] * shape[1]:
# Note that the privacy unit for generative NLP tasks is per sequence.
# The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size].
# This variable is necessary for ghost clipping to work with generative NLP tasks.
loss_per_sample = loss_per_sample.view(shape[0], shape[1]) # BxT
if self.loss_reduction == "mean":
loss_per_sample = loss_per_sample.mean(dim=1) # B
elif self.loss_reduction == "sum":
loss_per_sample = loss_per_sample.sum(dim=1) # B
else:
raise ValueError(
f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported"
)

return DPTensorFastGradientClipping(
self.module, self.optimizer, loss_per_sample, self.loss_reduction
)

0 comments on commit 3c32510

Please sign in to comment.