-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pass tensordict as an additional input of _precompute_cache
- Loading branch information
1 parent
9c5ff00
commit ed204d2
Showing
1 changed file
with
12 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from einops import rearrange | ||
from tensordict import TensorDict | ||
from torch import Tensor | ||
|
@@ -89,17 +88,11 @@ def __init__( | |
self.use_graph_context = use_graph_context | ||
|
||
# For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim | ||
self.project_node_embeddings = nn.Linear( | ||
embedding_dim, 3 * embedding_dim, bias=linear_bias | ||
) | ||
self.project_fixed_context = nn.Linear( | ||
embedding_dim, embedding_dim, bias=linear_bias | ||
) | ||
self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=linear_bias) | ||
self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=linear_bias) | ||
|
||
# MHA | ||
self.logit_attention = LogitAttention( | ||
embedding_dim, num_heads, **logit_attn_kwargs | ||
) | ||
self.logit_attention = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs) | ||
|
||
self.select_start_nodes_fn = select_start_nodes_fn | ||
|
||
|
@@ -143,13 +136,11 @@ def forward( | |
else: | ||
if num_starts is not None: | ||
if num_starts > 1: | ||
log.warn( | ||
f"num_starts={num_starts} is ignored for decode_type={decode_type}" | ||
) | ||
log.warn(f"num_starts={num_starts} is ignored for decode_type={decode_type}") | ||
num_starts = 0 | ||
|
||
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step | ||
cached_embeds = self._precompute_cache(embeddings, num_starts=num_starts) | ||
cached_embeds = self._precompute_cache(embeddings, td=td, num_starts=num_starts) | ||
|
||
# Collect outputs | ||
outputs = [] | ||
|
@@ -196,20 +187,24 @@ def forward( | |
|
||
return outputs, actions, td | ||
|
||
def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0): | ||
def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0, td: TensorDict = None): | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
Junyoungpark
Author
Collaborator
|
||
"""Compute the cached embeddings for the attention | ||
Args: | ||
embeddings: Precomputed embeddings for the nodes | ||
num_starts: Number of multi-starts to use. If 0, no multi-start decoding is used | ||
td: TensorDict containing the environment state. | ||
This one is not used in this class. However, passing Tensordict can be useful in child classes. | ||
""" | ||
|
||
# The projection of the node embeddings for the attention is calculated once up front | ||
( | ||
glimpse_key_fixed, | ||
glimpse_val_fixed, | ||
logit_key_fixed, | ||
) = self.project_node_embeddings(embeddings).chunk(3, dim=-1) | ||
) = self.project_node_embeddings( | ||
embeddings | ||
).chunk(3, dim=-1) | ||
|
||
# Optionally disable the graph context from the initial embedding as done in POMO | ||
if self.use_graph_context: | ||
|
@@ -267,9 +262,7 @@ def _get_log_p( | |
mask = ~td_unbatch["action_mask"] | ||
|
||
# Compute logits | ||
log_p = self.logit_attention( | ||
glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp | ||
) | ||
log_p = self.logit_attention(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp) | ||
|
||
# Now we need to reshape the logits and log_p to [batch_size*num_starts, num_nodes] | ||
# Note that rearranging order is important here | ||
|
I see that here you are passing a
td
(I guess this instead of thePrecomputedCache
?) but no TensorDict is used here. Is this the expected behavior?