Skip to content

Commit

Permalink
Pass tensordict as an additional input of _precompute_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Junyoungpark committed Sep 10, 2023
1 parent 9c5ff00 commit ed204d2
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions rl4co/models/zoo/common/autoregressive/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.nn as nn

from einops import rearrange
from tensordict import TensorDict
from torch import Tensor
Expand Down Expand Up @@ -89,17 +88,11 @@ def __init__(
self.use_graph_context = use_graph_context

# For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
self.project_node_embeddings = nn.Linear(
embedding_dim, 3 * embedding_dim, bias=linear_bias
)
self.project_fixed_context = nn.Linear(
embedding_dim, embedding_dim, bias=linear_bias
)
self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=linear_bias)
self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=linear_bias)

# MHA
self.logit_attention = LogitAttention(
embedding_dim, num_heads, **logit_attn_kwargs
)
self.logit_attention = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs)

self.select_start_nodes_fn = select_start_nodes_fn

Expand Down Expand Up @@ -143,13 +136,11 @@ def forward(
else:
if num_starts is not None:
if num_starts > 1:
log.warn(
f"num_starts={num_starts} is ignored for decode_type={decode_type}"
)
log.warn(f"num_starts={num_starts} is ignored for decode_type={decode_type}")
num_starts = 0

# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
cached_embeds = self._precompute_cache(embeddings, num_starts=num_starts)
cached_embeds = self._precompute_cache(embeddings, td=td, num_starts=num_starts)

# Collect outputs
outputs = []
Expand Down Expand Up @@ -196,20 +187,24 @@ def forward(

return outputs, actions, td

def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0):
def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0, td: TensorDict = None):

This comment has been minimized.

Copy link
@fedebotu

fedebotu Sep 11, 2023

Member

I see that here you are passing a td (I guess this instead of the PrecomputedCache?) but no TensorDict is used here. Is this the expected behavior?

This comment has been minimized.

Copy link
@Junyoungpark

Junyoungpark Sep 11, 2023

Author Collaborator

Yes this is expected behavior. This is for matching input signature of this method. some of the applications (e.g., mTSP) might require to know meta data (e.g., num agent) while preparing the precomputes

"""Compute the cached embeddings for the attention
Args:
embeddings: Precomputed embeddings for the nodes
num_starts: Number of multi-starts to use. If 0, no multi-start decoding is used
td: TensorDict containing the environment state.
This one is not used in this class. However, passing Tensordict can be useful in child classes.
"""

# The projection of the node embeddings for the attention is calculated once up front
(
glimpse_key_fixed,
glimpse_val_fixed,
logit_key_fixed,
) = self.project_node_embeddings(embeddings).chunk(3, dim=-1)
) = self.project_node_embeddings(
embeddings
).chunk(3, dim=-1)

# Optionally disable the graph context from the initial embedding as done in POMO
if self.use_graph_context:
Expand Down Expand Up @@ -267,9 +262,7 @@ def _get_log_p(
mask = ~td_unbatch["action_mask"]

# Compute logits
log_p = self.logit_attention(
glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp
)
log_p = self.logit_attention(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp)

# Now we need to reshape the logits and log_p to [batch_size*num_starts, num_nodes]
# Note that rearranging order is important here
Expand Down

0 comments on commit ed204d2

Please sign in to comment.