Skip to content

Commit cf0dd38

Browse files
committed
add test
1 parent 218e45c commit cf0dd38

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/pytorch/test_sanity.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,3 +1338,24 @@ def backward(ctx, grad_output):
13381338

13391339
# Assert that gradients are the same
13401340
torch.testing.assert_close(grad_checkpoint, grad_standard)
1341+
1342+
def test_linear_frozen_weights_memory_default_recipe():
1343+
"""Test that memory usage is optimized when weights are frozen for MXFP8."""
1344+
dim = 1024
1345+
linear = Linear(dim, dim, bias=False)
1346+
x = torch.randn(dim, dim, requires_grad=True, device="cuda")
1347+
1348+
# Freeze weights
1349+
linear.weight.requires_grad = False
1350+
1351+
# Forward and backward pass with FP8
1352+
with fp8_autocast():
1353+
o = linear(x)
1354+
g_o = torch.randn_like(o)
1355+
1356+
max_memory_before_backward = torch.cuda.max_memory_allocated()
1357+
o.backward(g_o)
1358+
max_memory_after_backward = torch.cuda.max_memory_allocated()
1359+
1360+
memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
1361+
assert memory_diff < 5.5, f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the grad_output should be quantized only columnwise."

0 commit comments

Comments
 (0)