diff --git a/test/test_cost.py b/test/test_cost.py index caf3b372bc1..bc01df6dec8 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5951,7 +5951,7 @@ def test_ppo( loss_fn.make_value_estimator(td_est) loss = loss_fn(td) - if reduction is None: + if reduction == "none": assert loss.batch_size == td.batch_size loss = loss.apply(lambda x: x.float().mean(), batch_size=[]) @@ -6698,7 +6698,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti elif td_est is not None: loss_fn.make_value_estimator(td_est) loss = loss_fn(td) - if reduction is None: + if reduction == "none": assert loss.batch_size == td.batch_size loss = loss.apply(lambda x: x.float().mean(), batch_size=[]) loss_critic = loss["loss_critic"] @@ -7184,7 +7184,7 @@ def test_reinforce_value_net( elif td_est is not None: loss_fn.make_value_estimator(td_est) loss_td = loss_fn(td) - if reduction is None: + if reduction == "none": assert loss_td.batch_size == td.batch_size loss_td = loss_td.apply(lambda x: x.float().mean(), batch_size=[]) autograd.grad( @@ -7346,7 +7346,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction): ) loss = loss_fn(td) - if reduction is None: + if reduction == "none": assert loss.batch_size == td.batch_size loss = loss.apply(lambda x: x.float().mean(), batch_size=[])