Skip to content

Commit

Permalink
get ready to expand on this work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 15, 2024
1 parent 1ab5a48 commit f8c3250
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Tests the examples in README
on: push

env:
TYPECHECK: True

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install uv
python -m uv pip install --upgrade pip
python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
python -m uv pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest tests/
2 changes: 1 addition & 1 deletion q_transformer/q_robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, dim, omega = 10000):
inv_freq = 1.0 / (omega ** (torch.arange(0, dim, 4).float() / dim))
self.register_buffer('inv_freq', inv_freq)

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def forward(self, height_width):
device, dtype = self.inv_freq.device, self.inv_freq.dtype

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.2.1',
version = '0.2.2',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down
51 changes: 51 additions & 0 deletions tests/test_q_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch

from q_transformer import (
QRoboticTransformer,
QLearner,
ReplayMemoryDataset
)

def test_q_transformer():
model = QRoboticTransformer(
vit = dict(
num_classes = 1000,
dim_conv_stem = 64,
dim = 64,
dim_head = 64,
depth = (2, 2, 5, 2),
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1
),
num_actions = 8,
depth = 1,
heads = 8,
dim_head = 64,
cond_drop_prob = 0.2,
dueling = True,
weight_tie_action_bin_embed = False
)

video = torch.randn(2, 3, 6, 224, 224)

instructions = [
'bring me that apple sitting on the table',
'please pass the butter'
]

text_embeds = model.embed_texts(instructions)
best_actions = model.get_actions(video, text_embeds = text_embeds)
best_actions = model.get_optimal_actions(video, instructions, actions = best_actions[:, :1])

q_values = model(video, instructions, actions = best_actions)

q_learner = QLearner(
model,
dataset = ReplayMemoryDataset('./replay_memories_data'),
n_step_q_learning = True,
num_train_steps = 10000,
learning_rate = 3e-4,
batch_size = 1
)

0 comments on commit f8c3250

Please sign in to comment.