Skip to content

Commit

Permalink
[Feature] Introduced multi-key reward sum
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 2, 2023
1 parent 7afff7d commit 21708df
Showing 1 changed file with 5 additions and 19 deletions.
24 changes: 5 additions & 19 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential
from tensordict.utils import _unravel_key_to_tuple
from torchrl.collectors import SyncDataCollector
from torchrl.envs import EnvBase, RewardSum, SerialEnv, TransformedEnv
from torchrl.envs import RewardSum, SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name
Expand Down Expand Up @@ -208,28 +207,15 @@ def _setup_task(self):
self.group_map = self.task.group_map(test_env)
self.max_steps = self.task.max_steps(test_env)

reward_spec = test_env.output_spec["full_reward_spec"]
transforms = []
for reward_key in reward_spec.keys(True, True):
transforms.append(
RewardSum(
in_keys=[reward_key],
out_keys=[
_unravel_key_to_tuple(reward_key)[:-1] + ("episode_reward",)
],
)
)
transforms = [RewardSum()]
transform = Compose(*transforms)

def env_func_transformed() -> EnvBase:
return TransformedEnv(env_func(), transform.clone())

if test_env.batch_size == ():
self.env_func = lambda: SerialEnv(
self.config.n_envs_per_worker, env_func_transformed
self.env_func = lambda: TransformedEnv(
SerialEnv(self.config.n_envs_per_worker, env_func), transform.clone()
)
else:
self.env_func = env_func_transformed
self.env_func = lambda: TransformedEnv(env_func(), transform.clone())

self.test_env = test_env

Expand Down

0 comments on commit 21708df

Please sign in to comment.