Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scanning #21

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ out*/
*.stl
*.txt
*.mp4
*.notes

environments/
screenshots/
scratch/
assets/
wandb/
models/
logs/
10 changes: 5 additions & 5 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class HumanoidEnv(PipelineEnv):

initial_qpos: jp.ndarray
_action_size: int
reset_noise_scale: float = 0
reset_noise_scale: float = 2e-4

def __init__(self, n_frames: int = 1) -> None:
"""Initializes system with initial joint positions, action size, the model, and update rate."""
Expand Down Expand Up @@ -137,7 +137,7 @@ def step(self, env_state: State, action: jp.ndarray) -> State:

def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarray) -> jp.ndarray:
"""Compute the reward for standing and height."""
min_z, max_z = 0.7, 2.0
min_z, max_z = 0, 2.0
# min_z, max_z = -0.35, 2.0
is_healthy = jp.where(state.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(state.q[2] > max_z, 0.0, is_healthy)
Expand All @@ -155,9 +155,9 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
# next_xpos,
# ordered=True,
# )
# jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True)
# jax.debug.print("ctrl_cost {}, is_healthy {}, height {}", ctrl_cost, is_healthy, state.q[2], ordered=True)

total_reward = 0.1 * ctrl_cost + 5 * is_healthy
total_reward = 0.1 * ctrl_cost + 5 * state.q[2]

return total_reward

Expand All @@ -167,7 +167,7 @@ def is_done(self, state: MjxState) -> jp.ndarray:
com_height = state.q[2]

# Set a termination threshold
termination_height = 0.7
termination_height = 0
# termination_height = -0.35

# Episode is done if the robot falls below the termination height
Expand Down
28 changes: 17 additions & 11 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import jax
import jax.numpy as jnp
import mediapy as media
from jax import Array

from environment import HumanoidEnv
from train import Actor, choose_action
from train import Actor, Critic, choose_action


def load_models(actor_path: str, critic_path: str) -> Tuple[Array, Array]:
def load_models(actor_path: str, critic_path: str) -> Tuple[Actor, Critic]:
"""Loads the pretrained actor and critic models from paths."""
with open(actor_path, "rb") as f:
actor_params = pickle.load(f)
Expand All @@ -25,22 +24,28 @@ def load_models(actor_path: str, critic_path: str) -> Tuple[Array, Array]:
def main() -> None:
"""Runs inference with pretrained models."""
parser = argparse.ArgumentParser()
parser.add_argument("--actor_path", type=str, default="actor_params.pkl", help="path to actor model")
parser.add_argument("--critic_path", type=str, default="critic_params.pkl", help="path to critic model")
parser.add_argument("--num_episodes", type=int, default=1, help="number of episodes to run")
parser.add_argument(
"--actor_path",
type=str,
default="models/more_envs_actor_70.pkl",
help="path to actor model",
)
parser.add_argument(
"--critic_path", type=str, default="models/height_based_reward_critic_1090.pkl", help="path to critic model"
)
parser.add_argument("--num_episodes", type=int, default=10, help="number of episodes to run")
parser.add_argument("--max_steps", type=int, default=1000, help="maximum steps per episode")
parser.add_argument("--video_path", type=str, default="inference_video.mp4", help="path to save video")
parser.add_argument("--render_every", type=int, default=2, help="how many frames to skip between renders")
parser.add_argument("--video_length", type=float, default=10.0, help="desired length of video in seconds")
parser.add_argument("--video_length", type=float, default=20.0, help="desired length of video in seconds")
parser.add_argument("--width", type=int, default=640, help="width of the video frame")
parser.add_argument("--height", type=int, default=480, help="height of the video frame")
args = parser.parse_args()

env = HumanoidEnv()
rng = jax.random.PRNGKey(0)

actor_params, _ = load_models(args.actor_path, args.critic_path)
actor = Actor(input_size=env.observation_size, action_size=env.action_size, key=rng)
actor, _ = load_models(args.actor_path, args.critic_path)

reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)
Expand All @@ -50,7 +55,7 @@ def main() -> None:
rollout: list[Any] = []

for episode in range(args.num_episodes):
rng, reset_rng = jax.rand.split(rng)
rng, reset_rng = jax.random.split(rng)
state = reset_fn(reset_rng)
obs = state.obs

Expand All @@ -60,7 +65,8 @@ def main() -> None:
if len(rollout) < max_frames:
rollout.append(state.pipeline_state)

action = choose_action(actor_params, obs, actor)
rng, action_rng = jax.random.split(rng)
action = choose_action(actor, obs, action_rng)
state = step_fn(state, action)
obs = state.obs
total_reward += state.reward
Expand Down
Loading
Loading