From e0933e73a60cf9e7d42fcb68c173b23d903f258d Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 3 Feb 2023 15:35:34 +0100 Subject: [PATCH] feat: add resnet, upsample, downsample kernel_size/width, more upsample options --- a_unet/apex.py | 45 +++++++++++++++++++++++++++++++------- a_unet/blocks.py | 57 +++++++++++++++++++++++++++++++++++++++++++----- setup.py | 2 +- 3 files changed, 89 insertions(+), 15 deletions(-) diff --git a/a_unet/apex.py b/a_unet/apex.py index 02cfc7a..547fb70 100644 --- a/a_unet/apex.py +++ b/a_unet/apex.py @@ -23,6 +23,7 @@ Sequential, T, Upsample, + UpsampleInterpolate, default, exists, ) @@ -46,6 +47,7 @@ def DownsampleItem( factor: Optional[int] = None, in_channels: Optional[int] = None, channels: Optional[int] = None, + downsample_width: int = 1, **kwargs, ) -> nn.Module: msg = "DownsampleItem requires dim, factor, in_channels, channels" @@ -54,7 +56,11 @@ def DownsampleItem( ), msg Item = SelectX(Downsample) return Item( # type: ignore - dim=dim, factor=factor, in_channels=in_channels, out_channels=channels + dim=dim, + factor=factor, + width=downsample_width, + in_channels=in_channels, + out_channels=channels, ) @@ -63,16 +69,34 @@ def UpsampleItem( factor: Optional[int] = None, channels: Optional[int] = None, out_channels: Optional[int] = None, + upsample_mode: str = "nearest", + upsample_kernel_size: int = 3, # Used with upsample_mode != "transpose" + upsample_width: int = 1, # Used with upsample_mode == "transpose" **kwargs, ) -> nn.Module: msg = "UpsampleItem requires dim, factor, channels, out_channels" assert ( exists(dim) and exists(factor) and exists(channels) and exists(out_channels) ), msg - Item = SelectX(Upsample) - return Item( # type: ignore - dim=dim, factor=factor, in_channels=channels, out_channels=out_channels - ) + if upsample_mode == "transpose": + Item = SelectX(Upsample) + return Item( # type: ignore + dim=dim, + factor=factor, + width=upsample_width, + in_channels=channels, + out_channels=out_channels, + ) + else: + Item = SelectX(UpsampleInterpolate) + return Item( # type: ignore + dim=dim, + factor=factor, + mode=upsample_mode, + kernel_size=upsample_kernel_size, + in_channels=channels, + out_channels=out_channels, + ) """ Main """ @@ -82,15 +106,20 @@ def ResnetItem( dim: Optional[int] = None, channels: Optional[int] = None, resnet_groups: Optional[int] = None, + resnet_kernel_size: int = 3, **kwargs, ) -> nn.Module: msg = "ResnetItem requires dim, channels, and resnet_groups" assert exists(dim) and exists(channels) and exists(resnet_groups), msg Item = SelectX(ResnetBlock) conv_block_t = T(ConvBlock)(norm_t=T(nn.GroupNorm)(num_groups=resnet_groups)) - return Item( - dim=dim, in_channels=channels, out_channels=channels, conv_block_t=conv_block_t - ) # type: ignore + return Item( # type: ignore + dim=dim, + in_channels=channels, + out_channels=channels, + kernel_size=resnet_kernel_size, + conv_block_t=conv_block_t, + ) def ConvNextV2Item( diff --git a/a_unet/blocks.py b/a_unet/blocks.py index 62892d5..20389a1 100644 --- a/a_unet/blocks.py +++ b/a_unet/blocks.py @@ -133,16 +133,57 @@ def Conv(dim: int, *args, **kwargs) -> nn.Module: return [nn.Conv1d, nn.Conv2d, nn.Conv3d][dim - 1](*args, **kwargs) -def Downsample(dim: int, factor: int = 2, conv_t=Conv, **kwargs) -> nn.Module: - return conv_t(dim=dim, kernel_size=factor, stride=factor, **kwargs) +def ConvTranspose(dim: int, *args, **kwargs) -> nn.Module: + return [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d][dim - 1]( + *args, **kwargs + ) + + +def Downsample( + dim: int, factor: int = 2, width: int = 1, conv_t=Conv, **kwargs +) -> nn.Module: + width = width if factor > 1 else 1 + return conv_t( + dim=dim, + kernel_size=factor * width, + stride=factor, + padding=(factor * width - factor) // 2, + **kwargs, + ) def Upsample( - dim: int, factor: int = 2, mode: str = "nearest", conv_t=Conv, **kwargs + dim: int, + factor: int = 2, + width: int = 1, + conv_t=Conv, + conv_tranpose_t=ConvTranspose, + **kwargs, +) -> nn.Module: + width = width if factor > 1 else 1 + return conv_tranpose_t( + dim=dim, + kernel_size=factor * width, + stride=factor, + padding=(factor * width - factor) // 2, + **kwargs, + ) + + +def UpsampleInterpolate( + dim: int, + factor: int = 2, + kernel_size: int = 3, + mode: str = "nearest", + conv_t=Conv, + **kwargs, ) -> nn.Module: + assert kernel_size % 2 == 1, "upsample kernel size must be odd" return nn.Sequential( - nn.Upsample(scale_factor=factor, mode="nearest"), - conv_t(dim=dim, kernel_size=3, padding=1, **kwargs), + nn.Upsample(scale_factor=factor, mode=mode), + conv_t( + dim=dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, **kwargs + ), ) @@ -165,10 +206,14 @@ def ResnetBlock( dim: int, in_channels: int, out_channels: int, + kernel_size: int = 3, conv_block_t=ConvBlock, conv_t=Conv, + **kwargs, ) -> nn.Module: - ConvBlock = T(conv_block_t)(dim=dim, kernel_size=3, padding=1) + ConvBlock = T(conv_block_t)( + dim=dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, **kwargs + ) Conv = T(conv_t)(dim=dim, kernel_size=1) conv_block = Sequential( diff --git a/setup.py b/setup.py index cb1b19a..a902cca 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="a-unet", packages=find_packages(exclude=[]), - version="0.0.15", + version="0.0.16", license="MIT", description="A-UNet", long_description_content_type="text/markdown",