diff --git a/test/convergence/bf16/test_mini_models_multimodal.py b/test/convergence/bf16/test_mini_models_multimodal.py index 2645e26fa..a89308795 100644 --- a/test/convergence/bf16/test_mini_models_multimodal.py +++ b/test/convergence/bf16/test_mini_models_multimodal.py @@ -322,6 +322,7 @@ def run_mini_model_multimodal( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) + output.logits.retain_grad() output.loss.backward() optimizer.step() @@ -403,10 +404,10 @@ def test_mini_model_multimodal( rtol=loss_rtol, ) - # Compare the logits from the last step + # Compare the logits.grad from the last step instead of logits, liger implementation doesn't keep logits assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], + expected_output["logits"].grad, + actual_output["logits"].grad, atol=logits_atol, rtol=logits_rtol, ) diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py index 8605893b1..5f9c47073 100644 --- a/test/convergence/bf16/test_mini_models_with_logits.py +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -444,6 +444,7 @@ def run_mini_model( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) + output.logits.retain_grad() output.loss.backward() optimizer.step() print(f"Step {i}, Loss: {output.loss.item()}") @@ -634,12 +635,11 @@ def test_mini_model( rtol=loss_rtol, ) - # No logits are materialized # import pdb; pdb.set_trace() - # Compare the logits from the last step + # Compare the logits.grad from the last step instead of logits, liger implementation doesn't keep logits assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], + expected_output["logits"].grad, + actual_output["logits"].grad, atol=logits_atol, rtol=logits_rtol, ) diff --git a/test/convergence/fp32/test_mini_models_multimodal.py b/test/convergence/fp32/test_mini_models_multimodal.py index ff251b9c9..2fa25becd 100644 --- a/test/convergence/fp32/test_mini_models_multimodal.py +++ b/test/convergence/fp32/test_mini_models_multimodal.py @@ -321,6 +321,7 @@ def run_mini_model_multimodal( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) + output.logits.retain_grad() output.loss.backward() optimizer.step() @@ -398,10 +399,10 @@ def test_mini_model_multimodal( rtol=loss_rtol, ) - # Compare the logits from the last step + # Compare the logits.grad from the last step instead of logits, liger implementation doesn't keep logits assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], + expected_output["logits"].grad, + actual_output["logits"].grad, atol=logits_atol, rtol=logits_rtol, ) diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py index 75b388740..f6f6a3a6b 100644 --- a/test/convergence/fp32/test_mini_models_with_logits.py +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -443,6 +443,7 @@ def run_mini_model( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) + output.logits.retain_grad() output.loss.backward() optimizer.step() print(f"Step {i}, Loss: {output.loss.item()}") @@ -527,10 +528,10 @@ def test_mini_model( # No logits are materialized # import pdb; pdb.set_trace() - # Compare the logits from the last step + # Compare the logits.grad from the last step instead of logits, liger implementation doesn't keep logits assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], + expected_output["logits"].grad, + actual_output["logits"].grad, atol=logits_atol, rtol=logits_rtol, )