Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Account for terminating data in SAC losses #2606

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4459,6 +4459,69 @@ def test_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
def test_sac_terminating(
self, action_key, observation_key, reward_key, done_key, terminated_key, version
):
torch.manual_seed(self.seed)
td = self._create_mock_data_sac(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
)

actor = self._create_mock_actor(
observation_key=observation_key, action_key=action_key
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
action_key=action_key,
out_keys=["state_action_value"],
)
if version == 1:
value = self._create_mock_value(observation_key=observation_key)
else:
value = None

loss = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
)
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
)

torch.manual_seed(self.seed)

SoftUpdate(loss, eps=0.5)

done = td.get(("next", done_key))
while not (done.any() and not done.all()):
done.bernoulli_(0.1)
obs_nan = td.get(("next", terminated_key))
obs_nan[done.squeeze(-1)] = float("nan")

kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": done,
f"next_{terminated_key}": obs_nan,
f"next_{observation_key}": td.get(("next", observation_key)),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
assert loss(td).isfinite().all()

def test_state_dict(self, version):
if version == 1:
pytest.skip("Test not implemented for version 1.")
Expand Down Expand Up @@ -5112,6 +5175,62 @@ def test_discrete_sac_notensordict(
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
def test_discrete_sac_terminating(
self, action_key, observation_key, reward_key, done_key, terminated_key
):
torch.manual_seed(self.seed)
td = self._create_mock_data_sac(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
)

actor = self._create_mock_actor(
observation_key=observation_key, action_key=action_key
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
)

loss = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
num_actions=actor.spec[action_key].space.n,
action_space="one-hot",
)
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
)

SoftUpdate(loss, eps=0.5)

torch.manual_seed(0)
done = td.get(("next", done_key))
while not (done.any() and not done.all()):
done = done.bernoulli_(0.1)
obs_none = td.get(("next", observation_key))
obs_none[done.squeeze(-1)] = float("nan")
kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": done,
f"next_{terminated_key}": td.get(("next", terminated_key)),
f"next_{observation_key}": obs_none,
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
assert loss(td).isfinite().all()

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_discrete_sac_reduction(self, reduction):
torch.manual_seed(self.seed)
Expand Down
51 changes: 43 additions & 8 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensordict import TensorDict, TensorDictBase, TensorDictParams

from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from tensordict.utils import expand_right, NestedKey
from torch import Tensor
from torchrl.data.tensor_specs import Composite, TensorSpec
from torchrl.data.utils import _find_action_space
Expand Down Expand Up @@ -711,13 +711,37 @@ def _compute_target_v2(self, tensordict) -> Tensor:
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
next_tensordict = tensordict.get("next").clone(False)
next_dist = self.actor_network.get_dist(next_tensordict)
next_tensordict = tensordict.get("next").copy()
# Check done state and avoid passing these to the actor
done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The done shape could be more extended than the batch shape, this line is breaking in multiagent settings

Copy link
Contributor Author

@vmoens vmoens Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we need a test that covers this use case!
Can you draft one for me?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SOTA ci picked up on this. Both SAC scripts are failing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah i didn't see (SOTA is broken bc of dreamer so I didn't check)
we should have tests that are not in SOTA, SOTA is there to test that scripts run smoothly, not features. The scripts are not part of the core lib - we can arbitrarily decide to ditch them, the rest of the lib should still work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I long wanted to make some tests for multiagent data in losses, will get to it when I have time.

Right now just crunching on writing thesis and satisfying BenchMARL users in free time.

else:
next_tensordict_select = next_tensordict
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = compute_log_prob(
next_dist, next_action, self.tensor_keys.log_prob
)
if next_tensordict_select is not next_tensordict:
mask = ~done.squeeze(-1)
if mask.ndim < next_action.ndim:
mask = expand_right(
mask, (*mask.shape, *next_action.shape[mask.ndim :])
)
next_action = next_action.new_zeros(mask.shape).masked_scatter_(
mask, next_action
)
mask = ~done.squeeze(-1)
if mask.ndim < next_sample_log_prob.ndim:
mask = expand_right(
mask,
(*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
)
next_sample_log_prob = next_sample_log_prob.new_zeros(
mask.shape
).masked_scatter_(mask, next_sample_log_prob)
next_tensordict.set(self.tensor_keys.action, next_action)

# get q-values
next_tensordict_expand = self._vmap_qnetworkN0(
Expand Down Expand Up @@ -1194,15 +1218,21 @@ def _compute_target(self, tensordict) -> Tensor:
with torch.no_grad():
next_tensordict = tensordict.get("next").clone(False)

done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
else:
next_tensordict_select = next_tensordict

# get probs and log probs for actions computed from "next"
with self.actor_network_params.to_module(self.actor_network):
next_dist = self.actor_network.get_dist(next_tensordict)
next_prob = next_dist.probs
next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob))
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_log_prob = next_dist.logits
next_prob = next_log_prob.exp()

# get q-values for all actions
next_tensordict_expand = self._vmap_qnetworkN0(
next_tensordict, self.target_qvalue_network_params
next_tensordict_select, self.target_qvalue_network_params
)
next_action_value = next_tensordict_expand.get(
self.tensor_keys.action_value
Expand All @@ -1212,6 +1242,11 @@ def _compute_target(self, tensordict) -> Tensor:
next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
if next_tensordict_select is not next_tensordict:
mask = ~done.squeeze(-1)
next_state_value = next_state_value.new_zeros(
mask.shape
).masked_scatter_(mask, next_state_value)

tensordict.set(
("next", self.value_estimator.tensor_keys.value), next_state_value
Expand Down
Loading