diff --git a/test/test_cost.py b/test/test_cost.py index 83b61f96cd2..ae08681d9c3 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5952,8 +5952,13 @@ def test_ppo( loss = loss_fn(td) if reduction == "none": - assert loss.batch_size == td.batch_size - loss = loss.apply(lambda x: x.float().mean(), batch_size=[]) + + def func(x): + if x.dtype != torch.float: + return + return x.mean() + + loss = loss.apply(func, batch_size=[]) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) @@ -6699,8 +6704,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti loss_fn.make_value_estimator(td_est) loss = loss_fn(td) if reduction == "none": - assert loss.batch_size == td.batch_size - loss = loss.apply(lambda x: x.float().mean(), batch_size=[]) + + def func(x): + if x.dtype != torch.float: + return + return x.mean() + + loss = loss.apply(func, batch_size=[]) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -7342,8 +7352,13 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction): loss = loss_fn(td) if reduction == "none": - assert loss.batch_size == td.batch_size - loss = loss.apply(lambda x: x.float().mean(), batch_size=[]) + + def func(x): + if x.dtype != torch.float: + return + return x.mean() + + loss = loss.apply(func, batch_size=[]) assert all( (p.grad is None) or (p.grad == 0).all()