Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents 880d0bf + c24bc73 commit e1183df
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.optim
from tensordict.nn import InteractionType, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch.distributions import Categorical

from torchrl.collectors import SyncDataCollector
from torchrl.data import (
Expand All @@ -36,15 +37,13 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
MLP,
OneHotCategorical,
ProbabilisticActor,
SafeModule,
TanhNormal,
ValueOperator,
)
from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate
from torchrl.record import VideoRecorder

from torchrl.trainers.helpers.models import ACTIVATIONS


Expand All @@ -58,7 +57,11 @@ def env_maker(cfg, device="cpu", from_pixels=False):
if lib in ("gym", "gymnasium"):
with set_gym_backend(lib):
return GymEnv(
cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
cfg.env.name,
device=device,
from_pixels=from_pixels,
pixels_only=False,
categorical_action_encoding=True,
)
elif lib == "dm_control":
env = DMControlEnv(
Expand Down Expand Up @@ -221,8 +224,8 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"low": action_spec.space.low,
"high": action_spec.space.high,
"low": action_spec.space.low.to(device),
"high": action_spec.space.high.to(device),
"tanh_loc": False,
},
default_interaction_type=ExplorationType.RANDOM,
Expand Down Expand Up @@ -318,7 +321,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device):
module=actor_module,
in_keys=["logits"],
out_keys=["action"],
distribution_class=OneHotCategorical,
distribution_class=Categorical,
distribution_kwargs={},
default_interaction_type=InteractionType.RANDOM,
return_log_prob=False,
Expand Down

0 comments on commit e1183df

Please sign in to comment.