Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Rename Recorder and LogReward #2616

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

Expand Down Expand Up @@ -174,9 +174,9 @@ Trainer and hooks
BatchSubSampler
ClearCudaCache
CountFramesLog
LogReward
LogScaler
OptimizerHook
Recorder
LogValidationReward
ReplayBufferTrainer
RewardNormalizer
SelectKeys
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
Trainer,
Expand Down Expand Up @@ -331,7 +331,7 @@ def make_trainer(

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -347,7 +347,7 @@ def make_trainer(
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
recorder_obj_explore = LogValidationReward(
record_frames=cfg.logger.record_frames,
frame_skip=cfg.env.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -369,7 +369,7 @@ def make_trainer(
"post_steps", UpdateWeights(collector, update_weights_interval=1)
)

trainer.register_op("pre_steps_log", LogReward())
trainer.register_op("pre_steps_log", LogScalar())
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip))

return trainer
Expand Down
27 changes: 13 additions & 14 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
TensorDictReplayBuffer,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.trainers import Recorder, Trainer
from torchrl.trainers import LogValidationReward, Trainer
from torchrl.trainers.helpers import transformed_env_constructor
from torchrl.trainers.trainers import (
_has_tqdm,
_has_ts,
BatchSubSampler,
CountFramesLog,
LogReward,
LogScalar,
mask_batch,
OptimizerHook,
ReplayBufferTrainer,
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
trainer.register_op("pre_steps_log", log_reward)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand All @@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar):
trainer = mocking_trainer()
trainer.collected_frames = 0

log_reward = LogReward(logname, log_pbar=pbar)
log_reward = LogScalar(logname, log_pbar=pbar)
log_reward.register(trainer)
td = TensorDict({REWARD_KEY: torch.ones(3)}, [3])
trainer._pre_steps_log_hook(td)
Expand Down Expand Up @@ -873,7 +873,7 @@ def test_recorder(self, N=8):
logger=logger,
)()

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand Down Expand Up @@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8):
os.environ["CKPT_BACKEND"] = backend
state_dict_has_been_called = [False]
load_state_dict_has_been_called = [False]
Recorder.state_dict, Recorder_state_dict = _fun_checker(
Recorder.state_dict, state_dict_has_been_called
LogValidationReward.state_dict, Recorder_state_dict = _fun_checker(
LogValidationReward.state_dict, state_dict_has_been_called
)
(LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker(
LogValidationReward.load_state_dict, load_state_dict_has_been_called
)
(
Recorder.load_state_dict,
Recorder_load_state_dict,
) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called)

args = self._get_args()

Expand All @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
)()
environment.rollout(2)

recorder = Recorder(
recorder = LogValidationReward(
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
Expand All @@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname):
assert recorder2._count == 8
assert state_dict_has_been_called[0]
assert load_state_dict_has_been_called[0]
Recorder.state_dict = Recorder_state_dict
Recorder.load_state_dict = Recorder_load_state_dict
LogValidationReward.state_dict = Recorder_state_dict
LogValidationReward.load_state_dict = Recorder_load_state_dict


def test_updateweights():
Expand Down
2 changes: 2 additions & 0 deletions torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
ClearCudaCache,
CountFramesLog,
LogReward,
raresdan marked this conversation as resolved.
Show resolved Hide resolved
LogScalar,
LogValidationReward,
mask_batch,
OptimizerHook,
Recorder,
raresdan marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
10 changes: 5 additions & 5 deletions torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
Expand Down Expand Up @@ -259,7 +259,7 @@ def make_trainer(

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -275,7 +275,7 @@ def make_trainer(
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
recorder_obj_explore = LogValidationReward(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
policy_exploration=policy_exploration,
Expand All @@ -297,7 +297,7 @@ def make_trainer(
"post_steps", UpdateWeights(collector, update_weights_interval=1)
)

trainer.register_op("pre_steps_log", LogReward())
trainer.register_op("pre_steps_log", LogScalar())
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip))

return trainer
61 changes: 58 additions & 3 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def __call__(self, *args, **kwargs):
torch.cuda.empty_cache()


class LogReward(TrainerHookBase):
class LogScalar(TrainerHookBase):
"""Reward logger hook.

Args:
Expand All @@ -833,7 +833,7 @@ class LogReward(TrainerHookBase):
in the input batch. Defaults to ``("next", "reward")``

Examples:
>>> log_reward = LogReward(("next", "reward"))
>>> log_reward = LogScalar(("next", "reward"))
>>> trainer.register_op("pre_steps_log", log_reward)

"""
Expand Down Expand Up @@ -870,6 +870,23 @@ def register(self, trainer: Trainer, name: str = "log_reward"):
trainer.register_module(name, self)


class LogReward(LogScalar):
"""Deprecated class. Use LogScalar instead."""

def __init__(
self,
logname="r_training",
log_pbar: bool = False,
reward_key: Union[str, tuple] = None,
):
warnings.warn(
"The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key)


class RewardNormalizer(TrainerHookBase):
"""Reward normalizer hook.

Expand Down Expand Up @@ -1127,7 +1144,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"):
trainer.register_module(name, self)


class Recorder(TrainerHookBase):
class LogValidationReward(TrainerHookBase):
"""Recorder hook for :class:`~torchrl.trainers.Trainer`.

Args:
Expand Down Expand Up @@ -1264,6 +1281,44 @@ def register(self, trainer: Trainer, name: str = "recorder"):
)


class Recorder(LogValidationReward):
"""Deprecated class. Use LogValidationReward instead."""

def __init__(
self,
*,
record_interval: int,
record_frames: int,
frame_skip: int = 1,
policy_exploration: TensorDictModule,
environment: EnvBase = None,
exploration_type: ExplorationType = ExplorationType.RANDOM,
log_keys: Optional[List[Union[str, Tuple[str]]]] = None,
out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None,
suffix: Optional[str] = None,
log_pbar: bool = False,
recorder: EnvBase = None,
) -> None:
warnings.warn(
"The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(
record_interval=record_interval,
record_frames=record_frames,
frame_skip=frame_skip,
policy_exploration=policy_exploration,
environment=environment,
exploration_type=exploration_type,
log_keys=log_keys,
out_keys=out_keys,
suffix=suffix,
log_pbar=log_pbar,
recorder=recorder,
)


class UpdateWeights(TrainerHookBase):
"""A collector weights update hook class.

Expand Down
6 changes: 3 additions & 3 deletions tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,12 +883,12 @@ def make_ddpg_actor(
#
# As the training data is obtained using some exploration strategy, the true
# performance of our algorithm needs to be assessed in deterministic mode. We
# do this using a dedicated class, ``Recorder``, which executes the policy in
# do this using a dedicated class, ``LogValidationReward``, which executes the policy in
# the environment at a given frequency and returns some statistics obtained
# from these simulations.
#
# The following helper function builds this object:
from torchrl.trainers import Recorder
from torchrl.trainers import LogValidationReward


def make_recorder(actor_model_explore, transform_state_dict, record_interval):
Expand All @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
) # must be instantiated to load the state dict
environment.transform[2].load_state_dict(transform_state_dict)

recorder_obj = Recorder(
recorder_obj = LogValidationReward(
record_frames=1000,
policy_exploration=actor_model_explore,
environment=environment,
Expand Down
8 changes: 4 additions & 4 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.record.loggers.csv import CSVLogger
from torchrl.trainers import (
LogReward,
Recorder,
LogScalar,
LogValidationReward,
ReplayBufferTrainer,
Trainer,
UpdateWeights,
Expand Down Expand Up @@ -666,7 +666,7 @@ def get_loss_module(actor, gamma):
buffer_hook.register(trainer)
weight_updater = UpdateWeights(collector, update_weights_interval=1)
weight_updater.register(trainer)
recorder = Recorder(
recorder = LogValidationReward(
record_interval=100, # log every 100 optimization steps
record_frames=1000, # maximum number of frames in the record
frame_skip=1,
Expand Down Expand Up @@ -704,7 +704,7 @@ def get_loss_module(actor, gamma):
# This will be reflected by the `total_rewards` value displayed in the
# progress bar.
#
log_reward = LogReward(log_pbar=True)
log_reward = LogScalar(log_pbar=True)
log_reward.register(trainer)

###############################################################################
Expand Down