Skip to content

Commit

Permalink
added scanning to make training much quicker + edits for better debug…
Browse files Browse the repository at this point in the history
…ging
  • Loading branch information
nathanjzhao committed Aug 13, 2024
1 parent f80e9f1 commit c9e4c8d
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 119 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ scratch/
assets/
wandb/
models/
logs/
8 changes: 4 additions & 4 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class HumanoidEnv(PipelineEnv):

initial_qpos: jp.ndarray
_action_size: int
reset_noise_scale: float = 1e-4
reset_noise_scale: float = 2e-4

def __init__(self, n_frames: int = 1) -> None:
"""Initializes system with initial joint positions, action size, the model, and update rate."""
Expand Down Expand Up @@ -137,7 +137,7 @@ def step(self, env_state: State, action: jp.ndarray) -> State:

def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarray) -> jp.ndarray:
"""Compute the reward for standing and height."""
min_z, max_z = 0.7, 2.0
min_z, max_z = 0, 2.0
# min_z, max_z = -0.35, 2.0
is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy)
Expand All @@ -157,7 +157,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
# )
# jax.debug.print("ctrl_cost {}, is_healthy {}, height {}", ctrl_cost, is_healthy, state.q[2], ordered=True)

total_reward = 0.1 * ctrl_cost + 5 * is_healthy
total_reward = 0.1 * ctrl_cost + 5 * state.q[2]

return total_reward

Expand All @@ -167,7 +167,7 @@ def is_done(self, state: MjxState) -> jp.ndarray:
com_height = state.q[2]

# Set a termination threshold
termination_height = 0.7
termination_height = 0
# termination_height = -0.35

# Episode is done if the robot falls below the termination height
Expand Down
7 changes: 3 additions & 4 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import jax
import jax.numpy as jnp
import mediapy as media
from jax import Array

from environment import HumanoidEnv
from train import choose_action, Actor, Critic
from train import Actor, Critic, choose_action


def load_models(actor_path: str, critic_path: str) -> Tuple[Actor, Critic]:
Expand All @@ -28,7 +27,7 @@ def main() -> None:
parser.add_argument(
"--actor_path",
type=str,
default="models/im_goated_actor_20.pkl",
default="models/more_envs_actor_70.pkl",
help="path to actor model",
)
parser.add_argument(
Expand All @@ -38,7 +37,7 @@ def main() -> None:
parser.add_argument("--max_steps", type=int, default=1000, help="maximum steps per episode")
parser.add_argument("--video_path", type=str, default="inference_video.mp4", help="path to save video")
parser.add_argument("--render_every", type=int, default=2, help="how many frames to skip between renders")
parser.add_argument("--video_length", type=float, default=10.0, help="desired length of video in seconds")
parser.add_argument("--video_length", type=float, default=20.0, help="desired length of video in seconds")
parser.add_argument("--width", type=int, default=640, help="width of the video frame")
parser.add_argument("--height", type=int, default=480, help="height of the video frame")
args = parser.parse_args()
Expand Down
Loading

0 comments on commit c9e4c8d

Please sign in to comment.