Skip to content

Commit

Permalink
data collection with ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Aug 3, 2023
1 parent 4d35945 commit 6b3768a
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion cleanrl/ppo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from comet_ml import Experiment

from src.models.risk_models import *
from src.utils import *

import hydra
import os
Expand Down Expand Up @@ -202,6 +203,12 @@ def train(cfg):
cum_cost, ep_cost, ep_risk_cost_int, cum_risk_cost_int, ep_risk, cum_risk = 0, 0, 0, 0, 0, 0
cost = 0
last_step = 0
episode = 0

if cfg.ppo.collect_data:
storage_path = os.path.join(cfg.ppo.storage_path, experiment.name)
make_dirs(storage_path, episode)

for update in range(1, num_updates + 1):
# Annealing the rate if instructed to do so.
if cfg.ppo.anneal_lr:
Expand Down Expand Up @@ -237,11 +244,16 @@ def train(cfg):
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminated, truncated)
rewards[step] = torch.tensor(reward).to(device).view(-1)

info_dict = {'reward': reward, 'done': done, 'cost': cost, 'prev_action': action}
if cfg.ppo.collect_data:
store_data(next_obs, info_dict, storage_path, episode)

next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

if not done:
Expand Down Expand Up @@ -287,6 +299,9 @@ def train(cfg):
experiment.log_metric("charts/episodic_cost", ep_cost, global_step)
experiment.log_metric("charts/cummulative_cost", cum_cost, global_step)
last_step = global_step
episode += 1
if cfg.ppo.collect_data:
make_dirs(storage_path, episode)

# bootstrap value if not done
with torch.no_grad():
Expand Down

0 comments on commit 6b3768a

Please sign in to comment.