From 571b674ab34bd33a1f410eb02a2b2a5c8ce174cb Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 27 Nov 2023 16:17:09 -0800 Subject: [PATCH] may be a good project to try tensortype --- q_transformer/q_transformer.py | 29 +++++++++++++++-------------- setup.py | 3 ++- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/q_transformer/q_transformer.py b/q_transformer/q_transformer.py index 2b97639..50fbc29 100644 --- a/q_transformer/q_transformer.py +++ b/q_transformer/q_transformer.py @@ -9,6 +9,8 @@ from torch.nn import Module, ModuleList from torch.utils.data import Dataset, DataLoader +from torchtyping import TensorType + from einops import rearrange, repeat, pack, unpack from einops.layers.torch import Rearrange @@ -232,14 +234,13 @@ def get_discount_matrix(self, timestep): def q_learn( self, - instructions: Tuple[str], - states: Tensor, - actions: Tensor, - next_states: Tensor, - reward: Tensor, - done: Tensor + instructions: Tuple[str], + states: TensorType['b', 'c', 'f', 'h', 'w', float], + actions: TensorType['b', int], + next_states: TensorType['b', 'c', 'f', 'h', 'w', float], + reward: TensorType['b', float], + done: TensorType['b', bool] ) -> Tuple[Tensor, QIntermediates]: - # 'next' stands for the very next time step (whether state, q, actions etc) γ = self.discount_factor_gamma @@ -270,12 +271,12 @@ def q_learn( def n_step_q_learn( self, - instructions: Tuple[str], - states: Tensor, - actions: Tensor, - next_states: Tensor, - rewards: Tensor, - dones: Tensor + instructions: Tuple[str], + states: TensorType['b', 't', 'c', 'f', 'h', 'w', float], + actions: TensorType['b', 't', int], + next_states: TensorType['b', 'c', 'f', 'h', 'w', float], + rewards: TensorType['b', 't', float], + dones: TensorType['b', 't', bool] ) -> Tuple[Tensor, QIntermediates]: """ einops @@ -285,8 +286,8 @@ def n_step_q_learn( f - frames h - height w - width - a - action bins t - timesteps + q - q values """ num_timesteps, device = states.shape[1], states.device diff --git a/setup.py b/setup.py index 8ac0999..476902c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.0.8', + version = '0.0.9', license='MIT', description = 'Q-Transformer', author = 'Phil Wang', @@ -23,6 +23,7 @@ 'einops>=0.7.0', 'ema-pytorch>=0.3.1', 'classifier-free-guidance-pytorch>=0.1.4', + 'torchtyping', 'torch>=2.0' ], classifiers=[