Skip to content

Commit

Permalink
[python|gym] Performance improvement. (#242)
Browse files Browse the repository at this point in the history
* [core/python] Do not use Eigenpy to convert systemState.q/v/a to improve efficiency and preserve constness.
* [python/viewer] Enable to add legend in Meshcat.
* [python/viewer] Enable to add logo in meshcat.
* [python/viewer] Add 'legend' and 'logo_fullpath' optional argument to 'play_trajectories'.
* [python|gym] Add Python native 'is_simulation_running' attribute to Simulator for fast access. 
* [gym] Define shared memories/proxies at init when possible. 
* [gym] Replace 'engine' proxy by 'simulator' for greater flexibility.
* [gym/envs] Improve performance.

Co-authored-by: Alexis Duburcq <[email protected]>
  • Loading branch information
duburcqa and Alexis Duburcq authored Dec 7, 2020
1 parent b26fe8e commit 7248150
Show file tree
Hide file tree
Showing 17 changed files with 341 additions and 74 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
cmake_minimum_required(VERSION 3.10)

# Set the build version
set(BUILD_VERSION 1.4.21)
set(BUILD_VERSION 1.4.22)

# Add definition of Jiminy version for C++ headers
add_definitions("-DJIMINY_VERSION=\"${BUILD_VERSION}\"")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gym

import jiminy_py.core as jiminy
from jiminy_py.simulator import Simulator

from ..utils import copy, SpaceDictNested

Expand Down Expand Up @@ -173,14 +174,14 @@ class ObserverControllerInterface(ObserverInterface, ControllerInterface):
"""Observer plus controller interface for both generic pipeline blocks,
including environments.
"""
engine: Optional[jiminy.EngineMultiRobot]
simulator: Optional[Simulator]
stepper_state: Optional[jiminy.StepperState]
system_state: Optional[jiminy.SystemState]
sensors_data: Optional[Dict[str, np.ndarray]]

def __init__(self, *args: Any, **kwargs: Any) -> None:
# Define some attributes
self.engine = None
self.simulator = None
self.stepper_state = None
self.system_state = None
self.sensors_data = None
Expand Down
25 changes: 10 additions & 15 deletions python/gym_jiminy/common/gym_jiminy/common/bases/pipeline_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import gym

import jiminy_py.core as jiminy
from jiminy_py.simulator import Simulator
from jiminy_py.controller import ObserverHandleType, ControllerHandleType

from ..utils import (
Expand Down Expand Up @@ -54,6 +55,12 @@ def __init__(self,
# Initialize base wrapper and interfaces through multiple inheritance
super().__init__(env) # Do not forward extra arguments, if any

# Refresh some proxies for fast lookup
self.simulator: Simulator = self.env.simulator
self.stepper_state: jiminy.StepperState = self.env.stepper_state
self.system_state: jiminy.SystemState = self.env.system_state
self.sensors_data: jiminy.sensorsData = self.env.sensors_data

# Define some internal buffers
self._dt_eps: Optional[float] = None
self._command = zeros(self.env.unwrapped.action_space)
Expand Down Expand Up @@ -187,9 +194,6 @@ def _setup(self) -> None:
fill(self._command, 0.0)

# Refresh some proxies for fast lookup
self.engine = self.env.engine
self.stepper_state = self.env.stepper_state
self.system_state = self.env.system_state
self.sensors_data = self.env.sensors_data

def refresh_observation(self) -> None: # type: ignore[override]
Expand Down Expand Up @@ -344,9 +348,6 @@ def refresh_observation(self) -> None: # type: ignore[override]
"""
# pylint: disable=arguments-differ

# Assertion(s) for type checker
assert self.engine is not None and self.stepper_state is not None

# Get environment observation
super().refresh_observation()

Expand All @@ -355,7 +356,7 @@ def refresh_observation(self) -> None: # type: ignore[override]
if _is_breakpoint(t, self.observe_dt, self._dt_eps):
obs = self.env.get_observation()
self.observer.refresh_observation(obs)
if not self.engine.is_simulation_running:
if not self.simulator.is_simulation_running:
features = self.observer.get_observation()
if self.augment_observation:
self._observation = obs
Expand Down Expand Up @@ -532,9 +533,6 @@ def compute_command(self,
:param measure: Observation of the environment.
:param action: High-level target to achieve.
"""
# Assertion(s) for type checker
assert self.engine is not None and self.stepper_state is not None

# Backup the action
set_value(self._action, action)

Expand All @@ -552,7 +550,7 @@ def compute_command(self,
# update the command of the right period. Ultimately, this is done
# automatically by the engine, which is calling `_controller_handle` at
# the right period.
if self.engine.is_simulation_running:
if self.simulator.is_simulation_running:
# Do not update command during the first iteration because the
# action is undefined at this point
np.core.umath.copyto(self._command, self.env.compute_command(
Expand All @@ -576,14 +574,11 @@ def refresh_observation(self) -> None: # type: ignore[override]
:returns: Original environment observation, eventually including
controllers targets if requested.
"""
# Assertion(s) for type checker
assert self.engine is not None

# Get environment observation
super().refresh_observation()

# Add target to observation if requested
if not self.engine.is_simulation_running:
if not self.simulator.is_simulation_running:
self._observation = self.env.get_observation()
if self.augment_observation:
self._observation.setdefault('targets', {})[
Expand Down
35 changes: 18 additions & 17 deletions python/gym_jiminy/common/gym_jiminy/common/envs/env_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,20 @@ def __init__(self,
"""
# pylint: disable=unused-argument

# Initialize the interfaces through multiple inheritance
super().__init__() # Do not forward extra arguments, if any

# Backup some user arguments
self.simulator = simulator
self.simulator: Simulator = simulator
self.step_dt = step_dt
self.enforce_bounded = enforce_bounded
self.debug = debug

# Initialize the interfaces through multiple inheritance
super().__init__() # Do not forward extra arguments, if any
# Define some proxies for fast access
self.engine: jiminy.EngineMultiRobot = self.simulator.engine
self.stepper_state: jiminy.StepperState = self.engine.stepper_state
self.system_state: jiminy.SystemState = self.engine.system_state
self.sensors_data: jiminy.sensorsData = dict(self.robot.sensors_data)

# Internal buffers for physics computations
self.rg = np.random.RandomState()
Expand Down Expand Up @@ -418,11 +424,8 @@ def reset(self,
# Reset the simulator
self.simulator.reset()

# Initialize shared memories.
# Re-initialize some shared memories.
# It must be done because the robot may have changed.
self.engine = self.simulator.engine
self.stepper_state = self.engine.stepper_state
self.system_state = self.engine.system_state
self.sensors_data = dict(self.robot.sensors_data)

# Make sure the environment is properly setup
Expand Down Expand Up @@ -503,14 +506,14 @@ def reset(self,
is_obs_valid = False
if not is_obs_valid:
raise RuntimeError(
"The observation returned by `refresh_observation` is "
"The observation computed by `refresh_observation` is "
"inconsistent with the observation space defined by "
"`_refresh_observation_space`.")

if self.is_done():
raise RuntimeError(
"The simulation is already done at `reset`. "
"Check the implementation of `is_done` if overloaded.")
"The simulation is already done at `reset`. Check the "
"implementation of `is_done` if overloaded.")

return self.get_observation()

Expand Down Expand Up @@ -558,7 +561,7 @@ def step(self,
not), and a dictionary of extra information
"""
# Make sure a simulation is already running
if self.engine is None or not self.engine.is_simulation_running:
if not self.simulator.is_simulation_running:
raise RuntimeError(
"No simulation running. Please call `reset` before `step`.")

Expand Down Expand Up @@ -625,10 +628,9 @@ def step(self,
def get_log(self) -> Tuple[Dict[str, np.ndarray], Dict[str, str]]:
"""Get log of recorded variable since the beginning of the episode.
"""
if self.engine is None or not self.engine.is_simulation_running:
if not self.simulator.is_simulation_running:
raise RuntimeError(
"No simulation running. Please call `reset` at least one "
"before getting log.")
"No simulation running. Please start one before getting log.")
return self.simulator.get_log()

def render(self,
Expand Down Expand Up @@ -839,11 +841,10 @@ def refresh_observation(self) -> None: # type: ignore[override]
# pylint: disable=arguments-differ

# Assertion(s) for type checker
assert (self.engine is not None and self.stepper_state is not None and
isinstance(self._observation, dict))
assert isinstance(self._observation, dict)

self._observation['t'][0] = self.stepper_state.t
if not self.engine.is_simulation_running:
if not self.simulator.is_simulation_running:
(self._observation['state']['Q'],
self._observation['state']['V']) = self.simulator.state
self._observation['sensors'] = self.sensors_data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,10 @@ def is_done(self) -> bool: # type: ignore[override]
"""
# pylint: disable=arguments-differ

# Assertion(s) for type checker
assert self.system_state is not None

if not self.simulator.is_simulation_running:
raise RuntimeError(
"No simulation running. Please start one before calling this "
"method.")
if self.system_state.q[2] < self._height_neutral * 0.75:
return True
if self.simulator.stepper_state.t >= self.simu_duration_max:
Expand All @@ -408,9 +409,6 @@ def compute_reward(self, # type: ignore[override]
"""
# pylint: disable=arguments-differ

# Assertion(s) for type checker
assert self.system_state is not None

reward_dict = info.setdefault('reward', {})

# Define some proxies
Expand Down
2 changes: 1 addition & 1 deletion python/gym_jiminy/common/gym_jiminy/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __init__(self: wrapped_env_class, # type: ignore[valid-type]
def load_pipeline(fullpath: str) -> Type[BasePipelineWrapper]:
""" TODO: Write documentation.
"""
file_ext = ''.join(pathlib.Path(fullpath).suffixes)
file_ext = pathlib.Path(fullpath).suffix
with open(fullpath, 'r') as f:
if file_ext == '.json':
return build_pipeline(**json.load(f))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,12 @@ def _setup(self) -> None:
"`StackedJiminyEnv` does not support time-continuous update.")

def refresh_observation(self) -> None: # type: ignore[override]
# Assertion(s) for type checker
assert self.engine is not None and self.stepper_state is not None

# Get environment observation
self.env.refresh_observation()

# Update observed features if necessary
t = self.stepper_state.t
if self.engine.is_simulation_running and \
if self.simulator.is_simulation_running and \
_is_breakpoint(t, self.observe_dt, self._dt_eps):
self.__n_last_stack += 1
if self.__n_last_stack == self.skip_frames_ratio:
Expand Down
9 changes: 8 additions & 1 deletion python/gym_jiminy/envs/gym_jiminy/envs/acrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def __init__(self, continuous: bool = False):
# Configure the learning environment
super().__init__(simulator, STEP_DT, debug=False)

# Create some proxies for fast access
self.__state_view = (self._observation[:self.robot.nq],
self._observation[self.robot.nv:])

def _refresh_observation_space(self) -> None:
"""Configure the observation of the environment.
Expand All @@ -131,7 +135,10 @@ def refresh_observation(self, *args: Any, **kwargs: Any) -> None:
For goal env, in addition of the current robot state, both the
desired and achieved goals are observable.
"""
np.concatenate(self.simulator.state, out=self._observation)
if not self.simulator.is_simulation_running:
self.__state = (self.system_state.q, self.system_state.v)
np.core.umath.copyto(self.__state_view[0], self.__state[0])
np.core.umath.copyto(self.__state_view[1], self.__state[1])

def _refresh_action_space(self) -> None:
"""Configure the action space of the environment.
Expand Down
10 changes: 8 additions & 2 deletions python/gym_jiminy/envs/gym_jiminy/envs/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def __init__(self, continuous: bool = False):
# Configure the learning environment
super().__init__(simulator, STEP_DT, debug=False)

# Create some proxies for fast access
self.__state_view = (self._observation[:self.robot.nq],
self._observation[self.robot.nv:])

def _setup(self) -> None:
""" TODO: Write documentation.
"""
Expand Down Expand Up @@ -176,8 +180,10 @@ def _sample_state(self) -> Tuple[np.ndarray, np.ndarray]:

def refresh_observation(self) -> None:
# @copydoc BaseJiminyEnv::refresh_observation
np.core.umath.copyto(
self._observation, np.concatenate(self.simulator.state))
if not self.simulator.is_simulation_running:
self.__state = (self.system_state.q, self.system_state.v)
np.core.umath.copyto(self.__state_view[0], self.__state[0])
np.core.umath.copyto(self.__state_view[1], self.__state[1])

def is_done(self) -> bool:
""" TODO: Write documentation.
Expand Down
2 changes: 1 addition & 1 deletion python/jiminy_py/src/jiminy_py/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def read_log(fullpath: str,
"""
# Handling of file file_format
if file_format is None:
file_ext = ''.join(pathlib.Path(fullpath).suffixes)
file_ext = pathlib.Path(fullpath).suffix
if file_ext == '.data':
file_format = 'binary'
elif file_ext == '.csv' or file_ext == '.txt':
Expand Down
33 changes: 33 additions & 0 deletions python/jiminy_py/src/jiminy_py/meshcat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

<script type="text/javascript" src="main.min.js"></script>
<script type="text/javascript" src="webm-writer-0.3.0.js"></script>
<script type="text/javascript" src="legend.js"></script>
<script>
// Instantiate a new Meshcat viewer
var viewer = new MeshCat.Viewer(document.getElementById("meshcat-pane"), false);
Expand All @@ -22,6 +23,38 @@
viewer.handle_command = function(cmd) {
if (cmd.type == "ready") {
viewer.connection.send("meshcat:ok");
} else if (cmd.type == "legend") {
var legend = document.getElementById("legend");
if (legend == null) {
createLegend("legend");
legend = document.getElementById("legend");
}
if (cmd.text) {
setLegendItem(legend, cmd.id, cmd.text, cmd.color);
} else {
removeLegendItem(legend, cmd.id);
}
} else if (cmd.type == "logo") {
var logo = document.getElementById("logo");
if (cmd.data) {
if (logo == null) {
logo = document.createElement("img");
logo.id = "logo";
logo.draggable = false;
logo.style.position = "fixed";
logo.style.bottom = "20px";
logo.style.left = "20px";
logo.style.pointerEvents = "none";
document.body.prepend(logo);
}
logo.setAttribute('src', 'data:image/png;base64,' + cmd.data);
logo.style.width = cmd.width.toString() + "px";
logo.style.height = cmd.height.toString() + "px";
} else {
if (logo !== null) {
document.body.removeChild(logo);
}
}
} else {
handle_command.call(this, cmd);
}
Expand Down
Loading

0 comments on commit 7248150

Please sign in to comment.