diff --git a/pearl/neural_networks/sequential_decision_making/q_value_networks.py b/pearl/neural_networks/sequential_decision_making/q_value_networks.py index bdb4425..07d5edf 100644 --- a/pearl/neural_networks/sequential_decision_making/q_value_networks.py +++ b/pearl/neural_networks/sequential_decision_making/q_value_networks.py @@ -172,7 +172,7 @@ def get_q_values( q_values = self.forward(x).squeeze( -1 ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -239,7 +239,7 @@ def get_q_values( q_values, # (batch_size x num actions x 1) ) # (batch_size x number of query actions x 1) q_values = q_values.squeeze(-1) # (batch_size x number of query actions) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -324,7 +324,7 @@ def get_q_value_distribution( q_values = self.forward( x ) # (batch_size, number_of_actions_to_query, number_of_quantiles) - return q_values if len(action_batch) == 3 else q_values.squeeze(-2) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-2) @property def quantiles(self) -> Tensor: @@ -503,7 +503,7 @@ def get_q_values( state_value + advantage - advantage_mean ) # shape: (batch_size, number of query actions) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) """ @@ -590,7 +590,7 @@ def get_q_values( q_values = self._interaction_features.forward(x).squeeze( -1 ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -695,7 +695,7 @@ def get_q_values( q_values = self.forward(x, z=z, persistent=persistent).squeeze( -1 ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: @@ -806,7 +806,7 @@ def get_q_values( q_values = self._model_fc(x).reshape( batch_size, num_query_actions ) # (batch_size, number_of_actions_to_query) - return q_values if len(action_batch) == 3 else q_values.squeeze(-1) + return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1) @property def state_dim(self) -> int: