Skip to content

Commit

Permalink
fix: convout dual channel, add skip connection
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 19, 2022
1 parent 7004f00 commit 4e6ee2c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
18 changes: 8 additions & 10 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:


class ConvOut1d(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int]
):
def __init__(self, channels: int, kernel_sizes: Sequence[int]):
super().__init__()
mid_channels = in_channels * 16
mid_channels = channels * 16

self.convs_in = nn.ModuleList(
Conv1d(
in_channels=in_channels,
in_channels=channels,
out_channels=mid_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
Expand All @@ -49,14 +47,15 @@ def __init__(
)

self.conv_out = Conv1d(
in_channels=mid_channels, out_channels=out_channels, kernel_size=1
in_channels=mid_channels, out_channels=channels, kernel_size=1
)

def forward(self, x: Tensor) -> Tensor:
skip = x
xs = torch.stack([conv(x) for conv in self.convs_in])
x = reduce(xs, "n b c t -> b c t", "sum") + x
x = reduce(xs, "n b c t -> b c t", "sum")
x = self.conv_mid(x)
x = self.conv_out(x)
x = self.conv_out(x) + skip
return x


Expand Down Expand Up @@ -932,8 +931,7 @@ def __init__(
),
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
ConvOut1d(
in_channels=out_channels,
out_channels=out_channels,
channels=out_channels,
kernel_sizes=kernel_sizes_out,
)
if exists(kernel_sizes_out)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.32",
version="0.0.33",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 4e6ee2c

Please sign in to comment.