Skip to content

Commit

Permalink
fully walking
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanjzhao committed Aug 24, 2024
1 parent 0db2690 commit 01b8b55
Show file tree
Hide file tree
Showing 3 changed files with 382 additions and 171 deletions.
193 changes: 152 additions & 41 deletions environment.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Definition of base humanoids environment with reward system and termination conditions."""

import argparse
from functools import partial
import logging
import os
import shutil
import subprocess
import tempfile
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any

Expand All @@ -19,23 +18,25 @@
from brax.io import mjcf
from brax.mjx.base import State as MjxState

logger = logging.getLogger(__name__)

logger = logging.getLogger(__name__)

@dataclass
class RewardConfig:
termination_height: float = field(default=-0.2)
height_min_z: float = field(default=-0.2)
height_max_z: float = field(default=2.0)
is_healthy_reward: float = field(default=5)
original_pos_reward_exp_coefficient: float = field(default=2)
original_pos_reward_subtraction_factor: float = field(default=0.2)
original_pos_reward_max_diff_norm: float = field(default=0.5)
ctrl_cost_coefficient: float = field(default=0.1)
weights_ctrl_cost: float = field(default=0.1)
weights_original_pos_reward: float = field(default=1)
weights_is_healthy: float = field(default=5)
weights_velocity: float = field(default=1.25)
REWARD_CONFIG = {
"termination_height": -0.2,
"height_limits": {"min_z": -0.2, "max_z": 2.0},
"is_healthy_reward": 5,
"original_pos_reward": {
"exp_coefficient": 2,
"subtraction_factor": 0.05,
"max_diff_norm": 0.5,
},
"weights": {
"ctrl_cost": 0.1,
"original_pos_reward": 4,
"is_healthy": 1,
"velocity": 1.25,
},
}


REPO_DIR = "humanoid_original" # humanoid_original or stompy or dora
Expand Down Expand Up @@ -65,7 +66,9 @@ def download_model_files(repo_url: str, repo_dir: str, local_path: str) -> None:

# Check if the target directory already exists
if target_path.exists():
logger.info(f"Model files are already present in {target_path}. Skipping download.")
logger.info(
f"Model files are already present in {target_path}. Skipping download."
)
return

# Create a temporary directory for cloning
Expand Down Expand Up @@ -124,9 +127,9 @@ def __init__(self, n_frames: int = PHYSICS_FRAMES, backend: str = "mjx") -> None
mj_model: mujoco.MjModel = mujoco.MjModel.from_xml_path(xml_path)

# can definitely look at this more https://mujoco.readthedocs.io/en/latest/APIreference/APItypes.html#mjtdisablebit
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
mj_model.opt.iterations = 1
mj_model.opt.ls_iterations = 4
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6

self._action_size = mj_model.nu
sys: base.System = mjcf.load_model(mj_model)
Expand All @@ -135,7 +138,9 @@ def __init__(self, n_frames: int = PHYSICS_FRAMES, backend: str = "mjx") -> None

try:
if KEYFRAME_NAME:
self.initial_qpos = jnp.array(mj_model.keyframe(self.keyframe_name).qpos)
self.initial_qpos = jnp.array(
mj_model.keyframe(self.keyframe_name).qpos
)
except:
self.initial_qpos = jnp.array(sys.qpos0)
print("No keyframe found, utilizing qpos0")
Expand All @@ -153,7 +158,9 @@ def reset(self, rng: jnp.ndarray) -> State:
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self.reset_noise_scale, self.reset_noise_scale
qpos = self.initial_qpos + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi)
qpos = self.initial_qpos + jax.random.uniform(
rng1, (self.sys.nq,), minval=low, maxval=hi
)
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)

# initialize mjx state
Expand All @@ -176,14 +183,18 @@ def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State
state = env_state.pipeline_state
metrics = env_state.metrics

state_step = self.pipeline_step(state, action) # because scaled action so bad...
state_step = self.pipeline_step(
state, action
) # because scaled action so bad...
obs_state = self.get_obs(state, action)

# reset env if done
rng, rng1, rng2 = jax.random.split(rng, 3)
low, hi = -self.reset_noise_scale, self.reset_noise_scale

qpos = self.initial_qpos + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi)
qpos = self.initial_qpos + jax.random.uniform(
rng1, (self.sys.nq,), minval=low, maxval=hi
)
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)
state_reset = self.pipeline_init(qpos, qvel)
obs_reset = self.get_obs(state, jnp.zeros(self._action_size))
Expand All @@ -202,7 +213,9 @@ def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State
done = jnp.logical_or(done, any_nan)

# selectively replace state/obs with reset environment based on if done
new_state = jax.tree.map(lambda x, y: jax.lax.select(done, x, y), state_reset, state_step)
new_state = jax.tree.map(
lambda x, y: jax.lax.select(done, x, y), state_reset, state_step
)
obs = jax.lax.select(done, obs_reset, obs_state)

########### METRIC TRACKING ###########
Expand All @@ -223,7 +236,9 @@ def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State
metrics["timestep"] = metrics["timestep"] + 1
metrics["returned_episode"] = done

return env_state.replace(pipeline_state=new_state, obs=obs, reward=reward, done=done, metrics=metrics)
return env_state.replace(
pipeline_state=new_state, obs=obs, reward=reward, done=done, metrics=metrics
)

@partial(jax.jit, static_argnums=(0,))
def compute_reward(
Expand All @@ -238,18 +253,42 @@ def compute_reward(
REWARD_CONFIG["height_limits"]["max_z"],
)

exp_coef, subtraction_factor, max_diff_norm = (
REWARD_CONFIG["original_pos_reward"]["exp_coefficient"],
REWARD_CONFIG["original_pos_reward"]["subtraction_factor"],
REWARD_CONFIG["original_pos_reward"]["max_diff_norm"],
)

# MAINTAINING ORIGINAL POSITION REWARD
qpos0_diff = self.initial_qpos - state.qpos
original_pos_reward = jnp.exp(
-exp_coef * jnp.linalg.norm(qpos0_diff)
) - subtraction_factor * jnp.clip(jnp.linalg.norm(qpos0_diff), 0, max_diff_norm)

# HEALTHY REWARD
is_healthy = jnp.where(state.q[2] < min_z, 0.0, 1.0)
is_healthy = jnp.where(state.q[2] > max_z, 0.0, is_healthy)

ctrl_cost = -jnp.sum(jnp.square(action))

xpos = state.subtree_com[1][0]
next_xpos = next_state.subtree_com[1][0]
velocity = (next_xpos - xpos) / self.dt


# Calculate and print each weight * reward pairing
ctrl_cost_weighted = REWARD_CONFIG["weights"]["ctrl_cost"] * ctrl_cost
original_pos_reward_weighted = REWARD_CONFIG["weights"]["original_pos_reward"] * original_pos_reward
velocity_weighted = REWARD_CONFIG["weights"]["velocity"] * velocity
is_healthy_weighted = REWARD_CONFIG["weights"]["is_healthy"] * is_healthy

# jax.debug.print("ctrl_cost_weighted: {}, original_pos_reward_weighted: {}, is_healthy_weighted {}, velocity_weighted: {}", ctrl_cost_weighted, original_pos_reward_weighted, is_healthy_weighted, velocity_weighted)

total_reward = (
REWARD_CONFIG["weights"]["ctrl_cost"] * ctrl_cost
+ REWARD_CONFIG["weights"]["ctrl_cost"] * velocity
+ REWARD_CONFIG["weights"]["is_healthy"] * is_healthy
ctrl_cost_weighted
+ original_pos_reward_weighted
+ velocity_weighted
+ is_healthy_weighted
)

return total_reward
Expand All @@ -264,7 +303,9 @@ def is_done(self, state: MjxState) -> bool:
REWARD_CONFIG["height_limits"]["min_z"],
REWARD_CONFIG["height_limits"]["max_z"],
)
height_condition = jnp.logical_not(jnp.logical_and(min_z < com_height, com_height < max_z))
height_condition = jnp.logical_not(
jnp.logical_and(min_z < com_height, com_height < max_z)
)

# Check if any element in qvel or qacc exceeds 1e5
velocity_condition = jnp.any(jnp.abs(state.qvel) > 1e5)
Expand Down Expand Up @@ -296,6 +337,9 @@ def get_obs(self, data: MjxState, action: jnp.ndarray) -> jnp.ndarray:
return jnp.concatenate(obs_components)


################## TEST ENVIRONMENT RUN ##################


def run_environment_adhoc() -> None:
"""Runs the environment for a few steps with random actions, for debugging."""
try:
Expand All @@ -305,15 +349,21 @@ def run_environment_adhoc() -> None:
raise ImportError("Please install `mediapy` and `tqdm` to run this script")

parser = argparse.ArgumentParser()
parser.add_argument("--actor_path", type=str, default="actor_params.pkl", help="path to actor model")
parser.add_argument(
"--actor_path", type=str, default="actor_params.pkl", help="path to actor model"
)
parser.add_argument(
"--critic_path",
type=str,
default="critic_params.pkl",
help="path to critic model",
)
parser.add_argument("--num_episodes", type=int, default=20, help="number of episodes to run")
parser.add_argument("--max_steps", type=int, default=1024, help="maximum steps per episode")
parser.add_argument(
"--num_episodes", type=int, default=20, help="number of episodes to run"
)
parser.add_argument(
"--max_steps", type=int, default=1024, help="maximum steps per episode"
)
parser.add_argument(
"--env_name",
type=str,
Expand All @@ -332,8 +382,12 @@ def run_environment_adhoc() -> None:
default=5.0,
help="desired length of video in seconds",
)
parser.add_argument("--width", type=int, default=640, help="width of the video frame")
parser.add_argument("--height", type=int, default=480, help="height of the video frame")
parser.add_argument(
"--width", type=int, default=640, help="width of the video frame"
)
parser.add_argument(
"--height", type=int, default=480, help="height of the video frame"
)
args = parser.parse_args()

env = HumanoidEnv()
Expand All @@ -356,15 +410,19 @@ def run_environment_adhoc() -> None:

total_reward = 0

for step in tqdm(range(args.max_steps), desc=f"Episode {episode + 1} Steps", leave=False):
for step in tqdm(
range(args.max_steps), desc=f"Episode {episode + 1} Steps", leave=False
):
if len(rollout) < args.video_length * fps:
rollout.append(state.pipeline_state)

#### STORED METRICS ####
metrics["qpos_2"].append(state.pipeline_state.qpos[2])

rng, action_rng = jax.random.split(rng)
action = jax.random.uniform(action_rng, (action_size,), minval=0, maxval=1.0) # placeholder for an action
action = jax.random.uniform(
action_rng, (action_size,), minval=0, maxval=1.0
) # placeholder for an action

rng, step_rng = jax.random.split(rng)
state = step_fn(state, action, step_rng)
Expand All @@ -378,18 +436,71 @@ def run_environment_adhoc() -> None:
if len(rollout) >= max_frames:
break

logger.info("Rendering video with %d frames at %d fps", len(rollout), fps)
print(f"Rendering video with {len(rollout)} frames at {fps} fps")
images = jnp.array(
env.render(
rollout[:: args.render_every],
# camera="side",
camera="side",
width=args.width,
height=args.height,
)
)

logger.info("Saving video to %s", video_path)
print("Video rendered")
media.write_video(video_path, images, fps=fps)
print(f"Video saved to {video_path}")

###### ADDS DEBUG TEXT ON TOP OF VIDEO ######

# video_path = video_path
# cap = cv2.VideoCapture(video_path)

# if not cap.isOpened():
# print(f"Error: Could not open video {video_path}")

# # Get video properties
# frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fps = cap.get(cv2.CAP_PROP_FPS)

# # Define the codec and create VideoWriter object
# debug_video_path = args.env_name + "_debug.mp4"
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# out = cv2.VideoWriter(debug_video_path, fourcc, fps, (frame_width, frame_height))

# # Loop through each frame
# frame_index = 0

# if not out.isOpened():
# print("Error: Could not open VideoWriter")

# while cap.isOpened():
# ret, frame = cap.read()
# if not ret:
# break

# # Write each metric on the frame
# font = cv2.FONT_HERSHEY_SIMPLEX
# font_scale = 1
# color = (255, 0, 0) # Blue color in BGR
# thickness = 2
# y_offset = 50 # Initial y position for the text

# for key, values in metrics.items():
# if frame_index < len(values):
# text = f'{key}: {values[frame_index]}'
# org = (50, y_offset) # Position for the text
# frame = cv2.putText(frame, text, org, font, font_scale, color, thickness, cv2.LINE_AA)
# y_offset += 30 # Move to the next line for the next metric

# # Write the frame into the output video
# out.write(frame)
# frame_index += 1

# # Release everything if job is finished
# cap.release()
# out.release()
# cv2.destroyAllWindows()


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 01b8b55

Please sign in to comment.