Skip to content

Commit

Permalink
feat: add option to train upsampler with multiple factors
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 7, 2022
1 parent 73a0d5c commit 36fc9be
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
26 changes: 18 additions & 8 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from math import prod
from typing import Optional, Sequence
from typing import Optional, Sequence, Union

import torch
from torch import Tensor, nn
Expand All @@ -15,6 +16,7 @@
Schedule,
)
from .modules import Encoder1d, ResnetBlock1d, UNet1d
from .utils import default, to_list

""" Diffusion Classes (generic for 1d data) """

Expand Down Expand Up @@ -95,24 +97,32 @@ def sample(


class DiffusionUpsampler1d(Model1d):
def __init__(self, factor: int, in_channels: int, *args, **kwargs):
self.factor = factor
def __init__(
self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs
):
self.factor = to_list(factor)
default_kwargs = dict(
in_channels=in_channels,
context_channels=[in_channels],
)
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore

def forward(self, x: Tensor, **kwargs) -> Tensor:
def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor:
# Either user provides factor or we pick one at random
factor = default(factor, random.choice(self.factor))
# Downsample by picking every `factor` item
downsampled = x[:, :, :: self.factor]
downsampled = x[:, :, ::factor]
# Upsample by interleaving to get context
context = torch.repeat_interleave(downsampled, repeats=self.factor, dim=2)
context = torch.repeat_interleave(downsampled, repeats=factor, dim=2)
return self.diffusion(x, context=[context], **kwargs)

def sample(self, undersampled: Tensor, *args, **kwargs): # type: ignore
def sample( # type: ignore
self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs
):
# Either user provides factor or we pick the first
factor = default(factor, self.factor[0])
# Upsample context by interleaving
context = torch.repeat_interleave(undersampled, repeats=self.factor, dim=2)
context = torch.repeat_interleave(undersampled, repeats=factor, dim=2)
noise = torch.randn_like(context)
default_kwargs = dict(context=[context])
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
Expand Down
10 changes: 9 additions & 1 deletion audio_diffusion_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from inspect import isfunction
from typing import Callable, Optional, TypeVar, Union
from typing import Callable, List, Optional, Sequence, TypeVar, Union

from typing_extensions import TypeGuard

Expand All @@ -22,3 +22,11 @@ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
if exists(val):
return val
return d() if isfunction(d) else d


def to_list(val: Union[T, Sequence[T]]) -> List[T]:
if isinstance(val, tuple):
return list(val)
if isinstance(val, list):
return val
return [val] # type: ignore
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.24",
version="0.0.25",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 36fc9be

Please sign in to comment.