diff --git a/cleanrl/ppo_rnd_envpool.py b/cleanrl/ppo_rnd_envpool.py index 32676d08b..daa23960a 100644 --- a/cleanrl/ppo_rnd_envpool.py +++ b/cleanrl/ppo_rnd_envpool.py @@ -6,8 +6,9 @@ from collections import deque from distutils.util import strtobool -import envpool -import gym +# import envpool +import gymnasium as gym +import safety_gymnasium import numpy as np import torch import torch.nn as nn @@ -16,6 +17,7 @@ from gym.wrappers.normalize import RunningMeanStd from torch.distributions.categorical import Categorical from torch.utils.tensorboard import SummaryWriter +from torch.distributions.normal import Normal def parse_args(): @@ -43,7 +45,7 @@ def parse_args(): help="total timesteps of the experiments") parser.add_argument("--learning-rate", type=float, default=1e-4, help="the learning rate of the optimizer") - parser.add_argument("--num-envs", type=int, default=128, + parser.add_argument("--num-envs", type=int, default=5, help="the number of parallel game environments") parser.add_argument("--num-steps", type=int, default=128, help="the number of steps to run in each environment per policy rollout") @@ -83,7 +85,7 @@ def parse_args(): help="coefficient of intrinsic reward") parser.add_argument("--int-gamma", type=float, default=0.99, help="Intrinsic reward discount rate") - parser.add_argument("--num-iterations-obs-norm-init", type=int, default=50, + parser.add_argument("--num-iterations-obs-norm-init", type=int, default=1, help="number of iterations to initialize the observations normalization parameters") args = parser.parse_args() @@ -110,13 +112,14 @@ def reset(self, **kwargs): return observations def step(self, action): - observations, rewards, dones, infos = super().step(action) - self.episode_returns += infos["reward"] + observations, rewards, terminateds, truncateds, infos = super().step(action) + dones = np.logical_or(terminateds, truncateds) + self.episode_returns += rewards self.episode_lengths += 1 self.returned_episode_returns[:] = self.episode_returns self.returned_episode_lengths[:] = self.episode_lengths - self.episode_returns *= 1 - infos["terminated"] - self.episode_lengths *= 1 - infos["terminated"] + self.episode_returns *= 1 - terminateds + self.episode_lengths *= 1 - terminateds infos["r"] = self.returned_episode_returns infos["l"] = self.returned_episode_lengths return ( @@ -134,54 +137,69 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0): return layer +def make_env(cfg, idx, capture_video, run_name, gamma): + def thunk(): + if capture_video: + env = gym.make(cfg.env_id)#, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance) + else: + env = gym.make(cfg.env_id)#, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance) + env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env = gym.wrappers.ClipAction(env) + env = gym.wrappers.NormalizeObservation(env) + env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) + env = gym.wrappers.NormalizeReward(env, gamma=gamma) + env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) + return env + + return thunk + class Agent(nn.Module): def __init__(self, envs): super().__init__() self.network = nn.Sequential( - layer_init(nn.Conv2d(4, 32, 8, stride=4)), - nn.ReLU(), - layer_init(nn.Conv2d(32, 64, 4, stride=2)), - nn.ReLU(), - layer_init(nn.Conv2d(64, 64, 3, stride=1)), - nn.ReLU(), - nn.Flatten(), - layer_init(nn.Linear(64 * 7 * 7, 256)), - nn.ReLU(), - layer_init(nn.Linear(256, 448)), - nn.ReLU(), + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh() ) - self.extra_layer = nn.Sequential(layer_init(nn.Linear(448, 448), std=0.1), nn.ReLU()) + self.extra_layer = nn.Sequential(layer_init(nn.Linear(64, 64), std=0.1), nn.ReLU()) self.actor = nn.Sequential( - layer_init(nn.Linear(448, 448), std=0.01), + layer_init(nn.Linear(64, 64), std=0.01), nn.ReLU(), - layer_init(nn.Linear(448, envs.single_action_space.n), std=0.01), + layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), ) - self.critic_ext = layer_init(nn.Linear(448, 1), std=0.01) - self.critic_int = layer_init(nn.Linear(448, 1), std=0.01) + self.critic_ext = layer_init(nn.Linear(64, 1), std=0.01) + self.critic_int = layer_init(nn.Linear(64, 1), std=0.01) + self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) def get_action_and_value(self, x, action=None): - hidden = self.network(x / 255.0) + hidden = self.network(x) logits = self.actor(hidden) - probs = Categorical(logits=logits) + action_std = torch.exp(self.actor_logstd) + probs = Normal(logits, action_std) features = self.extra_layer(hidden) if action is None: action = probs.sample() return ( action, - probs.log_prob(action), - probs.entropy(), + probs.log_prob(action).sum(1), + probs.entropy().sum(1), self.critic_ext(features + hidden), self.critic_int(features + hidden), ) def get_value(self, x): - hidden = self.network(x / 255.0) + hidden = self.network(x) features = self.extra_layer(hidden) return self.critic_ext(features + hidden), self.critic_int(features + hidden) class RNDModel(nn.Module): - def __init__(self, input_size, output_size): + def __init__(self, envs, input_size, output_size): super().__init__() self.input_size = input_size @@ -191,30 +209,18 @@ def __init__(self, input_size, output_size): # Prediction network self.predictor = nn.Sequential( - layer_init(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=8, stride=4)), - nn.LeakyReLU(), - layer_init(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)), - nn.LeakyReLU(), - layer_init(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)), - nn.LeakyReLU(), - nn.Flatten(), - layer_init(nn.Linear(feature_output, 512)), + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), nn.ReLU(), - layer_init(nn.Linear(512, 512)), + layer_init(nn.Linear(64, 64)), nn.ReLU(), - layer_init(nn.Linear(512, 512)), + layer_init(nn.Linear(64, 64)), ) # Target network self.target = nn.Sequential( - layer_init(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=8, stride=4)), - nn.LeakyReLU(), - layer_init(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)), - nn.LeakyReLU(), - layer_init(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)), - nn.LeakyReLU(), - nn.Flatten(), - layer_init(nn.Linear(feature_output, 512)), + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.ReLU(), + layer_init(nn.Linear(64, 64)), ) # target network is not trainable @@ -271,23 +277,13 @@ def update(self, rews): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - envs = envpool.make( - args.env_id, - env_type="gym", - num_envs=args.num_envs, - episodic_life=True, - reward_clip=True, - seed=args.seed, - repeat_action_probability=0.25, + envs = gym.vector.SyncVectorEnv( + [make_env(args, i, False, run_name, args.gamma) for i in range(args.num_envs)] ) - envs.num_envs = args.num_envs - envs.single_action_space = envs.action_space - envs.single_observation_space = envs.observation_space - envs = RecordEpisodeStatistics(envs) - assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" + # assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" agent = Agent(envs).to(device) - rnd_model = RNDModel(4, envs.single_action_space.n).to(device) + rnd_model = RNDModel(envs, 4, 4).to(device) combined_parameters = list(agent.parameters()) + list(rnd_model.predictor.parameters()) optimizer = optim.Adam( combined_parameters, @@ -296,7 +292,7 @@ def update(self, rews): ) reward_rms = RunningMeanStd() - obs_rms = RunningMeanStd(shape=(1, 1, 84, 84)) + obs_rms = RunningMeanStd(shape=(1, np.array(envs.single_observation_space.shape).prod())) discounted_reward = RewardForwardFilter(args.int_gamma) # ALGO Logic: Storage setup @@ -313,16 +309,18 @@ def update(self, rews): # TRY NOT TO MODIFY: start the game global_step = 0 start_time = time.time() - next_obs = torch.Tensor(envs.reset()).to(device) + next_obs = torch.Tensor(envs.reset()[0]).to(device) next_done = torch.zeros(args.num_envs).to(device) num_updates = args.total_timesteps // args.batch_size print("Start to initialize observation normalization parameter.....") next_ob = [] for step in range(args.num_steps * args.num_iterations_obs_norm_init): - acs = np.random.randint(0, envs.single_action_space.n, size=(args.num_envs,)) - s, r, d, _ = envs.step(acs) - next_ob += s[:, 3, :, :].reshape([-1, 1, 84, 84]).tolist() + acs = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + s, r, te, tr, _ = envs.step(acs) + d = np.logical_or(te, tr) + # print(s.shape) + next_ob += s.tolist() if len(next_ob) % (args.num_steps * args.num_envs) == 0: next_ob = np.stack(next_ob) @@ -355,34 +353,54 @@ def update(self, rews): logprobs[step] = logprob # TRY NOT TO MODIFY: execute the game and log data. - next_obs, reward, done, info = envs.step(action.cpu().numpy()) + next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy()) + done = np.logical_or(terminated, truncated) rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) rnd_next_obs = ( ( - (next_obs[:, 3, :, :].reshape(args.num_envs, 1, 84, 84) - torch.from_numpy(obs_rms.mean).to(device)) + (next_obs - torch.from_numpy(obs_rms.mean).to(device)) / torch.sqrt(torch.from_numpy(obs_rms.var).to(device)) ).clip(-5, 5) ).float() target_next_feature = rnd_model.target(rnd_next_obs) predict_next_feature = rnd_model.predictor(rnd_next_obs) curiosity_rewards[step] = ((target_next_feature - predict_next_feature).pow(2).sum(1) / 2).data - for idx, d in enumerate(done): - if d and info["lives"][idx] == 0: - avg_returns.append(info["r"][idx]) - epi_ret = np.average(avg_returns) - print( - f"global_step={global_step}, episodic_return={info['r'][idx]}, curiosity_reward={np.mean(curiosity_rewards[step].cpu().numpy())}" - ) - writer.add_scalar("charts/avg_episodic_return", epi_ret, global_step) - writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) - writer.add_scalar( - "charts/episode_curiosity_reward", - curiosity_rewards[step][idx], - global_step, - ) - writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) - + # for idx, d in enumerate(done): + # if d and info["lives"][idx] == 0: + # avg_returns.append(info["r"][idx]) + # epi_ret = np.average(avg_returns) + # print( + # f"global_step={global_step}, episodic_return={info['r'][idx]}, curiosity_reward={np.mean(curiosity_rewards[step].cpu().numpy())}" + # ) + # writer.add_scalar("charts/avg_episodic_return", epi_ret, global_step) + # writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) + # writer.add_scalar( + # "charts/episode_curiosity_reward", + # curiosity_rewards[step][idx], + # global_step, + # ) + # writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) + + # Only print when at least 1 env is done + if "final_info" not in infos: + continue + for i, info in enumerate(infos["final_info"]): + # Skip the envs that are not done + if info is None: + continue + avg_returns.append(info["episode"]["r"]) + epi_ret = np.average(avg_returns) + print(f"Episodic Return: {info['episode']['r']}") + writer.add_scalar("Performance/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("Performance/episodic_length", info["episode"]["l"], global_step) + writer.add_scalar( + "charts/episode_curiosity_reward", + curiosity_rewards[step][i], + global_step, + ) + writer.add_scalar("charts/avg_episodic_return", epi_ret, global_step) + curiosity_reward_per_env = np.array( [discounted_reward.update(reward_per_step) for reward_per_step in curiosity_rewards.cpu().data.numpy().T] ) @@ -428,7 +446,7 @@ def update(self, rews): # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) b_logprobs = logprobs.reshape(-1) - b_actions = actions.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) b_ext_advantages = ext_advantages.reshape(-1) b_int_advantages = int_advantages.reshape(-1) b_ext_returns = ext_returns.reshape(-1) @@ -437,14 +455,14 @@ def update(self, rews): b_advantages = b_int_advantages * args.int_coef + b_ext_advantages * args.ext_coef - obs_rms.update(b_obs[:, 3, :, :].reshape(-1, 1, 84, 84).cpu().numpy()) + obs_rms.update(b_obs.cpu().numpy()) # Optimizing the policy and value network b_inds = np.arange(args.batch_size) rnd_next_obs = ( ( - (b_obs[:, 3, :, :].reshape(-1, 1, 84, 84) - torch.from_numpy(obs_rms.mean).to(device)) + (b_obs - torch.from_numpy(obs_rms.mean).to(device)) / torch.sqrt(torch.from_numpy(obs_rms.var).to(device)) ).clip(-5, 5) ).float()