diff --git a/README.md b/README.md index 42e5413..ef3e7c1 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind +I will be keeping around the logic for Q-learning on single action just for final comparison with the proposed autoregressive discrete multiple actions. Also to serve as education for myself and the public. + ## Appreciation - StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research diff --git a/q_transformer/robotic_transformer.py b/q_transformer/robotic_transformer.py index cfa6ac4..6af2f3e 100644 --- a/q_transformer/robotic_transformer.py +++ b/q_transformer/robotic_transformer.py @@ -16,11 +16,6 @@ from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance -from x_transformers import ( - Decoder, - AutoregressiveWrapper -) - # helpers def exists(val): @@ -599,6 +594,57 @@ def forward(self, x): q_values = values + advantages return q_values.sigmoid() +# Action Transformer Decoder Head Modules + +class SingleActionHead(Module): + def __init__( + self, + dim, + *, + num_learned_tokens = 8, + action_bins = 256, + dueling = False + ): + super().__init__() + self.action_bins = action_bins + + if dueling: + self.to_q_values = nn.Sequential( + Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens), + DuelingHead( + dim, + action_bins = action_bins + ) + ) + else: + self.to_q_values = nn.Sequential( + Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens), + nn.LayerNorm(dim), + nn.Linear(dim, action_bins), + nn.Sigmoid() + ) + + def get_random_actions(self, batch_size): + return torch.randint(0, self.action_bins, (batch_size,), device = self.device) + + def get_best_actions( + self, + encoded_state, + return_q_values = False, + **kwargs + ): + q_values = self.forward(encoded_state) + + max_q, action_indices = q_values.max(dim = -1) + + if not return_q_values: + return action_indices + + return action_indices, max_q + + def forward(self, x): + return self.to_q_values(x) + # Robotic Transformer class QRoboticTransformer(Module): @@ -671,28 +717,19 @@ def __init__( self.cond_drop_prob = cond_drop_prob - if dueling: - self.to_q_values = nn.Sequential( - Reduce('b (f n) d -> b d', 'mean', n = self.num_learned_tokens), - DuelingHead( - attend_dim, - action_bins = action_bins - ) - ) - else: - self.to_q_values = nn.Sequential( - Reduce('b (f n) d -> b d', 'mean', n = self.num_learned_tokens), - LayerNorm(attend_dim), - nn.Linear(attend_dim, action_bins), - nn.Sigmoid() - ) + self.action_head = SingleActionHead( + attend_dim, + num_learned_tokens = self.num_learned_tokens, + action_bins = action_bins, + dueling = dueling + ) @property def device(self): return next(self.parameters()).device def get_random_actions(self, batch_size = 1): - return torch.randint(0, self.action_bins, (batch_size,), device = self.device) + return self.action_head.get_random_actions(batch_size) @torch.no_grad() def get_best_actions( @@ -701,14 +738,8 @@ def get_best_actions( return_q_values = False, **kwargs ): - q_values = self.forward(*args, **kwargs) - - max_q, action_indices = q_values.max(dim = -1) - - if not return_q_values: - return action_indices - - return action_indices, max_q + encoded_state = self.encode_state(*args, **kwargs) + return self.action_head.get_best_actions(encoded_state, return_q_values = return_q_values) def encode_state( self, @@ -793,8 +824,9 @@ def forward( cond_drop_prob = cond_drop_prob ) - # single actions + # head that returns the q values + # supporting both single and multiple actions - q_values = self.to_q_values(encoded_state) + q_values = self.action_head(encoded_state) return q_values diff --git a/setup.py b/setup.py index aa53dec..7379a24 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.18', + version = '0.0.19', license='MIT', description = 'Q-Transformer', author = 'Phil Wang', @@ -24,8 +24,7 @@ 'ema-pytorch>=0.3.1', 'classifier-free-guidance-pytorch>=0.1.4', 'torchtyping', - 'torch>=2.0', - 'x-transformers>=1.26.0' + 'torch>=2.0' ], classifiers=[ 'Development Status :: 4 - Beta',