Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 20, 2024
2 parents 7a6fb05 + 473ae80 commit 8053e0a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 70 deletions.
1 change: 0 additions & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
"low": torch.as_tensor(action_spec.space.low, device=device),
"high": torch.as_tensor(action_spec.space.high, device=device),
"tanh_loc": NonTensorData(False),
"safe_tanh": NonTensorData(not cfg.compile.compile),
}
),
no_convert=True,
Expand Down
133 changes: 64 additions & 69 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,16 +610,15 @@ def filter_and_repeat(name, x):
tensordict = data.named_apply(
filter_and_repeat, batch_size=batch_size, filter_empty=True
)
with torch.no_grad():
with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
self.actor_network
):
dist = self.actor_network.get_dist(tensordict)
action = dist.rsample()
tensordict.set(self.tensor_keys.action, action)
sample_log_prob = dist.log_prob(action)
# tensordict.del_("loc")
# tensordict.del_("scale")
with set_exploration_type(ExplorationType.RANDOM), actor_params.data.to_module(
self.actor_network
):
dist = self.actor_network.get_dist(tensordict)
action = dist.rsample()
tensordict.set(self.tensor_keys.action, action)
sample_log_prob = dist.log_prob(action)
# tensordict.del_("loc")
# tensordict.del_("scale")

return (
tensordict.select(
Expand All @@ -631,59 +630,58 @@ def filter_and_repeat(name, x):
def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
tensordict = tensordict.clone(False)
# get actions and log-probs
with torch.no_grad():
with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
self.actor_network
with set_exploration_type(ExplorationType.RANDOM), actor_params.data.to_module(
self.actor_network
):
next_tensordict = tensordict.get("next").clone(False)
next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = next_dist.log_prob(next_action)

# get q-values
if not self.max_q_backup:
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params.data
)
next_state_value = next_tensordict_expand.get(
self.tensor_keys.state_action_value
).min(0)[0]
if (
next_state_value.shape[-len(next_sample_log_prob.shape) :]
!= next_sample_log_prob.shape
):
next_tensordict = tensordict.get("next").clone(False)
next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = next_dist.log_prob(next_action)

# get q-values
if not self.max_q_backup:
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params
)
next_state_value = next_tensordict_expand.get(
self.tensor_keys.state_action_value
).min(0)[0]
if (
next_state_value.shape[-len(next_sample_log_prob.shape) :]
!= next_sample_log_prob.shape
):
next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
if not self.deterministic_backup:
next_state_value = next_state_value - _alpha * next_sample_log_prob

if self.max_q_backup:
next_tensordict, _ = self._get_policy_actions(
tensordict.get("next").copy(),
actor_params,
num_actions=self.num_random,
)
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params
)
next_sample_log_prob = next_sample_log_prob.unsqueeze(-1)
if not self.deterministic_backup:
next_state_value = next_state_value - _alpha * next_sample_log_prob

if self.max_q_backup:
next_tensordict, _ = self._get_policy_actions(
tensordict.get("next").copy(),
actor_params,
num_actions=self.num_random,
)
next_tensordict_expand = self._vmap_qvalue_networkN0(
next_tensordict, qval_params.data
)

state_action_value = next_tensordict_expand.get(
self.tensor_keys.state_action_value
state_action_value = next_tensordict_expand.get(
self.tensor_keys.state_action_value
)
# take max over actions
state_action_value = state_action_value.reshape(
torch.Size(
[self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
)
# take max over actions
state_action_value = state_action_value.reshape(
torch.Size(
[self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
)
).max(-2)[0]
# take min over qvalue nets
next_state_value = state_action_value.min(0)[0]
).max(-2)[0]
# take min over qvalue nets
next_state_value = state_action_value.min(0)[0]

tensordict.set(
("next", self.value_estimator.tensor_keys.value), next_state_value
)
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
return target_value
tensordict.set(
("next", self.value_estimator.tensor_keys.value), next_state_value
)
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
return target_value

def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first.
Expand Down Expand Up @@ -897,8 +895,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:
def _alpha(self):
if self.min_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
alpha = self.log_alpha.data.exp()
return alpha


Expand Down Expand Up @@ -1188,14 +1185,12 @@ def value_loss(
pred_val_index = (pred_val * action).sum(-1)

# calculate target value
with torch.no_grad():
target_value = self.value_estimator.value_estimate(
td_copy, params=self._cached_detached_target_value_params
).squeeze(-1)

with torch.no_grad():
td_error = (pred_val_index - target_value).pow(2)
td_error = td_error.unsqueeze(-1)
target_value = self.value_estimator.value_estimate(
td_copy, params=self._cached_detached_target_value_params
).squeeze(-1)

td_error = (pred_val_index - target_value).pow(2)
td_error = td_error.unsqueeze(-1)

tensordict.set(
self.tensor_keys.priority,
Expand Down

0 comments on commit 8053e0a

Please sign in to comment.