Skip to content

Commit

Permalink
less untyped stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Oct 16, 2024
1 parent 7d314e6 commit 1541567
Showing 1 changed file with 83 additions and 50 deletions.
133 changes: 83 additions & 50 deletions environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import jax.numpy as jnp
import mujoco
from brax import base
from brax.envs.base import PipelineEnv, State
from brax.envs.base import PipelineEnv
from brax.io import mjcf
from brax.mjx.base import State as MjxState
from kscale import KScale
Expand Down Expand Up @@ -44,11 +44,8 @@ class RewardConfig:


def load_mjcf_model(kscale_id: str) -> mujoco.MjModel:
if 1 < 0:
api = KScale()
mjcf_path = asyncio.run(api.mjcf_path(kscale_id))
else:
mjcf_path = Path("/Users/ben/Downloads/robot/robot.mjcf")
api = KScale()
mjcf_path = asyncio.run(api.mjcf_path(kscale_id))

# We need to fix up the MJCF model to allow it to work with Brax.
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -70,6 +67,52 @@ def load_mjcf_model(kscale_id: str) -> mujoco.MjModel:
return model


@dataclass
class EnvMetrics:
episode_returns: jnp.ndarray
episode_lengths: jnp.ndarray
returned_episode_returns: jnp.ndarray
returned_episode_lengths: jnp.ndarray
timestep: jnp.ndarray
returned_episode: jnp.ndarray

def tree_flatten(self) -> tuple[list[jnp.ndarray], None]:
return (
self.episode_returns,
self.episode_lengths,
self.returned_episode_returns,
self.returned_episode_lengths,
self.timestep,
self.returned_episode,
), None

@classmethod
def tree_unflatten(cls, aux_data: Any, children: list[jnp.ndarray]) -> "EnvMetrics": # noqa: ANN401
return cls(*children)


jax.tree_util.register_pytree_node_class(EnvMetrics)


@dataclass
class EnvState:
pipeline_state: Any # Use Any for MjxState as it's not a standard JAX type
obs: jnp.ndarray
reward: jnp.ndarray
done: jnp.ndarray
metrics: EnvMetrics

def tree_flatten(self) -> tuple[list[Any], None]:
return (self.pipeline_state, self.obs, self.reward, self.done, self.metrics), None

@classmethod
def tree_unflatten(cls, aux_data: Any, children: list[Any]) -> "EnvState": # noqa: ANN401
return cls(*children)


jax.tree_util.register_pytree_node_class(EnvState)


class HumanoidEnv(PipelineEnv):
"""Defines the environment for controlling a humanoid robot.
Expand Down Expand Up @@ -118,7 +161,7 @@ def __init__(
self.actuator_ctrlrange = jnp.array(actuator_ctrlrange)

@partial(jax.jit, static_argnums=(0,))
def reset(self, rng: jnp.ndarray) -> State:
def reset(self, rng: jnp.ndarray) -> EnvState:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)

Expand All @@ -129,24 +172,25 @@ def reset(self, rng: jnp.ndarray) -> State:
# initialize mjx state
state = self.pipeline_init(qpos, qvel)
obs = self.get_obs(state, jnp.zeros(self._action_size))
metrics = {
"episode_returns": 0,
"episode_lengths": 0,
"returned_episode_returns": 0,
"returned_episode_lengths": 0,
"timestep": 0,
"returned_episode": False,
}

return State(state, obs, jnp.array(0.0), False, metrics)
metrics = EnvMetrics(
episode_returns=jnp.array(0.0),
episode_lengths=jnp.array(0),
returned_episode_returns=jnp.array(0.0),
returned_episode_lengths=jnp.array(0),
timestep=jnp.array(0),
returned_episode=jnp.array(False),
)

return EnvState(state, obs, jnp.array(0.0), jnp.array(False), metrics)

@partial(jax.jit, static_argnums=(0,))
def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State:
def step(self, env_state: EnvState, action: jnp.ndarray, rng: jnp.ndarray) -> EnvState:
"""Run one timestep of the environment's dynamics and returns observations with rewards."""
state = env_state.pipeline_state
metrics = env_state.metrics

state_step = self.pipeline_step(state, action) # because scaled action so bad...
state_step = self.pipeline_step(state, action)
obs_state = self.get_obs(state, action)

