From 762165b4cc796b33133edc4f7f9595ba2e1b5d4a Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 15 Feb 2025 23:37:03 +0800 Subject: [PATCH] Update for bf16 and fp32 --- test/convergence/bf16/test_mini_models_multimodal.py | 5 +++-- test/convergence/bf16/test_mini_models_with_logits.py | 2 +- test/convergence/fp32/test_mini_models_multimodal.py | 5 +++-- test/convergence/fp32/test_mini_models_with_logits.py | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/convergence/bf16/test_mini_models_multimodal.py b/test/convergence/bf16/test_mini_models_multimodal.py index 2645e26fa..fefb00dc5 100644 --- a/test/convergence/bf16/test_mini_models_multimodal.py +++ b/test/convergence/bf16/test_mini_models_multimodal.py @@ -322,6 +322,7 @@ def run_mini_model_multimodal( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) + output.logits.retain_grad() output.loss.backward() optimizer.step() @@ -405,8 +406,8 @@ def test_mini_model_multimodal( # Compare the logits from the last step assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], + expected_output["logits"].grad, + actual_output["logits"].grad, atol=logits_atol, rtol=logits_rtol, ) diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py index 1d3788935..5f9c47073 100644 --- a/test/convergence/bf16/test_mini_models_with_logits.py +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -444,7 +444,7 @@ def run_mini_model( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) - output.logits.retain_grad() # For comparing logits.grad + output.logits.retain_grad() output.loss.backward() optimizer.step() print(f"Step {i}, Loss: {output.loss.item()}") diff --git a/test/convergence/fp32/test_mini_models_multimodal.py b/test/convergence/fp32/test_mini_models_multimodal.py index ff251b9c9..efb16f02d 100644 --- a/test/convergence/fp32/test_mini_models_multimodal.py +++ b/test/convergence/fp32/test_mini_models_multimodal.py @@ -321,6 +321,7 @@ def run_mini_model_multimodal( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) + output.logits.retain_grad() output.loss.backward() optimizer.step() @@ -400,8 +401,8 @@ def test_mini_model_multimodal( # Compare the logits from the last step assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], + expected_output["logits"].grad, + actual_output["logits"].grad, atol=logits_atol, rtol=logits_rtol, ) diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py index 75b388740..93e0494a0 100644 --- a/test/convergence/fp32/test_mini_models_with_logits.py +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -443,6 +443,7 @@ def run_mini_model( batch = next(loader_iter).to(model.device) optimizer.zero_grad() output = model(**batch) + output.logits.retain_grad() output.loss.backward() optimizer.step() print(f"Step {i}, Loss: {output.loss.item()}") @@ -529,8 +530,8 @@ def test_mini_model( # import pdb; pdb.set_trace() # Compare the logits from the last step assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], + expected_output["logits"].grad, + actual_output["logits"].grad, atol=logits_atol, rtol=logits_rtol, )