diff --git a/q_transformer/q_robotic_transformer.py b/q_transformer/q_robotic_transformer.py index 5b1f603..81a6457 100644 --- a/q_transformer/q_robotic_transformer.py +++ b/q_transformer/q_robotic_transformer.py @@ -347,6 +347,7 @@ def __init__( heads = 8, dim_head = 64, dim_conv_stem = None, + conv_stem_downsample = True, window_size = 7, mbconv_expansion_rate = 4, mbconv_shrinkage_rate = 0.25, @@ -356,14 +357,22 @@ def __init__( flash_attn = True ): super().__init__() + + self.depth = depth + # convolutional stem dim_conv_stem = default(dim_conv_stem, dim) - self.conv_stem = nn.Sequential( - nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1), - nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1) - ) + self.conv_stem_downsample = conv_stem_downsample + + if conv_stem_downsample: + self.conv_stem = nn.Sequential( + nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1), + nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1) + ) + else: + self.conv_stem = nn.Conv2d(channels, dim_conv_stem, 7, padding = 3) # variables @@ -433,6 +442,10 @@ def __init__( nn.Linear(embed_dim, num_classes) ) + @property + def downsample_factor(self): + return (2 if self.conv_stem_downsample else 1) * (2 ** len(self.depth)) + @beartype def forward( self, @@ -442,7 +455,9 @@ def forward( cond_drop_prob = 0., return_embeddings = False ): - assert all([divisible_by(d, self.window_size) for d in img.shape[-2:]]) + hw = img.shape[-2:] + assert all([divisible_by(d, self.window_size) for d in hw]), f'height and width of video frames {tuple(hw)} must be divisible by window size {self.window_size}' + assert all([divisible_by(d, self.downsample_factor) for d in hw]), f'height and width of video frames {tuple(hw)} must be divisible by total downsample factor {self.downsample_factor}' x = self.conv_stem(img) diff --git a/setup.py b/setup.py index da1e771..19b3b9a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'q-transformer', packages = find_packages(exclude=[]), - version = '0.1.14', + version = '0.1.15', license='MIT', description = 'Q-Transformer', author = 'Phil Wang',