# reset env if done
Expand All @@ -159,11 +203,7 @@ def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State
obs_reset = self.get_obs(state, jnp.zeros(self._action_size))

# get obs/reward/done of action + states
reward = self.compute_reward(
state,
state_step,
action,
)
reward = self.compute_reward(state, state_step, action)
done = self.is_done(state_step)

# setting done = True if nans in next state
Expand All @@ -175,25 +215,21 @@ def step(self, env_state: State, action: jnp.ndarray, rng: jnp.ndarray) -> State
new_state = jax.tree.map(lambda x, y: jax.lax.select(done, x, y), state_reset, state_step)
obs = jax.lax.select(done, obs_reset, obs_state)

########### METRIC TRACKING ###########

# Calculate new episode return and length
new_episode_return = metrics["episode_returns"] + reward
new_episode_length = metrics["episode_lengths"] + 1

# Update metrics -- we only count episode
metrics["episode_returns"] = new_episode_return * (1 - done)
metrics["episode_lengths"] = new_episode_length * (1 - done)
metrics["returned_episode_returns"] = (
metrics["returned_episode_returns"] * (1 - done) + new_episode_return * done
)
metrics["returned_episode_lengths"] = (
metrics["returned_episode_lengths"] * (1 - done) + new_episode_length * done
new_episode_return = metrics.episode_returns + reward
new_episode_length = metrics.episode_lengths + 1

# Update metrics
new_metrics = EnvMetrics(
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns=metrics.returned_episode_returns * (1 - done) + new_episode_return * done,
returned_episode_lengths=metrics.returned_episode_lengths * (1 - done) + new_episode_length * done,
timestep=metrics.timestep + 1,
returned_episode=done,
)
metrics["timestep"] = metrics["timestep"] + 1
metrics["returned_episode"] = done

return env_state.replace(pipeline_state=new_state, obs=obs, reward=reward, done=done, metrics=metrics)
return EnvState(new_state, obs, reward, done, new_metrics)

@partial(jax.jit, static_argnums=(0,))
def compute_reward(
Expand All @@ -202,8 +238,8 @@ def compute_reward(
next_state: MjxState,
action: jnp.ndarray,
) -> jnp.ndarray:
"""Compute the reward for standing and height."""
min_z, max_z = self.reward_config.height_min_z, self.reward_config.height_max_z
min_z = self.reward_config.height_min_z
max_z = self.reward_config.height_max_z

exp_coef = self.reward_config.original_pos_reward_exp_coefficient
subtraction_factor = self.reward_config.original_pos_reward_subtraction_factor
Expand Down Expand Up @@ -238,12 +274,9 @@ def compute_reward(
@partial(jax.jit, static_argnums=(0,))
def is_done(self, state: MjxState) -> jnp.ndarray:
"""Check if the episode should terminate."""
# Get the height of the robot's center of mass
com_height = state.q[2]

min_z, max_z = self.reward_config.height_min_z, self.reward_config.height_max_z
height_condition = jnp.logical_not(jnp.logical_and(min_z < com_height, com_height < max_z))

return height_condition

@partial(jax.jit, static_argnums=(0,))
Expand Down Expand Up @@ -343,26 +376,26 @@ def run_environment_adhoc() -> None:

fps = int(1 / env.dt)
max_frames = int(args.video_length * fps)
rollout: list[Any] = []
rollout: list[MjxState] = []

for episode in range(args.num_episodes):
rng, _ = jax.random.split(rng)
state = reset_fn(rng)
env_state = reset_fn(rng)

total_reward = 0

for _ in tqdm(range(args.max_steps), desc=f"Episode {episode + 1} Steps", leave=False):
if len(rollout) < args.video_length * fps:
rollout.append(state.pipeline_state)
rollout.append(env_state.pipeline_state)

rng, action_rng = jax.random.split(rng)
action = jax.random.uniform(action_rng, (action_size,), minval=0, maxval=1.0)

rng, step_rng = jax.random.split(rng)
state = step_fn(state, action, step_rng)
total_reward += state.reward
env_state = step_fn(env_state, action, step_rng)
total_reward += env_state.reward

if state.done:
if env_state.done:
break

logger.info("Episode %d total reward: %f", episode + 1, total_reward)
Expand Down

0 comments on commit 1541567

Please sign in to comment.