You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was just wondering if you could explain/give some motivation for why the dynamics network works as it does.
I'm looking at a simple ATARI example and when I'm inside: def dynamics(self, encoded_state, reward_hidden, action):
the encoded state is [2, 64, 6, 6] (batch size of 2 - just as a test), and the actions is [2, 1] (integers between 1 and 4).
You then define "actions_one_hot" as torch.ones(2, 1, 6, 6) and say: actions_one_hot = actions[:, :, None, None] * actions_one_hot / self.action_space_size
which gives actions_one_hot as [2, 1, 6, 6], with the values copied along the final two dimensions (so each action value is copied 36 times here). Then you concatenate with the encoded state along dim=1 to give a final state which is [2, 65, 6, 6].
Is this a standard thing to do/something that's been done elsewhere? It just feels a bit weird to me. Firstly, the actions are not "one hot encoded" here, so maybe the variable names aren't perfect (but that doesn't really matter I guess). I suppose it makes sense in that you probably want to be able to apply convolutions to the joint state/action within the dynamics network. And I guess with n_actions=4 this is fine, but it feels like this approach would probably break with a larger discrete action space, right?
Anyway if you have the time I'd be interested to hear your motivation/reasoning behind this, thanks!
The text was updated successfully, but these errors were encountered:
Here we scatter the actions to planes to ensure the shape is the same as the shape of the feature plane (eg: feature is B x 64 x 6 x 6 and action is B x 1 x 6 x 6). And we choose a / action_space to scale the actions.
I agree that it is not a VERY good or natural way to shape the actions because the distance between (a=1, a=2) is different from that between (a=1, a=3). A Good way is to broadcast the action into shape B x Action_space x 6 x 6 with one-hot labels. But it's a large tensor because of the spatial feature, especially when action space is large. Moreover, when action space=18 or even 81, the current implementation can still work well. So we just keep the current implementation.
By the way, it is interesting to find a better method for the action representation under spatial features (not flatten features). Hope this can help you:)
I guess my thought was it'd be better to broadcast onto the height dimension - so you could one-hot encode the actions and end up with a B x 64 x 6 x 10 tensor (or 64 x 6 x 24 if you have 18 actions). But I can see for a large number of actions this would become a large tensor.
Hi,
I was just wondering if you could explain/give some motivation for why the dynamics network works as it does.
I'm looking at a simple ATARI example and when I'm inside:
def dynamics(self, encoded_state, reward_hidden, action):
the encoded state is [2, 64, 6, 6] (batch size of 2 - just as a test), and the actions is [2, 1] (integers between 1 and 4).
You then define "actions_one_hot" as
torch.ones(2, 1, 6, 6)
and say:actions_one_hot = actions[:, :, None, None] * actions_one_hot / self.action_space_size
which gives actions_one_hot as [2, 1, 6, 6], with the values copied along the final two dimensions (so each action value is copied 36 times here). Then you concatenate with the encoded state along dim=1 to give a final state which is [2, 65, 6, 6].
Is this a standard thing to do/something that's been done elsewhere? It just feels a bit weird to me. Firstly, the actions are not "one hot encoded" here, so maybe the variable names aren't perfect (but that doesn't really matter I guess). I suppose it makes sense in that you probably want to be able to apply convolutions to the joint state/action within the dynamics network. And I guess with n_actions=4 this is fine, but it feels like this approach would probably break with a larger discrete action space, right?
Anyway if you have the time I'd be interested to hear your motivation/reasoning behind this, thanks!
The text was updated successfully, but these errors were encountered: