diff --git a/README.md b/README.md index 5a01436..3d861b3 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # A-UNet -A library that provides building blocks to customize UNets, in PyTorch. +A toolbox that provides hackable building blocks for (1D/2D/3D) UNets, in PyTorch. ## Install ```bash diff --git a/a_unet/blocks.py b/a_unet/blocks.py index c27dfed..ef54f4f 100644 --- a/a_unet/blocks.py +++ b/a_unet/blocks.py @@ -25,7 +25,7 @@ def __call__(self, *a, **ka): class Inner: def __init__(self): self.args = a - self.kwargs = ka + self.__dict__.update(**ka) def __call__(self, *b, **kb): if override: @@ -320,11 +320,11 @@ def forward(x: Tensor, features: Tensor) -> Tensor: return Module([to_scale_shift, norm], forward) -def MergeAdd(*args, **kwargs): +def MergeAdd(): return Module([], lambda x, y, *_: x + y) -def MergeCat(dim: int, channels: int, **kwargs) -> nn.Module: +def MergeCat(dim: int, channels: int) -> nn.Module: conv = Conv(dim=dim, in_channels=channels * 2, out_channels=channels, kernel_size=1) return Module([conv], lambda x, y, *_: conv(torch.cat([x, y], dim=1))) diff --git a/a_unet/unet/apex.py b/a_unet/unet/apex.py index 7cf4e2d..d4e94e9 100644 --- a/a_unet/unet/apex.py +++ b/a_unet/unet/apex.py @@ -1,10 +1,11 @@ -from typing import Callable, List, Optional, Sequence, Type, Union +from typing import Callable, List, Optional, Sequence from torch import Tensor, nn from ..blocks import ( Attention, Conv, + ConvBlock, CrossAttention, Downsample, FeedForward, @@ -23,155 +24,69 @@ exists, ) -""" Block """ - -SkipAdd = MergeAdd -SkipCat = MergeCat -SkipModulate = MergeModulate - - -class Block(nn.Module): - def __init__( - self, - dim: int, - in_channels: int, - channels: int, - factor: int, - depth: Optional[int] = None, - downsample_t: Callable = Downsample, - upsample_t: Callable = Upsample, - skip_t: Callable = SkipAdd, - items: Sequence[Union[Type, str]] = [], - items_up: Optional[Sequence[Type]] = None, - modulation_features: Optional[int] = None, - out_channels: Optional[int] = None, - inner_block: Optional[nn.Module] = None, - **kwargs - ): - super().__init__() - out_channels = default(out_channels, in_channels) - items_down = items - items_up = default(items_up, items_down) # type: ignore - - self.downsample = downsample_t( - dim=dim, factor=factor, in_channels=in_channels, out_channels=channels - ) - - self.upsample = upsample_t( - dim=dim, factor=factor, in_channels=channels, out_channels=out_channels - ) - - self.skip_adapter = ( - Conv( - dim=dim, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - ) - if in_channels != out_channels - else nn.Identity() - ) - - self.merge = skip_t( - dim=dim, channels=out_channels, modulation_features=modulation_features - ) - - # Kwargs forwarded to all items - item_kwargs = dict( - dim=dim, - channels=channels, - depth=depth, - modulation_features=modulation_features, - **kwargs - ) - - # Stack and build: items down -> inner block -> items up - items_all: List[nn.Module] = [] - for i, items in enumerate([items_down, items_up]): # type: ignore - for Item in items: - items_all += [Item(**item_kwargs)] # type: ignore - if i == 0 and exists(inner_block): - items_all += [inner_block] - - self.block = Sequential(*items_all) - - def forward( - self, - x: Tensor, - features: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - channels: Optional[Sequence[Tensor]] = None, - ) -> Tensor: - skip = self.skip_adapter(x) - x = self.downsample(x) - x = self.block(x, features, embedding, channels) - x = self.upsample(x) - x = self.merge(skip, x, features) - return x - - -# Block type, to be provided in UNet -BlockT = T(Block, override=False) - - -""" UNet """ +""" +Items +""" +# Selections for item forward paramters +SelectX = Select(lambda x, *_: (x,)) +SelectXE = Select(lambda x, f, e, *_: (x, e)) +SelectXF = Select(lambda x, f, *_: (x, f)) -class UNet(nn.Module): - def __init__( - self, - dim: int, - in_channels: int, - blocks: Sequence, - out_channels: Optional[int] = None, - **kwargs - ): - super().__init__() - num_layers = len(blocks) - out_channels = default(out_channels, in_channels) - def Net(i: int) -> Optional[nn.Module]: - if i == num_layers: - return None - block_t = blocks[i] - in_ch = in_channels if i == 0 else blocks[i - 1].kwargs["channels"] - out_ch = out_channels if i == 0 else in_ch - return block_t( - dim=dim, - in_channels=in_ch, - out_channels=out_ch, - inner_block=Net(i + 1), - **kwargs - ) +""" Downsample / Upsample """ - self.net = Net(0) - def forward( - self, - x: Tensor, - features: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - channels: Optional[Sequence[Tensor]] = None, - ) -> Tensor: - return self.net(x, features, embedding, channels) # type: ignore +def DownsampleItem( + dim: Optional[int] = None, + factor: Optional[int] = None, + in_channels: Optional[int] = None, + channels: Optional[int] = None, + **kwargs +) -> nn.Module: + msg = "DownsampleItem requires dim, factor, in_channels, channels" + assert ( + exists(dim) and exists(factor) and exists(in_channels) and exists(channels) + ), msg + Item = SelectX(Downsample) + return Item( # type: ignore + dim=dim, factor=factor, in_channels=in_channels, out_channels=channels + ) -""" Items """ +def UpsampleItem( + dim: Optional[int] = None, + factor: Optional[int] = None, + channels: Optional[int] = None, + out_channels: Optional[int] = None, + **kwargs +) -> nn.Module: + msg = "UpsampleItem requires dim, factor, channels, out_channels" + assert ( + exists(dim) and exists(factor) and exists(channels) and exists(out_channels) + ), msg + Item = SelectX(Upsample) + return Item( # type: ignore + dim=dim, factor=factor, in_channels=channels, out_channels=out_channels + ) -# Selections for item forward paramters -SelectX = Select(lambda x, *_: (x,)) -SelectXE = Select(lambda x, f, e, *_: (x, e)) -SelectXF = Select(lambda x, f, *_: (x, f)) +""" Main """ def ResnetItem( - dim: Optional[int] = None, channels: Optional[int] = None, **kwargs + dim: Optional[int] = None, + channels: Optional[int] = None, + resnet_groups: Optional[int] = None, + **kwargs ) -> nn.Module: - msg = "ResnetItem requires dim and channels" - assert exists(dim) and exists(channels), msg + msg = "ResnetItem requires dim, channels, and resnet_groups" + assert exists(dim) and exists(channels) and exists(resnet_groups), msg Item = SelectX(ResnetBlock) - return Item(dim=dim, in_channels=channels, out_channels=channels) # type: ignore + conv_block_t = T(ConvBlock)(norm_t=T(nn.GroupNorm)(num_groups=resnet_groups)) + return Item( + dim=dim, in_channels=channels, out_channels=channels, conv_block_t=conv_block_t + ) # type: ignore def AttentionItem( @@ -181,7 +96,9 @@ def AttentionItem( **kwargs ) -> nn.Module: msg = "AttentionItem requires channels, attention_features, attention_heads" - assert exists(attention_features) and exists(attention_heads), msg + assert ( + exists(channels) and exists(attention_features) and exists(attention_heads) + ), msg Item = SelectX(Attention) return Packed( Item( # type: ignore @@ -201,7 +118,8 @@ def CrossAttentionItem( ) -> nn.Module: msg = "CrossAttentionItem requires channels, embedding_features, attention_*" assert ( - exists(embedding_features) + exists(channels) + and exists(embedding_features) and exists(attention_features) and exists(attention_heads) ), msg @@ -219,7 +137,7 @@ def CrossAttentionItem( def ModulationItem( channels: Optional[int] = None, modulation_features: Optional[int] = None, **kwargs ) -> nn.Module: - msg = "M block requires channels, modulation_features" + msg = "ModulationItem requires channels, modulation_features" assert exists(channels) and exists(modulation_features), msg Item = SelectXF(Modulation) return Packed( @@ -233,8 +151,10 @@ def LinearAttentionItem( attention_heads: Optional[int] = None, **kwargs ) -> nn.Module: - msg = "LA block requires attention_features and attention_heads" - assert exists(attention_features) and exists(attention_heads), msg + msg = "LinearAttentionItem requires attention_features and attention_heads" + assert ( + exists(channels) and exists(attention_features) and exists(attention_heads) + ), msg Item = SelectX(T(Attention)(attention_base_t=LinearAttentionBase)) return Packed( Item( # type: ignore @@ -254,7 +174,8 @@ def LinearCrossAttentionItem( ) -> nn.Module: msg = "LinearCrossAttentionItem requires channels, embedding_features, attention_*" assert ( - exists(embedding_features) + exists(channels) + and exists(embedding_features) and exists(attention_features) and exists(attention_heads) ), msg @@ -271,10 +192,154 @@ def LinearCrossAttentionItem( def FeedForwardItem( channels: Optional[int] = None, attention_multiplier: Optional[int] = None, **kwargs -): - msg = "FeedForwardItem block requires channels, attention_multiplier" +) -> nn.Module: + msg = "FeedForwardItem requires channels, attention_multiplier" assert exists(channels) and exists(attention_multiplier), msg Item = SelectX(FeedForward) return Packed( Item(features=channels, multiplier=attention_multiplier) # type: ignore ) + + +""" Skip Adapters """ + + +def SkipAdapterItem( + dim: Optional[int] = None, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + **kwargs +): + msg = "SkipAdapterItem requires dim, in_channels, out_channels" + assert exists(dim) and exists(in_channels) and exists(out_channels), msg + Item = SelectX(Conv) + return ( + Item( # type: ignore + dim=dim, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + if in_channels != out_channels + else SelectX(nn.Identity)() + ) + + +""" Skip Connections """ + + +def SkipAddItem(**kwargs) -> nn.Module: + return MergeAdd() + + +def SkipCatItem( + dim: Optional[int] = None, out_channels: Optional[int] = None, **kwargs +) -> nn.Module: + msg = "SkipCatItem requires dim, out_channels" + assert exists(dim) and exists(out_channels), msg + return MergeCat(dim=dim, channels=out_channels) + + +def SkipModulateItem( + dim: Optional[int] = None, + out_channels: Optional[int] = None, + modulation_features: Optional[int] = None, + **kwargs +) -> nn.Module: + msg = "SkipModulateItem requires dim, out_channels, modulation_features" + assert exists(dim) and exists(out_channels) and exists(modulation_features), msg + return MergeModulate( + dim=dim, channels=out_channels, modulation_features=modulation_features + ) + + +""" Block """ + + +class Block(nn.Module): + def __init__( + self, + in_channels: int, + downsample_t: Callable = DownsampleItem, + upsample_t: Callable = UpsampleItem, + skip_t: Callable = SkipAddItem, + skip_adapter_t: Callable = SkipAdapterItem, + items: Sequence[Callable] = [], + items_up: Optional[Sequence[Callable]] = None, + out_channels: Optional[int] = None, + inner_block: Optional[nn.Module] = None, + **kwargs + ): + super().__init__() + out_channels = default(out_channels, in_channels) + + items_up = default(items_up, items) # type: ignore + items_down = [downsample_t] + list(items) + items_up = list(items_up) + [upsample_t] + items_kwargs = dict( + in_channels=in_channels, out_channels=out_channels, **kwargs + ) + + # Build items stack: items down -> inner block -> items up + items_all: List[nn.Module] = [] + items_all += [item_t(**items_kwargs) for item_t in items_down] + items_all += [inner_block] if exists(inner_block) else [] + items_all += [item_t(**items_kwargs) for item_t in items_up] + + self.skip_adapter = skip_adapter_t(**items_kwargs) + self.block = Sequential(*items_all) + self.skip = skip_t(**items_kwargs) + + def forward( + self, + x: Tensor, + features: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + channels: Optional[Sequence[Tensor]] = None, + ) -> Tensor: + skip = self.skip_adapter(x) + x = self.block(x, features, embedding, channels) + x = self.skip(skip, x, features) + return x + + +# Block type, to be provided in UNet +BlockT = T(Block, override=False) + + +""" UNet """ + + +class UNet(nn.Module): + def __init__( + self, + in_channels: int, + blocks: Sequence, + out_channels: Optional[int] = None, + **kwargs + ): + super().__init__() + num_layers = len(blocks) + out_channels = default(out_channels, in_channels) + + def Net(i: int) -> Optional[nn.Module]: + if i == num_layers: + return None # noqa + block_t = blocks[i] + in_ch = in_channels if i == 0 else blocks[i - 1].channels + out_ch = out_channels if i == 0 else in_ch + + return block_t( + in_channels=in_ch, out_channels=out_ch, inner_block=Net(i + 1), **kwargs + ) + + self.net = Net(0) + + def forward( + self, + x: Tensor, + features: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + channels: Optional[Sequence[Tensor]] = None, + ) -> Tensor: + return self.net(x, features, embedding, channels) # type: ignore