Skip to content

Commit

Permalink
actually remove all traces of torchtyping and add sentencepiece
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 16, 2024
1 parent 233c165 commit 0b5c5ab
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
10 changes: 5 additions & 5 deletions q_transformer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from q_transformer.q_robotic_transformer import QRoboticTransformer

from torchtyping import TensorType
from q_transformer.tensor_typing import Float, Bool

from beartype import beartype
from beartype.typing import Iterator, Tuple
Expand Down Expand Up @@ -146,11 +146,11 @@ def init(self) -> Tuple[str, Tensor]: # (instruction, initial state)

def forward(
self,
actions: Tensor
actions: Int['...']
) -> Tuple[
TensorType[(), float], # reward
Tensor, # next state
TensorType[(), bool] # done
Float[''], # reward
Float['...'], # next state
Bool[''] # done
]:
raise NotImplementedError

Expand Down
10 changes: 5 additions & 5 deletions q_transformer/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

from beartype.typing import Tuple

from torchtyping import TensorType
from q_transformer.tensor_typing import Float, Int, Bool
from q_transformer.agent import BaseEnvironment

class MockEnvironment(BaseEnvironment):
def init(self) -> Tuple[
str | None,
TensorType[float]
Float['...']
]:
return 'please clean the kitchen', torch.randn(self.state_shape, device = self.device)

def forward(self, actions) -> Tuple[
TensorType[(), float],
TensorType[float],
TensorType[(), bool]
Float[''],
Float['...'],
Bool['']
]:
rewards = torch.randn((), device = self.device)
next_states = torch.randn(self.state_shape, device = self.device)
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.1.16',
version = '0.1.17',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand All @@ -20,11 +20,12 @@
install_requires=[
'accelerate',
'beartype',
'classifier-free-guidance-pytorch>=0.4.2',
'einops>=0.7.0',
'ema-pytorch>=0.3.1',
'classifier-free-guidance-pytorch>=0.6.10',
'einops>=0.8.0',
'ema-pytorch>=0.5.3',
'jaxtyping',
'numpy',
'sentencepiece',
'torch>=2.0'
],
classifiers=[
Expand Down

0 comments on commit 0b5c5ab

Please sign in to comment.