Skip to content

Commit

Permalink
update to latest cfg lib with latest research finding + fix a warning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 7, 2024
1 parent 0f84f8f commit 1ab5a48
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions q_transformer/q_robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList

Expand Down Expand Up @@ -81,7 +81,7 @@ def rotate_half(x):
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d c -> ... (d c)')

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.2.1',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand All @@ -21,7 +21,7 @@
'accelerate',
'adam-atan2-pytorch>=0.0.12',
'beartype',
'classifier-free-guidance-pytorch>=0.6.10',
'classifier-free-guidance-pytorch>=0.7.1',
'einops>=0.8.0',
'ema-pytorch>=0.5.3',
'jaxtyping',
Expand Down

0 comments on commit 1ab5a48

Please sign in to comment.