You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am eager to integrate this work into my own project. However, I have a question regarding certain parts of the code and would greatly appreciate the author's assistance.
When calculating the logits of the Q-values (in def get_q_values), why is torch.roll used to shift embedding_values? I am struggling to understand why this step is necessary, as it seems to only occur during backpropagation.
action_bin_embeddings = self.action_bin_embeddings[:num_actions] action_bin_embeddings = torch.roll(action_bin_embeddings, shifts = -1, dims = 1) logits = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)
The text was updated successfully, but these errors were encountered:
I am eager to integrate this work into my own project. However, I have a question regarding certain parts of the code and would greatly appreciate the author's assistance.
When calculating the logits of the Q-values (in def get_q_values), why is torch.roll used to shift embedding_values? I am struggling to understand why this step is necessary, as it seems to only occur during backpropagation.
action_bin_embeddings = self.action_bin_embeddings[:num_actions] action_bin_embeddings = torch.roll(action_bin_embeddings, shifts = -1, dims = 1) logits = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)
The text was updated successfully, but these errors were encountered: