Skip to content

Commit

Permalink
[Refactor] Rename Recorder and LogReward (#2616)
Browse files Browse the repository at this point in the history
  • Loading branch information
raresdan authored Dec 3, 2024
1 parent d22266d commit 607ebc5
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 38 deletions.
8 changes: 4 additions & 4 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

Expand Down Expand Up @@ -174,9 +174,9 @@ Trainer and hooks
BatchSubSampler
ClearCudaCache
CountFramesLog
LogReward
LogScaler
OptimizerHook
Recorder
LogValidationReward
ReplayBufferTrainer
RewardNormalizer
SelectKeys
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
Trainer,
Expand Down Expand Up @@ -331,7 +331,7 @@ def make_trainer(

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -347,7 +347,7 @@ def make_trainer(
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
recorder_obj_explore = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -369,7 +369,7 @@ def make_trainer(
"post_steps", UpdateWeights(collector, update_weights_interval=1)
)

trainer.register_op("pre_steps_log", LogReward())
trainer.register_op("pre_steps_log", LogScalar())
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip))

return trainer
Expand Down
27 changes: 13 additions & 14 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
TensorDictReplayBuffer,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.trainers import Recorder, Trainer
from torchrl.trainers import LogValidationReward, Trainer
from torchrl.trainers.helpers import transformed_env_constructor
from torchrl.trainers.trainers import (
_has_tqdm,
_has_ts,
BatchSubSampler,
CountFramesLog,
LogReward,
LogScalar,
mask_batch,
OptimizerHook,
ReplayBufferTrainer,
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
trainer.register_op("pre_steps_log", log_reward)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand All @@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
log_reward.register(trainer)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand Down Expand Up @@ -873,7 +873,7 @@ def test_recorder(self, N=8):
logger=logger,
)()

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand Down Expand Up @@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8):
os.environ["CKPT_BACKEND"] = backend
state_dict_has_been_called = [False]
load_state_dict_has_been_called = [False]
Recorder.state_dict, Recorder_state_dict = _fun_checker(
Recorder.state_dict, state_dict_has_been_called
LogValidationReward.state_dict, Recorder_state_dict = _fun_checker(
LogValidationReward.state_dict, state_dict_has_been_called
)
(LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker(
LogValidationReward.load_state_dict, load_state_dict_has_been_called
)
(
Recorder.load_state_dict,
Recorder_load_state_dict,
) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called)

args = self._get_args()

Expand All @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
)()
environment.rollout(2)

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand All @@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname):
assert recorder2._count == 8
assert state_dict_has_been_called[0]
assert load_state_dict_has_been_called[0]
Recorder.state_dict = Recorder_state_dict
Recorder.load_state_dict = Recorder_load_state_dict
LogValidationReward.state_dict = Recorder_state_dict
LogValidationReward.load_state_dict = Recorder_load_state_dict


