Skip to content

Commit

Permalink
added tricks from blogpost
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanjzhao committed Aug 5, 2024
1 parent 69ea288 commit 88caa5a
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 39 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ Minimal training and inference code for making a humanoid robot stand up.
- [ ] Implement simple PPO policy to try to make the robot stand up
- [ ] Parallelize using JAX

# Findings
- Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty
- Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward.

## Goals

- The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files:
Expand Down
4 changes: 2 additions & 2 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy)

is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0)
# is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0)

ctrl_cost = -jp.sum(jp.square(action))

Expand All @@ -158,7 +158,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
# )
# jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True)

total_reward = 2.0 * is_healthy + 0.1 * ctrl_cost - 5.0 * is_bad
total_reward = 0.1 * ctrl_cost + 5 * state.q[2]

return total_reward

Expand Down
156 changes: 119 additions & 37 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,28 @@

@dataclass
class Config:
lr_actor: float = field(default=3e-4, metadata={"help": "Learning rate for the actor network."})
lr_critic: float = field(default=3e-4, metadata={"help": "Learning rate for the critic network."})
lr_actor: float = field(default=2.5e-4, metadata={"help": "Learning rate for the actor network."})
lr_critic: float = field(default=2.5e-4, metadata={"help": "Learning rate for the critic network."})
num_iterations: int = field(default=15000, metadata={"help": "Number of environment simulation iterations."})
num_envs: int = field(default=16, metadata={"help": "Number of environments to run at once with vectorization."})
max_steps_per_episode: int = field(
default=512 * 16, metadata={"help": "Maximum number of steps per episode (across ALL environments)."}
default=128 * 16, metadata={"help": "Maximum number of steps per episode (across ALL environments)."}
)
max_steps_per_iteration: int = field(
default=1024 * 16,
default=512 * 16,
metadata={
"help": "Maximum number of steps per iteration of simulating environments (across ALL environments)."
},
)
gamma: float = field(default=0.98, metadata={"help": "Discount factor for future rewards."})
lambd: float = field(default=0.99, metadata={"help": "Lambda parameter for GAE calculation."})
batch_size: int = field(default=64, metadata={"help": "Batch size for training updates."})
gamma: float = field(default=0.99, metadata={"help": "Discount factor for future rewards."})
lambd: float = field(default=0.95, metadata={"help": "Lambda parameter for GAE calculation."})
batch_size: int = field(default=32, metadata={"help": "Batch size for training updates."})
epsilon: float = field(default=0.2, metadata={"help": "Clipping parameter for PPO."})
l2_rate: float = field(default=0.001, metadata={"help": "L2 regularization rate for the critic."})
entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."})


# NOTE: change how initialize weights?
class Actor(eqx.Module):
"""Actor network for PPO."""

Expand All @@ -59,13 +61,37 @@ def __init__(self, input_size: int, action_size: int, key: Array) -> None:
self.mu_layer = eqx.nn.Linear(64, action_size, key=keys[2])
self.log_sigma_layer = eqx.nn.Linear(64, action_size, key=keys[3])

# Parameter initialization according to Trick #2
self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0])
self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1])
self.mu_layer = self.initialize_layer(self.mu_layer, 0.01, keys[2])
self.log_sigma_layer = self.initialize_layer(self.log_sigma_layer, 0.01, keys[3])

def __call__(self, x: Array) -> Tuple[Array, Array]:
x = jax.nn.tanh(self.linear1(x))
x = jax.nn.tanh(self.linear2(x))
mu = self.mu_layer(x)
log_sigma = self.log_sigma_layer(x)
return mu, jnp.exp(log_sigma)

def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear:
weight_shape = layer.weight.shape

initializer = jax.nn.initializers.orthogonal()
new_weight = initializer(key, weight_shape, jnp.float32) * scale
new_bias = jnp.zeros(layer.bias.shape) if layer.bias else None

def where_weight(layer: eqx.nn.Linear) -> Array:
return layer.weight

def where_bias(layer: eqx.nn.Linear) -> Array | None:
return layer.bias

new_layer = eqx.tree_at(where_weight, layer, new_weight)
new_layer = eqx.tree_at(where_bias, new_layer, new_bias)

return new_layer


