Skip to content

Commit

Permalink
review changes add doctests for tensorclass | merging to main
Browse files Browse the repository at this point in the history
  • Loading branch information
SandishKumarHN committed Feb 17, 2024
1 parent 387953f commit 7d6e08f
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 13 deletions.
3 changes: 0 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6693,11 +6693,8 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti
value,
loss_critic_type="l2",
functional=functional,
<<<<<<< HEAD
return_tensorclass=False,
=======
reduction=reduction,
>>>>>>> upstream/main
)

# Check error is raised when actions require grads
Expand Down
3 changes: 0 additions & 3 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,8 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
<<<<<<< HEAD
return_tensorclass: bool = False,
=======
reduction: str = None,
>>>>>>> upstream/main
):
if actor is not None:
actor_network = actor
Expand Down
7 changes: 0 additions & 7 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,8 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
<<<<<<< HEAD
return_tensorclass: bool = False,
=======
reduction: str = None,
>>>>>>> upstream/main
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -585,19 +582,15 @@ def forward(self, tensordict: TensorDictBase) -> PPOLosses:
td_out.set("entropy", entropy.detach()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
<<<<<<< HEAD
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
if self.return_tensorclass:
return PPOLosses._from_tensordict(td_out)
=======
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.apply(
functools.partial(_reduce, reduction=self.reduction), batch_size=[]
)

>>>>>>> upstream/main
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down

0 comments on commit 7d6e08f

Please sign in to comment.