Skip to content

Commit

Permalink
gradient clipping and epsilons + further nan debugging. iterated over…
Browse files Browse the repository at this point in the history
… minimum example with std set to 0, large actor loss seems to be problem
  • Loading branch information
nathanjzhao committed Aug 8, 2024
1 parent bbf0e3c commit 9c99714
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 31 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ Minimal training and inference code for making a humanoid robot stand up.
- Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty
- Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward.

# Currently tests:
- Hidden layer size of 256 shows progress (loss is based on state.q[2])

## Goals

- The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files:
Expand Down
16 changes: 2 additions & 14 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy)

# is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0)

ctrl_cost = -jp.sum(jp.square(action))

# xpos = state.subtree_com[1][0]
Expand All @@ -159,7 +157,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
# )
# jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True)

total_reward = jp.clip(0.1 * ctrl_cost + 5 * is_healthy, -1e8, 10.0)
total_reward = 0.1 * ctrl_cost + 5 * is_healthy

return total_reward

Expand All @@ -186,17 +184,7 @@ def get_obs(self, data: MjxState, action: jp.ndarray) -> jp.ndarray:
data.qfrc_actuator,
]

def clean_component(component: jp.ndarray) -> jp.ndarray:
# Check for NaNs or Infs and replace them
nan_mask = jp.isnan(component)
inf_mask = jp.isinf(component)
component = jp.where(nan_mask, 0.0, component)
component = jp.where(inf_mask, jp.where(component > 0, 1e6, -1e6), component)
return component

cleaned_components = [clean_component(comp) for comp in obs_components]

return jp.concatenate(cleaned_components)
return jp.concatenate(obs_components)


def run_environment_adhoc() -> None:
Expand Down
31 changes: 17 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Tuple

import equinox as eqx
Expand All @@ -30,10 +31,10 @@ class Config:
num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."})
num_envs: int = field(default=32, metadata={"help": "Number of environments to run at once with vectorization."})
max_steps_per_episode: int = field(
default=128 * 32, metadata={"help": "Maximum number of steps per episode (across ALL environments)."}
default=512 * 32, metadata={"help": "Maximum number of steps per episode (across ALL environments)."}
)
max_steps_per_iteration: int = field(
default=512 * 32,
default=1024 * 32,
metadata={
"help": "Maximum number of steps per iteration of simulating environments (across ALL environments)."
},
Expand Down Expand Up @@ -152,10 +153,14 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key:
init_value=config.lr_critic, end_value=1e-6, transition_steps=total_timesteps
)

clip_threshold = 1.0
# eps below according to Trick #3
self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5))
self.actor_optim = optax.chain(
optax.clip_by_global_norm(clip_threshold), optax.adam(learning_rate=self.actor_schedule, eps=1e-5)
)
self.critic_optim = optax.chain(
optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5)
optax.clip_by_global_norm(clip_threshold),
optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5),
)

# Initialize optimizer states
Expand Down Expand Up @@ -222,8 +227,12 @@ def actor_loss_fn(actor: Actor) -> Array:
# Clipping is done to prevent too much change if new advantages are very large
clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b

actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b))
entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * std_b**2) + 1))
# Choosing the smallest magnitude loss
actor_loss = -jnp.mean(
jnp.where(jnp.abs(surrogate_loss_b) < jnp.abs(clipped_loss_b), surrogate_loss_b, clipped_loss_b)
)

entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-8) ** 2) + 1))

total_loss = actor_loss - config.entropy_coeff * entropy_loss
return total_loss
Expand Down Expand Up @@ -294,8 +303,6 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con
critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0))
values = critic_vmap(ppo.critic, states).squeeze()

# NOTE: are the output shapes correct?

# Calculate GAE and returns
returns, advantages = get_gae(rewards, masks, values, config)

Expand Down Expand Up @@ -520,17 +527,14 @@ def step_fn(states: State, actions: jax.Array) -> State:
jnp.std(states.obs, axis=1, keepdims=True) + 1e-8
)
next_obs, rewards, dones = norm_obs, states.reward, states.done
masks = (1 - dones).astype(jnp.float32)
masks = (1 - 0.8 * dones).astype(jnp.float32)

# Update memory
new_data = {"states": obs, "actions": actions, "rewards": rewards, "masks": masks}
memory = update_memory(memory, new_data)

score += rewards

if jnp.any(jnp.isnan(rewards)):
print(rewards)

obs = next_obs
steps += config.num_envs
pbar.update(config.num_envs)
Expand Down Expand Up @@ -558,8 +562,6 @@ def step_fn(states: State, actions: jax.Array) -> State:
states = reset_fn(reset_rng)
break

# with open("log_" + args.env_name + ".txt", "a") as outfile:
# outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n")
scores.append(jnp.mean(score))

memory = reorder_memory(memory, config.num_envs)
Expand All @@ -571,6 +573,7 @@ def step_fn(states: State, actions: jax.Array) -> State:
train(ppo, train_memory, config)

score_avg = float(jnp.mean(jnp.array(scores)))

pbar.close()
logger.info("Episode %s score is %.2f", episodes, score_avg)

Expand Down

0 comments on commit 9c99714

Please sign in to comment.