Skip to content

Commit

Permalink
[Refactor] Rename Recorder and LogReward : Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
raresdan authored and vmoens committed Dec 3, 2024
1 parent 56f5c3e commit cb0d890
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 30 deletions.
9 changes: 4 additions & 5 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TensorDictReplayBuffer,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.trainers import Trainer, LogValidationReward
from torchrl.trainers import LogValidationReward, Trainer
from torchrl.trainers.helpers import transformed_env_constructor
from torchrl.trainers.trainers import (
_has_tqdm,
Expand Down Expand Up @@ -922,10 +922,9 @@ def test_recorder_load(self, backend, N=8):
LogValidationReward.state_dict, Recorder_state_dict = _fun_checker(
LogValidationReward.state_dict, state_dict_has_been_called
)
(
LogValidationReward.load_state_dict,
Recorder_load_state_dict,
) = _fun_checker(LogValidationReward.load_state_dict, load_state_dict_has_been_called)
(LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker(
LogValidationReward.load_state_dict, load_state_dict_has_been_called
)

args = self._get_args()

Expand Down
4 changes: 3 additions & 1 deletion torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
BatchSubSampler,
ClearCudaCache,
CountFramesLog,
LogReward,
LogScalar,
LogValidationReward,
mask_batch,
OptimizerHook,
LogValidationReward,
Recorder,
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
Expand Down
56 changes: 32 additions & 24 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,16 +872,19 @@ def register(self, trainer: Trainer, name: str = "log_reward"):

class LogReward(LogScalar):
"""Deprecated class. Use LogScalar instead."""
def __init__(self,
logname="r_training",
log_pbar: bool = False,
reward_key: Union[str, tuple] = None):
warnings.warn("The 'LogReward' class is deprecated and will be removed in a future release. Please use 'LogScalar' instead.",
DeprecationWarning,
stacklevel=2)
super().__init__(logname=logname,
log_pbar=log_pbar,
reward_key=reward_key)

def __init__(
self,
logname="r_training",
log_pbar: bool = False,
reward_key: Union[str, tuple] = None,
):
warnings.warn(
"The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key)


class RewardNormalizer(TrainerHookBase):
Expand Down Expand Up @@ -1280,6 +1283,7 @@ def register(self, trainer: Trainer, name: str = "recorder"):

class Recorder(LogValidationReward):
"""Deprecated class. Use LogValidationReward instead."""

def __init__(
self,
*,
Expand All @@ -1295,20 +1299,24 @@ def __init__(
log_pbar: bool = False,
recorder: EnvBase = None,
) -> None:
warnings.warn("The 'Recorder' class is deprecated and will be removed in a future release. Please use 'LogValidationReward' instead.",
DeprecationWarning,
stacklevel=2)
super().__init__(record_interval=record_interval,
record_frames=record_frames,
frame_skip=frame_skip,
policy_exploration=policy_exploration,
environment=environment,
exploration_type=exploration_type,
log_keys=log_keys,
out_keys=out_keys,
suffix=suffix,
log_pbar=log_pbar,
recorder=recorder)
warnings.warn(
"The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(
record_interval=record_interval,
record_frames=record_frames,
frame_skip=frame_skip,
policy_exploration=policy_exploration,
environment=environment,
exploration_type=exploration_type,
log_keys=log_keys,
out_keys=out_keys,
suffix=suffix,
log_pbar=log_pbar,
recorder=recorder,
)


class UpdateWeights(TrainerHookBase):
Expand Down

0 comments on commit cb0d890

Please sign in to comment.