Skip to content

Commit

Permalink
Changed Replay Trajectory to add options to record rewards as well as…
Browse files Browse the repository at this point in the history
… specify reward mode (#177)

* Added reward mode argument and record rewards argument to replay trajectory

* Fixed an error

---------

Co-authored-by: Arnav G <[email protected]>
Co-authored-by: Stone Tao <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2023
1 parent 030e5ab commit 4daf06a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
11 changes: 11 additions & 0 deletions mani_skill2/trajectory/replay_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ def parse_args(args=None):
default=None,
help="number of demonstrations to replay before exiting. By default will replay all demonstrations",
)

parser.add_argument(
"--reward-mode", type=str, help="specifies the reward type that the env should use", default="normalized_dense"
)

parser.add_argument(
"--record-rewards", type=bool, help="whether the replayed trajectory should include rewards", default=False
)

return parser.parse_args(args)


Expand Down Expand Up @@ -359,6 +368,7 @@ def _main(args, proc_id: int = 0, num_procs=1, pbar=None):
env_kwargs["obs_mode"] = target_obs_mode
if target_control_mode is not None:
env_kwargs["control_mode"] = target_control_mode
env_kwargs["reward_mode"] = args.reward_mode
env_kwargs[
"render_mode"
] = "rgb_array" # note this only affects the videos saved as RecordEpisode wrapper calls env.render
Expand All @@ -385,6 +395,7 @@ def _main(args, proc_id: int = 0, num_procs=1, pbar=None):
save_trajectory=args.save_traj,
trajectory_name=new_traj_name,
save_video=args.save_video,
record_reward=args.record_rewards
)

if env.save_trajectory:
Expand Down
7 changes: 7 additions & 0 deletions mani_skill2/utils/wrappers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
info_on_video=False,
save_on_reset=True,
clean_on_close=True,
record_reward=False,
video_fps=20,
):
super().__init__(env)
Expand All @@ -117,6 +118,7 @@ def __init__(

self.save_trajectory = save_trajectory
self.clean_on_close = clean_on_close
self.record_reward = record_reward
if self.save_trajectory:
if not trajectory_name:
trajectory_name = time.strftime("%Y%m%d_%H%M%S")
Expand Down Expand Up @@ -313,6 +315,11 @@ def flush_trajectory(self, verbose=False, ignore_empty_transition=False):
# Dump
group.create_dataset("actions", data=actions, dtype=np.float32)
group.create_dataset("success", data=dones, dtype=bool)

if self.record_reward:
rewards = np.stack([x["r"] for x in self._episode_data]).astype(np.float32)
group.create_dataset("rewards",data=rewards, dtype=np.float32)

if self.init_state_only:
group.create_dataset("env_init_state", data=env_states[0], dtype=np.float32)
else:
Expand Down

0 comments on commit 4daf06a

Please sign in to comment.