diff --git a/README.md b/README.md index d85e46a..4ab7bec 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,6 @@ Minimal training and inference code for making a humanoid robot stand up. - Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty - Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward. -# Currently tests: -- Hidden layer size of 256 shows progress (loss is based on state.q[2]) - ## Goals - The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files: diff --git a/environment.py b/environment.py index 7ce65b1..5263436 100644 --- a/environment.py +++ b/environment.py @@ -142,8 +142,6 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0) is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy) - # is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0) - ctrl_cost = -jp.sum(jp.square(action)) # xpos = state.subtree_com[1][0] @@ -159,7 +157,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr # ) # jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True) - total_reward = jp.clip(0.1 * ctrl_cost + 5 * is_healthy, -1e8, 10.0) + total_reward = 0.1 * ctrl_cost + 5 * is_healthy return total_reward @@ -186,17 +184,7 @@ def get_obs(self, data: MjxState, action: jp.ndarray) -> jp.ndarray: data.qfrc_actuator, ] - def clean_component(component: jp.ndarray) -> jp.ndarray: - # Check for NaNs or Infs and replace them - nan_mask = jp.isnan(component) - inf_mask = jp.isinf(component) - component = jp.where(nan_mask, 0.0, component) - component = jp.where(inf_mask, jp.where(component > 0, 1e6, -1e6), component) - return component - - cleaned_components = [clean_component(comp) for comp in obs_components] - - return jp.concatenate(cleaned_components) + return jp.concatenate(obs_components) def run_environment_adhoc() -> None: diff --git a/train.py b/train.py index 86eef1c..c505c83 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ import logging import os from dataclasses import dataclass, field +from functools import partial from typing import Any, Dict, List, Tuple import equinox as eqx @@ -30,10 +31,10 @@ class Config: num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."}) num_envs: int = field(default=32, metadata={"help": "Number of environments to run at once with vectorization."}) max_steps_per_episode: int = field( - default=128 * 32, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} + default=512 * 32, metadata={"help": "Maximum number of steps per episode (across ALL environments)."} ) max_steps_per_iteration: int = field( - default=512 * 32, + default=1024 * 32, metadata={ "help": "Maximum number of steps per iteration of simulating environments (across ALL environments)." }, @@ -152,10 +153,14 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: init_value=config.lr_critic, end_value=1e-6, transition_steps=total_timesteps ) + clip_threshold = 1.0 # eps below according to Trick #3 - self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5)) + self.actor_optim = optax.chain( + optax.clip_by_global_norm(clip_threshold), optax.adam(learning_rate=self.actor_schedule, eps=1e-5) + ) self.critic_optim = optax.chain( - optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5) + optax.clip_by_global_norm(clip_threshold), + optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5), ) # Initialize optimizer states @@ -222,8 +227,12 @@ def actor_loss_fn(actor: Actor) -> Array: # Clipping is done to prevent too much change if new advantages are very large clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b - actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) - entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * std_b**2) + 1)) + # Choosing the smallest magnitude loss + actor_loss = -jnp.mean( + jnp.where(jnp.abs(surrogate_loss_b) < jnp.abs(clipped_loss_b), surrogate_loss_b, clipped_loss_b) + ) + + entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-8) ** 2) + 1)) total_loss = actor_loss - config.entropy_coeff * entropy_loss return total_loss @@ -294,8 +303,6 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) values = critic_vmap(ppo.critic, states).squeeze() - # NOTE: are the output shapes correct? - # Calculate GAE and returns returns, advantages = get_gae(rewards, masks, values, config) @@ -520,7 +527,7 @@ def step_fn(states: State, actions: jax.Array) -> State: jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 ) next_obs, rewards, dones = norm_obs, states.reward, states.done - masks = (1 - dones).astype(jnp.float32) + masks = (1 - 0.8 * dones).astype(jnp.float32) # Update memory new_data = {"states": obs, "actions": actions, "rewards": rewards, "masks": masks} @@ -528,9 +535,6 @@ def step_fn(states: State, actions: jax.Array) -> State: score += rewards - if jnp.any(jnp.isnan(rewards)): - print(rewards) - obs = next_obs steps += config.num_envs pbar.update(config.num_envs) @@ -558,8 +562,6 @@ def step_fn(states: State, actions: jax.Array) -> State: states = reset_fn(reset_rng) break - # with open("log_" + args.env_name + ".txt", "a") as outfile: - # outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n") scores.append(jnp.mean(score)) memory = reorder_memory(memory, config.num_envs) @@ -571,6 +573,7 @@ def step_fn(states: State, actions: jax.Array) -> State: train(ppo, train_memory, config) score_avg = float(jnp.mean(jnp.array(scores))) + pbar.close() logger.info("Episode %s score is %.2f", episodes, score_avg) diff --git a/training.notes b/training.notes new file mode 100644 index 0000000..66d9976 --- /dev/null +++ b/training.notes @@ -0,0 +1,12 @@ + +# Currently tests: +- Hidden layer size of 256 shows progress (loss is based on state.q[2]) + +- setting std to zero makes rewards nans why. I wonder if there NEEDS to be randomization in the enviornment + +- ctrl cost is whats giving nans? interesting? +- it is unrelated to randomization of enviornmnet. i think gradient related + +- first thing to become nans seems to be actor loss and scores. after that, everything becomes nans + +- fixed entropy epsilon. hope this works now. \ No newline at end of file