diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 11384bda0e6..8f6be633743 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. - **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger - some information retrieved from that data. Examples include the ``Recorder`` hook, the reward - logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the + some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward + logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. @@ -174,9 +174,9 @@ Trainer and hooks BatchSubSampler ClearCudaCache CountFramesLog - LogReward + LogScaler OptimizerHook - Recorder + LogValidationReward ReplayBufferTrainer RewardNormalizer SelectKeys diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 9953fcb3112..fed4922b5a7 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -81,8 +81,8 @@ BatchSubSampler, ClearCudaCache, CountFramesLog, - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, RewardNormalizer, Trainer, @@ -331,7 +331,7 @@ def make_trainer( if recorder is not None: # create recorder object - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=cfg.logger.record_frames, frame_skip=cfg.env.frame_skip, policy_exploration=policy_exploration, @@ -347,7 +347,7 @@ def make_trainer( # call recorder - could be removed recorder_obj(None) # create explorative recorder - could be optional - recorder_obj_explore = Recorder( + recorder_obj_explore = LogValidationReward( record_frames=cfg.logger.record_frames, frame_skip=cfg.env.frame_skip, policy_exploration=policy_exploration, @@ -369,7 +369,7 @@ def make_trainer( "post_steps", UpdateWeights(collector, update_weights_interval=1) ) - trainer.register_op("pre_steps_log", LogReward()) + trainer.register_op("pre_steps_log", LogScalar()) trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip)) return trainer diff --git a/test/test_trainer.py b/test/test_trainer.py index f7e4ccffdf5..caae5bbe178 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -35,14 +35,14 @@ TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym -from torchrl.trainers import Recorder, Trainer +from torchrl.trainers import LogValidationReward, Trainer from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.trainers import ( _has_tqdm, _has_ts, BatchSubSampler, CountFramesLog, - LogReward, + LogScalar, mask_batch, OptimizerHook, ReplayBufferTrainer, @@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar): trainer = mocking_trainer() trainer.collected_frames = 0 - log_reward = LogReward(logname, log_pbar=pbar) + log_reward = LogScalar(logname, log_pbar=pbar) trainer.register_op("pre_steps_log", log_reward) td = TensorDict({REWARD_KEY: torch.ones(3)}, [3]) trainer._pre_steps_log_hook(td) @@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar): trainer = mocking_trainer() trainer.collected_frames = 0 - log_reward = LogReward(logname, log_pbar=pbar) + log_reward = LogScalar(logname, log_pbar=pbar) log_reward.register(trainer) td = TensorDict({REWARD_KEY: torch.ones(3)}, [3]) trainer._pre_steps_log_hook(td) @@ -873,7 +873,7 @@ def test_recorder(self, N=8): logger=logger, )() - recorder = Recorder( + recorder = LogValidationReward( record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, @@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8): os.environ["CKPT_BACKEND"] = backend state_dict_has_been_called = [False] load_state_dict_has_been_called = [False] - Recorder.state_dict, Recorder_state_dict = _fun_checker( - Recorder.state_dict, state_dict_has_been_called + LogValidationReward.state_dict, Recorder_state_dict = _fun_checker( + LogValidationReward.state_dict, state_dict_has_been_called + ) + (LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker( + LogValidationReward.load_state_dict, load_state_dict_has_been_called ) - ( - Recorder.load_state_dict, - Recorder_load_state_dict, - ) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called) args = self._get_args() @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname): )() environment.rollout(2) - recorder = Recorder( + recorder = LogValidationReward( record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, @@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname): assert recorder2._count == 8 assert state_dict_has_been_called[0] assert load_state_dict_has_been_called[0] - Recorder.state_dict = Recorder_state_dict - Recorder.load_state_dict = Recorder_load_state_dict + LogValidationReward.state_dict = Recorder_state_dict + LogValidationReward.load_state_dict = Recorder_load_state_dict def test_updateweights(): diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 364c0dec725..9d593d64f17 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -8,6 +8,8 @@ ClearCudaCache, CountFramesLog, LogReward, + LogScalar, + LogValidationReward, mask_batch, OptimizerHook, Recorder, diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 207bcec0ffd..4819d9e07e8 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -25,8 +25,8 @@ BatchSubSampler, ClearCudaCache, CountFramesLog, - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, RewardNormalizer, SelectKeys, @@ -259,7 +259,7 @@ def make_trainer( if recorder is not None: # create recorder object - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, @@ -275,7 +275,7 @@ def make_trainer( # call recorder - could be removed recorder_obj(None) # create explorative recorder - could be optional - recorder_obj_explore = Recorder( + recorder_obj_explore = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, @@ -297,7 +297,7 @@ def make_trainer( "post_steps", UpdateWeights(collector, update_weights_interval=1) ) - trainer.register_op("pre_steps_log", LogReward()) + trainer.register_op("pre_steps_log", LogScalar()) trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip)) return trainer diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 7e28da45f52..83bd050ef96 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -822,7 +822,7 @@ def __call__(self, *args, **kwargs): torch.cuda.empty_cache() -class LogReward(TrainerHookBase): +class LogScalar(TrainerHookBase): """Reward logger hook. Args: @@ -833,7 +833,7 @@ class LogReward(TrainerHookBase): in the input batch. Defaults to ``("next", "reward")`` Examples: - >>> log_reward = LogReward(("next", "reward")) + >>> log_reward = LogScalar(("next", "reward")) >>> trainer.register_op("pre_steps_log", log_reward) """ @@ -870,6 +870,23 @@ def register(self, trainer: Trainer, name: str = "log_reward"): trainer.register_module(name, self) +class LogReward(LogScalar): + """Deprecated class. Use LogScalar instead.""" + + def __init__( + self, + logname="r_training", + log_pbar: bool = False, + reward_key: Union[str, tuple] = None, + ): + warnings.warn( + "The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key) + + class RewardNormalizer(TrainerHookBase): """Reward normalizer hook. @@ -1127,7 +1144,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"): trainer.register_module(name, self) -class Recorder(TrainerHookBase): +class LogValidationReward(TrainerHookBase): """Recorder hook for :class:`~torchrl.trainers.Trainer`. Args: @@ -1264,6 +1281,44 @@ def register(self, trainer: Trainer, name: str = "recorder"): ) +class Recorder(LogValidationReward): + """Deprecated class. Use LogValidationReward instead.""" + + def __init__( + self, + *, + record_interval: int, + record_frames: int, + frame_skip: int = 1, + policy_exploration: TensorDictModule, + environment: EnvBase = None, + exploration_type: ExplorationType = ExplorationType.RANDOM, + log_keys: Optional[List[Union[str, Tuple[str]]]] = None, + out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, + suffix: Optional[str] = None, + log_pbar: bool = False, + recorder: EnvBase = None, + ) -> None: + warnings.warn( + "The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( + record_interval=record_interval, + record_frames=record_frames, + frame_skip=frame_skip, + policy_exploration=policy_exploration, + environment=environment, + exploration_type=exploration_type, + log_keys=log_keys, + out_keys=out_keys, + suffix=suffix, + log_pbar=log_pbar, + recorder=recorder, + ) + + class UpdateWeights(TrainerHookBase): """A collector weights update hook class. diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 906d162f181..70176f9de4a 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -883,12 +883,12 @@ def make_ddpg_actor( # # As the training data is obtained using some exploration strategy, the true # performance of our algorithm needs to be assessed in deterministic mode. We -# do this using a dedicated class, ``Recorder``, which executes the policy in +# do this using a dedicated class, ``LogValidationReward``, which executes the policy in # the environment at a given frequency and returns some statistics obtained # from these simulations. # # The following helper function builds this object: -from torchrl.trainers import Recorder +from torchrl.trainers import LogValidationReward def make_recorder(actor_model_explore, transform_state_dict, record_interval): @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval): ) # must be instantiated to load the state dict environment.transform[2].load_state_dict(transform_state_dict) - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=1000, policy_exploration=actor_model_explore, environment=environment, diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 59188ad21f6..a10e8c1169a 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -140,8 +140,8 @@ from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.record.loggers.csv import CSVLogger from torchrl.trainers import ( - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, Trainer, UpdateWeights, @@ -666,7 +666,7 @@ def get_loss_module(actor, gamma): buffer_hook.register(trainer) weight_updater = UpdateWeights(collector, update_weights_interval=1) weight_updater.register(trainer) -recorder = Recorder( +recorder = LogValidationReward( record_interval=100, # log every 100 optimization steps record_frames=1000, # maximum number of frames in the record frame_skip=1, @@ -704,7 +704,7 @@ def get_loss_module(actor, gamma): # This will be reflected by the `total_rewards` value displayed in the # progress bar. # -log_reward = LogReward(log_pbar=True) +log_reward = LogScalar(log_pbar=True) log_reward.register(trainer) ###############################################################################