Skip to content

Commit

Permalink
Update for bf16 and fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Feb 15, 2025
1 parent 02dbb2e commit 762165b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
5 changes: 3 additions & 2 deletions test/convergence/bf16/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def run_mini_model_multimodal(
batch = next(loader_iter).to(model.device)
optimizer.zero_grad()
output = model(**batch)
output.logits.retain_grad()
output.loss.backward()
optimizer.step()

Expand Down Expand Up @@ -405,8 +406,8 @@ def test_mini_model_multimodal(

# Compare the logits from the last step
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
expected_output["logits"].grad,
actual_output["logits"].grad,
atol=logits_atol,
rtol=logits_rtol,
)
Expand Down
2 changes: 1 addition & 1 deletion test/convergence/bf16/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def run_mini_model(
batch = next(loader_iter).to(model.device)
optimizer.zero_grad()
output = model(**batch)
output.logits.retain_grad() # For comparing logits.grad
output.logits.retain_grad()
output.loss.backward()
optimizer.step()
print(f"Step {i}, Loss: {output.loss.item()}")
Expand Down
5 changes: 3 additions & 2 deletions test/convergence/fp32/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def run_mini_model_multimodal(
batch = next(loader_iter).to(model.device)
optimizer.zero_grad()
output = model(**batch)
output.logits.retain_grad()
output.loss.backward()
optimizer.step()

Expand Down Expand Up @@ -400,8 +401,8 @@ def test_mini_model_multimodal(

# Compare the logits from the last step
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
expected_output["logits"].grad,
actual_output["logits"].grad,
atol=logits_atol,
rtol=logits_rtol,
)
Expand Down
5 changes: 3 additions & 2 deletions test/convergence/fp32/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def run_mini_model(
batch = next(loader_iter).to(model.device)
optimizer.zero_grad()
output = model(**batch)
output.logits.retain_grad()
output.loss.backward()
optimizer.step()
print(f"Step {i}, Loss: {output.loss.item()}")
Expand Down Expand Up @@ -529,8 +530,8 @@ def test_mini_model(
# import pdb; pdb.set_trace()
# Compare the logits from the last step
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
expected_output["logits"].grad,
actual_output["logits"].grad,
atol=logits_atol,
rtol=logits_rtol,
)
Expand Down

0 comments on commit 762165b

Please sign in to comment.