diff --git a/atari/create_dataset.py b/atari/create_dataset.py index 96660077..c10c6fe3 100644 --- a/atari/create_dataset.py +++ b/atari/create_dataset.py @@ -83,11 +83,11 @@ def create_dataset(num_buffers, num_steps, game, data_dir_prefix, trajectories_p rtg = np.zeros_like(stepwise_returns) for i in done_idxs: i = int(i) - curr_traj_returns = stepwise_returns[start_index:i+1] # includes i + curr_traj_returns = stepwise_returns[start_index:i] for j in range(i-1, start_index-1, -1): # start from i-1 - rtg_j = curr_traj_returns[j-start_index:i+1-start_index] - rtg[j] = sum(rtg_j) # includes i - start_index = i+1 + rtg_j = curr_traj_returns[j-start_index:i-start_index] + rtg[j] = sum(rtg_j) + start_index = i print('max rtg is %d' % max(rtg)) # -- create timestep dataset