class Critic(eqx.Module):
"""Critic network for PPO."""
Expand All @@ -80,11 +106,34 @@ def __init__(self, input_size: int, key: Array) -> None:
self.linear2 = eqx.nn.Linear(64, 64, key=keys[1])
self.value_layer = eqx.nn.Linear(64, 1, key=keys[2])

# Parameter initialization according to Trick #2
self.linear1 = self.initialize_layer(self.linear1, np.sqrt(2), keys[0])
self.linear2 = self.initialize_layer(self.linear2, np.sqrt(2), keys[1])
self.value_layer = self.initialize_layer(self.value_layer, 1.0, keys[2])

def __call__(self, x: Array) -> Array:
x = jax.nn.tanh(self.linear1(x))
x = jax.nn.tanh(self.linear2(x))
return self.value_layer(x)

def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear:
weight_shape = layer.weight.shape

initializer = jax.nn.initializers.orthogonal()
new_weight = initializer(key, weight_shape, jnp.float32) * scale
new_bias = jnp.zeros(layer.bias.shape) if layer.bias else None

def where_weight(layer: eqx.nn.Linear) -> Array:
return layer.weight

def where_bias(layer: eqx.nn.Linear) -> Array | None:
return layer.bias

new_layer = eqx.tree_at(where_weight, layer, new_weight)
new_layer = eqx.tree_at(where_bias, new_layer, new_bias)

return new_layer


class Ppo:
def __init__(self, observation_size: int, action_size: int, config: Config, key: Array) -> None:
Expand All @@ -93,8 +142,21 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key:
self.actor = Actor(observation_size, action_size, subkey1)
self.critic = Critic(observation_size, subkey2)

self.actor_optim = optax.adam(learning_rate=config.lr_actor)
self.critic_optim = optax.adamw(learning_rate=config.lr_critic, weight_decay=config.l2_rate)
total_timesteps = config.num_iterations

# Learning rate annealing according to Trick #4
self.actor_schedule = optax.linear_schedule(
init_value=config.lr_actor, end_value=0, transition_steps=total_timesteps
)
self.critic_schedule = optax.linear_schedule(
init_value=config.lr_critic, end_value=0, transition_steps=total_timesteps
)

# eps below according to Trick #3
self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5))
self.critic_optim = optax.chain(
optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5)
)

# Initialize optimizer states
self.actor_opt_state = self.actor_optim.init(eqx.filter(self.actor, eqx.is_array))
Expand Down Expand Up @@ -133,8 +195,9 @@ def train_step(
params: Dict[str, Any],
states_b: Array,
actions_b: Array,
rewards_b: Array,
masks_b: Array,
returns_b: Array,
advants_b: Array,
old_log_prob_b: Array,
config: Config,
) -> Tuple[Dict[str, Any], Array, Array]:
"""Perform a single training step with PPO parameters."""
Expand All @@ -143,25 +206,27 @@ def train_step(
actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0))
critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0))

values_b = critic_vmap(critic, states_b).squeeze()
returns_b, advants_b = get_gae(rewards_b, masks_b, values_b, config)

old_mu_b, old_std_b = actor_vmap(actor, states_b)
old_log_prob_b = actor_log_prob(old_mu_b, old_std_b, actions_b)
# Normalizing advantages *in minibatch* according to Trick #7
advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8)

@eqx.filter_value_and_grad
def actor_loss_fn(actor: Actor) -> Array:
"""Prioritizing advantageous actions over more training."""
mu_b, std_b = actor_vmap(actor, states_b)
new_log_prob_b = actor_log_prob(mu_b, std_b, actions_b)

# Calculating the ratio of new and old probabilities
ratio_b = jnp.exp(new_log_prob_b - old_log_prob_b)
breakpoint()
surrogate_loss_b = ratio_b * advants_b

# Clipping is done to prevent too much change if new advantages are very large
clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b
actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b))
return actor_loss

entropy_loss = jnp.mean(jax.scipy.stats.norm.entropy(mu_b, std_b))
total_loss = actor_loss - config.entropy_coeff * entropy_loss
return total_loss

@eqx.filter_value_and_grad
def critic_loss_fn(critic: Critic) -> Array:
Expand Down Expand Up @@ -205,11 +270,10 @@ def gae_step(carry: Tuple[Array, Array], inp: Tuple[Array, Array, Array]) -> Tup
_, advantages = jax.lax.scan(
f=gae_step,
init=(jnp.zeros_like(rewards[-1]), values[-1]),
xs=(rewards[::-1], masks[::-1], values[::-1]),
xs=(rewards, masks, values), # NOTE: correct direction?
reverse=True,
)
returns = advantages + values
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return returns, advantages


Expand All @@ -224,6 +288,18 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con
rewards = jnp.array([e[2] for e in memory])
masks = jnp.array([e[3] for e in memory])

# Calculate old log probabilities
actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0))
old_mu, old_std = actor_vmap(ppo.actor, states)
old_log_prob = actor_log_prob(old_mu, old_std, actions)

# Calculate values for all states
critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0))
values = critic_vmap(ppo.critic, states).squeeze()

# Calculate GAE and returns
returns, advantages = get_gae(rewards, masks, values, config)

n = len(states)
arr = jnp.arange(n)
key = jax.random.PRNGKey(0)
Expand All @@ -240,8 +316,9 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con
batch_indices = arr[config.batch_size * i : config.batch_size * (i + 1)]
states_b = states[batch_indices]
actions_b = actions[batch_indices]
rewards_b = rewards[batch_indices]
masks_b = masks[batch_indices]
returns_b = returns[batch_indices]
advantages_b = advantages[batch_indices]
old_log_prob_b = old_log_prob[batch_indices]

params = ppo.get_params()
new_params, actor_loss, critic_loss = train_step(
Expand All @@ -250,8 +327,9 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con
params,
states_b,
actions_b,
rewards_b,
masks_b,
returns_b,
advantages_b,
old_log_prob_b,
config,
)
ppo.update_params(new_params)
Expand Down Expand Up @@ -378,26 +456,27 @@ def step_fn(states: State, actions: jax.Array) -> State:

for i in range(1, config.num_iterations + 1):
# Initialize memory as JAX arrays
memory = {
"states": jnp.empty((0, observation_size)),
"actions": jnp.empty((0, action_size)),
"rewards": jnp.empty((0,)),
"masks": jnp.empty((0,)),
}
scores = []
steps = 0
rollout: List[MjxState] = []

rng, reset_rng = jax.random.split(rng)
states = reset_fn(reset_rng)
pbar = tqdm(total=config.max_steps_per_iteration, desc=f"Steps for iteration {i}")

while steps < config.max_steps_per_iteration:
episodes += config.num_envs

rng, reset_rng = jax.random.split(rng)
states = reset_fn(reset_rng)
obs = jax.device_put(states.obs)
score = jnp.zeros(config.num_envs)

memory = {
"states": jnp.empty((0, observation_size)),
"actions": jnp.empty((0, action_size)),
"rewards": jnp.empty((0,)),
"masks": jnp.empty((0,)),
}

for _ in range(config.max_steps_per_episode):

# Choosing actions
Expand Down Expand Up @@ -428,12 +507,21 @@ def step_fn(states: State, actions: jax.Array) -> State:
rollout.extend(unwrapped_states)

if jnp.all(dones):
rng, reset_rng = jax.random.split(rng)
states = reset_fn(reset_rng)
break

# with open("log_" + args.env_name + ".txt", "a") as outfile:
# outfile.write("\t" + str(episodes) + "\t" + str(jnp.mean(score)) + "\n")
scores.append(jnp.mean(score))

# Convert memory to the format expected by ppo.train
train_memory = [
(s, a, r, m)
for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"])
]
train(ppo, train_memory, config)

score_avg = float(jnp.mean(jnp.array(scores)))
pbar.close()
logger.info("Episode %s score is %.2f", episodes, score_avg)
Expand All @@ -456,12 +544,6 @@ def step_fn(states: State, actions: jax.Array) -> State:
logger.info("Saving video to %s for iteration %d", video_path, i)
media.write_video(video_path, images, fps=fps)

# Convert memory to the format expected by ppo.train
train_memory = [
(s, a, r, m) for s, a, r, m in zip(memory["states"], memory["actions"], memory["rewards"], memory["masks"])
]
train(ppo, train_memory, config)


if __name__ == "__main__":
main()

0 comments on commit 88caa5a

Please sign in to comment.