diff --git a/test/test_cost.py b/test/test_cost.py index 949a8a64bef..d7d5e762313 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6693,11 +6693,8 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti value, loss_critic_type="l2", functional=functional, -<<<<<<< HEAD return_tensorclass=False, -======= reduction=reduction, ->>>>>>> upstream/main ) # Check error is raised when actions require grads diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 071797b5f37..8b3d48247e8 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -271,11 +271,8 @@ def __init__( functional: bool = True, actor: ProbabilisticTensorDictSequential = None, critic: ProbabilisticTensorDictSequential = None, -<<<<<<< HEAD return_tensorclass: bool = False, -======= reduction: str = None, ->>>>>>> upstream/main ): if actor is not None: actor_network = actor diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 6de2bae2356..07dddb0dfe0 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -317,11 +317,8 @@ def __init__( functional: bool = True, actor: ProbabilisticTensorDictSequential = None, critic: ProbabilisticTensorDictSequential = None, -<<<<<<< HEAD return_tensorclass: bool = False, -======= reduction: str = None, ->>>>>>> upstream/main ): if actor is not None: actor_network = actor @@ -585,19 +582,15 @@ def forward(self, tensordict: TensorDictBase) -> PPOLosses: td_out.set("entropy", entropy.detach()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: -<<<<<<< HEAD loss_critic = self.loss_critic(tensordict).mean() td_out.set("loss_critic", loss_critic.mean()) if self.return_tensorclass: return PPOLosses._from_tensordict(td_out) -======= loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) td_out = td_out.apply( functools.partial(_reduce, reduction=self.reduction), batch_size=[] ) - ->>>>>>> upstream/main return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):