From 88caa5a8667bd4edcf9b512c8152ff23a8b6bb3d Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Mon, 5 Aug 2024 19:08:53 +0000 Subject: [PATCH 1/6] added tricks from blogpost --- README.md | 4 ++ environment.py | 4 +- train.py | 156 +++++++++++++++++++++++++++++++++++++------------ 3 files changed, 125 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 99a3caa..4ab7bec 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,10 @@ Minimal training and inference code for making a humanoid robot stand up. - [ ] Implement simple PPO policy to try to make the robot stand up - [ ] Parallelize using JAX +# Findings +- Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty +- Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward. + ## Goals - The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files: diff --git a/environment.py b/environment.py index a74df81..a350d4f 100644 --- a/environment.py +++ b/environment.py @@ -141,7 +141,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0) is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy) - is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0) + # is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0) ctrl_cost = -jp.sum(jp.square(action)) @@ -158,7 +158,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr # ) # jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True) - total_reward = 2.0 * is_healthy + 0.1 * ctrl_cost - 5.0 * is_bad + total_reward = 0.1 * ctrl_cost + 5 * state.q[2] return total_reward diff --git a/train.py b/train.py index d1f9ac1..322b00b 100644 --- a/train.py +++ b/train.py @@ -24,26 +24,28 @@ @dataclass class Config: - lr_actor: float = field(default=3e-4, metadata={"help": "Learning rate for the actor network."}) - lr_critic: float = field(default=3e-4, metadata={"help": "Learning rate for the critic network."}) + lr_actor: float = field(default=2.5e-4, metadata={"help": "Learning rate for the actor network."}) + lr_critic: float = field(default=2.5e-4, metadata={"help": "Learning rate for the critic network."}) num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."}) num_envs: int = field(default=16, metadata={"help": "Number of environments to run at once with vectorization."}) max_steps_per_episode: int = field( - default=512 * 16, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} + default=128 * 16, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} ) max_steps_per_iteration: int = field( - default=1024 * 16, + default=512 * 16, metadata={ "help": "Maximum number of steps per iteration of simulating environments (across ALL environments)." }, ) - gamma: float = field(default=0.98, metadata={"help": "Discount factor for future rewards."}) - lambd: float = field(default=0.99, metadata={"help": "Lambda parameter for GAE calculation."}) - batch_size: int = field(default=64, metadata={"help": "Batch size for training updates."}) + gamma: float = field(default=0.99, metadata={"help": "Discount factor for future rewards."}) + lambd: float = field(default=0.95, metadata={"help": "Lambda parameter for GAE calculation."}) + batch_size: int = field(default=32, metadata={"help": "Batch size for training updates."}) epsilon: float = field(default=0.2, metadata={"help": "Clipping parameter for PPO."}) l2_rate: float = field(default=0.001, metadata={"help": "L2 regularization rate for the critic."}) + entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."}) +# NOTE: change how initialize weights? class Actor(eqx.Module): """Actor network for PPO.""" @@ -59,6 +61,12 @@ def __init__(self, input_size: int, action_size: int, key: Array) -> None: self.mu_layer = eqx.nn.Linear(64, action_size, key=keys[2]) self.log_sigma_layer = eqx.nn.Linear(64, action_size, key=keys[3]) + # Parameter initialization according to Trick #2 + self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) + self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1]) + self.mu_layer = self.initialize_layer(self.mu_layer, 0.01, keys[2]) + self.log_sigma_layer = self.initialize_layer(self.log_sigma_layer, 0.01, keys[3]) + def __call__(self, x: Array) -> Tuple[Array, Array]: x = jax.nn.tanh(self.linear1(x)) x = jax.nn.tanh(self.linear2(x)) @@ -66,6 +74,24 @@ def __call__(self, x: Array) -> Tuple[Array, Array]: log_sigma = self.log_sigma_layer(x) return mu, jnp.exp(log_sigma) + def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: + weight_shape = layer.weight.shape + + initializer = jax.nn.initializers.orthogonal() + new_weight = initializer(key, weight_shape, jnp.float32) * scale + new_bias = jnp.zeros(layer.bias.shape) if layer.bias else None + + def where_weight(layer: eqx.nn.Linear) -> Array: + return layer.weight + + def where_bias(layer: eqx.nn.Linear) -> Array | None: + return layer.bias + + new_layer = eqx.tree_at(where_weight, layer, new_weight) + new_layer = eqx.tree_at(where_bias, new_layer, new_bias) + + return new_layer + class Critic(eqx.Module): """Critic network for PPO.""" @@ -80,11 +106,34 @@ def __init__(self, input_size: int, key: Array) -> None: self.linear2 = eqx.nn.Linear(64, 64, key=keys[1]) self.value_layer = eqx.nn.Linear(64, 1, key=keys[2]) + # Parameter initialization according to Trick #2 + self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) + self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1]) + self.value_layer = self.initialize_layer(self.value_layer, 1.0, keys[2]) + def __call__(self, x: Array) -> Array: x = jax.nn.tanh(self.linear1(x)) x = jax.nn.tanh(self.linear2(x)) return self.value_layer(x) + def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: + weight_shape = layer.weight.shape + + initializer = jax.nn.initializers.orthogonal() + new_weight = initializer(key, weight_shape, jnp.float32) * scale + new_bias = jnp.zeros(layer.bias.shape) if layer.bias else None + + def where_weight(layer: eqx.nn.Linear) -> Array: + return layer.weight + + def where_bias(layer: eqx.nn.Linear) -> Array | None: + return layer.bias + + new_layer = eqx.tree_at(where_weight, layer, new_weight) + new_layer = eqx.tree_at(where_bias, new_layer, new_bias) + + return new_layer + class Ppo: def __init__(self, observation_size: int, action_size: int, config: Config, key: Array) -> None: @@ -93,8 +142,21 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: self.actor = Actor(observation_size, action_size, subkey1) self.critic = Critic(observation_size, subkey2) - self.actor_optim = optax.adam(learning_rate=config.lr_actor) - self.critic_optim = optax.adamw(learning_rate=config.lr_critic, weight_decay=config.l2_rate) + total_timesteps = config.num_iterations + + # Learning rate annealing according to Trick #4 + self.actor_schedule = optax.linear_schedule( + init_value=config.lr_actor, end_value=0, transition_steps=total_timesteps + ) + self.critic_schedule = optax.linear_schedule( + init_value=config.lr_critic, end_value=0, transition_steps=total_timesteps + ) + + # eps below according to Trick #3 + self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5)) + self.critic_optim = optax.chain( + optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5) + ) # Initialize optimizer states self.actor_opt_state = self.actor_optim.init(eqx.filter(self.actor, eqx.is_array)) @@ -133,8 +195,9 @@ def train_step( params: Dict[str, Any], states_b: Array, actions_b: Array, - rewards_b: Array, - masks_b: Array, + returns_b: Array, + advants_b: Array, + old_log_prob_b: Array, config: Config, ) -> Tuple[Dict[str, Any], Array, Array]: """Perform a single training step with PPO parameters.""" @@ -143,11 +206,8 @@ def train_step( actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) - values_b = critic_vmap(critic, states_b).squeeze() - returns_b, advants_b = get_gae(rewards_b, masks_b, values_b, config) - - old_mu_b, old_std_b = actor_vmap(actor, states_b) - old_log_prob_b = actor_log_prob(old_mu_b, old_std_b, actions_b) + # Normalizing advantages *in minibatch* according to Trick #7 + advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8) @eqx.filter_value_and_grad def actor_loss_fn(actor: Actor) -> Array: @@ -155,13 +215,18 @@ def actor_loss_fn(actor: Actor) -> Array: mu_b, std_b = actor_vmap(actor, states_b) new_log_prob_b = actor_log_prob(mu_b, std_b, actions_b) + # Calculating the ratio of new and old probabilities ratio_b = jnp.exp(new_log_prob_b - old_log_prob_b) + breakpoint() surrogate_loss_b = ratio_b * advants_b # Clipping is done to prevent too much change if new advantages are very large clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) - return actor_loss + + entropy_loss = jnp.mean(jax.scipy.stats.norm.entropy(mu_b, std_b)) + total_loss = actor_loss - config.entropy_coeff * entropy_loss + return total_loss @eqx.filter_value_and_grad def critic_loss_fn(critic: Critic) -> Array: @@ -205,11 +270,10 @@ def gae_step(carry: Tuple[Array, Array], inp: Tuple[Array, Array, Array]) -> Tup _, advantages = jax.lax.scan( f=gae_step, init=(jnp.zeros_like(rewards[-1]), values[-1]), - xs=(rewards[::-1], masks[::-1], values[::-1]), + xs=(rewards, masks, values), # NOTE: correct direction? reverse=True, ) returns = advantages + values - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return returns, advantages @@ -224,6 +288,18 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con rewards = jnp.array([e[2] for e in memory]) masks = jnp.array([e[3] for e in memory]) + # Calculate old log probabilities + actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) + old_mu, old_std = actor_vmap(ppo.actor, states) + old_log_prob = actor_log_prob(old_mu, old_std, actions) + + # Calculate values for all states + critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) + values = critic_vmap(ppo.critic, states).squeeze() + + # Calculate GAE and returns + returns, advantages = get_gae(rewards, masks, values, config) + n = len(states) arr = jnp.arange(n) key = jax.random.PRNGKey(0) @@ -240,8 +316,9 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con batch_indices = arr[config.batch_size * i : config.batch_size * (i + 1)] states_b = states[batch_indices] actions_b = actions[batch_indices] - rewards_b = rewards[batch_indices] - masks_b = masks[batch_indices] + returns_b = returns[batch_indices] + advantages_b = advantages[batch_indices] + old_log_prob_b = old_log_prob[batch_indices] params = ppo.get_params() new_params, actor_loss, critic_loss = train_step( @@ -250,8 +327,9 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con params, states_b, actions_b, - rewards_b, - masks_b, + returns_b, + advantages_b, + old_log_prob_b, config, ) ppo.update_params(new_params) @@ -378,26 +456,27 @@ def step_fn(states: State, actions: jax.Array) -> State: for i in range(1, config.num_iterations + 1): # Initialize memory as JAX arrays - memory = { - "states": jnp.empty((0, observation_size)), - "actions": jnp.empty((0, action_size)), - "rewards": jnp.empty((0,)), - "masks": jnp.empty((0,)), - } scores = [] steps = 0 rollout: List[MjxState] = [] + rng, reset_rng = jax.random.split(rng) + states = reset_fn(reset_rng) pbar = tqdm(total=config.max_steps_per_iteration, desc=f"Steps for iteration {i}") while steps < config.max_steps_per_iteration: episodes += config.num_envs - rng, reset_rng = jax.random.split(rng) - states = reset_fn(reset_rng) obs = jax.device_put(states.obs) score = jnp.zeros(config.num_envs) + memory = { + "states": jnp.empty((0, observation_size)), + "actions": jnp.empty((0, action_size)), + "rewards": jnp.empty((0,)), + "masks": jnp.empty((0,)), + } + for _ in range(config.max_steps_per_episode): # Choosing actions @@ -428,12 +507,21 @@ def step_fn(states: State, actions: jax.Array) -> State: rollout.extend(unwrapped_states) if jnp.all(dones): + rng, reset_rng = jax.random.split(rng) + states = reset_fn(reset_rng) break # with open("log_" + args.env_name + ".txt", "a") as outfile: # outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n") scores.append(jnp.mean(score)) + # Convert memory to the format expected by ppo.train + train_memory = [ + (s, a, r, m) + for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"]) + ] + train(ppo, train_memory, config) + score_avg = float(jnp.mean(jnp.array(scores))) pbar.close() logger.info("Episode %s score is %.2f", episodes, score_avg) @@ -456,12 +544,6 @@ def step_fn(states: State, actions: jax.Array) -> State: logger.info("Saving video to %s for iteration %d", video_path, i) media.write_video(video_path, images, fps=fps) - # Convert memory to the format expected by ppo.train - train_memory = [ - (s, a, r, m) for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"]) - ] - train(ppo, train_memory, config) - if __name__ == "__main__": main() From 1cd581720b361478fffc38f995d0011a80b43cf4 Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Mon, 5 Aug 2024 20:07:45 +0000 Subject: [PATCH 2/6] reorder memory --- train.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/train.py b/train.py index 322b00b..3a1025e 100644 --- a/train.py +++ b/train.py @@ -56,10 +56,10 @@ class Actor(eqx.Module): def __init__(self, input_size: int, action_size: int, key: Array) -> None: keys = jax.random.split(key, 4) - self.linear1 = eqx.nn.Linear(input_size, 64, key=keys[0]) - self.linear2 = eqx.nn.Linear(64, 64, key=keys[1]) - self.mu_layer = eqx.nn.Linear(64, action_size, key=keys[2]) - self.log_sigma_layer = eqx.nn.Linear(64, action_size, key=keys[3]) + self.linear1 = eqx.nn.Linear(input_size, 256, key=keys[0]) + self.linear2 = eqx.nn.Linear(256, 256, key=keys[1]) + self.mu_layer = eqx.nn.Linear(256, action_size, key=keys[2]) + self.log_sigma_layer = eqx.nn.Linear(256, action_size, key=keys[3]) # Parameter initialization according to Trick #2 self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) @@ -79,7 +79,7 @@ def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eq initializer = jax.nn.initializers.orthogonal() new_weight = initializer(key, weight_shape, jnp.float32) * scale - new_bias = jnp.zeros(layer.bias.shape) if layer.bias else None + new_bias = jnp.zeros(layer.bias.shape) if layer.bias is not None else None def where_weight(layer: eqx.nn.Linear) -> Array: return layer.weight @@ -102,9 +102,9 @@ class Critic(eqx.Module): def __init__(self, input_size: int, key: Array) -> None: keys = jax.random.split(key, 3) - self.linear1 = eqx.nn.Linear(input_size, 64, key=keys[0]) - self.linear2 = eqx.nn.Linear(64, 64, key=keys[1]) - self.value_layer = eqx.nn.Linear(64, 1, key=keys[2]) + self.linear1 = eqx.nn.Linear(input_size, 256, key=keys[0]) + self.linear2 = eqx.nn.Linear(256, 256, key=keys[1]) + self.value_layer = eqx.nn.Linear(256, 1, key=keys[2]) # Parameter initialization according to Trick #2 self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) @@ -121,7 +121,7 @@ def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eq initializer = jax.nn.initializers.orthogonal() new_weight = initializer(key, weight_shape, jnp.float32) * scale - new_bias = jnp.zeros(layer.bias.shape) if layer.bias else None + new_bias = jnp.zeros(layer.bias.shape) if layer.bias is not None else None def where_weight(layer: eqx.nn.Linear) -> Array: return layer.weight @@ -217,14 +217,14 @@ def actor_loss_fn(actor: Actor) -> Array: # Calculating the ratio of new and old probabilities ratio_b = jnp.exp(new_log_prob_b - old_log_prob_b) - breakpoint() surrogate_loss_b = ratio_b * advants_b # Clipping is done to prevent too much change if new advantages are very large clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b - actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) - entropy_loss = jnp.mean(jax.scipy.stats.norm.entropy(mu_b, std_b)) + actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) + entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * std_b**2) + 1)) + total_loss = actor_loss - config.entropy_coeff * entropy_loss return total_loss @@ -279,8 +279,6 @@ def gae_step(carry: Tuple[Array, Array], inp: Tuple[Array, Array, Array]) -> Tup def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Config) -> None: """Train the PPO model using the memory collected from the environment.""" - # NOTE: think this needs to be reimplemented for vectorization because currently, - # doesn't account that memory order is maintained # Reorders memory according to states, actions, rewards, masks states = jnp.array([e[0] for e in memory]) @@ -297,6 +295,8 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) values = critic_vmap(ppo.critic, states).squeeze() + # NOTE: are the output shapes correct? + # Calculate GAE and returns returns, advantages = get_gae(rewards, masks, values, config) @@ -408,6 +408,16 @@ def update_memory(memory: Dict[str, Array], new_data: Dict[str, Array]) -> Dict[ return jax.tree.map(lambda x, y: jnp.concatenate([x, y]), memory, new_data) +def reorder_memory(memory, num_envs): + reordered_memory = { + "states": jnp.concatenate([memory["states"][i::num_envs] for i in range(num_envs)], axis=0), + "actions": jnp.concatenate([memory["actions"][i::num_envs] for i in range(num_envs)], axis=0), + "rewards": jnp.concatenate([memory["rewards"][i::num_envs] for i in range(num_envs)], axis=0), + "masks": jnp.concatenate([memory["masks"][i::num_envs] for i in range(num_envs)], axis=0), + } + return reordered_memory + + def main() -> None: logging.basicConfig( level=logging.INFO, @@ -516,6 +526,9 @@ def step_fn(states: State, actions: jax.Array) -> State: scores.append(jnp.mean(score)) # Convert memory to the format expected by ppo.train + + memory = reorder_memory(memory, config.num_envs) + train_memory = [ (s, a, r, m) for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"]) From 6462e27031e8496dee3985d73726a098eca444d1 Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Mon, 5 Aug 2024 22:12:17 +0000 Subject: [PATCH 3/6] shown trying to stand up + increasing score? not sure if this is because i literally just increased size of actor/critic models. maybe should do some ablation tests. or get working fully first! (better reward func) --- .gitignore | 3 +-- README.md | 3 +++ train.py | 50 +++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 4a98f29..269b2df 100644 --- a/.gitignore +++ b/.gitignore @@ -18,10 +18,9 @@ dist/ out*/ *.stl *.txt +*.mp4 -videos/ environments/ screenshots/ scratch/ -videos/ assets/ diff --git a/README.md b/README.md index 4ab7bec..d85e46a 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,9 @@ Minimal training and inference code for making a humanoid robot stand up. - Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty - Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward. +# Currently tests: +- Hidden layer size of 256 shows progress (loss is based on state.q[2]) + ## Goals - The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files: diff --git a/train.py b/train.py index 3a1025e..3a311b5 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,8 @@ """Trains a policy network to get a humanoid to stand up.""" import argparse +from functools import partial +import wandb import logging import os from dataclasses import dataclass, field @@ -27,12 +29,12 @@ class Config: lr_actor: float = field(default=2.5e-4, metadata={"help": "Learning rate for the actor network."}) lr_critic: float = field(default=2.5e-4, metadata={"help": "Learning rate for the critic network."}) num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."}) - num_envs: int = field(default=16, metadata={"help": "Number of environments to run at once with vectorization."}) + num_envs: int = field(default=2048, metadata={"help": "Number of environments to run at once with vectorization."}) max_steps_per_episode: int = field( - default=128 * 16, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} + default=128 * 2048, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} ) max_steps_per_iteration: int = field( - default=512 * 16, + default=512 * 2048, metadata={ "help": "Maximum number of steps per iteration of simulating environments (across ALL environments)." }, @@ -209,7 +211,7 @@ def train_step( # Normalizing advantages *in minibatch* according to Trick #7 advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8) - @eqx.filter_value_and_grad + @partial(eqx.filter_value_and_grad, has_aux=True) def actor_loss_fn(actor: Actor) -> Array: """Prioritizing advantageous actions over more training.""" mu_b, std_b = actor_vmap(actor, states_b) @@ -224,9 +226,11 @@ def actor_loss_fn(actor: Actor) -> Array: actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * std_b**2) + 1)) - + + fraction_clipped = jnp.mean(jnp.abs(ratio_b - 1.0) > config.epsilon) + total_loss = actor_loss - config.entropy_coeff * entropy_loss - return total_loss + return total_loss, (actor_loss, entropy_loss, fraction_clipped) @eqx.filter_value_and_grad def critic_loss_fn(critic: Critic) -> Array: @@ -236,7 +240,7 @@ def critic_loss_fn(critic: Critic) -> Array: return critic_loss # Calculating actor loss and updating actor parameters - actor_loss, actor_grads = actor_loss_fn(actor) + (_, actor_loss, entropy_loss, fraction_clipped), actor_grads = actor_loss_fn(actor) actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) @@ -252,7 +256,7 @@ def critic_loss_fn(critic: Critic) -> Array: "critic_opt_state": new_critic_opt_state, } - return new_params, actor_loss, critic_loss + return new_params, (actor_loss, critic_loss, entropy_loss, fraction_clipped) def get_gae(rewards: Array, masks: Array, values: Array, config: Config) -> Tuple[Array, Array]: @@ -307,8 +311,15 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con for epoch in range(1): key, subkey = jax.random.split(key) arr = jax.random.permutation(subkey, arr) + + # Calculate average advantages and returns + avg_advantages = jnp.mean(advantages) + avg_returns = jnp.mean(returns) total_actor_loss = 0.0 total_critic_loss = 0.0 + total_entropy_loss = 0.0 + total_fraction_clipped = 0.0 + logger.info("Processing %d batches", n // config.batch_size) for i in range(n // config.batch_size): @@ -321,7 +332,7 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con old_log_prob_b = old_log_prob[batch_indices] params = ppo.get_params() - new_params, actor_loss, critic_loss = train_step( + new_params, (actor_loss, critic_loss, entropy_loss, fraction_clipped) = train_step( ppo.actor_optim, ppo.critic_optim, params, @@ -336,12 +347,28 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con total_actor_loss += actor_loss.mean().item() total_critic_loss += critic_loss.mean().item() + total_entropy_loss += entropy_loss.item() + total_fraction_clipped += fraction_clipped.item() mean_actor_loss = total_actor_loss / (n // config.batch_size) mean_critic_loss = total_critic_loss / (n // config.batch_size) + mean_entropy_loss = total_entropy_loss / (n // config.batch_size) + mean_fraction_clipped = total_fraction_clipped / (n // config.batch_size) logger.info(f"Mean Actor Loss: {mean_actor_loss}, Mean Critic Loss: {mean_critic_loss}") + # Log metrics to wandb + wandb.log( + { + "actor_loss": mean_actor_loss, + "critic_loss": mean_critic_loss, + "entropy_loss": mean_entropy_loss, + "fraction_clipped": mean_fraction_clipped, + "avg_advantages": avg_advantages, + "avg_returns": avg_returns, + } + ) + def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array: """Calculate the log probability of the actions given the actor network's output.""" @@ -525,8 +552,6 @@ def step_fn(states: State, actions: jax.Array) -> State: # outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n") scores.append(jnp.mean(score)) - # Convert memory to the format expected by ppo.train - memory = reorder_memory(memory, config.num_envs) train_memory = [ @@ -538,6 +563,8 @@ def step_fn(states: State, actions: jax.Array) -> State: score_avg = float(jnp.mean(jnp.array(scores))) pbar.close() logger.info("Episode %s score is %.2f", episodes, score_avg) + + wandb.log({"score": score_avg, "episode": episodes}) # Save video for this iteration if args.save_video_every and i % args.save_video_every == 0 and rollout: @@ -560,3 +587,4 @@ def step_fn(states: State, actions: jax.Array) -> State: if __name__ == "__main__": main() + wandb.init(project="humanoid-ppo", config=vars(config)) From 053ff79179d85194102a179be675d38ff210443f Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Tue, 6 Aug 2024 18:18:53 +0000 Subject: [PATCH 4/6] NANs in training being fixed WIP Got nans in training and trying to fix by norm'ing obs and also adding epsilon to action_log_pdf. Also some cleaning --- .gitignore | 1 + environment.py | 36 ++++++++++++++++----------- train.py | 67 ++++++++++++++++++++++++++++++-------------------- 3 files changed, 63 insertions(+), 41 deletions(-) diff --git a/.gitignore b/.gitignore index 269b2df..9a5d811 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ environments/ screenshots/ scratch/ assets/ +wandb/ diff --git a/environment.py b/environment.py index a350d4f..74ede2b 100644 --- a/environment.py +++ b/environment.py @@ -125,6 +125,7 @@ def step(self, env_state: State, action: jp.ndarray) -> State: """Run one timestep of the environment's dynamics and returns observations with rewards.""" state = env_state.pipeline_state next_state = self.pipeline_step(state, action) + obs = self.get_obs(state, action) reward = self.compute_reward(state, next_state, action) @@ -158,7 +159,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr # ) # jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True) - total_reward = 0.1 * ctrl_cost + 5 * state.q[2] + total_reward = 0.1 * ctrl_cost + 5 * is_healthy + 3 * state.q[2] return total_reward @@ -177,20 +178,25 @@ def is_done(self, state: MjxState) -> jp.ndarray: return done def get_obs(self, data: MjxState, action: jp.ndarray) -> jp.ndarray: - """Returns the observation of the environment to pass to actor/critic model.""" - position = data.qpos - position = position[2:] # excludes "current positions" - - # external_contact_forces are excluded - return jp.concatenate( - [ - position, - data.qvel, - data.cinert[1:].ravel(), - data.cvel[1:].ravel(), - data.qfrc_actuator, - ] - ) + obs_components = [ + data.qpos[2:], + data.qvel, + data.cinert[1:].ravel(), + data.cvel[1:].ravel(), + data.qfrc_actuator, + ] + + def clean_component(component): + # Check for NaNs or Infs and replace them + nan_mask = jp.isnan(component) + inf_mask = jp.isinf(component) + component = jp.where(nan_mask, 0.0, component) + component = jp.where(inf_mask, jp.where(component > 0, 1e6, -1e6), component) + return component + + cleaned_components = [clean_component(comp) for comp in obs_components] + + return jp.concatenate(cleaned_components) def run_environment_adhoc() -> None: diff --git a/train.py b/train.py index 3a311b5..cbecc7b 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,10 @@ """Trains a policy network to get a humanoid to stand up.""" import argparse -from functools import partial -import wandb import logging import os from dataclasses import dataclass, field +from functools import partial from typing import Any, Dict, List, Tuple import equinox as eqx @@ -19,6 +18,7 @@ from jax import Array from tqdm import tqdm +import wandb from environment import HumanoidEnv logger = logging.getLogger(__name__) @@ -29,12 +29,12 @@ class Config: lr_actor: float = field(default=2.5e-4, metadata={"help": "Learning rate for the actor network."}) lr_critic: float = field(default=2.5e-4, metadata={"help": "Learning rate for the critic network."}) num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."}) - num_envs: int = field(default=2048, metadata={"help": "Number of environments to run at once with vectorization."}) + num_envs: int = field(default=32, metadata={"help": "Number of environments to run at once with vectorization."}) max_steps_per_episode: int = field( - default=128 * 2048, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} + default=128 * 32, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} ) max_steps_per_iteration: int = field( - default=512 * 2048, + default=512 * 32, metadata={ "help": "Maximum number of steps per iteration of simulating environments (across ALL environments)." }, @@ -47,7 +47,6 @@ class Config: entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."}) -# NOTE: change how initialize weights? class Actor(eqx.Module): """Actor network for PPO.""" @@ -148,10 +147,10 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: # Learning rate annealing according to Trick #4 self.actor_schedule = optax.linear_schedule( - init_value=config.lr_actor, end_value=0, transition_steps=total_timesteps + init_value=config.lr_actor, end_value=1e-6, transition_steps=total_timesteps ) self.critic_schedule = optax.linear_schedule( - init_value=config.lr_critic, end_value=0, transition_steps=total_timesteps + init_value=config.lr_critic, end_value=1e-6, transition_steps=total_timesteps ) # eps below according to Trick #3 @@ -239,8 +238,8 @@ def critic_loss_fn(critic: Critic) -> Array: critic_loss = jnp.mean((critic_returns_b - returns_b) ** 2) return critic_loss - # Calculating actor loss and updating actor parameters - (_, actor_loss, entropy_loss, fraction_clipped), actor_grads = actor_loss_fn(actor) + # Calculating actor loss and updating actor parameters --- outputting auxillary data for logging + (_, (actor_loss, entropy_loss, fraction_clipped)), actor_grads = actor_loss_fn(actor) actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) @@ -283,7 +282,6 @@ def gae_step(carry: Tuple[Array, Array], inp: Tuple[Array, Array, Array]) -> Tup def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Config) -> None: """Train the PPO model using the memory collected from the environment.""" - # Reorders memory according to states, actions, rewards, masks states = jnp.array([e[0] for e in memory]) actions = jnp.array([e[1] for e in memory]) @@ -372,7 +370,7 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array: """Calculate the log probability of the actions given the actor network's output.""" - return jax.scipy.stats.norm.logpdf(actions, mu, sigma).sum(axis=-1) + return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-8).sum(axis=-1) def actor_distribution(mu: Array, sigma: Array, rng: Array) -> Array: @@ -380,7 +378,7 @@ def actor_distribution(mu: Array, sigma: Array, rng: Array) -> Array: return jax.random.normal(rng, shape=mu.shape) * sigma + mu -def unwrap_state_vectorization(state: State, config: Config) -> State: +def unwrap_state_vectorization(state: State, envs_to_sample: int) -> State: """Unwraps one environment the vectorized rollout so that the frames in videos are correctly ordered.""" unwrapped_rollout = [] # Get all attributes of the state @@ -389,7 +387,7 @@ def unwrap_state_vectorization(state: State, config: Config) -> State: # NOTE: can change ordering of this to save runtiem if want to save more vectorized states. # NOTE: (but anyways, the video isn't correctly ordered then) # saves from only first vectorized state - for i in range(1): + for i in range(envs_to_sample): # Create a new state with the first element of each attribute new_state = {} for attr in attributes: @@ -435,7 +433,7 @@ def update_memory(memory: Dict[str, Array], new_data: Dict[str, Array]) -> Dict[ return jax.tree.map(lambda x, y: jnp.concatenate([x, y]), memory, new_data) -def reorder_memory(memory, num_envs): +def reorder_memory(memory: Dict[str, Array], num_envs: int) -> Dict[str, Array]: reordered_memory = { "states": jnp.concatenate([memory["states"][i::num_envs] for i in range(num_envs)], axis=0), "actions": jnp.concatenate([memory["actions"][i::num_envs] for i in range(num_envs)], axis=0), @@ -458,11 +456,13 @@ def main() -> None: parser.add_argument("--width", type=int, default=640, help="width of the video frame") parser.add_argument("--height", type=int, default=480, help="height of the video frame") parser.add_argument("--render_every", type=int, default=2, help="render the environment every N steps") - parser.add_argument("--video_length", type=int, default=5, help="maxmimum length of video in seconds") + parser.add_argument("--video_length", type=int, default=10, help="maxmimum length of video in seconds") parser.add_argument("--save_video_every", type=int, default=100, help="save video every N iterations") + parser.add_argument("--envs_to_sample", type=int, default=4, help="number of environments to sample for video") args = parser.parse_args() config = Config() + wandb.init(project="humanoid-ppo", config=vars(config)) env = HumanoidEnv() observation_size = env.observation_size @@ -500,11 +500,13 @@ def step_fn(states: State, actions: jax.Array) -> State: rng, reset_rng = jax.random.split(rng) states = reset_fn(reset_rng) pbar = tqdm(total=config.max_steps_per_iteration, desc=f"Steps for iteration {i}") + wait = -1 while steps < config.max_steps_per_iteration: episodes += config.num_envs - obs = jax.device_put(states.obs) + norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / (jnp.std(states.obs, axis=1, keepdims=True) + 1e-8) + obs = jax.device_put(norm_obs) score = jnp.zeros(config.num_envs) memory = { @@ -522,8 +524,9 @@ def step_fn(states: State, actions: jax.Array) -> State: actions = choose_action_vmap(ppo.actor, obs, jnp.array(action_rng)) states = step_fn(states, actions) - next_obs, rewards, dones = states.obs, states.reward, states.done - masks = (1 - dones).astype(jnp.float32) + norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / (jnp.std(states.obs, axis=1, keepdims=True) + 1e-8) + next_obs, rewards, dones = norm_obs, states.reward, states.done + masks = (1 - 0.8 * dones).astype(jnp.float32) # Update memory new_data = {"states": obs, "actions": actions, "rewards": rewards, "masks": masks} @@ -540,13 +543,21 @@ def step_fn(states: State, actions: jax.Array) -> State: and i % args.save_video_every == 0 and len(rollout) < args.video_length * int(1 / env.dt) ): - unwrapped_states = unwrap_state_vectorization(states.pipeline_state, config) + unwrapped_states = unwrap_state_vectorization(states.pipeline_state, args.envs_to_sample) rollout.extend(unwrapped_states) if jnp.all(dones): - rng, reset_rng = jax.random.split(rng) - states = reset_fn(reset_rng) - break + + # NOTE: this "waiting" penalizes systems that just want to fall quick (and reset quick) + # goes from needing 84 steps to 30 steps + if wait == -1: + wait = 10 + else: + wait -= 1 + if wait == 0: + rng, reset_rng = jax.random.split(rng) + states = reset_fn(reset_rng) + break # with open("log_" + args.env_name + ".txt", "a") as outfile: # outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n") @@ -563,13 +574,18 @@ def step_fn(states: State, actions: jax.Array) -> State: score_avg = float(jnp.mean(jnp.array(scores))) pbar.close() logger.info("Episode %s score is %.2f", episodes, score_avg) - + wandb.log({"score": score_avg, "episode": episodes}) # Save video for this iteration if args.save_video_every and i % args.save_video_every == 0 and rollout: + + reordered_rollout = [ + frame for i in range(args.envs_to_sample) for frame in rollout[i :: args.envs_to_sample] + ] + images = jnp.array( - env.render(rollout[:: args.render_every], camera="side", width=args.width, height=args.height) + env.render(reordered_rollout[:: args.render_every], camera="side", width=args.width, height=args.height) ) fps = int(1 / env.dt) @@ -587,4 +603,3 @@ def step_fn(states: State, actions: jax.Array) -> State: if __name__ == "__main__": main() - wandb.init(project="humanoid-ppo", config=vars(config)) From d105e294df43e7faf402b0071753e76d80487e64 Mon Sep 17 00:00:00 2001 From: Nathan Zhao <43712744+nathanjzhao@users.noreply.github.com> Date: Tue, 6 Aug 2024 19:51:31 -0700 Subject: [PATCH 5/6] added some tricks from blogpost (#15) * added tricks from blogpost * reorder memory * shown trying to stand up + increasing score? not sure if this is because i literally just increased size of actor/critic models. maybe should do some ablation tests. or get working fully first! (better reward func) --- .gitignore | 3 +- README.md | 7 ++ environment.py | 4 +- train.py | 225 ++++++++++++++++++++++++++++++++++++++----------- 4 files changed, 184 insertions(+), 55 deletions(-) diff --git a/.gitignore b/.gitignore index 4a98f29..269b2df 100644 --- a/.gitignore +++ b/.gitignore @@ -18,10 +18,9 @@ dist/ out*/ *.stl *.txt +*.mp4 -videos/ environments/ screenshots/ scratch/ -videos/ assets/ diff --git a/README.md b/README.md index 99a3caa..d85e46a 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,13 @@ Minimal training and inference code for making a humanoid robot stand up. - [ ] Implement simple PPO policy to try to make the robot stand up - [ ] Parallelize using JAX +# Findings +- Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty +- Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward. + +# Currently tests: +- Hidden layer size of 256 shows progress (loss is based on state.q[2]) + ## Goals - The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files: diff --git a/environment.py b/environment.py index a74df81..a350d4f 100644 --- a/environment.py +++ b/environment.py @@ -141,7 +141,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0) is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy) - is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0) + # is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0) ctrl_cost = -jp.sum(jp.square(action)) @@ -158,7 +158,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr # ) # jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True) - total_reward = 2.0 * is_healthy + 0.1 * ctrl_cost - 5.0 * is_bad + total_reward = 0.1 * ctrl_cost + 5 * state.q[2] return total_reward diff --git a/train.py b/train.py index d1f9ac1..3a311b5 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,8 @@ """Trains a policy network to get a humanoid to stand up.""" import argparse +from functools import partial +import wandb import logging import os from dataclasses import dataclass, field @@ -24,26 +26,28 @@ @dataclass class Config: - lr_actor: float = field(default=3e-4, metadata={"help": "Learning rate for the actor network."}) - lr_critic: float = field(default=3e-4, metadata={"help": "Learning rate for the critic network."}) + lr_actor: float = field(default=2.5e-4, metadata={"help": "Learning rate for the actor network."}) + lr_critic: float = field(default=2.5e-4, metadata={"help": "Learning rate for the critic network."}) num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."}) - num_envs: int = field(default=16, metadata={"help": "Number of environments to run at once with vectorization."}) + num_envs: int = field(default=2048, metadata={"help": "Number of environments to run at once with vectorization."}) max_steps_per_episode: int = field( - default=512 * 16, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} + default=128 * 2048, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} ) max_steps_per_iteration: int = field( - default=1024 * 16, + default=512 * 2048, metadata={ "help": "Maximum number of steps per iteration of simulating environments (across ALL environments)." }, ) - gamma: float = field(default=0.98, metadata={"help": "Discount factor for future rewards."}) - lambd: float = field(default=0.99, metadata={"help": "Lambda parameter for GAE calculation."}) - batch_size: int = field(default=64, metadata={"help": "Batch size for training updates."}) + gamma: float = field(default=0.99, metadata={"help": "Discount factor for future rewards."}) + lambd: float = field(default=0.95, metadata={"help": "Lambda parameter for GAE calculation."}) + batch_size: int = field(default=32, metadata={"help": "Batch size for training updates."}) epsilon: float = field(default=0.2, metadata={"help": "Clipping parameter for PPO."}) l2_rate: float = field(default=0.001, metadata={"help": "L2 regularization rate for the critic."}) + entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."}) +# NOTE: change how initialize weights? class Actor(eqx.Module): """Actor network for PPO.""" @@ -54,10 +58,16 @@ class Actor(eqx.Module): def __init__(self, input_size: int, action_size: int, key: Array) -> None: keys = jax.random.split(key, 4) - self.linear1 = eqx.nn.Linear(input_size, 64, key=keys[0]) - self.linear2 = eqx.nn.Linear(64, 64, key=keys[1]) - self.mu_layer = eqx.nn.Linear(64, action_size, key=keys[2]) - self.log_sigma_layer = eqx.nn.Linear(64, action_size, key=keys[3]) + self.linear1 = eqx.nn.Linear(input_size, 256, key=keys[0]) + self.linear2 = eqx.nn.Linear(256, 256, key=keys[1]) + self.mu_layer = eqx.nn.Linear(256, action_size, key=keys[2]) + self.log_sigma_layer = eqx.nn.Linear(256, action_size, key=keys[3]) + + # Parameter initialization according to Trick #2 + self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) + self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1]) + self.mu_layer = self.initialize_layer(self.mu_layer, 0.01, keys[2]) + self.log_sigma_layer = self.initialize_layer(self.log_sigma_layer, 0.01, keys[3]) def __call__(self, x: Array) -> Tuple[Array, Array]: x = jax.nn.tanh(self.linear1(x)) @@ -66,6 +76,24 @@ def __call__(self, x: Array) -> Tuple[Array, Array]: log_sigma = self.log_sigma_layer(x) return mu, jnp.exp(log_sigma) + def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: + weight_shape = layer.weight.shape + + initializer = jax.nn.initializers.orthogonal() + new_weight = initializer(key, weight_shape, jnp.float32) * scale + new_bias = jnp.zeros(layer.bias.shape) if layer.bias is not None else None + + def where_weight(layer: eqx.nn.Linear) -> Array: + return layer.weight + + def where_bias(layer: eqx.nn.Linear) -> Array | None: + return layer.bias + + new_layer = eqx.tree_at(where_weight, layer, new_weight) + new_layer = eqx.tree_at(where_bias, new_layer, new_bias) + + return new_layer + class Critic(eqx.Module): """Critic network for PPO.""" @@ -76,15 +104,38 @@ class Critic(eqx.Module): def __init__(self, input_size: int, key: Array) -> None: keys = jax.random.split(key, 3) - self.linear1 = eqx.nn.Linear(input_size, 64, key=keys[0]) - self.linear2 = eqx.nn.Linear(64, 64, key=keys[1]) - self.value_layer = eqx.nn.Linear(64, 1, key=keys[2]) + self.linear1 = eqx.nn.Linear(input_size, 256, key=keys[0]) + self.linear2 = eqx.nn.Linear(256, 256, key=keys[1]) + self.value_layer = eqx.nn.Linear(256, 1, key=keys[2]) + + # Parameter initialization according to Trick #2 + self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0]) + self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1]) + self.value_layer = self.initialize_layer(self.value_layer, 1.0, keys[2]) def __call__(self, x: Array) -> Array: x = jax.nn.tanh(self.linear1(x)) x = jax.nn.tanh(self.linear2(x)) return self.value_layer(x) + def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: + weight_shape = layer.weight.shape + + initializer = jax.nn.initializers.orthogonal() + new_weight = initializer(key, weight_shape, jnp.float32) * scale + new_bias = jnp.zeros(layer.bias.shape) if layer.bias is not None else None + + def where_weight(layer: eqx.nn.Linear) -> Array: + return layer.weight + + def where_bias(layer: eqx.nn.Linear) -> Array | None: + return layer.bias + + new_layer = eqx.tree_at(where_weight, layer, new_weight) + new_layer = eqx.tree_at(where_bias, new_layer, new_bias) + + return new_layer + class Ppo: def __init__(self, observation_size: int, action_size: int, config: Config, key: Array) -> None: @@ -93,8 +144,21 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: self.actor = Actor(observation_size, action_size, subkey1) self.critic = Critic(observation_size, subkey2) - self.actor_optim = optax.adam(learning_rate=config.lr_actor) - self.critic_optim = optax.adamw(learning_rate=config.lr_critic, weight_decay=config.l2_rate) + total_timesteps = config.num_iterations + + # Learning rate annealing according to Trick #4 + self.actor_schedule = optax.linear_schedule( + init_value=config.lr_actor, end_value=0, transition_steps=total_timesteps + ) + self.critic_schedule = optax.linear_schedule( + init_value=config.lr_critic, end_value=0, transition_steps=total_timesteps + ) + + # eps below according to Trick #3 + self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5)) + self.critic_optim = optax.chain( + optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5) + ) # Initialize optimizer states self.actor_opt_state = self.actor_optim.init(eqx.filter(self.actor, eqx.is_array)) @@ -133,8 +197,9 @@ def train_step( params: Dict[str, Any], states_b: Array, actions_b: Array, - rewards_b: Array, - masks_b: Array, + returns_b: Array, + advants_b: Array, + old_log_prob_b: Array, config: Config, ) -> Tuple[Dict[str, Any], Array, Array]: """Perform a single training step with PPO parameters.""" @@ -143,25 +208,29 @@ def train_step( actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) - values_b = critic_vmap(critic, states_b).squeeze() - returns_b, advants_b = get_gae(rewards_b, masks_b, values_b, config) - - old_mu_b, old_std_b = actor_vmap(actor, states_b) - old_log_prob_b = actor_log_prob(old_mu_b, old_std_b, actions_b) + # Normalizing advantages *in minibatch* according to Trick #7 + advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8) - @eqx.filter_value_and_grad + @partial(eqx.filter_value_and_grad, has_aux=True) def actor_loss_fn(actor: Actor) -> Array: """Prioritizing advantageous actions over more training.""" mu_b, std_b = actor_vmap(actor, states_b) new_log_prob_b = actor_log_prob(mu_b, std_b, actions_b) + # Calculating the ratio of new and old probabilities ratio_b = jnp.exp(new_log_prob_b - old_log_prob_b) surrogate_loss_b = ratio_b * advants_b # Clipping is done to prevent too much change if new advantages are very large clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b + actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) - return actor_loss + entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * std_b**2) + 1)) + + fraction_clipped = jnp.mean(jnp.abs(ratio_b - 1.0) > config.epsilon) + + total_loss = actor_loss - config.entropy_coeff * entropy_loss + return total_loss, (actor_loss, entropy_loss, fraction_clipped) @eqx.filter_value_and_grad def critic_loss_fn(critic: Critic) -> Array: @@ -171,7 +240,7 @@ def critic_loss_fn(critic: Critic) -> Array: return critic_loss # Calculating actor loss and updating actor parameters - actor_loss, actor_grads = actor_loss_fn(actor) + (_, actor_loss, entropy_loss, fraction_clipped), actor_grads = actor_loss_fn(actor) actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) @@ -187,7 +256,7 @@ def critic_loss_fn(critic: Critic) -> Array: "critic_opt_state": new_critic_opt_state, } - return new_params, actor_loss, critic_loss + return new_params, (actor_loss, critic_loss, entropy_loss, fraction_clipped) def get_gae(rewards: Array, masks: Array, values: Array, config: Config) -> Tuple[Array, Array]: @@ -205,18 +274,15 @@ def gae_step(carry: Tuple[Array, Array], inp: Tuple[Array, Array, Array]) -> Tup _, advantages = jax.lax.scan( f=gae_step, init=(jnp.zeros_like(rewards[-1]), values[-1]), - xs=(rewards[::-1], masks[::-1], values[::-1]), + xs=(rewards, masks, values), # NOTE: correct direction? reverse=True, ) returns = advantages + values - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) return returns, advantages def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Config) -> None: """Train the PPO model using the memory collected from the environment.""" - # NOTE: think this needs to be reimplemented for vectorization because currently, - # doesn't account that memory order is maintained # Reorders memory according to states, actions, rewards, masks states = jnp.array([e[0] for e in memory]) @@ -224,6 +290,20 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con rewards = jnp.array([e[2] for e in memory]) masks = jnp.array([e[3] for e in memory]) + # Calculate old log probabilities + actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) + old_mu, old_std = actor_vmap(ppo.actor, states) + old_log_prob = actor_log_prob(old_mu, old_std, actions) + + # Calculate values for all states + critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) + values = critic_vmap(ppo.critic, states).squeeze() + + # NOTE: are the output shapes correct? + + # Calculate GAE and returns + returns, advantages = get_gae(rewards, masks, values, config) + n = len(states) arr = jnp.arange(n) key = jax.random.PRNGKey(0) @@ -231,8 +311,15 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con for epoch in range(1): key, subkey = jax.random.split(key) arr = jax.random.permutation(subkey, arr) + + # Calculate average advantages and returns + avg_advantages = jnp.mean(advantages) + avg_returns = jnp.mean(returns) total_actor_loss = 0.0 total_critic_loss = 0.0 + total_entropy_loss = 0.0 + total_fraction_clipped = 0.0 + logger.info("Processing %d batches", n // config.batch_size) for i in range(n // config.batch_size): @@ -240,30 +327,48 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con batch_indices = arr[config.batch_size * i : config.batch_size * (i + 1)] states_b = states[batch_indices] actions_b = actions[batch_indices] - rewards_b = rewards[batch_indices] - masks_b = masks[batch_indices] + returns_b = returns[batch_indices] + advantages_b = advantages[batch_indices] + old_log_prob_b = old_log_prob[batch_indices] params = ppo.get_params() - new_params, actor_loss, critic_loss = train_step( + new_params, (actor_loss, critic_loss, entropy_loss, fraction_clipped) = train_step( ppo.actor_optim, ppo.critic_optim, params, states_b, actions_b, - rewards_b, - masks_b, + returns_b, + advantages_b, + old_log_prob_b, config, ) ppo.update_params(new_params) total_actor_loss += actor_loss.mean().item() total_critic_loss += critic_loss.mean().item() + total_entropy_loss += entropy_loss.item() + total_fraction_clipped += fraction_clipped.item() mean_actor_loss = total_actor_loss / (n // config.batch_size) mean_critic_loss = total_critic_loss / (n // config.batch_size) + mean_entropy_loss = total_entropy_loss / (n // config.batch_size) + mean_fraction_clipped = total_fraction_clipped / (n // config.batch_size) logger.info(f"Mean Actor Loss: {mean_actor_loss}, Mean Critic Loss: {mean_critic_loss}") + # Log metrics to wandb + wandb.log( + { + "actor_loss": mean_actor_loss, + "critic_loss": mean_critic_loss, + "entropy_loss": mean_entropy_loss, + "fraction_clipped": mean_fraction_clipped, + "avg_advantages": avg_advantages, + "avg_returns": avg_returns, + } + ) + def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array: """Calculate the log probability of the actions given the actor network's output.""" @@ -330,6 +435,16 @@ def update_memory(memory: Dict[str, Array], new_data: Dict[str, Array]) -> Dict[ return jax.tree.map(lambda x, y: jnp.concatenate([x, y]), memory, new_data) +def reorder_memory(memory, num_envs): + reordered_memory = { + "states": jnp.concatenate([memory["states"][i::num_envs] for i in range(num_envs)], axis=0), + "actions": jnp.concatenate([memory["actions"][i::num_envs] for i in range(num_envs)], axis=0), + "rewards": jnp.concatenate([memory["rewards"][i::num_envs] for i in range(num_envs)], axis=0), + "masks": jnp.concatenate([memory["masks"][i::num_envs] for i in range(num_envs)], axis=0), + } + return reordered_memory + + def main() -> None: logging.basicConfig( level=logging.INFO, @@ -378,26 +493,27 @@ def step_fn(states: State, actions: jax.Array) -> State: for i in range(1, config.num_iterations + 1): # Initialize memory as JAX arrays - memory = { - "states": jnp.empty((0, observation_size)), - "actions": jnp.empty((0, action_size)), - "rewards": jnp.empty((0,)), - "masks": jnp.empty((0,)), - } scores = [] steps = 0 rollout: List[MjxState] = [] + rng, reset_rng = jax.random.split(rng) + states = reset_fn(reset_rng) pbar = tqdm(total=config.max_steps_per_iteration, desc=f"Steps for iteration {i}") while steps < config.max_steps_per_iteration: episodes += config.num_envs - rng, reset_rng = jax.random.split(rng) - states = reset_fn(reset_rng) obs = jax.device_put(states.obs) score = jnp.zeros(config.num_envs) + memory = { + "states": jnp.empty((0, observation_size)), + "actions": jnp.empty((0, action_size)), + "rewards": jnp.empty((0,)), + "masks": jnp.empty((0,)), + } + for _ in range(config.max_steps_per_episode): # Choosing actions @@ -428,15 +544,27 @@ def step_fn(states: State, actions: jax.Array) -> State: rollout.extend(unwrapped_states) if jnp.all(dones): + rng, reset_rng = jax.random.split(rng) + states = reset_fn(reset_rng) break # with open("log_" + args.env_name + ".txt", "a") as outfile: # outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n") scores.append(jnp.mean(score)) + memory = reorder_memory(memory, config.num_envs) + + train_memory = [ + (s, a, r, m) + for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"]) + ] + train(ppo, train_memory, config) + score_avg = float(jnp.mean(jnp.array(scores))) pbar.close() logger.info("Episode %s score is %.2f", episodes, score_avg) + + wandb.log({"score": score_avg, "episode": episodes}) # Save video for this iteration if args.save_video_every and i % args.save_video_every == 0 and rollout: @@ -456,12 +584,7 @@ def step_fn(states: State, actions: jax.Array) -> State: logger.info("Saving video to %s for iteration %d", video_path, i) media.write_video(video_path, images, fps=fps) - # Convert memory to the format expected by ppo.train - train_memory = [ - (s, a, r, m) for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"]) - ] - train(ppo, train_memory, config) - if __name__ == "__main__": main() + wandb.init(project="humanoid-ppo", config=vars(config)) From bbf0e3c8f1db7c6ab9e687fe566d6d7302a5d8df Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Wed, 7 Aug 2024 03:03:35 +0000 Subject: [PATCH 6/6] training loop fully functional but contains nans --- environment.py | 4 ++-- train.py | 53 +++++++++++++++++++++++++------------------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/environment.py b/environment.py index 74ede2b..7ce65b1 100644 --- a/environment.py +++ b/environment.py @@ -159,7 +159,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr # ) # jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True) - total_reward = 0.1 * ctrl_cost + 5 * is_healthy + 3 * state.q[2] + total_reward = jp.clip(0.1 * ctrl_cost + 5 * is_healthy, -1e8, 10.0) return total_reward @@ -186,7 +186,7 @@ def get_obs(self, data: MjxState, action: jp.ndarray) -> jp.ndarray: data.qfrc_actuator, ] - def clean_component(component): + def clean_component(component: jp.ndarray) -> jp.ndarray: # Check for NaNs or Infs and replace them nan_mask = jp.isnan(component) inf_mask = jp.isinf(component) diff --git a/train.py b/train.py index cbecc7b..86eef1c 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,6 @@ import logging import os from dataclasses import dataclass, field -from functools import partial from typing import Any, Dict, List, Tuple import equinox as eqx @@ -200,7 +199,7 @@ def train_step( advants_b: Array, old_log_prob_b: Array, config: Config, -) -> Tuple[Dict[str, Any], Array, Array]: +) -> Tuple[Dict[str, Any], Tuple[Array, Array]]: """Perform a single training step with PPO parameters.""" actor, critic, actor_opt_state, critic_opt_state = params.values() @@ -210,7 +209,7 @@ def train_step( # Normalizing advantages *in minibatch* according to Trick #7 advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8) - @partial(eqx.filter_value_and_grad, has_aux=True) + @eqx.filter_value_and_grad def actor_loss_fn(actor: Actor) -> Array: """Prioritizing advantageous actions over more training.""" mu_b, std_b = actor_vmap(actor, states_b) @@ -226,10 +225,8 @@ def actor_loss_fn(actor: Actor) -> Array: actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * std_b**2) + 1)) - fraction_clipped = jnp.mean(jnp.abs(ratio_b - 1.0) > config.epsilon) - total_loss = actor_loss - config.entropy_coeff * entropy_loss - return total_loss, (actor_loss, entropy_loss, fraction_clipped) + return total_loss @eqx.filter_value_and_grad def critic_loss_fn(critic: Critic) -> Array: @@ -239,7 +236,7 @@ def critic_loss_fn(critic: Critic) -> Array: return critic_loss # Calculating actor loss and updating actor parameters --- outputting auxillary data for logging - (_, (actor_loss, entropy_loss, fraction_clipped)), actor_grads = actor_loss_fn(actor) + actor_loss, actor_grads = actor_loss_fn(actor) actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) @@ -255,7 +252,7 @@ def critic_loss_fn(critic: Critic) -> Array: "critic_opt_state": new_critic_opt_state, } - return new_params, (actor_loss, critic_loss, entropy_loss, fraction_clipped) + return new_params, (actor_loss, critic_loss) def get_gae(rewards: Array, masks: Array, values: Array, config: Config) -> Tuple[Array, Array]: @@ -315,8 +312,6 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con avg_returns = jnp.mean(returns) total_actor_loss = 0.0 total_critic_loss = 0.0 - total_entropy_loss = 0.0 - total_fraction_clipped = 0.0 logger.info("Processing %d batches", n // config.batch_size) for i in range(n // config.batch_size): @@ -330,7 +325,7 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con old_log_prob_b = old_log_prob[batch_indices] params = ppo.get_params() - new_params, (actor_loss, critic_loss, entropy_loss, fraction_clipped) = train_step( + new_params, (actor_loss, critic_loss) = train_step( ppo.actor_optim, ppo.critic_optim, params, @@ -345,13 +340,9 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con total_actor_loss += actor_loss.mean().item() total_critic_loss += critic_loss.mean().item() - total_entropy_loss += entropy_loss.item() - total_fraction_clipped += fraction_clipped.item() mean_actor_loss = total_actor_loss / (n // config.batch_size) mean_critic_loss = total_critic_loss / (n // config.batch_size) - mean_entropy_loss = total_entropy_loss / (n // config.batch_size) - mean_fraction_clipped = total_fraction_clipped / (n // config.batch_size) logger.info(f"Mean Actor Loss: {mean_actor_loss}, Mean Critic Loss: {mean_critic_loss}") @@ -360,8 +351,6 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con { "actor_loss": mean_actor_loss, "critic_loss": mean_critic_loss, - "entropy_loss": mean_entropy_loss, - "fraction_clipped": mean_fraction_clipped, "avg_advantages": avg_advantages, "avg_returns": avg_returns, } @@ -384,9 +373,6 @@ def unwrap_state_vectorization(state: State, envs_to_sample: int) -> State: # Get all attributes of the state attributes = dir(state) - # NOTE: can change ordering of this to save runtiem if want to save more vectorized states. - # NOTE: (but anyways, the video isn't correctly ordered then) - # saves from only first vectorized state for i in range(envs_to_sample): # Create a new state with the first element of each attribute new_state = {} @@ -502,10 +488,13 @@ def step_fn(states: State, actions: jax.Array) -> State: pbar = tqdm(total=config.max_steps_per_iteration, desc=f"Steps for iteration {i}") wait = -1 - while steps < config.max_steps_per_iteration: + while steps < config.max_steps_per_iteration // config.num_envs: episodes += config.num_envs - norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / (jnp.std(states.obs, axis=1, keepdims=True) + 1e-8) + # Normalizing observations quickens learning + norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( + jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + ) obs = jax.device_put(norm_obs) score = jnp.zeros(config.num_envs) @@ -516,23 +505,32 @@ def step_fn(states: State, actions: jax.Array) -> State: "masks": jnp.empty((0,)), } - for _ in range(config.max_steps_per_episode): + for _ in range(config.max_steps_per_episode // config.num_envs): # Choosing actions choose_action_vmap = jax.vmap(choose_action, in_axes=(None, 0, 0)) rng, *action_rng = jax.random.split(rng, num=config.num_envs + 1) actions = choose_action_vmap(ppo.actor, obs, jnp.array(action_rng)) + # NOTE: disable actions when "done" --> better for current wait system? states = step_fn(states, actions) - norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / (jnp.std(states.obs, axis=1, keepdims=True) + 1e-8) + + # Normalizing observations quickens learning + norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( + jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + ) next_obs, rewards, dones = norm_obs, states.reward, states.done - masks = (1 - 0.8 * dones).astype(jnp.float32) + masks = (1 - dones).astype(jnp.float32) # Update memory new_data = {"states": obs, "actions": actions, "rewards": rewards, "masks": masks} memory = update_memory(memory, new_data) score += rewards + + if jnp.any(jnp.isnan(rewards)): + print(rewards) + obs = next_obs steps += config.num_envs pbar.update(config.num_envs) @@ -548,8 +546,9 @@ def step_fn(states: State, actions: jax.Array) -> State: if jnp.all(dones): - # NOTE: this "waiting" penalizes systems that just want to fall quick (and reset quick) - # goes from needing 84 steps to 30 steps + # NOTE: this "waiting" penalizes systems that just want to fall quick (and reset quick), + # since prevents fast reset goes from needing 84 steps to 30 steps to get to pretty high scores. + # This isn't implemented in any other PPO though so not sure if it's a good idea. if wait == -1: wait = 10 else: