Skip to content

Commit

Permalink
better tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Feb 15, 2024
1 parent 566b2b9 commit 7e6b4b2
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5897,7 +5897,7 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_ppo(
self,
loss_class,
Expand Down Expand Up @@ -6644,7 +6644,7 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_a2c(device=device)
Expand Down Expand Up @@ -7102,9 +7102,8 @@ class TestReinforce(LossModuleTestBase):
@pytest.mark.parametrize(
"delay_value,functional", [[False, True], [False, False], [True, True]]
)
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
def test_reinforce_value_net(
self, advantage, gradient_mode, delay_value, td_est, functional, reduction
self, advantage, gradient_mode, delay_value, td_est, functional
):
n_obs = 3
n_act = 5
Expand Down Expand Up @@ -7152,7 +7151,6 @@ def test_reinforce_value_net(
critic_network=value_net,
delay_value=delay_value,
functional=functional,
reduction=reduction,
)

td = TensorDict(
Expand Down Expand Up @@ -7184,9 +7182,6 @@ def test_reinforce_value_net(
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
loss_td = loss_fn(td)
if reduction == "none":
assert loss_td.batch_size == td.batch_size
loss_td = loss_td.apply(lambda x: x.float().mean(), batch_size=[])
autograd.grad(
loss_td.get("loss_actor"),
actor_net.parameters(),
Expand Down Expand Up @@ -7334,7 +7329,7 @@ def _create_mock_common_layer_setup(
return actor, critic, common, td

@pytest.mark.parametrize("separate_losses", [False, True])
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction):
torch.manual_seed(self.seed)
actor, critic, common, td = self._create_mock_common_layer_setup()
Expand Down

0 comments on commit 7e6b4b2

Please sign in to comment.