def test_updateweights():
Expand Down
2 changes: 2 additions & 0 deletions torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
ClearCudaCache,
CountFramesLog,
LogReward,
LogScalar,
LogValidationReward,
mask_batch,
OptimizerHook,
Recorder,
Expand Down
10 changes: 5 additions & 5 deletions torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
Expand Down Expand Up @@ -259,7 +259,7 @@ def make_trainer(

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -275,7 +275,7 @@ def make_trainer(
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
recorder_obj_explore = LogValidationReward(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -297,7 +297,7 @@ def make_trainer(
"post_steps", UpdateWeights(collector, update_weights_interval=1)
)

trainer.register_op("pre_steps_log", LogReward())
trainer.register_op("pre_steps_log", LogScalar())
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip))

return trainer
61 changes: 58 additions & 3 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def __call__(self, *args, **kwargs):
torch.cuda.empty_cache()


class LogReward(TrainerHookBase):
class LogScalar(TrainerHookBase):
"""Reward logger hook.
Args:
Expand All @@ -833,7 +833,7 @@ class LogReward(TrainerHookBase):
in the input batch. Defaults to ``("next", "reward")``
Examples:
>>> log_reward = LogReward(("next", "reward"))
>>> log_reward = LogScalar(("next", "reward"))
>>> trainer.register_op("pre_steps_log", log_reward)
"""
Expand Down Expand Up @@ -870,6 +870,23 @@ def register(self, trainer: Trainer, name: str = "log_reward"):
trainer.register_module(name, self)


class LogReward(LogScalar):
"""Deprecated class. Use LogScalar instead."""

def __init__(
self,
logname="r_training",
log_pbar: bool = False,
reward_key: Union[str, tuple] = None,
):
warnings.warn(
"The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key)


class RewardNormalizer(TrainerHookBase):
"""Reward normalizer hook.
Expand Down Expand Up @@ -1127,7 +1144,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"):
trainer.register_module(name, self)


class Recorder(TrainerHookBase):
class LogValidationReward(TrainerHookBase):
"""Recorder hook for :class:`~torchrl.trainers.Trainer`.
Args:
Expand Down Expand Up @@ -1264,6 +1281,44 @@ def register(self, trainer: Trainer, name: str = "recorder"):
)


class Recorder(LogValidationReward):
"""Deprecated class. Use LogValidationReward instead."""

def __init__(
self,
*,
record_interval: int,
record_frames: int,
frame_skip: int = 1,
policy_exploration: TensorDictModule,
environment: EnvBase = None,
exploration_type: ExplorationType = ExplorationType.RANDOM,
log_keys: Optional[List[Union[str, Tuple[str]]]] = None,
out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None,
suffix: Optional[str] = None,
log_pbar: bool = False,
recorder: EnvBase = None,
) -> None:
warnings.warn(
"The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(
record_interval=record_interval,
record_frames=record_frames,
frame_skip=frame_skip,
policy_exploration=policy_exploration,
environment=environment,
exploration_type=exploration_type,
log_keys=log_keys,
out_keys=out_keys,
suffix=suffix,
log_pbar=log_pbar,
recorder=recorder,
)


class UpdateWeights(TrainerHookBase):
"""A collector weights update hook class.
Expand Down
6 changes: 3 additions & 3 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,12 +883,12 @@ def make_ddpg_actor(
#
# As the training data is obtained using some exploration strategy, the true
# performance of our algorithm needs to be assessed in deterministic mode. We
# do this using a dedicated class, ``Recorder``, which executes the policy in
# do this using a dedicated class, ``LogValidationReward``, which executes the policy in
# the environment at a given frequency and returns some statistics obtained
# from these simulations.
#
# The following helper function builds this object:
from torchrl.trainers import Recorder
from torchrl.trainers import LogValidationReward


def make_recorder(actor_model_explore, transform_state_dict, record_interval):
Expand All @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
) # must be instantiated to load the state dict
environment.transform[2].load_state_dict(transform_state_dict)

recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=1000,
policy_exploration=actor_model_explore,
environment=environment,
Expand Down
8 changes: 4 additions & 4 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
Trainer,
UpdateWeights,
Expand Down Expand Up @@ -666,7 +666,7 @@ def get_loss_module(actor, gamma):
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
recorder = LogValidationReward(
record_interval=100, # log every 100 optimization steps
record_frames=1000, # maximum number of frames in the record
frame_skip=1,
Expand Down Expand Up @@ -704,7 +704,7 @@ def get_loss_module(actor, gamma):
# This will be reflected by the `total_rewards` value displayed in the
# progress bar.
#
log_reward = LogReward(log_pbar=True)
log_reward = LogScalar(log_pbar=True)
log_reward.register(trainer)

###############################################################################
Expand Down

1 comment on commit 607ebc5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 607ebc5 Previous: d22266d Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 39.912925100474624 iter/sec (stddev: 0.1486640141965768) 251.12279639015262 iter/sec (stddev: 0.0005260375686386335) 6.29

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.