Skip to content

Commit

Permalink
may be a good project to try tensortype
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 28, 2023
1 parent 0c53c1f commit 571b674
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
29 changes: 15 additions & 14 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.nn import Module, ModuleList
from torch.utils.data import Dataset, DataLoader

from torchtyping import TensorType

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

Expand Down Expand Up @@ -232,14 +234,13 @@ def get_discount_matrix(self, timestep):

def q_learn(
self,
instructions: Tuple[str],
states: Tensor,
actions: Tensor,
next_states: Tensor,
reward: Tensor,
done: Tensor
instructions: Tuple[str],
states: TensorType['b', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
reward: TensorType['b', float],
done: TensorType['b', bool]
) -> Tuple[Tensor, QIntermediates]:

# 'next' stands for the very next time step (whether state, q, actions etc)

γ = self.discount_factor_gamma
Expand Down Expand Up @@ -270,12 +271,12 @@ def q_learn(

def n_step_q_learn(
self,
instructions: Tuple[str],
states: Tensor,
actions: Tensor,
next_states: Tensor,
rewards: Tensor,
dones: Tensor
instructions: Tuple[str],
states: TensorType['b', 't', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', 't', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool]
) -> Tuple[Tensor, QIntermediates]:
"""
einops
Expand All @@ -285,8 +286,8 @@ def n_step_q_learn(
f - frames
h - height
w - width
a - action bins
t - timesteps
q - q values
"""

num_timesteps, device = states.shape[1], states.device
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand All @@ -23,6 +23,7 @@
'einops>=0.7.0',
'ema-pytorch>=0.3.1',
'classifier-free-guidance-pytorch>=0.1.4',
'torchtyping',
'torch>=2.0'
],
classifiers=[
Expand Down

0 comments on commit 571b674

Please sign in to comment.