diff --git a/.gitignore b/.gitignore index 9a5d811..1f94e11 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ out*/ *.stl *.txt *.mp4 +*.notes environments/ screenshots/ diff --git a/train.py b/train.py index c505c83..640e76a 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,6 @@ import logging import os from dataclasses import dataclass, field -from functools import partial from typing import Any, Dict, List, Tuple import equinox as eqx @@ -45,6 +44,7 @@ class Config: epsilon: float = field(default=0.2, metadata={"help": "Clipping parameter for PPO."}) l2_rate: float = field(default=0.001, metadata={"help": "L2 regularization rate for the critic."}) entropy_coeff: float = field(default=0.01, metadata={"help": "Coefficient for entropy loss."}) + minimal: bool = field(default=False, metadata={"help": "Make minimal PPO (no std) for breakpoint debugging"}) class Actor(eqx.Module): @@ -72,7 +72,10 @@ def __call__(self, x: Array) -> Tuple[Array, Array]: x = jax.nn.tanh(self.linear1(x)) x = jax.nn.tanh(self.linear2(x)) mu = self.mu_layer(x) + log_sigma = self.log_sigma_layer(x) + + # return mu, jnp.zeros_like(log_sigma) return mu, jnp.exp(log_sigma) def initialize_layer(self, layer: eqx.nn.Linear, scale: float, key: Array) -> eqx.nn.Linear: @@ -147,19 +150,15 @@ def __init__(self, observation_size: int, action_size: int, config: Config, key: # Learning rate annealing according to Trick #4 self.actor_schedule = optax.linear_schedule( - init_value=config.lr_actor, end_value=1e-6, transition_steps=total_timesteps + init_value=config.lr_actor, end_value=0, transition_steps=total_timesteps ) self.critic_schedule = optax.linear_schedule( - init_value=config.lr_critic, end_value=1e-6, transition_steps=total_timesteps + init_value=config.lr_critic, end_value=0, transition_steps=total_timesteps ) - clip_threshold = 1.0 # eps below according to Trick #3 - self.actor_optim = optax.chain( - optax.clip_by_global_norm(clip_threshold), optax.adam(learning_rate=self.actor_schedule, eps=1e-5) - ) + self.actor_optim = optax.chain(optax.adam(learning_rate=self.actor_schedule, eps=1e-5)) self.critic_optim = optax.chain( - optax.clip_by_global_norm(clip_threshold), optax.adamw(learning_rate=self.critic_schedule, weight_decay=config.l2_rate, eps=1e-5), ) @@ -212,7 +211,7 @@ def train_step( critic_vmap = jax.vmap(apply_critic, in_axes=(None, 0)) # Normalizing advantages *in minibatch* according to Trick #7 - advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-8) + advants_b = (advants_b - advants_b.mean()) / (advants_b.std() + 1e-4) @eqx.filter_value_and_grad def actor_loss_fn(actor: Actor) -> Array: @@ -228,11 +227,9 @@ def actor_loss_fn(actor: Actor) -> Array: clipped_loss_b = jnp.clip(ratio_b, 1.0 - config.epsilon, 1.0 + config.epsilon) * advants_b # Choosing the smallest magnitude loss - actor_loss = -jnp.mean( - jnp.where(jnp.abs(surrogate_loss_b) < jnp.abs(clipped_loss_b), surrogate_loss_b, clipped_loss_b) - ) + actor_loss = -jnp.mean(jnp.minimum(surrogate_loss_b, clipped_loss_b)) - entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-8) ** 2) + 1)) + entropy_loss = jnp.mean(0.5 * (jnp.log(2 * jnp.pi * (std_b + 1e-4) ** 2) + 1)) total_loss = actor_loss - config.entropy_coeff * entropy_loss return total_loss @@ -249,6 +246,12 @@ def critic_loss_fn(critic: Critic) -> Array: actor_updates, new_actor_opt_state = actor_optim.update(actor_grads, actor_opt_state, params=actor) new_actor = eqx.apply_updates(actor, actor_updates) + if config.minimal: + breakpoint() + + if jnp.any(jnp.isnan(actor_loss)): + breakpoint() + # Calculating critic loss and updating critic parameters critic_loss, critic_grads = critic_loss_fn(critic) critic_updates, new_critic_opt_state = critic_optim.update(critic_grads, critic_opt_state, params=critic) @@ -330,6 +333,7 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con returns_b = returns[batch_indices] advantages_b = advantages[batch_indices] old_log_prob_b = old_log_prob[batch_indices] + # breakpoint() params = ppo.get_params() new_params, (actor_loss, critic_loss) = train_step( @@ -366,12 +370,14 @@ def train(ppo: Ppo, memory: List[Tuple[Array, Array, Array, Array]], config: Con def actor_log_prob(mu: Array, sigma: Array, actions: Array) -> Array: """Calculate the log probability of the actions given the actor network's output.""" - return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-8).sum(axis=-1) + # Summing across the number of actions after logpdf of each relative to mu/sigma, determined by state + return jax.scipy.stats.norm.logpdf(actions, mu, sigma + 1e-4).sum(axis=1) def actor_distribution(mu: Array, sigma: Array, rng: Array) -> Array: """Get an action from the actor network from its probability distribution of actions.""" - return jax.random.normal(rng, shape=mu.shape) * sigma + mu + action = jax.random.normal(rng, shape=mu.shape) * (sigma + 1e-4) + mu + return action def unwrap_state_vectorization(state: State, envs_to_sample: int) -> State: @@ -500,7 +506,7 @@ def step_fn(states: State, actions: jax.Array) -> State: # Normalizing observations quickens learning norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( - jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + jnp.std(states.obs, axis=1, keepdims=True) + 1e-4 ) obs = jax.device_put(norm_obs) score = jnp.zeros(config.num_envs) @@ -524,10 +530,14 @@ def step_fn(states: State, actions: jax.Array) -> State: # Normalizing observations quickens learning norm_obs = (states.obs - jnp.mean(states.obs, axis=1, keepdims=True)) / ( - jnp.std(states.obs, axis=1, keepdims=True) + 1e-8 + jnp.std(states.obs, axis=1, keepdims=True) + 1e-4 ) next_obs, rewards, dones = norm_obs, states.reward, states.done - masks = (1 - 0.8 * dones).astype(jnp.float32) + masks = (1 - dones).astype(jnp.float32) + + # Check for NaNs and print them + if jnp.any(jnp.isnan(rewards)): + breakpoint() # Update memory new_data = {"states": obs, "actions": actions, "rewards": rewards, "masks": masks} diff --git a/training.notes b/training.notes deleted file mode 100644 index 66d9976..0000000 --- a/training.notes +++ /dev/null @@ -1,12 +0,0 @@ - -# Currently tests: -- Hidden layer size of 256 shows progress (loss is based on state.q[2]) - -- setting std to zero makes rewards nans why. I wonder if there NEEDS to be randomization in the enviornment - -- ctrl cost is whats giving nans? interesting? -- it is unrelated to randomization of enviornmnet. i think gradient related - -- first thing to become nans seems to be actor loss and scores. after that, everything becomes nans - -- fixed entropy epsilon. hope this works now. \ No newline at end of file