From ce86eb845ad4d1272364d67a290a795a2a536a16 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Sun, 1 Jan 2023 20:06:11 +0100 Subject: [PATCH] feat: add inject channels item --- a_unet/apex.py | 41 +++++++++++++++++++++++++++++++++++++++-- a_unet/blocks.py | 2 +- setup.py | 2 +- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/a_unet/apex.py b/a_unet/apex.py index dd653fd..39ecdf9 100644 --- a/a_unet/apex.py +++ b/a_unet/apex.py @@ -1,5 +1,6 @@ -from typing import Callable, List, Optional, Sequence +from typing import Callable, List, Optional, Sequence, no_type_check +import torch from torch import Tensor, nn from .blocks import ( @@ -14,6 +15,7 @@ MergeCat, MergeModulate, Modulation, + Module, Packed, ResnetBlock, Select, @@ -30,8 +32,9 @@ # Selections for item forward parameters SelectX = Select(lambda x, *_: (x,)) -SelectXE = Select(lambda x, f, e, *_: (x, e)) SelectXF = Select(lambda x, f, *_: (x, f)) +SelectXE = Select(lambda x, f, e, *_: (x, e)) +SelectXC = Select(lambda x, f, e, c, *_: (x, c)) """ Downsample / Upsample """ @@ -201,6 +204,40 @@ def FeedForwardItem( ) +def InjectChannelsItem( + dim: Optional[int] = None, + channels: Optional[int] = None, + depth: Optional[int] = None, + context_channels: Optional[int] = None, + **kwargs, +) -> nn.Module: + msg = "InjectChannelsItem requires dim, depth, channels, context_channels" + assert ( + exists(dim) and exists(depth) and exists(channels) and exists(context_channels) + ), msg + msg = "InjectChannelsItem requires context_channels > 0" + assert context_channels > 0, msg + + conv = Conv( + dim=dim, + in_channels=channels + context_channels, + out_channels=channels, + kernel_size=1, + ) + + @no_type_check + def forward(x: Tensor, channels: Sequence[Optional[Tensor]]) -> Tensor: + msg_ = f"context `channels` at depth {depth} in forward" + assert depth < len(channels), f"Required {msg_}" + context = channels[depth] + shape = torch.Size([x.shape[0], context_channels, *x.shape[2:]]) + msg = f"Required {msg_} to be tensor of shape {list(shape)}" + assert torch.is_tensor(context) and context.shape == shape, msg + return conv(torch.cat([x, context], dim=1)) + x + + return SelectXC(Module)([conv], forward) # type: ignore + + """ Skip Adapters """ diff --git a/a_unet/blocks.py b/a_unet/blocks.py index ddfadee..35cb5ac 100644 --- a/a_unet/blocks.py +++ b/a_unet/blocks.py @@ -517,7 +517,7 @@ def TextConditioningPlugin( features: int = embedder.embedding_features # type: ignore def Net(embedding_features: int = features, **kwargs) -> nn.Module: - msg = f"TextConditioningPlugin requires embedding_features={features} " + msg = f"TextConditioningPlugin requires embedding_features={features}" assert embedding_features == features, msg net = net_t(embedding_features=embedding_features, **kwargs) # type: ignore diff --git a/setup.py b/setup.py index 190de26..33c9ac2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="a-unet", packages=find_packages(exclude=[]), - version="0.0.9", + version="0.0.10", license="MIT", description="A-UNet", long_description_content_type="text/markdown",