Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scanning and gae #20

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ out*/
*.stl
*.txt
*.mp4
*.notes

environments/
screenshots/
scratch/
assets/
wandb/
models/
4 changes: 2 additions & 2 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class HumanoidEnv(PipelineEnv):

initial_qpos: jp.ndarray
_action_size: int
reset_noise_scale: float = 0
reset_noise_scale: float = 1e-4

def __init__(self, n_frames: int = 1) -> None:
"""Initializes system with initial joint positions, action size, the model, and update rate."""
Expand Down Expand Up @@ -155,7 +155,7 @@ def compute_reward(self, state: MjxState, next_state: MjxState, action: jp.ndarr
# next_xpos,
# ordered=True,
# )
# jax.debug.print("is_healthy {}, height {}", is_healthy, state.q[2], ordered=True)
# jax.debug.print("ctrl_cost {}, is_healthy {}, height {}", ctrl_cost, is_healthy, state.q[2], ordered=True)

total_reward = 0.1 * ctrl_cost + 5 * is_healthy

Expand Down
25 changes: 16 additions & 9 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from jax import Array

from environment import HumanoidEnv
from train import Actor, choose_action
from train import choose_action, Actor, Critic


def load_models(actor_path: str, critic_path: str) -> Tuple[Array, Array]:
def load_models(actor_path: str, critic_path: str) -> Tuple[Actor, Critic]:
"""Loads the pretrained actor and critic models from paths."""
with open(actor_path, "rb") as f:
actor_params = pickle.load(f)
Expand All @@ -25,9 +25,16 @@ def load_models(actor_path: str, critic_path: str) -> Tuple[Array, Array]:
def main() -> None:
"""Runs inference with pretrained models."""
parser = argparse.ArgumentParser()
parser.add_argument("--actor_path", type=str, default="actor_params.pkl", help="path to actor model")
parser.add_argument("--critic_path", type=str, default="critic_params.pkl", help="path to critic model")
parser.add_argument("--num_episodes", type=int, default=1, help="number of episodes to run")
parser.add_argument(
"--actor_path",
type=str,
default="models/im_goated_actor_20.pkl",
help="path to actor model",
)
parser.add_argument(
"--critic_path", type=str, default="models/height_based_reward_critic_1090.pkl", help="path to critic model"
)
parser.add_argument("--num_episodes", type=int, default=10, help="number of episodes to run")
parser.add_argument("--max_steps", type=int, default=1000, help="maximum steps per episode")
parser.add_argument("--video_path", type=str, default="inference_video.mp4", help="path to save video")
parser.add_argument("--render_every", type=int, default=2, help="how many frames to skip between renders")
Expand All @@ -39,8 +46,7 @@ def main() -> None:
env = HumanoidEnv()
rng = jax.random.PRNGKey(0)

actor_params, _ = load_models(args.actor_path, args.critic_path)
actor = Actor(input_size=env.observation_size, action_size=env.action_size, key=rng)
actor, _ = load_models(args.actor_path, args.critic_path)

reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)
Expand All @@ -50,7 +56,7 @@ def main() -> None:
rollout: list[Any] = []

for episode in range(args.num_episodes):
rng, reset_rng = jax.rand.split(rng)
rng, reset_rng = jax.random.split(rng)
state = reset_fn(reset_rng)
obs = state.obs

Expand All @@ -60,7 +66,8 @@ def main() -> None:
if len(rollout) < max_frames:
rollout.append(state.pipeline_state)

action = choose_action(actor_params, obs, actor)
rng, action_rng = jax.random.split(rng)
action = choose_action(actor, obs, action_rng)
state = step_fn(state, action)
obs = state.obs
total_reward += state.reward
Expand Down
Loading
Loading