diff --git a/.gitignore b/.gitignore index 4a98f29..9a5d811 100644 --- a/.gitignore +++ b/.gitignore @@ -18,10 +18,10 @@ dist/ out*/ *.stl *.txt +*.mp4 -videos/ environments/ screenshots/ scratch/ -videos/ assets/ +wandb/ diff --git a/README.md b/README.md index 99a3caa..d85e46a 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,13 @@ Minimal training and inference code for making a humanoid robot stand up. - [ ] Implement simple PPO policy to try to make the robot stand up - [ ] Parallelize using JAX +# Findings +- Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty +- Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward. + +# Currently tests: +- Hidden layer size of 256 shows progress (loss is based on state.q[2]) + ## Goals - The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files: diff --git a/environment.py b/environment.py index a74df81..a75d5d6 100644 --- a/environment.py +++ b/environment.py @@ -125,6 +125,7 @@ def step(self, env_state: State, action: jp.ndarray) -> State: """Run one timestep of the environment's dynamics and returns observations with rewards.""" state = env_state.pipeline_state next_state = self.pipeline_step(state, action) + obs = self.get_obs(state, action) reward = self.compute_reward(state, next_state, action) @@ -141,7 +142,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0) is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy) - is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0) + # is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0) ctrl_cost = -jp.sum(jp.square(action)) @@ -158,7 +159,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr # ) # jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True) - total_reward = 2.0 * is_healthy + 0.1 * ctrl_cost - 5.0 * is_bad + total_reward = 0.1 * ctrl_cost + 5 * state.q[2] return total_reward @@ -177,20 +178,25 @@ def is_done(self, state: MjxState) -> jp.ndarray: return done def get_obs(self, data: MjxState, action: jp.ndarray) -> jp.ndarray: - """Returns the observation of the environment to pass to actor/critic model.""" - position = data.qpos - position = position[2:] # excludes "current positions" - - # external_contact_forces are excluded - return jp.concatenate( - [ - position, - data.qvel, - data.cinert[1:].ravel(), - data.cvel[1:].ravel(), - data.qfrc_actuator, - ] - ) + obs_components = [ + data.qpos[2:], + data.qvel, + data.cinert[1:].ravel(), + data.cvel[1:].ravel(), + data.qfrc_actuator, + ] + + def clean_component(component: jp.ndarray) -> jp.ndarray: + # Check for NaNs or Infs and replace them + nan_mask = jp.isnan(component) + inf_mask = jp.isinf(component) + component = jp.where(nan_mask, 0.0, component) + component = jp.where(inf_mask, jp.where(component > 0, 1e6, -1e6), component) + return component + + cleaned_components = [clean_component(comp) for comp in obs_components] + + return jp.concatenate(cleaned_components) def run_environment_adhoc() -> None: diff --git a/train.py b/train.py index d1f9ac1..c482a4a 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,8 @@ """Trains a policy network to get a humanoid to stand up.""" import argparse +from functools import partial +import wandb import logging import os from dataclasses import dataclass, field @@ -17,6 +19,7 @@ from jax import Array from tqdm import tqdm +import wandb from environment import HumanoidEnv logger = logging.getLogger(__name__) @@ -24,26 +27,28 @@ @dataclass class Config: - lr_actor: float = field(default=3e-4, metadata={"help": "Learning rate for the actor network."}) - lr_critic: float = field(default=3e-4, metadata={"help": "Learning rate for the critic network."}) + lr_actor: float = field(default=2.5e-4, metadata={"help": "Learning rate for the actor network."}) + lr_critic: float = field(default=2.5e-4, metadata={"help": "Learning rate for the critic network."}) num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."}) - num_envs: int = field(default=16, metadata={"help": "Number of environments to run at once with vectorization."}) + num_envs: int = field(default=32, metadata={"help": "Number of environments to run at once with vectorization."}) max_steps_per_episode: int = field( - default=512 * 16, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} + default=128 * 32, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} ) max_steps_per_iteration: int = field( - default=1024 * 16, + default=512 * 32, metadata={ "help": "Maximum number of steps per iteration of simulating environments (across ALL environments)." }, ) - gamma: float = field(default=0.98, metadata={"help": "Discount factor for future rewards."}) - lambd: float = field(default=0.99, metadata={"help": "Lambda parameter for GAE calculation."}) - batch_size: int = field(default=64, metadata={"help": "Batch size for training updates."}) + gamma: float = field(default=0.99, metadata={"help": "Discount factor for future rewards."}) + lambd: float = field(default=0.95, metadata={"help": "Lambda parameter for GAE calculation."}) + batch_size: int = field(default=32, metadata={"help": "Batch size for training updates."}) epsilon: float = field(default=0.2, metadata={"help": "Clipping parameter for PPO."}) l2_rate: float = field(default=0.001, metadata={"help": "L2 regularization rate for the critic."}) + entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."}) +# NOTE: change how initialize weights? class Actor(eqx.Module): """Actor network for PPO.""" @@ -54,10 +59,16 @@ class Actor(eqx.Module): def __init__(self, input_size: int, action_size: int, key: Array) -> None: keys = jax.random.split(key, 4) - self.linear1 = eqx.nn.Linear(input_size, 64, key=keys[0]) - self.linear2 = eqx.nn.Linear(64, 64, key=keys[1]) - self.mu_layer = eqx.nn.Linear(64, action_size, key=keys[2]) - self.log_sigma_layer = eqx.nn.Linear(64, action_size, key=keys[3]) + self.linear1 = eqx.nn.Linear(input_size, 256, key=keys[0]) + self.linear2 = eqx.nn.Linear(256, 256, key=keys[1]) + self.mu_layer = eqx.nn.Linear(256, action_size, key=keys[2]) + self.log_sigma_layer = eqx.nn.Linear(256, action_size, key=keys[3]) + + # Parameter initialization according to Trick #2 + self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) + self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1]) + self.mu_layer = self.initialize_layer(self.mu_layer, 0.01, keys[2]) + self.log_sigma_layer = self.initialize_layer(self.log_sigma_layer, 0.01, keys[3]) def __call__(self, x: Array) -> Tuple[Array, Array]: x = jax.nn.tanh(self.linear1(x)) @@ -66,6 +77,24 @@ def __call__(self, x: Array) -> Tuple[Array, Array]: log_sigma = self.log_sigma_layer(x) return mu, jnp.exp(log_sigma) + def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: + weight_shape = layer.weight.shape + + initializer = jax.nn.initializers.orthogonal() + new_weight = initializer(key, weight_shape, jnp.float32) * scale + new_bias = jnp.zeros(layer.bias.shape) if layer.bias is not None else None + + def where_weight(layer: eqx.nn.Linear) -> Array: + return layer.weight + + def where_bias(layer: eqx.nn.Linear) -> Array | None: + return layer.bias + + new_layer = eqx.tree_at(where_weight, layer, new_weight) + new_layer = eqx.tree_at(where_bias, new_layer, new_bias) + + return new_layer + class Critic(eqx.Module): """Critic network for PPO.""" @@ -76,15 +105,38 @@ class Critic(eqx.Module): def __init__(self, input_size: int, key: Array) -> None: keys = jax.random.split(key, 3) - self.linear1 = eqx.nn.Linear(input_size, 64, key=keys[0]) - self.linear2 = eqx.nn.Linear(64, 64, key=keys[1]) - self.value_layer = eqx.nn.Linear(64, 1, key=keys[2]) + self.linear1 = eqx.nn.Linear(input_size, 256, key=keys[0]) + self.linear2 = eqx.nn.Linear(256, 256, key=keys[1]) + self.value_layer = eqx.nn.Linear(256, 1, key=keys[2]) + + # Parameter initialization according to Trick #2 + self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) + self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1]) + self.value_layer = self.initialize_layer(self.value_layer, 1.0, keys[2]) def __call__(self, x: Array) -> Array: x = jax.nn.tanh(self.linear1(x)) x = jax.nn.tanh(self.linear2(x)) return self.value_layer(x) + def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: + weight_shape = layer.weight.shape + + initializer = jax.nn.initializers.orthogonal() + new_weight = initializer(key, weight_shape, jnp.float32) * scale + new_bias = jnp.zeros(layer.bias.shape) if layer.bias is not None else None + + def where_weight(layer: eqx.nn.Linear) -> Array: + return layer.weight + + def where_bias(layer: eqx.nn.Linear) -> Array | None: + return layer.bias + + new_layer = eqx.tree_at(where_weight, layer, new_weight) + new_layer = eqx.tree_at(where_bias, new_layer, new_bias) + + return new_layer + class Ppo: def __init__(self, observation_size: int, action_size: int, config: Config, key: Array) -> None: @@ -93,8 +145,21 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: self.actor = Actor(observation_size, action_size, subkey1) self.critic = Critic(observation_size, subkey2) - self.actor_optim = optax.adam(learning_rate=config.lr_actor) - self.critic_optim = optax.adamw(learning_rate=config.lr_critic, weight_decay=config.l2_rate) + total_timesteps = config.num_iterations + + # Learning rate annealing according to Trick #4 + self.actor_schedule = optax.linear_schedule( + init_value=config.lr_actor, end_value=1e-6, transition_steps=total_timesteps + ) + self.critic_schedule = optax.linear_schedule( + init_value=config.lr_critic, end_value=1e-6, transition_steps=total_timesteps + ) + + # eps below according to Trick #3 + self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5)) + self.critic_optim = optax.chain( + optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5) + ) # Initialize optimizer states self.actor_opt_state = self.actor_optim.init(eqx.filter(self.actor, eqx.is_array)) @@ -133,35 +198,38 @@ def train_step( params: Dict[str, Any], states_b: Array, actions_b: Array, - rewards_b: Array, - masks_b: Array, + returns_b: Array, + advants_b: Array, + old_log_prob_b: Array, config: Config, -) -> Tuple[Dict[str, Any], Array, Array]: +) -> Tuple[Dict[str, Any], Tuple[Array, Array]]: """Perform a single training step with PPO parameters.""" actor, critic, actor_opt_state, critic_opt_state = params.values() actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) - values_b = critic_vmap(critic, states_b).squeeze() - returns_b, advants_b = get_gae(rewards_b, masks_b, values_b, config) - - old_mu_b, old_std_b = actor_vmap(actor, states_b) - old_log_prob_b = actor_log_prob(old_mu_b, old_std_b, actions_b) + # Normalizing advantages *in minibatch* according to Trick #7 + advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8) - @eqx.filter_value_and_grad + @partial(eqx.filter_value_and_grad, has_aux=True) def actor_loss_fn(actor: Actor) -> Array: """Prioritizing advantageous actions over more training.""" mu_b, std_b = actor_vmap(actor, states_b) new_log_prob_b = actor_log_prob(mu_b, std_b, actions_b) + # Calculating the ratio of new and old probabilities ratio_b = jnp.exp(new_log_prob_b - old_log_prob_b) surrogate_loss_b = ratio_b * advants_b # Clipping is done to prevent too much change if new advantages are very large clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b + actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) - return actor_loss + entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * std_b**2) + 1)) + + total_loss = actor_loss - config.entropy_coeff * entropy_loss + return total_loss @eqx.filter_value_and_grad def critic_loss_fn(critic: Critic) -> Array: @@ -170,7 +238,7 @@ def critic_loss_fn(critic: Critic) -> Array: critic_loss = jnp.mean((critic_returns_b - returns_b) ** 2) return critic_loss - # Calculating actor loss and updating actor parameters + # Calculating actor loss and updating actor parameters --- outputting auxillary data for logging actor_loss, actor_grads = actor_loss_fn(actor) actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) @@ -187,7 +255,7 @@ def critic_loss_fn(critic: Critic) -> Array: "critic_opt_state": new_critic_opt_state, } - return new_params, actor_loss, critic_loss + return new_params, (actor_loss, critic_loss) def get_gae(rewards: Array, masks: Array, values: Array, config: Config) -> Tuple[Array, Array]: @@ -205,25 +273,35 @@ def gae_step(carry: Tuple[Array, Array], inp: Tuple[Array, Array, Array]) -> Tup _, advantages = jax.lax.scan( f=gae_step, init=(jnp.zeros_like(rewards[-1]), values[-1]), - xs=(rewards[::-1], masks[::-1], values[::-1]), + xs=(rewards, masks, values), # NOTE: correct direction? reverse=True, ) returns = advantages + values - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return returns, advantages def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Config) -> None: """Train the PPO model using the memory collected from the environment.""" - # NOTE: think this needs to be reimplemented for vectorization because currently, - # doesn't account that memory order is maintained - # Reorders memory according to states, actions, rewards, masks states = jnp.array([e[0] for e in memory]) actions = jnp.array([e[1] for e in memory]) rewards = jnp.array([e[2] for e in memory]) masks = jnp.array([e[3] for e in memory]) + # Calculate old log probabilities + actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) + old_mu, old_std = actor_vmap(ppo.actor, states) + old_log_prob = actor_log_prob(old_mu, old_std, actions) + + # Calculate values for all states + critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) + values = critic_vmap(ppo.critic, states).squeeze() + + # NOTE: are the output shapes correct? + + # Calculate GAE and returns + returns, advantages = get_gae(rewards, masks, values, config) + n = len(states) arr = jnp.arange(n) key = jax.random.PRNGKey(0) @@ -231,8 +309,13 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con for epoch in range(1): key, subkey = jax.random.split(key) arr = jax.random.permutation(subkey, arr) + + # Calculate average advantages and returns + avg_advantages = jnp.mean(advantages) + avg_returns = jnp.mean(returns) total_actor_loss = 0.0 total_critic_loss = 0.0 + logger.info("Processing %d batches", n // config.batch_size) for i in range(n // config.batch_size): @@ -240,34 +323,50 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con batch_indices = arr[config.batch_size * i : config.batch_size * (i + 1)] states_b = states[batch_indices] actions_b = actions[batch_indices] - rewards_b = rewards[batch_indices] - masks_b = masks[batch_indices] + returns_b = returns[batch_indices] + advantages_b = advantages[batch_indices] + old_log_prob_b = old_log_prob[batch_indices] params = ppo.get_params() - new_params, actor_loss, critic_loss = train_step( + new_params, (actor_loss, critic_loss) = train_step( ppo.actor_optim, ppo.critic_optim, params, states_b, actions_b, - rewards_b, - masks_b, + returns_b, + advantages_b, + old_log_prob_b, config, ) ppo.update_params(new_params) total_actor_loss += actor_loss.mean().item() total_critic_loss += critic_loss.mean().item() + total_entropy_loss += entropy_loss.item() + total_fraction_clipped += fraction_clipped.item() mean_actor_loss = total_actor_loss / (n // config.batch_size) mean_critic_loss = total_critic_loss / (n // config.batch_size) + mean_entropy_loss = total_entropy_loss / (n // config.batch_size) + mean_fraction_clipped = total_fraction_clipped / (n // config.batch_size) logger.info(f"Mean Actor Loss: {mean_actor_loss}, Mean Critic Loss: {mean_critic_loss}") + # Log metrics to wandb + wandb.log( + { + "actor_loss": mean_actor_loss, + "critic_loss": mean_critic_loss, + "avg_advantages": avg_advantages, + "avg_returns": avg_returns, + } + ) + def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array: """Calculate the log probability of the actions given the actor network's output.""" - return jax.scipy.stats.norm.logpdf(actions, mu, sigma).sum(axis=-1) + return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-8).sum(axis=-1) def actor_distribution(mu: Array, sigma: Array, rng: Array) -> Array: @@ -275,16 +374,13 @@ def actor_distribution(mu: Array, sigma: Array, rng: Array) -> Array: return jax.random.normal(rng, shape=mu.shape) * sigma + mu -def unwrap_state_vectorization(state: State, config: Config) -> State: +def unwrap_state_vectorization(state: State, envs_to_sample: int) -> State: """Unwraps one environment the vectorized rollout so that the frames in videos are correctly ordered.""" unwrapped_rollout = [] # Get all attributes of the state attributes = dir(state) - # NOTE: can change ordering of this to save runtiem if want to save more vectorized states. - # NOTE: (but anyways, the video isn't correctly ordered then) - # saves from only first vectorized state - for i in range(1): + for i in range(envs_to_sample): # Create a new state with the first element of each attribute new_state = {} for attr in attributes: @@ -330,6 +426,16 @@ def update_memory(memory: Dict[str, Array], new_data: Dict[str, Array]) -> Dict[ return jax.tree.map(lambda x, y: jnp.concatenate([x, y]), memory, new_data) +def reorder_memory(memory: Dict[str, Array], num_envs: int) -> Dict[str, Array]: + reordered_memory = { + "states": jnp.concatenate([memory["states"][i::num_envs] for i in range(num_envs)], axis=0), + "actions": jnp.concatenate([memory["actions"][i::num_envs] for i in range(num_envs)], axis=0), + "rewards": jnp.concatenate([memory["rewards"][i::num_envs] for i in range(num_envs)], axis=0), + "masks": jnp.concatenate([memory["masks"][i::num_envs] for i in range(num_envs)], axis=0), + } + return reordered_memory + + def main() -> None: logging.basicConfig( level=logging.INFO, @@ -343,11 +449,13 @@ def main() -> None: parser.add_argument("--width", type=int, default=640, help="width of the video frame") parser.add_argument("--height", type=int, default=480, help="height of the video frame") parser.add_argument("--render_every", type=int, default=2, help="render the environment every N steps") - parser.add_argument("--video_length", type=int, default=5, help="maxmimum length of video in seconds") + parser.add_argument("--video_length", type=int, default=10, help="maxmimum length of video in seconds") parser.add_argument("--save_video_every", type=int, default=100, help="save video every N iterations") + parser.add_argument("--envs_to_sample", type=int, default=4, help="number of environments to sample for video") args = parser.parse_args() config = Config() + wandb.init(project="humanoid-ppo", config=vars(config)) env = HumanoidEnv() observation_size = env.observation_size @@ -378,35 +486,47 @@ def step_fn(states: State, actions: jax.Array) -> State: for i in range(1, config.num_iterations + 1): # Initialize memory as JAX arrays - memory = { - "states": jnp.empty((0, observation_size)), - "actions": jnp.empty((0, action_size)), - "rewards": jnp.empty((0,)), - "masks": jnp.empty((0,)), - } scores = [] steps = 0 rollout: List[MjxState] = [] + rng, reset_rng = jax.random.split(rng) + states = reset_fn(reset_rng) pbar = tqdm(total=config.max_steps_per_iteration, desc=f"Steps for iteration {i}") + wait = -1 - while steps < config.max_steps_per_iteration: + while steps < config.max_steps_per_iteration // config.num_envs: episodes += config.num_envs - rng, reset_rng = jax.random.split(rng) - states = reset_fn(reset_rng) - obs = jax.device_put(states.obs) + # Normalizing observations quickens learning + norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( + jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + ) + obs = jax.device_put(norm_obs) score = jnp.zeros(config.num_envs) - for _ in range(config.max_steps_per_episode): + memory = { + "states": jnp.empty((0, observation_size)), + "actions": jnp.empty((0, action_size)), + "rewards": jnp.empty((0,)), + "masks": jnp.empty((0,)), + } + + for _ in range(config.max_steps_per_episode // config.num_envs): # Choosing actions choose_action_vmap = jax.vmap(choose_action, in_axes=(None, 0, 0)) rng, *action_rng = jax.random.split(rng, num=config.num_envs + 1) actions = choose_action_vmap(ppo.actor, obs, jnp.array(action_rng)) + # NOTE: disable actions when "done" --> better for current wait system? states = step_fn(states, actions) - next_obs, rewards, dones = states.obs, states.reward, states.done + + # Normalizing observations quickens learning + norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( + jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + ) + next_obs, rewards, dones = norm_obs, states.reward, states.done masks = (1 - dones).astype(jnp.float32) # Update memory @@ -414,6 +534,10 @@ def step_fn(states: State, actions: jax.Array) -> State: memory = update_memory(memory, new_data) score += rewards + + if jnp.any(jnp.isnan(rewards)): + print(rewards) + obs = next_obs steps += config.num_envs pbar.update(config.num_envs) @@ -424,24 +548,52 @@ def step_fn(states: State, actions: jax.Array) -> State: and i % args.save_video_every == 0 and len(rollout) < args.video_length * int(1 / env.dt) ): - unwrapped_states = unwrap_state_vectorization(states.pipeline_state, config) + unwrapped_states = unwrap_state_vectorization(states.pipeline_state, args.envs_to_sample) rollout.extend(unwrapped_states) if jnp.all(dones): - break + + # NOTE: this "waiting" penalizes systems that just want to fall quick (and reset quick), + # since prevents fast reset goes from needing 84 steps to 30 steps to get to pretty high scores. + # This isn't implemented in any other PPO though so not sure if it's a good idea. + if wait == -1: + wait = 10 + else: + wait -= 1 + if wait == 0: + rng, reset_rng = jax.random.split(rng) + states = reset_fn(reset_rng) + break # with open("log_" + args.env_name + ".txt", "a") as outfile: # outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n") scores.append(jnp.mean(score)) + memory = reorder_memory(memory, config.num_envs) + + train_memory = [ + (s, a, r, m) + for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"]) + ] + train(ppo, train_memory, config) + score_avg = float(jnp.mean(jnp.array(scores))) pbar.close() logger.info("Episode %s score is %.2f", episodes, score_avg) + + wandb.log({"score": score_avg, "episode": episodes}) + + wandb.log({"score": score_avg, "episode": episodes}) # Save video for this iteration if args.save_video_every and i % args.save_video_every == 0 and rollout: + + reordered_rollout = [ + frame for i in range(args.envs_to_sample) for frame in rollout[i :: args.envs_to_sample] + ] + images = jnp.array( - env.render(rollout[:: args.render_every], camera="side", width=args.width, height=args.height) + env.render(reordered_rollout[:: args.render_every], camera="side", width=args.width, height=args.height) ) fps = int(1 / env.dt) @@ -456,12 +608,7 @@ def step_fn(states: State, actions: jax.Array) -> State: logger.info("Saving video to %s for iteration %d", video_path, i) media.write_video(video_path, images, fps=fps) - # Convert memory to the format expected by ppo.train - train_memory = [ - (s, a, r, m) for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"]) - ] - train(ppo, train_memory, config) - if __name__ == "__main__": main() + wandb.init(project="humanoid-ppo", config=vars(config))