Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents cefe9ef + e0f702f commit 459ce5f
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,27 +298,24 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device):
"""Make discrete IQL agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.action_spec_unbatched
# Define Actor Network
in_keys = ["observation"]

actor_net_kwargs = {
"num_cells": cfg.model.hidden_sizes,
"out_features": action_spec.shape[-1],
"activation_class": ACTIVATIONS[cfg.model.activation],
}

actor_net = MLP(**actor_net_kwargs)
actor_net = MLP(
num_cells=cfg.model.hidden_sizes,
out_features=action_spec.space.n,
activation_class=ACTIVATIONS[cfg.model.activation],
device=device,
)

actor_module = SafeModule(
module=actor_net,
in_keys=in_keys,
out_keys=["logits"],
)
actor = ProbabilisticActor(
spec=Composite(action=eval_env.action_spec),
spec=Composite(action=eval_env.action_spec_unbatched).to(device),
module=actor_module,
in_keys=["logits"],
out_keys=["action"],
Expand All @@ -329,42 +326,38 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device):
)

# Define Critic Network
qvalue_net_kwargs = {
"num_cells": cfg.model.hidden_sizes,
"out_features": action_spec.shape[-1],
"activation_class": ACTIVATIONS[cfg.model.activation],
}
qvalue_net = MLP(
**qvalue_net_kwargs,
num_cells=cfg.model.hidden_sizes,
out_features=action_spec.space.n,
activation_class=ACTIVATIONS[cfg.model.activation],
device=device,
)

qvalue = TensorDictModule(
in_keys=["observation"],
out_keys=["state_action_value"],
module=qvalue_net,
)

# Define Value Network
value_net_kwargs = {
"num_cells": cfg.model.hidden_sizes,
"out_features": 1,
"activation_class": ACTIVATIONS[cfg.model.activation],
}
value_net = MLP(**value_net_kwargs)
value_net = MLP(
num_cells=cfg.model.hidden_sizes,
out_features=1,
activation_class=ACTIVATIONS[cfg.model.activation],
device=device,
)
value_net = TensorDictModule(
in_keys=["observation"],
out_keys=["state_value"],
module=value_net,
)

model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device)
model = torch.nn.ModuleList([actor, qvalue, value_net])
# init nets
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = eval_env.reset()
td = eval_env.fake_tensordict()
td = td.to(device)
for net in model:
net(td)
del td
eval_env.close()

return model
Expand Down

0 comments on commit 459ce5f

Please sign in to comment.