diff --git a/README.md b/README.md
index 42e5413..ef3e7c1 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,8 @@
Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
+I will be keeping around the logic for Q-learning on single action just for final comparison with the proposed autoregressive discrete multiple actions. Also to serve as education for myself and the public.
+
## Appreciation
- StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
diff --git a/q_transformer/robotic_transformer.py b/q_transformer/robotic_transformer.py
index cfa6ac4..6af2f3e 100644
--- a/q_transformer/robotic_transformer.py
+++ b/q_transformer/robotic_transformer.py
@@ -16,11 +16,6 @@
from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance
-from x_transformers import (
- Decoder,
- AutoregressiveWrapper
-)
-
# helpers
def exists(val):
@@ -599,6 +594,57 @@ def forward(self, x):
q_values = values + advantages
return q_values.sigmoid()
+# Action Transformer Decoder Head Modules
+
+class SingleActionHead(Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ num_learned_tokens = 8,
+ action_bins = 256,
+ dueling = False
+ ):
+ super().__init__()
+ self.action_bins = action_bins
+
+ if dueling:
+ self.to_q_values = nn.Sequential(
+ Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
+ DuelingHead(
+ dim,
+ action_bins = action_bins
+ )
+ )
+ else:
+ self.to_q_values = nn.Sequential(
+ Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
+ nn.LayerNorm(dim),
+ nn.Linear(dim, action_bins),
+ nn.Sigmoid()
+ )
+
+ def get_random_actions(self, batch_size):
+ return torch.randint(0, self.action_bins, (batch_size,), device = self.device)
+
+ def get_best_actions(
+ self,
+ encoded_state,
+ return_q_values = False,
+ **kwargs
+ ):
+ q_values = self.forward(encoded_state)
+
+ max_q, action_indices = q_values.max(dim = -1)
+
+ if not return_q_values:
+ return action_indices
+
+ return action_indices, max_q
+
+ def forward(self, x):
+ return self.to_q_values(x)
+
# Robotic Transformer
class QRoboticTransformer(Module):
@@ -671,28 +717,19 @@ def __init__(
self.cond_drop_prob = cond_drop_prob
- if dueling:
- self.to_q_values = nn.Sequential(
- Reduce('b (f n) d -> b d', 'mean', n = self.num_learned_tokens),
- DuelingHead(
- attend_dim,
- action_bins = action_bins
- )
- )
- else:
- self.to_q_values = nn.Sequential(
- Reduce('b (f n) d -> b d', 'mean', n = self.num_learned_tokens),
- LayerNorm(attend_dim),
- nn.Linear(attend_dim, action_bins),
- nn.Sigmoid()
- )
+ self.action_head = SingleActionHead(
+ attend_dim,
+ num_learned_tokens = self.num_learned_tokens,
+ action_bins = action_bins,
+ dueling = dueling
+ )
@property
def device(self):
return next(self.parameters()).device
def get_random_actions(self, batch_size = 1):
- return torch.randint(0, self.action_bins, (batch_size,), device = self.device)
+ return self.action_head.get_random_actions(batch_size)
@torch.no_grad()
def get_best_actions(
@@ -701,14 +738,8 @@ def get_best_actions(
return_q_values = False,
**kwargs
):
- q_values = self.forward(*args, **kwargs)
-
- max_q, action_indices = q_values.max(dim = -1)
-
- if not return_q_values:
- return action_indices
-
- return action_indices, max_q
+ encoded_state = self.encode_state(*args, **kwargs)
+ return self.action_head.get_best_actions(encoded_state, return_q_values = return_q_values)
def encode_state(
self,
@@ -793,8 +824,9 @@ def forward(
cond_drop_prob = cond_drop_prob
)
- # single actions
+ # head that returns the q values
+ # supporting both single and multiple actions
- q_values = self.to_q_values(encoded_state)
+ q_values = self.action_head(encoded_state)
return q_values
diff --git a/setup.py b/setup.py
index aa53dec..7379a24 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
- version = '0.0.18',
+ version = '0.0.19',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
@@ -24,8 +24,7 @@
'ema-pytorch>=0.3.1',
'classifier-free-guidance-pytorch>=0.1.4',
'torchtyping',
- 'torch>=2.0',
- 'x-transformers>=1.26.0'
+ 'torch>=2.0'
],
classifiers=[
'Development Status :: 4 - Beta',