Skip to content

Commit 218e45c

Browse files
committed
fix: update grad_output quant to avoid redundant work
Signed-off-by: kshitij12345 <[email protected]>
1 parent 2f61c40 commit 218e45c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

transformer_engine/pytorch/module/linear.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
524524
columnwise=columnwise_usage,
525525
)
526526

527+
# Adjust the quantization direction approach depending
528+
# on whether dgrad and wgrad calculations will be performed.
529+
if not ctx.requires_dgrad and ctx.grad_output_quantizer is not None:
530+
ctx.grad_output_quantizer.set_usage(rowwise=False)
531+
if not ctx.requires_wgrad and ctx.grad_output_quantizer is not None:
532+
ctx.grad_output_quantizer.set_usage(columnwise=False)
533+
527534
# Prepare grad output tensor
528535
# Note: Cast to expected dtype and perform tensor-parallel communication
529536
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")

0 commit comments

Comments
 (0)