From 4e6ee2c0299d79faab7d483a73756523700195ca Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Mon, 19 Sep 2022 10:51:29 +0200 Subject: [PATCH] fix: convout dual channel, add skip connection --- audio_diffusion_pytorch/modules.py | 18 ++++++++---------- setup.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index ab17d0e..68cc3b2 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -25,15 +25,13 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module: class ConvOut1d(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int] - ): + def __init__(self, channels: int, kernel_sizes: Sequence[int]): super().__init__() - mid_channels = in_channels * 16 + mid_channels = channels * 16 self.convs_in = nn.ModuleList( Conv1d( - in_channels=in_channels, + in_channels=channels, out_channels=mid_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, @@ -49,14 +47,15 @@ def __init__( ) self.conv_out = Conv1d( - in_channels=mid_channels, out_channels=out_channels, kernel_size=1 + in_channels=mid_channels, out_channels=channels, kernel_size=1 ) def forward(self, x: Tensor) -> Tensor: + skip = x xs = torch.stack([conv(x) for conv in self.convs_in]) - x = reduce(xs, "n b c t -> b c t", "sum") + x + x = reduce(xs, "n b c t -> b c t", "sum") x = self.conv_mid(x) - x = self.conv_out(x) + x = self.conv_out(x) + skip return x @@ -932,8 +931,7 @@ def __init__( ), Rearrange("b (c p) l -> b c (l p)", p=patch_size), ConvOut1d( - in_channels=out_channels, - out_channels=out_channels, + channels=out_channels, kernel_sizes=kernel_sizes_out, ) if exists(kernel_sizes_out) diff --git a/setup.py b/setup.py index 73d78de..03897df 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.32", + version="0.0.33", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",