Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Feb 15, 2024
1 parent 7e516f8 commit 2701bb8
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5951,7 +5951,7 @@ def test_ppo(
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if reduction is None:
if reduction == "none":
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

Expand Down Expand Up @@ -6698,7 +6698,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
if reduction is None:
if reduction == "none":
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])
loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -7184,7 +7184,7 @@ def test_reinforce_value_net(
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss_td = loss_fn(td)
if reduction is None:
if reduction == "none":
assert loss_td.batch_size == td.batch_size
loss_td = loss_td.apply(lambda x: x.float().mean(), batch_size=[])
autograd.grad(
Expand Down Expand Up @@ -7346,7 +7346,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):
)

loss = loss_fn(td)
if reduction is None:
if reduction == "none":
assert loss.batch_size == td.batch_size
loss = loss.apply(lambda x: x.float().mean(), batch_size=[])

Expand Down

0 comments on commit 2701bb8

Please sign in to comment.