From 1541567ae0477ae9f293599c8e0f3bcdb2de8446 Mon Sep 17 00:00:00 2001 From: Benjamin Bolte Date: Tue, 15 Oct 2024 22:52:42 -0700 Subject: [PATCH] less untyped stuff --- environment.py | 133 ++++++++++++++++++++++++++++++------------------- 1 file changed, 83 insertions(+), 50 deletions(-) diff --git a/environment.py b/environment.py index c612d8c..b694405 100644 --- a/environment.py +++ b/environment.py @@ -15,7 +15,7 @@ import jax.numpy as jnp import mujoco from brax import base -from brax.envs.base import PipelineEnv, State +from brax.envs.base import PipelineEnv from brax.io import mjcf from brax.mjx.base import State as MjxState from kscale import KScale @@ -44,11 +44,8 @@ class RewardConfig: def load_mjcf_model(kscale_id: str) -> mujoco.MjModel: - if 1 < 0: - api = KScale() - mjcf_path = asyncio.run(api.mjcf_path(kscale_id)) - else: - mjcf_path = Path("/Users/ben/Downloads/robot/robot.mjcf") + api = KScale() + mjcf_path = asyncio.run(api.mjcf_path(kscale_id)) # We need to fix up the MJCF model to allow it to work with Brax. with tempfile.TemporaryDirectory() as temp_dir: @@ -70,6 +67,52 @@ def load_mjcf_model(kscale_id: str) -> mujoco.MjModel: return model +@dataclass +class EnvMetrics: + episode_returns: jnp.ndarray + episode_lengths: jnp.ndarray + returned_episode_returns: jnp.ndarray + returned_episode_lengths: jnp.ndarray + timestep: jnp.ndarray + returned_episode: jnp.ndarray + + def tree_flatten(self) -> tuple[list[jnp.ndarray], None]: + return ( + self.episode_returns, + self.episode_lengths, + self.returned_episode_returns, + self.returned_episode_lengths, + self.timestep, + self.returned_episode, + ), None + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: list[jnp.ndarray]) -> "EnvMetrics": # noqa: ANN401 + return cls(*children) + + +jax.tree_util.register_pytree_node_class(EnvMetrics) + + +@dataclass +class EnvState: + pipeline_state: Any # Use Any for MjxState as it's not a standard JAX type + obs: jnp.ndarray + reward: jnp.ndarray + done: jnp.ndarray + metrics: EnvMetrics + + def tree_flatten(self) -> tuple[list[Any], None]: + return (self.pipeline_state, self.obs, self.reward, self.done, self.metrics), None + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: list[Any]) -> "EnvState": # noqa: ANN401 + return cls(*children) + + +jax.tree_util.register_pytree_node_class(EnvState) + + class HumanoidEnv(PipelineEnv): """Defines the environment for controlling a humanoid robot. @@ -118,7 +161,7 @@ def __init__( self.actuator_ctrlrange = jnp.array(actuator_ctrlrange) @partial(jax.jit, static_argnums=(0,)) - def reset(self, rng: jnp.ndarray) -> State: + def reset(self, rng: jnp.ndarray) -> EnvState: """Resets the environment to an initial state.""" rng, rng1, rng2 = jax.random.split(rng, 3) @@ -129,24 +172,25 @@ def reset(self, rng: jnp.ndarray) -> State: # initialize mjx state state = self.pipeline_init(qpos, qvel) obs = self.get_obs(state, jnp.zeros(self._action_size)) - metrics = { - "episode_returns": 0, - "episode_lengths": 0, - "returned_episode_returns": 0, - "returned_episode_lengths": 0, - "timestep": 0, - "returned_episode": False, - } - return State(state, obs, jnp.array(0.0), False, metrics) + metrics = EnvMetrics( + episode_returns=jnp.array(0.0), + episode_lengths=jnp.array(0), + returned_episode_returns=jnp.array(0.0), + returned_episode_lengths=jnp.array(0), + timestep=jnp.array(0), + returned_episode=jnp.array(False), + ) + + return EnvState(state, obs, jnp.array(0.0), jnp.array(False), metrics) @partial(jax.jit, static_argnums=(0,)) - def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State: + def step(self, env_state: EnvState, action: jnp.ndarray, rng: jnp.ndarray) -> EnvState: """Run one timestep of the environment's dynamics and returns observations with rewards.""" state = env_state.pipeline_state metrics = env_state.metrics - state_step = self.pipeline_step(state, action) # because scaled action so bad... + state_step = self.pipeline_step(state, action) obs_state = self.get_obs(state, action) # reset env if done @@ -159,11 +203,7 @@ def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State obs_reset = self.get_obs(state, jnp.zeros(self._action_size)) # get obs/reward/done of action + states - reward = self.compute_reward( - state, - state_step, - action, - ) + reward = self.compute_reward(state, state_step, action) done = self.is_done(state_step) # setting done = True if nans in next state @@ -175,25 +215,21 @@ def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State new_state = jax.tree.map(lambda x, y: jax.lax.select(done, x, y), state_reset, state_step) obs = jax.lax.select(done, obs_reset, obs_state) - ########### METRIC TRACKING ########### - # Calculate new episode return and length - new_episode_return = metrics["episode_returns"] + reward - new_episode_length = metrics["episode_lengths"] + 1 - - # Update metrics -- we only count episode - metrics["episode_returns"] = new_episode_return * (1 - done) - metrics["episode_lengths"] = new_episode_length * (1 - done) - metrics["returned_episode_returns"] = ( - metrics["returned_episode_returns"] * (1 - done) + new_episode_return * done - ) - metrics["returned_episode_lengths"] = ( - metrics["returned_episode_lengths"] * (1 - done) + new_episode_length * done + new_episode_return = metrics.episode_returns + reward + new_episode_length = metrics.episode_lengths + 1 + + # Update metrics + new_metrics = EnvMetrics( + episode_returns=new_episode_return * (1 - done), + episode_lengths=new_episode_length * (1 - done), + returned_episode_returns=metrics.returned_episode_returns * (1 - done) + new_episode_return * done, + returned_episode_lengths=metrics.returned_episode_lengths * (1 - done) + new_episode_length * done, + timestep=metrics.timestep + 1, + returned_episode=done, ) - metrics["timestep"] = metrics["timestep"] + 1 - metrics["returned_episode"] = done - return env_state.replace(pipeline_state=new_state, obs=obs, reward=reward, done=done, metrics=metrics) + return EnvState(new_state, obs, reward, done, new_metrics) @partial(jax.jit, static_argnums=(0,)) def compute_reward( @@ -202,8 +238,8 @@ def compute_reward( next_state: MjxState, action: jnp.ndarray, ) -> jnp.ndarray: - """Compute the reward for standing and height.""" - min_z, max_z = self.reward_config.height_min_z, self.reward_config.height_max_z + min_z = self.reward_config.height_min_z + max_z = self.reward_config.height_max_z exp_coef = self.reward_config.original_pos_reward_exp_coefficient subtraction_factor = self.reward_config.original_pos_reward_subtraction_factor @@ -238,12 +274,9 @@ def compute_reward( @partial(jax.jit, static_argnums=(0,)) def is_done(self, state: MjxState) -> jnp.ndarray: """Check if the episode should terminate.""" - # Get the height of the robot's center of mass com_height = state.q[2] - min_z, max_z = self.reward_config.height_min_z, self.reward_config.height_max_z height_condition = jnp.logical_not(jnp.logical_and(min_z < com_height, com_height < max_z)) - return height_condition @partial(jax.jit, static_argnums=(0,)) @@ -343,26 +376,26 @@ def run_environment_adhoc() -> None: fps = int(1 / env.dt) max_frames = int(args.video_length * fps) - rollout: list[Any] = [] + rollout: list[MjxState] = [] for episode in range(args.num_episodes): rng, _ = jax.random.split(rng) - state = reset_fn(rng) + env_state = reset_fn(rng) total_reward = 0 for _ in tqdm(range(args.max_steps), desc=f"Episode {episode + 1} Steps", leave=False): if len(rollout) < args.video_length * fps: - rollout.append(state.pipeline_state) + rollout.append(env_state.pipeline_state) rng, action_rng = jax.random.split(rng) action = jax.random.uniform(action_rng, (action_size,), minval=0, maxval=1.0) rng, step_rng = jax.random.split(rng) - state = step_fn(state, action, step_rng) - total_reward += state.reward + env_state = step_fn(env_state, action, step_rng) + total_reward += env_state.reward - if state.done: + if env_state.done: break logger.info("Episode %d total reward: %f", episode + 1, total_reward)