Skip to content

Commit

Permalink
[Feature Request] Rename Recorder and LogReward
Browse files Browse the repository at this point in the history
  • Loading branch information
raresdan authored and vmoens committed Dec 2, 2024
1 parent 90c8e40 commit 56f5c3e
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 38 deletions.
8 changes: 4 additions & 4 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

Expand Down Expand Up @@ -174,9 +174,9 @@ Trainer and hooks
BatchSubSampler
ClearCudaCache
CountFramesLog
LogReward
LogScaler
OptimizerHook
Recorder
LogValidationReward
ReplayBufferTrainer
RewardNormalizer
SelectKeys
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
Trainer,
Expand Down Expand Up @@ -331,7 +331,7 @@ def make_trainer(

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -347,7 +347,7 @@ def make_trainer(
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
recorder_obj_explore = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -369,7 +369,7 @@ def make_trainer(
"post_steps", UpdateWeights(collector, update_weights_interval=1)
)

trainer.register_op("pre_steps_log", LogReward())
trainer.register_op("pre_steps_log", LogScalar())
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip))

return trainer
Expand Down
24 changes: 12 additions & 12 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
TensorDictReplayBuffer,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.trainers import Recorder, Trainer
from torchrl.trainers import Trainer, LogValidationReward
from torchrl.trainers.helpers import transformed_env_constructor
from torchrl.trainers.trainers import (
_has_tqdm,
_has_ts,
BatchSubSampler,
CountFramesLog,
LogReward,
LogScalar,
mask_batch,
OptimizerHook,
ReplayBufferTrainer,
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
trainer.register_op("pre_steps_log", log_reward)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand All @@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
log_reward.register(trainer)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand Down Expand Up @@ -873,7 +873,7 @@ def test_recorder(self, N=8):
logger=logger,
)()

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand Down Expand Up @@ -919,13 +919,13 @@ def test_recorder_load(self, backend, N=8):
os.environ["CKPT_BACKEND"] = backend
state_dict_has_been_called = [False]
load_state_dict_has_been_called = [False]
Recorder.state_dict, Recorder_state_dict = _fun_checker(
Recorder.state_dict, state_dict_has_been_called
LogValidationReward.state_dict, Recorder_state_dict = _fun_checker(
LogValidationReward.state_dict, state_dict_has_been_called
)
(
Recorder.load_state_dict,
LogValidationReward.load_state_dict,
Recorder_load_state_dict,
) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called)
) = _fun_checker(LogValidationReward.load_state_dict, load_state_dict_has_been_called)

args = self._get_args()

Expand All @@ -948,7 +948,7 @@ def _make_recorder_and_trainer(tmpdirname):
)()
environment.rollout(2)

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand All @@ -969,8 +969,8 @@ def _make_recorder_and_trainer(tmpdirname):
assert recorder2._count == 8
assert state_dict_has_been_called[0]
assert load_state_dict_has_been_called[0]
Recorder.state_dict = Recorder_state_dict
Recorder.load_state_dict = Recorder_load_state_dict
LogValidationReward.state_dict = Recorder_state_dict
LogValidationReward.load_state_dict = Recorder_load_state_dict


