Skip to content

Commit

Permalink
feat: option to use sequential patching
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 22, 2022
1 parent 537764c commit 4af7157
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ def __init__(
context_channels: Optional[Sequence[int]] = None,
context_embedding_features: Optional[int] = None,
use_post_out_block: bool = False,
use_sequential_patching: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -873,8 +874,10 @@ def __init__(
and len(num_blocks) == num_layers
)

patching = "b c (p l)" if use_sequential_patching else "b c (l p)"

self.to_in = nn.Sequential(
Rearrange("b c (l p) -> b (c p) l", p=patch_size),
Rearrange(f"{patching} -> b (c p) l", p=patch_size),
CrossEmbed1d(
in_channels=(in_channels + context_channels[0]) * patch_size,
out_channels=channels,
Expand Down Expand Up @@ -981,7 +984,7 @@ def __init__(
out_channels=out_channels * patch_size,
kernel_size=1,
),
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
Rearrange(f"b (c p) l -> {patching}", p=patch_size),
)

if self.use_post_out_block:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.38",
version="0.0.39",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 4af7157

Please sign in to comment.