Skip to content

Commit

Permalink
[Feature] Introduced terminated
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 2, 2023
1 parent 21708df commit 7437288
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 0 deletions.
9 changes: 9 additions & 0 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _get_loss(
reward=(group, "reward"),
priority=(group, "td_error"),
done=(group, "done"),
terminated=(group, "terminated"),
)

loss_module.make_value_estimator(
Expand Down Expand Up @@ -150,13 +151,21 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
Expand Down
9 changes: 9 additions & 0 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _get_loss(
reward=(group, "reward"),
action=(group, "action"),
done=(group, "done"),
terminated=(group, "terminated"),
advantage=(group, "advantage"),
value_target=(group, "value_target"),
value=(group, "state_value"),
Expand Down Expand Up @@ -199,13 +200,21 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
Expand Down
9 changes: 9 additions & 0 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _get_loss(
reward=(group, "reward"),
action=(group, "action"),
done=(group, "done"),
terminated=(group, "terminated"),
action_value=(group, "action_value"),
value=(group, "chosen_action_value"),
priority=(group, "td_error"),
Expand Down Expand Up @@ -155,13 +156,21 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
Expand Down
10 changes: 10 additions & 0 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _get_loss(
reward=(group, "reward"),
priority=(group, "td_error"),
done=(group, "done"),
terminated=(group, "terminated"),
)

else:
Expand All @@ -116,6 +117,7 @@ def _get_loss(
reward=(group, "reward"),
priority=(group, "td_error"),
done=(group, "done"),
terminated=(group, "terminated"),
)

loss_module.make_value_estimator(
Expand Down Expand Up @@ -232,13 +234,21 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
Expand Down
9 changes: 9 additions & 0 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _get_loss(
reward=(group, "reward"),
priority=(group, "td_error"),
done=(group, "done"),
terminated=(group, "terminated"),
)

loss_module.make_value_estimator(
Expand Down Expand Up @@ -152,13 +153,21 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
Expand Down
9 changes: 9 additions & 0 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _get_loss(
reward=(group, "reward"),
action=(group, "action"),
done=(group, "done"),
terminated=(group, "terminated"),
advantage=(group, "advantage"),
value_target=(group, "value_target"),
value=(group, "state_value"),
Expand Down Expand Up @@ -199,13 +200,21 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
Expand Down
10 changes: 10 additions & 0 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _get_loss(
reward=(group, "reward"),
priority=(group, "td_error"),
done=(group, "done"),
terminated=(group, "terminated"),
)

else:
Expand All @@ -116,6 +117,7 @@ def _get_loss(
reward=(group, "reward"),
priority=(group, "td_error"),
done=(group, "done"),
terminated=(group, "terminated"),
)

loss_module.make_value_estimator(
Expand Down Expand Up @@ -233,13 +235,21 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
group_shape = batch.get(group).shape

nested_done_key = ("next", group, "done")
nested_terminated_key = ("next", group, "terminated")
nested_reward_key = ("next", group, "reward")

if nested_done_key not in keys:
batch.set(
nested_done_key,
batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
)
if nested_terminated_key not in keys:
batch.set(
nested_terminated_key,
batch.get(("next", "terminated"))
.unsqueeze(-1)
.expand((*group_shape, 1)),
)

if nested_reward_key not in keys:
batch.set(
Expand Down
7 changes: 7 additions & 0 deletions benchmarl/algorithms/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _get_loss(
reward="reward",
action=(group, "action"),
done="done",
terminated="terminated",
action_value=(group, "action_value"),
local_value=(group, "chosen_action_value"),
global_value="chosen_action_value",
Expand Down Expand Up @@ -159,13 +160,19 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
keys = list(batch.keys(True, True))

done_key = ("next", "done")
terminated_key = ("next", "terminated")
reward_key = ("next", "reward")

if done_key not in keys:
batch.set(
done_key,
batch.get(("next", group, "done")).any(-2),
)
if terminated_key not in keys:
batch.set(
terminated_key,
batch.get(("next", group, "terminated")).any(-2),
)

if reward_key not in keys:
batch.set(
Expand Down
7 changes: 7 additions & 0 deletions benchmarl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _get_loss(
reward="reward",
action=(group, "action"),
done="done",
terminated="terminated",
action_value=(group, "action_value"),
local_value=(group, "chosen_action_value"),
global_value="chosen_action_value",
Expand Down Expand Up @@ -159,13 +160,19 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
keys = list(batch.keys(True, True))

done_key = ("next", "done")
terminated_key = ("next", "terminated")
reward_key = ("next", "reward")

if done_key not in keys:
batch.set(
done_key,
batch.get(("next", group, "done")).any(-2),
)
if terminated_key not in keys:
batch.set(
terminated_key,
batch.get(("next", group, "terminated")).any(-2),
)

if reward_key not in keys:
batch.set(
Expand Down

0 comments on commit 7437288

Please sign in to comment.