From 334a08c6f9240fd93cbcd68b594c6bdc78f2952e Mon Sep 17 00:00:00 2001 From: Yi Wan Date: Thu, 19 Dec 2024 14:13:52 -0800 Subject: [PATCH] add the missing .shape to q value networks Summary: Add the missing .shape to q value networks. Reviewed By: rodrigodesalvobraz Differential Revision: D67272397 fbshipit-source-id: 88cf32b9efa43fd9051c8de17fcc82ec3214fbe7 --- .../sequential_decision_making/q_value_networks.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pearl/neural_networks/sequential_decision_making/q_value_networks.py b/pearl/neural_networks/sequential_decision_making/q_value_networks.py index bdb4425..07d5edf 100644 --- a/pearl/neural_networks/sequential_decision_making/q_value_networks.py +++ b/pearl/neural_networks/sequential_decision_making/q_value_networks.py @@ -172,7 +172,7 @@ def get_q_values( q_values = self.forward(x).squeeze( -1 ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -239,7 +239,7 @@ def get_q_values( q_values, # (batch_size x num actions x 1) ) # (batch_size x number of query actions x 1) q_values = q_values.squeeze(-1) # (batch_size x number of query actions) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -324,7 +324,7 @@ def get_q_value_distribution( q_values = self.forward( x ) # (batch_size, number_of_actions_to_query, number_of_quantiles) - return q_values if len(action_batch) == 3 else q_values.squeeze(-2) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-2) @property def quantiles(self) -> Tensor: @@ -503,7 +503,7 @@ def get_q_values( state_value + advantage - advantage_mean ) # shape: (batch_size, number of query actions) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) """ @@ -590,7 +590,7 @@ def get_q_values( q_values = self._interaction_features.forward(x).squeeze( -1 ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -695,7 +695,7 @@ def get_q_values( q_values = self.forward(x, z=z, persistent=persistent).squeeze( -1 ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -806,7 +806,7 @@ def get_q_values( q_values = self._model_fc(x).reshape( batch_size, num_query_actions ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: