From b150e76db2eac930bcf877ea81e9fcff6f15bc3b Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Thu, 8 Aug 2024 20:52:39 +0000 Subject: [PATCH 1/2] commit --- train.py | 85 +++++++++++++++++++++++++++++++++----------------- training.notes | 14 ++++++++- 2 files changed, 69 insertions(+), 30 deletions(-) diff --git a/train.py b/train.py index c505c83..4641afd 100644 --- a/train.py +++ b/train.py @@ -45,6 +45,7 @@ class Config: epsilon: float = field(default=0.2, metadata={"help": "Clipping parameter for PPO."}) l2_rate: float = field(default=0.001, metadata={"help": "L2 regularization rate for the critic."}) entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."}) + minimal: bool = field(default=False, metadata={"help": "Make minimal PPO (no std) for breakpoint debugging"}) class Actor(eqx.Module): @@ -72,7 +73,10 @@ def __call__(self, x: Array) -> Tuple[Array, Array]: x = jax.nn.tanh(self.linear1(x)) x = jax.nn.tanh(self.linear2(x)) mu = self.mu_layer(x) + log_sigma = self.log_sigma_layer(x) + + # return mu, jnp.zeros_like(log_sigma) return mu, jnp.exp(log_sigma) def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: @@ -147,19 +151,15 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: # Learning rate annealing according to Trick #4 self.actor_schedule = optax.linear_schedule( - init_value=config.lr_actor, end_value=1e-6, transition_steps=total_timesteps + init_value=config.lr_actor, end_value=0, transition_steps=total_timesteps ) self.critic_schedule = optax.linear_schedule( - init_value=config.lr_critic, end_value=1e-6, transition_steps=total_timesteps + init_value=config.lr_critic, end_value=0, transition_steps=total_timesteps ) - clip_threshold = 1.0 # eps below according to Trick #3 - self.actor_optim = optax.chain( - optax.clip_by_global_norm(clip_threshold), optax.adam(learning_rate=self.actor_schedule, eps=1e-5) - ) + self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5)) self.critic_optim = optax.chain( - optax.clip_by_global_norm(clip_threshold), optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5), ) @@ -167,14 +167,14 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: self.actor_opt_state = self.actor_optim.init(eqx.filter(self.actor, eqx.is_array)) self.critic_opt_state = self.critic_optim.init(eqx.filter(self.critic, eqx.is_array)) - def get_params(self) -> Dict[str, Any]: + def get_params(self) -> Tuple[Any]: """Get the parameters of the PPO model.""" - return { - "actor": self.actor, - "critic": self.critic, - "actor_opt_state": self.actor_opt_state, - "critic_opt_state": self.critic_opt_state, - } + return ( + self.actor, + self.critic, + self.actor_opt_state, + self.critic_opt_state, + ) def update_params(self, new_params: Dict[str, Any]) -> None: """Update the parameters of the PPO model.""" @@ -197,7 +197,7 @@ def apply_actor(actor: Critic, state: Array) -> Array: def train_step( actor_optim: optax.GradientTransformation, critic_optim: optax.GradientTransformation, - params: Dict[str, Any], + params: Tuple[Any], states_b: Array, actions_b: Array, returns_b: Array, @@ -206,16 +206,16 @@ def train_step( config: Config, ) -> Tuple[Dict[str, Any], Tuple[Array, Array]]: """Perform a single training step with PPO parameters.""" - actor, critic, actor_opt_state, critic_opt_state = params.values() + actor, critic, actor_opt_state, critic_opt_state = params actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) # Normalizing advantages *in minibatch* according to Trick #7 - advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8) + advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-4) - @eqx.filter_value_and_grad - def actor_loss_fn(actor: Actor) -> Array: + @partial(eqx.filter_value_and_grad, has_aux=True) + def actor_loss_fn(actor: Actor, states_b, actions_b, advants_b) -> Array: """Prioritizing advantageous actions over more training.""" mu_b, std_b = actor_vmap(actor, states_b) new_log_prob_b = actor_log_prob(mu_b, std_b, actions_b) @@ -232,25 +232,44 @@ def actor_loss_fn(actor: Actor) -> Array: jnp.where(jnp.abs(surrogate_loss_b) < jnp.abs(clipped_loss_b), surrogate_loss_b, clipped_loss_b) ) - entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-8) ** 2) + 1)) + # Keeping track of number of losses that are clipped + clipped_losses = jnp.abs(surrogate_loss_b) >= jnp.abs(clipped_loss_b) + fraction_clipped = jnp.mean(clipped_losses) + + entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-4) ** 2) + 1)) total_loss = actor_loss - config.entropy_coeff * entropy_loss - return total_loss + return total_loss, ( + ratio_b, + entropy_loss, + actor_loss, + surrogate_loss_b, + clipped_loss_b, + fraction_clipped, + new_log_prob_b, + old_log_prob_b, + ) @eqx.filter_value_and_grad - def critic_loss_fn(critic: Critic) -> Array: + def critic_loss_fn(critic: Critic, states_b, returns_b) -> Array: """Prioritizing being able to predict the ground truth returns.""" critic_returns_b = critic_vmap(critic, states_b).squeeze() critic_loss = jnp.mean((critic_returns_b - returns_b) ** 2) return critic_loss # Calculating actor loss and updating actor parameters --- outputting auxillary data for logging - actor_loss, actor_grads = actor_loss_fn(actor) + (actor_loss, values), actor_grads = actor_loss_fn(actor, states_b, actions_b, advants_b) actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) + if config.minimal: + breakpoint() + + if jnp.any(jnp.isnan(actor_loss)) or jnp.any(jnp.isinf(values[0])): + breakpoint() + # Calculating critic loss and updating critic parameters - critic_loss, critic_grads = critic_loss_fn(critic) + critic_loss, critic_grads = critic_loss_fn(critic, states_b, returns_b) critic_updates, new_critic_opt_state = critic_optim.update(critic_grads, critic_opt_state, params=critic) new_critic = eqx.apply_updates(critic, critic_updates) @@ -330,6 +349,7 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con returns_b = returns[batch_indices] advantages_b = advantages[batch_indices] old_log_prob_b = old_log_prob[batch_indices] + # breakpoint() params = ppo.get_params() new_params, (actor_loss, critic_loss) = train_step( @@ -366,12 +386,15 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array: """Calculate the log probability of the actions given the actor network's output.""" - return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-8).sum(axis=-1) + + # Summing across the number of actions after logpdf of each relative to mu/sigma, determined by state + return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-4).sum(axis=1) def actor_distribution(mu: Array, sigma: Array, rng: Array) -> Array: """Get an action from the actor network from its probability distribution of actions.""" - return jax.random.normal(rng, shape=mu.shape) * sigma + mu + action = jax.random.normal(rng, shape=mu.shape) * (sigma + 1e-4) + mu + return action def unwrap_state_vectorization(state: State, envs_to_sample: int) -> State: @@ -500,7 +523,7 @@ def step_fn(states: State, actions: jax.Array) -> State: # Normalizing observations quickens learning norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( - jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + jnp.std(states.obs, axis=1, keepdims=True) + 1e-4 ) obs = jax.device_put(norm_obs) score = jnp.zeros(config.num_envs) @@ -524,10 +547,14 @@ def step_fn(states: State, actions: jax.Array) -> State: # Normalizing observations quickens learning norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( - jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + jnp.std(states.obs, axis=1, keepdims=True) + 1e-4 ) next_obs, rewards, dones = norm_obs, states.reward, states.done - masks = (1 - 0.8 * dones).astype(jnp.float32) + masks = (1 - dones).astype(jnp.float32) + + # Check for NaNs and print them + if jnp.any(jnp.isnan(rewards)): + breakpoint() # Update memory new_data = {"states": obs, "actions": actions, "rewards": rewards, "masks": masks} diff --git a/training.notes b/training.notes index 66d9976..d746e14 100644 --- a/training.notes +++ b/training.notes @@ -9,4 +9,16 @@ - first thing to become nans seems to be actor loss and scores. after that, everything becomes nans -- fixed entropy epsilon. hope this works now. \ No newline at end of file +- fixed entropy epsilon. hope this works now. + + +- training clipped mus right now + +- `equinox.internal.debug_backward_nan` +- JAX_DEBUG_NANS=1 + +- lr = 0 still nan in minimum example +- problem maybe in gradient clipping? as per https://github.com/google/jax/discussions/6440 +- ^nope + +- i think its a problem with the clipping \ No newline at end of file From a32da29d5e9d3510f26c7db005ce6bedcf3565bf Mon Sep 17 00:00:00 2001 From: Nathan Zhao Date: Thu, 8 Aug 2024 21:04:02 +0000 Subject: [PATCH 2/2] cleanup tests + lint --- .gitignore | 1 + train.py | 51 +++++++++++++++++--------------------------------- training.notes | 24 ------------------------ 3 files changed, 18 insertions(+), 58 deletions(-) delete mode 100644 training.notes diff --git a/.gitignore b/.gitignore index 9a5d811..1f94e11 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ out*/ *.stl *.txt *.mp4 +*.notes environments/ screenshots/ diff --git a/train.py b/train.py index 4641afd..640e76a 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,6 @@ import logging import os from dataclasses import dataclass, field -from functools import partial from typing import Any, Dict, List, Tuple import equinox as eqx @@ -167,14 +166,14 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: self.actor_opt_state = self.actor_optim.init(eqx.filter(self.actor, eqx.is_array)) self.critic_opt_state = self.critic_optim.init(eqx.filter(self.critic, eqx.is_array)) - def get_params(self) -> Tuple[Any]: + def get_params(self) -> Dict[str, Any]: """Get the parameters of the PPO model.""" - return ( - self.actor, - self.critic, - self.actor_opt_state, - self.critic_opt_state, - ) + return { + "actor": self.actor, + "critic": self.critic, + "actor_opt_state": self.actor_opt_state, + "critic_opt_state": self.critic_opt_state, + } def update_params(self, new_params: Dict[str, Any]) -> None: """Update the parameters of the PPO model.""" @@ -197,7 +196,7 @@ def apply_actor(actor: Critic, state: Array) -> Array: def train_step( actor_optim: optax.GradientTransformation, critic_optim: optax.GradientTransformation, - params: Tuple[Any], + params: Dict[str, Any], states_b: Array, actions_b: Array, returns_b: Array, @@ -206,7 +205,7 @@ def train_step( config: Config, ) -> Tuple[Dict[str, Any], Tuple[Array, Array]]: """Perform a single training step with PPO parameters.""" - actor, critic, actor_opt_state, critic_opt_state = params + actor, critic, actor_opt_state, critic_opt_state = params.values() actor_vmap = jax.vmap(apply_actor, in_axes=(None, 0)) critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) @@ -214,8 +213,8 @@ def train_step( # Normalizing advantages *in minibatch* according to Trick #7 advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-4) - @partial(eqx.filter_value_and_grad, has_aux=True) - def actor_loss_fn(actor: Actor, states_b, actions_b, advants_b) -> Array: + @eqx.filter_value_and_grad + def actor_loss_fn(actor: Actor) -> Array: """Prioritizing advantageous actions over more training.""" mu_b, std_b = actor_vmap(actor, states_b) new_log_prob_b = actor_log_prob(mu_b, std_b, actions_b) @@ -228,48 +227,33 @@ def actor_loss_fn(actor: Actor, states_b, actions_b, advants_b) -> Array: clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b # Choosing the smallest magnitude loss - actor_loss = -jnp.mean( - jnp.where(jnp.abs(surrogate_loss_b) < jnp.abs(clipped_loss_b), surrogate_loss_b, clipped_loss_b) - ) - - # Keeping track of number of losses that are clipped - clipped_losses = jnp.abs(surrogate_loss_b) >= jnp.abs(clipped_loss_b) - fraction_clipped = jnp.mean(clipped_losses) + actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-4) ** 2) + 1)) total_loss = actor_loss - config.entropy_coeff * entropy_loss - return total_loss, ( - ratio_b, - entropy_loss, - actor_loss, - surrogate_loss_b, - clipped_loss_b, - fraction_clipped, - new_log_prob_b, - old_log_prob_b, - ) + return total_loss @eqx.filter_value_and_grad - def critic_loss_fn(critic: Critic, states_b, returns_b) -> Array: + def critic_loss_fn(critic: Critic) -> Array: """Prioritizing being able to predict the ground truth returns.""" critic_returns_b = critic_vmap(critic, states_b).squeeze() critic_loss = jnp.mean((critic_returns_b - returns_b) ** 2) return critic_loss # Calculating actor loss and updating actor parameters --- outputting auxillary data for logging - (actor_loss, values), actor_grads = actor_loss_fn(actor, states_b, actions_b, advants_b) + actor_loss, actor_grads = actor_loss_fn(actor) actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) if config.minimal: breakpoint() - if jnp.any(jnp.isnan(actor_loss)) or jnp.any(jnp.isinf(values[0])): + if jnp.any(jnp.isnan(actor_loss)): breakpoint() # Calculating critic loss and updating critic parameters - critic_loss, critic_grads = critic_loss_fn(critic, states_b, returns_b) + critic_loss, critic_grads = critic_loss_fn(critic) critic_updates, new_critic_opt_state = critic_optim.update(critic_grads, critic_opt_state, params=critic) new_critic = eqx.apply_updates(critic, critic_updates) @@ -386,7 +370,6 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array: """Calculate the log probability of the actions given the actor network's output.""" - # Summing across the number of actions after logpdf of each relative to mu/sigma, determined by state return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-4).sum(axis=1) diff --git a/training.notes b/training.notes deleted file mode 100644 index d746e14..0000000 --- a/training.notes +++ /dev/null @@ -1,24 +0,0 @@ - -# Currently tests: -- Hidden layer size of 256 shows progress (loss is based on state.q[2]) - -- setting std to zero makes rewards nans why. I wonder if there NEEDS to be randomization in the enviornment - -- ctrl cost is whats giving nans? interesting? -- it is unrelated to randomization of enviornmnet. i think gradient related - -- first thing to become nans seems to be actor loss and scores. after that, everything becomes nans - -- fixed entropy epsilon. hope this works now. - - -- training clipped mus right now - -- `equinox.internal.debug_backward_nan` -- JAX_DEBUG_NANS=1 - -- lr = 0 still nan in minimum example -- problem maybe in gradient clipping? as per https://github.com/google/jax/discussions/6440 -- ^nope - -- i think its a problem with the clipping \ No newline at end of file