Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 12, 2024
1 parent a5953cf commit 7db7d7f
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 14 deletions.
78 changes: 69 additions & 9 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,8 @@ def test_dqn_notensordict(
"action": action,
}
td = TensorDict(kwargs, []).unflatten_keys("_")
# Disable warning
SoftUpdate(dqn_loss, eps=0.5)
loss_val = dqn_loss(**kwargs)
loss_val_td = dqn_loss(td)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)
Expand All @@ -795,7 +797,7 @@ def test_distributional_dqn_tensordict_keys(self):
action_spec_type=action_spec_type, atoms=atoms
)

loss_fn = DistributionalDQNLoss(actor, gamma=gamma)
loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=True)

default_keys = {
"priority": "td_error",
Expand Down Expand Up @@ -830,11 +832,14 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est):
action_key=tensor_keys["action"],
action_value_key=tensor_keys["action_value"],
)
loss_fn = DistributionalDQNLoss(actor, gamma=0.9)
loss_fn = DistributionalDQNLoss(actor, gamma=0.9, delay_value=True)
loss_fn.set_keys(**tensor_keys)

loss_fn.make_value_estimator(td_est)

# remove warnings
SoftUpdate(loss_fn, eps=0.5)

with _check_td_steady(td):
_ = loss_fn(td)
assert loss_fn.tensor_keys.priority in td.keys()
Expand Down Expand Up @@ -1004,6 +1009,10 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est):
sum([item for _, item in loss.items()]).backward()
assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0

if delay_value:
# remove warning
SoftUpdate(loss_fn, eps=0.5)

# Check param update effect on targets
target_value = loss_fn.target_local_value_network_params.clone()
for p in loss_fn.parameters():
Expand Down Expand Up @@ -1071,6 +1080,11 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9)
else contextlib.nullcontext()
), _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)

if delay_value:
# remove warning
SoftUpdate(loss_fn, eps=0.5)

assert loss_fn.tensor_keys.priority in ms_td.keys()

with torch.no_grad():
Expand Down Expand Up @@ -1125,7 +1139,7 @@ def test_qmix_tensordict_keys(self, td_est):
action_spec_type = "one_hot"
actor = self._create_mock_actor(action_spec_type=action_spec_type)
mixer = self._create_mock_mixer()
loss_fn = QMixerLoss(actor, mixer)
loss_fn = QMixerLoss(actor, mixer, delay_value=True)

default_keys = {
"advantage": "advantage",
Expand All @@ -1142,7 +1156,7 @@ def test_qmix_tensordict_keys(self, td_est):

self.tensordict_keys_test(loss_fn, default_keys=default_keys)

loss_fn = QMixerLoss(actor, mixer)
loss_fn = QMixerLoss(actor, mixer, delay_value=True)
key_mapping = {
"advantage": ("advantage", "advantage_2"),
"value_target": ("value_target", ("value_target", "nested")),
Expand All @@ -1158,7 +1172,7 @@ def test_qmix_tensordict_keys(self, td_est):
mixer = self._create_mock_mixer(
global_chosen_action_value_key=("some", "nested")
)
loss_fn = QMixerLoss(actor, mixer)
loss_fn = QMixerLoss(actor, mixer, delay_value=True)
key_mapping = {
"global_value": ("value", ("some", "nested")),
}
Expand Down Expand Up @@ -1193,9 +1207,9 @@ def test_qmix_tensordict_run(self, action_spec_type, td_est):
action_value_key=tensor_keys["action_value"],
)

loss_fn = QMixerLoss(actor, mixer, loss_function="l2")
loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=True)
loss_fn.set_keys(**tensor_keys)

SoftUpdate(loss_fn, eps=0.5)
if td_est is not None:
loss_fn.make_value_estimator(td_est)
with _check_td_steady(td):
Expand Down Expand Up @@ -1251,7 +1265,9 @@ def test_mixer_keys(
)
td = actor(td)

loss = QMixerLoss(actor, mixer)
loss = QMixerLoss(actor, mixer, delay_value=True)

SoftUpdate(loss, eps=0.5)

# Wthout etting the keys
if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"):
Expand All @@ -1265,7 +1281,10 @@ def test_mixer_keys(
else:
loss(td)

loss = QMixerLoss(actor, mixer)
loss = QMixerLoss(actor, mixer, delay_value=True)

SoftUpdate(loss, eps=0.5)

# When setting the key
loss.set_keys(global_value=mixer_global_chosen_action_value_key)
if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"):
Expand Down Expand Up @@ -1486,6 +1505,10 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est):
):
loss = loss_fn(td)

if delay_value:
# remove warning
SoftUpdate(loss_fn, eps=0.5)

assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.value_network_params.values(True, True)
Expand Down Expand Up @@ -1602,6 +1625,9 @@ def test_ddpg_separate_losses(
with pytest.warns(UserWarning, match="No target network updater has been"):
loss = loss_fn(td)

# remove warning
SoftUpdate(loss_fn, eps=0.5)

assert all(
(p.grad is None) or (p.grad == 0).all()
for p in loss_fn.value_network_params.values(True, True)
Expand Down Expand Up @@ -1722,6 +1748,11 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9):
else contextlib.nullcontext()
), _check_td_steady(ms_td):
loss_ms = loss_fn(ms_td)

