Skip to content

Commit

Permalink
better tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Feb 15, 2024
1 parent 7e6b4b2 commit 8052e33
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5952,8 +5952,13 @@ def test_ppo(

loss = loss_fn(td)
if reduction == "none":
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
Expand Down Expand Up @@ -6699,8 +6704,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti
loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
if reduction == "none":
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -7342,8 +7352,13 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):

loss = loss_fn(td)
if reduction == "none":
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

def func(x):
if x.dtype != torch.float:
return
return x.mean()

loss = loss.apply(func, batch_size=[])

assert all(
(p.grad is None) or (p.grad == 0).all()
Expand Down

0 comments on commit 8052e33

Please sign in to comment.