From c44b2686f40987da8f4ecc35658f428f86379183 Mon Sep 17 00:00:00 2001 From: Lukas Lopatovsky Date: Mon, 10 Jun 2024 18:42:22 +0200 Subject: [PATCH 1/2] Use single forward pass in shared model --- skrl/utils/model_instantiators/torch/__init__.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/skrl/utils/model_instantiators/torch/__init__.py b/skrl/utils/model_instantiators/torch/__init__.py index 4fbcb588..4d80b521 100644 --- a/skrl/utils/model_instantiators/torch/__init__.py +++ b/skrl/utils/model_instantiators/torch/__init__.py @@ -555,15 +555,18 @@ def act(self, inputs, role): def compute(self, inputs, role): if self.instantiator_input_type == 0: - output = self.net(inputs["states"]) + net_inputs = inputs["states"] elif self.instantiator_input_type == -1: - output = self.net(inputs["taken_actions"]) + net_inputs = inputs["taken_actions"] elif self.instantiator_input_type == -2: - output = self.net(torch.cat((inputs["states"], inputs["taken_actions"]), dim=1)) + net_inputs = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1) if role == self._roles[0]: - return self.instantiator_output_scales[0] * self.mean_net(output), self.log_std_parameter, {} + self.output = self.net(net_inputs) + return self.instantiator_output_scales[0] * self.mean_net(self.output), self.log_std_parameter, {} elif role == self._roles[1]: + output = self.net(net_inputs) if self.output is None else self.output + self.output = None return self.instantiator_output_scales[1] * self.value_net(output), {} # TODO: define the model using the specified structure From d3395c6c5d131ea280f49b20b52838c3be07d701 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sat, 15 Jun 2024 12:07:50 -0400 Subject: [PATCH 2/2] Rename cached shared layer/network output --- skrl/utils/model_instantiators/torch/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/skrl/utils/model_instantiators/torch/__init__.py b/skrl/utils/model_instantiators/torch/__init__.py index 4d80b521..89189598 100644 --- a/skrl/utils/model_instantiators/torch/__init__.py +++ b/skrl/utils/model_instantiators/torch/__init__.py @@ -562,12 +562,12 @@ def compute(self, inputs, role): net_inputs = torch.cat((inputs["states"], inputs["taken_actions"]), dim=1) if role == self._roles[0]: - self.output = self.net(net_inputs) - return self.instantiator_output_scales[0] * self.mean_net(self.output), self.log_std_parameter, {} + self._shared_output = self.net(net_inputs) + return self.instantiator_output_scales[0] * self.mean_net(self._shared_output), self.log_std_parameter, {} elif role == self._roles[1]: - output = self.net(net_inputs) if self.output is None else self.output - self.output = None - return self.instantiator_output_scales[1] * self.value_net(output), {} + shared_output = self.net(net_inputs) if self._shared_output is None else self._shared_output + self._shared_output = None + return self.instantiator_output_scales[1] * self.value_net(shared_output), {} # TODO: define the model using the specified structure