Skip to content

Commit

Permalink
move encoding of states into separate function for robotic transforme…
Browse files Browse the repository at this point in the history
…r, to get ready for encoder / decoder
  • Loading branch information
lucidrains committed Nov 29, 2023
1 parent e5573e0 commit e700f84
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ Implementation of <a href="https://qtransformer.github.io/">Q-Transformer</a>, S
- [x] add n-step Q learning
- [x] build the conservative regularization

- [ ] improvise a cross attention variant + another decoder head, instead of concatenating previous actions at the frames + learned tokens stage. in other words, using a hierarchical transformer
- [ ] improvise a cross attention variant + another decoder head, instead of concatenating previous actions at the frames + learned tokens stage. in other words, use classic encoder - decoder

- [ ] build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)

- [ ] build out a simple dataset creator class, taking in the environment as an iterator / generator
Expand Down
34 changes: 28 additions & 6 deletions q_transformer/robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance

from x_transformers import (
Decoder,
AutoregressiveWrapper
)

# helpers

def exists(val):
Expand Down Expand Up @@ -705,10 +710,9 @@ def get_best_actions(

return action_indices, max_q

@classifier_free_guidance
def forward(
def encode_state(
self,
video,
video: Tensor,
texts: Optional[Union[List[str], Tuple[str]]] = None,
actions: Optional[Tensor] = None,
cond_drop_prob = 0.,
Expand Down Expand Up @@ -758,7 +762,7 @@ def forward(

# causal attention mask

attn_mask = torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
attn_mask = ~torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens)

# sinusoidal positional embedding
Expand All @@ -769,10 +773,28 @@ def forward(

# attention

attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = ~attn_mask)
attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = attn_mask)

return attended_tokens

@classifier_free_guidance
def forward(
self,
video: Tensor,
texts: Optional[Union[List[str], Tuple[str]]] = None,
actions: Optional[Tensor] = None,
cond_drop_prob = 0.,
):

encoded_state = self.encode_state(
video = video,
texts = texts,
actions = actions,
cond_drop_prob = cond_drop_prob
)

# single actions

q_values = self.to_q_values(attended_tokens)
q_values = self.to_q_values(encoded_state)

return q_values
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
'ema-pytorch>=0.3.1',
'classifier-free-guidance-pytorch>=0.1.4',
'torchtyping',
'torch>=2.0'
'torch>=2.0',
'x-transformers>=1.26.0'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit e700f84

Please sign in to comment.