From f8c3250fb24b208957e81d0bc55b1547861f2c78 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 15 Nov 2024 07:19:21 -0800 Subject: [PATCH] get ready to expand on this work --- .github/workflows/test.yaml | 24 ++++++++++++ q_transformer/q_robotic_transformer.py | 2 +- setup.py | 2 +- tests/test_q_transformer.py | 51 ++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/test.yaml create mode 100644 tests/test_q_transformer.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..593e558 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,24 @@ +name: Tests the examples in README +on: push + +env: + TYPECHECK: True + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install uv + python -m uv pip install --upgrade pip + python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu + python -m uv pip install -e .[test] + - name: Test with pytest + run: | + python -m pytest tests/ diff --git a/q_transformer/q_robotic_transformer.py b/q_transformer/q_robotic_transformer.py index 6107e81..c5c190a 100644 --- a/q_transformer/q_robotic_transformer.py +++ b/q_transformer/q_robotic_transformer.py @@ -63,7 +63,7 @@ def __init__(self, dim, omega = 10000): inv_freq = 1.0 / (omega ** (torch.arange(0, dim, 4).float() / dim)) self.register_buffer('inv_freq', inv_freq) - @autocast(enabled = False) + @autocast('cuda', enabled = False) def forward(self, height_width): device, dtype = self.inv_freq.device, self.inv_freq.dtype diff --git a/setup.py b/setup.py index cf359a2..cad5ef4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.2.1', + version = '0.2.2', license='MIT', description = 'Q-Transformer', author = 'Phil Wang', diff --git a/tests/test_q_transformer.py b/tests/test_q_transformer.py new file mode 100644 index 0000000..4ebb9b1 --- /dev/null +++ b/tests/test_q_transformer.py @@ -0,0 +1,51 @@ +import torch + +from q_transformer import ( + QRoboticTransformer, + QLearner, + ReplayMemoryDataset +) + +def test_q_transformer(): + model = QRoboticTransformer( + vit = dict( + num_classes = 1000, + dim_conv_stem = 64, + dim = 64, + dim_head = 64, + depth = (2, 2, 5, 2), + window_size = 7, + mbconv_expansion_rate = 4, + mbconv_shrinkage_rate = 0.25, + dropout = 0.1 + ), + num_actions = 8, + depth = 1, + heads = 8, + dim_head = 64, + cond_drop_prob = 0.2, + dueling = True, + weight_tie_action_bin_embed = False + ) + + video = torch.randn(2, 3, 6, 224, 224) + + instructions = [ + 'bring me that apple sitting on the table', + 'please pass the butter' + ] + + text_embeds = model.embed_texts(instructions) + best_actions = model.get_actions(video, text_embeds = text_embeds) + best_actions = model.get_optimal_actions(video, instructions, actions = best_actions[:, :1]) + + q_values = model(video, instructions, actions = best_actions) + + q_learner = QLearner( + model, + dataset = ReplayMemoryDataset('./replay_memories_data'), + n_step_q_learning = True, + num_train_steps = 10000, + learning_rate = 3e-4, + batch_size = 1 + )