diff --git a/skrl/utils/model_instantiators/torch/__init__.py b/skrl/utils/model_instantiators/torch/__init__.py index 4fbcb588..89189598 100644 --- a/skrl/utils/model_instantiators/torch/__init__.py +++ b/skrl/utils/model_instantiators/torch/__init__.py @@ -555,16 +555,19 @@ def act(self, inputs, role): def compute(self, inputs, role): if self.instantiator_input_type == 0: - output = self.net(inputs["states"]) + net_inputs = inputs["states"] elif self.instantiator_input_type == -1: - output = self.net(inputs["taken_actions"]) + net_inputs = inputs["taken_actions"] elif self.instantiator_input_type == -2: - output = self.net(torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)) + net_inputs = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1) if role == self._roles[0]: - return self.instantiator_output_scales[0] * self.mean_net(output), self.log_std_parameter, {} + self._shared_output = self.net(net_inputs) + return self.instantiator_output_scales[0] * self.mean_net(self._shared_output), self.log_std_parameter, {} elif role == self._roles[1]: - return self.instantiator_output_scales[1] * self.value_net(output), {} + shared_output = self.net(net_inputs) if self._shared_output is None else self._shared_output + self._shared_output = None + return self.instantiator_output_scales[1] * self.value_net(shared_output), {} # TODO: define the model using the specified structure