Skip to content

Commit

Permalink
add the missing .shape to q value networks
Browse files Browse the repository at this point in the history
Summary: Add the missing .shape to q value networks.

Reviewed By: rodrigodesalvobraz

Differential Revision: D67272397

fbshipit-source-id: 88cf32b9efa43fd9051c8de17fcc82ec3214fbe7
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 19, 2024
1 parent f35e798 commit 334a08c
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_q_values(
q_values = self.forward(x).squeeze(
-1
) # (batch_size, number_of_actions_to_query)
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)

@property
def state_dim(self) -> int:
Expand Down Expand Up @@ -239,7 +239,7 @@ def get_q_values(
q_values, # (batch_size x num actions x 1)
) # (batch_size x number of query actions x 1)
q_values = q_values.squeeze(-1) # (batch_size x number of query actions)
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)

@property
def state_dim(self) -> int:
Expand Down Expand Up @@ -324,7 +324,7 @@ def get_q_value_distribution(
q_values = self.forward(
x
) # (batch_size, number_of_actions_to_query, number_of_quantiles)
return q_values if len(action_batch) == 3 else q_values.squeeze(-2)
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-2)

@property
def quantiles(self) -> Tensor:
Expand Down Expand Up @@ -503,7 +503,7 @@ def get_q_values(
state_value + advantage - advantage_mean
) # shape: (batch_size, number of query actions)

return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)


"""
Expand Down Expand Up @@ -590,7 +590,7 @@ def get_q_values(
q_values = self._interaction_features.forward(x).squeeze(
-1
) # (batch_size, number_of_actions_to_query)
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)

@property
def state_dim(self) -> int:
Expand Down Expand Up @@ -695,7 +695,7 @@ def get_q_values(
q_values = self.forward(x, z=z, persistent=persistent).squeeze(
-1
) # (batch_size, number_of_actions_to_query)
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)

@property
def state_dim(self) -> int:
Expand Down Expand Up @@ -806,7 +806,7 @@ def get_q_values(
q_values = self._model_fc(x).reshape(
batch_size, num_query_actions
) # (batch_size, number_of_actions_to_query)
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)

@property
def state_dim(self) -> int:
Expand Down

0 comments on commit 334a08c

Please sign in to comment.