Skip to content

Commit

Permalink
feat: add ConvNextV2Block/Item with GRN
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jan 9, 2023
1 parent 50737ba commit db80168
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
12 changes: 12 additions & 0 deletions a_unet/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Attention,
Conv,
ConvBlock,
ConvNextV2Block,
CrossAttention,
Downsample,
FeedForward,
Expand Down Expand Up @@ -92,6 +93,17 @@ def ResnetItem(
) # type: ignore


def ConvNextV2Item(
dim: Optional[int] = None,
channels: Optional[int] = None,
**kwargs,
) -> nn.Module:
msg = "ResnetItem requires dim and channels"
assert exists(dim) and exists(channels), msg
Item = SelectX(ConvNextV2Block)
return Item(dim=dim, channels=channels) # type: ignore


def AttentionItem(
channels: Optional[int] = None,
attention_features: Optional[int] = None,
Expand Down
45 changes: 45 additions & 0 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,51 @@ def ResnetBlock(
return Module([conv_block, conv], lambda x: conv_block(x) + conv(x))


class GRN(nn.Module):
"""GRN (Global Response Normalization) layer from ConvNextV2 generic to any dim"""

def __init__(self, dim: int, channels: int):
super().__init__()
ones = (1,) * dim
self.gamma = nn.Parameter(torch.zeros(1, channels, *ones))
self.beta = nn.Parameter(torch.zeros(1, channels, *ones))
self.norm_dims = [d + 2 for d in range(dim)]

def forward(self, x: Tensor) -> Tensor:
Gx = torch.norm(x, p=2, dim=self.norm_dims, keepdim=True)
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x


def ConvNextV2Block(dim: int, channels: int) -> nn.Module:
block = nn.Sequential(
# Depthwise and LayerNorm
Conv(
dim=dim,
in_channels=channels,
out_channels=channels,
kernel_size=7,
padding=3,
groups=channels,
),
nn.GroupNorm(num_groups=1, num_channels=channels),
# Pointwise expand
Conv(dim=dim, in_channels=channels, out_channels=channels * 4, kernel_size=1),
# Activation and GRN
nn.GELU(),
GRN(dim=dim, channels=channels * 4),
# Pointwise contract
Conv(
dim=dim,
in_channels=channels * 4,
out_channels=channels,
kernel_size=1,
),
)

return Module([block], lambda x: x + block(x))


def AttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module:
scale = head_features**-0.5
mid_features = head_features * num_heads
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.11",
version="0.0.12",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit db80168

Please sign in to comment.