Skip to content

Commit

Permalink
making changes for transfer models
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Sep 11, 2023
1 parent 28a35a5 commit fd73f56
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def train(cfg):
torch.manual_seed(cfg.model_seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 and cfg.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv(
Expand Down Expand Up @@ -513,7 +513,17 @@ def train(cfg):
#print("button")
next_obs_risk = next_obs[:, list(range(40)) + list(range(56, 88))]
else:
next_obs_risk = next_obs
next_obs_risk = next_obs
elif "button" in cfg.risk_model_path.lower():
if "push" in cfg.env_id.lower():
#print("push")
next_obs_risk = next_obs[:, list(range(40)) + list(range(72, 88)) + list(range(40, 72))]
elif "goal" in cfg.env_id.lower():
#print("button")
next_obs_risk = next_obs[:, list(range(40)) + list(range(56, 88))]
else:
next_obs_risk = next_obs

next_risk = torch.Tensor(risk_model(next_obs_risk)).to(device)
if cfg.risk_type == "continuous":
next_risk = next_risk.unsqueeze(0)
Expand Down

0 comments on commit fd73f56

Please sign in to comment.