diff --git a/test/test_cost.py b/test/test_cost.py index d3eea6fd152..11dca14eb92 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -782,6 +782,8 @@ def test_dqn_notensordict( "action": action, } td = TensorDict(kwargs, []).unflatten_keys("_") + # Disable warning + SoftUpdate(dqn_loss, eps=0.5) loss_val = dqn_loss(**kwargs) loss_val_td = dqn_loss(td) torch.testing.assert_close(loss_val_td.get("loss"), loss_val) @@ -795,7 +797,7 @@ def test_distributional_dqn_tensordict_keys(self): action_spec_type=action_spec_type, atoms=atoms ) - loss_fn = DistributionalDQNLoss(actor, gamma=gamma) + loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=True) default_keys = { "priority": "td_error", @@ -830,11 +832,14 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est): action_key=tensor_keys["action"], action_value_key=tensor_keys["action_value"], ) - loss_fn = DistributionalDQNLoss(actor, gamma=0.9) + loss_fn = DistributionalDQNLoss(actor, gamma=0.9, delay_value=True) loss_fn.set_keys(**tensor_keys) loss_fn.make_value_estimator(td_est) + # remove warnings + SoftUpdate(loss_fn, eps=0.5) + with _check_td_steady(td): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -1004,6 +1009,10 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): sum([item for _, item in loss.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): @@ -1071,6 +1080,11 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert loss_fn.tensor_keys.priority in ms_td.keys() with torch.no_grad(): @@ -1125,7 +1139,7 @@ def test_qmix_tensordict_keys(self, td_est): action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) mixer = self._create_mock_mixer() - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) default_keys = { "advantage": "advantage", @@ -1142,7 +1156,7 @@ def test_qmix_tensordict_keys(self, td_est): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "advantage": ("advantage", "advantage_2"), "value_target": ("value_target", ("value_target", "nested")), @@ -1158,7 +1172,7 @@ def test_qmix_tensordict_keys(self, td_est): mixer = self._create_mock_mixer( global_chosen_action_value_key=("some", "nested") ) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "global_value": ("value", ("some", "nested")), } @@ -1193,9 +1207,9 @@ def test_qmix_tensordict_run(self, action_spec_type, td_est): action_value_key=tensor_keys["action_value"], ) - loss_fn = QMixerLoss(actor, mixer, loss_function="l2") + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=True) loss_fn.set_keys(**tensor_keys) - + SoftUpdate(loss_fn, eps=0.5) if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -1251,7 +1265,9 @@ def test_mixer_keys( ) td = actor(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) # Wthout etting the keys if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1265,7 +1281,10 @@ def test_mixer_keys( else: loss(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) + # When setting the key loss.set_keys(global_value=mixer_global_chosen_action_value_key) if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1486,6 +1505,10 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): ): loss = loss_fn(td) + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1602,6 +1625,9 @@ def test_ddpg_separate_losses( with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1722,6 +1748,11 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -2324,10 +2355,14 @@ def test_td3_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_qvalue or delay_actor: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) + if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum([item for _, item in loss.items()]) @@ -3311,6 +3346,9 @@ def test_sac_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) + + SoftUpdate(loss, eps=0.5) + loss_val_td = loss(td) if version == 1: @@ -3558,6 +3596,7 @@ def test_discrete_sac( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): @@ -3668,6 +3707,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) sd = loss_fn.state_dict() @@ -3679,6 +3719,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) loss_fn2.load_state_dict(sd) @@ -3716,6 +3757,7 @@ def test_discrete_sac_batcher( loss_function="l2", target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, + action_space="one-hot", **kwargs, ) @@ -3732,6 +3774,8 @@ def test_discrete_sac_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -3820,6 +3864,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -3842,6 +3887,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -3880,6 +3926,7 @@ def test_discrete_sac_notensordict( actor_network=actor, qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -4390,6 +4437,8 @@ def test_redq_deprecated_separate_losses(self, separate_losses): ): loss = loss_fn(td) + SoftUpdate(loss_fn, eps=0.5) + # check that losses are independent for k in loss.keys(): if not k.startswith("loss"): @@ -5428,6 +5477,9 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys(True) + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + sum([item for key, item in loss.items() if key.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 @@ -5487,6 +5539,9 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -5586,6 +5641,8 @@ def test_dcql_tensordict_run(self, action_spec_type, td_est): loss_fn = DiscreteCQLLoss(actor, loss_function="l2") loss_fn.set_keys(**tensor_keys) + SoftUpdate(loss_fn, eps=0.5) + if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -5610,6 +5667,9 @@ def test_dcql_notensordict( in_keys=[observation_key], ) loss = DiscreteCQLLoss(actor) + + SoftUpdate(loss, eps=0.5) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6edcda5c800..6d291bf5986 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -255,7 +255,7 @@ def __init__( if functional: self.convert_to_functional( - actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"] + actor_network, "actor_network", ) else: self.actor_network = actor_network diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a24aa4a1271..954bd0b9a42 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -275,7 +275,6 @@ def __init__( actor_network, "actor_network", create_target_params=False, - funs_to_decorate=["forward"], ) self.loss_function = loss_function diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 053da9e53d2..5b722fd05f3 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -292,7 +292,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -980,7 +979,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -1036,7 +1034,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DiscreteSACLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed." + "This behaviour will be deprecated soon and a space will have to be passed. " "Check the DiscreteSACLoss documentation to see how to pass the action space. " ) action_space = "one-hot"