From c8de01b300fcb56ad8788c1ae2aabde32082f767 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 12 Feb 2024 12:07:05 +0000 Subject: [PATCH] amend --- test/test_cost.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 11dca14eb92..5149a19a44f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -8863,6 +8863,9 @@ def test_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + # Remove warnings + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9286,6 +9289,7 @@ def test_discrete_iql( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -9408,6 +9412,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) sd = loss_fn.state_dict() loss_fn2 = DiscreteIQLLoss( @@ -9418,6 +9423,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) loss_fn2.load_state_dict(sd) @@ -9431,6 +9437,7 @@ def test_discrete_iql_separate_losses(self, separate_losses): value_network=value, loss_function="l2", separate_losses=separate_losses, + action_space="one-hot", ) with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) @@ -9609,6 +9616,7 @@ def test_discrete_iql_batcher( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) ms = MultiStep(gamma=gamma, n_steps=n).to(device) @@ -9624,6 +9632,8 @@ def test_discrete_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9695,6 +9705,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -9720,6 +9731,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -9755,7 +9767,10 @@ def test_discrete_iql_notensordict( value = self._create_mock_value(observation_key=observation_key) loss = DiscreteIQLLoss( - actor_network=actor, qvalue_network=qvalue, value_network=value + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + action_space="one-hot", ) loss.set_keys( action=action_key,