if delay_value:
# remove warning
SoftUpdate(loss_fn, eps=0.5)

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
Expand Down Expand Up @@ -2324,10 +2355,14 @@ def test_td3_batcher(
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority in ms_td.keys()

if delay_qvalue or delay_actor:
SoftUpdate(loss_fn, eps=0.5)

with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)

if n == 0:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum([item for _, item in loss.items()])
Expand Down Expand Up @@ -3311,6 +3346,9 @@ def test_sac_notensordict(
loss_val = loss(**kwargs)

torch.manual_seed(self.seed)

SoftUpdate(loss, eps=0.5)

loss_val_td = loss(td)

if version == 1:
Expand Down Expand Up @@ -3558,6 +3596,7 @@ def test_discrete_sac(
target_entropy_weight=target_entropy_weight,
target_entropy=target_entropy,
loss_function="l2",
action_space="one-hot",
**kwargs,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
Expand Down Expand Up @@ -3668,6 +3707,7 @@ def test_discrete_sac_state_dict(
target_entropy_weight=target_entropy_weight,
target_entropy=target_entropy,
loss_function="l2",
action_space="one-hot",
**kwargs,
)
sd = loss_fn.state_dict()
Expand All @@ -3679,6 +3719,7 @@ def test_discrete_sac_state_dict(
target_entropy_weight=target_entropy_weight,
target_entropy=target_entropy,
loss_function="l2",
action_space="one-hot",
**kwargs,
)
loss_fn2.load_state_dict(sd)
Expand Down Expand Up @@ -3716,6 +3757,7 @@ def test_discrete_sac_batcher(
loss_function="l2",
target_entropy_weight=target_entropy_weight,
target_entropy=target_entropy,
action_space="one-hot",
**kwargs,
)

Expand All @@ -3732,6 +3774,8 @@ def test_discrete_sac_batcher(
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority in ms_td.keys()

SoftUpdate(loss_fn, eps=0.5)

with torch.no_grad():
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
Expand Down Expand Up @@ -3820,6 +3864,7 @@ def test_discrete_sac_tensordict_keys(self, td_est):
qvalue_network=qvalue,
num_actions=actor.spec["action"].space.n,
loss_function="l2",
action_space="one-hot",
)

default_keys = {
Expand All @@ -3842,6 +3887,7 @@ def test_discrete_sac_tensordict_keys(self, td_est):
qvalue_network=qvalue,
num_actions=actor.spec["action"].space.n,
loss_function="l2",
action_space="one-hot",
)

key_mapping = {
Expand Down Expand Up @@ -3880,6 +3926,7 @@ def test_discrete_sac_notensordict(
actor_network=actor,
qvalue_network=qvalue,
num_actions=actor.spec[action_key].space.n,
action_space="one-hot",
)
loss.set_keys(
action=action_key,
Expand Down Expand Up @@ -4390,6 +4437,8 @@ def test_redq_deprecated_separate_losses(self, separate_losses):
):
loss = loss_fn(td)

SoftUpdate(loss_fn, eps=0.5)

# check that losses are independent
for k in loss.keys():
if not k.startswith("loss"):
Expand Down Expand Up @@ -5428,6 +5477,9 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est):
loss = loss_fn(td)
assert loss_fn.tensor_keys.priority in td.keys(True)

if delay_value:
SoftUpdate(loss_fn, eps=0.5)

sum([item for key, item in loss.items() if key.startswith("loss")]).backward()
assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0

Expand Down Expand Up @@ -5487,6 +5539,9 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9)
loss_ms = loss_fn(ms_td)
assert loss_fn.tensor_keys.priority in ms_td.keys()

if delay_value:
SoftUpdate(loss_fn, eps=0.5)

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
Expand Down Expand Up @@ -5586,6 +5641,8 @@ def test_dcql_tensordict_run(self, action_spec_type, td_est):
loss_fn = DiscreteCQLLoss(actor, loss_function="l2")
loss_fn.set_keys(**tensor_keys)

SoftUpdate(loss_fn, eps=0.5)

if td_est is not None:
loss_fn.make_value_estimator(td_est)
with _check_td_steady(td):
Expand All @@ -5610,6 +5667,9 @@ def test_dcql_notensordict(
in_keys=[observation_key],
)
loss = DiscreteCQLLoss(actor)

SoftUpdate(loss, eps=0.5)

loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key)
# define data
observation = torch.randn(n_obs)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(

if functional:
self.convert_to_functional(
actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"]
actor_network, "actor_network",
)
else:
self.actor_network = actor_network
Expand Down
1 change: 0 additions & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=False,
funs_to_decorate=["forward"],
)
self.loss_function = loss_function

Expand Down
4 changes: 1 addition & 3 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
funs_to_decorate=["forward", "get_dist"],
)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
Expand Down Expand Up @@ -980,7 +979,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
funs_to_decorate=["forward", "get_dist"],
)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
Expand Down Expand Up @@ -1036,7 +1034,7 @@ def __init__(
if action_space is None:
warnings.warn(
"action_space was not specified. DiscreteSACLoss will default to 'one-hot'."
"This behaviour will be deprecated soon and a space will have to be passed."
"This behaviour will be deprecated soon and a space will have to be passed. "
"Check the DiscreteSACLoss documentation to see how to pass the action space. "
)
action_space = "one-hot"
Expand Down

0 comments on commit 7db7d7f

Please sign in to comment.