Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aug 6 changes #16

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ dist/
out*/
*.stl
*.txt
*.mp4

videos/
environments/
screenshots/
scratch/
videos/
assets/
wandb/
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ Minimal training and inference code for making a humanoid robot stand up.
- [ ] Implement simple PPO policy to try to make the robot stand up
- [ ] Parallelize using JAX

# Findings
- Low standard deviation for "overfitting test" does not work very well for PPO because need to converge upon actual solution. Cannot get to actual solution if do not explore a little. With that in mind, I think the model is generally getting too comfortable tricking the reward system by just losing as fast as posisble so doesn't have to deal with penalty
- Theory that model just trying to lose (to retain `is_healthy` while not losing as much height. It wants to fall as quick as possible so can reset) can be tested by adding "wait" and removing the mask. This effectively reduces the fact that reset will work. While the model is stuck in the failed state, it still is unhealthy and therefore loses out on reward.

# Currently tests:
- Hidden layer size of 256 shows progress (loss is based on state.q[2])

## Goals

- The goal for this repository is to provide a super minimal implementation of a PPO policy for making a humanoid robot stand up, with only three files:
Expand Down
38 changes: 22 additions & 16 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def step(self, env_state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics and returns observations with rewards."""
state = env_state.pipeline_state
next_state = self.pipeline_step(state, action)

obs = self.get_obs(state, action)

reward = self.compute_reward(state, next_state, action)
Expand All @@ -141,7 +142,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy)

is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0)
# is_bad = jp.where(state.q[2] < min_z + 0.2, 1.0, 0.0)

ctrl_cost = -jp.sum(jp.square(action))

Expand All @@ -158,7 +159,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
# )
# jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True)

total_reward = 2.0 * is_healthy + 0.1 * ctrl_cost - 5.0 * is_bad
total_reward = 0.1 * ctrl_cost + 5 * state.q[2]

return total_reward

Expand All @@ -177,20 +178,25 @@ def is_done(self, state: MjxState) -> jp.ndarray:
return done

def get_obs(self, data: MjxState, action: jp.ndarray) -> jp.ndarray:
"""Returns the observation of the environment to pass to actor/critic model."""
position = data.qpos
position = position[2:] # excludes "current positions"

# external_contact_forces are excluded
return jp.concatenate(
[
position,
data.qvel,
data.cinert[1:].ravel(),
data.cvel[1:].ravel(),
data.qfrc_actuator,
]
)
obs_components = [
data.qpos[2:],
data.qvel,
data.cinert[1:].ravel(),
data.cvel[1:].ravel(),
data.qfrc_actuator,
]

def clean_component(component: jp.ndarray) -> jp.ndarray:
# Check for NaNs or Infs and replace them
nan_mask = jp.isnan(component)
inf_mask = jp.isinf(component)
component = jp.where(nan_mask, 0.0, component)
component = jp.where(inf_mask, jp.where(component > 0, 1e6, -1e6), component)
return component

cleaned_components = [clean_component(comp) for comp in obs_components]

return jp.concatenate(cleaned_components)


def run_environment_adhoc() -> None:
Expand Down
Loading
Loading