def test_updateweights():
Expand Down
4 changes: 2 additions & 2 deletions torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
LogScalar,
mask_batch,
OptimizerHook,
Recorder,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
Expand Down
10 changes: 5 additions & 5 deletions torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
Expand Down Expand Up @@ -259,7 +259,7 @@ def make_trainer(

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -275,7 +275,7 @@ def make_trainer(
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
recorder_obj_explore = LogValidationReward(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -297,7 +297,7 @@ def make_trainer(
"post_steps", UpdateWeights(collector, update_weights_interval=1)
)

trainer.register_op("pre_steps_log", LogReward())
trainer.register_op("pre_steps_log", LogScalar())
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip))

return trainer
53 changes: 50 additions & 3 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def __call__(self, *args, **kwargs):
torch.cuda.empty_cache()


class LogReward(TrainerHookBase):
class LogScalar(TrainerHookBase):
"""Reward logger hook.
Args:
Expand All @@ -833,7 +833,7 @@ class LogReward(TrainerHookBase):
in the input batch. Defaults to ``("next", "reward")``
Examples:
>>> log_reward = LogReward(("next", "reward"))
>>> log_reward = LogScalar(("next", "reward"))
>>> trainer.register_op("pre_steps_log", log_reward)
"""
Expand Down Expand Up @@ -870,6 +870,20 @@ def register(self, trainer: Trainer, name: str = "log_reward"):
trainer.register_module(name, self)


class LogReward(LogScalar):
"""Deprecated class. Use LogScalar instead."""
def __init__(self,
logname="r_training",
log_pbar: bool = False,
reward_key: Union[str, tuple] = None):
warnings.warn("The 'LogReward' class is deprecated and will be removed in a future release. Please use 'LogScalar' instead.",
DeprecationWarning,
stacklevel=2)
super().__init__(logname=logname,
log_pbar=log_pbar,
reward_key=reward_key)


class RewardNormalizer(TrainerHookBase):
"""Reward normalizer hook.
Expand Down Expand Up @@ -1127,7 +1141,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"):
trainer.register_module(name, self)


class Recorder(TrainerHookBase):
class LogValidationReward(TrainerHookBase):
"""Recorder hook for :class:`~torchrl.trainers.Trainer`.
Args:
Expand Down Expand Up @@ -1264,6 +1278,39 @@ def register(self, trainer: Trainer, name: str = "recorder"):
)


class Recorder(LogValidationReward):
"""Deprecated class. Use LogValidationReward instead."""
def __init__(
self,
*,
record_interval: int,
record_frames: int,
frame_skip: int = 1,
policy_exploration: TensorDictModule,
environment: EnvBase = None,
exploration_type: ExplorationType = ExplorationType.RANDOM,
log_keys: Optional[List[Union[str, Tuple[str]]]] = None,
out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None,
suffix: Optional[str] = None,
log_pbar: bool = False,
recorder: EnvBase = None,
) -> None:
warnings.warn("The 'Recorder' class is deprecated and will be removed in a future release. Please use 'LogValidationReward' instead.",
DeprecationWarning,
stacklevel=2)
super().__init__(record_interval=record_interval,
record_frames=record_frames,
frame_skip=frame_skip,
policy_exploration=policy_exploration,
environment=environment,
exploration_type=exploration_type,
log_keys=log_keys,
out_keys=out_keys,
suffix=suffix,
log_pbar=log_pbar,
recorder=recorder)


class UpdateWeights(TrainerHookBase):
"""A collector weights update hook class.
Expand Down
6 changes: 3 additions & 3 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,12 +883,12 @@ def make_ddpg_actor(
#
# As the training data is obtained using some exploration strategy, the true
# performance of our algorithm needs to be assessed in deterministic mode. We
# do this using a dedicated class, ``Recorder``, which executes the policy in
# do this using a dedicated class, ``LogValidationReward``, which executes the policy in
# the environment at a given frequency and returns some statistics obtained
# from these simulations.
#
# The following helper function builds this object:
from torchrl.trainers import Recorder
from torchrl.trainers import LogValidationReward


def make_recorder(actor_model_explore, transform_state_dict, record_interval):
Expand All @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
) # must be instantiated to load the state dict
environment.transform[2].load_state_dict(transform_state_dict)

recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=1000,
policy_exploration=actor_model_explore,
environment=environment,
Expand Down
8 changes: 4 additions & 4 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
Trainer,
UpdateWeights,
Expand Down Expand Up @@ -666,7 +666,7 @@ def get_loss_module(actor, gamma):
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
recorder = LogValidationReward(
record_interval=100, # log every 100 optimization steps
record_frames=1000, # maximum number of frames in the record
frame_skip=1,
Expand Down Expand Up @@ -704,7 +704,7 @@ def get_loss_module(actor, gamma):
# This will be reflected by the `total_rewards` value displayed in the
# progress bar.
#
log_reward = LogReward(log_pbar=True)
log_reward = LogScalar(log_pbar=True)
log_reward.register(trainer)

###############################################################################
Expand Down

0 comments on commit 56f5c3e

Please sign in to comment.