In other sections, we have seen how to use tmrl
as a standalone program, thanks to the ready-to-use training pipeline for TrackMania.
However, as soon as you want to try more advanced things (e.g., using robots, other video games, other training algorithms, etc.), you will need to get your hands dirty with some python coding.
This is when you need to start using tmrl
as a python library.
In this tutorial, we will learn from A to Z how to implement our own specialized pipeline, in our own robot environment, with our own training algorithm.
The full script of the tutorial is available here.
You may also want to read the API documentation.
In complement, you can find a minimal tmrl
pipeline for the real-time environment used in this tutorial here.
You can also find a minimal tmrl
pipeline for non-real-time environments (e.g., Pendulum) here.
Note: some modules can be implemented independently. If you are here because you wish to implement your own training algorithm in TrackMania, you should read the competition tutorial instead.
If this is not done yet:
pip install tmrl
(Note: if installing on Windows, you will be prompted to install the vigembus driver, as part of the vgamepad dependency.
This is for TrackMania support.
If you think an option to install tmrl
without support for TrackMania should exist on Windows, please submit a PR, or reach us.)
We use this method a lot in tmrl
, it enables partially initializing the kwargs of a class.
Import this method into your script:
from tmrl.util import partial
The method can then be used as:
my_partially_instantiated_class = partial(my_class,
some_kwarg=some_value,
some_other_kwarg=some_other_value,
...)
And the partially instantiated class can then be fully instantiated as:
my_object = my_partially_instantiated_class(missing_kwargs)
In case you need them, you can access the constants defined in the config.json
file via the config_constants module.
This module can be imported into your script as follows:
import tmrl.config.config_constants as cfg
You can then use the constants in your script, e.g.:
print(f"Run name: {cfg.RUN_NAME}")
(NB: read the code for finding available constants)
In the context of this tutorial, everything will happen on localhost and thus Internet security is not really a concern.
In real applications though, you may want to have several tmrl
entities communicating over the Internet.
Security-wise, tmrl
is based on tlspyo.
This enables authentication and encryption of your communications via a TLS key, that you first need to generate if you wish to use this option (see the tlspyo
documentation for doing so in a couple easy steps).
For your safety, please carefully consider using this feature when training over a public network, as this is to protect you against possible attacks from malicious users.
In the context of this tutorial, we will not enable this feature as we suppose our local network to be safe.
Instead, we will just rely on the weak password security that is always enabled in tmrl
:
security = None # change this to "TLS" for TLS encryption (requires a TLS key)
password = "A Secure Password" # change this to a random password
(NB: When training over a public network, you should use both a secure password and TLS encryption.
Please read the tlspyo
security instructions to understand why this is important.)
In this tutorial, we will use our localhost IP (i,e, 127.0.0.1
) and port 6666
for communication between our tmrl
entities.
In an Internet application, you would adapt these to your network setup:
server_ip = "127.0.0.1" # IP of the machine where we will run our TMRL Server
server_port = 6666 # port through which our Server will be accessible
tmrl
uses a multiple clients / single server architecture.
The Server
object is the central means of communication between an instance of the Trainer
class and one to several instances of the RolloutWorker
class.
This enables tmrl
to run on, e.g., HPC clusters where the user does not have port-forwarding access, as long as they have a local machine with port-forwarding access on which the Server
can run (this can be a machine running a RolloutWorker
in parallel).
Both the Trainer
and the RolloutWorkers
connect to the Server
.
The RolloutWorkers
run the current policy to collect samples in the real world, and periodically send these samples to the Server
, which forwards them to the Trainer
.
The Trainer
uses these samples to update the current policy, and periodically sends updated policy weights to the Server
, which forwards them to all connected RolloutWorkers
.
Being the central communication point between entities, the Server
should be instantiated first.
In the context of this tutorial, we will instantiate all 3 entities on the same machine, and thus they will communicate via the localhost
address, which is "127.0.0.1"
(NB: the Server
does not know this, it listens to any incoming connection).
Instantiating a Server
object is straightforward:
from tmrl.networking import Server
# tmrl Server
# (NB: When you omit arguments,
# tmrl retrieves the default in your config.json file.
# Read the documentation of each class for more info.)
my_server = Server(security=security,
password=password,
port=server_port)
As soon as the server is instantiated, it listens for incoming connections from the Trainer
and the RolloutWorkers
.
In RL, a task is often called an "environment".
tmrl
is particularly fit for asynchronous remote training of real-time applications such as robots.
We use Real-Time Gym (rtgym
) to wrap our robots and video games into Gymnasium environments.
But you can use an environment of your choice, as long as it is registered as a Gymnasium environment.
To build your own real-time environment (e.g., an environment for your own robot or video game), you can follow the rtgym tutorial.
If you need inspiration, you can find our rtgym
interfaces for TrackMania in custom_gym_interfaces.py.
For the sake of the tmrl
tutorial, we will be using the dummy RC drone environment from the rtgym
tutorial:
(NB: you need opencv-python
installed)
from rtgym import RealTimeGymInterface, DEFAULT_CONFIG_DICT, DummyRCDrone
import gymnasium.spaces as spaces
import numpy as np
import cv2
from threading import Thread
# rtgym interface:
class DummyRCDroneInterface(RealTimeGymInterface):
def __init__(self):
self.rc_drone = None
self.target = np.array([0.0, 0.0], dtype=np.float32)
self.initialized = False
self.blank_image = np.ones((500, 500, 3), dtype=np.uint8) * 255
self.rendering_thread = Thread(target=self._rendering_thread, args=(), kwargs={}, daemon=True)
def _rendering_thread(self):
from time import sleep
while True:
sleep(0.1)
self.render()
def get_observation_space(self):
pos_x_space = spaces.Box(low=-1.0, high=1.0, shape=(1,))
pos_y_space = spaces.Box(low=-1.0, high=1.0, shape=(1,))
tar_x_space = spaces.Box(low=-0.5, high=0.5, shape=(1,))
tar_y_space = spaces.Box(low=-0.5, high=0.5, shape=(1,))
return spaces.Tuple((pos_x_space, pos_y_space, tar_x_space, tar_y_space))
def get_action_space(self):
return spaces.Box(low=-2.0, high=2.0, shape=(2,))
def get_default_action(self):
return np.array([0.0, 0.0], dtype='float32')
def send_control(self, control):
vel_x = control[0]
vel_y = control[1]
self.rc_drone.send_control(vel_x, vel_y)
def reset(self, seed=None, options=None):
if not self.initialized:
self.rc_drone = DummyRCDrone()
self.initialized = True
pos_x, pos_y = self.rc_drone.get_observation()
self.target[0] = np.random.uniform(-0.5, 0.5)
self.target[1] = np.random.uniform(-0.5, 0.5)
return [np.array([pos_x], dtype='float32'),
np.array([pos_y], dtype='float32'),
np.array([self.target[0]], dtype='float32'),
np.array([self.target[1]], dtype='float32')], {}
def get_obs_rew_terminated_info(self):
pos_x, pos_y = self.rc_drone.get_observation()
tar_x = self.target[0]
tar_y = self.target[1]
obs = [np.array([pos_x], dtype='float32'),
np.array([pos_y], dtype='float32'),
np.array([tar_x], dtype='float32'),
np.array([tar_y], dtype='float32')]
rew = -np.linalg.norm(np.array([pos_x, pos_y], dtype=np.float32) - self.target)
terminated = rew > -0.01
info = {}
return obs, rew, terminated, info
def wait(self):
pass
def render(self):
image = self.blank_image.copy()
pos_x, pos_y = self.rc_drone.get_observation()
image = cv2.circle(img=image,
center=(int(pos_x * 200) + 250, int(pos_y * 200) + 250),
radius=10,
color=(255, 0, 0),
thickness=1)
image = cv2.circle(img=image,
center=(int(self.target[0] * 200) + 250, int(self.target[1] * 200) + 250),
radius=5,
color=(0, 0, 255),
thickness=-1)
cv2.imshow("Dummy RC drone", image)
if cv2.waitKey(1) & 0xFF == ord('q'):
return
# rtgym configuration dictionary:
my_config = DEFAULT_CONFIG_DICT.copy()
my_config["interface"] = DummyRCDroneInterface
my_config["time_step_duration"] = 0.05
my_config["start_obs_capture"] = 0.05
my_config["time_step_timeout_factor"] = 1.0
my_config["ep_max_length"] = 100
my_config["act_buf_len"] = 4
my_config["reset_act_buf"] = False
my_config["benchmark"] = True
my_config["benchmark_polyak"] = 0.2
Now that we have our robot encapsulated in a Gymnasium environment, we will create an RL actor.
In tmrl
, this is done within a RolloutWorker
object.
One to several RolloutWorkers
can coexist in tmrl
, each one typically encapsulating a robot, or, in the case of a video game, an instance of the game
(each RolloutWorker
possibly running on a separate computer).
The prototype of the RolloutWorker
class is:
import tmrl.config.config_constants as cfg # constants from the config.json file
class RolloutWorker:
def __init__(
self,
env_cls=None, # class of the Gymnasium environment
actor_module_cls=None, # class of a module containing the policy
sample_compressor: callable = None, # compressor for sending samples over the Internet
server_ip=None, # ip of the central server
server_port=cfg.PORT, # port of the server
password=cfg.PASSWORD, # password of the server
max_samples_per_episode=np.inf, # if an episode gets longer than this, it is reset
model_path=cfg.MODEL_PATH_WORKER, # path where a local copy of the policy will be stored
obs_preprocessor: callable = None, # utility for modifying observations returned by the environment
crc_debug=False, # can be used for debugging the pipeline
model_path_history=cfg.MODEL_PATH_SAVE_HISTORY, # a
# history of policies can be stored here
model_history=cfg.MODEL_HISTORY, # new policies are saved % model_history (0: not saved)
standalone=False, # if True, the worker will not try to connect to a server
verbose=True, # if True, the worker will log messages to the console
):
# (...)
For example, the default RolloutWorker
implemented for TrackMania is instantiated here.
In this tutorial, we will implement a similar RolloutWorker
for our dummy drone environment.
The first argument of our RolloutWorker
is env_cls
.
This expects a Gymnasium environment class, which can be partially instantiated with partial()
.
Furthermore, this Gymnasium environment needs to be wrapped in the GenericGymEnv
wrapper.
With our dummy drone environment, this translates to:
from tmrl.util import partial
from tmrl.envs import GenericGymEnv
env_cls=partial(GenericGymEnv, id="real-time-gym-ts-v1", gym_kwargs={"config": my_config})
We can create a dummy environment to retrieve the action and observation spaces:
dummy_env = env_cls()
act_space = dummy_env.action_space
obs_space = dummy_env.observation_space
print(f"action space: {act_space}")
print(f"observation space: {obs_space}")
which outputs the following:
action space: Box([-2. -2.], [2. 2.], (2,), float32)
observation space: Tuple(Box([-1.], [1.], (1,), float32),
Box([-1.], [1.], (1,), float32),
Box([-0.5], [0.5], (1,), float32),
Box([-0.5], [0.5], (1,), float32),
Box([-2. -2.], [2. 2.], (2,), float32),
Box([-2. -2.], [2. 2.], (2,), float32),
Box([-2. -2.], [2. 2.], (2,), float32),
Box([-2. -2.], [2. 2.], (2,), float32))
Our dummy drone environment has a simple action space of two floats (velocities on x and y). Its observation space is a bit more complex: 4 floats representing the position (2 values) and the target position (2 values), and 4 arrays representing the 4 last actions (4 times 2 values). This history of actions is required to make the observation space Markov because the dummy RC drone has random communication delays.
The second argument is actor_module_cls
.
This expects a class that implements the ActorModule interface.
When using pytorch, we can conveniently use TorchActorModule
instead, which partially implements this interface.
TorchActorModule
is a pytorch neural network (i.e., a subclass of torch.nn.Module
) that implements an extra act()
method on top of the usual forward()
method.
The neural network is what will be trained by the Trainer (our policy), while the act()
method is for the RolloutWorker
to interact with this policy.
On top of the act()
method, subclasses of ActorModule
must implement a __init__()
method that takes at least two arguments: observation_space
and action_space
.
This enables you to implement generic models as we will do now.
Let us implement this module for our dummy drone environment.
Here, we basically copy-paste the implementation of the SAC MLP actor from OpenAI Spinup and adapt it to the TorchActorModule
interface:
from tmrl.actor import TorchActorModule
from tmrl.util import prod
import torch
import torch.nn.functional as F
LOG_STD_MAX = 2
LOG_STD_MIN = -20
def mlp(sizes, activation, output_activation=torch.nn.Identity):
layers = []
for j in range(len(sizes) - 1):
act = activation if j < len(sizes) - 2 else output_activation
layers += [torch.nn.Linear(sizes[j], sizes[j + 1]), act()]
return torch.nn.Sequential(*layers)
class MyActorModule(TorchActorModule):
"""
Directly adapted from the Spinup implementation of SAC
"""
def __init__(self, observation_space, action_space, hidden_sizes=(256, 256), activation=torch.nn.ReLU):
super().__init__(observation_space, action_space)
dim_obs = sum(prod(s for s in space.shape) for space in observation_space)
dim_act = action_space.shape[0]
act_limit = action_space.high[0]
self.net = mlp([dim_obs] + list(hidden_sizes), activation, activation)
self.mu_layer = torch.nn.Linear(hidden_sizes[-1], dim_act)
self.log_std_layer = torch.nn.Linear(hidden_sizes[-1], dim_act)
self.act_limit = act_limit
def forward(self, obs, test=False, with_logprob=True):
net_out = self.net(torch.cat(obs, -1))
mu = self.mu_layer(net_out)
log_std = self.log_std_layer(net_out)
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
std = torch.exp(log_std)
pi_distribution = torch.distributions.normal.Normal(mu, std)
if test:
pi_action = mu
else:
pi_action = pi_distribution.rsample()
if with_logprob:
logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(axis=1)
else:
logp_pi = None
pi_action = torch.tanh(pi_action)
pi_action = self.act_limit * pi_action
pi_action = pi_action.squeeze()
return pi_action, logp_pi
def act(self, obs, test=False):
with torch.no_grad():
a, _ = self.forward(obs, test, False)
return a.cpu().numpy()
Now, the actor module can be partially instantiated:
actor_module_cls = partial(MyActorModule) # could add paramters like hidden_sizes if needed
One reason for using tmrl
is that it supports implementing ad-hoc pipelines for your applications.
In particular, say you have a robot that uses CNNs to process a history of 4 concatenated images.
You certainly do not want to send these 4 images over the Internet for each sample, because the content of the samples would overlap and you could use 4 times less bandwidth with a more clever pipeline.
For our dummy RC drone, we have a similar (yet much less serious) issue with the action buffer that is part of the observations.
This buffer contains a history of the last 4 sent actions, and thus 3 of them overlap with the content of previous samples.
Moreover, actions are part of samples anyway because a sample is defined as (act, obs, rew, terminated, truncated, info)
.
Although this is not really an issue given the small size of actions here, let us implement an optimal pipeline for the sake of illustration.
The sample_compressor
argument enables implementing custom compression schemes for your applications.
If left to None
, no compression will happen and raw samples will be sent over the network.
For our dummy RC drone, an optimal compression scheme just removes the action buffer from observations:
def my_sample_compressor(act, obs, rew, terminated, truncated, info):
"""
Compresses samples before sending them over the network.
This function creates the sample that will actually be stored in local buffers for networking.
This is to compress the sample before sending it over the Internet/local network.
Buffers of compressed samples will be given as input to the append() method of the memory.
When you implement a compressor, you also need to implement a decompressor in the memory.
Args:
act: action computed from a previous observation and applied to yield obs in the transition
obs, rew, terminated, truncated, info: outcome of the transition
Returns:
act_mod: compressed act
obs_mod: compressed obs
rew_mod: compressed rew
terminated_mod: compressed terminated
truncated_mod: compressed truncated
info_mod: compressed info
"""
act_mod, obs_mod, rew_mod, terminated_mod, truncated_mod, info_mod = act, obs, rew, terminated, truncated, info
obs_mod = obs_mod[:4] # here we remove the action buffer from observations
return act_mod, obs_mod, rew_mod, terminated_mod, truncated_mod, info_mod
We can then pass our sample compressor to the RolloutWorker
as:
sample_compressor = my_sample_compressor
The device
argument tells whether inference on the RolloutWorker
must run on CPU or GPU.
The default is "cpu"
, but if you have a large model that requires a GPU for inference (e.g., for image processing), you can set this to another device such as "cuda:0"
.
device = "cpu"
RolloutWorkers
connect to a central Server
to which they periodically send buffers of samples, while the Server
periodically broadcasts updated weights for the RolloutWorkers
.
RolloutWorkers
behave as Internet clients, and must therefore know the IP address of the Server
to be able to communicate.
Typically, the Server
lives on a machine to which you can forward ports behind your router.
Nevertheless, it is of course possible to work locally by hosting the Server
, RolloutWorkers
, and Trainer
on localhost.
This is done by setting the Server
IP as the localhost IP, i.e., "127.0.0.1"
, which we did.
(NB: We have set the values for server_ip
and server_port
earlier in this tutorial.)
In case your Gymnasium environment is never terminated
(or only after too long), tmrl
has an option for forcing reset after a time-steps threshold.
For instance, let us say we don't want an episode to last more than 1000 time-steps:
(Note 1: This is for the sake of illustration, in fact, this cannot happen in our RC drone environment)
(Note 2: If the episode is stopped because of this threshold, the terminated
signal will be False
and the truncated
signal will be True
)
max_samples_per_episode = 1000
model_path
refers to the path where the RolloutWorker
will locally save its current weights.
Furthermore, if weights are already present at this path, they will be loaded on the RolloutWorker
instantiation
(this acts as a saving mechanism).
model_path_history
refers to the path where the RolloutWorker
will locally save a history of its weights during training if you set model_history > 0
.
CAUTION: model_path
and model_path_history
are weird and will probably change in future versions.
At the moment, we recommend not setting these parameters and changing the value of the "RUN_NAME"
entry in the config.json
file instead (weights will then be saved and loaded from the weights
folder).
However, if you do not want to modify the config.json
file, you can use these kwargs as follows:
import tmrl.config.config_constants as cfg
my_run_name = "tutorial"
weights_folder = cfg.WEIGHTS_FOLDER # path to the weights folder
model_path = str(weights_folder / (my_run_name + ".tmod"))
model_path_history = str(weights_folder / (my_run_name + "_"))
model_history
can be set to 0 if you do not wish to save the weight history, and to a positive value otherwise.
When model_history > 0
, incoming models from the Server
will be saved each time model_history
models have been received since the last saved model (e.g., 1 will save all models, 2 will save every one in two models, etc.).
For instance, let us say we want to save a checkpoint of our policy for every 10 new updated policies:
model_history = 10
(Note: this will save all these policies in model_path_history
, if you want to avoid this, set model_history = 0
instead)
A few more parameters are configurable, although they will not be useful in this tutorial. In particular:
obs_preprocessor
can be used to modify observations returned by the environment (this enables, e.g., converting RGB images into grayscale without modifying the environment).
Some examples of such preprocessors are available here.
standalone
can be set to True
for deployment, in which case the RolloutWorker
will not attempt to connect to the Server
.
Finally, crc_debug
is a useful tool for debugging your pipeline.
We will see how to use it at the end of this tutorial, you can ignore it for now.
Now we can instantiate a RolloutWorker
:
from tmrl.networking import RolloutWorker
my_worker = RolloutWorker(
env_cls=env_cls,
actor_module_cls=actor_module_cls,
sample_compressor=sample_compressor,
device=device,
server_ip=server_ip,
server_port=server_port,
password=password,
max_samples_per_episode=max_samples_per_episode,
model_path=model_path,
model_path_history=model_path_history,
model_history=model_history)
This connects to the Server
, but does not start collecting experiences.
If we want to start collecting experiences, we can use the run()
method:
my_worker.run(test_episode_interval=10)
This will collect training samples and run a test episode every 10 training episodes.
Test episodes are not used as training samples, and call the act()
method of your ActorModule
with test=True
.
Note that this function runs forever and will block your script there if you don't call run()
within a new python thread.
To stop the script, you will need to press CTRL + C
.
For the moment, let us just comment this line:
# my_worker.run(test_episode_interval=10)
Note: The run()
method collects episodes continuously, as fast as allowed by the environment.
In certain cases (e.g., in very-high-frequency or non-real-time environments), this is not what you want to achieve.
Instead, you may prefer lower-level API functions.
For instance, you could do something like this:
# Collect 10 episodes, then wait for the policy to get updated before collecting more.
# (NB: in rtgym, you can tell your robot to "wait" when an episode is complete)
while True:
episode = 0
while episode < 10:
my_worker.collect_train_episode(max_samples_per_episode)
my_worker.send_and_clear_buffer()
episode += 1
# Block until new weights are received from the Server
my_worker.update_actor_weights(verbose=True, blocking=True)
# Run a test episode
my_worker.run_episode(max_samples_per_episode, train=False)
In tmrl
, RL training per-se happens in the Trainer
entity.
The Trainer
connects to the Server
, from which it receives compressed samples gathered from connected RolloutWorkers
.
These samples are stored (possibly in compressed format) in a Memory
object.
They are decompressed either when stored, or when sampled from the Memory
, depending on the user choice.
The decompressed samples are then used by the TrainingAgent
object to optimize the policy weights, that the Trainer
periodically sends back to the Server
so they are broadcast to all connected RolloutWorkers
.
The prototype of the Trainer
class is:
import tmrl.config.config_constants as cfg
import tmrl.config.config_objects as cfg_obj
class Trainer:
def __init__(self,
training_cls=cfg_obj.TRAINER,
server_ip=cfg.SERVER_IP_FOR_TRAINER,
server_port=cfg.PORT,
password=cfg.PASSWORD,
model_path=cfg.MODEL_PATH_TRAINER,
checkpoint_path=cfg.CHECKPOINT_PATH,
dump_run_instance_fn: callable = None,
load_run_instance_fn: callable = None):
server_ip
is the public IP address of the Server
.
Since both the Trainer
and RolloutWorker
will run on the same machine as the Server
in this tutorial, the server_ip
will also be localhost here, i.e., "127.0.0.1"
.
The server_port
and the password
are still valid for our Trainer
.
model_path
is similar to the one of the RolloutWorker
.
The trainer will keep a local copy of its model that acts as a saving file.
checkpoints_path
is similar, but this will save the entire training_cls
instance (including the Memory
).
If set to None
, training will not be checkpointed.
You could leave both paths to their default value and simply change the value of the "RUN_NAME"
entry in config.json
instead.
But again, if you do not wish to use "config.json"
, you can set these arguments as follows:
CAUTION: do not set the exact same path as the one of the RolloutWorker
when running on the same machine (here, we use _t to differentiate both).
import tmrl.config.config_constants as cfg
weights_folder = cfg.WEIGHTS_FOLDER # path to the weights folder
checkpoints_folder = cfg.CHECKPOINTS_FOLDER
my_run_name = "tutorial"
model_path = str(weights_folder / (my_run_name + "_t.tmod"))
checkpoints_path = str(checkpoints_folder / (my_run_name + "_t.tcpt"))
dump_run_instance_fn
and load_run_instance_fn
are for advanced serialization when your Trainer
content cannot be pickled.
You can ignore these for now.
Now, the real beast is the training_cls
argument.
This expects a training class, possibly partially initialized.
At the moment, tmrl
supports one training class called TrainingOffline.
This class is meant for off-policy asynchronous RL algorithms such as SAC.
With a few synchronization tricks, it can also be used for online algorithms such as PPO
(see the note at the end of the "Rollout workers" section).
When using PyTorch, the TorchTrainingOffline
subclass further enables automatic selection of an available device.
The TorchTrainingOffline
prototype is:
@dataclass(eq=0)
class TorchTrainingOffline:
env_cls: type = GenericGymEnv # dummy environment, used only to retrieve observation / action spaces
memory_cls: type = TorchMemory # replay memory
training_agent_cls: type = TrainingAgent # training agent
epochs: int = 10 # total number of epochs, we save the agent every epoch
rounds: int = 50 # number of rounds per epoch, we generate statistics every round
steps: int = 2000 # number of training steps per round
update_model_interval: int = 100 # number of training steps between model broadcasts
update_buffer_interval: int = 100 # number of training steps between retrieving buffered samples
max_training_steps_per_env_step: float = 1.0 # training will pause when above this ratio
sleep_between_buffer_retrieval_attempts: float = 0.1 # when waiting for needed incoming samples
profiling: bool = False # if True, run_epoch will be profiled at the end of each epoch
agent_scheduler: callable = None # if not None, must be of the form f(agent:TrainingAgent, epoch:int)
start_training: int = 0 # minimum number of samples in the replay buffer before starting training
device: str = None # device on which the model of the TrainingAgent will live (None for automatic)
(Note: be careful when pairing max_training_steps_per_env_step
with a similar contraint in your RolloutWorker(s)
: the Trainer
counts calls to RolloutWorker.reset()
as an "env_step")
TorchTrainingOffline
requires other (possibly partially instantiated) classes as arguments: a dummy environment, a TorchMemory
, and a TrainingAgent
env_cls
: Most of the time, the dummy environment class that you need to pass here is the same class as for the RolloutWorker
Gymnasium environment:
from tmrl.util import partial
from tmrl.envs import GenericGymEnv
env_cls = partial(GenericGymEnv, id="real-time-gym-ts-v1", gym_kwargs={"config": my_config})
This dummy environment will only be used by the Trainer
to retrieve the observation and action spaces (reset()
will not be called).
Alternatively, you can pass this information as a tuple:
env_cls = (observation_space, action_space)
memory_cls
is the class of your replay buffer.
This must be a subclass of TorchMemory
.
The role of a TorchMemory
object is to store and decompress samples received by the Trainer
from the Server
.
In a tmrl
pipeline, the Memory
class is typically the most difficult (but also the most interesting) class to implement.
This is because, in conjunction with sample compressors, Memory
is where you can really optimize your pipeline for your specific application, and tmrl
gives you full latitude for doing this.
For the people who wish to quickly test things rather than implementing optimized pipelines, tmrl
also provides a readily implemented GenericTorchMemory
that works for all cases, as it has no optimizations (see the tuto_minimal_drone.py
script for an example of usage).
In this tutorial, we will instead implement a very optimized Memory
for the sake of illustration.
While these optimizations are clearly overkill for our toy RC drone example, they are tremendously important for vision-based applications like the default tmrl
pipeline for TrackMania.
TorchMemory
has the following interface:
class TorchMemory(ABC):
def __init__(self,
device, # output tensors will be collated to this device
nb_steps, # number of steps per round
sample_preprocessor: callable = None, # can be used for data augmentation
memory_size=1000000, # size of the circular buffer
batch_size=256, # batch size of the output tensors
dataset_path="", # an offline dataset may be provided here to initialize the memory
...) # unsupported stuff
@abstractmethod
def append_buffer(self, buffer):
"""
Appends a buffer of samples to the memory
"""
raise NotImplementedError
@abstractmethod
def __len__(self):
"""
Returns:
memory_length: int: maximum value (+1) that item can take in get_transition()
"""
raise NotImplementedError
@abstractmethod
def get_transition(self, item):
"""
Outputs a decompressed RL transition.
This transition is the same as the output by the Gymnasium environment (after observation preprocessing).
Args:
item: int: indices of the transition that the Trainer wants to sample
Returns:
full transition: (last_obs, new_act, rew, new_obs, terminated, truncated, info)
"""
raise NotImplementedError
You do not need to worry about device
and nb_steps
, as they will be set automatically by the Trainer
and are for the superclass.
sample_preprocessor
can be used if you wish to implement data augmentation before the samples are used for training.
We will not do this in the tutorial, but you can find a no-op example here (for syntax).
This argument is also only to be passed to the superclass.
memory_size
is the maximum number of transitions that can be contained in your TorchMemory
object.
When this size is exceeded, you will want to trim your memory in the append_buffer()
method.
The implementation of this trimming is left to your discretion.
Pass this to the superclass.
batch_size
is the size of the batches of tensors that the Trainer
will collate together.
In the current iteration of tmrl
, the Trainer
will call your get_transition()
method repeatedly with random item
values to retrieve samples one by one and will collate these samples together to form a batch.
Just pass to the superclass.
dataset_path
enables the user to initialize the memory with an offline dataset.
If used, this should point to a pickled file.
This file will be unpickled and put in self.data
on instantiation.
Otherwise, self.data
will be initialized with an empty list.
We will not be using this option in this tutorial, though.
Again, just pass to the superclass.
Let us implement our own TorchMemory
.
import random
from tmrl.memory import TorchMemory
class MyMemory(TorchMemory):
# (...)
You can do whatever you want in the __init__()
method as long as you initialize the superclass with its relevant arguments.
In our decompression scheme, we have removed the action buffer that we will need to rebuild here.
Thus, we will use the action buffer length as an additional argument to our custom class:
def __init__(self,
device=None,
nb_steps=None,
sample_preprocessor: callable = None,
memory_size=1000000,
batch_size=32,
dataset_path="",
act_buf_len=my_config["act_buf_len"]):
self.act_buf_len = act_buf_len # length of the action buffer
super().__init__(device=device,
nb_steps=nb_steps,
sample_preprocessor=sample_preprocessor,
memory_size=memory_size,
batch_size=batch_size,
dataset_path=dataset_path)
In fact, the TorchMemory
class leaves the entire storing and sampling procedures to your discretion.
This is because, when using tmrl
, you may want to do exotic things such as storing samples on your hard drive (if they contain images for instance).
If you have implemented a sample compressor for the RolloutWorker
(as we have done earlier in this tutorial), you will also need to implement a decompression scheme.
This decompression may happen either in append_buffer()
(if you privilege sampling speed) or in get_transition()
(if you privilege memory usage).
In this tutorial, we will privilege memory usage and thus we will implement our decompression scheme in get_transition()
.
The append_buffer()
method will simply store the compressed sample components in self.data
.
append_buffer()
is passed a buffer object that contains a list of compressed (act, new_obs, rew, terminated, truncated, info)
samples in its memory
attribute.
act
is the action that was sent to the step()
method of the Gymnasium environment to yield new_obs
, rew
, terminated
, truncated
, and info
.
Here, we decompose our samples in their relevant components, we append these components to the self.data
list, and we truncate self.data
when self.memory_size
is exceeded:
def append_buffer(self, buffer):
"""
buffer.memory is a list of compressed (act_mod, new_obs_mod, rew_mod, terminated_mod, truncated_mod, info_mod) samples
"""
# decompose compressed samples into their relevant components:
list_action = [b[0] for b in buffer.memory]
list_x_position = [b[1][0] for b in buffer.memory]
list_y_position = [b[1][1] for b in buffer.memory]
list_x_target = [b[1][2] for b in buffer.memory]
list_y_target = [b[1][3] for b in buffer.memory]
list_reward = [b[2] for b in buffer.memory]
list_terminated = [b[3] for b in buffer.memory]
list_truncated = [b[4] for b in buffer.memory]
list_info = [b[5] for b in buffer.memory]
list_done = [b[3] or b[4] for b in buffer.memory]
# append to self.data in some arbitrary way:
if self.__len__() > 0:
self.data[0] += list_action
self.data[1] += list_x_position
self.data[2] += list_y_position
self.data[3] += list_x_target
self.data[4] += list_y_target
self.data[5] += list_reward
self.data[6] += list_terminated
self.data[7] += list_info
self.data[8] += list_truncated
self.data[9] += list_done
else:
self.data.append(list_action)
self.data.append(list_x_position)
self.data.append(list_y_position)
self.data.append(list_x_target)
self.data.append(list_y_target)
self.data.append(list_reward)
self.data.append(list_terminated)
self.data.append(list_info)
self.data.append(list_truncated)
self.data.append(list_done)
# trim self.data in some arbitrary way when self.__len__() > self.memory_size:
to_trim = self.__len__() - self.memory_size
if to_trim > 0:
self.data[0] = self.data[0][to_trim:]
self.data[1] = self.data[1][to_trim:]
self.data[2] = self.data[2][to_trim:]
self.data[3] = self.data[3][to_trim:]
self.data[4] = self.data[4][to_trim:]
self.data[5] = self.data[5][to_trim:]
self.data[6] = self.data[6][to_trim:]
self.data[7] = self.data[7][to_trim:]
self.data[8] = self.data[8][to_trim:]
self.data[9] = self.data[9][to_trim:]
We must also implement the __len__()
method of our memory because the content of self.data
is arbitrary and the Trainer
needs to know what it can ask to the get_transition()
method:
def __len__(self):
if len(self.data) == 0:
return 0 # self.data is empty
result = len(self.data[0]) - self.act_buf_len - 1
if result < 0:
return 0 # not enough samples to reconstruct the action buffer
else:
return result # we can reconstruct that many samples
Now, this is becoming interesting: why is the __len__()
method so complicated?
self.data
is initially an empty list, so when its len
is 0
, our memory is empty.
But when it is not empty and we have fewer samples than the length of our action buffer, we cannot reconstruct the action buffer.
Thus our memory is still empty.
Finally, if we have enough samples, we need to remove the length of the action buffer to get the number of samples we can actually reconstruct.
Furthermore, the get_transition()
method outputs a full RL transition, which includes the previous observation.
Thus, we must subtract 1 to get the number of full transitions that we can actually output.
Alright, let us finally implement get_transition()
, where we have chosen sample decompression would happen.
This method outputs full transitions as if they were output by the Gymnasium environment
(after observation preprocessing if used):
def get_transition(self, item):
"""
Args:
item: int: indices of the transition that the Trainer wants to sample
Returns:
full transition: (last_obs, new_act, rew, new_obs, terminated, truncated, info)
"""
# if item corresponds to a transition from a terminal state to a reset state
if self.data[9][item + self.act_buf_len - 1]:
# this wouldn't make sense in RL, so we replace item by a neighbour transition
if item == 0: # if first item of the buffer
item += 1
elif item == self.__len__() - 1: # if last item of the buffer
item -= 1
elif random.random() < 0.5: # otherwise, sample randomly
item += 1
else:
item -= 1
idx_last = item + self.act_buf_len - 1 # index of previous observation
idx_now = item + self.act_buf_len # index of new observation
# rebuild the action buffer of both observations:
actions = self.data[0][item:(item + self.act_buf_len + 1)]
last_act_buf = actions[:-1] # action buffer of previous observation
new_act_buf = actions[1:] # action buffer of new observation
# rebuild the previous observation:
last_obs = (self.data[1][idx_last], # x position
self.data[2][idx_last], # y position
self.data[3][idx_last], # x target
self.data[4][idx_last], # y target
*last_act_buf) # action buffer
# rebuild the new observation:
new_obs = (self.data[1][idx_now], # x position
self.data[2][idx_now], # y position
self.data[3][idx_now], # x target
self.data[4][idx_now], # y target
*new_act_buf) # action buffer
# other components of the transition:
new_act = self.data[0][idx_now] # action
rew = np.float32(self.data[5][idx_now]) # reward
terminated = self.data[6][idx_now] # terminated signal
truncated = self.data[8][idx_now] # truncated signal
info = self.data[7][idx_now] # info dictionary
return last_obs, new_act, rew, new_obs, terminated, truncated, info
Note 1: the action buffer of new_obs
contains new_act
.
This is because at least the last computed action (new_act
) must be in the action buffer to keep a Markov state in a real-time environment. See rtgym.
Note 2: in our dummy RC drone environment, the action buffer is not reset on calls to reset()
and we do not call the special wait()
method when an episode ends.
Doing this, we keep the real-time flow of operations around reset transitions.
Nevertheless, reset transitions are discontinuities in the Gymnasium
standard: there is no action between a terminal state and the next state returned by reset()
.
This means that one action is a dummy action around reset transitions here.
More specifically, in the case of an rtgym
environment, the very last action of a trajectory is never sent to the robot.
From an RL perspective, this action is still the proper last action of the trajectory (although it has no effect).
However, because it is never applied, it is not supposed to go into the action buffer of post-reset transitions.
If you are using an rtgym
environment (as we do in this tutorial), tmrl
will set the default action of your environment as the action corresponding to the observation returned by reset()
in the buffer
argument of your append_buffer()
method.
This is because this default action will be applied right after reset()
is called, and it should therefore be part of the action buffer.
If you are instead using a vanilla Gymnasium
environment, tmrl
will set this action to None
, as you should ignore it anyway in non-real-time environments.
Because this is an intricate effect specific to real-time environments, we will ignore it for now and come back to it at the very end of this tutorial.
Just bear in mind that, if you want your pipeline to be 100% correct in real-time scenarios, you need to take special care for rebuilding transitions that happened after a terminated
or truncated
signal is set to True
.
This is done in the tmrl
implementation of TorchMemory for TrackMania.
We now have our memory_cls
argument:
memory_cls = partial(MyMemory,
act_buf_len=my_config["act_buf_len"])
The training_agent_cls
expects an implementation of the TrainingAgent
abstract class.
TrainingAgent
is where you can implement your actual RL training algorithm.
The interface of TrainingAgent
is:
class TrainingAgent(ABC):
def __init__(self,
observation_space,
action_space,
device):
"""
observation_space, action_space, and device are here for your convenience.
You are free to use them or not, but your subclass must have them as args or kwargs of __init__().
"""
self.observation_space = observation_space
self.action_space = action_space
self.device = device
@abstractmethod
def train(self, batch):
"""
Executes a training step.
Args:
batch: tuple of batched torch.tensors
(previous observation, action, reward, new observation, terminated, truncated)
Returns:
ret_dict: dictionary: a dictionary containing one entry per metric you wish to log
(e.g. for wandb)
"""
raise NotImplementedError
@abstractmethod
def get_actor(self):
"""
Returns the current ActorModule to be broadcast to the RolloutWorkers.
Returns:
actor: ActorModule: current actor to be broadcast
"""
raise NotImplementedError
This interface has a __init__()
method that is mostly here to remind you that your implementation needs to take at least observation_space
, action_space
, and device
as arguments.
These are for you to use in your implementation.
device
is the device that your algorithm is supposed to use for training and where the batch lives (e.g. "cpu"
or "cuda:0"
), while observation_space
and action_space
are mandatory input to the ActorModule
class (although you don't have to use them: they are simply here for convenience).
In this tutorial, we will be implementing Soft Actor-Critic (SAC) since we have already built a SAC-compatible policy as the ActorModule
of our RolloutWorker
.
First, let us implement a critic module, (we already have our actor from the ActorModule section):
class MyCriticModule(torch.nn.Module):
def __init__(self, observation_space, action_space, hidden_sizes=(256, 256), activation=torch.nn.ReLU):
super().__init__()
obs_dim = sum(prod(s for s in space.shape) for space in observation_space)
act_dim = action_space.shape[0]
self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)
def forward(self, obs, act):
x = torch.cat((*obs, act), -1)
q = self.q(x)
return torch.squeeze(q, -1)
class MyActorCriticModule(torch.nn.Module):
def __init__(self, observation_space, action_space, hidden_sizes=(256, 256), activation=torch.nn.ReLU):
super().__init__()
# our ActorModule:
self.actor = MyActorModule(observation_space, action_space, hidden_sizes, activation)
# double Q networks:
self.q1 = MyCriticModule(observation_space, action_space, hidden_sizes, activation)
self.q2 = MyCriticModule(observation_space, action_space, hidden_sizes, activation)
Our custom TrainingAgent
subclass must take the aforementioned args/kwargs, and can take any user-defined additional kwargs.
Again, here, we simply adapt the SAC implementation from Spinup, but of course you can implement whatever you want instead:
from tmrl.training import TrainingAgent
from tmrl.custom.utils.nn import copy_shared, no_grad
from tmrl.util import cached_property
from torch.optim import Adam
from copy import deepcopy
import itertools
class MyTrainingAgent(TrainingAgent):
model_nograd = cached_property(lambda self: no_grad(copy_shared(self.model)))
def __init__(self,
observation_space=None,
action_space=None,
device=None,
model_cls=MyActorCriticModule, # an actor-critic module, encapsulating our ActorModule
gamma=0.99, # discount factor
polyak=0.995, # exponential averaging factor for the target critic
alpha=0.2, # fixed (SAC v1) or initial (SAC v2) value of the entropy coefficient
lr_actor=1e-3, # learning rate for the actor
lr_critic=1e-3, # learning rate for the critic
lr_entropy=1e-3, # entropy autotuning coefficient (SAC v2)
learn_entropy_coef=True, # if True, SAC v2 is used, else, SAC v1 is used
target_entropy=None): # if None, the target entropy for SAC v2 is set automatically
super().__init__(observation_space=observation_space,
action_space=action_space,
device=device)
model = model_cls(observation_space, action_space)
self.model = model.to(device)
self.model_target = no_grad(deepcopy(self.model))
self.gamma = gamma
self.polyak = polyak
self.alpha = alpha
self.lr_actor = lr_actor
self.lr_critic = lr_critic
self.lr_entropy = lr_entropy
self.learn_entropy_coef = learn_entropy_coef
self.target_entropy = target_entropy
self.q_params = itertools.chain(self.model.q1.parameters(), self.model.q2.parameters())
self.pi_optimizer = Adam(self.model.actor.parameters(), lr=self.lr_actor)
self.q_optimizer = Adam(self.q_params, lr=self.lr_critic)
if self.target_entropy is None:
self.target_entropy = -np.prod(action_space.shape).astype(np.float32)
else:
self.target_entropy = float(self.target_entropy)
if self.learn_entropy_coef:
self.log_alpha = torch.log(torch.ones(1, device=self.device) * self.alpha).requires_grad_(True)
self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=self.lr_entropy)
else:
self.alpha_t = torch.tensor(float(self.alpha)).to(self.device)
The get_actor()
method outputs the ActorModule
to be broadcast to the RolloutWorkers
:
def get_actor(self):
return self.model_nograd.actor
And finally, for the training algorithm itself, we simply adapt the SAC Spinup implementation to the train()
signature.
Note that train()
returns a python dictionary in which you can store the metrics you wish to be logged automatically on wandb
:
def train(self, batch):
"""
Adapted from the SAC implementation of OpenAI Spinup
https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/sac
"""
o, a, r, o2, d, _ = batch # these tensors are collated on device
# note that we purposefully ignore the truncated signal ( _ )
# thus, our value estimator will not be affected by episode truncation
pi, logp_pi = self.model.actor(o)
loss_alpha = None
if self.learn_entropy_coef:
alpha_t = torch.exp(self.log_alpha.detach())
loss_alpha = -(self.log_alpha * (logp_pi + self.target_entropy).detach()).mean()
else:
alpha_t = self.alpha_t
if loss_alpha is not None:
self.alpha_optimizer.zero_grad()
loss_alpha.backward()
self.alpha_optimizer.step()
q1 = self.model.q1(o, a)
q2 = self.model.q2(o, a)
with torch.no_grad():
a2, logp_a2 = self.model.actor(o2)
q1_pi_targ = self.model_target.q1(o2, a2)
q2_pi_targ = self.model_target.q2(o2, a2)
q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
backup = r + self.gamma * (1 - d) * (q_pi_targ - alpha_t * logp_a2)
loss_q1 = ((q1 - backup)**2).mean()
loss_q2 = ((q2 - backup)**2).mean()
loss_q = loss_q1 + loss_q2
self.q_optimizer.zero_grad()
loss_q.backward()
self.q_optimizer.step()
for p in self.q_params:
p.requires_grad = False
q1_pi = self.model.q1(o, pi)
q2_pi = self.model.q2(o, pi)
q_pi = torch.min(q1_pi, q2_pi)
loss_pi = (alpha_t * logp_pi - q_pi).mean()
self.pi_optimizer.zero_grad()
loss_pi.backward()
self.pi_optimizer.step()
for p in self.q_params:
p.requires_grad = True
with torch.no_grad():
for p, p_targ in zip(self.model.parameters(), self.model_target.parameters()):
p_targ.data.mul_(self.polyak)
p_targ.data.add_((1 - self.polyak) * p.data)
ret_dict = dict(
loss_actor=loss_pi.detach().item(),
loss_critic=loss_q.detach().item(),
)
if self.learn_entropy_coef:
ret_dict["loss_entropy_coef"] = loss_alpha.detach().item()
ret_dict["entropy_coef"] = alpha_t.item()
return ret_dict # dictionary of metrics to be logged
This gives us our training_agent_cls
argument, e.g.:
training_agent_cls = partial(MyTrainingAgent,
model_cls=MyActorCriticModule,
gamma=0.99,
polyak=0.995,
alpha=0.2,
lr_actor=1e-3,
lr_critic=1e-3,
lr_entropy=1e-3,
learn_entropy_coef=False,
target_entropy=None)
There are no epochs in RL.
In tmrl
, we call epoch
the moment when the Trainer
checkpoints the training session on the hard drive and sends training metrics to wandb
.
An epoch
is made of a fixed number of rounds
, and a round
is made of a fixed number of training steps
.
These values are very arbitrary and you can set mostly whatever you like depending on how often you want to see metrics logged (they are printed at the end of each round
and logged to wandb
at the end of each epoch
):
epochs = 10 # maximum number of epochs, usually set this to np.inf
rounds = 10 # number of rounds per epoch
steps = 1000 # number of training steps per round
update_buffer_interval
defines how often we want to check for incoming samples from the Server
.
If it is set to 100, we will check for available new samples every 100 training steps
:
update_buffer_interval = 100
update_model_interval
defines how often we want to send the model to the Server
to be broadcast to the RolloutWorkers
.
If set to 1000, the model will be sent at the end of each round in our example:
update_model_interval = 1000
max_training_steps_per_env_step
enables limiting the impact of the asynchronous nature of training in tmrl
.
If set to, e.g., 2.0, training will pause until new samples are available when 2.0 times more training steps have been performed compared to the number of samples (i.e., environment steps) that the Trainer
has received:
max_training_steps_per_env_step = 2.0
start_training
is the number of samples that the Trainer
will wait for at the beginning before starting training.
If set to 500, training will start only after 500 environment steps are collected:
start_training = 500
device
is the device on which training will take place (it is the device
parameter that will be passed to TorchMemory
and TrainingAgent
).
If set to None
, the training device will be selected automatically:
device = None
A couple more options not used in this tutorial are available.
In particular, profiling
enables profiling training (but this doesn't work well with CUDA), and agent_scheduler
enables changing the TrainingAgent
parameters during training.
We finally have our training class:
from tmrl.training_offline import TorchTrainingOffline
training_cls = partial(
TorchTrainingOffline,
env_cls=env_cls,
memory_cls=memory_cls,
training_agent_cls=training_agent_cls,
epochs=epochs,
rounds=rounds,
steps=steps,
update_buffer_interval=update_buffer_interval,
update_model_interval=update_model_interval,
max_training_steps_per_env_step=max_training_steps_per_env_step,
start_training=start_training,
device=device)
We can now instantiate our Trainer
.
from tmrl.networking import Trainer
my_trainer = Trainer(
training_cls=training_cls,
server_ip=server_ip,
server_port=server_port,
password=password,
model_path=model_path,
checkpoint_path=checkpoints_path) # None for not saving training checkpoints
On instantiation, the Trainer
will automatically connect to the Server
.
We can now start training.
Either by calling the simple run()
method:
my_trainer.run()
Or by calling run_with_wandb()
, which is the same thing as run()
but logs your training metrics on wandb at the end of each epoch
:
my_wandb_entity="your_wandb_entity_here"
my_wandb_project="your_wandb_project_here"
my_wandb_run_id="your_wandb_run_name_here"
my_wandb_key="your_wandb_key_here"
my_trainer.run_with_wandb(entity=my_wandb_entity,
project=my_wandb_project,
run_id=my_wandb_run_id,
key=my_wandb_key)
(WARNING: when using run_with_wandb
, make sure all the partially instantiated classes that are part of the Trainer
have kwargs only, no args, otherwise you will get an error complaining about invalid keywords.
When it does not make sense to have default values, just set the default values to None
as done in, e.g., MyMemory
)
But as for the RolloutWorker
, this would block the code here until all epochs
are complete, which in itself would require the RolloutWorker
to also be running.
In fact, the RolloutWorker
, Trainer
, and Server
are best run in separate terminals or machines (see TrackMania).
However, for the sake of this tutorial, we will instantiate and run all of them in the same script by using python threads
(of course, you are free to implement them in separate scripts on your end, actually, it is even strongly recommended):
def run_worker(worker):
worker.run(test_episode_interval=10)
def run_trainer(trainer):
trainer.run()
daemon_thread_worker = Thread(target=run_worker, args=(my_worker, ), kwargs={}, daemon=True)
daemon_thread_worker.start() # start the worker daemon thread
run_trainer(my_trainer)
# the worker daemon thread will be killed here.
Since we did not set epochs=np.inf
, this code will reach completion at some point (but you will not be able to train for more epochs).
The worker thread will simply be killed then.
If you have followed this tutorial carefully, you will now see the dummy RC drone (blue circle) slowly training to reach the red target.
And that is mostly all, folks! 😄
You have probably noticed that implementing your own compression/decompression pipeline is extremely error-prone.
tmrl
provides a useful tool for debugging your pipeline: "CRC debugging".
"CRC" stands for "Cyclic Redundancy Control". This is a way of checking that data has not been corrupted in the pipeline.
In fact, there is no proper CRC in tmrl
at the moment, but a tool that does essentially the same thing.
The "CRC debugging" tool should only be used for debugging as it will completely destroy the benefit of having a compression pipeline at all when turned on. Here is what it does:
In crc_debug
mode, the RolloutWorker
will store the full transition in the info
dictionary of each sample.
In crc_debug
mode, your TorchMemory
will convert each sampled transition into a python string, and compare this string with the one obtained from the transition in the info
dictionary.
If something does not match, the program will stop and you will be shown what is mismatched.
Otherwise, you will get a "CRC check passed" message printed in the terminal for each sample correctly rebuilt.
We recommend using the crc_debug
mode as a sanity check whenever you implement a compression/decompression pipeline.
To activate this mode, set the crc_debug
arguments to True
for both your RolloutWorker
and TorchMemory
instances.
Let us do this for our implemented pipeline and see what happens:
INFO:root: Resuming training
DEBUG: CRC check passed. Time step: 91, since reset: 90
DEBUG: CRC check passed. Time step: 320, since reset: 16
DEBUG: CRC check passed. Time step: 135, since reset: 33
DEBUG: CRC check passed. Time step: 96, since reset: 95
DEBUG: CRC check passed. Time step: 24, since reset: 23
DEBUG: CRC check passed. Time step: 259, since reset: 56
Traceback (most recent call last):
File ".../tmrl/tmrl/tuto/tuto.py", line 535, in <module>
run_trainer(my_trainer)
File ".../tmrl/tmrl/tuto/tuto.py", line 528, in run_trainer
trainer.run()
File ".../tmrl/tmrl/networking.py", line 394, in run
run(interface=self.interface,
File ".../tmrl/tmrl/networking.py", line 327, in run
for stats in iterate_epochs(...):
File ".../tmrl/tmrl/networking.py", line 270, in iterate_epochs
yield run_instance.run_epoch(interface=interface)
File ".../tmrl/tmrl/training_offline.py", line 114, in run_epoch
for batch in self.memory:
File ".../tmrl/tmrl/memory.py", line 89, in __iter__
yield self.sample()
File ".../tmrl/tmrl/memory.py", line 152, in sample
batch = [self[idx] for idx in indices]
File ".../tmrl/tmrl/memory.py", line 152, in <listcomp>
batch = [self[idx] for idx in indices]
File ".../tmrl/tmrl/memory.py", line 169, in __getitem__
check_samples_crc(...)
File ".../tmrl/tmrl/memory.py", line 22, in check_samples_crc
AssertionError: previous observations don't match:
original:
(array([0.34098563], dtype=float32),
array([-0.63508004], dtype=float32),
array([0.44789273], dtype=float32),
array([0.39232507], dtype=float32),
array([0.29402873, 1.3678434 ], dtype=float32),
array([-1.8195685, -1.4343303], dtype=float32),
array([0., 0.], dtype=float32),
array([-0.07386027, -1.6602786 ], dtype=float32))
!= rebuilt:
(array([0.34098563], dtype=float32),
array([-0.63508004], dtype=float32),
array([0.44789273], dtype=float32),
array([0.39232507], dtype=float32),
array([-1.8195685, -1.4343303], dtype=float32),
array([0.50566 , 1.7723236], dtype=float32),
array([0., 0.], dtype=float32),
array([-0.07386027, -1.6602786 ], dtype=float32))
Time step: 306, since reset: 2
(I have parsed the error a little bit to make it easier for you to analyze.)
Oh no! What is going on? 😱
If you have carefully read the entire tutorial, well, first, congratulations, but now you should be able to spot exactly what is wrong here and understand why this happens.
Do you have it? 😀
The "since reset" part indicates that the offending observation originates from a sample that was collected 2 time-steps after a reset transition in this examples, whereas the "CRC passed" messages indicate that many samples originating from later transitions have been reconstructed successfully.
The last four arrays of each observation are the action buffers included in real-time observations.
Notice the array([0., 0.], dtype=float32)
in both the original and the rebuilt observations: this is what the dummy RC drone environment defines as its default action.
Everything after this default action is correct, whereas everything before this action is shifted by one time-step in the action buffer.
This issue happens because we weren't careful enough when rebuilding the action buffer after reset transitions.
reset()
is a discontinuity in the Gymnnasium
standard, and it needs to be handled with special care in real-time scenarios.
Specifically, in rtgym
environments, the very last action of a trajectory is never sent to the robot.
This is because, since the previously captured state is terminal (or truncated), it would never make sense to send this last action to the robot in the first place
(in fact, sending it could even be harmful: the last action is theoretically random since it has no captured effect in the underlying MDP - note that it is also worthwhile to think about this for longer delays).
On the other hand, when reset()
is called in an rtgym
environment, it sends whatever is currently defined as the "default" action to the robot (i.e., array([0., 0.], dtype=float32)
in our case).
In a nutshell, the very last action of the previous trajectory is never part of the post-reset action buffer in rtgym
environments.
Therefore, we need to make a couple modifications to our custom Memory
.
First, add this helper to your script:
# Helper function to spot reset transitions:
def last_true_in_list(li):
"""
Returns the index of the last True element in list li, or None.
"""
for i in reversed(range(len(li))):
if li[i]:
return i
return None
Then, add a "done" list to your data that checks whether terminated
or truncated
is True
in append_buffer()
.
Finally, use this new list in get_transition()
to spot reset transitions and update the action buffer accordingly:
class MyMemory(TorchMemory):
def __init__(self,
act_buf_len=None,
device=None,
nb_steps=None,
sample_preprocessor: callable = None,
memory_size=1000000,
batch_size=32,
dataset_path=""):
self.act_buf_len = act_buf_len # length of the action buffer
super().__init__(device=device,
nb_steps=nb_steps,
sample_preprocessor=sample_preprocessor,
memory_size=memory_size,
batch_size=batch_size,
dataset_path=dataset_path,
crc_debug=CRC_DEBUG)
def append_buffer(self, buffer):
"""
buffer.memory is a list of compressed (act_mod, new_obs_mod, rew_mod, terminated_mod, truncated_mod, info_mod) samples
"""
# decompose compressed samples into their relevant components:
list_action = [b[0] for b in buffer.memory]
list_x_position = [b[1][0] for b in buffer.memory]
list_y_position = [b[1][1] for b in buffer.memory]
list_x_target = [b[1][2] for b in buffer.memory]
list_y_target = [b[1][3] for b in buffer.memory]
list_reward = [b[2] for b in buffer.memory]
list_terminated = [b[3] for b in buffer.memory]
list_truncated = [b[4] for b in buffer.memory]
list_info = [b[5] for b in buffer.memory]
list_done = [b[3] or b[4] for b in buffer.memory]
# append to self.data in some arbitrary way:
if self.__len__() > 0:
self.data[0] += list_action
self.data[1] += list_x_position
self.data[2] += list_y_position
self.data[3] += list_x_target
self.data[4] += list_y_target
self.data[5] += list_reward
self.data[6] += list_terminated
self.data[7] += list_info
self.data[8] += list_truncated
self.data[9] += list_done
else:
self.data.append(list_action)
self.data.append(list_x_position)
self.data.append(list_y_position)
self.data.append(list_x_target)
self.data.append(list_y_target)
self.data.append(list_reward)
self.data.append(list_terminated)
self.data.append(list_info)
self.data.append(list_truncated)
self.data.append(list_done)
# trim self.data in some arbitrary way when self.__len__() > self.memory_size:
to_trim = self.__len__() - self.memory_size
if to_trim > 0:
self.data[0] = self.data[0][to_trim:]
self.data[1] = self.data[1][to_trim:]
self.data[2] = self.data[2][to_trim:]
self.data[3] = self.data[3][to_trim:]
self.data[4] = self.data[4][to_trim:]
self.data[5] = self.data[5][to_trim:]
self.data[6] = self.data[6][to_trim:]
self.data[7] = self.data[7][to_trim:]
self.data[8] = self.data[8][to_trim:]
self.data[9] = self.data[9][to_trim:]
def __len__(self):
if len(self.data) == 0:
return 0 # self.data is empty
result = len(self.data[0]) - self.act_buf_len - 1
if result < 0:
return 0 # not enough samples to reconstruct the action buffer
else:
return result # we can reconstruct that many samples
def get_transition(self, item):
"""
Args:
item: int: indice of the transition that the Trainer wants to sample
Returns:
full transition: (last_obs, new_act, rew, new_obs, terminated, truncated, info)
"""
while True: # this enables modifying item in edge cases
# if item corresponds to a transition from a terminal state to a reset state
if self.data[9][item + self.act_buf_len - 1]:
# this wouldn't make sense in RL, so we replace item by a neighbour transition
if item == 0: # if first item of the buffer
item += 1
elif item == self.__len__() - 1: # if last item of the buffer
item -= 1
elif random.random() < 0.5: # otherwise, sample randomly
item += 1
else:
item -= 1
idx_last = item + self.act_buf_len - 1 # index of previous observation
idx_now = item + self.act_buf_len # index of new observation
# rebuild the action buffer of both observations:
actions = self.data[0][item:(item + self.act_buf_len + 1)]
last_act_buf = actions[:-1] # action buffer of previous observation
new_act_buf = actions[1:] # action buffer of new observation
# correct the action buffer when it goes over a reset transition:
# (NB: we have eliminated the case where the transition *is* the reset transition)
eoe = last_true_in_list(self.data[9][item:(item + self.act_buf_len)]) # the last one is not important
if eoe is not None:
# either one or both action buffers are passing over a reset transition
if eoe < self.act_buf_len - 1:
# last_act_buf is concerned
if item == 0:
# we have a problem: the previous action has been discarded; we cannot recover the buffer
# in this edge case, we randomly sample another item
item = random.randint(1, self.__len__())
continue
last_act_buf_eoe = eoe
# replace everything before last_act_buf_eoe by the previous action
prev_act = self.data[0][item - 1]
for idx in range(last_act_buf_eoe + 1):
act_tmp = last_act_buf[idx]
last_act_buf[idx] = prev_act
prev_act = act_tmp
if eoe > 0:
# new_act_buf is concerned
new_act_buf_eoe = eoe - 1
# replace everything before new_act_buf_eoe by the previous action
prev_act = self.data[0][item]
for idx in range(new_act_buf_eoe + 1):
act_tmp = new_act_buf[idx]
new_act_buf[idx] = prev_act
prev_act = act_tmp
# rebuild the previous observation:
last_obs = (self.data[1][idx_last], # x position
self.data[2][idx_last], # y position
self.data[3][idx_last], # x target
self.data[4][idx_last], # y target
*last_act_buf) # action buffer
# rebuild the new observation:
new_obs = (self.data[1][idx_now], # x position
self.data[2][idx_now], # y position
self.data[3][idx_now], # x target
self.data[4][idx_now], # y target
*new_act_buf) # action buffer
# other components of the transition:
new_act = self.data[0][idx_now] # action
rew = np.float32(self.data[5][idx_now]) # reward
terminated = self.data[6][idx_now] # terminated signal
truncated = self.data[8][idx_now] # truncated signal
info = self.data[7][idx_now] # info dictionary
break
return last_obs, new_act, rew, new_obs, terminated, truncated, info
Alright! This was a bit hacky, but that should work. Let us delete the automatically saved checkpoint file in TmrlData/checkpoints
and retry:
DEBUG: CRC check passed. Time step: 652, since reset: 74
DEBUG: CRC check passed. Time step: 105, since reset: 33
DEBUG: CRC check passed. Time step: 319, since reset: 1
DEBUG: CRC check passed. Time step: 847, since reset: 67
DEBUG: CRC check passed. Time step: 738, since reset: 59
DEBUG: CRC check passed. Time step: 233, since reset: 7
DEBUG: CRC check passed. Time step: 729, since reset: 50
DEBUG: CRC check passed. Time step: 565, since reset: 88
DEBUG: CRC check passed. Time step: 868, since reset: 88
DEBUG: CRC check passed. Time step: 389, since reset: 1
DEBUG: CRC check passed. Time step: 335, since reset: 17
DEBUG: CRC check passed. Time step: 217, since reset: 11
DEBUG: CRC check passed. Time step: 561, since reset: 84
DEBUG: CRC check passed. Time step: 859, since reset: 79
DEBUG: CRC check passed. Time step: 223, since reset: 17
DEBUG: CRC check passed. Time step: 829, since reset: 49
DEBUG: CRC check passed. Time step: 375, since reset: 4
DEBUG: CRC check passed. Time step: 378, since reset: 7
(...)
Finally! At this point, you can be fairly confident that your pipeline works perfectly 😎
You can now disable CRC-debugging, delete checkpoints and weights in TmrlData
, and start training seriously.
(Note 1: I have not conducted any hyperparameter tuning when writing this tutorial and I have selected most values randomly, so it is very likely you can find much better training hyperparameters for this toy task if you like to try. However, be mindful that this task is much harder than it looks: the dummy RC drone has random and fairly long action and observation delays, which makes reaching the target difficult for vanilla RL algorithms like SAC.)
(Note 2: Although in this tutorial we have run the RolloutWorker
and the Trainer
on the same CPU/GPU, this is of course not recommended in real applications.
Since the environment is real-time, training may introduce noise in the time-step duration despite the best effort of rtgym
to prevent this from happening.
If you see rtgym
warning you against time-step timeouts, this is probably because the Trainer
is slowing it down too much.)
(Note 3: If you have set model_history > 0
, you will find the model history in your weights
folder.
Note also that everything will be checkpointed, so unless you empty your checkpoints
and weights
folders or change the run name, you will not be able to restart training from scratch.)
(Note 4: Thank you for reading the tmrl
tutorial.
You are now ready for the Real Life 🚀 )