From b415115490c6b287cbe1b4c560da1973145089bf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 19 Nov 2024 12:55:20 +0000 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/redq/utils.py | 50 ++++++++++++++++-------------- torchrl/objectives/__init__.py | 1 + 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 8312d359366..b08bcd457cc 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -21,55 +21,59 @@ from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.collectors.collectors import DataCollectorBase -from torchrl.data import ReplayBuffer, TensorDictReplayBuffer -from torchrl.data.postprocs import MultiStep -from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data import ( + LazyMemmapStorage, + MultiStep, + PrioritizedSampler, + RandomSampler, + ReplayBuffer, + TensorDictReplayBuffer, +) from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs import ParallelEnv -from torchrl.envs.common import EnvBase -from torchrl.envs.env_creator import env_creator, EnvCreator -from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.transforms import ( +from torchrl.envs import ( CatFrames, CatTensors, CenterCrop, Compose, + DMControlEnv, DoubleToFloat, + env_creator, + EnvBase, + EnvCreator, + FlattenObservation, GrayScale, + gSDENoise, + GymEnv, + InitTracker, NoopResetEnv, ObservationNorm, + ParallelEnv, Resize, RewardScaling, + StepCounter, ToTensorImage, TransformedEnv, VecNorm, ) -from torchrl.envs.transforms.transforms import ( - FlattenObservation, - gSDENoise, - InitTracker, - StepCounter, -) from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( ActorCriticOperator, ActorValueOperator, + DdpgCnnActor, + DdpgCnnQNet, + MLP, NoisyLinear, NormalParamExtractor, + ProbabilisticActor, SafeModule, SafeSequential, + TanhNormal, + ValueOperator, ) -from torchrl.modules.distributions import TanhNormal from torchrl.modules.distributions.continuous import SafeTanhTransform from torchrl.modules.models.exploration import LazygSDEModule -from torchrl.modules.models.models import DdpgCnnActor, DdpgCnnQNet, MLP -from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator -from torchrl.objectives import HardUpdate, SoftUpdate -from torchrl.objectives.common import LossModule +from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater from torchrl.objectives.deprecated import REDQLoss_deprecated -from torchrl.objectives.utils import TargetNetUpdater from torchrl.record.loggers import Logger from torchrl.record.recorder import VideoRecorder from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector @@ -518,7 +522,7 @@ def make_redq_model( actor_module = SafeSequential( actor_module, SafeModule( - LazygSDEModule(transform=transform), + LazygSDEModule(transform=transform, device=device), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], ), diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 01f993e629a..f8f5636db95 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -29,5 +29,6 @@ hold_out_params, next_state_value, SoftUpdate, + TargetNetUpdater, ValueEstimators, )