diff --git a/test/test_trainer.py b/test/test_trainer.py index 97868957f4c..caae5bbe178 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -35,7 +35,7 @@ TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym -from torchrl.trainers import Trainer, LogValidationReward +from torchrl.trainers import LogValidationReward, Trainer from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.trainers import ( _has_tqdm, @@ -922,10 +922,9 @@ def test_recorder_load(self, backend, N=8): LogValidationReward.state_dict, Recorder_state_dict = _fun_checker( LogValidationReward.state_dict, state_dict_has_been_called ) - ( - LogValidationReward.load_state_dict, - Recorder_load_state_dict, - ) = _fun_checker(LogValidationReward.load_state_dict, load_state_dict_has_been_called) + (LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker( + LogValidationReward.load_state_dict, load_state_dict_has_been_called + ) args = self._get_args() diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 824f0dfa848..9d593d64f17 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -7,10 +7,12 @@ BatchSubSampler, ClearCudaCache, CountFramesLog, + LogReward, LogScalar, + LogValidationReward, mask_batch, OptimizerHook, - LogValidationReward, + Recorder, ReplayBufferTrainer, RewardNormalizer, SelectKeys, diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index d0b04b92ad0..83bd050ef96 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -872,16 +872,19 @@ def register(self, trainer: Trainer, name: str = "log_reward"): class LogReward(LogScalar): """Deprecated class. Use LogScalar instead.""" - def __init__(self, - logname="r_training", - log_pbar: bool = False, - reward_key: Union[str, tuple] = None): - warnings.warn("The 'LogReward' class is deprecated and will be removed in a future release. Please use 'LogScalar' instead.", - DeprecationWarning, - stacklevel=2) - super().__init__(logname=logname, - log_pbar=log_pbar, - reward_key=reward_key) + + def __init__( + self, + logname="r_training", + log_pbar: bool = False, + reward_key: Union[str, tuple] = None, + ): + warnings.warn( + "The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key) class RewardNormalizer(TrainerHookBase): @@ -1280,6 +1283,7 @@ def register(self, trainer: Trainer, name: str = "recorder"): class Recorder(LogValidationReward): """Deprecated class. Use LogValidationReward instead.""" + def __init__( self, *, @@ -1295,20 +1299,24 @@ def __init__( log_pbar: bool = False, recorder: EnvBase = None, ) -> None: - warnings.warn("The 'Recorder' class is deprecated and will be removed in a future release. Please use 'LogValidationReward' instead.", - DeprecationWarning, - stacklevel=2) - super().__init__(record_interval=record_interval, - record_frames=record_frames, - frame_skip=frame_skip, - policy_exploration=policy_exploration, - environment=environment, - exploration_type=exploration_type, - log_keys=log_keys, - out_keys=out_keys, - suffix=suffix, - log_pbar=log_pbar, - recorder=recorder) + warnings.warn( + "The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( + record_interval=record_interval, + record_frames=record_frames, + frame_skip=frame_skip, + policy_exploration=policy_exploration, + environment=environment, + exploration_type=exploration_type, + log_keys=log_keys, + out_keys=out_keys, + suffix=suffix, + log_pbar=log_pbar, + recorder=recorder, + ) class UpdateWeights(TrainerHookBase):