diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 489ce9c..0860929 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -1,5 +1,6 @@ +import random from math import prod -from typing import Optional, Sequence +from typing import Optional, Sequence, Union import torch from torch import Tensor, nn @@ -15,6 +16,7 @@ Schedule, ) from .modules import Encoder1d, ResnetBlock1d, UNet1d +from .utils import default, to_list """ Diffusion Classes (generic for 1d data) """ @@ -95,24 +97,32 @@ def sample( class DiffusionUpsampler1d(Model1d): - def __init__(self, factor: int, in_channels: int, *args, **kwargs): - self.factor = factor + def __init__( + self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs + ): + self.factor = to_list(factor) default_kwargs = dict( in_channels=in_channels, context_channels=[in_channels], ) super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore - def forward(self, x: Tensor, **kwargs) -> Tensor: + def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor: + # Either user provides factor or we pick one at random + factor = default(factor, random.choice(self.factor)) # Downsample by picking every `factor` item - downsampled = x[:, :, :: self.factor] + downsampled = x[:, :, ::factor] # Upsample by interleaving to get context - context = torch.repeat_interleave(downsampled, repeats=self.factor, dim=2) + context = torch.repeat_interleave(downsampled, repeats=factor, dim=2) return self.diffusion(x, context=[context], **kwargs) - def sample(self, undersampled: Tensor, *args, **kwargs): # type: ignore + def sample( # type: ignore + self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs + ): + # Either user provides factor or we pick the first + factor = default(factor, self.factor[0]) # Upsample context by interleaving - context = torch.repeat_interleave(undersampled, repeats=self.factor, dim=2) + context = torch.repeat_interleave(undersampled, repeats=factor, dim=2) noise = torch.randn_like(context) default_kwargs = dict(context=[context]) return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore diff --git a/audio_diffusion_pytorch/utils.py b/audio_diffusion_pytorch/utils.py index 8debb37..48d80c4 100644 --- a/audio_diffusion_pytorch/utils.py +++ b/audio_diffusion_pytorch/utils.py @@ -1,5 +1,5 @@ from inspect import isfunction -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, List, Optional, Sequence, TypeVar, Union from typing_extensions import TypeGuard @@ -22,3 +22,11 @@ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: if exists(val): return val return d() if isfunction(d) else d + + +def to_list(val: Union[T, Sequence[T]]) -> List[T]: + if isinstance(val, tuple): + return list(val) + if isinstance(val, list): + return val + return [val] # type: ignore diff --git a/setup.py b/setup.py index ff6e9b0..4339d90 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.24", + version="0.0.25", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",