Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 19, 2024
2 parents 44278d0 + b415115 commit 2104d06
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
50 changes: 27 additions & 23 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,59 @@
from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.collectors.collectors import DataCollectorBase

from torchrl.data import ReplayBuffer, TensorDictReplayBuffer
from torchrl.data.postprocs import MultiStep
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data import (
LazyMemmapStorage,
MultiStep,
PrioritizedSampler,
RandomSampler,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import ParallelEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import env_creator, EnvCreator
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
from torchrl.envs import (
CatFrames,
CatTensors,
CenterCrop,
Compose,
DMControlEnv,
DoubleToFloat,
env_creator,
EnvBase,
EnvCreator,
FlattenObservation,
GrayScale,
gSDENoise,
GymEnv,
InitTracker,
NoopResetEnv,
ObservationNorm,
ParallelEnv,
Resize,
RewardScaling,
StepCounter,
ToTensorImage,
TransformedEnv,
VecNorm,
)
from torchrl.envs.transforms.transforms import (
FlattenObservation,
gSDENoise,
InitTracker,
StepCounter,
)
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
ActorCriticOperator,
ActorValueOperator,
DdpgCnnActor,
DdpgCnnQNet,
MLP,
NoisyLinear,
NormalParamExtractor,
ProbabilisticActor,
SafeModule,
SafeSequential,
TanhNormal,
ValueOperator,
)
from torchrl.modules.distributions import TanhNormal
from torchrl.modules.distributions.continuous import SafeTanhTransform
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.models.models import DdpgCnnActor, DdpgCnnQNet, MLP
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.objectives import HardUpdate, SoftUpdate
from torchrl.objectives.common import LossModule
from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater
from torchrl.objectives.deprecated import REDQLoss_deprecated
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.record.loggers import Logger
from torchrl.record.recorder import VideoRecorder
from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector
Expand Down Expand Up @@ -518,7 +522,7 @@ def make_redq_model(
actor_module = SafeSequential(
actor_module,
SafeModule(
LazygSDEModule(transform=transform),
LazygSDEModule(transform=transform, device=device),
in_keys=["action", gSDE_state_key, "_eps_gSDE"],
out_keys=["loc", "scale", "action", "_eps_gSDE"],
),
Expand Down
1 change: 1 addition & 0 deletions torchrl/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
hold_out_params,
next_state_value,
SoftUpdate,
TargetNetUpdater,
ValueEstimators,
)

0 comments on commit 2104d06

Please sign in to comment.