Skip to content

Commit

Permalink
move single action q head into own module
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 29, 2023
1 parent e700f84 commit c496cd4
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 34 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

Implementation of <a href="https://qtransformer.github.io/">Q-Transformer</a>, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind

I will be keeping around the logic for Q-learning on single action just for final comparison with the proposed autoregressive discrete multiple actions. Also to serve as education for myself and the public.

## Appreciation

- <a href="https://stability.ai/">StabilityAI</a>, <a href="https://a16z.com/supporting-the-open-source-ai-community/">A16Z Open Source AI Grant Program</a>, and <a href="https://huggingface.co/">🤗 Huggingface</a> for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Expand Down
94 changes: 63 additions & 31 deletions q_transformer/robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@

from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance

from x_transformers import (
Decoder,
AutoregressiveWrapper
)

# helpers

def exists(val):
Expand Down Expand Up @@ -599,6 +594,57 @@ def forward(self, x):
q_values = values + advantages
return q_values.sigmoid()

# Action Transformer Decoder Head Modules

class SingleActionHead(Module):
def __init__(
self,
dim,
*,
num_learned_tokens = 8,
action_bins = 256,
dueling = False
):
super().__init__()
self.action_bins = action_bins

if dueling:
self.to_q_values = nn.Sequential(
Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
DuelingHead(
dim,
action_bins = action_bins
)
)
else:
self.to_q_values = nn.Sequential(
Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
nn.LayerNorm(dim),
nn.Linear(dim, action_bins),
nn.Sigmoid()
)

def get_random_actions(self, batch_size):
return torch.randint(0, self.action_bins, (batch_size,), device = self.device)

def get_best_actions(
self,
encoded_state,
return_q_values = False,
**kwargs
):
q_values = self.forward(encoded_state)

max_q, action_indices = q_values.max(dim = -1)

if not return_q_values:
return action_indices

return action_indices, max_q

def forward(self, x):
return self.to_q_values(x)

# Robotic Transformer

class QRoboticTransformer(Module):
Expand Down Expand Up @@ -671,28 +717,19 @@ def __init__(

self.cond_drop_prob = cond_drop_prob

if dueling:
self.to_q_values = nn.Sequential(
Reduce('b (f n) d -> b d', 'mean', n = self.num_learned_tokens),
DuelingHead(
attend_dim,
action_bins = action_bins
)
)
else:
self.to_q_values = nn.Sequential(
Reduce('b (f n) d -> b d', 'mean', n = self.num_learned_tokens),
LayerNorm(attend_dim),
nn.Linear(attend_dim, action_bins),
nn.Sigmoid()
)
self.action_head = SingleActionHead(
attend_dim,
num_learned_tokens = self.num_learned_tokens,
action_bins = action_bins,
dueling = dueling
)

@property
def device(self):
return next(self.parameters()).device

def get_random_actions(self, batch_size = 1):
return torch.randint(0, self.action_bins, (batch_size,), device = self.device)
return self.action_head.get_random_actions(batch_size)

@torch.no_grad()
def get_best_actions(
Expand All @@ -701,14 +738,8 @@ def get_best_actions(
return_q_values = False,
**kwargs
):
q_values = self.forward(*args, **kwargs)

max_q, action_indices = q_values.max(dim = -1)

if not return_q_values:
return action_indices

return action_indices, max_q
encoded_state = self.encode_state(*args, **kwargs)
return self.action_head.get_best_actions(encoded_state, return_q_values = return_q_values)

def encode_state(
self,
Expand Down Expand Up @@ -793,8 +824,9 @@ def forward(
cond_drop_prob = cond_drop_prob
)

# single actions
# head that returns the q values
# supporting both single and multiple actions

q_values = self.to_q_values(encoded_state)
q_values = self.action_head(encoded_state)

return q_values
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.18',
version = '0.0.19',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand All @@ -24,8 +24,7 @@
'ema-pytorch>=0.3.1',
'classifier-free-guidance-pytorch>=0.1.4',
'torchtyping',
'torch>=2.0',
'x-transformers>=1.26.0'
'torch>=2.0'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit c496cd4

Please sign in to comment.