Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nans with higher epsilon #19

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ out*/
*.stl
*.txt
*.mp4
*.notes

environments/
screenshots/
Expand Down
46 changes: 28 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Tuple

import equinox as eqx
Expand Down Expand Up @@ -45,6 +44,7 @@ class Config:
epsilon: float = field(default=0.2, metadata={"help": "Clipping parameter for PPO."})
l2_rate: float = field(default=0.001, metadata={"help": "L2 regularization rate for the critic."})
entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."})
minimal: bool = field(default=False, metadata={"help": "Make minimal PPO (no std) for breakpoint debugging"})


class Actor(eqx.Module):
Expand Down Expand Up @@ -72,7 +72,10 @@ def __call__(self, x: Array) -> Tuple[Array, Array]:
x = jax.nn.tanh(self.linear1(x))
x = jax.nn.tanh(self.linear2(x))
mu = self.mu_layer(x)

log_sigma = self.log_sigma_layer(x)

# return mu, jnp.zeros_like(log_sigma)
return mu, jnp.exp(log_sigma)

def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear:
Expand Down Expand Up @@ -147,19 +150,15 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key:

# Learning rate annealing according to Trick #4
self.actor_schedule = optax.linear_schedule(
init_value=config.lr_actor, end_value=1e-6, transition_steps=total_timesteps
init_value=config.lr_actor, end_value=0, transition_steps=total_timesteps
)
self.critic_schedule = optax.linear_schedule(
init_value=config.lr_critic, end_value=1e-6, transition_steps=total_timesteps
init_value=config.lr_critic, end_value=0, transition_steps=total_timesteps
)

clip_threshold = 1.0
# eps below according to Trick #3
self.actor_optim = optax.chain(
optax.clip_by_global_norm(clip_threshold), optax.adam(learning_rate=self.actor_schedule, eps=1e-5)
)
self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5))
self.critic_optim = optax.chain(
optax.clip_by_global_norm(clip_threshold),
optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5),
)

Expand Down Expand Up @@ -212,7 +211,7 @@ def train_step(
critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0))

# Normalizing advantages *in minibatch* according to Trick #7
advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8)
advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-4)

@eqx.filter_value_and_grad
def actor_loss_fn(actor: Actor) -> Array:
Expand All @@ -228,11 +227,9 @@ def actor_loss_fn(actor: Actor) -> Array:
clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b

# Choosing the smallest magnitude loss
actor_loss = -jnp.mean(
jnp.where(jnp.abs(surrogate_loss_b) < jnp.abs(clipped_loss_b), surrogate_loss_b, clipped_loss_b)
)
actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b))

entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-8) ** 2) + 1))
entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-4) ** 2) + 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yea - you should use 1e-5 for eps when training fp16 models bc the minimum positive value is $2^{-24} \approx 5 \times 10^{-8}$

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

damn knowing that would've saved me a lot of time lol


total_loss = actor_loss - config.entropy_coeff * entropy_loss
return total_loss
Expand All @@ -249,6 +246,12 @@ def critic_loss_fn(critic: Critic) -> Array:
actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor)
new_actor = eqx.apply_updates(actor, actor_updates)

if config.minimal:
breakpoint()

if jnp.any(jnp.isnan(actor_loss)):
breakpoint()

# Calculating critic loss and updating critic parameters
critic_loss, critic_grads = critic_loss_fn(critic)
critic_updates, new_critic_opt_state = critic_optim.update(critic_grads, critic_opt_state, params=critic)
Expand Down Expand Up @@ -330,6 +333,7 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con
returns_b = returns[batch_indices]
advantages_b = advantages[batch_indices]
old_log_prob_b = old_log_prob[batch_indices]
# breakpoint()

params = ppo.get_params()
new_params, (actor_loss, critic_loss) = train_step(
Expand Down Expand Up @@ -366,12 +370,14 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con

def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array:
"""Calculate the log probability of the actions given the actor network's output."""
return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-8).sum(axis=-1)
# Summing across the number of actions after logpdf of each relative to mu/sigma, determined by state
return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-4).sum(axis=1)


def actor_distribution(mu: Array, sigma: Array, rng: Array) -> Array:
"""Get an action from the actor network from its probability distribution of actions."""
return jax.random.normal(rng, shape=mu.shape) * sigma + mu
action = jax.random.normal(rng, shape=mu.shape) * (sigma + 1e-4) + mu
return action


def unwrap_state_vectorization(state: State, envs_to_sample: int) -> State:
Expand Down Expand Up @@ -500,7 +506,7 @@ def step_fn(states: State, actions: jax.Array) -> State:

# Normalizing observations quickens learning
norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / (
jnp.std(states.obs, axis=1, keepdims=True) + 1e-8
jnp.std(states.obs, axis=1, keepdims=True) + 1e-4
)
obs = jax.device_put(norm_obs)
score = jnp.zeros(config.num_envs)
Expand All @@ -524,10 +530,14 @@ def step_fn(states: State, actions: jax.Array) -> State:

# Normalizing observations quickens learning
norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / (
jnp.std(states.obs, axis=1, keepdims=True) + 1e-8
jnp.std(states.obs, axis=1, keepdims=True) + 1e-4
)
next_obs, rewards, dones = norm_obs, states.reward, states.done
masks = (1 - 0.8 * dones).astype(jnp.float32)
masks = (1 - dones).astype(jnp.float32)

# Check for NaNs and print them
if jnp.any(jnp.isnan(rewards)):
breakpoint()

# Update memory
new_data = {"states": obs, "actions": actions, "rewards": rewards, "masks": masks}
Expand Down
12 changes: 0 additions & 12 deletions training.notes

This file was deleted.

Loading