diff --git a/api_docs/audiocraft/adversarial/discriminators/base.html b/api_docs/audiocraft/adversarial/discriminators/base.html new file mode 100644 index 00000000..fe4eb186 --- /dev/null +++ b/api_docs/audiocraft/adversarial/discriminators/base.html @@ -0,0 +1,205 @@ + + + + + + +audiocraft.adversarial.discriminators.base API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.adversarial.discriminators.base

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+import typing as tp
+
+import torch
+import torch.nn as nn
+
+
+FeatureMapType = tp.List[torch.Tensor]
+LogitsType = torch.Tensor
+MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
+
+
+class MultiDiscriminator(ABC, nn.Module):
+    """Base implementation for discriminators composed of sub-discriminators acting at different scales.
+    """
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        ...
+
+    @property
+    @abstractmethod
+    def num_discriminators(self) -> int:
+        """Number of discriminators.
+        """
+        ...
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MultiDiscriminator +
+
+

Base implementation for discriminators composed of sub-discriminators acting at different scales.

+
+ +Expand source code + +
class MultiDiscriminator(ABC, nn.Module):
+    """Base implementation for discriminators composed of sub-discriminators acting at different scales.
+    """
+    def __init__(self):
+        super().__init__()
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        ...
+
+    @property
+    @abstractmethod
+    def num_discriminators(self) -> int:
+        """Number of discriminators.
+        """
+        ...
+
+

Ancestors

+
    +
  • abc.ABC
  • +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var num_discriminators : int
+
+

Number of discriminators.

+
+ +Expand source code + +
@property
+@abstractmethod
+def num_discriminators(self) -> int:
+    """Number of discriminators.
+    """
+    ...
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Tuple[List[torch.Tensor], List[List[torch.Tensor]]] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
@abstractmethod
+def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+    ...
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/adversarial/discriminators/index.html b/api_docs/audiocraft/adversarial/discriminators/index.html new file mode 100644 index 00000000..9eaf91b5 --- /dev/null +++ b/api_docs/audiocraft/adversarial/discriminators/index.html @@ -0,0 +1,95 @@ + + + + + + +audiocraft.adversarial.discriminators API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.adversarial.discriminators

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .mpd import MultiPeriodDiscriminator
+from .msd import MultiScaleDiscriminator
+from .msstftd import MultiScaleSTFTDiscriminator
+
+
+
+

Sub-modules

+
+
audiocraft.adversarial.discriminators.base
+
+
+
+
audiocraft.adversarial.discriminators.mpd
+
+
+
+
audiocraft.adversarial.discriminators.msd
+
+
+
+
audiocraft.adversarial.discriminators.msstftd
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/adversarial/discriminators/mpd.html b/api_docs/audiocraft/adversarial/discriminators/mpd.html new file mode 100644 index 00000000..806267f0 --- /dev/null +++ b/api_docs/audiocraft/adversarial/discriminators/mpd.html @@ -0,0 +1,446 @@ + + + + + + +audiocraft.adversarial.discriminators.mpd API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.adversarial.discriminators.mpd

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...modules import NormConv2d
+from .base import MultiDiscriminator, MultiDiscriminatorOutputType
+
+
+def get_padding(kernel_size: int, dilation: int = 1) -> int:
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+class PeriodDiscriminator(nn.Module):
+    """Period sub-discriminator.
+
+    Args:
+        period (int): Period between samples of audio.
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        n_layers (int): Number of convolutional layers.
+        kernel_sizes (list of int): Kernel sizes for convolutions.
+        stride (int): Stride for convolutions.
+        filters (int): Initial number of filters in convolutions.
+        filters_scale (int): Multiplier of number of filters as we increase depth.
+        max_filters (int): Maximum number of filters.
+        norm (str): Normalization method.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+    """
+    def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
+                 n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
+                 filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
+                 norm: str = 'weight_norm', activation: str = 'LeakyReLU',
+                 activation_params: dict = {'negative_slope': 0.2}):
+        super().__init__()
+        self.period = period
+        self.n_layers = n_layers
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.convs = nn.ModuleList()
+        in_chs = in_channels
+        for i in range(self.n_layers):
+            out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
+            eff_stride = 1 if i == self.n_layers - 1 else stride
+            self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
+                                         padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
+            in_chs = out_chs
+        self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
+                                    padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), 'reflect')
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for conv in self.convs:
+            x = conv(x)
+            x = self.activation(x)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        # x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class MultiPeriodDiscriminator(MultiDiscriminator):
+    """Multi-Period (MPD) Discriminator.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
+        **kwargs: Additional args for `PeriodDiscriminator`
+    """
+    def __init__(self, in_channels: int = 1, out_channels: int = 1,
+                 periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
+        super().__init__()
+        self.discriminators = nn.ModuleList([
+            PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
+        ])
+
+    @property
+    def num_discriminators(self):
+        return len(self.discriminators)
+
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+        return logits, fmaps
+
+
+
+
+
+
+
+

Functions

+
+
+def get_padding(kernel_size: int, dilation: int = 1) ‑> int +
+
+
+
+ +Expand source code + +
def get_padding(kernel_size: int, dilation: int = 1) -> int:
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+
+
+
+

Classes

+
+
+class MultiPeriodDiscriminator +(in_channels: int = 1, out_channels: int = 1, periods: Sequence[int] = [2, 3, 5, 7, 11], **kwargs) +
+
+

Multi-Period (MPD) Discriminator.

+

Args

+
+
in_channels : int
+
Number of input channels.
+
out_channels : int
+
Number of output channels.
+
periods : Sequence[int]
+
Periods between samples of audio for the sub-discriminators.
+
**kwargs
+
Additional args for PeriodDiscriminator
+
+
+ +Expand source code + +
class MultiPeriodDiscriminator(MultiDiscriminator):
+    """Multi-Period (MPD) Discriminator.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
+        **kwargs: Additional args for `PeriodDiscriminator`
+    """
+    def __init__(self, in_channels: int = 1, out_channels: int = 1,
+                 periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
+        super().__init__()
+        self.discriminators = nn.ModuleList([
+            PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
+        ])
+
+    @property
+    def num_discriminators(self):
+        return len(self.discriminators)
+
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+        return logits, fmaps
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class PeriodDiscriminator +(period: int, in_channels: int = 1, out_channels: int = 1, n_layers: int = 5, kernel_sizes: List[int] = [5, 3], stride: int = 3, filters: int = 8, filters_scale: int = 4, max_filters: int = 1024, norm: str = 'weight_norm', activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}) +
+
+

Period sub-discriminator.

+

Args

+
+
period : int
+
Period between samples of audio.
+
in_channels : int
+
Number of input channels.
+
out_channels : int
+
Number of output channels.
+
n_layers : int
+
Number of convolutional layers.
+
kernel_sizes : list of int
+
Kernel sizes for convolutions.
+
stride : int
+
Stride for convolutions.
+
filters : int
+
Initial number of filters in convolutions.
+
filters_scale : int
+
Multiplier of number of filters as we increase depth.
+
max_filters : int
+
Maximum number of filters.
+
norm : str
+
Normalization method.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class PeriodDiscriminator(nn.Module):
+    """Period sub-discriminator.
+
+    Args:
+        period (int): Period between samples of audio.
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        n_layers (int): Number of convolutional layers.
+        kernel_sizes (list of int): Kernel sizes for convolutions.
+        stride (int): Stride for convolutions.
+        filters (int): Initial number of filters in convolutions.
+        filters_scale (int): Multiplier of number of filters as we increase depth.
+        max_filters (int): Maximum number of filters.
+        norm (str): Normalization method.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+    """
+    def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
+                 n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
+                 filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
+                 norm: str = 'weight_norm', activation: str = 'LeakyReLU',
+                 activation_params: dict = {'negative_slope': 0.2}):
+        super().__init__()
+        self.period = period
+        self.n_layers = n_layers
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.convs = nn.ModuleList()
+        in_chs = in_channels
+        for i in range(self.n_layers):
+            out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
+            eff_stride = 1 if i == self.n_layers - 1 else stride
+            self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
+                                         padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
+            in_chs = out_chs
+        self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
+                                    padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), 'reflect')
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for conv in self.convs:
+            x = conv(x)
+            x = self.activation(x)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        # x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor):
+    fmap = []
+    # 1d to 2d
+    b, c, t = x.shape
+    if t % self.period != 0:  # pad first
+        n_pad = self.period - (t % self.period)
+        x = F.pad(x, (0, n_pad), 'reflect')
+        t = t + n_pad
+    x = x.view(b, c, t // self.period, self.period)
+
+    for conv in self.convs:
+        x = conv(x)
+        x = self.activation(x)
+        fmap.append(x)
+    x = self.conv_post(x)
+    fmap.append(x)
+    # x = torch.flatten(x, 1, -1)
+
+    return x, fmap
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/adversarial/discriminators/msd.html b/api_docs/audiocraft/adversarial/discriminators/msd.html new file mode 100644 index 00000000..4294eacb --- /dev/null +++ b/api_docs/audiocraft/adversarial/discriminators/msd.html @@ -0,0 +1,468 @@ + + + + + + +audiocraft.adversarial.discriminators.msd API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.adversarial.discriminators.msd

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...modules import NormConv1d
+from .base import MultiDiscriminator, MultiDiscriminatorOutputType
+
+
+class ScaleDiscriminator(nn.Module):
+    """Waveform sub-discriminator.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
+        filters (int): Number of initial filters for convolutions.
+        max_filters (int): Maximum number of filters.
+        downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
+        inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
+        groups (Sequence[int] or None): Groups for inner convolutions.
+        strides (Sequence[int] or None): Strides for inner convolutions.
+        paddings (Sequence[int] or None): Paddings for inner convolutions.
+        norm (str): Normalization method.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        pad (str): Padding for initial convolution.
+        pad_params (dict): Parameters to provide to the padding module.
+    """
+    def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
+                 filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
+                 inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
+                 strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
+                 norm: str = 'weight_norm', activation: str = 'LeakyReLU',
+                 activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
+                 pad_params: dict = {}):
+        super().__init__()
+        assert len(kernel_sizes) == 2
+        assert kernel_sizes[0] % 2 == 1
+        assert kernel_sizes[1] % 2 == 1
+        assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
+        assert (groups is None or len(groups) == len(downsample_scales))
+        assert (strides is None or len(strides) == len(downsample_scales))
+        assert (paddings is None or len(paddings) == len(downsample_scales))
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.convs = nn.ModuleList()
+        self.convs.append(
+            nn.Sequential(
+                getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
+                NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
+            )
+        )
+
+        in_chs = filters
+        for i, downsample_scale in enumerate(downsample_scales):
+            out_chs = min(in_chs * downsample_scale, max_filters)
+            default_kernel_size = downsample_scale * 10 + 1
+            default_stride = downsample_scale
+            default_padding = (default_kernel_size - 1) // 2
+            default_groups = in_chs // 4
+            self.convs.append(
+                NormConv1d(in_chs, out_chs,
+                           kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
+                           stride=strides[i] if strides else default_stride,
+                           groups=groups[i] if groups else default_groups,
+                           padding=paddings[i] if paddings else default_padding,
+                           norm=norm))
+            in_chs = out_chs
+
+        out_chs = min(in_chs * 2, max_filters)
+        self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
+                                     padding=(kernel_sizes[0] - 1) // 2, norm=norm))
+        self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
+                                    padding=(kernel_sizes[1] - 1) // 2, norm=norm)
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        for layer in self.convs:
+            x = layer(x)
+            x = self.activation(x)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        # x = torch.flatten(x, 1, -1)
+        return x, fmap
+
+
+class MultiScaleDiscriminator(MultiDiscriminator):
+    """Multi-Scale (MSD) Discriminator,
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        downsample_factor (int): Downsampling factor between the different scales.
+        scale_norms (Sequence[str]): Normalization for each sub-discriminator.
+        **kwargs: Additional args for ScaleDiscriminator.
+    """
+    def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
+                 scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
+        super().__init__()
+        self.discriminators = nn.ModuleList([
+            ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
+        ])
+        self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
+
+    @property
+    def num_discriminators(self):
+        return len(self.discriminators)
+
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        logits = []
+        fmaps = []
+        for i, disc in enumerate(self.discriminators):
+            if i != 0:
+                self.downsample(x)
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+        return logits, fmaps
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MultiScaleDiscriminator +(in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2, scale_norms: Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs) +
+
+

Multi-Scale (MSD) Discriminator,

+

Args

+
+
in_channels : int
+
Number of input channels.
+
out_channels : int
+
Number of output channels.
+
downsample_factor : int
+
Downsampling factor between the different scales.
+
scale_norms : Sequence[str]
+
Normalization for each sub-discriminator.
+
**kwargs
+
Additional args for ScaleDiscriminator.
+
+
+ +Expand source code + +
class MultiScaleDiscriminator(MultiDiscriminator):
+    """Multi-Scale (MSD) Discriminator,
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        downsample_factor (int): Downsampling factor between the different scales.
+        scale_norms (Sequence[str]): Normalization for each sub-discriminator.
+        **kwargs: Additional args for ScaleDiscriminator.
+    """
+    def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
+                 scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
+        super().__init__()
+        self.discriminators = nn.ModuleList([
+            ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
+        ])
+        self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
+
+    @property
+    def num_discriminators(self):
+        return len(self.discriminators)
+
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        logits = []
+        fmaps = []
+        for i, disc in enumerate(self.discriminators):
+            if i != 0:
+                self.downsample(x)
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+        return logits, fmaps
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class ScaleDiscriminator +(in_channels=1, out_channels=1, kernel_sizes: Sequence[int] = [5, 3], filters: int = 16, max_filters: int = 1024, downsample_scales: Sequence[int] = [4, 4, 4, 4], inner_kernel_sizes: Optional[Sequence[int]] = None, groups: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, paddings: Optional[Sequence[int]] = None, norm: str = 'weight_norm', activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d', pad_params: dict = {}) +
+
+

Waveform sub-discriminator.

+

Args

+
+
in_channels : int
+
Number of input channels.
+
out_channels : int
+
Number of output channels.
+
kernel_sizes : Sequence[int]
+
Kernel sizes for first and last convolutions.
+
filters : int
+
Number of initial filters for convolutions.
+
max_filters : int
+
Maximum number of filters.
+
downsample_scales : Sequence[int]
+
Scale for downsampling implemented as strided convolutions.
+
inner_kernel_sizes : Sequence[int] or None
+
Kernel sizes for inner convolutions.
+
groups : Sequence[int] or None
+
Groups for inner convolutions.
+
strides : Sequence[int] or None
+
Strides for inner convolutions.
+
paddings : Sequence[int] or None
+
Paddings for inner convolutions.
+
norm : str
+
Normalization method.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
pad : str
+
Padding for initial convolution.
+
pad_params : dict
+
Parameters to provide to the padding module.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ScaleDiscriminator(nn.Module):
+    """Waveform sub-discriminator.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
+        filters (int): Number of initial filters for convolutions.
+        max_filters (int): Maximum number of filters.
+        downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
+        inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
+        groups (Sequence[int] or None): Groups for inner convolutions.
+        strides (Sequence[int] or None): Strides for inner convolutions.
+        paddings (Sequence[int] or None): Paddings for inner convolutions.
+        norm (str): Normalization method.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        pad (str): Padding for initial convolution.
+        pad_params (dict): Parameters to provide to the padding module.
+    """
+    def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
+                 filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
+                 inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
+                 strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
+                 norm: str = 'weight_norm', activation: str = 'LeakyReLU',
+                 activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
+                 pad_params: dict = {}):
+        super().__init__()
+        assert len(kernel_sizes) == 2
+        assert kernel_sizes[0] % 2 == 1
+        assert kernel_sizes[1] % 2 == 1
+        assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
+        assert (groups is None or len(groups) == len(downsample_scales))
+        assert (strides is None or len(strides) == len(downsample_scales))
+        assert (paddings is None or len(paddings) == len(downsample_scales))
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.convs = nn.ModuleList()
+        self.convs.append(
+            nn.Sequential(
+                getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
+                NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
+            )
+        )
+
+        in_chs = filters
+        for i, downsample_scale in enumerate(downsample_scales):
+            out_chs = min(in_chs * downsample_scale, max_filters)
+            default_kernel_size = downsample_scale * 10 + 1
+            default_stride = downsample_scale
+            default_padding = (default_kernel_size - 1) // 2
+            default_groups = in_chs // 4
+            self.convs.append(
+                NormConv1d(in_chs, out_chs,
+                           kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
+                           stride=strides[i] if strides else default_stride,
+                           groups=groups[i] if groups else default_groups,
+                           padding=paddings[i] if paddings else default_padding,
+                           norm=norm))
+            in_chs = out_chs
+
+        out_chs = min(in_chs * 2, max_filters)
+        self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
+                                     padding=(kernel_sizes[0] - 1) // 2, norm=norm))
+        self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
+                                    padding=(kernel_sizes[1] - 1) // 2, norm=norm)
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        for layer in self.convs:
+            x = layer(x)
+            x = self.activation(x)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        # x = torch.flatten(x, 1, -1)
+        return x, fmap
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor):
+    fmap = []
+    for layer in self.convs:
+        x = layer(x)
+        x = self.activation(x)
+        fmap.append(x)
+    x = self.conv_post(x)
+    fmap.append(x)
+    # x = torch.flatten(x, 1, -1)
+    return x, fmap
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/adversarial/discriminators/msstftd.html b/api_docs/audiocraft/adversarial/discriminators/msstftd.html new file mode 100644 index 00000000..0e396486 --- /dev/null +++ b/api_docs/audiocraft/adversarial/discriminators/msstftd.html @@ -0,0 +1,505 @@ + + + + + + +audiocraft.adversarial.discriminators.msstftd API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.adversarial.discriminators.msstftd

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import torchaudio
+import torch
+from torch import nn
+from einops import rearrange
+
+from ...modules import NormConv2d
+from .base import MultiDiscriminator, MultiDiscriminatorOutputType
+
+
+def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
+    return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
+
+
+class DiscriminatorSTFT(nn.Module):
+    """STFT sub-discriminator.
+
+    Args:
+        filters (int): Number of filters in convolutions.
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        n_fft (int): Size of FFT for each scale.
+        hop_length (int): Length of hop between STFT windows for each scale.
+        kernel_size (tuple of int): Inner Conv2d kernel sizes.
+        stride (tuple of int): Inner Conv2d strides.
+        dilations (list of int): Inner Conv2d dilation on the time dimension.
+        win_length (int): Window size for each scale.
+        normalized (bool): Whether to normalize by magnitude after stft.
+        norm (str): Normalization method.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        growth (int): Growth factor for the filters.
+    """
+    def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
+                 n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
+                 filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
+                 stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
+                 activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
+        super().__init__()
+        assert len(kernel_size) == 2
+        assert len(stride) == 2
+        self.filters = filters
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.normalized = normalized
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.spec_transform = torchaudio.transforms.Spectrogram(
+            n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
+            normalized=self.normalized, center=False, pad_mode=None, power=None)
+        spec_channels = 2 * self.in_channels
+        self.convs = nn.ModuleList()
+        self.convs.append(
+            NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
+        )
+        in_chs = min(filters_scale * self.filters, max_filters)
+        for i, dilation in enumerate(dilations):
+            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+            self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
+                                         dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
+                                         norm=norm))
+            in_chs = out_chs
+        out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
+        self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
+                                     padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+                                     norm=norm))
+        self.conv_post = NormConv2d(out_chs, self.out_channels,
+                                    kernel_size=(kernel_size[0], kernel_size[0]),
+                                    padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+                                    norm=norm)
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
+        z = torch.cat([z.real, z.imag], dim=1)
+        z = rearrange(z, 'b c w t -> b c t w')
+        for i, layer in enumerate(self.convs):
+            z = layer(z)
+            z = self.activation(z)
+            fmap.append(z)
+        z = self.conv_post(z)
+        return z, fmap
+
+
+class MultiScaleSTFTDiscriminator(MultiDiscriminator):
+    """Multi-Scale STFT (MS-STFT) discriminator.
+
+    Args:
+        filters (int): Number of filters in convolutions.
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        sep_channels (bool): Separate channels to distinct samples for stereo support.
+        n_ffts (Sequence[int]): Size of FFT for each scale.
+        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
+        win_lengths (Sequence[int]): Window size for each scale.
+        **kwargs: Additional args for STFTDiscriminator.
+    """
+    def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
+                 n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
+                 win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
+        super().__init__()
+        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+        self.sep_channels = sep_channels
+        self.discriminators = nn.ModuleList([
+            DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
+                              n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
+            for i in range(len(n_ffts))
+        ])
+
+    @property
+    def num_discriminators(self):
+        return len(self.discriminators)
+
+    def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
+        B, C, T = x.shape
+        return x.view(-1, 1, T)
+
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+        return logits, fmaps
+
+
+
+
+
+
+
+

Functions

+
+
+def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)) +
+
+
+
+ +Expand source code + +
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
+    return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
+
+
+
+
+
+

Classes

+
+
+class DiscriminatorSTFT +(filters: int, in_channels: int = 1, out_channels: int = 1, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, filters_scale: int = 1, kernel_size: Tuple[int, int] = (3, 9), dilations: List[~T] = [1, 2, 4], stride: Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}) +
+
+

STFT sub-discriminator.

+

Args

+
+
filters : int
+
Number of filters in convolutions.
+
in_channels : int
+
Number of input channels.
+
out_channels : int
+
Number of output channels.
+
n_fft : int
+
Size of FFT for each scale.
+
hop_length : int
+
Length of hop between STFT windows for each scale.
+
kernel_size : tuple of int
+
Inner Conv2d kernel sizes.
+
stride : tuple of int
+
Inner Conv2d strides.
+
dilations : list of int
+
Inner Conv2d dilation on the time dimension.
+
win_length : int
+
Window size for each scale.
+
normalized : bool
+
Whether to normalize by magnitude after stft.
+
norm : str
+
Normalization method.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
growth : int
+
Growth factor for the filters.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DiscriminatorSTFT(nn.Module):
+    """STFT sub-discriminator.
+
+    Args:
+        filters (int): Number of filters in convolutions.
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        n_fft (int): Size of FFT for each scale.
+        hop_length (int): Length of hop between STFT windows for each scale.
+        kernel_size (tuple of int): Inner Conv2d kernel sizes.
+        stride (tuple of int): Inner Conv2d strides.
+        dilations (list of int): Inner Conv2d dilation on the time dimension.
+        win_length (int): Window size for each scale.
+        normalized (bool): Whether to normalize by magnitude after stft.
+        norm (str): Normalization method.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        growth (int): Growth factor for the filters.
+    """
+    def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
+                 n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
+                 filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
+                 stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
+                 activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
+        super().__init__()
+        assert len(kernel_size) == 2
+        assert len(stride) == 2
+        self.filters = filters
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.normalized = normalized
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.spec_transform = torchaudio.transforms.Spectrogram(
+            n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
+            normalized=self.normalized, center=False, pad_mode=None, power=None)
+        spec_channels = 2 * self.in_channels
+        self.convs = nn.ModuleList()
+        self.convs.append(
+            NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
+        )
+        in_chs = min(filters_scale * self.filters, max_filters)
+        for i, dilation in enumerate(dilations):
+            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+            self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
+                                         dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
+                                         norm=norm))
+            in_chs = out_chs
+        out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
+        self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
+                                     padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+                                     norm=norm))
+        self.conv_post = NormConv2d(out_chs, self.out_channels,
+                                    kernel_size=(kernel_size[0], kernel_size[0]),
+                                    padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+                                    norm=norm)
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
+        z = torch.cat([z.real, z.imag], dim=1)
+        z = rearrange(z, 'b c w t -> b c t w')
+        for i, layer in enumerate(self.convs):
+            z = layer(z)
+            z = self.activation(z)
+            fmap.append(z)
+        z = self.conv_post(z)
+        return z, fmap
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor):
+    fmap = []
+    z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
+    z = torch.cat([z.real, z.imag], dim=1)
+    z = rearrange(z, 'b c w t -> b c t w')
+    for i, layer in enumerate(self.convs):
+        z = layer(z)
+        z = self.activation(z)
+        fmap.append(z)
+    z = self.conv_post(z)
+    return z, fmap
+
+
+
+
+
+class MultiScaleSTFTDiscriminator +(filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False, n_ffts: List[int] = [1024, 2048, 512], hop_lengths: List[int] = [256, 512, 128], win_lengths: List[int] = [1024, 2048, 512], **kwargs) +
+
+

Multi-Scale STFT (MS-STFT) discriminator.

+

Args

+
+
filters : int
+
Number of filters in convolutions.
+
in_channels : int
+
Number of input channels.
+
out_channels : int
+
Number of output channels.
+
sep_channels : bool
+
Separate channels to distinct samples for stereo support.
+
n_ffts : Sequence[int]
+
Size of FFT for each scale.
+
hop_lengths : Sequence[int]
+
Length of hop between STFT windows for each scale.
+
win_lengths : Sequence[int]
+
Window size for each scale.
+
**kwargs
+
Additional args for STFTDiscriminator.
+
+
+ +Expand source code + +
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
+    """Multi-Scale STFT (MS-STFT) discriminator.
+
+    Args:
+        filters (int): Number of filters in convolutions.
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        sep_channels (bool): Separate channels to distinct samples for stereo support.
+        n_ffts (Sequence[int]): Size of FFT for each scale.
+        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
+        win_lengths (Sequence[int]): Window size for each scale.
+        **kwargs: Additional args for STFTDiscriminator.
+    """
+    def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
+                 n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
+                 win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
+        super().__init__()
+        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+        self.sep_channels = sep_channels
+        self.discriminators = nn.ModuleList([
+            DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
+                              n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
+            for i in range(len(n_ffts))
+        ])
+
+    @property
+    def num_discriminators(self):
+        return len(self.discriminators)
+
+    def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
+        B, C, T = x.shape
+        return x.view(-1, 1, T)
+
+    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+        return logits, fmaps
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/adversarial/index.html b/api_docs/audiocraft/adversarial/index.html new file mode 100644 index 00000000..8d01cf1d --- /dev/null +++ b/api_docs/audiocraft/adversarial/index.html @@ -0,0 +1,98 @@ + + + + + + +audiocraft.adversarial API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.adversarial

+
+
+

Adversarial losses and discriminator architectures.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Adversarial losses and discriminator architectures."""
+
+# flake8: noqa
+from .discriminators import (
+    MultiPeriodDiscriminator,
+    MultiScaleDiscriminator,
+    MultiScaleSTFTDiscriminator
+)
+from .losses import (
+    AdversarialLoss,
+    AdvLossType,
+    get_adv_criterion,
+    get_fake_criterion,
+    get_real_criterion,
+    FeatLossType,
+    FeatureMatchingLoss
+)
+
+
+
+

Sub-modules

+
+
audiocraft.adversarial.discriminators
+
+
+
+
audiocraft.adversarial.losses
+
+

Utility module to handle adversarial losses without requiring to mess up the main training loop.

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/adversarial/losses.html b/api_docs/audiocraft/adversarial/losses.html new file mode 100644 index 00000000..7ad2d321 --- /dev/null +++ b/api_docs/audiocraft/adversarial/losses.html @@ -0,0 +1,855 @@ + + + + + + +audiocraft.adversarial.losses API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.adversarial.losses

+
+
+

Utility module to handle adversarial losses without requiring to mess up the main training loop.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility module to handle adversarial losses without requiring to mess up the main training loop.
+"""
+
+import typing as tp
+
+import flashy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
+
+
+AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
+FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
+
+
+class AdversarialLoss(nn.Module):
+    """Adversary training wrapper.
+
+    Args:
+        adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
+            We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
+            where the first item is a list of logits and the second item is a list of feature maps.
+        optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
+        loss (AdvLossType): Loss function for generator training.
+        loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
+        loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
+        loss_feat (FeatLossType): Feature matching loss function for generator training.
+        normalize (bool): Whether to normalize by number of sub-discriminators.
+
+    Example of usage:
+        adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
+        for real in loader:
+            noise = torch.randn(...)
+            fake = model(noise)
+            adv_loss.train_adv(fake, real)
+            loss, _ = adv_loss(fake, real)
+            loss.backward()
+    """
+    def __init__(self,
+                 adversary: nn.Module,
+                 optimizer: torch.optim.Optimizer,
+                 loss: AdvLossType,
+                 loss_real: AdvLossType,
+                 loss_fake: AdvLossType,
+                 loss_feat: tp.Optional[FeatLossType] = None,
+                 normalize: bool = True):
+        super().__init__()
+        self.adversary: nn.Module = adversary
+        flashy.distrib.broadcast_model(self.adversary)
+        self.optimizer = optimizer
+        self.loss = loss
+        self.loss_real = loss_real
+        self.loss_fake = loss_fake
+        self.loss_feat = loss_feat
+        self.normalize = normalize
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        # Add the optimizer state dict inside our own.
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'optimizer'] = self.optimizer.state_dict()
+        return destination
+
+    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+        # Load optimizer state.
+        self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
+        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+    def get_adversary_pred(self, x):
+        """Run adversary model, validating expected output format."""
+        logits, fmaps = self.adversary(x)
+        assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
+            f'Expecting a list of tensors as logits but {type(logits)} found.'
+        assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
+        for fmap in fmaps:
+            assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
+                f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
+        return logits, fmaps
+
+    def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
+        """Train the adversary with the given fake and real example.
+
+        We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
+        The first item being the logits and second item being a list of feature maps for each sub-discriminator.
+
+        This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
+        and call the optimizer.
+        """
+        loss = torch.tensor(0., device=fake.device)
+        all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
+        all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
+        n_sub_adversaries = len(all_logits_fake_is_fake)
+        for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
+            loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
+
+        if self.normalize:
+            loss /= n_sub_adversaries
+
+        self.optimizer.zero_grad()
+        with flashy.distrib.eager_sync_model(self.adversary):
+            loss.backward()
+        self.optimizer.step()
+
+        return loss
+
+    def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Return the loss for the generator, i.e. trying to fool the adversary,
+        and feature matching loss if provided.
+        """
+        adv = torch.tensor(0., device=fake.device)
+        feat = torch.tensor(0., device=fake.device)
+        with flashy.utils.readonly(self.adversary):
+            all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
+            all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
+            n_sub_adversaries = len(all_logits_fake_is_fake)
+            for logit_fake_is_fake in all_logits_fake_is_fake:
+                adv += self.loss(logit_fake_is_fake)
+            if self.loss_feat:
+                for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
+                    feat += self.loss_feat(fmap_fake, fmap_real)
+
+        if self.normalize:
+            adv /= n_sub_adversaries
+            feat /= n_sub_adversaries
+
+        return adv, feat
+
+
+def get_adv_criterion(loss_type: str) -> tp.Callable:
+    assert loss_type in ADVERSARIAL_LOSSES
+    if loss_type == 'mse':
+        return mse_loss
+    elif loss_type == 'hinge':
+        return hinge_loss
+    elif loss_type == 'hinge2':
+        return hinge2_loss
+    raise ValueError('Unsupported loss')
+
+
+def get_fake_criterion(loss_type: str) -> tp.Callable:
+    assert loss_type in ADVERSARIAL_LOSSES
+    if loss_type == 'mse':
+        return mse_fake_loss
+    elif loss_type in ['hinge', 'hinge2']:
+        return hinge_fake_loss
+    raise ValueError('Unsupported loss')
+
+
+def get_real_criterion(loss_type: str) -> tp.Callable:
+    assert loss_type in ADVERSARIAL_LOSSES
+    if loss_type == 'mse':
+        return mse_real_loss
+    elif loss_type in ['hinge', 'hinge2']:
+        return hinge_real_loss
+    raise ValueError('Unsupported loss')
+
+
+def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
+    return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
+
+
+def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
+    return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
+
+
+def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
+    return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
+    return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+def mse_loss(x: torch.Tensor) -> torch.Tensor:
+    if x.numel() == 0:
+        return torch.tensor([0.0], device=x.device)
+    return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
+
+
+def hinge_loss(x: torch.Tensor) -> torch.Tensor:
+    if x.numel() == 0:
+        return torch.tensor([0.0], device=x.device)
+    return -x.mean()
+
+
+def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
+    if x.numel() == 0:
+        return torch.tensor([0.0])
+    return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+class FeatureMatchingLoss(nn.Module):
+    """Feature matching loss for adversarial training.
+
+    Args:
+        loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
+        normalize (bool): Whether to normalize the loss.
+            by number of feature maps.
+    """
+    def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
+        super().__init__()
+        self.loss = loss
+        self.normalize = normalize
+
+    def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
+        assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
+        feat_loss = torch.tensor(0., device=fmap_fake[0].device)
+        feat_scale = torch.tensor(0., device=fmap_fake[0].device)
+        n_fmaps = 0
+        for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
+            assert feat_fake.shape == feat_real.shape
+            n_fmaps += 1
+            feat_loss += self.loss(feat_fake, feat_real)
+            feat_scale += torch.mean(torch.abs(feat_real))
+
+        if self.normalize:
+            feat_loss /= n_fmaps
+
+        return feat_loss
+
+
+
+
+
+
+
+

Functions

+
+
+def get_adv_criterion(loss_type: str) ‑> Callable +
+
+
+
+ +Expand source code + +
def get_adv_criterion(loss_type: str) -> tp.Callable:
+    assert loss_type in ADVERSARIAL_LOSSES
+    if loss_type == 'mse':
+        return mse_loss
+    elif loss_type == 'hinge':
+        return hinge_loss
+    elif loss_type == 'hinge2':
+        return hinge2_loss
+    raise ValueError('Unsupported loss')
+
+
+
+def get_fake_criterion(loss_type: str) ‑> Callable +
+
+
+
+ +Expand source code + +
def get_fake_criterion(loss_type: str) -> tp.Callable:
+    assert loss_type in ADVERSARIAL_LOSSES
+    if loss_type == 'mse':
+        return mse_fake_loss
+    elif loss_type in ['hinge', 'hinge2']:
+        return hinge_fake_loss
+    raise ValueError('Unsupported loss')
+
+
+
+def get_real_criterion(loss_type: str) ‑> Callable +
+
+
+
+ +Expand source code + +
def get_real_criterion(loss_type: str) -> tp.Callable:
+    assert loss_type in ADVERSARIAL_LOSSES
+    if loss_type == 'mse':
+        return mse_real_loss
+    elif loss_type in ['hinge', 'hinge2']:
+        return hinge_real_loss
+    raise ValueError('Unsupported loss')
+
+
+
+def hinge2_loss(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
+    if x.numel() == 0:
+        return torch.tensor([0.0])
+    return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+
+def hinge_fake_loss(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
+    return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+
+def hinge_loss(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
+    if x.numel() == 0:
+        return torch.tensor([0.0], device=x.device)
+    return -x.mean()
+
+
+
+def hinge_real_loss(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
+    return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
+
+
+
+def mse_fake_loss(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
+    return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
+
+
+
+def mse_loss(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def mse_loss(x: torch.Tensor) -> torch.Tensor:
+    if x.numel() == 0:
+        return torch.tensor([0.0], device=x.device)
+    return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
+
+
+
+def mse_real_loss(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
+    return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
+
+
+
+
+
+

Classes

+
+
+class AdversarialLoss +(adversary: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, loss: Union[torch.nn.modules.module.Module, Callable[[torch.Tensor], torch.Tensor]], loss_real: Union[torch.nn.modules.module.Module, Callable[[torch.Tensor], torch.Tensor]], loss_fake: Union[torch.nn.modules.module.Module, Callable[[torch.Tensor], torch.Tensor]], loss_feat: Union[torch.nn.modules.module.Module, Callable[[torch.Tensor, torch.Tensor], torch.Tensor], None] = None, normalize: bool = True) +
+
+

Adversary training wrapper.

+

Args

+
+
adversary : nn.Module
+
The adversary module will be used to estimate the logits given the fake and real samples. +We assume here the adversary output is Tuple[List[torch.Tensor], List[List[torch.Tensor]]] +where the first item is a list of logits and the second item is a list of feature maps.
+
optimizer : torch.optim.Optimizer
+
Optimizer used for training the given module.
+
loss : AdvLossType
+
Loss function for generator training.
+
loss_real : AdvLossType
+
Loss function for adversarial training on logits from real samples.
+
loss_fake : AdvLossType
+
Loss function for adversarial training on logits from fake samples.
+
loss_feat : FeatLossType
+
Feature matching loss function for generator training.
+
normalize : bool
+
Whether to normalize by number of sub-discriminators.
+
+

Example of usage: +adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake) +for real in loader: +noise = torch.randn(…) +fake = model(noise) +adv_loss.train_adv(fake, real) +loss, _ = adv_loss(fake, real) +loss.backward()

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class AdversarialLoss(nn.Module):
+    """Adversary training wrapper.
+
+    Args:
+        adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
+            We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
+            where the first item is a list of logits and the second item is a list of feature maps.
+        optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
+        loss (AdvLossType): Loss function for generator training.
+        loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
+        loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
+        loss_feat (FeatLossType): Feature matching loss function for generator training.
+        normalize (bool): Whether to normalize by number of sub-discriminators.
+
+    Example of usage:
+        adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
+        for real in loader:
+            noise = torch.randn(...)
+            fake = model(noise)
+            adv_loss.train_adv(fake, real)
+            loss, _ = adv_loss(fake, real)
+            loss.backward()
+    """
+    def __init__(self,
+                 adversary: nn.Module,
+                 optimizer: torch.optim.Optimizer,
+                 loss: AdvLossType,
+                 loss_real: AdvLossType,
+                 loss_fake: AdvLossType,
+                 loss_feat: tp.Optional[FeatLossType] = None,
+                 normalize: bool = True):
+        super().__init__()
+        self.adversary: nn.Module = adversary
+        flashy.distrib.broadcast_model(self.adversary)
+        self.optimizer = optimizer
+        self.loss = loss
+        self.loss_real = loss_real
+        self.loss_fake = loss_fake
+        self.loss_feat = loss_feat
+        self.normalize = normalize
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        # Add the optimizer state dict inside our own.
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'optimizer'] = self.optimizer.state_dict()
+        return destination
+
+    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+        # Load optimizer state.
+        self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
+        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+    def get_adversary_pred(self, x):
+        """Run adversary model, validating expected output format."""
+        logits, fmaps = self.adversary(x)
+        assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
+            f'Expecting a list of tensors as logits but {type(logits)} found.'
+        assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
+        for fmap in fmaps:
+            assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
+                f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
+        return logits, fmaps
+
+    def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
+        """Train the adversary with the given fake and real example.
+
+        We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
+        The first item being the logits and second item being a list of feature maps for each sub-discriminator.
+
+        This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
+        and call the optimizer.
+        """
+        loss = torch.tensor(0., device=fake.device)
+        all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
+        all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
+        n_sub_adversaries = len(all_logits_fake_is_fake)
+        for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
+            loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
+
+        if self.normalize:
+            loss /= n_sub_adversaries
+
+        self.optimizer.zero_grad()
+        with flashy.distrib.eager_sync_model(self.adversary):
+            loss.backward()
+        self.optimizer.step()
+
+        return loss
+
+    def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Return the loss for the generator, i.e. trying to fool the adversary,
+        and feature matching loss if provided.
+        """
+        adv = torch.tensor(0., device=fake.device)
+        feat = torch.tensor(0., device=fake.device)
+        with flashy.utils.readonly(self.adversary):
+            all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
+            all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
+            n_sub_adversaries = len(all_logits_fake_is_fake)
+            for logit_fake_is_fake in all_logits_fake_is_fake:
+                adv += self.loss(logit_fake_is_fake)
+            if self.loss_feat:
+                for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
+                    feat += self.loss_feat(fmap_fake, fmap_real)
+
+        if self.normalize:
+            adv /= n_sub_adversaries
+            feat /= n_sub_adversaries
+
+        return adv, feat
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, fake: torch.Tensor, real: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Return the loss for the generator, i.e. trying to fool the adversary, +and feature matching loss if provided.

+
+ +Expand source code + +
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+    """Return the loss for the generator, i.e. trying to fool the adversary,
+    and feature matching loss if provided.
+    """
+    adv = torch.tensor(0., device=fake.device)
+    feat = torch.tensor(0., device=fake.device)
+    with flashy.utils.readonly(self.adversary):
+        all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
+        all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
+        n_sub_adversaries = len(all_logits_fake_is_fake)
+        for logit_fake_is_fake in all_logits_fake_is_fake:
+            adv += self.loss(logit_fake_is_fake)
+        if self.loss_feat:
+            for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
+                feat += self.loss_feat(fmap_fake, fmap_real)
+
+    if self.normalize:
+        adv /= n_sub_adversaries
+        feat /= n_sub_adversaries
+
+    return adv, feat
+
+
+
+def get_adversary_pred(self, x) +
+
+

Run adversary model, validating expected output format.

+
+ +Expand source code + +
def get_adversary_pred(self, x):
+    """Run adversary model, validating expected output format."""
+    logits, fmaps = self.adversary(x)
+    assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
+        f'Expecting a list of tensors as logits but {type(logits)} found.'
+    assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
+    for fmap in fmaps:
+        assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
+            f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
+    return logits, fmaps
+
+
+
+def train_adv(self, fake: torch.Tensor, real: torch.Tensor) ‑> torch.Tensor +
+
+

Train the adversary with the given fake and real example.

+

We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]]. +The first item being the logits and second item being a list of feature maps for each sub-discriminator.

+

This will automatically synchronize gradients (with flashy.distrib.eager_sync_model) +and call the optimizer.

+
+ +Expand source code + +
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
+    """Train the adversary with the given fake and real example.
+
+    We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
+    The first item being the logits and second item being a list of feature maps for each sub-discriminator.
+
+    This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
+    and call the optimizer.
+    """
+    loss = torch.tensor(0., device=fake.device)
+    all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
+    all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
+    n_sub_adversaries = len(all_logits_fake_is_fake)
+    for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
+        loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
+
+    if self.normalize:
+        loss /= n_sub_adversaries
+
+    self.optimizer.zero_grad()
+    with flashy.distrib.eager_sync_model(self.adversary):
+        loss.backward()
+    self.optimizer.step()
+
+    return loss
+
+
+
+
+
+class FeatureMatchingLoss +(loss: torch.nn.modules.module.Module = L1Loss(), normalize: bool = True) +
+
+

Feature matching loss for adversarial training.

+

Args

+
+
loss : nn.Module
+
Loss to use for feature matching (default=torch.nn.L1).
+
normalize : bool
+
Whether to normalize the loss. +by number of feature maps.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class FeatureMatchingLoss(nn.Module):
+    """Feature matching loss for adversarial training.
+
+    Args:
+        loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
+        normalize (bool): Whether to normalize the loss.
+            by number of feature maps.
+    """
+    def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
+        super().__init__()
+        self.loss = loss
+        self.normalize = normalize
+
+    def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
+        assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
+        feat_loss = torch.tensor(0., device=fmap_fake[0].device)
+        feat_scale = torch.tensor(0., device=fmap_fake[0].device)
+        n_fmaps = 0
+        for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
+            assert feat_fake.shape == feat_real.shape
+            n_fmaps += 1
+            feat_loss += self.loss(feat_fake, feat_real)
+            feat_scale += torch.mean(torch.abs(feat_real))
+
+        if self.normalize:
+            feat_loss /= n_fmaps
+
+        return feat_loss
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, fmap_fake: List[torch.Tensor], fmap_real: List[torch.Tensor]) ‑> torch.Tensor +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
+    assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
+    feat_loss = torch.tensor(0., device=fmap_fake[0].device)
+    feat_scale = torch.tensor(0., device=fmap_fake[0].device)
+    n_fmaps = 0
+    for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
+        assert feat_fake.shape == feat_real.shape
+        n_fmaps += 1
+        feat_loss += self.loss(feat_fake, feat_real)
+        feat_scale += torch.mean(torch.abs(feat_real))
+
+    if self.normalize:
+        feat_loss /= n_fmaps
+
+    return feat_loss
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/audio.html b/api_docs/audiocraft/data/audio.html new file mode 100644 index 00000000..3971c9f4 --- /dev/null +++ b/api_docs/audiocraft/data/audio.html @@ -0,0 +1,548 @@ + + + + + + +audiocraft.data.audio API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.audio

+
+
+

Audio IO methods are defined in this module (info, read, write), +We rely on av library for faster read when possible, otherwise on torchaudio.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Audio IO methods are defined in this module (info, read, write),
+We rely on av library for faster read when possible, otherwise on torchaudio.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+import logging
+import typing as tp
+
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+
+import av
+import subprocess as sp
+
+from .audio_utils import f32_pcm, normalize_audio
+
+
+_av_initialized = False
+
+
+def _init_av():
+    global _av_initialized
+    if _av_initialized:
+        return
+    logger = logging.getLogger('libav.mp3')
+    logger.setLevel(logging.ERROR)
+    _av_initialized = True
+
+
+@dataclass(frozen=True)
+class AudioFileInfo:
+    sample_rate: int
+    duration: float
+    channels: int
+
+
+def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    _init_av()
+    with av.open(str(filepath)) as af:
+        stream = af.streams.audio[0]
+        sample_rate = stream.codec_context.sample_rate
+        duration = float(stream.duration * stream.time_base)
+        channels = stream.channels
+        return AudioFileInfo(sample_rate, duration, channels)
+
+
+def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    info = soundfile.info(filepath)
+    return AudioFileInfo(info.samplerate, info.duration, info.channels)
+
+
+def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    # torchaudio no longer returns useful duration informations for some formats like mp3s.
+    filepath = Path(filepath)
+    if filepath.suffix in ['.flac', '.ogg']:  # TODO: Validate .ogg can be safely read with av_info
+        # ffmpeg has some weird issue with flac.
+        return _soundfile_info(filepath)
+    else:
+        return _av_info(filepath)
+
+
+def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
+    """FFMPEG-based audio file reading using PyAV bindings.
+    Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
+
+    Args:
+        filepath (str or Path): Path to audio file to read.
+        seek_time (float): Time at which to start reading in the file.
+        duration (float): Duration to read from the file. If set to -1, the whole file is read.
+    Returns:
+        tuple of torch.Tensor, int: Tuple containing audio data and sample rate
+    """
+    _init_av()
+    with av.open(str(filepath)) as af:
+        stream = af.streams.audio[0]
+        sr = stream.codec_context.sample_rate
+        num_frames = int(sr * duration) if duration >= 0 else -1
+        frame_offset = int(sr * seek_time)
+        # we need a small negative offset otherwise we get some edge artifact
+        # from the mp3 decoder.
+        af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
+        frames = []
+        length = 0
+        for frame in af.decode(streams=stream.index):
+            current_offset = int(frame.rate * frame.pts * frame.time_base)
+            strip = max(0, frame_offset - current_offset)
+            buf = torch.from_numpy(frame.to_ndarray())
+            if buf.shape[0] != stream.channels:
+                buf = buf.view(-1, stream.channels).t()
+            buf = buf[:, strip:]
+            frames.append(buf)
+            length += buf.shape[1]
+            if num_frames > 0 and length >= num_frames:
+                break
+        assert frames
+        # If the above assert fails, it is likely because we seeked past the end of file point,
+        # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
+        # This will need proper debugging, in due time.
+        wav = torch.cat(frames, dim=1)
+        assert wav.shape[0] == stream.channels
+        if num_frames > 0:
+            wav = wav[:, :num_frames]
+        return f32_pcm(wav), sr
+
+
+def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
+               duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
+    """Read audio by picking the most appropriate backend tool based on the audio format.
+
+    Args:
+        filepath (str or Path): Path to audio file to read.
+        seek_time (float): Time at which to start reading in the file.
+        duration (float): Duration to read from the file. If set to -1, the whole file is read.
+        pad (bool): Pad output audio if not reaching expected duration.
+    Returns:
+        tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
+    """
+    fp = Path(filepath)
+    if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
+        # There is some bug with ffmpeg and reading flac
+        info = _soundfile_info(filepath)
+        frames = -1 if duration <= 0 else int(duration * info.sample_rate)
+        frame_offset = int(seek_time * info.sample_rate)
+        wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
+        assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
+        wav = torch.from_numpy(wav).t().contiguous()
+        if len(wav.shape) == 1:
+            wav = torch.unsqueeze(wav, 0)
+    else:
+        wav, sr = _av_read(filepath, seek_time, duration)
+    if pad and duration > 0:
+        expected_frames = int(duration * sr)
+        wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
+    return wav, sr
+
+
+def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
+    # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
+    assert wav.dim() == 2, wav.shape
+    command = [
+        'ffmpeg',
+        '-loglevel', 'error',
+        '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
+        '-i', '-'] + flags + [str(out_path)]
+    input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
+    sp.run(command, input=input_, check=True)
+
+
+def audio_write(stem_name: tp.Union[str, Path],
+                wav: torch.Tensor, sample_rate: int,
+                format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
+                normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                loudness_compressor: bool = False,
+                log_clipping: bool = True, make_parent_dir: bool = True,
+                add_suffix: bool = True) -> Path:
+    """Convenience function for saving audio to disk. Returns the filename the audio was written to.
+
+    Args:
+        stem_name (str or Path): Filename without extension which will be added automatically.
+        wav (torch.Tensor): Audio data to save.
+        sample_rate (int): Sample rate of audio data.
+        format (str): Either "wav", "mp3", "ogg", or "flac".
+        mp3_rate (int): kbps when using mp3s.
+        ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
+         when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        make_parent_dir (bool): Make parent directory if it doesn't exist.
+    Returns:
+        Path: Path of the saved audio.
+    """
+    assert wav.dtype.is_floating_point, "wav is not floating point"
+    if wav.dim() == 1:
+        wav = wav[None]
+    elif wav.dim() > 2:
+        raise ValueError("Input wav should be at most 2 dimension.")
+    assert wav.isfinite().all()
+    wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
+                          rms_headroom_db, loudness_headroom_db, loudness_compressor,
+                          log_clipping=log_clipping, sample_rate=sample_rate,
+                          stem_name=str(stem_name))
+    if format == 'mp3':
+        suffix = '.mp3'
+        flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
+    elif format == 'wav':
+        suffix = '.wav'
+        flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
+    elif format == 'ogg':
+        suffix = '.ogg'
+        flags = ['-f', 'ogg', '-c:a', 'libvorbis']
+        if ogg_rate is not None:
+            flags += ['-b:a', f'{ogg_rate}k']
+    elif format == 'flac':
+        suffix = '.flac'
+        flags = ['-f', 'flac']
+    else:
+        raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
+    if not add_suffix:
+        suffix = ''
+    path = Path(str(stem_name) + suffix)
+    if make_parent_dir:
+        path.parent.mkdir(exist_ok=True, parents=True)
+    try:
+        _piping_to_ffmpeg(path, wav, sample_rate, flags)
+    except Exception:
+        if path.exists():
+            # we do not want to leave half written files around.
+            path.unlink()
+        raise
+    return path
+
+
+
+
+
+
+
+

Functions

+
+
+def audio_info(filepath: Union[str, pathlib.Path]) ‑> AudioFileInfo +
+
+
+
+ +Expand source code + +
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    # torchaudio no longer returns useful duration informations for some formats like mp3s.
+    filepath = Path(filepath)
+    if filepath.suffix in ['.flac', '.ogg']:  # TODO: Validate .ogg can be safely read with av_info
+        # ffmpeg has some weird issue with flac.
+        return _soundfile_info(filepath)
+    else:
+        return _av_info(filepath)
+
+
+
+def audio_read(filepath: Union[str, pathlib.Path], seek_time: float = 0.0, duration: float = -1.0, pad: bool = False) ‑> Tuple[torch.Tensor, int] +
+
+

Read audio by picking the most appropriate backend tool based on the audio format.

+

Args

+
+
filepath : str or Path
+
Path to audio file to read.
+
seek_time : float
+
Time at which to start reading in the file.
+
duration : float
+
Duration to read from the file. If set to -1, the whole file is read.
+
pad : bool
+
Pad output audio if not reaching expected duration.
+
+

Returns

+
+
tuple of torch.Tensor, int
+
Tuple containing audio data and sample rate.
+
+
+ +Expand source code + +
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
+               duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
+    """Read audio by picking the most appropriate backend tool based on the audio format.
+
+    Args:
+        filepath (str or Path): Path to audio file to read.
+        seek_time (float): Time at which to start reading in the file.
+        duration (float): Duration to read from the file. If set to -1, the whole file is read.
+        pad (bool): Pad output audio if not reaching expected duration.
+    Returns:
+        tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
+    """
+    fp = Path(filepath)
+    if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
+        # There is some bug with ffmpeg and reading flac
+        info = _soundfile_info(filepath)
+        frames = -1 if duration <= 0 else int(duration * info.sample_rate)
+        frame_offset = int(seek_time * info.sample_rate)
+        wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
+        assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
+        wav = torch.from_numpy(wav).t().contiguous()
+        if len(wav.shape) == 1:
+            wav = torch.unsqueeze(wav, 0)
+    else:
+        wav, sr = _av_read(filepath, seek_time, duration)
+    if pad and duration > 0:
+        expected_frames = int(duration * sr)
+        wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
+    return wav, sr
+
+
+
+def audio_write(stem_name: Union[str, pathlib.Path], wav: torch.Tensor, sample_rate: int, format: str = 'wav', mp3_rate: int = 320, ogg_rate: Optional[int] = None, normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = True, make_parent_dir: bool = True, add_suffix: bool = True) ‑> pathlib.Path +
+
+

Convenience function for saving audio to disk. Returns the filename the audio was written to.

+

Args

+
+
stem_name : str or Path
+
Filename without extension which will be added automatically.
+
wav : torch.Tensor
+
Audio data to save.
+
sample_rate : int
+
Sample rate of audio data.
+
format : str
+
Either "wav", "mp3", "ogg", or "flac".
+
mp3_rate : int
+
kbps when using mp3s.
+
ogg_rate : int
+
kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
+
normalize : bool
+
if True (default), normalizes according to the prescribed +strategy (see after). If False, the strategy is only used in case clipping +would happen.
+
strategy : str
+
Can be either 'clip', 'peak', or 'rms'. Default is 'peak', +i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square +with extra headroom to avoid clipping. 'clip' just clips.
+
peak_clip_headroom_db : float
+
Headroom in dB when doing 'peak' or 'clip' strategy.
+
rms_headroom_db : float
+
Headroom in dB when doing 'rms' strategy. This must be much larger +than the peak_clip one to avoid further clipping.
+
loudness_headroom_db : float
+
Target loudness for loudness normalization.
+
loudness_compressor : bool
+
Uses tanh for soft clipping when strategy is 'loudness'.
+
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
+
occurs despite strategy (only for 'rms').
+
make_parent_dir : bool
+
Make parent directory if it doesn't exist.
+
+

Returns

+
+
Path
+
Path of the saved audio.
+
+
+ +Expand source code + +
def audio_write(stem_name: tp.Union[str, Path],
+                wav: torch.Tensor, sample_rate: int,
+                format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
+                normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                loudness_compressor: bool = False,
+                log_clipping: bool = True, make_parent_dir: bool = True,
+                add_suffix: bool = True) -> Path:
+    """Convenience function for saving audio to disk. Returns the filename the audio was written to.
+
+    Args:
+        stem_name (str or Path): Filename without extension which will be added automatically.
+        wav (torch.Tensor): Audio data to save.
+        sample_rate (int): Sample rate of audio data.
+        format (str): Either "wav", "mp3", "ogg", or "flac".
+        mp3_rate (int): kbps when using mp3s.
+        ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
+         when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        make_parent_dir (bool): Make parent directory if it doesn't exist.
+    Returns:
+        Path: Path of the saved audio.
+    """
+    assert wav.dtype.is_floating_point, "wav is not floating point"
+    if wav.dim() == 1:
+        wav = wav[None]
+    elif wav.dim() > 2:
+        raise ValueError("Input wav should be at most 2 dimension.")
+    assert wav.isfinite().all()
+    wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
+                          rms_headroom_db, loudness_headroom_db, loudness_compressor,
+                          log_clipping=log_clipping, sample_rate=sample_rate,
+                          stem_name=str(stem_name))
+    if format == 'mp3':
+        suffix = '.mp3'
+        flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
+    elif format == 'wav':
+        suffix = '.wav'
+        flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
+    elif format == 'ogg':
+        suffix = '.ogg'
+        flags = ['-f', 'ogg', '-c:a', 'libvorbis']
+        if ogg_rate is not None:
+            flags += ['-b:a', f'{ogg_rate}k']
+    elif format == 'flac':
+        suffix = '.flac'
+        flags = ['-f', 'flac']
+    else:
+        raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
+    if not add_suffix:
+        suffix = ''
+    path = Path(str(stem_name) + suffix)
+    if make_parent_dir:
+        path.parent.mkdir(exist_ok=True, parents=True)
+    try:
+        _piping_to_ffmpeg(path, wav, sample_rate, flags)
+    except Exception:
+        if path.exists():
+            # we do not want to leave half written files around.
+            path.unlink()
+        raise
+    return path
+
+
+
+
+
+

Classes

+
+
+class AudioFileInfo +(sample_rate: int, duration: float, channels: int) +
+
+

AudioFileInfo(sample_rate: int, duration: float, channels: int)

+
+ +Expand source code + +
class AudioFileInfo:
+    sample_rate: int
+    duration: float
+    channels: int
+
+

Class variables

+
+
var channels : int
+
+
+
+
var duration : float
+
+
+
+
var sample_rate : int
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/audio_dataset.html b/api_docs/audiocraft/data/audio_dataset.html new file mode 100644 index 00000000..ee907984 --- /dev/null +++ b/api_docs/audiocraft/data/audio_dataset.html @@ -0,0 +1,1715 @@ + + + + + + +audiocraft.data.audio_dataset API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.audio_dataset

+
+
+

AudioDataset support. In order to handle a larger number of files +without having to scan again the folders, we precompute some metadata +(filename, sample rate, duration), and use that to efficiently sample audio segments.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""AudioDataset support. In order to handle a larger number of files
+without having to scan again the folders, we precompute some metadata
+(filename, sample rate, duration), and use that to efficiently sample audio segments.
+"""
+import argparse
+import copy
+from concurrent.futures import ThreadPoolExecutor, Future
+from dataclasses import dataclass, fields
+from contextlib import ExitStack
+from functools import lru_cache
+import gzip
+import json
+import logging
+import os
+from pathlib import Path
+import random
+import sys
+import typing as tp
+
+import torch
+import torch.nn.functional as F
+
+from .audio import audio_read, audio_info
+from .audio_utils import convert_audio
+from .zip import PathInZip
+
+try:
+    import dora
+except ImportError:
+    dora = None  # type: ignore
+
+
+@dataclass(order=True)
+class BaseInfo:
+
+    @classmethod
+    def _dict2fields(cls, dictionary: dict):
+        return {
+            field.name: dictionary[field.name]
+            for field in fields(cls) if field.name in dictionary
+        }
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        _dictionary = cls._dict2fields(dictionary)
+        return cls(**_dictionary)
+
+    def to_dict(self):
+        return {
+            field.name: self.__getattribute__(field.name)
+            for field in fields(self)
+            }
+
+
+@dataclass(order=True)
+class AudioMeta(BaseInfo):
+    path: str
+    duration: float
+    sample_rate: int
+    amplitude: tp.Optional[float] = None
+    weight: tp.Optional[float] = None
+    # info_path is used to load additional information about the audio file that is stored in zip files.
+    info_path: tp.Optional[PathInZip] = None
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        base = cls._dict2fields(dictionary)
+        if 'info_path' in base and base['info_path'] is not None:
+            base['info_path'] = PathInZip(base['info_path'])
+        return cls(**base)
+
+    def to_dict(self):
+        d = super().to_dict()
+        if d['info_path'] is not None:
+            d['info_path'] = str(d['info_path'])
+        return d
+
+
+@dataclass(order=True)
+class SegmentInfo(BaseInfo):
+    meta: AudioMeta
+    seek_time: float
+    # The following values are given once the audio is processed, e.g.
+    # at the target sample rate and target number of channels.
+    n_frames: int      # actual number of frames without padding
+    total_frames: int  # total number of frames, padding included
+    sample_rate: int   # actual sample rate
+    channels: int      # number of audio channels.
+
+
+DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
+
+logger = logging.getLogger(__name__)
+
+
+def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
+    """AudioMeta from a path to an audio file.
+
+    Args:
+        file_path (str): Resolved path of valid audio file.
+        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+    Returns:
+        AudioMeta: Audio file path and its metadata.
+    """
+    info = audio_info(file_path)
+    amplitude: tp.Optional[float] = None
+    if not minimal:
+        wav, sr = audio_read(file_path)
+        amplitude = wav.abs().max().item()
+    return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
+
+
+def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
+    """If Dora is available as a dependency, try to resolve potential relative paths
+    in list of AudioMeta. This method is expected to be used when loading meta from file.
+
+    Args:
+        m (AudioMeta): Audio meta to resolve.
+        fast (bool): If True, uses a really fast check for determining if a file
+            is already absolute or not. Only valid on Linux/Mac.
+    Returns:
+        AudioMeta: Audio meta with resolved path.
+    """
+    def is_abs(m):
+        if fast:
+            return str(m)[0] == '/'
+        else:
+            os.path.isabs(str(m))
+
+    if not dora:
+        return m
+
+    if not is_abs(m.path):
+        m.path = dora.git_save.to_absolute_path(m.path)
+    if m.info_path is not None and not is_abs(m.info_path.zip_path):
+        m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
+    return m
+
+
+def find_audio_files(path: tp.Union[Path, str],
+                     exts: tp.List[str] = DEFAULT_EXTS,
+                     resolve: bool = True,
+                     minimal: bool = True,
+                     progress: bool = False,
+                     workers: int = 0) -> tp.List[AudioMeta]:
+    """Build a list of AudioMeta from a given path,
+    collecting relevant audio files and fetching meta info.
+
+    Args:
+        path (str or Path): Path to folder containing audio files.
+        exts (list of str): List of file extensions to consider for audio files.
+        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+        progress (bool): Whether to log progress on audio files collection.
+        workers (int): number of parallel workers, if 0, use only the current thread.
+    Returns:
+        list of AudioMeta: List of audio file path and its metadata.
+    """
+    audio_files = []
+    futures: tp.List[Future] = []
+    pool: tp.Optional[ThreadPoolExecutor] = None
+    with ExitStack() as stack:
+        if workers > 0:
+            pool = ThreadPoolExecutor(workers)
+            stack.enter_context(pool)
+
+        if progress:
+            print("Finding audio files...")
+        for root, folders, files in os.walk(path, followlinks=True):
+            for file in files:
+                full_path = Path(root) / file
+                if full_path.suffix.lower() in exts:
+                    audio_files.append(full_path)
+                    if pool is not None:
+                        futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
+                    if progress:
+                        print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
+
+        if progress:
+            print("Getting audio metadata...")
+        meta: tp.List[AudioMeta] = []
+        for idx, file_path in enumerate(audio_files):
+            try:
+                if pool is None:
+                    m = _get_audio_meta(str(file_path), minimal)
+                else:
+                    m = futures[idx].result()
+                if resolve:
+                    m = _resolve_audio_meta(m)
+            except Exception as err:
+                print("Error with", str(file_path), err, file=sys.stderr)
+                continue
+            meta.append(m)
+            if progress:
+                print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
+    meta.sort()
+    return meta
+
+
+def load_audio_meta(path: tp.Union[str, Path],
+                    resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
+    """Load list of AudioMeta from an optionally compressed json file.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        resolve (bool): Whether to resolve the path from AudioMeta (default=True).
+        fast (bool): activates some tricks to make things faster.
+    Returns:
+        list of AudioMeta: List of audio file path and its total duration.
+    """
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'rb') as fp:  # type: ignore
+        lines = fp.readlines()
+    meta = []
+    for line in lines:
+        d = json.loads(line)
+        m = AudioMeta.from_dict(d)
+        if resolve:
+            m = _resolve_audio_meta(m, fast=fast)
+        meta.append(m)
+    return meta
+
+
+def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
+    """Save the audio metadata to the file pointer as json.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        metadata (list of BaseAudioMeta): List of audio meta to save.
+    """
+    Path(path).parent.mkdir(exist_ok=True, parents=True)
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'wb') as fp:  # type: ignore
+        for m in meta:
+            json_str = json.dumps(m.to_dict()) + '\n'
+            json_bytes = json_str.encode('utf-8')
+            fp.write(json_bytes)
+
+
+class AudioDataset:
+    """Base audio dataset.
+
+    The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
+    and potentially additional information, by creating random segments from the list of audio
+    files referenced in the metadata and applying minimal data pre-processing such as resampling,
+    mixing of channels, padding, etc.
+
+    If no segment_duration value is provided, the AudioDataset will return the full wav for each
+    audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
+    duration, applying padding if required.
+
+    By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
+    allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
+    original audio meta.
+
+    Note that you can call `start_epoch(epoch)` in order to get
+    a deterministic "randomization" for `shuffle=True`.
+    For a given epoch and dataset index, this will always return the same extract.
+    You can get back some diversity by setting the `shuffle_seed` param.
+
+    Args:
+        meta (list of AudioMeta): List of audio files metadata.
+        segment_duration (float, optional): Optional segment duration of audio to load.
+            If not specified, the dataset will load the full audio segment from the file.
+        shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
+        sample_rate (int): Target sample rate of the loaded audio samples.
+        channels (int): Target number of channels of the loaded audio samples.
+        sample_on_duration (bool): Set to `True` to sample segments with probability
+            dependent on audio file duration. This is only used if `segment_duration` is provided.
+        sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
+            `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
+            of the file duration and file weight. This is only used if `segment_duration` is provided.
+        min_segment_ratio (float): Minimum segment ratio to use when the audio file
+            is shorter than the desired segment.
+        max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
+        return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
+        min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
+            audio shorter than this will be filtered out.
+        max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
+            audio longer than this will be filtered out.
+        shuffle_seed (int): can be used to further randomize
+        load_wav (bool): if False, skip loading the wav but returns a tensor of 0
+            with the expected segment_duration (which must be provided if load_wav is False).
+        permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
+            are False. Will ensure a permutation on files when going through the dataset.
+            In that case the epoch number must be provided in order for the model
+            to continue the permutation across epochs. In that case, it is assumed
+            that `num_samples = total_batch_size * num_updates_per_epoch`, with
+            `total_batch_size` the overall batch size accounting for all gpus.
+    """
+    def __init__(self,
+                 meta: tp.List[AudioMeta],
+                 segment_duration: tp.Optional[float] = None,
+                 shuffle: bool = True,
+                 num_samples: int = 10_000,
+                 sample_rate: int = 48_000,
+                 channels: int = 2,
+                 pad: bool = True,
+                 sample_on_duration: bool = True,
+                 sample_on_weight: bool = True,
+                 min_segment_ratio: float = 0.5,
+                 max_read_retry: int = 10,
+                 return_info: bool = False,
+                 min_audio_duration: tp.Optional[float] = None,
+                 max_audio_duration: tp.Optional[float] = None,
+                 shuffle_seed: int = 0,
+                 load_wav: bool = True,
+                 permutation_on_files: bool = False,
+                 ):
+        assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
+        assert segment_duration is None or segment_duration > 0
+        assert segment_duration is None or min_segment_ratio >= 0
+        self.segment_duration = segment_duration
+        self.min_segment_ratio = min_segment_ratio
+        self.max_audio_duration = max_audio_duration
+        self.min_audio_duration = min_audio_duration
+        if self.min_audio_duration is not None and self.max_audio_duration is not None:
+            assert self.min_audio_duration <= self.max_audio_duration
+        self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
+        assert len(self.meta)  # Fail fast if all data has been filtered.
+        self.total_duration = sum(d.duration for d in self.meta)
+
+        if segment_duration is None:
+            num_samples = len(self.meta)
+        self.num_samples = num_samples
+        self.shuffle = shuffle
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.pad = pad
+        self.sample_on_weight = sample_on_weight
+        self.sample_on_duration = sample_on_duration
+        self.sampling_probabilities = self._get_sampling_probabilities()
+        self.max_read_retry = max_read_retry
+        self.return_info = return_info
+        self.shuffle_seed = shuffle_seed
+        self.current_epoch: tp.Optional[int] = None
+        self.load_wav = load_wav
+        if not load_wav:
+            assert segment_duration is not None
+        self.permutation_on_files = permutation_on_files
+        if permutation_on_files:
+            assert not self.sample_on_duration
+            assert not self.sample_on_weight
+            assert self.shuffle
+
+    def start_epoch(self, epoch: int):
+        self.current_epoch = epoch
+
+    def __len__(self):
+        return self.num_samples
+
+    def _get_sampling_probabilities(self, normalized: bool = True):
+        """Return the sampling probabilities for each file inside `self.meta`."""
+        scores: tp.List[float] = []
+        for file_meta in self.meta:
+            score = 1.
+            if self.sample_on_weight and file_meta.weight is not None:
+                score *= file_meta.weight
+            if self.sample_on_duration:
+                score *= file_meta.duration
+            scores.append(score)
+        probabilities = torch.tensor(scores)
+        if normalized:
+            probabilities /= probabilities.sum()
+        return probabilities
+
+    @staticmethod
+    @lru_cache(16)
+    def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
+        # Used to keep the most recent files permutation in memory implicitely.
+        # will work unless someone is using a lot of Datasets in parallel.
+        rng = torch.Generator()
+        rng.manual_seed(base_seed + permutation_index)
+        return torch.randperm(num_files, generator=rng)
+
+    def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
+        """Sample a given file from `self.meta`. Can be overridden in subclasses.
+        This is only called if `segment_duration` is not None.
+
+        You must use the provided random number generator `rng` for reproducibility.
+        You can further make use of the index accessed.
+        """
+        if self.permutation_on_files:
+            assert self.current_epoch is not None
+            total_index = self.current_epoch * len(self) + index
+            permutation_index = total_index // len(self.meta)
+            relative_index = total_index % len(self.meta)
+            permutation = AudioDataset._get_file_permutation(
+                len(self.meta), permutation_index, self.shuffle_seed)
+            file_index = permutation[relative_index]
+            return self.meta[file_index]
+
+        if not self.sample_on_weight and not self.sample_on_duration:
+            file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+        else:
+            file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+        return self.meta[file_index]
+
+    def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
+        # Override this method in subclass if needed.
+        if self.load_wav:
+            return audio_read(path, seek_time, duration, pad=False)
+        else:
+            assert self.segment_duration is not None
+            n_frames = int(self.sample_rate * self.segment_duration)
+            return torch.zeros(self.channels, n_frames), self.sample_rate
+
+    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
+        if self.segment_duration is None:
+            file_meta = self.meta[index]
+            out, sr = audio_read(file_meta.path)
+            out = convert_audio(out, sr, self.sample_rate, self.channels)
+            n_frames = out.shape[-1]
+            segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
+                                       sample_rate=self.sample_rate, channels=out.shape[0])
+        else:
+            rng = torch.Generator()
+            if self.shuffle:
+                # We use index, plus extra randomness, either totally random if we don't know the epoch.
+                # otherwise we make use of the epoch number and optional shuffle_seed.
+                if self.current_epoch is None:
+                    rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
+                else:
+                    rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
+            else:
+                # We only use index
+                rng.manual_seed(index)
+
+            for retry in range(self.max_read_retry):
+                file_meta = self.sample_file(index, rng)
+                # We add some variance in the file position even if audio file is smaller than segment
+                # without ending up with empty segments
+                max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
+                seek_time = torch.rand(1, generator=rng).item() * max_seek
+                try:
+                    out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
+                    out = convert_audio(out, sr, self.sample_rate, self.channels)
+                    n_frames = out.shape[-1]
+                    target_frames = int(self.segment_duration * self.sample_rate)
+                    if self.pad:
+                        out = F.pad(out, (0, target_frames - n_frames))
+                    segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
+                                               sample_rate=self.sample_rate, channels=out.shape[0])
+                except Exception as exc:
+                    logger.warning("Error opening file %s: %r", file_meta.path, exc)
+                    if retry == self.max_read_retry - 1:
+                        raise
+                else:
+                    break
+
+        if self.return_info:
+            # Returns the wav and additional information on the wave segment
+            return out, segment_info
+        else:
+            return out
+
+    def collater(self, samples):
+        """The collater function has to be provided to the dataloader
+        if AudioDataset has return_info=True in order to properly collate
+        the samples of a batch.
+        """
+        if self.segment_duration is None and len(samples) > 1:
+            assert self.pad, "Must allow padding when batching examples of different durations."
+
+        # In this case the audio reaching the collater is of variable length as segment_duration=None.
+        to_pad = self.segment_duration is None and self.pad
+        if to_pad:
+            max_len = max([wav.shape[-1] for wav, _ in samples])
+
+            def _pad_wav(wav):
+                return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+        if self.return_info:
+            if len(samples) > 0:
+                assert len(samples[0]) == 2
+                assert isinstance(samples[0][0], torch.Tensor)
+                assert isinstance(samples[0][1], SegmentInfo)
+
+            wavs = [wav for wav, _ in samples]
+            segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+            if to_pad:
+                # Each wav could be of a different duration as they are not segmented.
+                for i in range(len(samples)):
+                    # Determines the total length of the signal with padding, so we update here as we pad.
+                    segment_infos[i].total_frames = max_len
+                    wavs[i] = _pad_wav(wavs[i])
+
+            wav = torch.stack(wavs)
+            return wav, segment_infos
+        else:
+            assert isinstance(samples[0], torch.Tensor)
+            if to_pad:
+                samples = [_pad_wav(s) for s in samples]
+            return torch.stack(samples)
+
+    def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+        """Filters out audio files with audio durations that will not allow to sample examples from them."""
+        orig_len = len(meta)
+
+        # Filter data that is too short.
+        if self.min_audio_duration is not None:
+            meta = [m for m in meta if m.duration >= self.min_audio_duration]
+
+        # Filter data that is too long.
+        if self.max_audio_duration is not None:
+            meta = [m for m in meta if m.duration <= self.max_audio_duration]
+
+        filtered_len = len(meta)
+        removed_percentage = 100*(1-float(filtered_len)/orig_len)
+        msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
+        if removed_percentage < 10:
+            logging.debug(msg)
+        else:
+            logging.warning(msg)
+        return meta
+
+    @classmethod
+    def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+        """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_dir():
+            if (root / 'data.jsonl').exists():
+                root = root / 'data.jsonl'
+            elif (root / 'data.jsonl.gz').exists():
+                root = root / 'data.jsonl.gz'
+            else:
+                raise ValueError("Don't know where to read metadata from in the dir. "
+                                 "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+        meta = load_audio_meta(root)
+        return cls(meta, **kwargs)
+
+    @classmethod
+    def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+                  exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+        """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            minimal_meta (bool): Whether to only load minimal metadata or not.
+            exts (list of str): Extensions for audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_file():
+            meta = load_audio_meta(root, resolve=True)
+        else:
+            meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+        return cls(meta, **kwargs)
+
+
+def main():
+    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+    parser = argparse.ArgumentParser(
+        prog='audio_dataset',
+        description='Generate .jsonl files by scanning a folder.')
+    parser.add_argument('root', help='Root folder with all the audio files')
+    parser.add_argument('output_meta_file',
+                        help='Output file to store the metadata, ')
+    parser.add_argument('--complete',
+                        action='store_false', dest='minimal', default=True,
+                        help='Retrieve all metadata, even the one that are expansive '
+                             'to compute (e.g. normalization).')
+    parser.add_argument('--resolve',
+                        action='store_true', default=False,
+                        help='Resolve the paths to be absolute and with no symlinks.')
+    parser.add_argument('--workers',
+                        default=10, type=int,
+                        help='Number of workers.')
+    args = parser.parse_args()
+    meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
+                            resolve=args.resolve, minimal=args.minimal, workers=args.workers)
+    save_audio_meta(args.output_meta_file, meta)
+
+
+if __name__ == '__main__':
+    main()
+
+
+
+
+
+
+
+

Functions

+
+
+def find_audio_files(path: Union[str, pathlib.Path], exts: List[str] = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'], resolve: bool = True, minimal: bool = True, progress: bool = False, workers: int = 0) ‑> List[AudioMeta] +
+
+

Build a list of AudioMeta from a given path, +collecting relevant audio files and fetching meta info.

+

Args

+
+
path : str or Path
+
Path to folder containing audio files.
+
exts : list of str
+
List of file extensions to consider for audio files.
+
minimal : bool
+
Whether to only load the minimal set of metadata (takes longer if not).
+
progress : bool
+
Whether to log progress on audio files collection.
+
workers : int
+
number of parallel workers, if 0, use only the current thread.
+
+

Returns

+
+
list of AudioMeta
+
List of audio file path and its metadata.
+
+
+ +Expand source code + +
def find_audio_files(path: tp.Union[Path, str],
+                     exts: tp.List[str] = DEFAULT_EXTS,
+                     resolve: bool = True,
+                     minimal: bool = True,
+                     progress: bool = False,
+                     workers: int = 0) -> tp.List[AudioMeta]:
+    """Build a list of AudioMeta from a given path,
+    collecting relevant audio files and fetching meta info.
+
+    Args:
+        path (str or Path): Path to folder containing audio files.
+        exts (list of str): List of file extensions to consider for audio files.
+        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+        progress (bool): Whether to log progress on audio files collection.
+        workers (int): number of parallel workers, if 0, use only the current thread.
+    Returns:
+        list of AudioMeta: List of audio file path and its metadata.
+    """
+    audio_files = []
+    futures: tp.List[Future] = []
+    pool: tp.Optional[ThreadPoolExecutor] = None
+    with ExitStack() as stack:
+        if workers > 0:
+            pool = ThreadPoolExecutor(workers)
+            stack.enter_context(pool)
+
+        if progress:
+            print("Finding audio files...")
+        for root, folders, files in os.walk(path, followlinks=True):
+            for file in files:
+                full_path = Path(root) / file
+                if full_path.suffix.lower() in exts:
+                    audio_files.append(full_path)
+                    if pool is not None:
+                        futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
+                    if progress:
+                        print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
+
+        if progress:
+            print("Getting audio metadata...")
+        meta: tp.List[AudioMeta] = []
+        for idx, file_path in enumerate(audio_files):
+            try:
+                if pool is None:
+                    m = _get_audio_meta(str(file_path), minimal)
+                else:
+                    m = futures[idx].result()
+                if resolve:
+                    m = _resolve_audio_meta(m)
+            except Exception as err:
+                print("Error with", str(file_path), err, file=sys.stderr)
+                continue
+            meta.append(m)
+            if progress:
+                print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
+    meta.sort()
+    return meta
+
+
+
+def load_audio_meta(path: Union[str, pathlib.Path], resolve: bool = True, fast: bool = True) ‑> List[AudioMeta] +
+
+

Load list of AudioMeta from an optionally compressed json file.

+

Args

+
+
path : str or Path
+
Path to JSON file.
+
resolve : bool
+
Whether to resolve the path from AudioMeta (default=True).
+
fast : bool
+
activates some tricks to make things faster.
+
+

Returns

+
+
list of AudioMeta
+
List of audio file path and its total duration.
+
+
+ +Expand source code + +
def load_audio_meta(path: tp.Union[str, Path],
+                    resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
+    """Load list of AudioMeta from an optionally compressed json file.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        resolve (bool): Whether to resolve the path from AudioMeta (default=True).
+        fast (bool): activates some tricks to make things faster.
+    Returns:
+        list of AudioMeta: List of audio file path and its total duration.
+    """
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'rb') as fp:  # type: ignore
+        lines = fp.readlines()
+    meta = []
+    for line in lines:
+        d = json.loads(line)
+        m = AudioMeta.from_dict(d)
+        if resolve:
+            m = _resolve_audio_meta(m, fast=fast)
+        meta.append(m)
+    return meta
+
+
+
+def main() +
+
+
+
+ +Expand source code + +
def main():
+    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+    parser = argparse.ArgumentParser(
+        prog='audio_dataset',
+        description='Generate .jsonl files by scanning a folder.')
+    parser.add_argument('root', help='Root folder with all the audio files')
+    parser.add_argument('output_meta_file',
+                        help='Output file to store the metadata, ')
+    parser.add_argument('--complete',
+                        action='store_false', dest='minimal', default=True,
+                        help='Retrieve all metadata, even the one that are expansive '
+                             'to compute (e.g. normalization).')
+    parser.add_argument('--resolve',
+                        action='store_true', default=False,
+                        help='Resolve the paths to be absolute and with no symlinks.')
+    parser.add_argument('--workers',
+                        default=10, type=int,
+                        help='Number of workers.')
+    args = parser.parse_args()
+    meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
+                            resolve=args.resolve, minimal=args.minimal, workers=args.workers)
+    save_audio_meta(args.output_meta_file, meta)
+
+
+
+def save_audio_meta(path: Union[str, pathlib.Path], meta: List[AudioMeta]) +
+
+

Save the audio metadata to the file pointer as json.

+

Args

+
+
path : str or Path
+
Path to JSON file.
+
metadata : list of BaseAudioMeta
+
List of audio meta to save.
+
+
+ +Expand source code + +
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
+    """Save the audio metadata to the file pointer as json.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        metadata (list of BaseAudioMeta): List of audio meta to save.
+    """
+    Path(path).parent.mkdir(exist_ok=True, parents=True)
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'wb') as fp:  # type: ignore
+        for m in meta:
+            json_str = json.dumps(m.to_dict()) + '\n'
+            json_bytes = json_str.encode('utf-8')
+            fp.write(json_bytes)
+
+
+
+
+
+

Classes

+
+
+class AudioDataset +(meta: List[AudioMeta], segment_duration: Optional[float] = None, shuffle: bool = True, num_samples: int = 10000, sample_rate: int = 48000, channels: int = 2, pad: bool = True, sample_on_duration: bool = True, sample_on_weight: bool = True, min_segment_ratio: float = 0.5, max_read_retry: int = 10, return_info: bool = False, min_audio_duration: Optional[float] = None, max_audio_duration: Optional[float] = None, shuffle_seed: int = 0, load_wav: bool = True, permutation_on_files: bool = False) +
+
+

Base audio dataset.

+

The dataset takes a list of AudioMeta and create a dataset composed of segments of audio +and potentially additional information, by creating random segments from the list of audio +files referenced in the metadata and applying minimal data pre-processing such as resampling, +mixing of channels, padding, etc.

+

If no segment_duration value is provided, the AudioDataset will return the full wav for each +audio file. Otherwise, it will randomly sample audio files and create a segment of the specified +duration, applying padding if required.

+

By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True +allows to return a tuple containing the torch Tensor and additional metadata on the segment and the +original audio meta.

+

Note that you can call start_epoch(epoch) in order to get +a deterministic "randomization" for shuffle=True. +For a given epoch and dataset index, this will always return the same extract. +You can get back some diversity by setting the shuffle_seed param.

+

Args

+
+
meta : list of AudioMeta
+
List of audio files metadata.
+
segment_duration : float, optional
+
Optional segment duration of audio to load. +If not specified, the dataset will load the full audio segment from the file.
+
shuffle : bool
+
Set to True to have the data reshuffled at every epoch.
+
sample_rate : int
+
Target sample rate of the loaded audio samples.
+
channels : int
+
Target number of channels of the loaded audio samples.
+
sample_on_duration : bool
+
Set to True to sample segments with probability +dependent on audio file duration. This is only used if segment_duration is provided.
+
sample_on_weight : bool
+
Set to True to sample segments using the weight entry of +AudioMeta. If sample_on_duration is also True, the actual weight will be the product +of the file duration and file weight. This is only used if segment_duration is provided.
+
min_segment_ratio : float
+
Minimum segment ratio to use when the audio file +is shorter than the desired segment.
+
max_read_retry : int
+
Maximum number of retries to sample an audio segment from the dataset.
+
return_info : bool
+
Whether to return the wav only or return wav along with segment info and metadata.
+
min_audio_duration : float, optional
+
Minimum audio file duration, in seconds, if provided +audio shorter than this will be filtered out.
+
max_audio_duration : float, optional
+
Maximal audio file duration in seconds, if provided +audio longer than this will be filtered out.
+
shuffle_seed : int
+
can be used to further randomize
+
load_wav : bool
+
if False, skip loading the wav but returns a tensor of 0 +with the expected segment_duration (which must be provided if load_wav is False).
+
permutation_on_files : bool
+
only if sample_on_weight and sample_on_duration +are False. Will ensure a permutation on files when going through the dataset. +In that case the epoch number must be provided in order for the model +to continue the permutation across epochs. In that case, it is assumed +that num_samples = total_batch_size * num_updates_per_epoch, with +total_batch_size the overall batch size accounting for all gpus.
+
+
+ +Expand source code + +
class AudioDataset:
+    """Base audio dataset.
+
+    The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
+    and potentially additional information, by creating random segments from the list of audio
+    files referenced in the metadata and applying minimal data pre-processing such as resampling,
+    mixing of channels, padding, etc.
+
+    If no segment_duration value is provided, the AudioDataset will return the full wav for each
+    audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
+    duration, applying padding if required.
+
+    By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
+    allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
+    original audio meta.
+
+    Note that you can call `start_epoch(epoch)` in order to get
+    a deterministic "randomization" for `shuffle=True`.
+    For a given epoch and dataset index, this will always return the same extract.
+    You can get back some diversity by setting the `shuffle_seed` param.
+
+    Args:
+        meta (list of AudioMeta): List of audio files metadata.
+        segment_duration (float, optional): Optional segment duration of audio to load.
+            If not specified, the dataset will load the full audio segment from the file.
+        shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
+        sample_rate (int): Target sample rate of the loaded audio samples.
+        channels (int): Target number of channels of the loaded audio samples.
+        sample_on_duration (bool): Set to `True` to sample segments with probability
+            dependent on audio file duration. This is only used if `segment_duration` is provided.
+        sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
+            `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
+            of the file duration and file weight. This is only used if `segment_duration` is provided.
+        min_segment_ratio (float): Minimum segment ratio to use when the audio file
+            is shorter than the desired segment.
+        max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
+        return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
+        min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
+            audio shorter than this will be filtered out.
+        max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
+            audio longer than this will be filtered out.
+        shuffle_seed (int): can be used to further randomize
+        load_wav (bool): if False, skip loading the wav but returns a tensor of 0
+            with the expected segment_duration (which must be provided if load_wav is False).
+        permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
+            are False. Will ensure a permutation on files when going through the dataset.
+            In that case the epoch number must be provided in order for the model
+            to continue the permutation across epochs. In that case, it is assumed
+            that `num_samples = total_batch_size * num_updates_per_epoch`, with
+            `total_batch_size` the overall batch size accounting for all gpus.
+    """
+    def __init__(self,
+                 meta: tp.List[AudioMeta],
+                 segment_duration: tp.Optional[float] = None,
+                 shuffle: bool = True,
+                 num_samples: int = 10_000,
+                 sample_rate: int = 48_000,
+                 channels: int = 2,
+                 pad: bool = True,
+                 sample_on_duration: bool = True,
+                 sample_on_weight: bool = True,
+                 min_segment_ratio: float = 0.5,
+                 max_read_retry: int = 10,
+                 return_info: bool = False,
+                 min_audio_duration: tp.Optional[float] = None,
+                 max_audio_duration: tp.Optional[float] = None,
+                 shuffle_seed: int = 0,
+                 load_wav: bool = True,
+                 permutation_on_files: bool = False,
+                 ):
+        assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
+        assert segment_duration is None or segment_duration > 0
+        assert segment_duration is None or min_segment_ratio >= 0
+        self.segment_duration = segment_duration
+        self.min_segment_ratio = min_segment_ratio
+        self.max_audio_duration = max_audio_duration
+        self.min_audio_duration = min_audio_duration
+        if self.min_audio_duration is not None and self.max_audio_duration is not None:
+            assert self.min_audio_duration <= self.max_audio_duration
+        self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
+        assert len(self.meta)  # Fail fast if all data has been filtered.
+        self.total_duration = sum(d.duration for d in self.meta)
+
+        if segment_duration is None:
+            num_samples = len(self.meta)
+        self.num_samples = num_samples
+        self.shuffle = shuffle
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.pad = pad
+        self.sample_on_weight = sample_on_weight
+        self.sample_on_duration = sample_on_duration
+        self.sampling_probabilities = self._get_sampling_probabilities()
+        self.max_read_retry = max_read_retry
+        self.return_info = return_info
+        self.shuffle_seed = shuffle_seed
+        self.current_epoch: tp.Optional[int] = None
+        self.load_wav = load_wav
+        if not load_wav:
+            assert segment_duration is not None
+        self.permutation_on_files = permutation_on_files
+        if permutation_on_files:
+            assert not self.sample_on_duration
+            assert not self.sample_on_weight
+            assert self.shuffle
+
+    def start_epoch(self, epoch: int):
+        self.current_epoch = epoch
+
+    def __len__(self):
+        return self.num_samples
+
+    def _get_sampling_probabilities(self, normalized: bool = True):
+        """Return the sampling probabilities for each file inside `self.meta`."""
+        scores: tp.List[float] = []
+        for file_meta in self.meta:
+            score = 1.
+            if self.sample_on_weight and file_meta.weight is not None:
+                score *= file_meta.weight
+            if self.sample_on_duration:
+                score *= file_meta.duration
+            scores.append(score)
+        probabilities = torch.tensor(scores)
+        if normalized:
+            probabilities /= probabilities.sum()
+        return probabilities
+
+    @staticmethod
+    @lru_cache(16)
+    def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
+        # Used to keep the most recent files permutation in memory implicitely.
+        # will work unless someone is using a lot of Datasets in parallel.
+        rng = torch.Generator()
+        rng.manual_seed(base_seed + permutation_index)
+        return torch.randperm(num_files, generator=rng)
+
+    def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
+        """Sample a given file from `self.meta`. Can be overridden in subclasses.
+        This is only called if `segment_duration` is not None.
+
+        You must use the provided random number generator `rng` for reproducibility.
+        You can further make use of the index accessed.
+        """
+        if self.permutation_on_files:
+            assert self.current_epoch is not None
+            total_index = self.current_epoch * len(self) + index
+            permutation_index = total_index // len(self.meta)
+            relative_index = total_index % len(self.meta)
+            permutation = AudioDataset._get_file_permutation(
+                len(self.meta), permutation_index, self.shuffle_seed)
+            file_index = permutation[relative_index]
+            return self.meta[file_index]
+
+        if not self.sample_on_weight and not self.sample_on_duration:
+            file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+        else:
+            file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+        return self.meta[file_index]
+
+    def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
+        # Override this method in subclass if needed.
+        if self.load_wav:
+            return audio_read(path, seek_time, duration, pad=False)
+        else:
+            assert self.segment_duration is not None
+            n_frames = int(self.sample_rate * self.segment_duration)
+            return torch.zeros(self.channels, n_frames), self.sample_rate
+
+    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
+        if self.segment_duration is None:
+            file_meta = self.meta[index]
+            out, sr = audio_read(file_meta.path)
+            out = convert_audio(out, sr, self.sample_rate, self.channels)
+            n_frames = out.shape[-1]
+            segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
+                                       sample_rate=self.sample_rate, channels=out.shape[0])
+        else:
+            rng = torch.Generator()
+            if self.shuffle:
+                # We use index, plus extra randomness, either totally random if we don't know the epoch.
+                # otherwise we make use of the epoch number and optional shuffle_seed.
+                if self.current_epoch is None:
+                    rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
+                else:
+                    rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
+            else:
+                # We only use index
+                rng.manual_seed(index)
+
+            for retry in range(self.max_read_retry):
+                file_meta = self.sample_file(index, rng)
+                # We add some variance in the file position even if audio file is smaller than segment
+                # without ending up with empty segments
+                max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
+                seek_time = torch.rand(1, generator=rng).item() * max_seek
+                try:
+                    out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
+                    out = convert_audio(out, sr, self.sample_rate, self.channels)
+                    n_frames = out.shape[-1]
+                    target_frames = int(self.segment_duration * self.sample_rate)
+                    if self.pad:
+                        out = F.pad(out, (0, target_frames - n_frames))
+                    segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
+                                               sample_rate=self.sample_rate, channels=out.shape[0])
+                except Exception as exc:
+                    logger.warning("Error opening file %s: %r", file_meta.path, exc)
+                    if retry == self.max_read_retry - 1:
+                        raise
+                else:
+                    break
+
+        if self.return_info:
+            # Returns the wav and additional information on the wave segment
+            return out, segment_info
+        else:
+            return out
+
+    def collater(self, samples):
+        """The collater function has to be provided to the dataloader
+        if AudioDataset has return_info=True in order to properly collate
+        the samples of a batch.
+        """
+        if self.segment_duration is None and len(samples) > 1:
+            assert self.pad, "Must allow padding when batching examples of different durations."
+
+        # In this case the audio reaching the collater is of variable length as segment_duration=None.
+        to_pad = self.segment_duration is None and self.pad
+        if to_pad:
+            max_len = max([wav.shape[-1] for wav, _ in samples])
+
+            def _pad_wav(wav):
+                return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+        if self.return_info:
+            if len(samples) > 0:
+                assert len(samples[0]) == 2
+                assert isinstance(samples[0][0], torch.Tensor)
+                assert isinstance(samples[0][1], SegmentInfo)
+
+            wavs = [wav for wav, _ in samples]
+            segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+            if to_pad:
+                # Each wav could be of a different duration as they are not segmented.
+                for i in range(len(samples)):
+                    # Determines the total length of the signal with padding, so we update here as we pad.
+                    segment_infos[i].total_frames = max_len
+                    wavs[i] = _pad_wav(wavs[i])
+
+            wav = torch.stack(wavs)
+            return wav, segment_infos
+        else:
+            assert isinstance(samples[0], torch.Tensor)
+            if to_pad:
+                samples = [_pad_wav(s) for s in samples]
+            return torch.stack(samples)
+
+    def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+        """Filters out audio files with audio durations that will not allow to sample examples from them."""
+        orig_len = len(meta)
+
+        # Filter data that is too short.
+        if self.min_audio_duration is not None:
+            meta = [m for m in meta if m.duration >= self.min_audio_duration]
+
+        # Filter data that is too long.
+        if self.max_audio_duration is not None:
+            meta = [m for m in meta if m.duration <= self.max_audio_duration]
+
+        filtered_len = len(meta)
+        removed_percentage = 100*(1-float(filtered_len)/orig_len)
+        msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
+        if removed_percentage < 10:
+            logging.debug(msg)
+        else:
+            logging.warning(msg)
+        return meta
+
+    @classmethod
+    def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+        """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_dir():
+            if (root / 'data.jsonl').exists():
+                root = root / 'data.jsonl'
+            elif (root / 'data.jsonl.gz').exists():
+                root = root / 'data.jsonl.gz'
+            else:
+                raise ValueError("Don't know where to read metadata from in the dir. "
+                                 "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+        meta = load_audio_meta(root)
+        return cls(meta, **kwargs)
+
+    @classmethod
+    def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+                  exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+        """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            minimal_meta (bool): Whether to only load minimal metadata or not.
+            exts (list of str): Extensions for audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_file():
+            meta = load_audio_meta(root, resolve=True)
+        else:
+            meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+        return cls(meta, **kwargs)
+
+

Subclasses

+ +

Static methods

+
+
+def from_meta(root: Union[str, pathlib.Path], **kwargs) +
+
+

Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.

+

Args

+
+
root : str or Path
+
Path to root folder containing audio files.
+
kwargs
+
Additional keyword arguments for the AudioDataset.
+
+
+ +Expand source code + +
@classmethod
+def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+    """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+    Args:
+        root (str or Path): Path to root folder containing audio files.
+        kwargs: Additional keyword arguments for the AudioDataset.
+    """
+    root = Path(root)
+    if root.is_dir():
+        if (root / 'data.jsonl').exists():
+            root = root / 'data.jsonl'
+        elif (root / 'data.jsonl.gz').exists():
+            root = root / 'data.jsonl.gz'
+        else:
+            raise ValueError("Don't know where to read metadata from in the dir. "
+                             "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+    meta = load_audio_meta(root)
+    return cls(meta, **kwargs)
+
+
+
+def from_path(root: Union[str, pathlib.Path], minimal_meta: bool = True, exts: List[str] = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'], **kwargs) +
+
+

Instantiate AudioDataset from a path containing (possibly nested) audio files.

+

Args

+
+
root : str or Path
+
Path to root folder containing audio files.
+
minimal_meta : bool
+
Whether to only load minimal metadata or not.
+
exts : list of str
+
Extensions for audio files.
+
kwargs
+
Additional keyword arguments for the AudioDataset.
+
+
+ +Expand source code + +
@classmethod
+def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+              exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+    """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+    Args:
+        root (str or Path): Path to root folder containing audio files.
+        minimal_meta (bool): Whether to only load minimal metadata or not.
+        exts (list of str): Extensions for audio files.
+        kwargs: Additional keyword arguments for the AudioDataset.
+    """
+    root = Path(root)
+    if root.is_file():
+        meta = load_audio_meta(root, resolve=True)
+    else:
+        meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+    return cls(meta, **kwargs)
+
+
+
+

Methods

+
+
+def collater(self, samples) +
+
+

The collater function has to be provided to the dataloader +if AudioDataset has return_info=True in order to properly collate +the samples of a batch.

+
+ +Expand source code + +
def collater(self, samples):
+    """The collater function has to be provided to the dataloader
+    if AudioDataset has return_info=True in order to properly collate
+    the samples of a batch.
+    """
+    if self.segment_duration is None and len(samples) > 1:
+        assert self.pad, "Must allow padding when batching examples of different durations."
+
+    # In this case the audio reaching the collater is of variable length as segment_duration=None.
+    to_pad = self.segment_duration is None and self.pad
+    if to_pad:
+        max_len = max([wav.shape[-1] for wav, _ in samples])
+
+        def _pad_wav(wav):
+            return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+    if self.return_info:
+        if len(samples) > 0:
+            assert len(samples[0]) == 2
+            assert isinstance(samples[0][0], torch.Tensor)
+            assert isinstance(samples[0][1], SegmentInfo)
+
+        wavs = [wav for wav, _ in samples]
+        segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+        if to_pad:
+            # Each wav could be of a different duration as they are not segmented.
+            for i in range(len(samples)):
+                # Determines the total length of the signal with padding, so we update here as we pad.
+                segment_infos[i].total_frames = max_len
+                wavs[i] = _pad_wav(wavs[i])
+
+        wav = torch.stack(wavs)
+        return wav, segment_infos
+    else:
+        assert isinstance(samples[0], torch.Tensor)
+        if to_pad:
+            samples = [_pad_wav(s) for s in samples]
+        return torch.stack(samples)
+
+
+
+def sample_file(self, index: int, rng: torch._C.Generator) ‑> AudioMeta +
+
+

Sample a given file from self.meta. Can be overridden in subclasses. +This is only called if segment_duration is not None.

+

You must use the provided random number generator rng for reproducibility. +You can further make use of the index accessed.

+
+ +Expand source code + +
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
+    """Sample a given file from `self.meta`. Can be overridden in subclasses.
+    This is only called if `segment_duration` is not None.
+
+    You must use the provided random number generator `rng` for reproducibility.
+    You can further make use of the index accessed.
+    """
+    if self.permutation_on_files:
+        assert self.current_epoch is not None
+        total_index = self.current_epoch * len(self) + index
+        permutation_index = total_index // len(self.meta)
+        relative_index = total_index % len(self.meta)
+        permutation = AudioDataset._get_file_permutation(
+            len(self.meta), permutation_index, self.shuffle_seed)
+        file_index = permutation[relative_index]
+        return self.meta[file_index]
+
+    if not self.sample_on_weight and not self.sample_on_duration:
+        file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+    else:
+        file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+    return self.meta[file_index]
+
+
+
+def start_epoch(self, epoch: int) +
+
+
+
+ +Expand source code + +
def start_epoch(self, epoch: int):
+    self.current_epoch = epoch
+
+
+
+
+
+class AudioMeta +(path: str, duration: float, sample_rate: int, amplitude: Optional[float] = None, weight: Optional[float] = None, info_path: Optional[PathInZip] = None) +
+
+

AudioMeta(path: str, duration: float, sample_rate: int, amplitude: Union[float, NoneType] = None, weight: Union[float, NoneType] = None, info_path: Union[audiocraft.data.zip.PathInZip, NoneType] = None)

+
+ +Expand source code + +
class AudioMeta(BaseInfo):
+    path: str
+    duration: float
+    sample_rate: int
+    amplitude: tp.Optional[float] = None
+    weight: tp.Optional[float] = None
+    # info_path is used to load additional information about the audio file that is stored in zip files.
+    info_path: tp.Optional[PathInZip] = None
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        base = cls._dict2fields(dictionary)
+        if 'info_path' in base and base['info_path'] is not None:
+            base['info_path'] = PathInZip(base['info_path'])
+        return cls(**base)
+
+    def to_dict(self):
+        d = super().to_dict()
+        if d['info_path'] is not None:
+            d['info_path'] = str(d['info_path'])
+        return d
+
+

Ancestors

+ +

Class variables

+
+
var amplitude : Optional[float]
+
+
+
+
var duration : float
+
+
+
+
var info_path : Optional[PathInZip]
+
+
+
+
var path : str
+
+
+
+
var sample_rate : int
+
+
+
+
var weight : Optional[float]
+
+
+
+
+

Static methods

+
+
+def from_dict(dictionary: dict) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_dict(cls, dictionary: dict):
+    base = cls._dict2fields(dictionary)
+    if 'info_path' in base and base['info_path'] is not None:
+        base['info_path'] = PathInZip(base['info_path'])
+    return cls(**base)
+
+
+
+

Methods

+
+
+def to_dict(self) +
+
+
+
+ +Expand source code + +
def to_dict(self):
+    d = super().to_dict()
+    if d['info_path'] is not None:
+        d['info_path'] = str(d['info_path'])
+    return d
+
+
+
+
+
+class BaseInfo +
+
+

BaseInfo()

+
+ +Expand source code + +
class BaseInfo:
+
+    @classmethod
+    def _dict2fields(cls, dictionary: dict):
+        return {
+            field.name: dictionary[field.name]
+            for field in fields(cls) if field.name in dictionary
+        }
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        _dictionary = cls._dict2fields(dictionary)
+        return cls(**_dictionary)
+
+    def to_dict(self):
+        return {
+            field.name: self.__getattribute__(field.name)
+            for field in fields(self)
+            }
+
+

Subclasses

+ +

Static methods

+
+
+def from_dict(dictionary: dict) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_dict(cls, dictionary: dict):
+    _dictionary = cls._dict2fields(dictionary)
+    return cls(**_dictionary)
+
+
+
+

Methods

+
+
+def to_dict(self) +
+
+
+
+ +Expand source code + +
def to_dict(self):
+    return {
+        field.name: self.__getattribute__(field.name)
+        for field in fields(self)
+        }
+
+
+
+
+
+class SegmentInfo +(meta: AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int, channels: int) +
+
+

SegmentInfo(meta: audiocraft.data.audio_dataset.AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int, channels: int)

+
+ +Expand source code + +
class SegmentInfo(BaseInfo):
+    meta: AudioMeta
+    seek_time: float
+    # The following values are given once the audio is processed, e.g.
+    # at the target sample rate and target number of channels.
+    n_frames: int      # actual number of frames without padding
+    total_frames: int  # total number of frames, padding included
+    sample_rate: int   # actual sample rate
+    channels: int      # number of audio channels.
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var channels : int
+
+
+
+
var metaAudioMeta
+
+
+
+
var n_frames : int
+
+
+
+
var sample_rate : int
+
+
+
+
var seek_time : float
+
+
+
+
var total_frames : int
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/audio_utils.html b/api_docs/audiocraft/data/audio_utils.html new file mode 100644 index 00000000..b48d6db8 --- /dev/null +++ b/api_docs/audiocraft/data/audio_utils.html @@ -0,0 +1,528 @@ + + + + + + +audiocraft.data.audio_utils API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.audio_utils

+
+
+

Various utilities for audio convertion (pcm format, sample rate and channels), +and volume normalization.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Various utilities for audio convertion (pcm format, sample rate and channels),
+and volume normalization."""
+import sys
+import typing as tp
+
+import julius
+import torch
+import torchaudio
+
+
+def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
+    """Convert audio to the given number of channels.
+
+    Args:
+        wav (torch.Tensor): Audio wave of shape [B, C, T].
+        channels (int): Expected number of channels as output.
+    Returns:
+        torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
+    """
+    *shape, src_channels, length = wav.shape
+    if src_channels == channels:
+        pass
+    elif channels == 1:
+        # Case 1:
+        # The caller asked 1-channel audio, and the stream has multiple
+        # channels, downmix all channels.
+        wav = wav.mean(dim=-2, keepdim=True)
+    elif src_channels == 1:
+        # Case 2:
+        # The caller asked for multiple channels, but the input file has
+        # a single channel, replicate the audio over all channels.
+        wav = wav.expand(*shape, channels, length)
+    elif src_channels >= channels:
+        # Case 3:
+        # The caller asked for multiple channels, and the input file has
+        # more channels than requested. In that case return the first channels.
+        wav = wav[..., :channels, :]
+    else:
+        # Case 4: What is a reasonable choice here?
+        raise ValueError('The audio file has less channels than requested but is not mono.')
+    return wav
+
+
+def convert_audio(wav: torch.Tensor, from_rate: float,
+                  to_rate: float, to_channels: int) -> torch.Tensor:
+    """Convert audio to new sample rate and number of audio channels."""
+    wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
+    wav = convert_audio_channels(wav, to_channels)
+    return wav
+
+
+def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
+                       loudness_compressor: bool = False, energy_floor: float = 2e-3):
+    """Normalize an input signal to a user loudness in dB LKFS.
+    Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
+
+    Args:
+        wav (torch.Tensor): Input multichannel audio data.
+        sample_rate (int): Sample rate.
+        loudness_headroom_db (float): Target loudness of the output in dB LUFS.
+        loudness_compressor (bool): Uses tanh for soft clipping.
+        energy_floor (float): anything below that RMS level will not be rescaled.
+    Returns:
+        torch.Tensor: Loudness normalized output data.
+    """
+    energy = wav.pow(2).mean().sqrt().item()
+    if energy < energy_floor:
+        return wav
+    transform = torchaudio.transforms.Loudness(sample_rate)
+    input_loudness_db = transform(wav).item()
+    # calculate the gain needed to scale to the desired loudness level
+    delta_loudness = -loudness_headroom_db - input_loudness_db
+    gain = 10.0 ** (delta_loudness / 20.0)
+    output = gain * wav
+    if loudness_compressor:
+        output = torch.tanh(output)
+    assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
+    return output
+
+
+def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
+    """Utility function to clip the audio with logging if specified."""
+    max_scale = wav.abs().max()
+    if log_clipping and max_scale > 1:
+        clamp_prob = (wav.abs() > 1).float().mean().item()
+        print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
+              clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
+    wav.clamp_(-1, 1)
+
+
+def normalize_audio(wav: torch.Tensor, normalize: bool = True,
+                    strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                    rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                    loudness_compressor: bool = False, log_clipping: bool = False,
+                    sample_rate: tp.Optional[int] = None,
+                    stem_name: tp.Optional[str] = None) -> torch.Tensor:
+    """Normalize the audio according to the prescribed strategy (see after).
+
+    Args:
+        wav (torch.Tensor): Audio data.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): If True, uses tanh based soft clipping.
+        log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        sample_rate (int): Sample rate for the audio data (required for loudness).
+        stem_name (str, optional): Stem name for clipping logging.
+    Returns:
+        torch.Tensor: Normalized audio.
+    """
+    scale_peak = 10 ** (-peak_clip_headroom_db / 20)
+    scale_rms = 10 ** (-rms_headroom_db / 20)
+    if strategy == 'peak':
+        rescaling = (scale_peak / wav.abs().max())
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+    elif strategy == 'clip':
+        wav = wav.clamp(-scale_peak, scale_peak)
+    elif strategy == 'rms':
+        mono = wav.mean(dim=0)
+        rescaling = scale_rms / mono.pow(2).mean().sqrt()
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    elif strategy == 'loudness':
+        assert sample_rate is not None, "Loudness normalization requires sample rate."
+        wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    else:
+        assert wav.abs().max() < 1
+        assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
+    return wav
+
+
+def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to float 32 bits PCM format.
+    """
+    if wav.dtype.is_floating_point:
+        return wav
+    elif wav.dtype == torch.int16:
+        return wav.float() / 2**15
+    elif wav.dtype == torch.int32:
+        return wav.float() / 2**31
+    raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
+
+
+def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to int 16 bits PCM format.
+
+    ..Warning:: There exist many formula for doing this conversion. None are perfect
+    due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
+    or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
+    it is possible that `i16_pcm(f32_pcm)) != Identity`.
+    """
+    if wav.dtype.is_floating_point:
+        assert wav.abs().max() <= 1
+        candidate = (wav * 2 ** 15).round()
+        if candidate.max() >= 2 ** 15:  # clipping would occur
+            candidate = (wav * (2 ** 15 - 1)).round()
+        return candidate.short()
+    else:
+        assert wav.dtype == torch.int16
+        return wav
+
+
+
+
+
+
+
+

Functions

+
+
+def convert_audio(wav: torch.Tensor, from_rate: float, to_rate: float, to_channels: int) ‑> torch.Tensor +
+
+

Convert audio to new sample rate and number of audio channels.

+
+ +Expand source code + +
def convert_audio(wav: torch.Tensor, from_rate: float,
+                  to_rate: float, to_channels: int) -> torch.Tensor:
+    """Convert audio to new sample rate and number of audio channels."""
+    wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
+    wav = convert_audio_channels(wav, to_channels)
+    return wav
+
+
+
+def convert_audio_channels(wav: torch.Tensor, channels: int = 2) ‑> torch.Tensor +
+
+

Convert audio to the given number of channels.

+

Args

+
+
wav : torch.Tensor
+
Audio wave of shape [B, C, T].
+
channels : int
+
Expected number of channels as output.
+
+

Returns

+
+
torch.Tensor
+
Downmixed or unchanged audio wave [B, C, T].
+
+
+ +Expand source code + +
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
+    """Convert audio to the given number of channels.
+
+    Args:
+        wav (torch.Tensor): Audio wave of shape [B, C, T].
+        channels (int): Expected number of channels as output.
+    Returns:
+        torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
+    """
+    *shape, src_channels, length = wav.shape
+    if src_channels == channels:
+        pass
+    elif channels == 1:
+        # Case 1:
+        # The caller asked 1-channel audio, and the stream has multiple
+        # channels, downmix all channels.
+        wav = wav.mean(dim=-2, keepdim=True)
+    elif src_channels == 1:
+        # Case 2:
+        # The caller asked for multiple channels, but the input file has
+        # a single channel, replicate the audio over all channels.
+        wav = wav.expand(*shape, channels, length)
+    elif src_channels >= channels:
+        # Case 3:
+        # The caller asked for multiple channels, and the input file has
+        # more channels than requested. In that case return the first channels.
+        wav = wav[..., :channels, :]
+    else:
+        # Case 4: What is a reasonable choice here?
+        raise ValueError('The audio file has less channels than requested but is not mono.')
+    return wav
+
+
+
+def f32_pcm(wav: torch.Tensor) ‑> torch.Tensor +
+
+

Convert audio to float 32 bits PCM format.

+
+ +Expand source code + +
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to float 32 bits PCM format.
+    """
+    if wav.dtype.is_floating_point:
+        return wav
+    elif wav.dtype == torch.int16:
+        return wav.float() / 2**15
+    elif wav.dtype == torch.int32:
+        return wav.float() / 2**31
+    raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
+
+
+
+def i16_pcm(wav: torch.Tensor) ‑> torch.Tensor +
+
+

Convert audio to int 16 bits PCM format.

+
+

Warning: There exist many formula for doing this conversion. None are perfect

+
+

due to the asymmetry of the int16 range. One either have possible clipping, DC offset, +or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, +it is possible that i16_pcm(f32_pcm)) != Identity.

+
+ +Expand source code + +
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to int 16 bits PCM format.
+
+    ..Warning:: There exist many formula for doing this conversion. None are perfect
+    due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
+    or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
+    it is possible that `i16_pcm(f32_pcm)) != Identity`.
+    """
+    if wav.dtype.is_floating_point:
+        assert wav.abs().max() <= 1
+        candidate = (wav * 2 ** 15).round()
+        if candidate.max() >= 2 ** 15:  # clipping would occur
+            candidate = (wav * (2 ** 15 - 1)).round()
+        return candidate.short()
+    else:
+        assert wav.dtype == torch.int16
+        return wav
+
+
+
+def normalize_audio(wav: torch.Tensor, normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = False, sample_rate: Optional[int] = None, stem_name: Optional[str] = None) ‑> torch.Tensor +
+
+

Normalize the audio according to the prescribed strategy (see after).

+

Args

+
+
wav : torch.Tensor
+
Audio data.
+
normalize : bool
+
if True (default), normalizes according to the prescribed +strategy (see after). If False, the strategy is only used in case clipping +would happen.
+
strategy : str
+
Can be either 'clip', 'peak', or 'rms'. Default is 'peak', +i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square +with extra headroom to avoid clipping. 'clip' just clips.
+
peak_clip_headroom_db : float
+
Headroom in dB when doing 'peak' or 'clip' strategy.
+
rms_headroom_db : float
+
Headroom in dB when doing 'rms' strategy. This must be much larger +than the peak_clip one to avoid further clipping.
+
loudness_headroom_db : float
+
Target loudness for loudness normalization.
+
loudness_compressor : bool
+
If True, uses tanh based soft clipping.
+
log_clipping : bool
+
If True, basic logging on stderr when clipping still +occurs despite strategy (only for 'rms').
+
sample_rate : int
+
Sample rate for the audio data (required for loudness).
+
stem_name : str, optional
+
Stem name for clipping logging.
+
+

Returns

+
+
torch.Tensor
+
Normalized audio.
+
+
+ +Expand source code + +
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
+                    strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                    rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                    loudness_compressor: bool = False, log_clipping: bool = False,
+                    sample_rate: tp.Optional[int] = None,
+                    stem_name: tp.Optional[str] = None) -> torch.Tensor:
+    """Normalize the audio according to the prescribed strategy (see after).
+
+    Args:
+        wav (torch.Tensor): Audio data.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): If True, uses tanh based soft clipping.
+        log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        sample_rate (int): Sample rate for the audio data (required for loudness).
+        stem_name (str, optional): Stem name for clipping logging.
+    Returns:
+        torch.Tensor: Normalized audio.
+    """
+    scale_peak = 10 ** (-peak_clip_headroom_db / 20)
+    scale_rms = 10 ** (-rms_headroom_db / 20)
+    if strategy == 'peak':
+        rescaling = (scale_peak / wav.abs().max())
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+    elif strategy == 'clip':
+        wav = wav.clamp(-scale_peak, scale_peak)
+    elif strategy == 'rms':
+        mono = wav.mean(dim=0)
+        rescaling = scale_rms / mono.pow(2).mean().sqrt()
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    elif strategy == 'loudness':
+        assert sample_rate is not None, "Loudness normalization requires sample rate."
+        wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    else:
+        assert wav.abs().max() < 1
+        assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
+    return wav
+
+
+
+def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, loudness_compressor: bool = False, energy_floor: float = 0.002) +
+
+

Normalize an input signal to a user loudness in dB LKFS. +Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.

+

Args

+
+
wav : torch.Tensor
+
Input multichannel audio data.
+
sample_rate : int
+
Sample rate.
+
loudness_headroom_db : float
+
Target loudness of the output in dB LUFS.
+
loudness_compressor : bool
+
Uses tanh for soft clipping.
+
energy_floor : float
+
anything below that RMS level will not be rescaled.
+
+

Returns

+
+
torch.Tensor
+
Loudness normalized output data.
+
+
+ +Expand source code + +
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
+                       loudness_compressor: bool = False, energy_floor: float = 2e-3):
+    """Normalize an input signal to a user loudness in dB LKFS.
+    Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
+
+    Args:
+        wav (torch.Tensor): Input multichannel audio data.
+        sample_rate (int): Sample rate.
+        loudness_headroom_db (float): Target loudness of the output in dB LUFS.
+        loudness_compressor (bool): Uses tanh for soft clipping.
+        energy_floor (float): anything below that RMS level will not be rescaled.
+    Returns:
+        torch.Tensor: Loudness normalized output data.
+    """
+    energy = wav.pow(2).mean().sqrt().item()
+    if energy < energy_floor:
+        return wav
+    transform = torchaudio.transforms.Loudness(sample_rate)
+    input_loudness_db = transform(wav).item()
+    # calculate the gain needed to scale to the desired loudness level
+    delta_loudness = -loudness_headroom_db - input_loudness_db
+    gain = 10.0 ** (delta_loudness / 20.0)
+    output = gain * wav
+    if loudness_compressor:
+        output = torch.tanh(output)
+    assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
+    return output
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/index.html b/api_docs/audiocraft/data/index.html new file mode 100644 index 00000000..525d13b5 --- /dev/null +++ b/api_docs/audiocraft/data/index.html @@ -0,0 +1,118 @@ + + + + + + +audiocraft.data API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data

+
+
+

Audio loading and writing support. Datasets for raw audio +or also including some metadata.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Audio loading and writing support. Datasets for raw audio
+or also including some metadata."""
+
+# flake8: noqa
+from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
+
+
+
+

Sub-modules

+
+
audiocraft.data.audio
+
+

Audio IO methods are defined in this module (info, read, write), +We rely on av library for faster read when possible, otherwise on torchaudio.

+
+
audiocraft.data.audio_dataset
+
+

AudioDataset support. In order to handle a larger number of files +without having to scan again the folders, we precompute some metadata +(filename, …

+
+
audiocraft.data.audio_utils
+
+

Various utilities for audio convertion (pcm format, sample rate and channels), +and volume normalization.

+
+
audiocraft.data.info_audio_dataset
+
+

Base classes for the datasets that also provide non-audio metadata, +e.g. description, text transcription etc.

+
+
audiocraft.data.music_dataset
+
+

Dataset of music tracks with rich metadata.

+
+
audiocraft.data.sound_dataset
+
+

Dataset of audio with a simple description.

+
+
audiocraft.data.zip
+
+

Utility for reading some info from inside a zip file.

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/info_audio_dataset.html b/api_docs/audiocraft/data/info_audio_dataset.html new file mode 100644 index 00000000..c0b269b3 --- /dev/null +++ b/api_docs/audiocraft/data/info_audio_dataset.html @@ -0,0 +1,402 @@ + + + + + + +audiocraft.data.info_audio_dataset API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.info_audio_dataset

+
+
+

Base classes for the datasets that also provide non-audio metadata, +e.g. description, text transcription etc.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Base classes for the datasets that also provide non-audio metadata,
+e.g. description, text transcription etc.
+"""
+from dataclasses import dataclass
+import logging
+import math
+import re
+import typing as tp
+
+import torch
+
+from .audio_dataset import AudioDataset, AudioMeta
+from ..environment import AudioCraftEnvironment
+from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
+
+
+logger = logging.getLogger(__name__)
+
+
+def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
+    """Monkey-patch meta to match cluster specificities."""
+    meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
+    if meta.info_path is not None:
+        meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
+    return meta
+
+
+def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+    """Monkey-patch all meta to match cluster specificities."""
+    return [_clusterify_meta(m) for m in meta]
+
+
+@dataclass
+class AudioInfo(SegmentWithAttributes):
+    """Dummy SegmentInfo with empty attributes.
+
+    The InfoAudioDataset is expected to return metadata that inherits
+    from SegmentWithAttributes class and can return conditioning attributes.
+
+    This basically guarantees all datasets will be compatible with current
+    solver that contain conditioners requiring this.
+    """
+    audio_tokens: tp.Optional[torch.Tensor] = None  # populated when using cached batch for training a LM.
+
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        return ConditioningAttributes()
+
+
+class InfoAudioDataset(AudioDataset):
+    """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
+
+    See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
+    """
+    def __init__(self, meta: tp.List[AudioMeta], **kwargs):
+        super().__init__(clusterify_all_meta(meta), **kwargs)
+
+    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
+        if not self.return_info:
+            wav = super().__getitem__(index)
+            assert isinstance(wav, torch.Tensor)
+            return wav
+        wav, meta = super().__getitem__(index)
+        return wav, AudioInfo(**meta.to_dict())
+
+
+def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
+    """Preprocess a single keyword or possible a list of keywords."""
+    if isinstance(value, list):
+        return get_keyword_list(value)
+    else:
+        return get_keyword(value)
+
+
+def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
+    """Preprocess a single keyword."""
+    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+        return None
+    else:
+        return value.strip()
+
+
+def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
+    """Preprocess a single keyword."""
+    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+        return None
+    else:
+        return value.strip().lower()
+
+
+def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
+    """Preprocess a list of keywords."""
+    if isinstance(values, str):
+        values = [v.strip() for v in re.split(r'[,\s]', values)]
+    elif isinstance(values, float) and math.isnan(values):
+        values = []
+    if not isinstance(values, list):
+        logger.debug(f"Unexpected keyword list {values}")
+        values = [str(values)]
+
+    kws = [get_keyword(v) for v in values]
+    kw_list = [k for k in kws if k is not None]
+    if len(kw_list) == 0:
+        return None
+    else:
+        return kw_list
+
+
+
+
+
+
+
+

Functions

+
+
+def clusterify_all_meta(meta: List[AudioMeta]) ‑> List[AudioMeta] +
+
+

Monkey-patch all meta to match cluster specificities.

+
+ +Expand source code + +
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+    """Monkey-patch all meta to match cluster specificities."""
+    return [_clusterify_meta(m) for m in meta]
+
+
+
+def get_keyword(value: Optional[str]) ‑> Optional[str] +
+
+

Preprocess a single keyword.

+
+ +Expand source code + +
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
+    """Preprocess a single keyword."""
+    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+        return None
+    else:
+        return value.strip().lower()
+
+
+
+def get_keyword_list(values: Union[str, List[str]]) ‑> Optional[List[str]] +
+
+

Preprocess a list of keywords.

+
+ +Expand source code + +
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
+    """Preprocess a list of keywords."""
+    if isinstance(values, str):
+        values = [v.strip() for v in re.split(r'[,\s]', values)]
+    elif isinstance(values, float) and math.isnan(values):
+        values = []
+    if not isinstance(values, list):
+        logger.debug(f"Unexpected keyword list {values}")
+        values = [str(values)]
+
+    kws = [get_keyword(v) for v in values]
+    kw_list = [k for k in kws if k is not None]
+    if len(kw_list) == 0:
+        return None
+    else:
+        return kw_list
+
+
+
+def get_keyword_or_keyword_list(value: Optional[str]) ‑> Union[str, None, List[str]] +
+
+

Preprocess a single keyword or possible a list of keywords.

+
+ +Expand source code + +
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
+    """Preprocess a single keyword or possible a list of keywords."""
+    if isinstance(value, list):
+        return get_keyword_list(value)
+    else:
+        return get_keyword(value)
+
+
+
+def get_string(value: Optional[str]) ‑> Optional[str] +
+
+

Preprocess a single keyword.

+
+ +Expand source code + +
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
+    """Preprocess a single keyword."""
+    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+        return None
+    else:
+        return value.strip()
+
+
+
+
+
+

Classes

+
+
+class AudioInfo +(meta: AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int, channels: int, audio_tokens: Optional[torch.Tensor] = None) +
+
+

Dummy SegmentInfo with empty attributes.

+

The InfoAudioDataset is expected to return metadata that inherits +from SegmentWithAttributes class and can return conditioning attributes.

+

This basically guarantees all datasets will be compatible with current +solver that contain conditioners requiring this.

+
+ +Expand source code + +
class AudioInfo(SegmentWithAttributes):
+    """Dummy SegmentInfo with empty attributes.
+
+    The InfoAudioDataset is expected to return metadata that inherits
+    from SegmentWithAttributes class and can return conditioning attributes.
+
+    This basically guarantees all datasets will be compatible with current
+    solver that contain conditioners requiring this.
+    """
+    audio_tokens: tp.Optional[torch.Tensor] = None  # populated when using cached batch for training a LM.
+
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        return ConditioningAttributes()
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var audio_tokens : Optional[torch.Tensor]
+
+
+
+
+

Methods

+
+
+def to_condition_attributes(self) ‑> ConditioningAttributes +
+
+
+
+ +Expand source code + +
def to_condition_attributes(self) -> ConditioningAttributes:
+    return ConditioningAttributes()
+
+
+
+
+
+class InfoAudioDataset +(meta: List[AudioMeta], **kwargs) +
+
+

AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.

+

See AudioDataset for initialization arguments.

+
+ +Expand source code + +
class InfoAudioDataset(AudioDataset):
+    """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
+
+    See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
+    """
+    def __init__(self, meta: tp.List[AudioMeta], **kwargs):
+        super().__init__(clusterify_all_meta(meta), **kwargs)
+
+    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
+        if not self.return_info:
+            wav = super().__getitem__(index)
+            assert isinstance(wav, torch.Tensor)
+            return wav
+        wav, meta = super().__getitem__(index)
+        return wav, AudioInfo(**meta.to_dict())
+
+

Ancestors

+ +

Subclasses

+ +

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/music_dataset.html b/api_docs/audiocraft/data/music_dataset.html new file mode 100644 index 00000000..7c767bba --- /dev/null +++ b/api_docs/audiocraft/data/music_dataset.html @@ -0,0 +1,913 @@ + + + + + + +audiocraft.data.music_dataset API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.music_dataset

+
+
+

Dataset of music tracks with rich metadata.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Dataset of music tracks with rich metadata.
+"""
+from dataclasses import dataclass, field, fields, replace
+import gzip
+import json
+import logging
+from pathlib import Path
+import random
+import typing as tp
+
+import torch
+
+from .info_audio_dataset import (
+    InfoAudioDataset,
+    AudioInfo,
+    get_keyword_list,
+    get_keyword,
+    get_string
+)
+from ..modules.conditioners import (
+    ConditioningAttributes,
+    JointEmbedCondition,
+    WavCondition,
+)
+from ..utils.utils import warn_once
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class MusicInfo(AudioInfo):
+    """Segment info augmented with music metadata.
+    """
+    # music-specific metadata
+    title: tp.Optional[str] = None
+    artist: tp.Optional[str] = None  # anonymized artist id, used to ensure no overlap between splits
+    key: tp.Optional[str] = None
+    bpm: tp.Optional[float] = None
+    genre: tp.Optional[str] = None
+    moods: tp.Optional[list] = None
+    keywords: tp.Optional[list] = None
+    description: tp.Optional[str] = None
+    name: tp.Optional[str] = None
+    instrument: tp.Optional[str] = None
+    # original wav accompanying the metadata
+    self_wav: tp.Optional[WavCondition] = None
+    # dict mapping attributes names to tuple of wav, text and metadata
+    joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
+
+    @property
+    def has_music_meta(self) -> bool:
+        return self.name is not None
+
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        out = ConditioningAttributes()
+        for _field in fields(self):
+            key, value = _field.name, getattr(self, _field.name)
+            if key == 'self_wav':
+                out.wav[key] = value
+            elif key == 'joint_embed':
+                for embed_attribute, embed_cond in value.items():
+                    out.joint_embed[embed_attribute] = embed_cond
+            else:
+                if isinstance(value, list):
+                    value = ' '.join(value)
+                out.text[key] = value
+        return out
+
+    @staticmethod
+    def attribute_getter(attribute):
+        if attribute == 'bpm':
+            preprocess_func = get_bpm
+        elif attribute == 'key':
+            preprocess_func = get_musical_key
+        elif attribute in ['moods', 'keywords']:
+            preprocess_func = get_keyword_list
+        elif attribute in ['genre', 'name', 'instrument']:
+            preprocess_func = get_keyword
+        elif attribute in ['title', 'artist', 'description']:
+            preprocess_func = get_string
+        else:
+            preprocess_func = None
+        return preprocess_func
+
+    @classmethod
+    def from_dict(cls, dictionary: dict, fields_required: bool = False):
+        _dictionary: tp.Dict[str, tp.Any] = {}
+
+        # allow a subset of attributes to not be loaded from the dictionary
+        # these attributes may be populated later
+        post_init_attributes = ['self_wav', 'joint_embed']
+        optional_fields = ['keywords']
+
+        for _field in fields(cls):
+            if _field.name in post_init_attributes:
+                continue
+            elif _field.name not in dictionary:
+                if fields_required and _field.name not in optional_fields:
+                    raise KeyError(f"Unexpected missing key: {_field.name}")
+            else:
+                preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+                value = dictionary[_field.name]
+                if preprocess_func:
+                    value = preprocess_func(value)
+                _dictionary[_field.name] = value
+        return cls(**_dictionary)
+
+
+def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
+                                   drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
+    """Augment MusicInfo description with additional metadata fields and potential dropout.
+    Additional textual attributes are added given probability 'merge_text_conditions_p' and
+    the original textual description is dropped from the augmented description given probability drop_desc_p.
+
+    Args:
+        music_info (MusicInfo): The music metadata to augment.
+        merge_text_p (float): Probability of merging additional metadata to the description.
+            If provided value is 0, then no merging is performed.
+        drop_desc_p (float): Probability of dropping the original description on text merge.
+            if provided value is 0, then no drop out is performed.
+        drop_other_p (float): Probability of dropping the other fields used for text augmentation.
+    Returns:
+        MusicInfo: The MusicInfo with augmented textual description.
+    """
+    def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
+        valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
+        valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
+        keep_field = random.uniform(0, 1) < drop_other_p
+        return valid_field_name and valid_field_value and keep_field
+
+    def process_value(v: tp.Any) -> str:
+        if isinstance(v, (int, float, str)):
+            return str(v)
+        if isinstance(v, list):
+            return ", ".join(v)
+        else:
+            raise ValueError(f"Unknown type for text value! ({type(v), v})")
+
+    description = music_info.description
+
+    metadata_text = ""
+    if random.uniform(0, 1) < merge_text_p:
+        meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
+                      for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
+        random.shuffle(meta_pairs)
+        metadata_text = ". ".join(meta_pairs)
+        description = description if not random.uniform(0, 1) < drop_desc_p else None
+        logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
+
+    if description is None:
+        description = metadata_text if len(metadata_text) > 1 else None
+    else:
+        description = ". ".join([description.rstrip('.'), metadata_text])
+    description = description.strip() if description else None
+
+    music_info = replace(music_info)
+    music_info.description = description
+    return music_info
+
+
+class Paraphraser:
+    def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
+        self.paraphrase_p = paraphrase_p
+        open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
+        with open_fn(paraphrase_source, 'rb') as f:  # type: ignore
+            self.paraphrase_source = json.loads(f.read())
+        logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
+
+    def sample_paraphrase(self, audio_path: str, description: str):
+        if random.random() >= self.paraphrase_p:
+            return description
+        info_path = Path(audio_path).with_suffix('.json')
+        if info_path not in self.paraphrase_source:
+            warn_once(logger, f"{info_path} not in paraphrase source!")
+            return description
+        new_desc = random.choice(self.paraphrase_source[info_path])
+        logger.debug(f"{description} -> {new_desc}")
+        return new_desc
+
+
+class MusicDataset(InfoAudioDataset):
+    """Music dataset is an AudioDataset with music-related metadata.
+
+    Args:
+        info_fields_required (bool): Whether to enforce having required fields.
+        merge_text_p (float): Probability of merging additional metadata to the description.
+        drop_desc_p (float): Probability of dropping the original description on text merge.
+        drop_other_p (float): Probability of dropping the other fields used for text augmentation.
+        joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
+        paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
+            paraphrases for the description. The json should be a dict with keys are the
+            original info path (e.g. track_path.json) and each value is a list of possible
+            paraphrased.
+        paraphrase_p (float): probability of taking a paraphrase.
+
+    See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
+    """
+    def __init__(self, *args, info_fields_required: bool = True,
+                 merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
+                 joint_embed_attributes: tp.List[str] = [],
+                 paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
+                 **kwargs):
+        kwargs['return_info'] = True  # We require the info for each song of the dataset.
+        super().__init__(*args, **kwargs)
+        self.info_fields_required = info_fields_required
+        self.merge_text_p = merge_text_p
+        self.drop_desc_p = drop_desc_p
+        self.drop_other_p = drop_other_p
+        self.joint_embed_attributes = joint_embed_attributes
+        self.paraphraser = None
+        if paraphrase_source is not None:
+            self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
+
+    def __getitem__(self, index):
+        wav, info = super().__getitem__(index)
+        info_data = info.to_dict()
+        music_info_path = Path(info.meta.path).with_suffix('.json')
+
+        if Path(music_info_path).exists():
+            with open(music_info_path, 'r') as json_file:
+                music_data = json.load(json_file)
+                music_data.update(info_data)
+                music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
+            if self.paraphraser is not None:
+                music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
+            if self.merge_text_p:
+                music_info = augment_music_info_description(
+                    music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
+        else:
+            music_info = MusicInfo.from_dict(info_data, fields_required=False)
+
+        music_info.self_wav = WavCondition(
+            wav=wav[None], length=torch.tensor([info.n_frames]),
+            sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+
+        for att in self.joint_embed_attributes:
+            att_value = getattr(music_info, att)
+            joint_embed_cond = JointEmbedCondition(
+                wav[None], [att_value], torch.tensor([info.n_frames]),
+                sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+            music_info.joint_embed[att] = joint_embed_cond
+
+        return wav, music_info
+
+
+def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
+    """Preprocess key keywords, discarding them if there are multiple key defined."""
+    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+        return None
+    elif ',' in value:
+        # For now, we discard when multiple keys are defined separated with comas
+        return None
+    else:
+        return value.strip().lower()
+
+
+def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
+    """Preprocess to a float."""
+    if value is None:
+        return None
+    try:
+        return float(value)
+    except ValueError:
+        return None
+
+
+
+
+
+
+
+

Functions

+
+
+def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.0, drop_desc_p: float = 0.0, drop_other_p: float = 0.0) ‑> MusicInfo +
+
+

Augment MusicInfo description with additional metadata fields and potential dropout. +Additional textual attributes are added given probability 'merge_text_conditions_p' and +the original textual description is dropped from the augmented description given probability drop_desc_p.

+

Args

+
+
music_info : MusicInfo
+
The music metadata to augment.
+
merge_text_p : float
+
Probability of merging additional metadata to the description. +If provided value is 0, then no merging is performed.
+
drop_desc_p : float
+
Probability of dropping the original description on text merge. +if provided value is 0, then no drop out is performed.
+
drop_other_p : float
+
Probability of dropping the other fields used for text augmentation.
+
+

Returns

+
+
MusicInfo
+
The MusicInfo with augmented textual description.
+
+
+ +Expand source code + +
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
+                                   drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
+    """Augment MusicInfo description with additional metadata fields and potential dropout.
+    Additional textual attributes are added given probability 'merge_text_conditions_p' and
+    the original textual description is dropped from the augmented description given probability drop_desc_p.
+
+    Args:
+        music_info (MusicInfo): The music metadata to augment.
+        merge_text_p (float): Probability of merging additional metadata to the description.
+            If provided value is 0, then no merging is performed.
+        drop_desc_p (float): Probability of dropping the original description on text merge.
+            if provided value is 0, then no drop out is performed.
+        drop_other_p (float): Probability of dropping the other fields used for text augmentation.
+    Returns:
+        MusicInfo: The MusicInfo with augmented textual description.
+    """
+    def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
+        valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
+        valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
+        keep_field = random.uniform(0, 1) < drop_other_p
+        return valid_field_name and valid_field_value and keep_field
+
+    def process_value(v: tp.Any) -> str:
+        if isinstance(v, (int, float, str)):
+            return str(v)
+        if isinstance(v, list):
+            return ", ".join(v)
+        else:
+            raise ValueError(f"Unknown type for text value! ({type(v), v})")
+
+    description = music_info.description
+
+    metadata_text = ""
+    if random.uniform(0, 1) < merge_text_p:
+        meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
+                      for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
+        random.shuffle(meta_pairs)
+        metadata_text = ". ".join(meta_pairs)
+        description = description if not random.uniform(0, 1) < drop_desc_p else None
+        logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
+
+    if description is None:
+        description = metadata_text if len(metadata_text) > 1 else None
+    else:
+        description = ". ".join([description.rstrip('.'), metadata_text])
+    description = description.strip() if description else None
+
+    music_info = replace(music_info)
+    music_info.description = description
+    return music_info
+
+
+
+def get_bpm(value: Optional[str]) ‑> Optional[float] +
+
+

Preprocess to a float.

+
+ +Expand source code + +
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
+    """Preprocess to a float."""
+    if value is None:
+        return None
+    try:
+        return float(value)
+    except ValueError:
+        return None
+
+
+
+def get_musical_key(value: Optional[str]) ‑> Optional[str] +
+
+

Preprocess key keywords, discarding them if there are multiple key defined.

+
+ +Expand source code + +
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
+    """Preprocess key keywords, discarding them if there are multiple key defined."""
+    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
+        return None
+    elif ',' in value:
+        # For now, we discard when multiple keys are defined separated with comas
+        return None
+    else:
+        return value.strip().lower()
+
+
+
+
+
+

Classes

+
+
+class MusicDataset +(*args, info_fields_required: bool = True, merge_text_p: float = 0.0, drop_desc_p: float = 0.0, drop_other_p: float = 0.0, joint_embed_attributes: List[str] = [], paraphrase_source: Optional[str] = None, paraphrase_p: float = 0, **kwargs) +
+
+

Music dataset is an AudioDataset with music-related metadata.

+

Args

+
+
info_fields_required : bool
+
Whether to enforce having required fields.
+
merge_text_p : float
+
Probability of merging additional metadata to the description.
+
drop_desc_p : float
+
Probability of dropping the original description on text merge.
+
drop_other_p : float
+
Probability of dropping the other fields used for text augmentation.
+
joint_embed_attributes : list[str]
+
A list of attributes for which joint embedding metadata is returned.
+
paraphrase_source : str, optional
+
Path to the .json or .json.gz file containing the +paraphrases for the description. The json should be a dict with keys are the +original info path (e.g. track_path.json) and each value is a list of possible +paraphrased.
+
paraphrase_p : float
+
probability of taking a paraphrase.
+
+

See InfoAudioDataset for full initialization arguments.

+
+ +Expand source code + +
class MusicDataset(InfoAudioDataset):
+    """Music dataset is an AudioDataset with music-related metadata.
+
+    Args:
+        info_fields_required (bool): Whether to enforce having required fields.
+        merge_text_p (float): Probability of merging additional metadata to the description.
+        drop_desc_p (float): Probability of dropping the original description on text merge.
+        drop_other_p (float): Probability of dropping the other fields used for text augmentation.
+        joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
+        paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
+            paraphrases for the description. The json should be a dict with keys are the
+            original info path (e.g. track_path.json) and each value is a list of possible
+            paraphrased.
+        paraphrase_p (float): probability of taking a paraphrase.
+
+    See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
+    """
+    def __init__(self, *args, info_fields_required: bool = True,
+                 merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
+                 joint_embed_attributes: tp.List[str] = [],
+                 paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
+                 **kwargs):
+        kwargs['return_info'] = True  # We require the info for each song of the dataset.
+        super().__init__(*args, **kwargs)
+        self.info_fields_required = info_fields_required
+        self.merge_text_p = merge_text_p
+        self.drop_desc_p = drop_desc_p
+        self.drop_other_p = drop_other_p
+        self.joint_embed_attributes = joint_embed_attributes
+        self.paraphraser = None
+        if paraphrase_source is not None:
+            self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
+
+    def __getitem__(self, index):
+        wav, info = super().__getitem__(index)
+        info_data = info.to_dict()
+        music_info_path = Path(info.meta.path).with_suffix('.json')
+
+        if Path(music_info_path).exists():
+            with open(music_info_path, 'r') as json_file:
+                music_data = json.load(json_file)
+                music_data.update(info_data)
+                music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
+            if self.paraphraser is not None:
+                music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
+            if self.merge_text_p:
+                music_info = augment_music_info_description(
+                    music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
+        else:
+            music_info = MusicInfo.from_dict(info_data, fields_required=False)
+
+        music_info.self_wav = WavCondition(
+            wav=wav[None], length=torch.tensor([info.n_frames]),
+            sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+
+        for att in self.joint_embed_attributes:
+            att_value = getattr(music_info, att)
+            joint_embed_cond = JointEmbedCondition(
+                wav[None], [att_value], torch.tensor([info.n_frames]),
+                sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+            music_info.joint_embed[att] = joint_embed_cond
+
+        return wav, music_info
+
+

Ancestors

+ +

Inherited members

+ +
+
+class MusicInfo +(meta: AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int, channels: int, audio_tokens: Optional[torch.Tensor] = None, title: Optional[str] = None, artist: Optional[str] = None, key: Optional[str] = None, bpm: Optional[float] = None, genre: Optional[str] = None, moods: Optional[list] = None, keywords: Optional[list] = None, description: Optional[str] = None, name: Optional[str] = None, instrument: Optional[str] = None, self_wav: Optional[WavCondition] = None, joint_embed: Dict[str, JointEmbedCondition] = <factory>) +
+
+

Segment info augmented with music metadata.

+
+ +Expand source code + +
class MusicInfo(AudioInfo):
+    """Segment info augmented with music metadata.
+    """
+    # music-specific metadata
+    title: tp.Optional[str] = None
+    artist: tp.Optional[str] = None  # anonymized artist id, used to ensure no overlap between splits
+    key: tp.Optional[str] = None
+    bpm: tp.Optional[float] = None
+    genre: tp.Optional[str] = None
+    moods: tp.Optional[list] = None
+    keywords: tp.Optional[list] = None
+    description: tp.Optional[str] = None
+    name: tp.Optional[str] = None
+    instrument: tp.Optional[str] = None
+    # original wav accompanying the metadata
+    self_wav: tp.Optional[WavCondition] = None
+    # dict mapping attributes names to tuple of wav, text and metadata
+    joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
+
+    @property
+    def has_music_meta(self) -> bool:
+        return self.name is not None
+
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        out = ConditioningAttributes()
+        for _field in fields(self):
+            key, value = _field.name, getattr(self, _field.name)
+            if key == 'self_wav':
+                out.wav[key] = value
+            elif key == 'joint_embed':
+                for embed_attribute, embed_cond in value.items():
+                    out.joint_embed[embed_attribute] = embed_cond
+            else:
+                if isinstance(value, list):
+                    value = ' '.join(value)
+                out.text[key] = value
+        return out
+
+    @staticmethod
+    def attribute_getter(attribute):
+        if attribute == 'bpm':
+            preprocess_func = get_bpm
+        elif attribute == 'key':
+            preprocess_func = get_musical_key
+        elif attribute in ['moods', 'keywords']:
+            preprocess_func = get_keyword_list
+        elif attribute in ['genre', 'name', 'instrument']:
+            preprocess_func = get_keyword
+        elif attribute in ['title', 'artist', 'description']:
+            preprocess_func = get_string
+        else:
+            preprocess_func = None
+        return preprocess_func
+
+    @classmethod
+    def from_dict(cls, dictionary: dict, fields_required: bool = False):
+        _dictionary: tp.Dict[str, tp.Any] = {}
+
+        # allow a subset of attributes to not be loaded from the dictionary
+        # these attributes may be populated later
+        post_init_attributes = ['self_wav', 'joint_embed']
+        optional_fields = ['keywords']
+
+        for _field in fields(cls):
+            if _field.name in post_init_attributes:
+                continue
+            elif _field.name not in dictionary:
+                if fields_required and _field.name not in optional_fields:
+                    raise KeyError(f"Unexpected missing key: {_field.name}")
+            else:
+                preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+                value = dictionary[_field.name]
+                if preprocess_func:
+                    value = preprocess_func(value)
+                _dictionary[_field.name] = value
+        return cls(**_dictionary)
+
+

Ancestors

+ +

Class variables

+
+
var artist : Optional[str]
+
+
+
+
var bpm : Optional[float]
+
+
+
+
var description : Optional[str]
+
+
+
+
var genre : Optional[str]
+
+
+
+
var instrument : Optional[str]
+
+
+
+
var joint_embed : Dict[str, JointEmbedCondition]
+
+
+
+
var key : Optional[str]
+
+
+
+
var keywords : Optional[list]
+
+
+
+
var moods : Optional[list]
+
+
+
+
var name : Optional[str]
+
+
+
+
var self_wav : Optional[WavCondition]
+
+
+
+
var title : Optional[str]
+
+
+
+
+

Static methods

+
+
+def attribute_getter(attribute) +
+
+
+
+ +Expand source code + +
@staticmethod
+def attribute_getter(attribute):
+    if attribute == 'bpm':
+        preprocess_func = get_bpm
+    elif attribute == 'key':
+        preprocess_func = get_musical_key
+    elif attribute in ['moods', 'keywords']:
+        preprocess_func = get_keyword_list
+    elif attribute in ['genre', 'name', 'instrument']:
+        preprocess_func = get_keyword
+    elif attribute in ['title', 'artist', 'description']:
+        preprocess_func = get_string
+    else:
+        preprocess_func = None
+    return preprocess_func
+
+
+
+def from_dict(dictionary: dict, fields_required: bool = False) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_dict(cls, dictionary: dict, fields_required: bool = False):
+    _dictionary: tp.Dict[str, tp.Any] = {}
+
+    # allow a subset of attributes to not be loaded from the dictionary
+    # these attributes may be populated later
+    post_init_attributes = ['self_wav', 'joint_embed']
+    optional_fields = ['keywords']
+
+    for _field in fields(cls):
+        if _field.name in post_init_attributes:
+            continue
+        elif _field.name not in dictionary:
+            if fields_required and _field.name not in optional_fields:
+                raise KeyError(f"Unexpected missing key: {_field.name}")
+        else:
+            preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+            value = dictionary[_field.name]
+            if preprocess_func:
+                value = preprocess_func(value)
+            _dictionary[_field.name] = value
+    return cls(**_dictionary)
+
+
+
+

Instance variables

+
+
var has_music_meta : bool
+
+
+
+ +Expand source code + +
@property
+def has_music_meta(self) -> bool:
+    return self.name is not None
+
+
+
+

Methods

+
+
+def to_condition_attributes(self) ‑> ConditioningAttributes +
+
+
+
+ +Expand source code + +
def to_condition_attributes(self) -> ConditioningAttributes:
+    out = ConditioningAttributes()
+    for _field in fields(self):
+        key, value = _field.name, getattr(self, _field.name)
+        if key == 'self_wav':
+            out.wav[key] = value
+        elif key == 'joint_embed':
+            for embed_attribute, embed_cond in value.items():
+                out.joint_embed[embed_attribute] = embed_cond
+        else:
+            if isinstance(value, list):
+                value = ' '.join(value)
+            out.text[key] = value
+    return out
+
+
+
+
+
+class Paraphraser +(paraphrase_source: Union[str, pathlib.Path], paraphrase_p: float = 0.0) +
+
+
+
+ +Expand source code + +
class Paraphraser:
+    def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
+        self.paraphrase_p = paraphrase_p
+        open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
+        with open_fn(paraphrase_source, 'rb') as f:  # type: ignore
+            self.paraphrase_source = json.loads(f.read())
+        logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
+
+    def sample_paraphrase(self, audio_path: str, description: str):
+        if random.random() >= self.paraphrase_p:
+            return description
+        info_path = Path(audio_path).with_suffix('.json')
+        if info_path not in self.paraphrase_source:
+            warn_once(logger, f"{info_path} not in paraphrase source!")
+            return description
+        new_desc = random.choice(self.paraphrase_source[info_path])
+        logger.debug(f"{description} -> {new_desc}")
+        return new_desc
+
+

Methods

+
+
+def sample_paraphrase(self, audio_path: str, description: str) +
+
+
+
+ +Expand source code + +
def sample_paraphrase(self, audio_path: str, description: str):
+    if random.random() >= self.paraphrase_p:
+        return description
+    info_path = Path(audio_path).with_suffix('.json')
+    if info_path not in self.paraphrase_source:
+        warn_once(logger, f"{info_path} not in paraphrase source!")
+        return description
+    new_desc = random.choice(self.paraphrase_source[info_path])
+    logger.debug(f"{description} -> {new_desc}")
+    return new_desc
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/sound_dataset.html b/api_docs/audiocraft/data/sound_dataset.html new file mode 100644 index 00000000..e89e9769 --- /dev/null +++ b/api_docs/audiocraft/data/sound_dataset.html @@ -0,0 +1,1005 @@ + + + + + + +audiocraft.data.sound_dataset API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.sound_dataset

+
+
+

Dataset of audio with a simple description.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Dataset of audio with a simple description.
+"""
+
+from dataclasses import dataclass, fields, replace
+import json
+from pathlib import Path
+import random
+import typing as tp
+
+import numpy as np
+import torch
+
+from .info_audio_dataset import (
+    InfoAudioDataset,
+    get_keyword_or_keyword_list
+)
+from ..modules.conditioners import (
+    ConditioningAttributes,
+    SegmentWithAttributes,
+    WavCondition,
+)
+
+
+EPS = torch.finfo(torch.float32).eps
+TARGET_LEVEL_LOWER = -35
+TARGET_LEVEL_UPPER = -15
+
+
+@dataclass
+class SoundInfo(SegmentWithAttributes):
+    """Segment info augmented with Sound metadata.
+    """
+    description: tp.Optional[str] = None
+    self_wav: tp.Optional[torch.Tensor] = None
+
+    @property
+    def has_sound_meta(self) -> bool:
+        return self.description is not None
+
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        out = ConditioningAttributes()
+
+        for _field in fields(self):
+            key, value = _field.name, getattr(self, _field.name)
+            if key == 'self_wav':
+                out.wav[key] = value
+            else:
+                out.text[key] = value
+        return out
+
+    @staticmethod
+    def attribute_getter(attribute):
+        if attribute == 'description':
+            preprocess_func = get_keyword_or_keyword_list
+        else:
+            preprocess_func = None
+        return preprocess_func
+
+    @classmethod
+    def from_dict(cls, dictionary: dict, fields_required: bool = False):
+        _dictionary: tp.Dict[str, tp.Any] = {}
+
+        # allow a subset of attributes to not be loaded from the dictionary
+        # these attributes may be populated later
+        post_init_attributes = ['self_wav']
+
+        for _field in fields(cls):
+            if _field.name in post_init_attributes:
+                continue
+            elif _field.name not in dictionary:
+                if fields_required:
+                    raise KeyError(f"Unexpected missing key: {_field.name}")
+            else:
+                preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+                value = dictionary[_field.name]
+                if preprocess_func:
+                    value = preprocess_func(value)
+                _dictionary[_field.name] = value
+        return cls(**_dictionary)
+
+
+class SoundDataset(InfoAudioDataset):
+    """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
+
+    Args:
+        info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
+        external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
+            The metadata files contained in this folder are expected to match the stem of the audio file with
+            a json extension.
+        aug_p (float): Probability of performing audio mixing augmentation on the batch.
+        mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
+        mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
+        mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
+        mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
+        kwargs: Additional arguments for AudioDataset.
+
+    See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
+    """
+    def __init__(
+        self,
+        *args,
+        info_fields_required: bool = True,
+        external_metadata_source: tp.Optional[str] = None,
+        aug_p: float = 0.,
+        mix_p: float = 0.,
+        mix_snr_low: int = -5,
+        mix_snr_high: int = 5,
+        mix_min_overlap: float = 0.5,
+        **kwargs
+    ):
+        kwargs['return_info'] = True  # We require the info for each song of the dataset.
+        super().__init__(*args, **kwargs)
+        self.info_fields_required = info_fields_required
+        self.external_metadata_source = external_metadata_source
+        self.aug_p = aug_p
+        self.mix_p = mix_p
+        if self.aug_p > 0:
+            assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
+            assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
+        self.mix_snr_low = mix_snr_low
+        self.mix_snr_high = mix_snr_high
+        self.mix_min_overlap = mix_min_overlap
+
+    def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
+        """Get path of JSON with metadata (description, etc.).
+        If there exists a JSON with the same name as 'path.name', then it will be used.
+        Else, such JSON will be searched for in an external json source folder if it exists.
+        """
+        info_path = Path(path).with_suffix('.json')
+        if Path(info_path).exists():
+            return info_path
+        elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
+            return Path(self.external_metadata_source) / info_path.name
+        else:
+            raise Exception(f"Unable to find a metadata JSON for path: {path}")
+
+    def __getitem__(self, index):
+        wav, info = super().__getitem__(index)
+        info_data = info.to_dict()
+        info_path = self._get_info_path(info.meta.path)
+        if Path(info_path).exists():
+            with open(info_path, 'r') as json_file:
+                sound_data = json.load(json_file)
+                sound_data.update(info_data)
+                sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
+                # if there are multiple descriptions, sample one randomly
+                if isinstance(sound_info.description, list):
+                    sound_info.description = random.choice(sound_info.description)
+        else:
+            sound_info = SoundInfo.from_dict(info_data, fields_required=False)
+
+        sound_info.self_wav = WavCondition(
+            wav=wav[None], length=torch.tensor([info.n_frames]),
+            sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+
+        return wav, sound_info
+
+    def collater(self, samples):
+        # when training, audio mixing is performed in the collate function
+        wav, sound_info = super().collater(samples)  # SoundDataset always returns infos
+        if self.aug_p > 0:
+            wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
+                                          snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
+                                          min_overlap=self.mix_min_overlap)
+        return wav, sound_info
+
+
+def rms_f(x: torch.Tensor) -> torch.Tensor:
+    return (x ** 2).mean(1).pow(0.5)
+
+
+def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
+    """Normalize the signal to the target level."""
+    rms = rms_f(audio)
+    scalar = 10 ** (target_level / 20) / (rms + EPS)
+    audio = audio * scalar.unsqueeze(1)
+    return audio
+
+
+def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
+    return (abs(audio) > clipping_threshold).any(1)
+
+
+def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
+    start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
+    remainder = src.shape[1] - start
+    if dst.shape[1] > remainder:
+        src[:, start:] = src[:, start:] + dst[:, :remainder]
+    else:
+        src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
+    return src
+
+
+def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
+              target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
+    """Function to mix clean speech and noise at various SNR levels.
+
+    Args:
+        clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
+        noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
+        snr (int): SNR level when mixing.
+        min_overlap (float): Minimum overlap between the two mixed sources.
+        target_level (int): Gain level in dB.
+        clipping_threshold (float): Threshold for clipping the audio.
+    Returns:
+        torch.Tensor: The mixed audio, of shape [B, T].
+    """
+    if clean.shape[1] > noise.shape[1]:
+        noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
+    else:
+        noise = noise[:, :clean.shape[1]]
+
+    # normalizing to -25 dB FS
+    clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
+    clean = normalize(clean, target_level)
+    rmsclean = rms_f(clean)
+
+    noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
+    noise = normalize(noise, target_level)
+    rmsnoise = rms_f(noise)
+
+    # set the noise level for a given SNR
+    noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
+    noisenewlevel = noise * noisescalar
+
+    # mix noise and clean speech
+    noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
+
+    # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
+    # there is a chance of clipping that might happen with very less probability, which is not a major issue.
+    noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
+    rmsnoisy = rms_f(noisyspeech)
+    scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
+    noisyspeech = noisyspeech * scalarnoisy
+    clean = clean * scalarnoisy
+    noisenewlevel = noisenewlevel * scalarnoisy
+
+    # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
+    clipped = is_clipped(noisyspeech)
+    if clipped.any():
+        noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
+        noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
+
+    return noisyspeech
+
+
+def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
+    if snr_low == snr_high:
+        snr = snr_low
+    else:
+        snr = np.random.randint(snr_low, snr_high)
+    mix = snr_mixer(src, dst, snr, min_overlap)
+    return mix
+
+
+def mix_text(src_text: str, dst_text: str):
+    """Mix text from different sources by concatenating them."""
+    if src_text == dst_text:
+        return src_text
+    return src_text + " " + dst_text
+
+
+def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
+                snr_low: int, snr_high: int, min_overlap: float):
+    """Mix samples within a batch, summing the waveforms and concatenating the text infos.
+
+    Args:
+        wavs (torch.Tensor): Audio tensors of shape [B, C, T].
+        infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
+        aug_p (float): Augmentation probability.
+        mix_p (float): Proportion of items in the batch to mix (and merge) together.
+        snr_low (int): Lowerbound for sampling SNR.
+        snr_high (int): Upperbound for sampling SNR.
+        min_overlap (float): Minimum overlap between mixed samples.
+    Returns:
+        tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
+            and mixed SoundInfo for the given batch.
+    """
+    # no mixing to perform within the batch
+    if mix_p == 0:
+        return wavs, infos
+
+    if random.uniform(0, 1) < aug_p:
+        # perform all augmentations on waveforms as [B, T]
+        # randomly picking pairs of audio to mix
+        assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
+        wavs = wavs.mean(dim=1, keepdim=False)
+        B, T = wavs.shape
+        k = int(mix_p * B)
+        mixed_sources_idx = torch.randperm(B)[:k]
+        mixed_targets_idx = torch.randperm(B)[:k]
+        aug_wavs = snr_mix(
+            wavs[mixed_sources_idx],
+            wavs[mixed_targets_idx],
+            snr_low,
+            snr_high,
+            min_overlap,
+        )
+        # mixing textual descriptions in metadata
+        descriptions = [info.description for info in infos]
+        aug_infos = []
+        for i, j in zip(mixed_sources_idx, mixed_targets_idx):
+            text = mix_text(descriptions[i], descriptions[j])
+            m = replace(infos[i])
+            m.description = text
+            aug_infos.append(m)
+
+        # back to [B, C, T]
+        aug_wavs = aug_wavs.unsqueeze(1)
+        assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
+        assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
+        assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
+
+        return aug_wavs, aug_infos  # [B, C, T]
+    else:
+        # randomly pick samples in the batch to match
+        # the batch size when performing audio mixing
+        B, C, T = wavs.shape
+        k = int(mix_p * B)
+        wav_idx = torch.randperm(B)[:k]
+        wavs = wavs[wav_idx]
+        infos = [infos[i] for i in wav_idx]
+        assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
+
+        return wavs, infos  # [B, C, T]
+
+
+
+
+
+
+
+

Functions

+
+
+def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
+    return (abs(audio) > clipping_threshold).any(1)
+
+
+
+def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
+    start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
+    remainder = src.shape[1] - start
+    if dst.shape[1] > remainder:
+        src[:, start:] = src[:, start:] + dst[:, :remainder]
+    else:
+        src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
+    return src
+
+
+
+def mix_samples(wavs: torch.Tensor, infos: List[SoundInfo], aug_p: float, mix_p: float, snr_low: int, snr_high: int, min_overlap: float) +
+
+

Mix samples within a batch, summing the waveforms and concatenating the text infos.

+

Args

+
+
wavs : torch.Tensor
+
Audio tensors of shape [B, C, T].
+
infos : list[SoundInfo]
+
List of SoundInfo items corresponding to the audio.
+
aug_p : float
+
Augmentation probability.
+
mix_p : float
+
Proportion of items in the batch to mix (and merge) together.
+
snr_low : int
+
Lowerbound for sampling SNR.
+
snr_high : int
+
Upperbound for sampling SNR.
+
min_overlap : float
+
Minimum overlap between mixed samples.
+
+

Returns

+
+
tuple[torch.Tensor, list[SoundInfo]]
+
A tuple containing the mixed wavs +and mixed SoundInfo for the given batch.
+
+
+ +Expand source code + +
def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
+                snr_low: int, snr_high: int, min_overlap: float):
+    """Mix samples within a batch, summing the waveforms and concatenating the text infos.
+
+    Args:
+        wavs (torch.Tensor): Audio tensors of shape [B, C, T].
+        infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
+        aug_p (float): Augmentation probability.
+        mix_p (float): Proportion of items in the batch to mix (and merge) together.
+        snr_low (int): Lowerbound for sampling SNR.
+        snr_high (int): Upperbound for sampling SNR.
+        min_overlap (float): Minimum overlap between mixed samples.
+    Returns:
+        tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
+            and mixed SoundInfo for the given batch.
+    """
+    # no mixing to perform within the batch
+    if mix_p == 0:
+        return wavs, infos
+
+    if random.uniform(0, 1) < aug_p:
+        # perform all augmentations on waveforms as [B, T]
+        # randomly picking pairs of audio to mix
+        assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
+        wavs = wavs.mean(dim=1, keepdim=False)
+        B, T = wavs.shape
+        k = int(mix_p * B)
+        mixed_sources_idx = torch.randperm(B)[:k]
+        mixed_targets_idx = torch.randperm(B)[:k]
+        aug_wavs = snr_mix(
+            wavs[mixed_sources_idx],
+            wavs[mixed_targets_idx],
+            snr_low,
+            snr_high,
+            min_overlap,
+        )
+        # mixing textual descriptions in metadata
+        descriptions = [info.description for info in infos]
+        aug_infos = []
+        for i, j in zip(mixed_sources_idx, mixed_targets_idx):
+            text = mix_text(descriptions[i], descriptions[j])
+            m = replace(infos[i])
+            m.description = text
+            aug_infos.append(m)
+
+        # back to [B, C, T]
+        aug_wavs = aug_wavs.unsqueeze(1)
+        assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
+        assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
+        assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
+
+        return aug_wavs, aug_infos  # [B, C, T]
+    else:
+        # randomly pick samples in the batch to match
+        # the batch size when performing audio mixing
+        B, C, T = wavs.shape
+        k = int(mix_p * B)
+        wav_idx = torch.randperm(B)[:k]
+        wavs = wavs[wav_idx]
+        infos = [infos[i] for i in wav_idx]
+        assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
+
+        return wavs, infos  # [B, C, T]
+
+
+
+def mix_text(src_text: str, dst_text: str) +
+
+

Mix text from different sources by concatenating them.

+
+ +Expand source code + +
def mix_text(src_text: str, dst_text: str):
+    """Mix text from different sources by concatenating them."""
+    if src_text == dst_text:
+        return src_text
+    return src_text + " " + dst_text
+
+
+
+def normalize(audio: torch.Tensor, target_level: int = -25) ‑> torch.Tensor +
+
+

Normalize the signal to the target level.

+
+ +Expand source code + +
def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
+    """Normalize the signal to the target level."""
+    rms = rms_f(audio)
+    scalar = 10 ** (target_level / 20) / (rms + EPS)
+    audio = audio * scalar.unsqueeze(1)
+    return audio
+
+
+
+def rms_f(x: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def rms_f(x: torch.Tensor) -> torch.Tensor:
+    return (x ** 2).mean(1).pow(0.5)
+
+
+
+def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float) +
+
+
+
+ +Expand source code + +
def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
+    if snr_low == snr_high:
+        snr = snr_low
+    else:
+        snr = np.random.randint(snr_low, snr_high)
+    mix = snr_mixer(src, dst, snr, min_overlap)
+    return mix
+
+
+
+def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float, target_level: int = -25, clipping_threshold: float = 0.99) ‑> torch.Tensor +
+
+

Function to mix clean speech and noise at various SNR levels.

+

Args

+
+
clean : torch.Tensor
+
Clean audio source to mix, of shape [B, T].
+
noise : torch.Tensor
+
Noise audio source to mix, of shape [B, T].
+
snr : int
+
SNR level when mixing.
+
min_overlap : float
+
Minimum overlap between the two mixed sources.
+
target_level : int
+
Gain level in dB.
+
clipping_threshold : float
+
Threshold for clipping the audio.
+
+

Returns

+
+
torch.Tensor
+
The mixed audio, of shape [B, T].
+
+
+ +Expand source code + +
def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
+              target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
+    """Function to mix clean speech and noise at various SNR levels.
+
+    Args:
+        clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
+        noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
+        snr (int): SNR level when mixing.
+        min_overlap (float): Minimum overlap between the two mixed sources.
+        target_level (int): Gain level in dB.
+        clipping_threshold (float): Threshold for clipping the audio.
+    Returns:
+        torch.Tensor: The mixed audio, of shape [B, T].
+    """
+    if clean.shape[1] > noise.shape[1]:
+        noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
+    else:
+        noise = noise[:, :clean.shape[1]]
+
+    # normalizing to -25 dB FS
+    clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
+    clean = normalize(clean, target_level)
+    rmsclean = rms_f(clean)
+
+    noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
+    noise = normalize(noise, target_level)
+    rmsnoise = rms_f(noise)
+
+    # set the noise level for a given SNR
+    noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
+    noisenewlevel = noise * noisescalar
+
+    # mix noise and clean speech
+    noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
+
+    # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
+    # there is a chance of clipping that might happen with very less probability, which is not a major issue.
+    noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
+    rmsnoisy = rms_f(noisyspeech)
+    scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
+    noisyspeech = noisyspeech * scalarnoisy
+    clean = clean * scalarnoisy
+    noisenewlevel = noisenewlevel * scalarnoisy
+
+    # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
+    clipped = is_clipped(noisyspeech)
+    if clipped.any():
+        noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
+        noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
+
+    return noisyspeech
+
+
+
+
+
+

Classes

+
+
+class SoundDataset +(*args, info_fields_required: bool = True, external_metadata_source: Optional[str] = None, aug_p: float = 0.0, mix_p: float = 0.0, mix_snr_low: int = -5, mix_snr_high: int = 5, mix_min_overlap: float = 0.5, **kwargs) +
+
+

Sound audio dataset: Audio dataset with environmental sound-specific metadata.

+

Args

+
+
info_fields_required : bool
+
Whether all the mandatory metadata fields should be in the loaded metadata.
+
external_metadata_source : tp.Optional[str]
+
Folder containing JSON metadata for the corresponding dataset. +The metadata files contained in this folder are expected to match the stem of the audio file with +a json extension.
+
aug_p : float
+
Probability of performing audio mixing augmentation on the batch.
+
mix_p : float
+
Proportion of batch items that are mixed together when applying audio mixing augmentation.
+
mix_snr_low : int
+
Lowerbound for SNR value sampled for mixing augmentation.
+
mix_snr_high : int
+
Upperbound for SNR value sampled for mixing augmentation.
+
mix_min_overlap : float
+
Minimum overlap between audio files when performing mixing augmentation.
+
kwargs
+
Additional arguments for AudioDataset.
+
+

See InfoAudioDataset for full initialization arguments.

+
+ +Expand source code + +
class SoundDataset(InfoAudioDataset):
+    """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
+
+    Args:
+        info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
+        external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
+            The metadata files contained in this folder are expected to match the stem of the audio file with
+            a json extension.
+        aug_p (float): Probability of performing audio mixing augmentation on the batch.
+        mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
+        mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
+        mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
+        mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
+        kwargs: Additional arguments for AudioDataset.
+
+    See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
+    """
+    def __init__(
+        self,
+        *args,
+        info_fields_required: bool = True,
+        external_metadata_source: tp.Optional[str] = None,
+        aug_p: float = 0.,
+        mix_p: float = 0.,
+        mix_snr_low: int = -5,
+        mix_snr_high: int = 5,
+        mix_min_overlap: float = 0.5,
+        **kwargs
+    ):
+        kwargs['return_info'] = True  # We require the info for each song of the dataset.
+        super().__init__(*args, **kwargs)
+        self.info_fields_required = info_fields_required
+        self.external_metadata_source = external_metadata_source
+        self.aug_p = aug_p
+        self.mix_p = mix_p
+        if self.aug_p > 0:
+            assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
+            assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
+        self.mix_snr_low = mix_snr_low
+        self.mix_snr_high = mix_snr_high
+        self.mix_min_overlap = mix_min_overlap
+
+    def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
+        """Get path of JSON with metadata (description, etc.).
+        If there exists a JSON with the same name as 'path.name', then it will be used.
+        Else, such JSON will be searched for in an external json source folder if it exists.
+        """
+        info_path = Path(path).with_suffix('.json')
+        if Path(info_path).exists():
+            return info_path
+        elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
+            return Path(self.external_metadata_source) / info_path.name
+        else:
+            raise Exception(f"Unable to find a metadata JSON for path: {path}")
+
+    def __getitem__(self, index):
+        wav, info = super().__getitem__(index)
+        info_data = info.to_dict()
+        info_path = self._get_info_path(info.meta.path)
+        if Path(info_path).exists():
+            with open(info_path, 'r') as json_file:
+                sound_data = json.load(json_file)
+                sound_data.update(info_data)
+                sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
+                # if there are multiple descriptions, sample one randomly
+                if isinstance(sound_info.description, list):
+                    sound_info.description = random.choice(sound_info.description)
+        else:
+            sound_info = SoundInfo.from_dict(info_data, fields_required=False)
+
+        sound_info.self_wav = WavCondition(
+            wav=wav[None], length=torch.tensor([info.n_frames]),
+            sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
+
+        return wav, sound_info
+
+    def collater(self, samples):
+        # when training, audio mixing is performed in the collate function
+        wav, sound_info = super().collater(samples)  # SoundDataset always returns infos
+        if self.aug_p > 0:
+            wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
+                                          snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
+                                          min_overlap=self.mix_min_overlap)
+        return wav, sound_info
+
+

Ancestors

+ +

Inherited members

+ +
+
+class SoundInfo +(meta: AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int, channels: int, description: Optional[str] = None, self_wav: Optional[torch.Tensor] = None) +
+
+

Segment info augmented with Sound metadata.

+
+ +Expand source code + +
class SoundInfo(SegmentWithAttributes):
+    """Segment info augmented with Sound metadata.
+    """
+    description: tp.Optional[str] = None
+    self_wav: tp.Optional[torch.Tensor] = None
+
+    @property
+    def has_sound_meta(self) -> bool:
+        return self.description is not None
+
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        out = ConditioningAttributes()
+
+        for _field in fields(self):
+            key, value = _field.name, getattr(self, _field.name)
+            if key == 'self_wav':
+                out.wav[key] = value
+            else:
+                out.text[key] = value
+        return out
+
+    @staticmethod
+    def attribute_getter(attribute):
+        if attribute == 'description':
+            preprocess_func = get_keyword_or_keyword_list
+        else:
+            preprocess_func = None
+        return preprocess_func
+
+    @classmethod
+    def from_dict(cls, dictionary: dict, fields_required: bool = False):
+        _dictionary: tp.Dict[str, tp.Any] = {}
+
+        # allow a subset of attributes to not be loaded from the dictionary
+        # these attributes may be populated later
+        post_init_attributes = ['self_wav']
+
+        for _field in fields(cls):
+            if _field.name in post_init_attributes:
+                continue
+            elif _field.name not in dictionary:
+                if fields_required:
+                    raise KeyError(f"Unexpected missing key: {_field.name}")
+            else:
+                preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+                value = dictionary[_field.name]
+                if preprocess_func:
+                    value = preprocess_func(value)
+                _dictionary[_field.name] = value
+        return cls(**_dictionary)
+
+

Ancestors

+ +

Class variables

+
+
var description : Optional[str]
+
+
+
+
var self_wav : Optional[torch.Tensor]
+
+
+
+
+

Static methods

+
+
+def attribute_getter(attribute) +
+
+
+
+ +Expand source code + +
@staticmethod
+def attribute_getter(attribute):
+    if attribute == 'description':
+        preprocess_func = get_keyword_or_keyword_list
+    else:
+        preprocess_func = None
+    return preprocess_func
+
+
+
+def from_dict(dictionary: dict, fields_required: bool = False) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_dict(cls, dictionary: dict, fields_required: bool = False):
+    _dictionary: tp.Dict[str, tp.Any] = {}
+
+    # allow a subset of attributes to not be loaded from the dictionary
+    # these attributes may be populated later
+    post_init_attributes = ['self_wav']
+
+    for _field in fields(cls):
+        if _field.name in post_init_attributes:
+            continue
+        elif _field.name not in dictionary:
+            if fields_required:
+                raise KeyError(f"Unexpected missing key: {_field.name}")
+        else:
+            preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
+            value = dictionary[_field.name]
+            if preprocess_func:
+                value = preprocess_func(value)
+            _dictionary[_field.name] = value
+    return cls(**_dictionary)
+
+
+
+

Instance variables

+
+
var has_sound_meta : bool
+
+
+
+ +Expand source code + +
@property
+def has_sound_meta(self) -> bool:
+    return self.description is not None
+
+
+
+

Methods

+
+
+def to_condition_attributes(self) ‑> ConditioningAttributes +
+
+
+
+ +Expand source code + +
def to_condition_attributes(self) -> ConditioningAttributes:
+    out = ConditioningAttributes()
+
+    for _field in fields(self):
+        key, value = _field.name, getattr(self, _field.name)
+        if key == 'self_wav':
+            out.wav[key] = value
+        else:
+            out.text[key] = value
+    return out
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/data/zip.html b/api_docs/audiocraft/data/zip.html new file mode 100644 index 00000000..db13511e --- /dev/null +++ b/api_docs/audiocraft/data/zip.html @@ -0,0 +1,292 @@ + + + + + + +audiocraft.data.zip API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.zip

+
+
+

Utility for reading some info from inside a zip file.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Utility for reading some info from inside a zip file.
+"""
+
+import typing
+import zipfile
+
+from dataclasses import dataclass
+from functools import lru_cache
+from typing_extensions import Literal
+
+
+DEFAULT_SIZE = 32
+MODE = Literal['r', 'w', 'x', 'a']
+
+
+@dataclass(order=True)
+class PathInZip:
+    """Hold a path of file within a zip file.
+
+    Args:
+        path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
+            Let's assume there is a zip file /some/location/foo.zip
+            and inside of it is a json file located at /data/file1.json,
+            Then we expect path = "/some/location/foo.zip:/data/file1.json".
+    """
+
+    INFO_PATH_SEP = ':'
+    zip_path: str
+    file_path: str
+
+    def __init__(self, path: str) -> None:
+        split_path = path.split(self.INFO_PATH_SEP)
+        assert len(split_path) == 2
+        self.zip_path, self.file_path = split_path
+
+    @classmethod
+    def from_paths(cls, zip_path: str, file_path: str):
+        return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+    def __str__(self) -> str:
+        return self.zip_path + self.INFO_PATH_SEP + self.file_path
+
+
+def _open_zip(path: str, mode: MODE = 'r'):
+    return zipfile.ZipFile(path, mode)
+
+
+_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
+
+
+def set_zip_cache_size(max_size: int):
+    """Sets the maximal LRU caching for zip file opening.
+
+    Args:
+        max_size (int): the maximal LRU cache.
+    """
+    global _cached_open_zip
+    _cached_open_zip = lru_cache(max_size)(_open_zip)
+
+
+def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
+    """Opens a file stored inside a zip and returns a file-like object.
+
+    Args:
+        path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
+        mode (str): The mode in which to open the file with.
+    Returns:
+        A file-like object for PathInZip.
+    """
+    zf = _cached_open_zip(path_in_zip.zip_path)
+    return zf.open(path_in_zip.file_path)
+
+
+
+
+
+
+
+

Functions

+
+
+def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') ‑>  +
+
+

Opens a file stored inside a zip and returns a file-like object.

+

Args

+
+
path_in_zip : PathInZip
+
A PathInZip object representing the file to return a file-like object of.
+
mode : str
+
The mode in which to open the file with.
+
+

Returns

+

A file-like object for PathInZip.

+
+ +Expand source code + +
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
+    """Opens a file stored inside a zip and returns a file-like object.
+
+    Args:
+        path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
+        mode (str): The mode in which to open the file with.
+    Returns:
+        A file-like object for PathInZip.
+    """
+    zf = _cached_open_zip(path_in_zip.zip_path)
+    return zf.open(path_in_zip.file_path)
+
+
+
+def set_zip_cache_size(max_size: int) +
+
+

Sets the maximal LRU caching for zip file opening.

+

Args

+
+
max_size : int
+
the maximal LRU cache.
+
+
+ +Expand source code + +
def set_zip_cache_size(max_size: int):
+    """Sets the maximal LRU caching for zip file opening.
+
+    Args:
+        max_size (int): the maximal LRU cache.
+    """
+    global _cached_open_zip
+    _cached_open_zip = lru_cache(max_size)(_open_zip)
+
+
+
+
+
+

Classes

+
+
+class PathInZip +(path: str) +
+
+

Hold a path of file within a zip file.

+

Args

+
+
path : str
+
The convention is :. +Let's assume there is a zip file /some/location/foo.zip +and inside of it is a json file located at /data/file1.json, +Then we expect path = "/some/location/foo.zip:/data/file1.json".
+
+
+ +Expand source code + +
class PathInZip:
+    """Hold a path of file within a zip file.
+
+    Args:
+        path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
+            Let's assume there is a zip file /some/location/foo.zip
+            and inside of it is a json file located at /data/file1.json,
+            Then we expect path = "/some/location/foo.zip:/data/file1.json".
+    """
+
+    INFO_PATH_SEP = ':'
+    zip_path: str
+    file_path: str
+
+    def __init__(self, path: str) -> None:
+        split_path = path.split(self.INFO_PATH_SEP)
+        assert len(split_path) == 2
+        self.zip_path, self.file_path = split_path
+
+    @classmethod
+    def from_paths(cls, zip_path: str, file_path: str):
+        return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+    def __str__(self) -> str:
+        return self.zip_path + self.INFO_PATH_SEP + self.file_path
+
+

Class variables

+
+
var INFO_PATH_SEP
+
+
+
+
var file_path : str
+
+
+
+
var zip_path : str
+
+
+
+
+

Static methods

+
+
+def from_paths(zip_path: str, file_path: str) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_paths(cls, zip_path: str, file_path: str):
+    return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/environment.html b/api_docs/audiocraft/environment.html new file mode 100644 index 00000000..7d7f61ac --- /dev/null +++ b/api_docs/audiocraft/environment.html @@ -0,0 +1,669 @@ + + + + + + +audiocraft.environment API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.environment

+
+
+

Provides cluster and tools configuration across clusters (slurm, dora, utilities).

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Provides cluster and tools configuration across clusters (slurm, dora, utilities).
+"""
+
+import logging
+import os
+from pathlib import Path
+import re
+import typing as tp
+
+import omegaconf
+
+from .utils.cluster import _guess_cluster_type
+
+
+logger = logging.getLogger(__name__)
+
+
+class AudioCraftEnvironment:
+    """Environment configuration for teams and clusters.
+
+    AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
+    or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
+    provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
+    allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
+    map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
+
+    The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
+    Use the following environment variables to specify the cluster, team or configuration:
+
+        AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
+            cannot be inferred automatically.
+        AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
+            If not set, configuration is read from config/teams.yaml.
+        AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
+            Cluster configuration are shared across teams to match compute allocation,
+            specify your cluster configuration in the configuration file under a key mapping
+            your team name.
+    """
+    _instance = None
+    DEFAULT_TEAM = "default"
+
+    def __init__(self) -> None:
+        """Loads configuration."""
+        self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
+        cluster_type = _guess_cluster_type()
+        cluster = os.getenv(
+            "AUDIOCRAFT_CLUSTER", cluster_type.value
+        )
+        logger.info("Detecting cluster type %s", cluster_type)
+
+        self.cluster: str = cluster
+
+        config_path = os.getenv(
+            "AUDIOCRAFT_CONFIG",
+            Path(__file__)
+            .parent.parent.joinpath("config/teams", self.team)
+            .with_suffix(".yaml"),
+        )
+        self.config = omegaconf.OmegaConf.load(config_path)
+        self._dataset_mappers = []
+        cluster_config = self._get_cluster_config()
+        if "dataset_mappers" in cluster_config:
+            for pattern, repl in cluster_config["dataset_mappers"].items():
+                regex = re.compile(pattern)
+                self._dataset_mappers.append((regex, repl))
+
+    def _get_cluster_config(self) -> omegaconf.DictConfig:
+        assert isinstance(self.config, omegaconf.DictConfig)
+        return self.config[self.cluster]
+
+    @classmethod
+    def instance(cls):
+        if cls._instance is None:
+            cls._instance = cls()
+        return cls._instance
+
+    @classmethod
+    def reset(cls):
+        """Clears the environment and forces a reload on next invocation."""
+        cls._instance = None
+
+    @classmethod
+    def get_team(cls) -> str:
+        """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
+        If not defined, defaults to "labs".
+        """
+        return cls.instance().team
+
+    @classmethod
+    def get_cluster(cls) -> str:
+        """Gets the detected cluster.
+        This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
+        """
+        return cls.instance().cluster
+
+    @classmethod
+    def get_dora_dir(cls) -> Path:
+        """Gets the path to the dora directory for the current team and cluster.
+        Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
+        """
+        cluster_config = cls.instance()._get_cluster_config()
+        dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
+        logger.warning(f"Dora directory: {dora_dir}")
+        return Path(dora_dir)
+
+    @classmethod
+    def get_reference_dir(cls) -> Path:
+        """Gets the path to the reference directory for the current team and cluster.
+        Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
+        """
+        cluster_config = cls.instance()._get_cluster_config()
+        return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
+
+    @classmethod
+    def get_slurm_exclude(cls) -> tp.Optional[str]:
+        """Get the list of nodes to exclude for that cluster."""
+        cluster_config = cls.instance()._get_cluster_config()
+        return cluster_config.get("slurm_exclude")
+
+    @classmethod
+    def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
+        """Gets the requested partitions for the current team and cluster as a comma-separated string.
+
+        Args:
+            partition_types (list[str], optional): partition types to retrieve. Values must be
+                from ['global', 'team']. If not provided, the global partition is returned.
+        """
+        if not partition_types:
+            partition_types = ["global"]
+
+        cluster_config = cls.instance()._get_cluster_config()
+        partitions = [
+            cluster_config["partitions"][partition_type]
+            for partition_type in partition_types
+        ]
+        return ",".join(partitions)
+
+    @classmethod
+    def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
+        """Converts reference placeholder in path with configured reference dir to resolve paths.
+
+        Args:
+            path (str or Path): Path to resolve.
+        Returns:
+            Path: Resolved path.
+        """
+        path = str(path)
+
+        if path.startswith("//reference"):
+            reference_dir = cls.get_reference_dir()
+            logger.warn(f"Reference directory: {reference_dir}")
+            assert (
+                reference_dir.exists() and reference_dir.is_dir()
+            ), f"Reference directory does not exist: {reference_dir}."
+            path = re.sub("^//reference", str(reference_dir), path)
+
+        return Path(path)
+
+    @classmethod
+    def apply_dataset_mappers(cls, path: str) -> str:
+        """Applies dataset mapping regex rules as defined in the configuration.
+        If no rules are defined, the path is returned as-is.
+        """
+        instance = cls.instance()
+
+        for pattern, repl in instance._dataset_mappers:
+            path = pattern.sub(repl, path)
+
+        return path
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class AudioCraftEnvironment +
+
+

Environment configuration for teams and clusters.

+

AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment +or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment +provides pointers to a reference folder resolved automatically across clusters that is shared across team members, +allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically +map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.

+

The cluster type is identified automatically and base configuration file is read from config/teams.yaml. +Use the following environment variables to specify the cluster, team or configuration:

+
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
+    cannot be inferred automatically.
+AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
+    If not set, configuration is read from config/teams.yaml.
+AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
+    Cluster configuration are shared across teams to match compute allocation,
+    specify your cluster configuration in the configuration file under a key mapping
+    your team name.
+
+

Loads configuration.

+
+ +Expand source code + +
class AudioCraftEnvironment:
+    """Environment configuration for teams and clusters.
+
+    AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
+    or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
+    provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
+    allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
+    map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
+
+    The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
+    Use the following environment variables to specify the cluster, team or configuration:
+
+        AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
+            cannot be inferred automatically.
+        AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
+            If not set, configuration is read from config/teams.yaml.
+        AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
+            Cluster configuration are shared across teams to match compute allocation,
+            specify your cluster configuration in the configuration file under a key mapping
+            your team name.
+    """
+    _instance = None
+    DEFAULT_TEAM = "default"
+
+    def __init__(self) -> None:
+        """Loads configuration."""
+        self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
+        cluster_type = _guess_cluster_type()
+        cluster = os.getenv(
+            "AUDIOCRAFT_CLUSTER", cluster_type.value
+        )
+        logger.info("Detecting cluster type %s", cluster_type)
+
+        self.cluster: str = cluster
+
+        config_path = os.getenv(
+            "AUDIOCRAFT_CONFIG",
+            Path(__file__)
+            .parent.parent.joinpath("config/teams", self.team)
+            .with_suffix(".yaml"),
+        )
+        self.config = omegaconf.OmegaConf.load(config_path)
+        self._dataset_mappers = []
+        cluster_config = self._get_cluster_config()
+        if "dataset_mappers" in cluster_config:
+            for pattern, repl in cluster_config["dataset_mappers"].items():
+                regex = re.compile(pattern)
+                self._dataset_mappers.append((regex, repl))
+
+    def _get_cluster_config(self) -> omegaconf.DictConfig:
+        assert isinstance(self.config, omegaconf.DictConfig)
+        return self.config[self.cluster]
+
+    @classmethod
+    def instance(cls):
+        if cls._instance is None:
+            cls._instance = cls()
+        return cls._instance
+
+    @classmethod
+    def reset(cls):
+        """Clears the environment and forces a reload on next invocation."""
+        cls._instance = None
+
+    @classmethod
+    def get_team(cls) -> str:
+        """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
+        If not defined, defaults to "labs".
+        """
+        return cls.instance().team
+
+    @classmethod
+    def get_cluster(cls) -> str:
+        """Gets the detected cluster.
+        This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
+        """
+        return cls.instance().cluster
+
+    @classmethod
+    def get_dora_dir(cls) -> Path:
+        """Gets the path to the dora directory for the current team and cluster.
+        Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
+        """
+        cluster_config = cls.instance()._get_cluster_config()
+        dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
+        logger.warning(f"Dora directory: {dora_dir}")
+        return Path(dora_dir)
+
+    @classmethod
+    def get_reference_dir(cls) -> Path:
+        """Gets the path to the reference directory for the current team and cluster.
+        Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
+        """
+        cluster_config = cls.instance()._get_cluster_config()
+        return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
+
+    @classmethod
+    def get_slurm_exclude(cls) -> tp.Optional[str]:
+        """Get the list of nodes to exclude for that cluster."""
+        cluster_config = cls.instance()._get_cluster_config()
+        return cluster_config.get("slurm_exclude")
+
+    @classmethod
+    def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
+        """Gets the requested partitions for the current team and cluster as a comma-separated string.
+
+        Args:
+            partition_types (list[str], optional): partition types to retrieve. Values must be
+                from ['global', 'team']. If not provided, the global partition is returned.
+        """
+        if not partition_types:
+            partition_types = ["global"]
+
+        cluster_config = cls.instance()._get_cluster_config()
+        partitions = [
+            cluster_config["partitions"][partition_type]
+            for partition_type in partition_types
+        ]
+        return ",".join(partitions)
+
+    @classmethod
+    def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
+        """Converts reference placeholder in path with configured reference dir to resolve paths.
+
+        Args:
+            path (str or Path): Path to resolve.
+        Returns:
+            Path: Resolved path.
+        """
+        path = str(path)
+
+        if path.startswith("//reference"):
+            reference_dir = cls.get_reference_dir()
+            logger.warn(f"Reference directory: {reference_dir}")
+            assert (
+                reference_dir.exists() and reference_dir.is_dir()
+            ), f"Reference directory does not exist: {reference_dir}."
+            path = re.sub("^//reference", str(reference_dir), path)
+
+        return Path(path)
+
+    @classmethod
+    def apply_dataset_mappers(cls, path: str) -> str:
+        """Applies dataset mapping regex rules as defined in the configuration.
+        If no rules are defined, the path is returned as-is.
+        """
+        instance = cls.instance()
+
+        for pattern, repl in instance._dataset_mappers:
+            path = pattern.sub(repl, path)
+
+        return path
+
+

Class variables

+
+
var DEFAULT_TEAM
+
+
+
+
+

Static methods

+
+
+def apply_dataset_mappers(path: str) ‑> str +
+
+

Applies dataset mapping regex rules as defined in the configuration. +If no rules are defined, the path is returned as-is.

+
+ +Expand source code + +
@classmethod
+def apply_dataset_mappers(cls, path: str) -> str:
+    """Applies dataset mapping regex rules as defined in the configuration.
+    If no rules are defined, the path is returned as-is.
+    """
+    instance = cls.instance()
+
+    for pattern, repl in instance._dataset_mappers:
+        path = pattern.sub(repl, path)
+
+    return path
+
+
+
+def get_cluster() ‑> str +
+
+

Gets the detected cluster. +This value can be overridden by the AUDIOCRAFT_CLUSTER env var.

+
+ +Expand source code + +
@classmethod
+def get_cluster(cls) -> str:
+    """Gets the detected cluster.
+    This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
+    """
+    return cls.instance().cluster
+
+
+
+def get_dora_dir() ‑> pathlib.Path +
+
+

Gets the path to the dora directory for the current team and cluster. +Value is overridden by the AUDIOCRAFT_DORA_DIR env var.

+
+ +Expand source code + +
@classmethod
+def get_dora_dir(cls) -> Path:
+    """Gets the path to the dora directory for the current team and cluster.
+    Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
+    """
+    cluster_config = cls.instance()._get_cluster_config()
+    dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
+    logger.warning(f"Dora directory: {dora_dir}")
+    return Path(dora_dir)
+
+
+
+def get_reference_dir() ‑> pathlib.Path +
+
+

Gets the path to the reference directory for the current team and cluster. +Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.

+
+ +Expand source code + +
@classmethod
+def get_reference_dir(cls) -> Path:
+    """Gets the path to the reference directory for the current team and cluster.
+    Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
+    """
+    cluster_config = cls.instance()._get_cluster_config()
+    return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
+
+
+
+def get_slurm_exclude() ‑> Optional[str] +
+
+

Get the list of nodes to exclude for that cluster.

+
+ +Expand source code + +
@classmethod
+def get_slurm_exclude(cls) -> tp.Optional[str]:
+    """Get the list of nodes to exclude for that cluster."""
+    cluster_config = cls.instance()._get_cluster_config()
+    return cluster_config.get("slurm_exclude")
+
+
+
+def get_slurm_partitions(partition_types: Optional[List[str]] = None) ‑> str +
+
+

Gets the requested partitions for the current team and cluster as a comma-separated string.

+

Args

+
+
partition_types : list[str], optional
+
partition types to retrieve. Values must be +from ['global', 'team']. If not provided, the global partition is returned.
+
+
+ +Expand source code + +
@classmethod
+def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
+    """Gets the requested partitions for the current team and cluster as a comma-separated string.
+
+    Args:
+        partition_types (list[str], optional): partition types to retrieve. Values must be
+            from ['global', 'team']. If not provided, the global partition is returned.
+    """
+    if not partition_types:
+        partition_types = ["global"]
+
+    cluster_config = cls.instance()._get_cluster_config()
+    partitions = [
+        cluster_config["partitions"][partition_type]
+        for partition_type in partition_types
+    ]
+    return ",".join(partitions)
+
+
+
+def get_team() ‑> str +
+
+

Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var. +If not defined, defaults to "labs".

+
+ +Expand source code + +
@classmethod
+def get_team(cls) -> str:
+    """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
+    If not defined, defaults to "labs".
+    """
+    return cls.instance().team
+
+
+
+def instance() +
+
+
+
+ +Expand source code + +
@classmethod
+def instance(cls):
+    if cls._instance is None:
+        cls._instance = cls()
+    return cls._instance
+
+
+
+def reset() +
+
+

Clears the environment and forces a reload on next invocation.

+
+ +Expand source code + +
@classmethod
+def reset(cls):
+    """Clears the environment and forces a reload on next invocation."""
+    cls._instance = None
+
+
+
+def resolve_reference_path(path: Union[str, pathlib.Path]) ‑> pathlib.Path +
+
+

Converts reference placeholder in path with configured reference dir to resolve paths.

+

Args

+
+
path : str or Path
+
Path to resolve.
+
+

Returns

+
+
Path
+
Resolved path.
+
+
+ +Expand source code + +
@classmethod
+def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
+    """Converts reference placeholder in path with configured reference dir to resolve paths.
+
+    Args:
+        path (str or Path): Path to resolve.
+    Returns:
+        Path: Resolved path.
+    """
+    path = str(path)
+
+    if path.startswith("//reference"):
+        reference_dir = cls.get_reference_dir()
+        logger.warn(f"Reference directory: {reference_dir}")
+        assert (
+            reference_dir.exists() and reference_dir.is_dir()
+        ), f"Reference directory does not exist: {reference_dir}."
+        path = re.sub("^//reference", str(reference_dir), path)
+
+    return Path(path)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/audiogen/audiogen_base_16khz.html b/api_docs/audiocraft/grids/audiogen/audiogen_base_16khz.html new file mode 100644 index 00000000..241ebed6 --- /dev/null +++ b/api_docs/audiocraft/grids/audiogen/audiogen_base_16khz.html @@ -0,0 +1,81 @@ + + + + + + +audiocraft.grids.audiogen.audiogen_base_16khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.audiogen.audiogen_base_16khz

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ..musicgen._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=64, partition=partitions)
+    launcher.bind_(solver='audiogen/audiogen_base_16khz')
+    # replace this by the desired environmental sound dataset
+    launcher.bind_(dset='internal/sounds_16khz')
+
+    fsdp = {'autocast': False, 'fsdp.use': True}
+    medium = {'model/lm/model_scale': 'medium'}
+
+    launcher.bind_(fsdp)
+    launcher(medium)
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.html b/api_docs/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.html new file mode 100644 index 00000000..dd985039 --- /dev/null +++ b/api_docs/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.html @@ -0,0 +1,173 @@ + + + + + + +audiocraft.grids.audiogen.audiogen_pretrained_16khz_eval API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.audiogen.audiogen_pretrained_16khz_eval

+
+
+

Evaluation with objective metrics for the pretrained AudioGen models. +This grid takes signature from the training grid and runs evaluation-only stage.

+

When running the grid for the first time, please use: +REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval +and re-use the REGEN=1 option when the grid is changed to force regenerating it.

+

Note that you need the proper metrics external libraries setup to use all +the objective metrics activated in this grid. Refer to the README for more information.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Evaluation with objective metrics for the pretrained AudioGen models.
+This grid takes signature from the training grid and runs evaluation-only stage.
+
+When running the grid for the first time, please use:
+REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
+and re-use the REGEN=1 option when the grid is changed to force regenerating it.
+
+Note that you need the proper metrics external libraries setup to use all
+the objective metrics activated in this grid. Refer to the README for more information.
+"""
+
+import os
+
+from ..musicgen._explorers import GenerationEvalExplorer
+from ...environment import AudioCraftEnvironment
+from ... import train
+
+
+def eval(launcher, batch_size: int = 32):
+    opts = {
+        'dset': 'audio/audiocaps_16khz',
+        'solver/audiogen/evaluation': 'objective_eval',
+        'execute_only': 'evaluate',
+        '+dataset.evaluate.batch_size': batch_size,
+        '+metrics.fad.tf.batch_size': 32,
+    }
+    # binary for FAD computation: replace this path with your own path
+    metrics_opts = {
+        'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
+    }
+    opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
+    opt2 = {'transformer_lm.two_step_cfg': True}
+
+    sub = launcher.bind(opts)
+    sub.bind_(metrics_opts)
+
+    # base objective metrics
+    sub(opt1, opt2)
+
+
+@GenerationEvalExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=4, partition=partitions)
+
+    if 'REGEN' not in os.environ:
+        folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
+        with launcher.job_array():
+            for sig in folder.iterdir():
+                if not sig.is_symlink():
+                    continue
+                xp = train.main.get_xp_from_sig(sig.name)
+                launcher(xp.argv)
+        return
+
+    audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
+    audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
+
+    audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
+    audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
+    eval(audiogen_base_medium, batch_size=128)
+
+
+
+
+
+
+
+

Functions

+
+
+def eval(launcher, batch_size: int = 32) +
+
+
+
+ +Expand source code + +
def eval(launcher, batch_size: int = 32):
+    opts = {
+        'dset': 'audio/audiocaps_16khz',
+        'solver/audiogen/evaluation': 'objective_eval',
+        'execute_only': 'evaluate',
+        '+dataset.evaluate.batch_size': batch_size,
+        '+metrics.fad.tf.batch_size': 32,
+    }
+    # binary for FAD computation: replace this path with your own path
+    metrics_opts = {
+        'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
+    }
+    opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
+    opt2 = {'transformer_lm.two_step_cfg': True}
+
+    sub = launcher.bind(opts)
+    sub.bind_(metrics_opts)
+
+    # base objective metrics
+    sub(opt1, opt2)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/audiogen/index.html b/api_docs/audiocraft/grids/audiogen/index.html new file mode 100644 index 00000000..da0bc049 --- /dev/null +++ b/api_docs/audiocraft/grids/audiogen/index.html @@ -0,0 +1,83 @@ + + + + + + +audiocraft.grids.audiogen API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.audiogen

+
+
+

AudioGen grids.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""AudioGen grids."""
+
+
+
+

Sub-modules

+
+
audiocraft.grids.audiogen.audiogen_base_16khz
+
+
+
+
audiocraft.grids.audiogen.audiogen_pretrained_16khz_eval
+
+

Evaluation with objective metrics for the pretrained AudioGen models. +This grid takes signature from the training grid and runs evaluation-only stage …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/compression/debug.html b/api_docs/audiocraft/grids/compression/debug.html new file mode 100644 index 00000000..6e5bd044 --- /dev/null +++ b/api_docs/audiocraft/grids/compression/debug.html @@ -0,0 +1,97 @@ + + + + + + +audiocraft.grids.compression.debug API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.compression.debug

+
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line.

+

This grid is a minimal example for debugging compression task +and how to override parameters directly in a grid. +Learn more about dora grids: https://github.com/facebookresearch/dora

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid is a minimal example for debugging compression task
+and how to override parameters directly in a grid.
+Learn more about dora grids: https://github.com/facebookresearch/dora
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=2, partition=partitions)
+    launcher.bind_(solver='compression/debug')
+
+    with launcher.job_array():
+        # base debug task using config from solver=compression/debug
+        launcher()
+        # we can override parameters in the grid to launch additional xps
+        launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/compression/encodec_audiogen_16khz.html b/api_docs/audiocraft/grids/compression/encodec_audiogen_16khz.html new file mode 100644 index 00000000..37248616 --- /dev/null +++ b/api_docs/audiocraft/grids/compression/encodec_audiogen_16khz.html @@ -0,0 +1,93 @@ + + + + + + +audiocraft.grids.compression.encodec_audiogen_16khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.compression.encodec_audiogen_16khz

+
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line.

+

This grid shows how to train the new AudioGen EnCodec model at 16 kHz.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=8, partition=partitions)
+    # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
+    # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
+    launcher.bind_(solver='compression/encodec_audiogen_16khz')
+    # replace this by the desired sound dataset
+    launcher.bind_(dset='internal/sounds_16khz')
+    # launch xp
+    launcher()
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/compression/encodec_base_24khz.html b/api_docs/audiocraft/grids/compression/encodec_base_24khz.html new file mode 100644 index 00000000..7433cb3e --- /dev/null +++ b/api_docs/audiocraft/grids/compression/encodec_base_24khz.html @@ -0,0 +1,92 @@ + + + + + + +audiocraft.grids.compression.encodec_base_24khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.compression.encodec_base_24khz

+
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line.

+

This grid shows how to train a base causal EnCodec model at 24 kHz.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid shows how to train a base causal EnCodec model at 24 kHz.
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=8, partition=partitions)
+    # base causal EnCodec trained on monophonic audio sampled at 24 kHz
+    launcher.bind_(solver='compression/encodec_base_24khz')
+    # replace this by the desired dataset
+    launcher.bind_(dset='audio/example')
+    # launch xp
+    launcher()
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/compression/encodec_musicgen_32khz.html b/api_docs/audiocraft/grids/compression/encodec_musicgen_32khz.html new file mode 100644 index 00000000..d8008e0d --- /dev/null +++ b/api_docs/audiocraft/grids/compression/encodec_musicgen_32khz.html @@ -0,0 +1,98 @@ + + + + + + +audiocraft.grids.compression.encodec_musicgen_32khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.compression.encodec_musicgen_32khz

+
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by commenting its line.

+

This grid shows how to train a MusicGen EnCodec model at 32 kHz.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Grid search file, simply list all the exp you want in `explorer`.
+Any new exp added there will be scheduled.
+You can cancel and experiment by commenting its line.
+
+This grid shows how to train a MusicGen EnCodec model at 32 kHz.
+"""
+
+from ._explorers import CompressionExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@CompressionExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=8, partition=partitions)
+    # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
+    # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
+    launcher.bind_(solver='compression/encodec_musicgen_32khz')
+    # replace this by the desired music dataset
+    launcher.bind_(dset='internal/music_400k_32khz')
+    # launch xp
+    launcher()
+    launcher({
+        'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
+        'label': 'visqol',
+        'evaluate.metrics.visqol': True
+    })
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/compression/index.html b/api_docs/audiocraft/grids/compression/index.html new file mode 100644 index 00000000..8c0d27f7 --- /dev/null +++ b/api_docs/audiocraft/grids/compression/index.html @@ -0,0 +1,100 @@ + + + + + + +audiocraft.grids.compression API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.compression

+
+
+

EnCodec grids.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""EnCodec grids."""
+
+
+
+

Sub-modules

+
+
audiocraft.grids.compression.debug
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by …

+
+
audiocraft.grids.compression.encodec_audiogen_16khz
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by …

+
+
audiocraft.grids.compression.encodec_base_24khz
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by …

+
+
audiocraft.grids.compression.encodec_musicgen_32khz
+
+

Grid search file, simply list all the exp you want in explorer. +Any new exp added there will be scheduled. +You can cancel and experiment by …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/diffusion/4_bands_base_32khz.html b/api_docs/audiocraft/grids/diffusion/4_bands_base_32khz.html new file mode 100644 index 00000000..ee50f6b9 --- /dev/null +++ b/api_docs/audiocraft/grids/diffusion/4_bands_base_32khz.html @@ -0,0 +1,90 @@ + + + + + + +audiocraft.grids.diffusion.4_bands_base_32khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.diffusion.4_bands_base_32khz

+
+
+

Training of the 4 diffusion models described in +"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" +(paper link).

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Training of the 4 diffusion models described in
+"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
+(paper link).
+"""
+
+from ._explorers import DiffusionExplorer
+
+
+@DiffusionExplorer
+def explorer(launcher):
+    launcher.slurm_(gpus=4, partition='learnfair')
+
+    launcher.bind_({'solver': 'diffusion/default',
+                    'dset': 'internal/music_10k_32khz'})
+
+    with launcher.job_array():
+        launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
+        launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
+        launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
+        launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/diffusion/index.html b/api_docs/audiocraft/grids/diffusion/index.html new file mode 100644 index 00000000..c3aa796e --- /dev/null +++ b/api_docs/audiocraft/grids/diffusion/index.html @@ -0,0 +1,79 @@ + + + + + + +audiocraft.grids.diffusion API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.diffusion

+
+
+

Diffusion grids.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Diffusion grids."""
+
+
+
+

Sub-modules

+
+
audiocraft.grids.diffusion.4_bands_base_32khz
+
+

Training of the 4 diffusion models described in +"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" +(paper link).

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/index.html b/api_docs/audiocraft/grids/index.html new file mode 100644 index 00000000..d530cfe1 --- /dev/null +++ b/api_docs/audiocraft/grids/index.html @@ -0,0 +1,92 @@ + + + + + + +audiocraft.grids API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids

+
+
+

Dora Grids.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Dora Grids."""
+
+
+
+

Sub-modules

+
+
audiocraft.grids.audiogen
+
+

AudioGen grids.

+
+
audiocraft.grids.compression
+
+

EnCodec grids.

+
+
audiocraft.grids.diffusion
+
+

Diffusion grids.

+
+
audiocraft.grids.musicgen
+
+

MusicGen grids.

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/musicgen/index.html b/api_docs/audiocraft/grids/musicgen/index.html new file mode 100644 index 00000000..c086c6ee --- /dev/null +++ b/api_docs/audiocraft/grids/musicgen/index.html @@ -0,0 +1,98 @@ + + + + + + +audiocraft.grids.musicgen API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.musicgen

+
+
+

MusicGen grids.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""MusicGen grids."""
+
+
+
+

Sub-modules

+
+
audiocraft.grids.musicgen.musicgen_base_32khz
+
+
+
+
audiocraft.grids.musicgen.musicgen_base_cached_32khz
+
+
+
+
audiocraft.grids.musicgen.musicgen_clapemb_32khz
+
+
+
+
audiocraft.grids.musicgen.musicgen_melody_32khz
+
+
+
+
audiocraft.grids.musicgen.musicgen_pretrained_32khz_eval
+
+

Evaluation with objective metrics for the pretrained MusicGen models. +This grid takes signature from the training grid and runs evaluation-only stage …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/musicgen/musicgen_base_32khz.html b/api_docs/audiocraft/grids/musicgen/musicgen_base_32khz.html new file mode 100644 index 00000000..7c676c0c --- /dev/null +++ b/api_docs/audiocraft/grids/musicgen/musicgen_base_32khz.html @@ -0,0 +1,101 @@ + + + + + + +audiocraft.grids.musicgen.musicgen_base_32khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.musicgen.musicgen_base_32khz

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=32, partition=partitions)
+    launcher.bind_(solver='musicgen/musicgen_base_32khz')
+    # replace this by the desired music dataset
+    launcher.bind_(dset='internal/music_400k_32khz')
+
+    fsdp = {'autocast': False, 'fsdp.use': True}
+    medium = {'model/lm/model_scale': 'medium'}
+    large = {'model/lm/model_scale': 'large'}
+
+    cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
+    wd_low = {'conditioners.description.t5.word_dropout': 0.2}
+
+    adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
+
+    launcher.bind_(fsdp)
+
+    launcher.slurm_(gpus=32).bind_(label='32gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub()
+
+    launcher.slurm_(gpus=64).bind_(label='64gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub(medium, adam)
+
+    launcher.slurm_(gpus=96).bind_(label='96gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/musicgen/musicgen_base_cached_32khz.html b/api_docs/audiocraft/grids/musicgen/musicgen_base_cached_32khz.html new file mode 100644 index 00000000..e10b1cd8 --- /dev/null +++ b/api_docs/audiocraft/grids/musicgen/musicgen_base_cached_32khz.html @@ -0,0 +1,125 @@ + + + + + + +audiocraft.grids.musicgen.musicgen_base_cached_32khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.musicgen.musicgen_base_cached_32khz

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=32, partition=partitions)
+    launcher.bind_(solver='musicgen/musicgen_base_32khz')
+    # replace this by the desired music dataset
+    launcher.bind_(dset='internal/music_400k_32khz')
+
+    fsdp = {'autocast': False, 'fsdp.use': True}
+    medium = {'model/lm/model_scale': 'medium'}
+    large = {'model/lm/model_scale': 'large'}
+
+    cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
+    wd_low = {'conditioners.description.t5.word_dropout': 0.2}
+
+    adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
+
+    # BEGINNING OF CACHE WRITING JOBS.
+    cache_write = {
+        'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
+        'cache.write': True,
+        'generate.every': 500,
+        'evaluate.every': 500,
+        'logging.log_updates': 50,
+    }
+
+    cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
+    cache_sub.bind_({'deadlock.use': True})
+    cache_sub.slurm_(gpus=8)
+    with launcher.job_array():
+        num_shards = 10  # total number of jobs running in parallel.
+        for shard in range(0, num_shards):
+            launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
+
+    # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
+    # OR SUFFICIENTLY AHEAD.
+    return
+
+    cache = {
+        'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
+    }
+    launcher.bind_(fsdp, cache)
+
+    launcher.slurm_(gpus=32).bind_(label='32gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub()
+
+    launcher.slurm_(gpus=64).bind_(label='64gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub(medium, adam)
+
+    launcher.slurm_(gpus=96).bind_(label='96gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/musicgen/musicgen_clapemb_32khz.html b/api_docs/audiocraft/grids/musicgen/musicgen_clapemb_32khz.html new file mode 100644 index 00000000..2ad78dc8 --- /dev/null +++ b/api_docs/audiocraft/grids/musicgen/musicgen_clapemb_32khz.html @@ -0,0 +1,90 @@ + + + + + + +audiocraft.grids.musicgen.musicgen_clapemb_32khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.musicgen.musicgen_clapemb_32khz

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=32, partition=partitions)
+    launcher.bind_(solver='musicgen/musicgen_base_32khz')
+    # replace this by the desired music dataset
+    launcher.bind_(dset='internal/music_400k_32khz')
+    launcher.bind_(conditioner='clapemb2music')
+
+    fsdp = {'autocast': False, 'fsdp.use': True}
+    cache_path = {'conditioners.description.clap.cache_path':
+                  '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
+    text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
+
+    launcher.bind_(fsdp)
+
+    launcher.slurm_(gpus=32).bind_(label='32gpus')
+    with launcher.job_array():
+        launcher()
+        launcher(text_wav_training_opt)
+        launcher(cache_path)
+        launcher(cache_path, text_wav_training_opt)
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/musicgen/musicgen_melody_32khz.html b/api_docs/audiocraft/grids/musicgen/musicgen_melody_32khz.html new file mode 100644 index 00000000..e6601e89 --- /dev/null +++ b/api_docs/audiocraft/grids/musicgen/musicgen_melody_32khz.html @@ -0,0 +1,123 @@ + + + + + + +audiocraft.grids.musicgen.musicgen_melody_32khz API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.musicgen.musicgen_melody_32khz

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=32, partition=partitions)
+    launcher.bind_(solver='musicgen/musicgen_melody_32khz')
+    # replace this by the desired music dataset
+    launcher.bind_(dset='internal/music_400k_32khz')
+
+    fsdp = {'autocast': False, 'fsdp.use': True}
+    medium = {'model/lm/model_scale': 'medium'}
+    large = {'model/lm/model_scale': 'large'}
+
+    cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
+    wd_low = {'conditioners.description.t5.word_dropout': 0.2}
+
+    adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
+
+    cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
+                  '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
+
+    # CACHE GENERATION JOBS
+    n_cache_gen_jobs = 4
+    gen_sub = launcher.slurm(gpus=1)
+    gen_sub.bind_(
+        cache_path, {
+            # the cache is always computed over the whole file, so duration doesn't matter here.
+            'dataset.segment_duration': 2.,
+            'dataset.batch_size': 8,
+            'dataset.train.permutation_on_files': True,  # try to not repeat files.
+            'optim.epochs': 10,
+            'model/lm/model_scale': 'xsmall',
+
+        })
+    with gen_sub.job_array():
+        for gen_job in range(n_cache_gen_jobs):
+            gen_sub({'dataset.train.shuffle_seed': gen_job})
+
+    # ACTUAL TRAINING JOBS.
+    launcher.bind_(fsdp)
+
+    launcher.slurm_(gpus=32).bind_(label='32gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub()
+        sub(cache_path)
+
+    launcher.slurm_(gpus=64).bind_(label='64gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub(medium, adam)
+
+    launcher.slurm_(gpus=96).bind_(label='96gpus')
+    with launcher.job_array():
+        sub = launcher.bind()
+        sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.html b/api_docs/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.html new file mode 100644 index 00000000..c3c8a209 --- /dev/null +++ b/api_docs/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.html @@ -0,0 +1,218 @@ + + + + + + +audiocraft.grids.musicgen.musicgen_pretrained_32khz_eval API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.grids.musicgen.musicgen_pretrained_32khz_eval

+
+
+

Evaluation with objective metrics for the pretrained MusicGen models. +This grid takes signature from the training grid and runs evaluation-only stage.

+

When running the grid for the first time, please use: +REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval +and re-use the REGEN=1 option when the grid is changed to force regenerating it.

+

Note that you need the proper metrics external libraries setup to use all +the objective metrics activated in this grid. Refer to the README for more information.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Evaluation with objective metrics for the pretrained MusicGen models.
+This grid takes signature from the training grid and runs evaluation-only stage.
+
+When running the grid for the first time, please use:
+REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
+and re-use the REGEN=1 option when the grid is changed to force regenerating it.
+
+Note that you need the proper metrics external libraries setup to use all
+the objective metrics activated in this grid. Refer to the README for more information.
+"""
+
+import os
+
+from ._explorers import GenerationEvalExplorer
+from ...environment import AudioCraftEnvironment
+from ... import train
+
+
+def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
+    opts = {
+        'dset': 'audio/musiccaps_32khz',
+        'solver/musicgen/evaluation': 'objective_eval',
+        'execute_only': 'evaluate',
+        '+dataset.evaluate.batch_size': batch_size,
+        '+metrics.fad.tf.batch_size': 16,
+    }
+    # chroma-specific evaluation
+    chroma_opts = {
+        'dset': 'internal/music_400k_32khz',
+        'dataset.evaluate.segment_duration': 30,
+        'dataset.evaluate.num_samples': 1000,
+        'evaluate.metrics.chroma_cosine': True,
+        'evaluate.metrics.fad': False,
+        'evaluate.metrics.kld': False,
+        'evaluate.metrics.text_consistency': False,
+    }
+    # binary for FAD computation: replace this path with your own path
+    metrics_opts = {
+        'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
+    }
+    opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
+    opt2 = {'transformer_lm.two_step_cfg': True}
+
+    sub = launcher.bind(opts)
+    sub.bind_(metrics_opts)
+
+    # base objective metrics
+    sub(opt1, opt2)
+
+    if eval_melody:
+        # chroma-specific metrics
+        sub(opt1, opt2, chroma_opts)
+
+
+@GenerationEvalExplorer
+def explorer(launcher):
+    partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+    launcher.slurm_(gpus=4, partition=partitions)
+
+    if 'REGEN' not in os.environ:
+        folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
+        with launcher.job_array():
+            for sig in folder.iterdir():
+                if not sig.is_symlink():
+                    continue
+                xp = train.main.get_xp_from_sig(sig.name)
+                launcher(xp.argv)
+        return
+
+    with launcher.job_array():
+        musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
+        musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
+
+        # base musicgen models
+        musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
+        eval(musicgen_base_small, batch_size=128)
+
+        musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
+        musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
+        eval(musicgen_base_medium, batch_size=128)
+
+        musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
+        musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
+        eval(musicgen_base_large, batch_size=128)
+
+        # melody musicgen model
+        musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
+        musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
+
+        musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
+        musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
+        eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
+
+
+
+
+
+
+
+

Functions

+
+
+def eval(launcher, batch_size: int = 32, eval_melody: bool = False) +
+
+
+
+ +Expand source code + +
def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
+    opts = {
+        'dset': 'audio/musiccaps_32khz',
+        'solver/musicgen/evaluation': 'objective_eval',
+        'execute_only': 'evaluate',
+        '+dataset.evaluate.batch_size': batch_size,
+        '+metrics.fad.tf.batch_size': 16,
+    }
+    # chroma-specific evaluation
+    chroma_opts = {
+        'dset': 'internal/music_400k_32khz',
+        'dataset.evaluate.segment_duration': 30,
+        'dataset.evaluate.num_samples': 1000,
+        'evaluate.metrics.chroma_cosine': True,
+        'evaluate.metrics.fad': False,
+        'evaluate.metrics.kld': False,
+        'evaluate.metrics.text_consistency': False,
+    }
+    # binary for FAD computation: replace this path with your own path
+    metrics_opts = {
+        'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
+    }
+    opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
+    opt2 = {'transformer_lm.two_step_cfg': True}
+
+    sub = launcher.bind(opts)
+    sub.bind_(metrics_opts)
+
+    # base objective metrics
+    sub(opt1, opt2)
+
+    if eval_melody:
+        # chroma-specific metrics
+        sub(opt1, opt2, chroma_opts)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/index.html b/api_docs/audiocraft/index.html new file mode 100644 index 00000000..19d5a0d1 --- /dev/null +++ b/api_docs/audiocraft/index.html @@ -0,0 +1,172 @@ + + + + + + +audiocraft API documentation + + + + + + + + + + + +
+
+
+

Package audiocraft

+
+
+

AudioCraft is a general framework for training audio generative models. +At the moment we provide the training code for:

+
    +
  • MusicGen, a state-of-the-art +text-to-music and melody+text autoregressive generative model. +For the solver, see MusicGenSolver, and for the model, +MusicGen.
  • +
  • AudioGen, a state-of-the-art +text-to-general-audio generative model.
  • +
  • EnCodec, efficient and high fidelity +neural audio codec which provides an excellent tokenizer for autoregressive language models. +See CompressionSolver, and EncodecModel.
  • +
  • MultiBandDiffusion, alternative diffusion-based decoder compatible with EnCodec that +improves the perceived quality and reduces the artifacts coming from adversarial decoders.
  • +
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+AudioCraft is a general framework for training audio generative models.
+At the moment we provide the training code for:
+
+- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
+    text-to-music and melody+text autoregressive generative model.
+    For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
+    `audiocraft.models.musicgen.MusicGen`.
+- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
+    text-to-general-audio generative model.
+- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
+    neural audio codec which provides an excellent tokenizer for autoregressive language models.
+    See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
+- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
+    improves the perceived quality and reduces the artifacts coming from adversarial decoders.
+"""
+
+# flake8: noqa
+from . import data, modules, models
+
+__version__ = '1.1.0a1'
+
+
+
+

Sub-modules

+
+
audiocraft.adversarial
+
+

Adversarial losses and discriminator architectures.

+
+
audiocraft.data
+
+

Audio loading and writing support. Datasets for raw audio +or also including some metadata.

+
+
audiocraft.environment
+
+

Provides cluster and tools configuration across clusters (slurm, dora, utilities).

+
+
audiocraft.grids
+
+

Dora Grids.

+
+
audiocraft.losses
+
+

Loss related classes and functions. In particular the loss balancer from +EnCodec, and the usual spectral losses.

+
+
audiocraft.metrics
+
+

Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.

+
+
audiocraft.models
+
+

Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.

+
+
audiocraft.modules
+
+

Modules used for building the models.

+
+
audiocraft.optim
+
+

Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers +and Exponential Moving Average.

+
+
audiocraft.quantization
+
+

RVQ.

+
+
audiocraft.solvers
+
+

Solvers. A Solver is a training recipe, combining the dataloaders, models, +optimizer, losses etc into a single convenient object.

+
+
audiocraft.train
+
+

Entry point for dora to launch solvers for running training loops. +See more info on how to use dora: https://github.com/facebookresearch/dora

+
+
audiocraft.utils
+
+

Utilities.

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/losses/balancer.html b/api_docs/audiocraft/losses/balancer.html new file mode 100644 index 00000000..2d6685a3 --- /dev/null +++ b/api_docs/audiocraft/losses/balancer.html @@ -0,0 +1,491 @@ + + + + + + +audiocraft.losses.balancer API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.losses.balancer

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import flashy
+import torch
+from torch import autograd
+
+
+class Balancer:
+    """Loss balancer.
+
+    The loss balancer combines losses together to compute gradients for the backward.
+    Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
+    not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
+    `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
+    the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
+    going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
+    interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
+
+    Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
+    (with `avg` an exponential moving average over the updates),
+
+        G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
+
+    If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
+    standard sum of the partial gradients with the given weights.
+
+    A call to the backward method of the balancer will compute the the partial gradients,
+    combining all the losses and potentially rescaling the gradients,
+    which can help stabilize the training and reason about multiple losses with varying scales.
+    The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
+
+    Expected usage:
+
+        weights = {'loss_a': 1, 'loss_b': 4}
+        balancer = Balancer(weights, ...)
+        losses: dict = {}
+        losses['loss_a'] = compute_loss_a(x, y)
+        losses['loss_b'] = compute_loss_b(x, y)
+        if model.training():
+            effective_loss = balancer.backward(losses, x)
+
+    Args:
+        weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
+            from the backward method to match the weights keys to assign weight to each of the provided loss.
+        balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
+            overall gradient, rather than a constant multiplier.
+        total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
+        emay_decay (float): EMA decay for averaging the norms.
+        per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
+            when rescaling the gradients.
+        epsilon (float): Epsilon value for numerical stability.
+        monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
+            coming from each loss, when calling `backward()`.
+    """
+    def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
+                 ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
+                 monitor: bool = False):
+        self.weights = weights
+        self.per_batch_item = per_batch_item
+        self.total_norm = total_norm or 1.
+        self.averager = flashy.averager(ema_decay or 1.)
+        self.epsilon = epsilon
+        self.monitor = monitor
+        self.balance_grads = balance_grads
+        self._metrics: tp.Dict[str, tp.Any] = {}
+
+    @property
+    def metrics(self):
+        return self._metrics
+
+    def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
+        """Compute the backward and return the effective train loss, e.g. the loss obtained from
+        computing the effective weights. If `balance_grads` is True, the effective weights
+        are the one that needs to be applied to each gradient to respect the desired relative
+        scale of gradients coming from each loss.
+
+        Args:
+            losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
+            input (torch.Tensor): the input of the losses, typically the output of the model.
+                This should be the single point of dependence between the losses
+                and the model being trained.
+        """
+        norms = {}
+        grads = {}
+        for name, loss in losses.items():
+            # Compute partial derivative of the less with respect to the input.
+            grad, = autograd.grad(loss, [input], retain_graph=True)
+            if self.per_batch_item:
+                # We do not average the gradient over the batch dimension.
+                dims = tuple(range(1, grad.dim()))
+                norm = grad.norm(dim=dims, p=2).mean()
+            else:
+                norm = grad.norm(p=2)
+            norms[name] = norm
+            grads[name] = grad
+
+        count = 1
+        if self.per_batch_item:
+            count = len(grad)
+        # Average norms across workers. Theoretically we should average the
+        # squared norm, then take the sqrt, but it worked fine like that.
+        avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
+        # We approximate the total norm of the gradient as the sums of the norms.
+        # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
+        total = sum(avg_norms.values())
+
+        self._metrics = {}
+        if self.monitor:
+            # Store the ratio of the total gradient represented by each loss.
+            for k, v in avg_norms.items():
+                self._metrics[f'ratio_{k}'] = v / total
+
+        total_weights = sum([self.weights[k] for k in avg_norms])
+        assert total_weights > 0.
+        desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
+
+        out_grad = torch.zeros_like(input)
+        effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
+        for name, avg_norm in avg_norms.items():
+            if self.balance_grads:
+                # g_balanced = g / avg(||g||) * total_norm * desired_ratio
+                scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
+            else:
+                # We just do regular weighted sum of the gradients.
+                scale = self.weights[name]
+            out_grad.add_(grads[name], alpha=scale)
+            effective_loss += scale * losses[name].detach()
+        # Send the computed partial derivative with respect to the output of the model to the model.
+        input.backward(out_grad)
+        return effective_loss
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class Balancer +(weights: Dict[str, float], balance_grads: bool = True, total_norm: float = 1.0, ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12, monitor: bool = False) +
+
+

Loss balancer.

+

The loss balancer combines losses together to compute gradients for the backward. +Given y = f(...), and a number of losses l1(y, …), l2(y, …), with +not having any dependence on f, the balancer can efficiently normalize the partial gradients +d l1 / d y, d l2 / dy before summing them in order to achieve a desired ratio between +the losses. For instance if weights = {'l1': 2, 'l2': 1}, 66% of the gradient +going into f(…) will come from l1 on average, and 33% from l2. This allows for an easy +interpration of the weights even if the intrisic scale of l1, l2 … is unknown.

+

Noting g1 = d l1 / dy, etc., the balanced gradient G will be +(with avg an exponential moving average over the updates),

+
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
+
+

If balance_grads is False, this is deactivated, and instead the gradient will just be the +standard sum of the partial gradients with the given weights.

+

A call to the backward method of the balancer will compute the the partial gradients, +combining all the losses and potentially rescaling the gradients, +which can help stabilize the training and reason about multiple losses with varying scales. +The obtained gradient with respect to y is then back-propagated to f(…).

+

Expected usage:

+
weights = {'loss_a': 1, 'loss_b': 4}
+balancer = Balancer(weights, ...)
+losses: dict = {}
+losses['loss_a'] = compute_loss_a(x, y)
+losses['loss_b'] = compute_loss_b(x, y)
+if model.training():
+    effective_loss = balancer.backward(losses, x)
+
+

Args

+
+
weights : dict[str, float]
+
Weight coefficient for each loss. The balancer expect the losses keys +from the backward method to match the weights keys to assign weight to each of the provided loss.
+
balance_grads : bool
+
Whether to rescale gradients so that weights reflect the fraction of the +overall gradient, rather than a constant multiplier.
+
total_norm : float
+
Reference norm when rescaling gradients, ignored otherwise.
+
emay_decay : float
+
EMA decay for averaging the norms.
+
per_batch_item : bool
+
Whether to compute the averaged norm per batch item or not. This only holds +when rescaling the gradients.
+
epsilon : float
+
Epsilon value for numerical stability.
+
monitor : bool
+
If True, stores in self.metrics the relative ratio between the norm of the gradients +coming from each loss, when calling backward().
+
+
+ +Expand source code + +
class Balancer:
+    """Loss balancer.
+
+    The loss balancer combines losses together to compute gradients for the backward.
+    Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
+    not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
+    `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
+    the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
+    going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
+    interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
+
+    Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
+    (with `avg` an exponential moving average over the updates),
+
+        G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
+
+    If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
+    standard sum of the partial gradients with the given weights.
+
+    A call to the backward method of the balancer will compute the the partial gradients,
+    combining all the losses and potentially rescaling the gradients,
+    which can help stabilize the training and reason about multiple losses with varying scales.
+    The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
+
+    Expected usage:
+
+        weights = {'loss_a': 1, 'loss_b': 4}
+        balancer = Balancer(weights, ...)
+        losses: dict = {}
+        losses['loss_a'] = compute_loss_a(x, y)
+        losses['loss_b'] = compute_loss_b(x, y)
+        if model.training():
+            effective_loss = balancer.backward(losses, x)
+
+    Args:
+        weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
+            from the backward method to match the weights keys to assign weight to each of the provided loss.
+        balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
+            overall gradient, rather than a constant multiplier.
+        total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
+        emay_decay (float): EMA decay for averaging the norms.
+        per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
+            when rescaling the gradients.
+        epsilon (float): Epsilon value for numerical stability.
+        monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
+            coming from each loss, when calling `backward()`.
+    """
+    def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
+                 ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
+                 monitor: bool = False):
+        self.weights = weights
+        self.per_batch_item = per_batch_item
+        self.total_norm = total_norm or 1.
+        self.averager = flashy.averager(ema_decay or 1.)
+        self.epsilon = epsilon
+        self.monitor = monitor
+        self.balance_grads = balance_grads
+        self._metrics: tp.Dict[str, tp.Any] = {}
+
+    @property
+    def metrics(self):
+        return self._metrics
+
+    def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
+        """Compute the backward and return the effective train loss, e.g. the loss obtained from
+        computing the effective weights. If `balance_grads` is True, the effective weights
+        are the one that needs to be applied to each gradient to respect the desired relative
+        scale of gradients coming from each loss.
+
+        Args:
+            losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
+            input (torch.Tensor): the input of the losses, typically the output of the model.
+                This should be the single point of dependence between the losses
+                and the model being trained.
+        """
+        norms = {}
+        grads = {}
+        for name, loss in losses.items():
+            # Compute partial derivative of the less with respect to the input.
+            grad, = autograd.grad(loss, [input], retain_graph=True)
+            if self.per_batch_item:
+                # We do not average the gradient over the batch dimension.
+                dims = tuple(range(1, grad.dim()))
+                norm = grad.norm(dim=dims, p=2).mean()
+            else:
+                norm = grad.norm(p=2)
+            norms[name] = norm
+            grads[name] = grad
+
+        count = 1
+        if self.per_batch_item:
+            count = len(grad)
+        # Average norms across workers. Theoretically we should average the
+        # squared norm, then take the sqrt, but it worked fine like that.
+        avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
+        # We approximate the total norm of the gradient as the sums of the norms.
+        # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
+        total = sum(avg_norms.values())
+
+        self._metrics = {}
+        if self.monitor:
+            # Store the ratio of the total gradient represented by each loss.
+            for k, v in avg_norms.items():
+                self._metrics[f'ratio_{k}'] = v / total
+
+        total_weights = sum([self.weights[k] for k in avg_norms])
+        assert total_weights > 0.
+        desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
+
+        out_grad = torch.zeros_like(input)
+        effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
+        for name, avg_norm in avg_norms.items():
+            if self.balance_grads:
+                # g_balanced = g / avg(||g||) * total_norm * desired_ratio
+                scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
+            else:
+                # We just do regular weighted sum of the gradients.
+                scale = self.weights[name]
+            out_grad.add_(grads[name], alpha=scale)
+            effective_loss += scale * losses[name].detach()
+        # Send the computed partial derivative with respect to the output of the model to the model.
+        input.backward(out_grad)
+        return effective_loss
+
+

Instance variables

+
+
var metrics
+
+
+
+ +Expand source code + +
@property
+def metrics(self):
+    return self._metrics
+
+
+
+

Methods

+
+
+def backward(self, losses: Dict[str, torch.Tensor], input: torch.Tensor) ‑> torch.Tensor +
+
+

Compute the backward and return the effective train loss, e.g. the loss obtained from +computing the effective weights. If balance_grads is True, the effective weights +are the one that needs to be applied to each gradient to respect the desired relative +scale of gradients coming from each loss.

+

Args

+
+
losses : Dict[str, torch.Tensor]
+
dictionary with the same keys as self.weights.
+
input : torch.Tensor
+
the input of the losses, typically the output of the model. +This should be the single point of dependence between the losses +and the model being trained.
+
+
+ +Expand source code + +
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
+    """Compute the backward and return the effective train loss, e.g. the loss obtained from
+    computing the effective weights. If `balance_grads` is True, the effective weights
+    are the one that needs to be applied to each gradient to respect the desired relative
+    scale of gradients coming from each loss.
+
+    Args:
+        losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
+        input (torch.Tensor): the input of the losses, typically the output of the model.
+            This should be the single point of dependence between the losses
+            and the model being trained.
+    """
+    norms = {}
+    grads = {}
+    for name, loss in losses.items():
+        # Compute partial derivative of the less with respect to the input.
+        grad, = autograd.grad(loss, [input], retain_graph=True)
+        if self.per_batch_item:
+            # We do not average the gradient over the batch dimension.
+            dims = tuple(range(1, grad.dim()))
+            norm = grad.norm(dim=dims, p=2).mean()
+        else:
+            norm = grad.norm(p=2)
+        norms[name] = norm
+        grads[name] = grad
+
+    count = 1
+    if self.per_batch_item:
+        count = len(grad)
+    # Average norms across workers. Theoretically we should average the
+    # squared norm, then take the sqrt, but it worked fine like that.
+    avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
+    # We approximate the total norm of the gradient as the sums of the norms.
+    # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
+    total = sum(avg_norms.values())
+
+    self._metrics = {}
+    if self.monitor:
+        # Store the ratio of the total gradient represented by each loss.
+        for k, v in avg_norms.items():
+            self._metrics[f'ratio_{k}'] = v / total
+
+    total_weights = sum([self.weights[k] for k in avg_norms])
+    assert total_weights > 0.
+    desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
+
+    out_grad = torch.zeros_like(input)
+    effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
+    for name, avg_norm in avg_norms.items():
+        if self.balance_grads:
+            # g_balanced = g / avg(||g||) * total_norm * desired_ratio
+            scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
+        else:
+            # We just do regular weighted sum of the gradients.
+            scale = self.weights[name]
+        out_grad.add_(grads[name], alpha=scale)
+        effective_loss += scale * losses[name].detach()
+    # Send the computed partial derivative with respect to the output of the model to the model.
+    input.backward(out_grad)
+    return effective_loss
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/losses/index.html b/api_docs/audiocraft/losses/index.html new file mode 100644 index 00000000..0a4e32ae --- /dev/null +++ b/api_docs/audiocraft/losses/index.html @@ -0,0 +1,109 @@ + + + + + + +audiocraft.losses API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.losses

+
+
+

Loss related classes and functions. In particular the loss balancer from +EnCodec, and the usual spectral losses.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Loss related classes and functions. In particular the loss balancer from
+EnCodec, and the usual spectral losses."""
+
+# flake8: noqa
+from .balancer import Balancer
+from .sisnr import SISNR
+from .stftloss import (
+    LogSTFTMagnitudeLoss,
+    MRSTFTLoss,
+    SpectralConvergenceLoss,
+    STFTLoss
+)
+from .specloss import (
+    MelSpectrogramL1Loss,
+    MultiScaleMelSpectrogramLoss,
+)
+
+
+
+

Sub-modules

+
+
audiocraft.losses.balancer
+
+
+
+
audiocraft.losses.sisnr
+
+
+
+
audiocraft.losses.specloss
+
+
+
+
audiocraft.losses.stftloss
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/losses/sisnr.html b/api_docs/audiocraft/losses/sisnr.html new file mode 100644 index 00000000..b1e2be25 --- /dev/null +++ b/api_docs/audiocraft/losses/sisnr.html @@ -0,0 +1,332 @@ + + + + + + +audiocraft.losses.sisnr API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.losses.sisnr

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
+    """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
+    with K the kernel size, by extracting frames with the given stride.
+    This will pad the input so that `F = ceil(T / K)`.
+    see https://github.com/pytorch/pytorch/issues/60466
+    """
+    *shape, length = a.shape
+    n_frames = math.ceil(length / stride)
+    tgt_length = (n_frames - 1) * stride + kernel_size
+    a = F.pad(a, (0, tgt_length - length))
+    strides = list(a.stride())
+    assert strides[-1] == 1, "data should be contiguous"
+    strides = strides[:-1] + [stride, 1]
+    return a.as_strided([*shape, n_frames, kernel_size], strides)
+
+
+def _center(x: torch.Tensor) -> torch.Tensor:
+    return x - x.mean(-1, True)
+
+
+def _norm2(x: torch.Tensor) -> torch.Tensor:
+    return x.pow(2).sum(-1, True)
+
+
+class SISNR(nn.Module):
+    """SISNR loss.
+
+    Input should be [B, C, T], output is scalar.
+
+    ..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`).
+        Consequently, lower scores are better in terms of reconstruction quality,
+        in particular, it should be negative if training goes well. This done this way so
+        that this module can also be used as a loss function for training model.
+
+    Args:
+        sample_rate (int): Sample rate.
+        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
+            entire audio only.
+        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
+        epsilon (float): Epsilon value for numerical stability.
+    """
+    def __init__(
+        self,
+        sample_rate: int = 16000,
+        segment: tp.Optional[float] = 20,
+        overlap: float = 0.5,
+        epsilon: float = torch.finfo(torch.float32).eps,
+    ):
+        super().__init__()
+        self.sample_rate = sample_rate
+        self.segment = segment
+        self.overlap = overlap
+        self.epsilon = epsilon
+
+    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
+        B, C, T = ref_sig.shape
+        assert ref_sig.shape == out_sig.shape
+
+        if self.segment is None:
+            frame = T
+            stride = T
+        else:
+            frame = int(self.segment * self.sample_rate)
+            stride = int(frame * (1 - self.overlap))
+
+        epsilon = self.epsilon * frame  # make epsilon prop to frame size.
+
+        gt = _unfold(ref_sig, frame, stride)
+        est = _unfold(out_sig, frame, stride)
+        if self.segment is None:
+            assert gt.shape[-1] == 1
+
+        gt = _center(gt)
+        est = _center(est)
+        dot = torch.einsum("bcft,bcft->bcf", gt, est)
+
+        proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
+        noise = est - proj
+
+        sisnr = 10 * (
+            torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
+        )
+        return -1 * sisnr[..., 0].mean()
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class SISNR +(sample_rate: int = 16000, segment: Optional[float] = 20, overlap: float = 0.5, epsilon: float = 1.1920928955078125e-07) +
+
+

SISNR loss.

+

Input should be [B, C, T], output is scalar.

+
+

Warning: This function returns the opposite of the SI-SNR (e.g. -1 * regular_SI_SNR).

+

Consequently, lower scores are better in terms of reconstruction quality, +in particular, it should be negative if training goes well. This done this way so +that this module can also be used as a loss function for training model.

+
+

Args

+
+
sample_rate : int
+
Sample rate.
+
segment : float or None
+
Evaluate on chunks of that many seconds. If None, evaluate on +entire audio only.
+
overlap : float
+
Overlap between chunks, i.e. 0.5 = 50 % overlap.
+
epsilon : float
+
Epsilon value for numerical stability.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SISNR(nn.Module):
+    """SISNR loss.
+
+    Input should be [B, C, T], output is scalar.
+
+    ..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`).
+        Consequently, lower scores are better in terms of reconstruction quality,
+        in particular, it should be negative if training goes well. This done this way so
+        that this module can also be used as a loss function for training model.
+
+    Args:
+        sample_rate (int): Sample rate.
+        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
+            entire audio only.
+        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
+        epsilon (float): Epsilon value for numerical stability.
+    """
+    def __init__(
+        self,
+        sample_rate: int = 16000,
+        segment: tp.Optional[float] = 20,
+        overlap: float = 0.5,
+        epsilon: float = torch.finfo(torch.float32).eps,
+    ):
+        super().__init__()
+        self.sample_rate = sample_rate
+        self.segment = segment
+        self.overlap = overlap
+        self.epsilon = epsilon
+
+    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
+        B, C, T = ref_sig.shape
+        assert ref_sig.shape == out_sig.shape
+
+        if self.segment is None:
+            frame = T
+            stride = T
+        else:
+            frame = int(self.segment * self.sample_rate)
+            stride = int(frame * (1 - self.overlap))
+
+        epsilon = self.epsilon * frame  # make epsilon prop to frame size.
+
+        gt = _unfold(ref_sig, frame, stride)
+        est = _unfold(out_sig, frame, stride)
+        if self.segment is None:
+            assert gt.shape[-1] == 1
+
+        gt = _center(gt)
+        est = _center(est)
+        dot = torch.einsum("bcft,bcft->bcf", gt, est)
+
+        proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
+        noise = est - proj
+
+        sisnr = 10 * (
+            torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
+        )
+        return -1 * sisnr[..., 0].mean()
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) ‑> torch.Tensor +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
+    B, C, T = ref_sig.shape
+    assert ref_sig.shape == out_sig.shape
+
+    if self.segment is None:
+        frame = T
+        stride = T
+    else:
+        frame = int(self.segment * self.sample_rate)
+        stride = int(frame * (1 - self.overlap))
+
+    epsilon = self.epsilon * frame  # make epsilon prop to frame size.
+
+    gt = _unfold(ref_sig, frame, stride)
+    est = _unfold(out_sig, frame, stride)
+    if self.segment is None:
+        assert gt.shape[-1] == 1
+
+    gt = _center(gt)
+    est = _center(est)
+    dot = torch.einsum("bcft,bcft->bcf", gt, est)
+
+    proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
+    noise = est - proj
+
+    sisnr = 10 * (
+        torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
+    )
+    return -1 * sisnr[..., 0].mean()
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/losses/specloss.html b/api_docs/audiocraft/losses/specloss.html new file mode 100644 index 00000000..b347f9c2 --- /dev/null +++ b/api_docs/audiocraft/losses/specloss.html @@ -0,0 +1,634 @@ + + + + + + +audiocraft.losses.specloss API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.losses.specloss

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+from torchaudio.transforms import MelSpectrogram
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ..modules import pad_for_conv1d
+
+
+class MelSpectrogramWrapper(nn.Module):
+    """Wrapper around MelSpectrogram torchaudio transform providing proper padding
+    and additional post-processing including log scaling.
+
+    Args:
+        n_mels (int): Number of mel bins.
+        n_fft (int): Number of fft.
+        hop_length (int): Hop size.
+        win_length (int): Window length.
+        n_mels (int): Number of mel bins.
+        sample_rate (int): Sample rate.
+        f_min (float or None): Minimum frequency.
+        f_max (float or None): Maximum frequency.
+        log (bool): Whether to scale with log.
+        normalized (bool): Whether to normalize the melspectrogram.
+        floor_level (float): Floor level based on human perception (default=1e-5).
+    """
+    def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
+                 n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+                 log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
+        super().__init__()
+        self.n_fft = n_fft
+        hop_length = int(hop_length)
+        self.hop_length = hop_length
+        self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
+                                            win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
+                                            window_fn=torch.hann_window, center=False)
+        self.floor_level = floor_level
+        self.log = log
+
+    def forward(self, x):
+        p = int((self.n_fft - self.hop_length) // 2)
+        if len(x.shape) == 2:
+            x = x.unsqueeze(1)
+        x = F.pad(x, (p, p), "reflect")
+        # Make sure that all the frames are full.
+        # The combination of `pad_for_conv1d` and the above padding
+        # will make the output of size ceil(T / hop).
+        x = pad_for_conv1d(x, self.n_fft, self.hop_length)
+        self.mel_transform.to(x.device)
+        mel_spec = self.mel_transform(x)
+        B, C, freqs, frame = mel_spec.shape
+        if self.log:
+            mel_spec = torch.log10(self.floor_level + mel_spec)
+        return mel_spec.reshape(B, C * freqs, frame)
+
+
+class MelSpectrogramL1Loss(torch.nn.Module):
+    """L1 Loss on MelSpectrogram.
+
+    Args:
+        sample_rate (int): Sample rate.
+        n_fft (int): Number of fft.
+        hop_length (int): Hop size.
+        win_length (int): Window length.
+        n_mels (int): Number of mel bins.
+        f_min (float or None): Minimum frequency.
+        f_max (float or None): Maximum frequency.
+        log (bool): Whether to scale with log.
+        normalized (bool): Whether to normalize the melspectrogram.
+        floor_level (float): Floor level value based on human perception (default=1e-5).
+    """
+    def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
+                 n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+                 log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
+        super().__init__()
+        self.l1 = torch.nn.L1Loss()
+        self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
+                                             n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+                                             log=log, normalized=normalized, floor_level=floor_level)
+
+    def forward(self, x, y):
+        self.melspec.to(x.device)
+        s_x = self.melspec(x)
+        s_y = self.melspec(y)
+        return self.l1(s_x, s_y)
+
+
+class MultiScaleMelSpectrogramLoss(nn.Module):
+    """Multi-Scale spectrogram loss (msspec).
+
+    Args:
+        sample_rate (int): Sample rate.
+        range_start (int): Power of 2 to use for the first scale.
+        range_stop (int): Power of 2 to use for the last scale.
+        n_mels (int): Number of mel bins.
+        f_min (float): Minimum frequency.
+        f_max (float or None): Maximum frequency.
+        normalized (bool): Whether to normalize the melspectrogram.
+        alphas (bool): Whether to use alphas as coefficients or not.
+        floor_level (float): Floor level value based on human perception (default=1e-5).
+    """
+    def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
+                 n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+                 normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
+        super().__init__()
+        l1s = list()
+        l2s = list()
+        self.alphas = list()
+        self.total = 0
+        self.normalized = normalized
+        for i in range(range_start, range_end):
+            l1s.append(
+                MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
+                                      n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+                                      log=False, normalized=normalized, floor_level=floor_level))
+            l2s.append(
+                MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
+                                      n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+                                      log=True, normalized=normalized, floor_level=floor_level))
+            if alphas:
+                self.alphas.append(np.sqrt(2 ** i - 1))
+            else:
+                self.alphas.append(1)
+            self.total += self.alphas[-1] + 1
+
+        self.l1s = nn.ModuleList(l1s)
+        self.l2s = nn.ModuleList(l2s)
+
+    def forward(self, x, y):
+        loss = 0.0
+        self.l1s.to(x.device)
+        self.l2s.to(x.device)
+        for i in range(len(self.alphas)):
+            s_x_1 = self.l1s[i](x)
+            s_y_1 = self.l1s[i](y)
+            s_x_2 = self.l2s[i](x)
+            s_y_2 = self.l2s[i](y)
+            loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
+        if self.normalized:
+            loss = loss / self.total
+        return loss
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MelSpectrogramL1Loss +(sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, n_mels: int = 80, f_min: float = 0.0, f_max: Optional[float] = None, log: bool = True, normalized: bool = False, floor_level: float = 1e-05) +
+
+

L1 Loss on MelSpectrogram.

+

Args

+
+
sample_rate : int
+
Sample rate.
+
n_fft : int
+
Number of fft.
+
hop_length : int
+
Hop size.
+
win_length : int
+
Window length.
+
n_mels : int
+
Number of mel bins.
+
f_min : float or None
+
Minimum frequency.
+
f_max : float or None
+
Maximum frequency.
+
log : bool
+
Whether to scale with log.
+
normalized : bool
+
Whether to normalize the melspectrogram.
+
floor_level : float
+
Floor level value based on human perception (default=1e-5).
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MelSpectrogramL1Loss(torch.nn.Module):
+    """L1 Loss on MelSpectrogram.
+
+    Args:
+        sample_rate (int): Sample rate.
+        n_fft (int): Number of fft.
+        hop_length (int): Hop size.
+        win_length (int): Window length.
+        n_mels (int): Number of mel bins.
+        f_min (float or None): Minimum frequency.
+        f_max (float or None): Maximum frequency.
+        log (bool): Whether to scale with log.
+        normalized (bool): Whether to normalize the melspectrogram.
+        floor_level (float): Floor level value based on human perception (default=1e-5).
+    """
+    def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
+                 n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+                 log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
+        super().__init__()
+        self.l1 = torch.nn.L1Loss()
+        self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
+                                             n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+                                             log=log, normalized=normalized, floor_level=floor_level)
+
+    def forward(self, x, y):
+        self.melspec.to(x.device)
+        s_x = self.melspec(x)
+        s_y = self.melspec(y)
+        return self.l1(s_x, s_y)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x, y) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, y):
+    self.melspec.to(x.device)
+    s_x = self.melspec(x)
+    s_y = self.melspec(y)
+    return self.l1(s_x, s_y)
+
+
+
+
+
+class MelSpectrogramWrapper +(n_fft: int = 1024, hop_length: int = 256, win_length: Optional[int] = None, n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: Optional[float] = None, log: bool = True, normalized: bool = False, floor_level: float = 1e-05) +
+
+

Wrapper around MelSpectrogram torchaudio transform providing proper padding +and additional post-processing including log scaling.

+

Args

+
+
n_mels : int
+
Number of mel bins.
+
n_fft : int
+
Number of fft.
+
hop_length : int
+
Hop size.
+
win_length : int
+
Window length.
+
n_mels : int
+
Number of mel bins.
+
sample_rate : int
+
Sample rate.
+
f_min : float or None
+
Minimum frequency.
+
f_max : float or None
+
Maximum frequency.
+
log : bool
+
Whether to scale with log.
+
normalized : bool
+
Whether to normalize the melspectrogram.
+
floor_level : float
+
Floor level based on human perception (default=1e-5).
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MelSpectrogramWrapper(nn.Module):
+    """Wrapper around MelSpectrogram torchaudio transform providing proper padding
+    and additional post-processing including log scaling.
+
+    Args:
+        n_mels (int): Number of mel bins.
+        n_fft (int): Number of fft.
+        hop_length (int): Hop size.
+        win_length (int): Window length.
+        n_mels (int): Number of mel bins.
+        sample_rate (int): Sample rate.
+        f_min (float or None): Minimum frequency.
+        f_max (float or None): Maximum frequency.
+        log (bool): Whether to scale with log.
+        normalized (bool): Whether to normalize the melspectrogram.
+        floor_level (float): Floor level based on human perception (default=1e-5).
+    """
+    def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
+                 n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+                 log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
+        super().__init__()
+        self.n_fft = n_fft
+        hop_length = int(hop_length)
+        self.hop_length = hop_length
+        self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
+                                            win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
+                                            window_fn=torch.hann_window, center=False)
+        self.floor_level = floor_level
+        self.log = log
+
+    def forward(self, x):
+        p = int((self.n_fft - self.hop_length) // 2)
+        if len(x.shape) == 2:
+            x = x.unsqueeze(1)
+        x = F.pad(x, (p, p), "reflect")
+        # Make sure that all the frames are full.
+        # The combination of `pad_for_conv1d` and the above padding
+        # will make the output of size ceil(T / hop).
+        x = pad_for_conv1d(x, self.n_fft, self.hop_length)
+        self.mel_transform.to(x.device)
+        mel_spec = self.mel_transform(x)
+        B, C, freqs, frame = mel_spec.shape
+        if self.log:
+            mel_spec = torch.log10(self.floor_level + mel_spec)
+        return mel_spec.reshape(B, C * freqs, frame)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    p = int((self.n_fft - self.hop_length) // 2)
+    if len(x.shape) == 2:
+        x = x.unsqueeze(1)
+    x = F.pad(x, (p, p), "reflect")
+    # Make sure that all the frames are full.
+    # The combination of `pad_for_conv1d` and the above padding
+    # will make the output of size ceil(T / hop).
+    x = pad_for_conv1d(x, self.n_fft, self.hop_length)
+    self.mel_transform.to(x.device)
+    mel_spec = self.mel_transform(x)
+    B, C, freqs, frame = mel_spec.shape
+    if self.log:
+        mel_spec = torch.log10(self.floor_level + mel_spec)
+    return mel_spec.reshape(B, C * freqs, frame)
+
+
+
+
+
+class MultiScaleMelSpectrogramLoss +(sample_rate: int, range_start: int = 6, range_end: int = 11, n_mels: int = 64, f_min: float = 0.0, f_max: Optional[float] = None, normalized: bool = False, alphas: bool = True, floor_level: float = 1e-05) +
+
+

Multi-Scale spectrogram loss (msspec).

+

Args

+
+
sample_rate : int
+
Sample rate.
+
range_start : int
+
Power of 2 to use for the first scale.
+
range_stop : int
+
Power of 2 to use for the last scale.
+
n_mels : int
+
Number of mel bins.
+
f_min : float
+
Minimum frequency.
+
f_max : float or None
+
Maximum frequency.
+
normalized : bool
+
Whether to normalize the melspectrogram.
+
alphas : bool
+
Whether to use alphas as coefficients or not.
+
floor_level : float
+
Floor level value based on human perception (default=1e-5).
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MultiScaleMelSpectrogramLoss(nn.Module):
+    """Multi-Scale spectrogram loss (msspec).
+
+    Args:
+        sample_rate (int): Sample rate.
+        range_start (int): Power of 2 to use for the first scale.
+        range_stop (int): Power of 2 to use for the last scale.
+        n_mels (int): Number of mel bins.
+        f_min (float): Minimum frequency.
+        f_max (float or None): Maximum frequency.
+        normalized (bool): Whether to normalize the melspectrogram.
+        alphas (bool): Whether to use alphas as coefficients or not.
+        floor_level (float): Floor level value based on human perception (default=1e-5).
+    """
+    def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
+                 n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
+                 normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
+        super().__init__()
+        l1s = list()
+        l2s = list()
+        self.alphas = list()
+        self.total = 0
+        self.normalized = normalized
+        for i in range(range_start, range_end):
+            l1s.append(
+                MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
+                                      n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+                                      log=False, normalized=normalized, floor_level=floor_level))
+            l2s.append(
+                MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
+                                      n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
+                                      log=True, normalized=normalized, floor_level=floor_level))
+            if alphas:
+                self.alphas.append(np.sqrt(2 ** i - 1))
+            else:
+                self.alphas.append(1)
+            self.total += self.alphas[-1] + 1
+
+        self.l1s = nn.ModuleList(l1s)
+        self.l2s = nn.ModuleList(l2s)
+
+    def forward(self, x, y):
+        loss = 0.0
+        self.l1s.to(x.device)
+        self.l2s.to(x.device)
+        for i in range(len(self.alphas)):
+            s_x_1 = self.l1s[i](x)
+            s_y_1 = self.l1s[i](y)
+            s_x_2 = self.l2s[i](x)
+            s_y_2 = self.l2s[i](y)
+            loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
+        if self.normalized:
+            loss = loss / self.total
+        return loss
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x, y) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, y):
+    loss = 0.0
+    self.l1s.to(x.device)
+    self.l2s.to(x.device)
+    for i in range(len(self.alphas)):
+        s_x_1 = self.l1s[i](x)
+        s_y_1 = self.l1s[i](y)
+        s_x_2 = self.l2s[i](x)
+        s_y_2 = self.l2s[i](y)
+        loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
+    if self.normalized:
+        loss = loss / self.total
+    return loss
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/losses/stftloss.html b/api_docs/audiocraft/losses/stftloss.html new file mode 100644 index 00000000..2320e3fb --- /dev/null +++ b/api_docs/audiocraft/losses/stftloss.html @@ -0,0 +1,890 @@ + + + + + + +audiocraft.losses.stftloss API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.losses.stftloss

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# Adapted from MIT code under the original license
+# Copyright 2019 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+# TODO: Replace with torchaudio.STFT?
+def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
+          window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
+    """Perform STFT and convert to magnitude spectrogram.
+
+    Args:
+        x: Input signal tensor (B, C, T).
+        fft_size (int): FFT size.
+        hop_length (int): Hop size.
+        win_length (int): Window length.
+        window (torch.Tensor or None): Window function type.
+        normalized (bool): Whether to normalize the STFT or not.
+
+    Returns:
+        torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
+    """
+    B, C, T = x.shape
+    x_stft = torch.stft(
+        x.view(-1, T), fft_size, hop_length, win_length, window,
+        normalized=normalized, return_complex=True,
+    )
+    x_stft = x_stft.view(B, C, *x_stft.shape[1:])
+    real = x_stft.real
+    imag = x_stft.imag
+
+    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+    return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
+
+
+class SpectralConvergenceLoss(nn.Module):
+    """Spectral convergence loss.
+    """
+    def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.epsilon = epsilon
+
+    def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+        """Calculate forward propagation.
+
+        Args:
+            x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            torch.Tensor: Spectral convergence loss value.
+        """
+        return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
+
+
+class LogSTFTMagnitudeLoss(nn.Module):
+    """Log STFT magnitude loss.
+
+    Args:
+        epsilon (float): Epsilon value for numerical stability.
+    """
+    def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.epsilon = epsilon
+
+    def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+        """Calculate forward propagation.
+
+        Args:
+            x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            torch.Tensor: Log STFT magnitude loss value.
+        """
+        return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
+
+
+class STFTLosses(nn.Module):
+    """STFT losses.
+
+    Args:
+        n_fft (int): Size of FFT.
+        hop_length (int): Hop length.
+        win_length (int): Window length.
+        window (str): Window function type.
+        normalized (bool): Whether to use normalized STFT or not.
+        epsilon (float): Epsilon for numerical stability.
+    """
+    def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
+                 window: str = "hann_window", normalized: bool = False,
+                 epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.normalized = normalized
+        self.register_buffer("window", getattr(torch, window)(win_length))
+        self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
+
+    def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Calculate forward propagation.
+
+        Args:
+            x (torch.Tensor): Predicted signal (B, T).
+            y (torch.Tensor): Groundtruth signal (B, T).
+        Returns:
+            torch.Tensor: Spectral convergence loss value.
+            torch.Tensor: Log STFT magnitude loss value.
+        """
+        x_mag = _stft(x, self.n_fft, self.hop_length,
+                      self.win_length, self.window, self.normalized)  # type: ignore
+        y_mag = _stft(y, self.n_fft, self.hop_length,
+                      self.win_length, self.window, self.normalized)  # type: ignore
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+        return sc_loss, mag_loss
+
+
+class STFTLoss(nn.Module):
+    """Single Resolution STFT loss.
+
+    Args:
+        n_fft (int): Nb of FFT.
+        hop_length (int): Hop length.
+        win_length (int): Window length.
+        window (str): Window function type.
+        normalized (bool): Whether to use normalized STFT or not.
+        epsilon (float): Epsilon for numerical stability.
+        factor_sc (float): Coefficient for the spectral loss.
+        factor_mag (float): Coefficient for the magnitude loss.
+    """
+    def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
+                 window: str = "hann_window", normalized: bool = False,
+                 factor_sc: float = 0.1, factor_mag: float = 0.1,
+                 epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
+        self.factor_sc = factor_sc
+        self.factor_mag = factor_mag
+
+    def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Calculate forward propagation.
+
+        Args:
+            x (torch.Tensor): Predicted signal (B, T).
+            y (torch.Tensor): Groundtruth signal (B, T).
+        Returns:
+            torch.Tensor: Single resolution STFT loss.
+        """
+        sc_loss, mag_loss = self.loss(x, y)
+        return self.factor_sc * sc_loss + self.factor_mag * mag_loss
+
+
+class MRSTFTLoss(nn.Module):
+    """Multi resolution STFT loss.
+
+    Args:
+        n_ffts (Sequence[int]): Sequence of FFT sizes.
+        hop_lengths (Sequence[int]): Sequence of hop sizes.
+        win_lengths (Sequence[int]): Sequence of window lengths.
+        window (str): Window function type.
+        factor_sc (float): Coefficient for the spectral loss.
+        factor_mag (float): Coefficient for the magnitude loss.
+        normalized (bool): Whether to use normalized STFT or not.
+        epsilon (float): Epsilon for numerical stability.
+    """
+    def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
+                 win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
+                 factor_sc: float = 0.1, factor_mag: float = 0.1,
+                 normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+        self.stft_losses = torch.nn.ModuleList()
+        for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
+            self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
+        self.factor_sc = factor_sc
+        self.factor_mag = factor_mag
+
+    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        """Calculate forward propagation.
+
+        Args:
+            x (torch.Tensor): Predicted signal (B, T).
+            y (torch.Tensor): Groundtruth signal (B, T).
+        Returns:
+            torch.Tensor: Multi resolution STFT loss.
+        """
+        sc_loss = torch.Tensor([0.0])
+        mag_loss = torch.Tensor([0.0])
+        for f in self.stft_losses:
+            sc_l, mag_l = f(x, y)
+            sc_loss += sc_l
+            mag_loss += mag_l
+        sc_loss /= len(self.stft_losses)
+        mag_loss /= len(self.stft_losses)
+
+        return self.factor_sc * sc_loss + self.factor_mag * mag_loss
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class LogSTFTMagnitudeLoss +(epsilon: float = 1.1920928955078125e-07) +
+
+

Log STFT magnitude loss.

+

Args

+
+
epsilon : float
+
Epsilon value for numerical stability.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class LogSTFTMagnitudeLoss(nn.Module):
+    """Log STFT magnitude loss.
+
+    Args:
+        epsilon (float): Epsilon value for numerical stability.
+    """
+    def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.epsilon = epsilon
+
+    def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+        """Calculate forward propagation.
+
+        Args:
+            x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            torch.Tensor: Log STFT magnitude loss value.
+        """
+        return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor) ‑> Callable[..., Any] +
+
+

Calculate forward propagation.

+

Args

+
+
x_mag : torch.Tensor
+
Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+
y_mag : torch.Tensor
+
Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+
+

Returns

+
+
torch.Tensor
+
Log STFT magnitude loss value.
+
+
+ +Expand source code + +
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+    """Calculate forward propagation.
+
+    Args:
+        x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+        y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+    Returns:
+        torch.Tensor: Log STFT magnitude loss value.
+    """
+    return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
+
+
+
+
+
+class MRSTFTLoss +(n_ffts: Sequence[int] = [1024, 2048, 512], hop_lengths: Sequence[int] = [120, 240, 50], win_lengths: Sequence[int] = [600, 1200, 240], window: str = 'hann_window', factor_sc: float = 0.1, factor_mag: float = 0.1, normalized: bool = False, epsilon: float = 1.1920928955078125e-07) +
+
+

Multi resolution STFT loss.

+

Args

+
+
n_ffts : Sequence[int]
+
Sequence of FFT sizes.
+
hop_lengths : Sequence[int]
+
Sequence of hop sizes.
+
win_lengths : Sequence[int]
+
Sequence of window lengths.
+
window : str
+
Window function type.
+
factor_sc : float
+
Coefficient for the spectral loss.
+
factor_mag : float
+
Coefficient for the magnitude loss.
+
normalized : bool
+
Whether to use normalized STFT or not.
+
epsilon : float
+
Epsilon for numerical stability.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MRSTFTLoss(nn.Module):
+    """Multi resolution STFT loss.
+
+    Args:
+        n_ffts (Sequence[int]): Sequence of FFT sizes.
+        hop_lengths (Sequence[int]): Sequence of hop sizes.
+        win_lengths (Sequence[int]): Sequence of window lengths.
+        window (str): Window function type.
+        factor_sc (float): Coefficient for the spectral loss.
+        factor_mag (float): Coefficient for the magnitude loss.
+        normalized (bool): Whether to use normalized STFT or not.
+        epsilon (float): Epsilon for numerical stability.
+    """
+    def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
+                 win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
+                 factor_sc: float = 0.1, factor_mag: float = 0.1,
+                 normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+        self.stft_losses = torch.nn.ModuleList()
+        for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
+            self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
+        self.factor_sc = factor_sc
+        self.factor_mag = factor_mag
+
+    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        """Calculate forward propagation.
+
+        Args:
+            x (torch.Tensor): Predicted signal (B, T).
+            y (torch.Tensor): Groundtruth signal (B, T).
+        Returns:
+            torch.Tensor: Multi resolution STFT loss.
+        """
+        sc_loss = torch.Tensor([0.0])
+        mag_loss = torch.Tensor([0.0])
+        for f in self.stft_losses:
+            sc_l, mag_l = f(x, y)
+            sc_loss += sc_l
+            mag_loss += mag_l
+        sc_loss /= len(self.stft_losses)
+        mag_loss /= len(self.stft_losses)
+
+        return self.factor_sc * sc_loss + self.factor_mag * mag_loss
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor, y: torch.Tensor) ‑> torch.Tensor +
+
+

Calculate forward propagation.

+

Args

+
+
x : torch.Tensor
+
Predicted signal (B, T).
+
y : torch.Tensor
+
Groundtruth signal (B, T).
+
+

Returns

+
+
torch.Tensor
+
Multi resolution STFT loss.
+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    """Calculate forward propagation.
+
+    Args:
+        x (torch.Tensor): Predicted signal (B, T).
+        y (torch.Tensor): Groundtruth signal (B, T).
+    Returns:
+        torch.Tensor: Multi resolution STFT loss.
+    """
+    sc_loss = torch.Tensor([0.0])
+    mag_loss = torch.Tensor([0.0])
+    for f in self.stft_losses:
+        sc_l, mag_l = f(x, y)
+        sc_loss += sc_l
+        mag_loss += mag_l
+    sc_loss /= len(self.stft_losses)
+    mag_loss /= len(self.stft_losses)
+
+    return self.factor_sc * sc_loss + self.factor_mag * mag_loss
+
+
+
+
+
+class STFTLoss +(n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, window: str = 'hann_window', normalized: bool = False, factor_sc: float = 0.1, factor_mag: float = 0.1, epsilon: float = 1.1920928955078125e-07) +
+
+

Single Resolution STFT loss.

+

Args

+
+
n_fft : int
+
Nb of FFT.
+
hop_length : int
+
Hop length.
+
win_length : int
+
Window length.
+
window : str
+
Window function type.
+
normalized : bool
+
Whether to use normalized STFT or not.
+
epsilon : float
+
Epsilon for numerical stability.
+
factor_sc : float
+
Coefficient for the spectral loss.
+
factor_mag : float
+
Coefficient for the magnitude loss.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class STFTLoss(nn.Module):
+    """Single Resolution STFT loss.
+
+    Args:
+        n_fft (int): Nb of FFT.
+        hop_length (int): Hop length.
+        win_length (int): Window length.
+        window (str): Window function type.
+        normalized (bool): Whether to use normalized STFT or not.
+        epsilon (float): Epsilon for numerical stability.
+        factor_sc (float): Coefficient for the spectral loss.
+        factor_mag (float): Coefficient for the magnitude loss.
+    """
+    def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
+                 window: str = "hann_window", normalized: bool = False,
+                 factor_sc: float = 0.1, factor_mag: float = 0.1,
+                 epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
+        self.factor_sc = factor_sc
+        self.factor_mag = factor_mag
+
+    def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Calculate forward propagation.
+
+        Args:
+            x (torch.Tensor): Predicted signal (B, T).
+            y (torch.Tensor): Groundtruth signal (B, T).
+        Returns:
+            torch.Tensor: Single resolution STFT loss.
+        """
+        sc_loss, mag_loss = self.loss(x, y)
+        return self.factor_sc * sc_loss + self.factor_mag * mag_loss
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor, y: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Calculate forward propagation.

+

Args

+
+
x : torch.Tensor
+
Predicted signal (B, T).
+
y : torch.Tensor
+
Groundtruth signal (B, T).
+
+

Returns

+
+
torch.Tensor
+
Single resolution STFT loss.
+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+    """Calculate forward propagation.
+
+    Args:
+        x (torch.Tensor): Predicted signal (B, T).
+        y (torch.Tensor): Groundtruth signal (B, T).
+    Returns:
+        torch.Tensor: Single resolution STFT loss.
+    """
+    sc_loss, mag_loss = self.loss(x, y)
+    return self.factor_sc * sc_loss + self.factor_mag * mag_loss
+
+
+
+
+
+class STFTLosses +(n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, window: str = 'hann_window', normalized: bool = False, epsilon: float = 1.1920928955078125e-07) +
+
+

STFT losses.

+

Args

+
+
n_fft : int
+
Size of FFT.
+
hop_length : int
+
Hop length.
+
win_length : int
+
Window length.
+
window : str
+
Window function type.
+
normalized : bool
+
Whether to use normalized STFT or not.
+
epsilon : float
+
Epsilon for numerical stability.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class STFTLosses(nn.Module):
+    """STFT losses.
+
+    Args:
+        n_fft (int): Size of FFT.
+        hop_length (int): Hop length.
+        win_length (int): Window length.
+        window (str): Window function type.
+        normalized (bool): Whether to use normalized STFT or not.
+        epsilon (float): Epsilon for numerical stability.
+    """
+    def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
+                 window: str = "hann_window", normalized: bool = False,
+                 epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.normalized = normalized
+        self.register_buffer("window", getattr(torch, window)(win_length))
+        self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
+
+    def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Calculate forward propagation.
+
+        Args:
+            x (torch.Tensor): Predicted signal (B, T).
+            y (torch.Tensor): Groundtruth signal (B, T).
+        Returns:
+            torch.Tensor: Spectral convergence loss value.
+            torch.Tensor: Log STFT magnitude loss value.
+        """
+        x_mag = _stft(x, self.n_fft, self.hop_length,
+                      self.win_length, self.window, self.normalized)  # type: ignore
+        y_mag = _stft(y, self.n_fft, self.hop_length,
+                      self.win_length, self.window, self.normalized)  # type: ignore
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+        return sc_loss, mag_loss
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor, y: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Calculate forward propagation.

+

Args

+
+
x : torch.Tensor
+
Predicted signal (B, T).
+
y : torch.Tensor
+
Groundtruth signal (B, T).
+
+

Returns

+
+
torch.Tensor
+
Spectral convergence loss value.
+
torch.Tensor
+
Log STFT magnitude loss value.
+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+    """Calculate forward propagation.
+
+    Args:
+        x (torch.Tensor): Predicted signal (B, T).
+        y (torch.Tensor): Groundtruth signal (B, T).
+    Returns:
+        torch.Tensor: Spectral convergence loss value.
+        torch.Tensor: Log STFT magnitude loss value.
+    """
+    x_mag = _stft(x, self.n_fft, self.hop_length,
+                  self.win_length, self.window, self.normalized)  # type: ignore
+    y_mag = _stft(y, self.n_fft, self.hop_length,
+                  self.win_length, self.window, self.normalized)  # type: ignore
+    sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+    mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+    return sc_loss, mag_loss
+
+
+
+
+
+class SpectralConvergenceLoss +(epsilon: float = 1.1920928955078125e-07) +
+
+

Spectral convergence loss.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SpectralConvergenceLoss(nn.Module):
+    """Spectral convergence loss.
+    """
+    def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
+        super().__init__()
+        self.epsilon = epsilon
+
+    def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+        """Calculate forward propagation.
+
+        Args:
+            x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            torch.Tensor: Spectral convergence loss value.
+        """
+        return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor) ‑> Callable[..., Any] +
+
+

Calculate forward propagation.

+

Args

+
+
x_mag
+
Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+
y_mag
+
Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+
+

Returns

+
+
torch.Tensor
+
Spectral convergence loss value.
+
+
+ +Expand source code + +
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
+    """Calculate forward propagation.
+
+    Args:
+        x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+        y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+    Returns:
+        torch.Tensor: Spectral convergence loss value.
+    """
+    return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/metrics/chroma_cosinesim.html b/api_docs/audiocraft/metrics/chroma_cosinesim.html new file mode 100644 index 00000000..9b5cab5d --- /dev/null +++ b/api_docs/audiocraft/metrics/chroma_cosinesim.html @@ -0,0 +1,330 @@ + + + + + + +audiocraft.metrics.chroma_cosinesim API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.metrics.chroma_cosinesim

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torchmetrics
+
+from ..data.audio_utils import convert_audio
+from ..modules.chroma import ChromaExtractor
+
+
+class ChromaCosineSimilarityMetric(torchmetrics.Metric):
+    """Chroma cosine similarity metric.
+
+    This metric extracts a chromagram for a reference waveform and
+    a generated waveform and compares each frame using the cosine similarity
+    function. The output is the mean cosine similarity.
+
+    Args:
+        sample_rate (int): Sample rate used by the chroma extractor.
+        n_chroma (int): Number of chroma used by the chroma extractor.
+        radix2_exp (int): Exponent for the chroma extractor.
+        argmax (bool): Whether the chroma extractor uses argmax.
+        eps (float): Epsilon for cosine similarity computation.
+    """
+    def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
+        super().__init__()
+        self.chroma_sample_rate = sample_rate
+        self.n_chroma = n_chroma
+        self.eps = eps
+        self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
+                                                radix2_exp=radix2_exp, argmax=argmax)
+        self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
+
+    def update(self, preds: torch.Tensor, targets: torch.Tensor,
+               sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
+        if preds.size(0) == 0:
+            return
+
+        assert preds.shape == targets.shape, (
+            f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
+        assert preds.size(0) == sizes.size(0), (
+            f"Number of items in preds ({preds.shape}) mismatch ",
+            f"with sizes ({sizes.shape})")
+        assert preds.size(0) == sample_rates.size(0), (
+            f"Number of items in preds ({preds.shape}) mismatch ",
+            f"with sample_rates ({sample_rates.shape})")
+        assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
+
+        device = self.weight.device
+        preds, targets = preds.to(device), targets.to(device)  # type: ignore
+        sample_rate = sample_rates[0].item()
+        preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+        targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+        gt_chroma = self.chroma_extractor(targets)
+        gen_chroma = self.chroma_extractor(preds)
+        chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
+        for i in range(len(gt_chroma)):
+            t = int(chroma_lens[i].item())
+            cosine_sim = torch.nn.functional.cosine_similarity(
+                gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
+            self.cosine_sum += cosine_sim.sum(dim=0)  # type: ignore
+            self.weight += torch.tensor(t)  # type: ignore
+
+    def compute(self) -> float:
+        """Computes the average cosine similarty across all generated/target chromagrams pairs."""
+        assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0"  # type: ignore
+        return (self.cosine_sum / self.weight).item()  # type: ignore
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ChromaCosineSimilarityMetric +(sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-08) +
+
+

Chroma cosine similarity metric.

+

This metric extracts a chromagram for a reference waveform and +a generated waveform and compares each frame using the cosine similarity +function. The output is the mean cosine similarity.

+

Args

+
+
sample_rate : int
+
Sample rate used by the chroma extractor.
+
n_chroma : int
+
Number of chroma used by the chroma extractor.
+
radix2_exp : int
+
Exponent for the chroma extractor.
+
argmax : bool
+
Whether the chroma extractor uses argmax.
+
eps : float
+
Epsilon for cosine similarity computation.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ChromaCosineSimilarityMetric(torchmetrics.Metric):
+    """Chroma cosine similarity metric.
+
+    This metric extracts a chromagram for a reference waveform and
+    a generated waveform and compares each frame using the cosine similarity
+    function. The output is the mean cosine similarity.
+
+    Args:
+        sample_rate (int): Sample rate used by the chroma extractor.
+        n_chroma (int): Number of chroma used by the chroma extractor.
+        radix2_exp (int): Exponent for the chroma extractor.
+        argmax (bool): Whether the chroma extractor uses argmax.
+        eps (float): Epsilon for cosine similarity computation.
+    """
+    def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
+        super().__init__()
+        self.chroma_sample_rate = sample_rate
+        self.n_chroma = n_chroma
+        self.eps = eps
+        self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
+                                                radix2_exp=radix2_exp, argmax=argmax)
+        self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
+
+    def update(self, preds: torch.Tensor, targets: torch.Tensor,
+               sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
+        if preds.size(0) == 0:
+            return
+
+        assert preds.shape == targets.shape, (
+            f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
+        assert preds.size(0) == sizes.size(0), (
+            f"Number of items in preds ({preds.shape}) mismatch ",
+            f"with sizes ({sizes.shape})")
+        assert preds.size(0) == sample_rates.size(0), (
+            f"Number of items in preds ({preds.shape}) mismatch ",
+            f"with sample_rates ({sample_rates.shape})")
+        assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
+
+        device = self.weight.device
+        preds, targets = preds.to(device), targets.to(device)  # type: ignore
+        sample_rate = sample_rates[0].item()
+        preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+        targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+        gt_chroma = self.chroma_extractor(targets)
+        gen_chroma = self.chroma_extractor(preds)
+        chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
+        for i in range(len(gt_chroma)):
+            t = int(chroma_lens[i].item())
+            cosine_sim = torch.nn.functional.cosine_similarity(
+                gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
+            self.cosine_sum += cosine_sim.sum(dim=0)  # type: ignore
+            self.weight += torch.tensor(t)  # type: ignore
+
+    def compute(self) -> float:
+        """Computes the average cosine similarty across all generated/target chromagrams pairs."""
+        assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0"  # type: ignore
+        return (self.cosine_sum / self.weight).item()  # type: ignore
+
+

Ancestors

+
    +
  • torchmetrics.metric.Metric
  • +
  • torch.nn.modules.module.Module
  • +
  • abc.ABC
  • +
+

Class variables

+
+
var full_state_update : Optional[bool]
+
+
+
+
var higher_is_better : Optional[bool]
+
+
+
+
var is_differentiable : Optional[bool]
+
+
+
+
var plot_legend_name : Optional[str]
+
+
+
+
var plot_lower_bound : Optional[float]
+
+
+
+
var plot_upper_bound : Optional[float]
+
+
+
+
+

Methods

+
+
+def compute(self) ‑> float +
+
+

Computes the average cosine similarty across all generated/target chromagrams pairs.

+
+ +Expand source code + +
def compute(self) -> float:
+    """Computes the average cosine similarty across all generated/target chromagrams pairs."""
+    assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0"  # type: ignore
+    return (self.cosine_sum / self.weight).item()  # type: ignore
+
+
+
+def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) ‑> None +
+
+

Compute cosine similarity between chromagrams and accumulate scores over the dataset.

+
+ +Expand source code + +
def update(self, preds: torch.Tensor, targets: torch.Tensor,
+           sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+    """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
+    if preds.size(0) == 0:
+        return
+
+    assert preds.shape == targets.shape, (
+        f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
+    assert preds.size(0) == sizes.size(0), (
+        f"Number of items in preds ({preds.shape}) mismatch ",
+        f"with sizes ({sizes.shape})")
+    assert preds.size(0) == sample_rates.size(0), (
+        f"Number of items in preds ({preds.shape}) mismatch ",
+        f"with sample_rates ({sample_rates.shape})")
+    assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
+
+    device = self.weight.device
+    preds, targets = preds.to(device), targets.to(device)  # type: ignore
+    sample_rate = sample_rates[0].item()
+    preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+    targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
+    gt_chroma = self.chroma_extractor(targets)
+    gen_chroma = self.chroma_extractor(preds)
+    chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
+    for i in range(len(gt_chroma)):
+        t = int(chroma_lens[i].item())
+        cosine_sim = torch.nn.functional.cosine_similarity(
+            gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
+        self.cosine_sum += cosine_sim.sum(dim=0)  # type: ignore
+        self.weight += torch.tensor(t)  # type: ignore
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/metrics/clap_consistency.html b/api_docs/audiocraft/metrics/clap_consistency.html new file mode 100644 index 00000000..5023e4ef --- /dev/null +++ b/api_docs/audiocraft/metrics/clap_consistency.html @@ -0,0 +1,410 @@ + + + + + + +audiocraft.metrics.clap_consistency API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.metrics.clap_consistency

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import typing as tp
+
+import torch
+import torchmetrics
+from transformers import RobertaTokenizer  # type: ignore
+
+from ..data.audio_utils import convert_audio
+from ..environment import AudioCraftEnvironment
+from ..utils.utils import load_clap_state_dict
+
+try:
+    import laion_clap  # type: ignore
+except ImportError:
+    laion_clap = None
+
+
+class TextConsistencyMetric(torchmetrics.Metric):
+    """Text consistency metric measuring consistency between audio and text pairs."""
+
+    def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
+
+    def compute(self):
+        raise NotImplementedError("implement how to compute the final metric score.")
+
+
+class CLAPTextConsistencyMetric(TextConsistencyMetric):
+    """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
+
+    This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
+    or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
+
+    As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
+    similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
+    well as the generated audio based on them, and define the MCC metric as the average cosine similarity
+    between these embeddings.
+
+    Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
+    """
+    def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
+        super().__init__()
+        if laion_clap is None:
+            raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
+        self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self._initialize_model(model_path, model_arch, enable_fusion)
+
+    def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
+        model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
+        self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+        self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
+        self.model_sample_rate = 48_000
+        load_clap_state_dict(self.model, model_path)
+        self.model.eval()
+
+    def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
+        # we use the default params from CLAP module here as well
+        return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
+
+    def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
+        assert audio.size(0) == len(text), "Number of audio and text samples should match"
+        assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
+        sample_rate = int(sample_rates[0].item())
+        # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
+        audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
+        audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
+        text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
+        # cosine similarity between the text and the audio embedding
+        cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
+        self.cosine_sum += cosine_sim.sum(dim=0)
+        self.weight += torch.tensor(cosine_sim.size(0))
+
+    def compute(self):
+        """Computes the average cosine similarty across all audio/text pairs."""
+        assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0"  # type: ignore
+        return (self.cosine_sum / self.weight).item()  # type: ignore
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CLAPTextConsistencyMetric +(model_path: Union[str, pathlib.Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False) +
+
+

Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).

+

This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) +or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).

+

As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the +similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as +well as the generated audio based on them, and define the MCC metric as the average cosine similarity +between these embeddings.

+

Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class CLAPTextConsistencyMetric(TextConsistencyMetric):
+    """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
+
+    This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
+    or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
+
+    As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
+    similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
+    well as the generated audio based on them, and define the MCC metric as the average cosine similarity
+    between these embeddings.
+
+    Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
+    """
+    def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
+        super().__init__()
+        if laion_clap is None:
+            raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
+        self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self._initialize_model(model_path, model_arch, enable_fusion)
+
+    def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
+        model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
+        self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+        self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
+        self.model_sample_rate = 48_000
+        load_clap_state_dict(self.model, model_path)
+        self.model.eval()
+
+    def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
+        # we use the default params from CLAP module here as well
+        return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
+
+    def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
+        assert audio.size(0) == len(text), "Number of audio and text samples should match"
+        assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
+        sample_rate = int(sample_rates[0].item())
+        # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
+        audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
+        audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
+        text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
+        # cosine similarity between the text and the audio embedding
+        cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
+        self.cosine_sum += cosine_sim.sum(dim=0)
+        self.weight += torch.tensor(cosine_sim.size(0))
+
+    def compute(self):
+        """Computes the average cosine similarty across all audio/text pairs."""
+        assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0"  # type: ignore
+        return (self.cosine_sum / self.weight).item()  # type: ignore
+
+

Ancestors

+ +

Class variables

+
+
var full_state_update : Optional[bool]
+
+
+
+
var higher_is_better : Optional[bool]
+
+
+
+
var is_differentiable : Optional[bool]
+
+
+
+
var plot_legend_name : Optional[str]
+
+
+
+
var plot_lower_bound : Optional[float]
+
+
+
+
var plot_upper_bound : Optional[float]
+
+
+
+
+

Methods

+
+
+def compute(self) +
+
+

Computes the average cosine similarty across all audio/text pairs.

+
+ +Expand source code + +
def compute(self):
+    """Computes the average cosine similarty across all audio/text pairs."""
+    assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0"  # type: ignore
+    return (self.cosine_sum / self.weight).item()  # type: ignore
+
+
+
+def update(self, audio: torch.Tensor, text: List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) ‑> None +
+
+

Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.

+
+ +Expand source code + +
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+    """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
+    assert audio.size(0) == len(text), "Number of audio and text samples should match"
+    assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
+    sample_rate = int(sample_rates[0].item())
+    # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
+    audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
+    audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
+    text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
+    # cosine similarity between the text and the audio embedding
+    cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
+    self.cosine_sum += cosine_sim.sum(dim=0)
+    self.weight += torch.tensor(cosine_sim.size(0))
+
+
+
+
+
+class TextConsistencyMetric +(**kwargs: Any) +
+
+

Text consistency metric measuring consistency between audio and text pairs.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class TextConsistencyMetric(torchmetrics.Metric):
+    """Text consistency metric measuring consistency between audio and text pairs."""
+
+    def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
+
+    def compute(self):
+        raise NotImplementedError("implement how to compute the final metric score.")
+
+

Ancestors

+
    +
  • torchmetrics.metric.Metric
  • +
  • torch.nn.modules.module.Module
  • +
  • abc.ABC
  • +
+

Subclasses

+ +

Class variables

+
+
var full_state_update : Optional[bool]
+
+
+
+
var higher_is_better : Optional[bool]
+
+
+
+
var is_differentiable : Optional[bool]
+
+
+
+
var plot_legend_name : Optional[str]
+
+
+
+
var plot_lower_bound : Optional[float]
+
+
+
+
var plot_upper_bound : Optional[float]
+
+
+
+
+

Methods

+
+
+def compute(self) +
+
+

Override this method to compute the final metric value.

+

This method will automatically synchronize state variables when running in distributed backend.

+
+ +Expand source code + +
def compute(self):
+    raise NotImplementedError("implement how to compute the final metric score.")
+
+
+
+def update(self, audio: torch.Tensor, text: List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) ‑> None +
+
+

Override this method to update the state variables of your metric class.

+
+ +Expand source code + +
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+    raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/metrics/fad.html b/api_docs/audiocraft/metrics/fad.html new file mode 100644 index 00000000..79c1e439 --- /dev/null +++ b/api_docs/audiocraft/metrics/fad.html @@ -0,0 +1,962 @@ + + + + + + +audiocraft.metrics.fad API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.metrics.fad

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from pathlib import Path
+import os
+import subprocess
+import tempfile
+import typing as tp
+
+from audiocraft.data.audio import audio_write
+from audiocraft.data.audio_utils import convert_audio
+import flashy
+import torch
+import torchmetrics
+
+from ..environment import AudioCraftEnvironment
+
+
+logger = logging.getLogger(__name__)
+
+VGGISH_SAMPLE_RATE = 16_000
+VGGISH_CHANNELS = 1
+
+
+class FrechetAudioDistanceMetric(torchmetrics.Metric):
+    """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
+
+    From: D.C. Dowson & B.V. Landau The Fréchet distance between
+    multivariate normal distributions
+    https://doi.org/10.1016/0047-259X(82)90077-X
+    The Fréchet distance between two multivariate gaussians,
+    `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
+    d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
+        = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
+                        - 2 * Tr(sqrt(sigma_x*sigma_y)))
+
+    To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
+    from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
+    We provide the below instructions as reference but we do not guarantee for further support
+    in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
+
+        We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
+
+        1. Get the code and models following the repository instructions. We used the steps below:
+                git clone git@github.com:google-research/google-research.git
+                git clone git@github.com:tensorflow/models.git
+                mkdir google-research/tensorflow_models
+                touch google-research/tensorflow_models/__init__.py
+                cp -r models/research/audioset google-research/tensorflow_models/
+                touch google-research/tensorflow_models/audioset/__init__.py
+                echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
+                    google-research/tensorflow_models/audioset/__init__.py
+                # we can now remove the tensorflow models repository
+                # rm -r models
+                cd google-research
+           Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
+           assumes it is placed in the AudioCraft reference dir.
+
+           Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
+           - Update xrange for range in:
+             https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
+           - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
+             `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
+              https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
+           - Update `import vggish_params as params` to `from . import vggish_params as params` in:
+             https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
+           - Add flag to provide a given batch size for running the AudioSet model in:
+             https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
+             ```
+             flags.DEFINE_integer('batch_size', 64,
+                                  'Number of samples in the batch for AudioSet model.')
+             ```
+             Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
+             `batch_size=FLAGS.batch_size` to the provided parameters.
+
+        2. Follow instructions for the library installation and a valid TensorFlow installation
+           ```
+           # e.g. instructions from: https://www.tensorflow.org/install/pip
+           conda install -c conda-forge cudatoolkit=11.8.0
+           python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
+           mkdir -p $CONDA_PREFIX/etc/conda/activate.d
+           echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
+             >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+           echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
+             >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+           source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+           # Verify install: on a machine with GPU device
+           python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
+           ```
+
+           Now install frechet_audio_distance required dependencies:
+           ```
+           # We assume we already have TensorFlow installed from the above steps
+           pip install apache-beam numpy scipy tf_slim
+           ```
+
+           Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
+           (you may want to specify --model_ckpt flag pointing to the model's path).
+
+        3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
+           and Tensorflow library path from the above installation steps:
+            export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
+            export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
+
+            e.g. assuming we have installed everything in a dedicated conda env
+            with python 3.10 that is currently active:
+            export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
+            export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
+
+            Finally you may want to export the following variable:
+            export TF_FORCE_GPU_ALLOW_GROWTH=true
+            See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
+
+            You can save those environment variables in your training conda env, when currently active:
+            `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
+            e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
+            and the training conda env is named audiocraft:
+            ```
+            # activate training env
+            conda activate audiocraft
+            # get path to all envs
+            CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
+            # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
+            touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
+                $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
+                $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            # optionally:
+            echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            # you may need to reactivate the audiocraft env for this to take effect
+            ```
+
+    Args:
+        bin (Path or str): Path to installed frechet audio distance code.
+        model_path (Path or str): Path to Tensorflow checkpoint for the model
+            used to compute statistics over the embedding beams.
+        format (str): Audio format used to save files.
+        log_folder (Path or str, optional): Path where to write process logs.
+    """
+    def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
+                 format: str = "wav", batch_size: tp.Optional[int] = None,
+                 log_folder: tp.Optional[tp.Union[Path, str]] = None):
+        super().__init__()
+        self.model_sample_rate = VGGISH_SAMPLE_RATE
+        self.model_channels = VGGISH_CHANNELS
+        self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
+        assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
+        self.format = format
+        self.batch_size = batch_size
+        self.bin = bin
+        self.tf_env = {"PYTHONPATH": str(self.bin)}
+        self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
+        logger.info("Python exe for TF is  %s", self.python_path)
+        if 'TF_LIBRARY_PATH' in os.environ:
+            self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
+        if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
+            self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
+        logger.info("Env for TF is %r", self.tf_env)
+        self.reset(log_folder)
+        self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
+
+    def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
+        """Reset torchmetrics.Metrics state."""
+        log_folder = Path(log_folder or tempfile.mkdtemp())
+        self.tmp_dir = log_folder / 'fad'
+        self.tmp_dir.mkdir(exist_ok=True)
+        self.samples_tests_dir = self.tmp_dir / 'tests'
+        self.samples_tests_dir.mkdir(exist_ok=True)
+        self.samples_background_dir = self.tmp_dir / 'background'
+        self.samples_background_dir.mkdir(exist_ok=True)
+        self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
+        self.manifest_background = self.tmp_dir / 'files_background.cvs'
+        self.stats_tests_dir = self.tmp_dir / 'stats_tests'
+        self.stats_background_dir = self.tmp_dir / 'stats_background'
+        self.counter = 0
+
+    def update(self, preds: torch.Tensor, targets: torch.Tensor,
+               sizes: torch.Tensor, sample_rates: torch.Tensor,
+               stems: tp.Optional[tp.List[str]] = None):
+        """Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
+        assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
+        num_samples = preds.shape[0]
+        assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
+        assert stems is None or num_samples == len(set(stems))
+        for i in range(num_samples):
+            self.total_files += 1  # type: ignore
+            self.counter += 1
+            wav_len = int(sizes[i].item())
+            sample_rate = int(sample_rates[i].item())
+            pred_wav = preds[i]
+            target_wav = targets[i]
+            pred_wav = pred_wav[..., :wav_len]
+            target_wav = target_wav[..., :wav_len]
+            stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
+            # dump audio files
+            try:
+                pred_wav = convert_audio(
+                    pred_wav.unsqueeze(0), from_rate=sample_rate,
+                    to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+                audio_write(
+                    self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
+                    format=self.format, strategy="peak")
+            except Exception as e:
+                logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
+            try:
+                # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
+                # the original audio when writing it
+                target_wav = convert_audio(
+                    target_wav.unsqueeze(0), from_rate=sample_rate,
+                    to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+                audio_write(
+                    self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
+                    format=self.format, strategy="peak")
+            except Exception as e:
+                logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
+
+    def _get_samples_name(self, is_background: bool):
+        return 'background' if is_background else 'tests'
+
+    def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
+        if is_background:
+            input_samples_dir = self.samples_background_dir
+            input_filename = self.manifest_background
+            stats_name = self.stats_background_dir
+        else:
+            input_samples_dir = self.samples_tests_dir
+            input_filename = self.manifest_tests
+            stats_name = self.stats_tests_dir
+        beams_name = self._get_samples_name(is_background)
+        log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
+
+        logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
+        with open(input_filename, "w") as fout:
+            for path in Path(input_samples_dir).glob(f"*.{self.format}"):
+                fout.write(f"{str(path)}\n")
+
+        cmd = [
+            self.python_path, "-m",
+            "frechet_audio_distance.create_embeddings_main",
+            "--model_ckpt", f"{self.model_path}",
+            "--input_files", f"{str(input_filename)}",
+            "--stats", f"{str(stats_name)}",
+        ]
+        if self.batch_size is not None:
+            cmd += ["--batch_size", str(self.batch_size)]
+        logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
+        env = os.environ
+        if gpu_index is not None:
+            env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
+        process = subprocess.Popen(
+            cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
+        return process, log_file
+
+    def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
+        cmd = [
+            self.python_path, "-m", "frechet_audio_distance.compute_fad",
+            "--test_stats", f"{str(self.stats_tests_dir)}",
+            "--background_stats", f"{str(self.stats_background_dir)}",
+        ]
+        logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
+        env = os.environ
+        if gpu_index is not None:
+            env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
+        result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
+        if result.returncode:
+            logger.error(
+                "Error with FAD computation from stats: \n %s \n %s",
+                result.stdout.decode(), result.stderr.decode()
+            )
+            raise RuntimeError("Error while executing FAD computation from stats")
+        try:
+            # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
+            fad_score = float(result.stdout[4:])
+            return fad_score
+        except Exception as e:
+            raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
+
+    def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
+        beams_name = self._get_samples_name(is_background)
+        if returncode:
+            with open(log_file, "r") as f:
+                error_log = f.read()
+                logger.error(error_log)
+            os._exit(1)
+        else:
+            logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
+
+    def _parallel_create_embedding_beams(self, num_of_gpus: int):
+        assert num_of_gpus > 0
+        logger.info("Creating embeddings beams in a parallel manner on different GPUs")
+        tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
+        bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
+        tests_beams_code = tests_beams_process.wait()
+        bg_beams_code = bg_beams_process.wait()
+        self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
+        self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
+
+    def _sequential_create_embedding_beams(self):
+        logger.info("Creating embeddings beams in a sequential manner")
+        tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
+        tests_beams_code = tests_beams_process.wait()
+        self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
+        bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
+        bg_beams_code = bg_beams_process.wait()
+        self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
+
+    @flashy.distrib.rank_zero_only
+    def _local_compute_frechet_audio_distance(self):
+        """Compute Frechet Audio Distance score calling TensorFlow API."""
+        num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
+        if num_of_gpus > 1:
+            self._parallel_create_embedding_beams(num_of_gpus)
+        else:
+            self._sequential_create_embedding_beams()
+        fad_score = self._compute_fad_score(gpu_index=0)
+        return fad_score
+
+    def compute(self) -> float:
+        """Compute metrics."""
+        assert self.total_files.item() > 0, "No files dumped for FAD computation!"  # type: ignore
+        fad_score = self._local_compute_frechet_audio_distance()
+        logger.warning(f"FAD score = {fad_score}")
+        fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
+        return fad_score
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class FrechetAudioDistanceMetric +(bin: Union[str, pathlib.Path], model_path: Union[str, pathlib.Path], format: str = 'wav', batch_size: Optional[int] = None, log_folder: Union[pathlib.Path, str, None] = None) +
+
+

Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.

+

From: D.C. Dowson & B.V. Landau The Fréchet distance between +multivariate normal distributions +https://doi.org/10.1016/0047-259X(82)90077-X +The Fréchet distance between two multivariate gaussians, +X ~ N(mu_x, sigma_x) and Y ~ N(mu_y, sigma_y), is d^2. +d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_xsigma_y)) += (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y) +- 2 * Tr(sqrt(sigma_xsigma_y)))

+

To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup +from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance +We provide the below instructions as reference but we do not guarantee for further support +in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.

+
We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
+
+1. Get the code and models following the repository instructions. We used the steps below:
+        git clone git@github.com:google-research/google-research.git
+        git clone git@github.com:tensorflow/models.git
+        mkdir google-research/tensorflow_models
+        touch google-research/tensorflow_models/__init__.py
+        cp -r models/research/audioset google-research/tensorflow_models/
+        touch google-research/tensorflow_models/audioset/__init__.py
+        echo "from .vggish import mel_features, vggish_params, vggish_slim" >                     google-research/tensorflow_models/audioset/__init__.py
+        # we can now remove the tensorflow models repository
+        # rm -r models
+        cd google-research
+   Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
+   assumes it is placed in the AudioCraft reference dir.
+
+   Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
+   - Update xrange for range in:
+     <https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py>
+   - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
+     `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
+      <https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py>
+   - Update <code>import vggish\_params as params</code> to <code>from . import vggish\_params as params</code> in:
+     <https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py>
+   - Add flag to provide a given batch size for running the AudioSet model in:
+     <https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py>
+     ```
+     flags.DEFINE_integer('batch_size', 64,
+                          'Number of samples in the batch for AudioSet model.')
+     ```
+     Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
+     `batch_size=FLAGS.batch_size` to the provided parameters.
+
+2. Follow instructions for the library installation and a valid TensorFlow installation
+   ```
+   # e.g. instructions from: <https://www.tensorflow.org/install/pip>
+   conda install -c conda-forge cudatoolkit=11.8.0
+   python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
+   mkdir -p $CONDA_PREFIX/etc/conda/activate.d
+   echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))'              >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+   echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib'              >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+   source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+   # Verify install: on a machine with GPU device
+   python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
+   ```
+
+   Now install frechet_audio_distance required dependencies:
+   ```
+   # We assume we already have TensorFlow installed from the above steps
+   pip install apache-beam numpy scipy tf_slim
+   ```
+
+   Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
+   (you may want to specify --model_ckpt flag pointing to the model's path).
+
+3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
+   and Tensorflow library path from the above installation steps:
+    export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
+    export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
+
+    e.g. assuming we have installed everything in a dedicated conda env
+    with python 3.10 that is currently active:
+    export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
+    export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
+
+    Finally you may want to export the following variable:
+    export TF_FORCE_GPU_ALLOW_GROWTH=true
+    See: <https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth>
+
+    You can save those environment variables in your training conda env, when currently active:
+    `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
+    e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
+    and the training conda env is named audiocraft:
+    ```
+    # activate training env
+    conda activate audiocraft
+    # get path to all envs
+    CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
+    # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
+    touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+    echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >>                 $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+    echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >>                 $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+    # optionally:
+    echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+    # you may need to reactivate the audiocraft env for this to take effect
+    ```
+
+

Args

+
+
bin : Path or str
+
Path to installed frechet audio distance code.
+
model_path : Path or str
+
Path to Tensorflow checkpoint for the model +used to compute statistics over the embedding beams.
+
format : str
+
Audio format used to save files.
+
log_folder : Path or str, optional
+
Path where to write process logs.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class FrechetAudioDistanceMetric(torchmetrics.Metric):
+    """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
+
+    From: D.C. Dowson & B.V. Landau The Fréchet distance between
+    multivariate normal distributions
+    https://doi.org/10.1016/0047-259X(82)90077-X
+    The Fréchet distance between two multivariate gaussians,
+    `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
+    d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
+        = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
+                        - 2 * Tr(sqrt(sigma_x*sigma_y)))
+
+    To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
+    from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
+    We provide the below instructions as reference but we do not guarantee for further support
+    in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
+
+        We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
+
+        1. Get the code and models following the repository instructions. We used the steps below:
+                git clone git@github.com:google-research/google-research.git
+                git clone git@github.com:tensorflow/models.git
+                mkdir google-research/tensorflow_models
+                touch google-research/tensorflow_models/__init__.py
+                cp -r models/research/audioset google-research/tensorflow_models/
+                touch google-research/tensorflow_models/audioset/__init__.py
+                echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
+                    google-research/tensorflow_models/audioset/__init__.py
+                # we can now remove the tensorflow models repository
+                # rm -r models
+                cd google-research
+           Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
+           assumes it is placed in the AudioCraft reference dir.
+
+           Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
+           - Update xrange for range in:
+             https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
+           - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
+             `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
+              https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
+           - Update `import vggish_params as params` to `from . import vggish_params as params` in:
+             https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
+           - Add flag to provide a given batch size for running the AudioSet model in:
+             https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
+             ```
+             flags.DEFINE_integer('batch_size', 64,
+                                  'Number of samples in the batch for AudioSet model.')
+             ```
+             Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
+             `batch_size=FLAGS.batch_size` to the provided parameters.
+
+        2. Follow instructions for the library installation and a valid TensorFlow installation
+           ```
+           # e.g. instructions from: https://www.tensorflow.org/install/pip
+           conda install -c conda-forge cudatoolkit=11.8.0
+           python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
+           mkdir -p $CONDA_PREFIX/etc/conda/activate.d
+           echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
+             >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+           echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
+             >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+           source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+           # Verify install: on a machine with GPU device
+           python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
+           ```
+
+           Now install frechet_audio_distance required dependencies:
+           ```
+           # We assume we already have TensorFlow installed from the above steps
+           pip install apache-beam numpy scipy tf_slim
+           ```
+
+           Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
+           (you may want to specify --model_ckpt flag pointing to the model's path).
+
+        3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
+           and Tensorflow library path from the above installation steps:
+            export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
+            export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
+
+            e.g. assuming we have installed everything in a dedicated conda env
+            with python 3.10 that is currently active:
+            export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
+            export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
+
+            Finally you may want to export the following variable:
+            export TF_FORCE_GPU_ALLOW_GROWTH=true
+            See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
+
+            You can save those environment variables in your training conda env, when currently active:
+            `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
+            e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
+            and the training conda env is named audiocraft:
+            ```
+            # activate training env
+            conda activate audiocraft
+            # get path to all envs
+            CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
+            # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
+            touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
+                $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
+                $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            # optionally:
+            echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+            # you may need to reactivate the audiocraft env for this to take effect
+            ```
+
+    Args:
+        bin (Path or str): Path to installed frechet audio distance code.
+        model_path (Path or str): Path to Tensorflow checkpoint for the model
+            used to compute statistics over the embedding beams.
+        format (str): Audio format used to save files.
+        log_folder (Path or str, optional): Path where to write process logs.
+    """
+    def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
+                 format: str = "wav", batch_size: tp.Optional[int] = None,
+                 log_folder: tp.Optional[tp.Union[Path, str]] = None):
+        super().__init__()
+        self.model_sample_rate = VGGISH_SAMPLE_RATE
+        self.model_channels = VGGISH_CHANNELS
+        self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
+        assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
+        self.format = format
+        self.batch_size = batch_size
+        self.bin = bin
+        self.tf_env = {"PYTHONPATH": str(self.bin)}
+        self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
+        logger.info("Python exe for TF is  %s", self.python_path)
+        if 'TF_LIBRARY_PATH' in os.environ:
+            self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
+        if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
+            self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
+        logger.info("Env for TF is %r", self.tf_env)
+        self.reset(log_folder)
+        self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
+
+    def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
+        """Reset torchmetrics.Metrics state."""
+        log_folder = Path(log_folder or tempfile.mkdtemp())
+        self.tmp_dir = log_folder / 'fad'
+        self.tmp_dir.mkdir(exist_ok=True)
+        self.samples_tests_dir = self.tmp_dir / 'tests'
+        self.samples_tests_dir.mkdir(exist_ok=True)
+        self.samples_background_dir = self.tmp_dir / 'background'
+        self.samples_background_dir.mkdir(exist_ok=True)
+        self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
+        self.manifest_background = self.tmp_dir / 'files_background.cvs'
+        self.stats_tests_dir = self.tmp_dir / 'stats_tests'
+        self.stats_background_dir = self.tmp_dir / 'stats_background'
+        self.counter = 0
+
+    def update(self, preds: torch.Tensor, targets: torch.Tensor,
+               sizes: torch.Tensor, sample_rates: torch.Tensor,
+               stems: tp.Optional[tp.List[str]] = None):
+        """Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
+        assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
+        num_samples = preds.shape[0]
+        assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
+        assert stems is None or num_samples == len(set(stems))
+        for i in range(num_samples):
+            self.total_files += 1  # type: ignore
+            self.counter += 1
+            wav_len = int(sizes[i].item())
+            sample_rate = int(sample_rates[i].item())
+            pred_wav = preds[i]
+            target_wav = targets[i]
+            pred_wav = pred_wav[..., :wav_len]
+            target_wav = target_wav[..., :wav_len]
+            stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
+            # dump audio files
+            try:
+                pred_wav = convert_audio(
+                    pred_wav.unsqueeze(0), from_rate=sample_rate,
+                    to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+                audio_write(
+                    self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
+                    format=self.format, strategy="peak")
+            except Exception as e:
+                logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
+            try:
+                # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
+                # the original audio when writing it
+                target_wav = convert_audio(
+                    target_wav.unsqueeze(0), from_rate=sample_rate,
+                    to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+                audio_write(
+                    self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
+                    format=self.format, strategy="peak")
+            except Exception as e:
+                logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
+
+    def _get_samples_name(self, is_background: bool):
+        return 'background' if is_background else 'tests'
+
+    def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
+        if is_background:
+            input_samples_dir = self.samples_background_dir
+            input_filename = self.manifest_background
+            stats_name = self.stats_background_dir
+        else:
+            input_samples_dir = self.samples_tests_dir
+            input_filename = self.manifest_tests
+            stats_name = self.stats_tests_dir
+        beams_name = self._get_samples_name(is_background)
+        log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
+
+        logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
+        with open(input_filename, "w") as fout:
+            for path in Path(input_samples_dir).glob(f"*.{self.format}"):
+                fout.write(f"{str(path)}\n")
+
+        cmd = [
+            self.python_path, "-m",
+            "frechet_audio_distance.create_embeddings_main",
+            "--model_ckpt", f"{self.model_path}",
+            "--input_files", f"{str(input_filename)}",
+            "--stats", f"{str(stats_name)}",
+        ]
+        if self.batch_size is not None:
+            cmd += ["--batch_size", str(self.batch_size)]
+        logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
+        env = os.environ
+        if gpu_index is not None:
+            env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
+        process = subprocess.Popen(
+            cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
+        return process, log_file
+
+    def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
+        cmd = [
+            self.python_path, "-m", "frechet_audio_distance.compute_fad",
+            "--test_stats", f"{str(self.stats_tests_dir)}",
+            "--background_stats", f"{str(self.stats_background_dir)}",
+        ]
+        logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
+        env = os.environ
+        if gpu_index is not None:
+            env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
+        result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
+        if result.returncode:
+            logger.error(
+                "Error with FAD computation from stats: \n %s \n %s",
+                result.stdout.decode(), result.stderr.decode()
+            )
+            raise RuntimeError("Error while executing FAD computation from stats")
+        try:
+            # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
+            fad_score = float(result.stdout[4:])
+            return fad_score
+        except Exception as e:
+            raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
+
+    def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
+        beams_name = self._get_samples_name(is_background)
+        if returncode:
+            with open(log_file, "r") as f:
+                error_log = f.read()
+                logger.error(error_log)
+            os._exit(1)
+        else:
+            logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
+
+    def _parallel_create_embedding_beams(self, num_of_gpus: int):
+        assert num_of_gpus > 0
+        logger.info("Creating embeddings beams in a parallel manner on different GPUs")
+        tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
+        bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
+        tests_beams_code = tests_beams_process.wait()
+        bg_beams_code = bg_beams_process.wait()
+        self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
+        self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
+
+    def _sequential_create_embedding_beams(self):
+        logger.info("Creating embeddings beams in a sequential manner")
+        tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
+        tests_beams_code = tests_beams_process.wait()
+        self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
+        bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
+        bg_beams_code = bg_beams_process.wait()
+        self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
+
+    @flashy.distrib.rank_zero_only
+    def _local_compute_frechet_audio_distance(self):
+        """Compute Frechet Audio Distance score calling TensorFlow API."""
+        num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
+        if num_of_gpus > 1:
+            self._parallel_create_embedding_beams(num_of_gpus)
+        else:
+            self._sequential_create_embedding_beams()
+        fad_score = self._compute_fad_score(gpu_index=0)
+        return fad_score
+
+    def compute(self) -> float:
+        """Compute metrics."""
+        assert self.total_files.item() > 0, "No files dumped for FAD computation!"  # type: ignore
+        fad_score = self._local_compute_frechet_audio_distance()
+        logger.warning(f"FAD score = {fad_score}")
+        fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
+        return fad_score
+
+

Ancestors

+
    +
  • torchmetrics.metric.Metric
  • +
  • torch.nn.modules.module.Module
  • +
  • abc.ABC
  • +
+

Class variables

+
+
var full_state_update : Optional[bool]
+
+
+
+
var higher_is_better : Optional[bool]
+
+
+
+
var is_differentiable : Optional[bool]
+
+
+
+
var plot_legend_name : Optional[str]
+
+
+
+
var plot_lower_bound : Optional[float]
+
+
+
+
var plot_upper_bound : Optional[float]
+
+
+
+
+

Methods

+
+
+def compute(self) ‑> float +
+
+

Compute metrics.

+
+ +Expand source code + +
def compute(self) -> float:
+    """Compute metrics."""
+    assert self.total_files.item() > 0, "No files dumped for FAD computation!"  # type: ignore
+    fad_score = self._local_compute_frechet_audio_distance()
+    logger.warning(f"FAD score = {fad_score}")
+    fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
+    return fad_score
+
+
+
+def reset(self, log_folder: Union[pathlib.Path, str, None] = None) +
+
+

Reset torchmetrics.Metrics state.

+
+ +Expand source code + +
def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
+    """Reset torchmetrics.Metrics state."""
+    log_folder = Path(log_folder or tempfile.mkdtemp())
+    self.tmp_dir = log_folder / 'fad'
+    self.tmp_dir.mkdir(exist_ok=True)
+    self.samples_tests_dir = self.tmp_dir / 'tests'
+    self.samples_tests_dir.mkdir(exist_ok=True)
+    self.samples_background_dir = self.tmp_dir / 'background'
+    self.samples_background_dir.mkdir(exist_ok=True)
+    self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
+    self.manifest_background = self.tmp_dir / 'files_background.cvs'
+    self.stats_tests_dir = self.tmp_dir / 'stats_tests'
+    self.stats_background_dir = self.tmp_dir / 'stats_background'
+    self.counter = 0
+
+
+
+def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor, stems: Optional[List[str]] = None) +
+
+

Update torchmetrics.Metrics by saving the audio and updating the manifest file.

+
+ +Expand source code + +
def update(self, preds: torch.Tensor, targets: torch.Tensor,
+           sizes: torch.Tensor, sample_rates: torch.Tensor,
+           stems: tp.Optional[tp.List[str]] = None):
+    """Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
+    assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
+    num_samples = preds.shape[0]
+    assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
+    assert stems is None or num_samples == len(set(stems))
+    for i in range(num_samples):
+        self.total_files += 1  # type: ignore
+        self.counter += 1
+        wav_len = int(sizes[i].item())
+        sample_rate = int(sample_rates[i].item())
+        pred_wav = preds[i]
+        target_wav = targets[i]
+        pred_wav = pred_wav[..., :wav_len]
+        target_wav = target_wav[..., :wav_len]
+        stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
+        # dump audio files
+        try:
+            pred_wav = convert_audio(
+                pred_wav.unsqueeze(0), from_rate=sample_rate,
+                to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+            audio_write(
+                self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
+                format=self.format, strategy="peak")
+        except Exception as e:
+            logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
+        try:
+            # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
+            # the original audio when writing it
+            target_wav = convert_audio(
+                target_wav.unsqueeze(0), from_rate=sample_rate,
+                to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
+            audio_write(
+                self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
+                format=self.format, strategy="peak")
+        except Exception as e:
+            logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/metrics/index.html b/api_docs/audiocraft/metrics/index.html new file mode 100644 index 00000000..2bc5d486 --- /dev/null +++ b/api_docs/audiocraft/metrics/index.html @@ -0,0 +1,110 @@ + + + + + + +audiocraft.metrics API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.metrics

+
+
+

Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.
+"""
+# flake8: noqa
+from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric
+from .chroma_cosinesim import ChromaCosineSimilarityMetric
+from .fad import FrechetAudioDistanceMetric
+from .kld import KLDivergenceMetric, PasstKLDivergenceMetric
+from .rvm import RelativeVolumeMel
+from .visqol import ViSQOL
+
+
+
+

Sub-modules

+
+
audiocraft.metrics.chroma_cosinesim
+
+
+
+
audiocraft.metrics.clap_consistency
+
+
+
+
audiocraft.metrics.fad
+
+
+
+
audiocraft.metrics.kld
+
+
+
+
audiocraft.metrics.rvm
+
+
+
+
audiocraft.metrics.visqol
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/metrics/kld.html b/api_docs/audiocraft/metrics/kld.html new file mode 100644 index 00000000..11991eaf --- /dev/null +++ b/api_docs/audiocraft/metrics/kld.html @@ -0,0 +1,712 @@ + + + + + + +audiocraft.metrics.kld API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.metrics.kld

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+from functools import partial
+import logging
+import os
+import typing as tp
+
+import torch
+import torchmetrics
+
+from ..data.audio_utils import convert_audio
+
+
+logger = logging.getLogger(__name__)
+
+
+class _patch_passt_stft:
+    """Decorator to patch torch.stft in PaSST."""
+    def __init__(self):
+        self.old_stft = torch.stft
+
+    def __enter__(self):
+        # return_complex is a mandatory parameter in latest torch versions
+        # torch is throwing RuntimeErrors when not set
+        torch.stft = partial(torch.stft, return_complex=False)
+
+    def __exit__(self, *exc):
+        torch.stft = self.old_stft
+
+
+def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
+    """Computes the elementwise KL-Divergence loss between probability distributions
+    from generated samples and target samples.
+
+    Args:
+        pred_probs (torch.Tensor): Probabilities for each label obtained
+            from a classifier on generated audio. Expected shape is [B, num_classes].
+        target_probs (torch.Tensor): Probabilities for each label obtained
+            from a classifier on target audio. Expected shape is [B, num_classes].
+        epsilon (float): Epsilon value.
+    Returns:
+        kld (torch.Tensor): KLD loss between each generated sample and target pair.
+    """
+    kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
+    return kl_div.sum(-1)
+
+
+class KLDivergenceMetric(torchmetrics.Metric):
+    """Base implementation for KL Divergence metric.
+
+    The KL divergence is measured between probability distributions
+    of class predictions returned by a pre-trained audio classification model.
+    When the KL-divergence is low, the generated audio is expected to
+    have similar acoustic characteristics as the reference audio,
+    according to the classifier.
+    """
+    def __init__(self):
+        super().__init__()
+        self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
+
+    def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
+                                sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
+        """Get model output given provided input tensor.
+
+        Args:
+            x (torch.Tensor): Input audio tensor of shape [B, C, T].
+            sizes (torch.Tensor): Actual audio sample length, of shape [B].
+            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+        Returns:
+            probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
+        """
+        raise NotImplementedError("implement method to extract label distributions from the model.")
+
+    def update(self, preds: torch.Tensor, targets: torch.Tensor,
+               sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        """Calculates running KL-Divergence loss between batches of audio
+        preds (generated) and target (ground-truth)
+        Args:
+            preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
+            targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
+            sizes (torch.Tensor): Actual audio sample length, of shape [B].
+            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+        """
+        assert preds.shape == targets.shape
+        assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
+        preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
+        targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
+        if preds_probs is not None and targets_probs is not None:
+            assert preds_probs.shape == targets_probs.shape
+            kld_scores = kl_divergence(preds_probs, targets_probs)
+            assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
+            self.kld_pq_sum += torch.sum(kld_scores)
+            kld_qp_scores = kl_divergence(targets_probs, preds_probs)
+            self.kld_qp_sum += torch.sum(kld_qp_scores)
+            self.weight += torch.tensor(kld_scores.size(0))
+
+    def compute(self) -> dict:
+        """Computes KL-Divergence across all evaluated pred/target pairs."""
+        weight: float = float(self.weight.item())  # type: ignore
+        assert weight > 0, "Unable to compute with total number of comparisons <= 0"
+        logger.info(f"Computing KL divergence on a total of {weight} samples")
+        kld_pq = self.kld_pq_sum.item() / weight  # type: ignore
+        kld_qp = self.kld_qp_sum.item() / weight  # type: ignore
+        kld_both = kld_pq + kld_qp
+        return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
+
+
+class PasstKLDivergenceMetric(KLDivergenceMetric):
+    """KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
+
+    From: PaSST: Efficient Training of Audio Transformers with Patchout
+    Paper: https://arxiv.org/abs/2110.05069
+    Implementation: https://github.com/kkoutini/PaSST
+
+    Follow instructions from the github repo:
+    ```
+    pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
+    ```
+
+    Args:
+        pretrained_length (float, optional): Audio duration used for the pretrained model.
+    """
+    def __init__(self, pretrained_length: tp.Optional[float] = None):
+        super().__init__()
+        self._initialize_model(pretrained_length)
+
+    def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
+        """Initialize underlying PaSST audio classifier."""
+        model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
+        self.min_input_frames = min_frames
+        self.max_input_frames = max_frames
+        self.model_sample_rate = sr
+        self.model = model
+        self.model.eval()
+        self.model.to(self.device)
+
+    def _load_base_model(self, pretrained_length: tp.Optional[float]):
+        """Load pretrained model from PaSST."""
+        try:
+            if pretrained_length == 30:
+                from hear21passt.base30sec import get_basic_model  # type: ignore
+                max_duration = 30
+            elif pretrained_length == 20:
+                from hear21passt.base20sec import get_basic_model  # type: ignore
+                max_duration = 20
+            else:
+                from hear21passt.base import get_basic_model  # type: ignore
+                # Original PASST was trained on AudioSet with 10s-long audio samples
+                max_duration = 10
+            min_duration = 0.15
+            min_duration = 0.15
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "Please install hear21passt to compute KL divergence: ",
+                "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
+            )
+        model_sample_rate = 32_000
+        max_input_frames = int(max_duration * model_sample_rate)
+        min_input_frames = int(min_duration * model_sample_rate)
+        with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
+            model = get_basic_model(mode='logits')
+        return model, model_sample_rate, max_input_frames, min_input_frames
+
+    def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
+        """Process audio to feed to the pretrained model."""
+        wav = wav.unsqueeze(0)
+        wav = wav[..., :wav_len]
+        wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
+        wav = wav.squeeze(0)
+        # we don't pad but return a list of audio segments as this otherwise affects the KLD computation
+        segments = torch.split(wav, self.max_input_frames, dim=-1)
+        valid_segments = []
+        for s in segments:
+            # ignoring too small segments that are breaking the model inference
+            if s.size(-1) > self.min_input_frames:
+                valid_segments.append(s)
+        return [s[None] for s in valid_segments]
+
+    def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
+        """Run the pretrained model and get the predictions."""
+        assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
+        wav = wav.mean(dim=1)
+        # PaSST is printing a lot of garbage that we are not interested in
+        with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
+            with torch.no_grad(), _patch_passt_stft():
+                logits = self.model(wav.to(self.device))
+                probs = torch.softmax(logits, dim=-1)
+                return probs
+
+    def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
+                                sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
+        """Get model output given provided input tensor.
+
+        Args:
+            x (torch.Tensor): Input audio tensor of shape [B, C, T].
+            sizes (torch.Tensor): Actual audio sample length, of shape [B].
+            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+        Returns:
+            probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
+        """
+        all_probs: tp.List[torch.Tensor] = []
+        for i, wav in enumerate(x):
+            sample_rate = int(sample_rates[i].item())
+            wav_len = int(sizes[i].item())
+            wav_segments = self._process_audio(wav, sample_rate, wav_len)
+            for segment in wav_segments:
+                probs = self._get_model_preds(segment).mean(dim=0)
+                all_probs.append(probs)
+        if len(all_probs) > 0:
+            return torch.stack(all_probs, dim=0)
+        else:
+            return None
+
+
+
+
+
+
+
+

Functions

+
+
+def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-06) ‑> torch.Tensor +
+
+

Computes the elementwise KL-Divergence loss between probability distributions +from generated samples and target samples.

+

Args

+
+
pred_probs : torch.Tensor
+
Probabilities for each label obtained +from a classifier on generated audio. Expected shape is [B, num_classes].
+
target_probs : torch.Tensor
+
Probabilities for each label obtained +from a classifier on target audio. Expected shape is [B, num_classes].
+
epsilon : float
+
Epsilon value.
+
+

Returns

+

kld (torch.Tensor): KLD loss between each generated sample and target pair.

+
+ +Expand source code + +
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
+    """Computes the elementwise KL-Divergence loss between probability distributions
+    from generated samples and target samples.
+
+    Args:
+        pred_probs (torch.Tensor): Probabilities for each label obtained
+            from a classifier on generated audio. Expected shape is [B, num_classes].
+        target_probs (torch.Tensor): Probabilities for each label obtained
+            from a classifier on target audio. Expected shape is [B, num_classes].
+        epsilon (float): Epsilon value.
+    Returns:
+        kld (torch.Tensor): KLD loss between each generated sample and target pair.
+    """
+    kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
+    return kl_div.sum(-1)
+
+
+
+
+
+

Classes

+
+
+class KLDivergenceMetric +
+
+

Base implementation for KL Divergence metric.

+

The KL divergence is measured between probability distributions +of class predictions returned by a pre-trained audio classification model. +When the KL-divergence is low, the generated audio is expected to +have similar acoustic characteristics as the reference audio, +according to the classifier.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class KLDivergenceMetric(torchmetrics.Metric):
+    """Base implementation for KL Divergence metric.
+
+    The KL divergence is measured between probability distributions
+    of class predictions returned by a pre-trained audio classification model.
+    When the KL-divergence is low, the generated audio is expected to
+    have similar acoustic characteristics as the reference audio,
+    according to the classifier.
+    """
+    def __init__(self):
+        super().__init__()
+        self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
+        self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
+
+    def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
+                                sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
+        """Get model output given provided input tensor.
+
+        Args:
+            x (torch.Tensor): Input audio tensor of shape [B, C, T].
+            sizes (torch.Tensor): Actual audio sample length, of shape [B].
+            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+        Returns:
+            probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
+        """
+        raise NotImplementedError("implement method to extract label distributions from the model.")
+
+    def update(self, preds: torch.Tensor, targets: torch.Tensor,
+               sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+        """Calculates running KL-Divergence loss between batches of audio
+        preds (generated) and target (ground-truth)
+        Args:
+            preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
+            targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
+            sizes (torch.Tensor): Actual audio sample length, of shape [B].
+            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+        """
+        assert preds.shape == targets.shape
+        assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
+        preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
+        targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
+        if preds_probs is not None and targets_probs is not None:
+            assert preds_probs.shape == targets_probs.shape
+            kld_scores = kl_divergence(preds_probs, targets_probs)
+            assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
+            self.kld_pq_sum += torch.sum(kld_scores)
+            kld_qp_scores = kl_divergence(targets_probs, preds_probs)
+            self.kld_qp_sum += torch.sum(kld_qp_scores)
+            self.weight += torch.tensor(kld_scores.size(0))
+
+    def compute(self) -> dict:
+        """Computes KL-Divergence across all evaluated pred/target pairs."""
+        weight: float = float(self.weight.item())  # type: ignore
+        assert weight > 0, "Unable to compute with total number of comparisons <= 0"
+        logger.info(f"Computing KL divergence on a total of {weight} samples")
+        kld_pq = self.kld_pq_sum.item() / weight  # type: ignore
+        kld_qp = self.kld_qp_sum.item() / weight  # type: ignore
+        kld_both = kld_pq + kld_qp
+        return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
+
+

Ancestors

+
    +
  • torchmetrics.metric.Metric
  • +
  • torch.nn.modules.module.Module
  • +
  • abc.ABC
  • +
+

Subclasses

+ +

Class variables

+
+
var full_state_update : Optional[bool]
+
+
+
+
var higher_is_better : Optional[bool]
+
+
+
+
var is_differentiable : Optional[bool]
+
+
+
+
var plot_legend_name : Optional[str]
+
+
+
+
var plot_lower_bound : Optional[float]
+
+
+
+
var plot_upper_bound : Optional[float]
+
+
+
+
+

Methods

+
+
+def compute(self) ‑> dict +
+
+

Computes KL-Divergence across all evaluated pred/target pairs.

+
+ +Expand source code + +
def compute(self) -> dict:
+    """Computes KL-Divergence across all evaluated pred/target pairs."""
+    weight: float = float(self.weight.item())  # type: ignore
+    assert weight > 0, "Unable to compute with total number of comparisons <= 0"
+    logger.info(f"Computing KL divergence on a total of {weight} samples")
+    kld_pq = self.kld_pq_sum.item() / weight  # type: ignore
+    kld_qp = self.kld_qp_sum.item() / weight  # type: ignore
+    kld_both = kld_pq + kld_qp
+    return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
+
+
+
+def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) ‑> None +
+
+

Calculates running KL-Divergence loss between batches of audio +preds (generated) and target (ground-truth)

+

Args

+
+
preds : torch.Tensor
+
Audio samples to evaluate, of shape [B, C, T].
+
targets : torch.Tensor
+
Target samples to compare against, of shape [B, C, T].
+
sizes : torch.Tensor
+
Actual audio sample length, of shape [B].
+
sample_rates : torch.Tensor
+
Actual audio sample rate, of shape [B].
+
+
+ +Expand source code + +
def update(self, preds: torch.Tensor, targets: torch.Tensor,
+           sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
+    """Calculates running KL-Divergence loss between batches of audio
+    preds (generated) and target (ground-truth)
+    Args:
+        preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
+        targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
+        sizes (torch.Tensor): Actual audio sample length, of shape [B].
+        sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+    """
+    assert preds.shape == targets.shape
+    assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
+    preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
+    targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
+    if preds_probs is not None and targets_probs is not None:
+        assert preds_probs.shape == targets_probs.shape
+        kld_scores = kl_divergence(preds_probs, targets_probs)
+        assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
+        self.kld_pq_sum += torch.sum(kld_scores)
+        kld_qp_scores = kl_divergence(targets_probs, preds_probs)
+        self.kld_qp_sum += torch.sum(kld_qp_scores)
+        self.weight += torch.tensor(kld_scores.size(0))
+
+
+
+
+
+class PasstKLDivergenceMetric +(pretrained_length: Optional[float] = None) +
+
+

KL-Divergence metric based on pre-trained PASST classifier on AudioSet.

+

From: PaSST: Efficient Training of Audio Transformers with Patchout +Paper: https://arxiv.org/abs/2110.05069 +Implementation: https://github.com/kkoutini/PaSST

+

Follow instructions from the github repo:

+
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
+
+

Args

+
+
pretrained_length : float, optional
+
Audio duration used for the pretrained model.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class PasstKLDivergenceMetric(KLDivergenceMetric):
+    """KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
+
+    From: PaSST: Efficient Training of Audio Transformers with Patchout
+    Paper: https://arxiv.org/abs/2110.05069
+    Implementation: https://github.com/kkoutini/PaSST
+
+    Follow instructions from the github repo:
+    ```
+    pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
+    ```
+
+    Args:
+        pretrained_length (float, optional): Audio duration used for the pretrained model.
+    """
+    def __init__(self, pretrained_length: tp.Optional[float] = None):
+        super().__init__()
+        self._initialize_model(pretrained_length)
+
+    def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
+        """Initialize underlying PaSST audio classifier."""
+        model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
+        self.min_input_frames = min_frames
+        self.max_input_frames = max_frames
+        self.model_sample_rate = sr
+        self.model = model
+        self.model.eval()
+        self.model.to(self.device)
+
+    def _load_base_model(self, pretrained_length: tp.Optional[float]):
+        """Load pretrained model from PaSST."""
+        try:
+            if pretrained_length == 30:
+                from hear21passt.base30sec import get_basic_model  # type: ignore
+                max_duration = 30
+            elif pretrained_length == 20:
+                from hear21passt.base20sec import get_basic_model  # type: ignore
+                max_duration = 20
+            else:
+                from hear21passt.base import get_basic_model  # type: ignore
+                # Original PASST was trained on AudioSet with 10s-long audio samples
+                max_duration = 10
+            min_duration = 0.15
+            min_duration = 0.15
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "Please install hear21passt to compute KL divergence: ",
+                "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
+            )
+        model_sample_rate = 32_000
+        max_input_frames = int(max_duration * model_sample_rate)
+        min_input_frames = int(min_duration * model_sample_rate)
+        with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
+            model = get_basic_model(mode='logits')
+        return model, model_sample_rate, max_input_frames, min_input_frames
+
+    def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
+        """Process audio to feed to the pretrained model."""
+        wav = wav.unsqueeze(0)
+        wav = wav[..., :wav_len]
+        wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
+        wav = wav.squeeze(0)
+        # we don't pad but return a list of audio segments as this otherwise affects the KLD computation
+        segments = torch.split(wav, self.max_input_frames, dim=-1)
+        valid_segments = []
+        for s in segments:
+            # ignoring too small segments that are breaking the model inference
+            if s.size(-1) > self.min_input_frames:
+                valid_segments.append(s)
+        return [s[None] for s in valid_segments]
+
+    def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
+        """Run the pretrained model and get the predictions."""
+        assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
+        wav = wav.mean(dim=1)
+        # PaSST is printing a lot of garbage that we are not interested in
+        with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
+            with torch.no_grad(), _patch_passt_stft():
+                logits = self.model(wav.to(self.device))
+                probs = torch.softmax(logits, dim=-1)
+                return probs
+
+    def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
+                                sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
+        """Get model output given provided input tensor.
+
+        Args:
+            x (torch.Tensor): Input audio tensor of shape [B, C, T].
+            sizes (torch.Tensor): Actual audio sample length, of shape [B].
+            sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
+        Returns:
+            probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
+        """
+        all_probs: tp.List[torch.Tensor] = []
+        for i, wav in enumerate(x):
+            sample_rate = int(sample_rates[i].item())
+            wav_len = int(sizes[i].item())
+            wav_segments = self._process_audio(wav, sample_rate, wav_len)
+            for segment in wav_segments:
+                probs = self._get_model_preds(segment).mean(dim=0)
+                all_probs.append(probs)
+        if len(all_probs) > 0:
+            return torch.stack(all_probs, dim=0)
+        else:
+            return None
+
+

Ancestors

+
    +
  • KLDivergenceMetric
  • +
  • torchmetrics.metric.Metric
  • +
  • torch.nn.modules.module.Module
  • +
  • abc.ABC
  • +
+

Class variables

+
+
var full_state_update : Optional[bool]
+
+
+
+
var higher_is_better : Optional[bool]
+
+
+
+
var is_differentiable : Optional[bool]
+
+
+
+
var plot_legend_name : Optional[str]
+
+
+
+
var plot_lower_bound : Optional[float]
+
+
+
+
var plot_upper_bound : Optional[float]
+
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/metrics/rvm.html b/api_docs/audiocraft/metrics/rvm.html new file mode 100644 index 00000000..5c4b98a4 --- /dev/null +++ b/api_docs/audiocraft/metrics/rvm.html @@ -0,0 +1,447 @@ + + + + + + +audiocraft.metrics.rvm API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.metrics.rvm

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+import torch
+from torch import nn
+import torchaudio
+
+
+def db_to_scale(volume: tp.Union[float, torch.Tensor]):
+    return 10 ** (volume / 20)
+
+
+def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
+    min_scale = db_to_scale(min_volume)
+    return 20 * torch.log10(scale.clamp(min=min_scale))
+
+
+class RelativeVolumeMel(nn.Module):
+    """Relative volume melspectrogram measure.
+
+    Computes a measure of distance over two mel spectrogram that is interpretable in terms
+    of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
+    first renormalize both by the ground truth of `x_ref`.
+
+    ..Warning:: This class returns the volume of the distortion at the spectrogram level,
+        e.g. low negative values reflects lower distortion levels. For a SNR (like reported
+        in the MultiBandDiffusion paper), just take `-rvm`.
+
+    Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
+    relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
+    clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
+    with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
+    Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
+    average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
+    good (for a neural network output, although sound engineers typically aim for much lower attenuations).
+    Similarly, anything above +30 dB would just be completely missing the target, and there is no point
+    in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
+    in line with what neural nets currently can achieve.
+
+    For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
+    the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
+
+    The metric can be aggregated over a given frequency band in order have different insights for
+    different region of the spectrum. `num_aggregated_bands` controls the number of bands.
+
+    ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
+        is numerically stable when computing its gradient. We thus advise against using it as a training loss.
+
+    Args:
+        sample_rate (int): Sample rate of the input audio.
+        n_mels (int): Number of mel bands to use.
+        n_fft (int): Number of frequency bins for the STFT.
+        hop_length (int): Hop length of the STFT and the mel-spectrogram.
+        min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
+            the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
+        max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
+        max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
+            to that amount, to avoid rescaling near silence. Given in dB.
+        min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
+            bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
+            and anything below that will be considered equally.
+        num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
+            For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
+    """
+    def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
+                 hop_length: int = 128, min_relative_volume: float = -25,
+                 max_relative_volume: float = 25, max_initial_gain: float = 25,
+                 min_activity_volume: float = -25,
+                 num_aggregated_bands: int = 4) -> None:
+        super().__init__()
+        self.melspec = torchaudio.transforms.MelSpectrogram(
+            n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
+            normalized=True, sample_rate=sample_rate, power=2)
+        self.min_relative_volume = min_relative_volume
+        self.max_relative_volume = max_relative_volume
+        self.max_initial_gain = max_initial_gain
+        self.min_activity_volume = min_activity_volume
+        self.num_aggregated_bands = num_aggregated_bands
+
+    def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
+        """Compute RVM metric between estimate and reference samples.
+
+        Args:
+            estimate (torch.Tensor): Estimate sample.
+            ground_truth (torch.Tensor): Reference sample.
+
+        Returns:
+            dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
+            for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
+        """
+        min_scale = db_to_scale(-self.max_initial_gain)
+        std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
+        z_gt = self.melspec(ground_truth / std).sqrt()
+        z_est = self.melspec(estimate / std).sqrt()
+
+        delta = z_gt - z_est
+        ref_db = scale_to_db(z_gt, self.min_activity_volume)
+        delta_db = scale_to_db(delta.abs(), min_volume=-120)
+        relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
+        dims = list(range(relative_db.dim()))
+        dims.remove(dims[-2])
+        losses_per_band = relative_db.mean(dim=dims)
+        aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
+        metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
+        metrics['rvm'] = losses_per_band.mean()
+        return metrics
+
+
+
+
+
+
+
+

Functions

+
+
+def db_to_scale(volume: Union[float, torch.Tensor]) +
+
+
+
+ +Expand source code + +
def db_to_scale(volume: tp.Union[float, torch.Tensor]):
+    return 10 ** (volume / 20)
+
+
+
+def scale_to_db(scale: torch.Tensor, min_volume: float = -120) +
+
+
+
+ +Expand source code + +
def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
+    min_scale = db_to_scale(min_volume)
+    return 20 * torch.log10(scale.clamp(min=min_scale))
+
+
+
+
+
+

Classes

+
+
+class RelativeVolumeMel +(sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512, hop_length: int = 128, min_relative_volume: float = -25, max_relative_volume: float = 25, max_initial_gain: float = 25, min_activity_volume: float = -25, num_aggregated_bands: int = 4) +
+
+

Relative volume melspectrogram measure.

+

Computes a measure of distance over two mel spectrogram that is interpretable in terms +of decibels. Given x_ref and x_est two waveforms of shape [*, T], it will +first renormalize both by the ground truth of x_ref.

+
+

Warning: This class returns the volume of the distortion at the spectrogram level,

+

e.g. low negative values reflects lower distortion levels. For a SNR (like reported +in the MultiBandDiffusion paper), just take -rvm.

+
+

Then it computes the mel spectrogram z_ref and z_est and compute volume of the difference +relative to the volume of z_ref for each time-frequency bin. It further adds some limits, e.g. +clamping the values between -25 and 25 dB (controlled by min_relative_volume and max_relative_volume) +with the goal of avoiding the loss being dominated by parts where the reference is almost silent. +Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final +average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely +good (for a neural network output, although sound engineers typically aim for much lower attenuations). +Similarly, anything above +30 dB would just be completely missing the target, and there is no point +in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more +in line with what neural nets currently can achieve.

+

For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between +the target and reference mel-spec is 10 dB lower than the reference mel-spec value.

+

The metric can be aggregated over a given frequency band in order have different insights for +different region of the spectrum. num_aggregated_bands controls the number of bands.

+
+

Warning: While this function is optimized for interpretability, nothing was done to ensure it

+

is numerically stable when computing its gradient. We thus advise against using it as a training loss.

+
+

Args

+
+
sample_rate : int
+
Sample rate of the input audio.
+
n_mels : int
+
Number of mel bands to use.
+
n_fft : int
+
Number of frequency bins for the STFT.
+
hop_length : int
+
Hop length of the STFT and the mel-spectrogram.
+
min_relative_volume : float
+
The error z_ref - z_est volume is given relative to +the volume of z_ref. If error is smaller than -25 dB of z_ref, then it is clamped.
+
max_relative_volume : float
+
Same as min_relative_volume but clamping if the error is larger than that.
+
max_initial_gain : float
+
When rescaling the audio at the very beginning, we will limit the gain +to that amount, to avoid rescaling near silence. Given in dB.
+
min_activity_volume : float
+
When computing the reference level from z_ref, will clamp low volume +bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram, +and anything below that will be considered equally.
+
num_aggregated_bands : int
+
Number of bands to keep when computing the average RVM value. +For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class RelativeVolumeMel(nn.Module):
+    """Relative volume melspectrogram measure.
+
+    Computes a measure of distance over two mel spectrogram that is interpretable in terms
+    of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
+    first renormalize both by the ground truth of `x_ref`.
+
+    ..Warning:: This class returns the volume of the distortion at the spectrogram level,
+        e.g. low negative values reflects lower distortion levels. For a SNR (like reported
+        in the MultiBandDiffusion paper), just take `-rvm`.
+
+    Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
+    relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
+    clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
+    with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
+    Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
+    average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
+    good (for a neural network output, although sound engineers typically aim for much lower attenuations).
+    Similarly, anything above +30 dB would just be completely missing the target, and there is no point
+    in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
+    in line with what neural nets currently can achieve.
+
+    For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
+    the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
+
+    The metric can be aggregated over a given frequency band in order have different insights for
+    different region of the spectrum. `num_aggregated_bands` controls the number of bands.
+
+    ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
+        is numerically stable when computing its gradient. We thus advise against using it as a training loss.
+
+    Args:
+        sample_rate (int): Sample rate of the input audio.
+        n_mels (int): Number of mel bands to use.
+        n_fft (int): Number of frequency bins for the STFT.
+        hop_length (int): Hop length of the STFT and the mel-spectrogram.
+        min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
+            the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
+        max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
+        max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
+            to that amount, to avoid rescaling near silence. Given in dB.
+        min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
+            bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
+            and anything below that will be considered equally.
+        num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
+            For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
+    """
+    def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
+                 hop_length: int = 128, min_relative_volume: float = -25,
+                 max_relative_volume: float = 25, max_initial_gain: float = 25,
+                 min_activity_volume: float = -25,
+                 num_aggregated_bands: int = 4) -> None:
+        super().__init__()
+        self.melspec = torchaudio.transforms.MelSpectrogram(
+            n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
+            normalized=True, sample_rate=sample_rate, power=2)
+        self.min_relative_volume = min_relative_volume
+        self.max_relative_volume = max_relative_volume
+        self.max_initial_gain = max_initial_gain
+        self.min_activity_volume = min_activity_volume
+        self.num_aggregated_bands = num_aggregated_bands
+
+    def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
+        """Compute RVM metric between estimate and reference samples.
+
+        Args:
+            estimate (torch.Tensor): Estimate sample.
+            ground_truth (torch.Tensor): Reference sample.
+
+        Returns:
+            dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
+            for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
+        """
+        min_scale = db_to_scale(-self.max_initial_gain)
+        std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
+        z_gt = self.melspec(ground_truth / std).sqrt()
+        z_est = self.melspec(estimate / std).sqrt()
+
+        delta = z_gt - z_est
+        ref_db = scale_to_db(z_gt, self.min_activity_volume)
+        delta_db = scale_to_db(delta.abs(), min_volume=-120)
+        relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
+        dims = list(range(relative_db.dim()))
+        dims.remove(dims[-2])
+        losses_per_band = relative_db.mean(dim=dims)
+        aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
+        metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
+        metrics['rvm'] = losses_per_band.mean()
+        return metrics
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) ‑> Dict[str, torch.Tensor] +
+
+

Compute RVM metric between estimate and reference samples.

+

Args

+
+
estimate : torch.Tensor
+
Estimate sample.
+
ground_truth : torch.Tensor
+
Reference sample.
+
+

Returns

+
+
dict[str, torch.Tensor]
+
Metrics with keys rvm for the overall average, and rvm_{k}
+
+

for the RVM over the k-th band (k=0..num_aggregated_bands - 1).

+
+ +Expand source code + +
def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
+    """Compute RVM metric between estimate and reference samples.
+
+    Args:
+        estimate (torch.Tensor): Estimate sample.
+        ground_truth (torch.Tensor): Reference sample.
+
+    Returns:
+        dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
+        for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
+    """
+    min_scale = db_to_scale(-self.max_initial_gain)
+    std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
+    z_gt = self.melspec(ground_truth / std).sqrt()
+    z_est = self.melspec(estimate / std).sqrt()
+
+    delta = z_gt - z_est
+    ref_db = scale_to_db(z_gt, self.min_activity_volume)
+    delta_db = scale_to_db(delta.abs(), min_volume=-120)
+    relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
+    dims = list(range(relative_db.dim()))
+    dims.remove(dims[-2])
+    losses_per_band = relative_db.mean(dim=dims)
+    aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
+    metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
+    metrics['rvm'] = losses_per_band.mean()
+    return metrics
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/metrics/visqol.html b/api_docs/audiocraft/metrics/visqol.html new file mode 100644 index 00000000..edd75990 --- /dev/null +++ b/api_docs/audiocraft/metrics/visqol.html @@ -0,0 +1,550 @@ + + + + + + +audiocraft.metrics.visqol API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.metrics.visqol

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import csv
+import json
+import logging
+from pathlib import Path
+import tempfile
+import typing as tp
+import subprocess
+import shutil
+
+import torch
+import torchaudio
+
+logger = logging.getLogger(__name__)
+
+
+class ViSQOL:
+    """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
+
+    To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
+    instructions available in the open source repository: https://github.com/google/visqol
+
+    ViSQOL is capable of running in two modes:
+
+    Audio Mode:
+        When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
+        Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
+        Audio mode uses support vector regression, with the maximum range at ~4.75.
+
+    Speech Mode:
+        When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
+            Input should be resampled to 16kHz.
+        As part of the speech mode processing, a root mean square implementation for voice activity detection
+            is performed on the reference signal to determine what parts of the signal have voice activity and
+            should therefore be included in the comparison. The signal is normalized before performing the voice
+            activity detection.
+        Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
+        Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
+
+    For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
+
+    Args:
+        visqol_bin (str): Path to the ViSQOL binary.
+        mode (str): ViSQOL computation mode, expecting "audio" or "speech".
+        model (str): Name of the model to use for similarity to quality model.
+        debug (bool): Whether to also get debug metrics from ViSQOL or not.
+    """
+    SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
+    ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
+
+    def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
+                 model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
+        assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
+        self.visqol_bin = str(bin)
+        self.visqol_mode = mode
+        self.target_sr = self._get_target_sr(self.visqol_mode)
+        self.model = model
+        self.debug = debug
+        assert Path(self.visqol_model).exists(), \
+            f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
+
+    def _get_target_sr(self, mode: str) -> int:
+        # returns target sampling rate for the corresponding ViSQOL mode.
+        if mode not in ViSQOL.SAMPLE_RATES_MODES:
+            raise ValueError(
+                f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
+            )
+        return ViSQOL.SAMPLE_RATES_MODES[mode]
+
+    def _prepare_files(
+        self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
+    ):
+        # prepare files for ViSQOL evaluation.
+        assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
+        assert len(ref_sig) == len(deg_sig), (
+            "Expects same number of ref and degraded inputs",
+            f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
+        )
+        # resample audio if needed
+        if sr != target_sr:
+            transform = torchaudio.transforms.Resample(sr, target_sr)
+            pad = int(0.5 * target_sr)
+            rs_ref = []
+            rs_deg = []
+            for i in range(len(ref_sig)):
+                rs_ref_i = transform(ref_sig[i])
+                rs_deg_i = transform(deg_sig[i])
+                if pad_with_silence:
+                    rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
+                    rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
+                rs_ref.append(rs_ref_i)
+                rs_deg.append(rs_deg_i)
+            ref_sig = torch.stack(rs_ref)
+            deg_sig = torch.stack(rs_deg)
+        # save audio chunks to tmp dir and create csv
+        tmp_dir = Path(tempfile.mkdtemp())
+        try:
+            tmp_input_csv_path = tmp_dir / "input.csv"
+            tmp_results_csv_path = tmp_dir / "results.csv"
+            tmp_debug_json_path = tmp_dir / "debug.json"
+            with open(tmp_input_csv_path, "w") as csv_file:
+                csv_writer = csv.writer(csv_file)
+                csv_writer.writerow(["reference", "degraded"])
+                for i in range(len(ref_sig)):
+                    tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
+                    tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
+                    torchaudio.save(
+                        tmp_ref_filename,
+                        torch.clamp(ref_sig[i], min=-0.99, max=0.99),
+                        sample_rate=target_sr,
+                        bits_per_sample=16,
+                        encoding="PCM_S"
+                    )
+                    torchaudio.save(
+                        tmp_deg_filename,
+                        torch.clamp(deg_sig[i], min=-0.99, max=0.99),
+                        sample_rate=target_sr,
+                        bits_per_sample=16,
+                        encoding="PCM_S"
+                    )
+                    csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
+            return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
+        except Exception as e:
+            logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
+            return tmp_dir, None, None, None
+
+    def _flush_files(self, tmp_dir: tp.Union[Path, str]):
+        # flush tmp files used to compute ViSQOL.
+        shutil.rmtree(str(tmp_dir))
+
+    def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
+        # collect results for each evaluated pair and return averaged moslqo score.
+        with open(results_csv_path, "r") as csv_file:
+            reader = csv.DictReader(csv_file)
+            moslqo_scores = [float(row["moslqo"]) for row in reader]
+            if len(moslqo_scores) > 0:
+                return sum(moslqo_scores) / len(moslqo_scores)
+            else:
+                return 0.0
+
+    def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
+        # collect debug data for the visqol inference.
+        with open(debug_json_path, "r") as f:
+            data = json.load(f)
+            return data
+
+    @property
+    def visqol_model(self):
+        return f'{self.visqol_bin}/model/{self.model}'
+
+    def _run_visqol(
+        self,
+        input_csv_path: tp.Union[Path, str],
+        results_csv_path: tp.Union[Path, str],
+        debug_csv_path: tp.Optional[tp.Union[Path, str]],
+    ):
+        input_csv_path = str(input_csv_path)
+        results_csv_path = str(results_csv_path)
+        debug_csv_path = str(debug_csv_path)
+        cmd = [
+            f'{self.visqol_bin}/bazel-bin/visqol',
+            '--batch_input_csv', f'{input_csv_path}',
+            '--results_csv', f'{results_csv_path}'
+        ]
+        if debug_csv_path is not None:
+            cmd += ['--output_debug', f'{debug_csv_path}']
+        if self.visqol_mode == "speech":
+            cmd += ['--use_speech_mode']
+        cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
+        result = subprocess.run(cmd, capture_output=True)
+        if result.returncode:
+            logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
+            raise RuntimeError("Error while executing visqol")
+        result.check_returncode()
+
+    def __call__(
+        self,
+        ref_sig: torch.Tensor,
+        deg_sig: torch.Tensor,
+        sr: int,
+        pad_with_silence: bool = False,
+    ):
+        """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
+        Args:
+            ref_sig (torch.Tensor): Reference signals as [B, C, T].
+            deg_sig (torch.Tensor): Degraded signals as [B, C, T].
+            sr (int): Sample rate of the two audio signals.
+            pad_with_silence (bool): Whether to pad the file with silences as recommended
+                in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
+        Returns:
+            float: The ViSQOL score or mean score for the batch.
+        """
+        logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
+        tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
+            ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
+        )
+        try:
+            if input_csv and results_csv:
+                self._run_visqol(
+                    input_csv,
+                    results_csv,
+                    debug_json if self.debug else None,
+                )
+                mosqol = self._collect_moslqo_score(results_csv)
+                return mosqol
+            else:
+                raise RuntimeError("Something unexpected happened when running VISQOL!")
+        except Exception as e:
+            logger.error("Exception occurred when running ViSQOL: %s", e)
+        finally:
+            self._flush_files(tmp_dir)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ViSQOL +(bin: Union[str, pathlib.Path], mode: str = 'audio', model: str = 'libsvm_nu_svr_model.txt', debug: bool = False) +
+
+

ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.

+

To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the +instructions available in the open source repository: https://github.com/google/visqol

+

ViSQOL is capable of running in two modes:

+

Audio Mode: +When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz. +Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. +Audio mode uses support vector regression, with the maximum range at ~4.75.

+

Speech Mode: +When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz. +Input should be resampled to 16kHz. +As part of the speech mode processing, a root mean square implementation for voice activity detection +is performed on the reference signal to determine what parts of the signal have voice activity and +should therefore be included in the comparison. The signal is normalized before performing the voice +activity detection. +Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. +Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.

+

For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input

+

Args

+
+
visqol_bin : str
+
Path to the ViSQOL binary.
+
mode : str
+
ViSQOL computation mode, expecting "audio" or "speech".
+
model : str
+
Name of the model to use for similarity to quality model.
+
debug : bool
+
Whether to also get debug metrics from ViSQOL or not.
+
+
+ +Expand source code + +
class ViSQOL:
+    """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
+
+    To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
+    instructions available in the open source repository: https://github.com/google/visqol
+
+    ViSQOL is capable of running in two modes:
+
+    Audio Mode:
+        When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
+        Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
+        Audio mode uses support vector regression, with the maximum range at ~4.75.
+
+    Speech Mode:
+        When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
+            Input should be resampled to 16kHz.
+        As part of the speech mode processing, a root mean square implementation for voice activity detection
+            is performed on the reference signal to determine what parts of the signal have voice activity and
+            should therefore be included in the comparison. The signal is normalized before performing the voice
+            activity detection.
+        Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
+        Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
+
+    For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
+
+    Args:
+        visqol_bin (str): Path to the ViSQOL binary.
+        mode (str): ViSQOL computation mode, expecting "audio" or "speech".
+        model (str): Name of the model to use for similarity to quality model.
+        debug (bool): Whether to also get debug metrics from ViSQOL or not.
+    """
+    SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
+    ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
+
+    def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
+                 model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
+        assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
+        self.visqol_bin = str(bin)
+        self.visqol_mode = mode
+        self.target_sr = self._get_target_sr(self.visqol_mode)
+        self.model = model
+        self.debug = debug
+        assert Path(self.visqol_model).exists(), \
+            f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
+
+    def _get_target_sr(self, mode: str) -> int:
+        # returns target sampling rate for the corresponding ViSQOL mode.
+        if mode not in ViSQOL.SAMPLE_RATES_MODES:
+            raise ValueError(
+                f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
+            )
+        return ViSQOL.SAMPLE_RATES_MODES[mode]
+
+    def _prepare_files(
+        self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
+    ):
+        # prepare files for ViSQOL evaluation.
+        assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
+        assert len(ref_sig) == len(deg_sig), (
+            "Expects same number of ref and degraded inputs",
+            f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
+        )
+        # resample audio if needed
+        if sr != target_sr:
+            transform = torchaudio.transforms.Resample(sr, target_sr)
+            pad = int(0.5 * target_sr)
+            rs_ref = []
+            rs_deg = []
+            for i in range(len(ref_sig)):
+                rs_ref_i = transform(ref_sig[i])
+                rs_deg_i = transform(deg_sig[i])
+                if pad_with_silence:
+                    rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
+                    rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
+                rs_ref.append(rs_ref_i)
+                rs_deg.append(rs_deg_i)
+            ref_sig = torch.stack(rs_ref)
+            deg_sig = torch.stack(rs_deg)
+        # save audio chunks to tmp dir and create csv
+        tmp_dir = Path(tempfile.mkdtemp())
+        try:
+            tmp_input_csv_path = tmp_dir / "input.csv"
+            tmp_results_csv_path = tmp_dir / "results.csv"
+            tmp_debug_json_path = tmp_dir / "debug.json"
+            with open(tmp_input_csv_path, "w") as csv_file:
+                csv_writer = csv.writer(csv_file)
+                csv_writer.writerow(["reference", "degraded"])
+                for i in range(len(ref_sig)):
+                    tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
+                    tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
+                    torchaudio.save(
+                        tmp_ref_filename,
+                        torch.clamp(ref_sig[i], min=-0.99, max=0.99),
+                        sample_rate=target_sr,
+                        bits_per_sample=16,
+                        encoding="PCM_S"
+                    )
+                    torchaudio.save(
+                        tmp_deg_filename,
+                        torch.clamp(deg_sig[i], min=-0.99, max=0.99),
+                        sample_rate=target_sr,
+                        bits_per_sample=16,
+                        encoding="PCM_S"
+                    )
+                    csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
+            return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
+        except Exception as e:
+            logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
+            return tmp_dir, None, None, None
+
+    def _flush_files(self, tmp_dir: tp.Union[Path, str]):
+        # flush tmp files used to compute ViSQOL.
+        shutil.rmtree(str(tmp_dir))
+
+    def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
+        # collect results for each evaluated pair and return averaged moslqo score.
+        with open(results_csv_path, "r") as csv_file:
+            reader = csv.DictReader(csv_file)
+            moslqo_scores = [float(row["moslqo"]) for row in reader]
+            if len(moslqo_scores) > 0:
+                return sum(moslqo_scores) / len(moslqo_scores)
+            else:
+                return 0.0
+
+    def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
+        # collect debug data for the visqol inference.
+        with open(debug_json_path, "r") as f:
+            data = json.load(f)
+            return data
+
+    @property
+    def visqol_model(self):
+        return f'{self.visqol_bin}/model/{self.model}'
+
+    def _run_visqol(
+        self,
+        input_csv_path: tp.Union[Path, str],
+        results_csv_path: tp.Union[Path, str],
+        debug_csv_path: tp.Optional[tp.Union[Path, str]],
+    ):
+        input_csv_path = str(input_csv_path)
+        results_csv_path = str(results_csv_path)
+        debug_csv_path = str(debug_csv_path)
+        cmd = [
+            f'{self.visqol_bin}/bazel-bin/visqol',
+            '--batch_input_csv', f'{input_csv_path}',
+            '--results_csv', f'{results_csv_path}'
+        ]
+        if debug_csv_path is not None:
+            cmd += ['--output_debug', f'{debug_csv_path}']
+        if self.visqol_mode == "speech":
+            cmd += ['--use_speech_mode']
+        cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
+        result = subprocess.run(cmd, capture_output=True)
+        if result.returncode:
+            logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
+            raise RuntimeError("Error while executing visqol")
+        result.check_returncode()
+
+    def __call__(
+        self,
+        ref_sig: torch.Tensor,
+        deg_sig: torch.Tensor,
+        sr: int,
+        pad_with_silence: bool = False,
+    ):
+        """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
+        Args:
+            ref_sig (torch.Tensor): Reference signals as [B, C, T].
+            deg_sig (torch.Tensor): Degraded signals as [B, C, T].
+            sr (int): Sample rate of the two audio signals.
+            pad_with_silence (bool): Whether to pad the file with silences as recommended
+                in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
+        Returns:
+            float: The ViSQOL score or mean score for the batch.
+        """
+        logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
+        tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
+            ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
+        )
+        try:
+            if input_csv and results_csv:
+                self._run_visqol(
+                    input_csv,
+                    results_csv,
+                    debug_json if self.debug else None,
+                )
+                mosqol = self._collect_moslqo_score(results_csv)
+                return mosqol
+            else:
+                raise RuntimeError("Something unexpected happened when running VISQOL!")
+        except Exception as e:
+            logger.error("Exception occurred when running ViSQOL: %s", e)
+        finally:
+            self._flush_files(tmp_dir)
+
+

Class variables

+
+
var ALLOWED_SAMPLE_RATES
+
+
+
+
var SAMPLE_RATES_MODES
+
+
+
+
+

Instance variables

+
+
var visqol_model
+
+
+
+ +Expand source code + +
@property
+def visqol_model(self):
+    return f'{self.visqol_bin}/model/{self.model}'
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/audiogen.html b/api_docs/audiocraft/models/audiogen.html new file mode 100644 index 00000000..085d461d --- /dev/null +++ b/api_docs/audiocraft/models/audiogen.html @@ -0,0 +1,852 @@ + + + + + + +audiocraft.models.audiogen API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.audiogen

+
+
+

Main model for using AudioGen. This will combine all the required components +and provide easy access to the generation API.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Main model for using AudioGen. This will combine all the required components
+and provide easy access to the generation API.
+"""
+
+import typing as tp
+
+import torch
+
+from .encodec import CompressionModel
+from .lm import LMModel
+from .builders import get_debug_compression_model, get_debug_lm_model
+from .loaders import load_compression_model, load_lm_model
+from ..data.audio_utils import convert_audio
+from ..modules.conditioners import ConditioningAttributes
+from ..utils.autocast import TorchAutocast
+
+
+class AudioGen:
+    """AudioGen main model with convenient generation API.
+
+    Args:
+        name (str): name of the model.
+        compression_model (CompressionModel): Compression model
+            used to map audio to invertible discrete representations.
+        lm (LMModel): Language model over discrete representations.
+        max_duration (float, optional): maximum duration the model can produce,
+            otherwise, inferred from the training params.
+    """
+    def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+                 max_duration: tp.Optional[float] = None):
+        self.name = name
+        self.compression_model = compression_model
+        self.lm = lm
+        # Just to be safe, let's put everything in eval mode.
+        self.compression_model.eval()
+        self.lm.eval()
+
+        if max_duration is None:
+            if hasattr(lm, 'cfg'):
+                max_duration = lm.cfg.dataset.segment_duration  # type: ignore
+            else:
+                raise ValueError("You must provide max_duration when building directly AudioGen")
+        assert max_duration is not None
+        self.max_duration: float = max_duration
+        self.device = next(iter(lm.parameters())).device
+        self.generation_params: dict = {}
+        self.set_generation_params(duration=5)  # 5 seconds by default
+        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+        if self.device.type == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+        else:
+            self.autocast = TorchAutocast(
+                enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+    @property
+    def frame_rate(self) -> float:
+        """Roughly the number of AR steps per seconds."""
+        return self.compression_model.frame_rate
+
+    @property
+    def sample_rate(self) -> int:
+        """Sample rate of the generated audio."""
+        return self.compression_model.sample_rate
+
+    @property
+    def audio_channels(self) -> int:
+        """Audio channels of the generated audio."""
+        return self.compression_model.channels
+
+    @staticmethod
+    def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
+        """Return pretrained model, we provide a single model for now:
+        - facebook/audiogen-medium (1.5B), text to sound,
+          # see: https://huggingface.co/facebook/audiogen-medium
+        """
+        if device is None:
+            if torch.cuda.device_count():
+                device = 'cuda'
+            else:
+                device = 'cpu'
+
+        if name == 'debug':
+            # used only for unit tests
+            compression_model = get_debug_compression_model(device, sample_rate=16000)
+            lm = get_debug_lm_model(device)
+            return AudioGen(name, compression_model, lm, max_duration=10)
+
+        compression_model = load_compression_model(name, device=device)
+        lm = load_lm_model(name, device=device)
+        assert 'self_wav' not in lm.condition_provider.conditioners, \
+            "AudioGen do not support waveform conditioning for now"
+        return AudioGen(name, compression_model, lm)
+
+    def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                              top_p: float = 0.0, temperature: float = 1.0,
+                              duration: float = 10.0, cfg_coef: float = 3.0,
+                              two_step_cfg: bool = False, extend_stride: float = 2):
+        """Set the generation parameters for AudioGen.
+
+        Args:
+            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+            top_k (int, optional): top_k used for sampling. Defaults to 250.
+            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+            temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+            duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
+            cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+            two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+                instead of batching together the two. This has some impact on how things
+                are padded but seems to have little impact in practice.
+            extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
+                should we extend the audio each time. Larger values will mean less context is
+                preserved, and shorter value will require extra computations.
+        """
+        assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+        self.extend_stride = extend_stride
+        self.duration = duration
+        self.generation_params = {
+            'use_sampling': use_sampling,
+            'temp': temperature,
+            'top_k': top_k,
+            'top_p': top_p,
+            'cfg_coef': cfg_coef,
+            'two_step_cfg': two_step_cfg,
+        }
+
+    def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+        """Override the default progress callback."""
+        self._progress_callback = progress_callback
+
+    def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on text.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        assert prompt_tokens is None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                              descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                              progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on audio prompts.
+
+        Args:
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+                Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+            prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+            descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if prompt.dim() == 2:
+            prompt = prompt[None]
+        if prompt.dim() != 3:
+            raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+        prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+        if descriptions is None:
+            descriptions = [None] * len(prompt)
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+        assert prompt_tokens is not None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    @torch.no_grad()
+    def _prepare_tokens_and_attributes(
+            self,
+            descriptions: tp.Sequence[tp.Optional[str]],
+            prompt: tp.Optional[torch.Tensor],
+    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+        """Prepare model inputs.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+        """
+        attributes = [
+            ConditioningAttributes(text={'description': description})
+            for description in descriptions]
+
+        if prompt is not None:
+            if descriptions is not None:
+                assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+            prompt = prompt.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt)
+            assert scale is None
+        else:
+            prompt_tokens = None
+        return attributes, prompt_tokens
+
+    def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+                         prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+        """Generate discrete audio tokens given audio prompt and/or conditions.
+
+        Args:
+            attributes (list of ConditioningAttributes): Conditions used for generation (here text).
+            prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        Returns:
+            torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+        """
+        total_gen_len = int(self.duration * self.frame_rate)
+        max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+        current_gen_offset: int = 0
+
+        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+            generated_tokens += current_gen_offset
+            if self._progress_callback is not None:
+                # Note that total_gen_len might be quite wrong depending on the
+                # codebook pattern used, but with delay it is almost accurate.
+                self._progress_callback(generated_tokens, total_gen_len)
+            else:
+                print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+        if prompt_tokens is not None:
+            assert max_prompt_len >= prompt_tokens.shape[-1], \
+                "Prompt is longer than audio to generate"
+
+        callback = None
+        if progress:
+            callback = _progress_callback
+
+        if self.duration <= self.max_duration:
+            # generate by sampling from LM, simple case.
+            with self.autocast:
+                gen_tokens = self.lm.generate(
+                    prompt_tokens, attributes,
+                    callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+        else:
+            all_tokens = []
+            if prompt_tokens is None:
+                prompt_length = 0
+            else:
+                all_tokens.append(prompt_tokens)
+                prompt_length = prompt_tokens.shape[-1]
+
+            stride_tokens = int(self.frame_rate * self.extend_stride)
+            while current_gen_offset + prompt_length < total_gen_len:
+                time_offset = current_gen_offset / self.frame_rate
+                chunk_duration = min(self.duration - time_offset, self.max_duration)
+                max_gen_len = int(chunk_duration * self.frame_rate)
+                with self.autocast:
+                    gen_tokens = self.lm.generate(
+                        prompt_tokens, attributes,
+                        callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+                if prompt_tokens is None:
+                    all_tokens.append(gen_tokens)
+                else:
+                    all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+                prompt_tokens = gen_tokens[:, :, stride_tokens:]
+                prompt_length = prompt_tokens.shape[-1]
+                current_gen_offset += stride_tokens
+
+            gen_tokens = torch.cat(all_tokens, dim=-1)
+
+        # generate audio
+        assert gen_tokens.dim() == 3
+        with torch.no_grad():
+            gen_audio = self.compression_model.decode(gen_tokens, None)
+        return gen_audio
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class AudioGen +(name: str, compression_model: CompressionModel, lm: LMModel, max_duration: Optional[float] = None) +
+
+

AudioGen main model with convenient generation API.

+

Args

+
+
name : str
+
name of the model.
+
compression_model : CompressionModel
+
Compression model +used to map audio to invertible discrete representations.
+
lm : LMModel
+
Language model over discrete representations.
+
max_duration : float, optional
+
maximum duration the model can produce, +otherwise, inferred from the training params.
+
+
+ +Expand source code + +
class AudioGen:
+    """AudioGen main model with convenient generation API.
+
+    Args:
+        name (str): name of the model.
+        compression_model (CompressionModel): Compression model
+            used to map audio to invertible discrete representations.
+        lm (LMModel): Language model over discrete representations.
+        max_duration (float, optional): maximum duration the model can produce,
+            otherwise, inferred from the training params.
+    """
+    def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+                 max_duration: tp.Optional[float] = None):
+        self.name = name
+        self.compression_model = compression_model
+        self.lm = lm
+        # Just to be safe, let's put everything in eval mode.
+        self.compression_model.eval()
+        self.lm.eval()
+
+        if max_duration is None:
+            if hasattr(lm, 'cfg'):
+                max_duration = lm.cfg.dataset.segment_duration  # type: ignore
+            else:
+                raise ValueError("You must provide max_duration when building directly AudioGen")
+        assert max_duration is not None
+        self.max_duration: float = max_duration
+        self.device = next(iter(lm.parameters())).device
+        self.generation_params: dict = {}
+        self.set_generation_params(duration=5)  # 5 seconds by default
+        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+        if self.device.type == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+        else:
+            self.autocast = TorchAutocast(
+                enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+    @property
+    def frame_rate(self) -> float:
+        """Roughly the number of AR steps per seconds."""
+        return self.compression_model.frame_rate
+
+    @property
+    def sample_rate(self) -> int:
+        """Sample rate of the generated audio."""
+        return self.compression_model.sample_rate
+
+    @property
+    def audio_channels(self) -> int:
+        """Audio channels of the generated audio."""
+        return self.compression_model.channels
+
+    @staticmethod
+    def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
+        """Return pretrained model, we provide a single model for now:
+        - facebook/audiogen-medium (1.5B), text to sound,
+          # see: https://huggingface.co/facebook/audiogen-medium
+        """
+        if device is None:
+            if torch.cuda.device_count():
+                device = 'cuda'
+            else:
+                device = 'cpu'
+
+        if name == 'debug':
+            # used only for unit tests
+            compression_model = get_debug_compression_model(device, sample_rate=16000)
+            lm = get_debug_lm_model(device)
+            return AudioGen(name, compression_model, lm, max_duration=10)
+
+        compression_model = load_compression_model(name, device=device)
+        lm = load_lm_model(name, device=device)
+        assert 'self_wav' not in lm.condition_provider.conditioners, \
+            "AudioGen do not support waveform conditioning for now"
+        return AudioGen(name, compression_model, lm)
+
+    def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                              top_p: float = 0.0, temperature: float = 1.0,
+                              duration: float = 10.0, cfg_coef: float = 3.0,
+                              two_step_cfg: bool = False, extend_stride: float = 2):
+        """Set the generation parameters for AudioGen.
+
+        Args:
+            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+            top_k (int, optional): top_k used for sampling. Defaults to 250.
+            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+            temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+            duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
+            cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+            two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+                instead of batching together the two. This has some impact on how things
+                are padded but seems to have little impact in practice.
+            extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
+                should we extend the audio each time. Larger values will mean less context is
+                preserved, and shorter value will require extra computations.
+        """
+        assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+        self.extend_stride = extend_stride
+        self.duration = duration
+        self.generation_params = {
+            'use_sampling': use_sampling,
+            'temp': temperature,
+            'top_k': top_k,
+            'top_p': top_p,
+            'cfg_coef': cfg_coef,
+            'two_step_cfg': two_step_cfg,
+        }
+
+    def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+        """Override the default progress callback."""
+        self._progress_callback = progress_callback
+
+    def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on text.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        assert prompt_tokens is None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                              descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                              progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on audio prompts.
+
+        Args:
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+                Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+            prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+            descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if prompt.dim() == 2:
+            prompt = prompt[None]
+        if prompt.dim() != 3:
+            raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+        prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+        if descriptions is None:
+            descriptions = [None] * len(prompt)
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+        assert prompt_tokens is not None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    @torch.no_grad()
+    def _prepare_tokens_and_attributes(
+            self,
+            descriptions: tp.Sequence[tp.Optional[str]],
+            prompt: tp.Optional[torch.Tensor],
+    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+        """Prepare model inputs.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+        """
+        attributes = [
+            ConditioningAttributes(text={'description': description})
+            for description in descriptions]
+
+        if prompt is not None:
+            if descriptions is not None:
+                assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+            prompt = prompt.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt)
+            assert scale is None
+        else:
+            prompt_tokens = None
+        return attributes, prompt_tokens
+
+    def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+                         prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+        """Generate discrete audio tokens given audio prompt and/or conditions.
+
+        Args:
+            attributes (list of ConditioningAttributes): Conditions used for generation (here text).
+            prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        Returns:
+            torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+        """
+        total_gen_len = int(self.duration * self.frame_rate)
+        max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+        current_gen_offset: int = 0
+
+        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+            generated_tokens += current_gen_offset
+            if self._progress_callback is not None:
+                # Note that total_gen_len might be quite wrong depending on the
+                # codebook pattern used, but with delay it is almost accurate.
+                self._progress_callback(generated_tokens, total_gen_len)
+            else:
+                print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+        if prompt_tokens is not None:
+            assert max_prompt_len >= prompt_tokens.shape[-1], \
+                "Prompt is longer than audio to generate"
+
+        callback = None
+        if progress:
+            callback = _progress_callback
+
+        if self.duration <= self.max_duration:
+            # generate by sampling from LM, simple case.
+            with self.autocast:
+                gen_tokens = self.lm.generate(
+                    prompt_tokens, attributes,
+                    callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+        else:
+            all_tokens = []
+            if prompt_tokens is None:
+                prompt_length = 0
+            else:
+                all_tokens.append(prompt_tokens)
+                prompt_length = prompt_tokens.shape[-1]
+
+            stride_tokens = int(self.frame_rate * self.extend_stride)
+            while current_gen_offset + prompt_length < total_gen_len:
+                time_offset = current_gen_offset / self.frame_rate
+                chunk_duration = min(self.duration - time_offset, self.max_duration)
+                max_gen_len = int(chunk_duration * self.frame_rate)
+                with self.autocast:
+                    gen_tokens = self.lm.generate(
+                        prompt_tokens, attributes,
+                        callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+                if prompt_tokens is None:
+                    all_tokens.append(gen_tokens)
+                else:
+                    all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+                prompt_tokens = gen_tokens[:, :, stride_tokens:]
+                prompt_length = prompt_tokens.shape[-1]
+                current_gen_offset += stride_tokens
+
+            gen_tokens = torch.cat(all_tokens, dim=-1)
+
+        # generate audio
+        assert gen_tokens.dim() == 3
+        with torch.no_grad():
+            gen_audio = self.compression_model.decode(gen_tokens, None)
+        return gen_audio
+
+

Static methods

+
+
+def get_pretrained(name: str = 'facebook/audiogen-medium', device=None) +
+
+

Return pretrained model, we provide a single model for now: +- facebook/audiogen-medium (1.5B), text to sound, +# see: https://huggingface.co/facebook/audiogen-medium

+
+ +Expand source code + +
@staticmethod
+def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
+    """Return pretrained model, we provide a single model for now:
+    - facebook/audiogen-medium (1.5B), text to sound,
+      # see: https://huggingface.co/facebook/audiogen-medium
+    """
+    if device is None:
+        if torch.cuda.device_count():
+            device = 'cuda'
+        else:
+            device = 'cpu'
+
+    if name == 'debug':
+        # used only for unit tests
+        compression_model = get_debug_compression_model(device, sample_rate=16000)
+        lm = get_debug_lm_model(device)
+        return AudioGen(name, compression_model, lm, max_duration=10)
+
+    compression_model = load_compression_model(name, device=device)
+    lm = load_lm_model(name, device=device)
+    assert 'self_wav' not in lm.condition_provider.conditioners, \
+        "AudioGen do not support waveform conditioning for now"
+    return AudioGen(name, compression_model, lm)
+
+
+
+

Instance variables

+
+
var audio_channels : int
+
+

Audio channels of the generated audio.

+
+ +Expand source code + +
@property
+def audio_channels(self) -> int:
+    """Audio channels of the generated audio."""
+    return self.compression_model.channels
+
+
+
var frame_rate : float
+
+

Roughly the number of AR steps per seconds.

+
+ +Expand source code + +
@property
+def frame_rate(self) -> float:
+    """Roughly the number of AR steps per seconds."""
+    return self.compression_model.frame_rate
+
+
+
var sample_rate : int
+
+

Sample rate of the generated audio.

+
+ +Expand source code + +
@property
+def sample_rate(self) -> int:
+    """Sample rate of the generated audio."""
+    return self.compression_model.sample_rate
+
+
+
+

Methods

+
+
+def generate(self, descriptions: List[str], progress: bool = False) ‑> torch.Tensor +
+
+

Generate samples conditioned on text.

+

Args

+
+
descriptions : list of str
+
A list of strings used as text conditioning.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+    """Generate samples conditioned on text.
+
+    Args:
+        descriptions (list of str): A list of strings used as text conditioning.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+    assert prompt_tokens is None
+    return self._generate_tokens(attributes, prompt_tokens, progress)
+
+
+
+def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: Optional[List[Optional[str]]] = None, progress: bool = False) ‑> torch.Tensor +
+
+

Generate samples conditioned on audio prompts.

+

Args

+
+
prompt : torch.Tensor
+
A batch of waveforms used for continuation. +Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+
prompt_sample_rate : int
+
Sampling rate of the given audio waveforms.
+
descriptions : list of str, optional
+
A list of strings used as text conditioning. Defaults to None.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                          descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                          progress: bool = False) -> torch.Tensor:
+    """Generate samples conditioned on audio prompts.
+
+    Args:
+        prompt (torch.Tensor): A batch of waveforms used for continuation.
+            Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+        prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+        descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    if prompt.dim() == 2:
+        prompt = prompt[None]
+    if prompt.dim() != 3:
+        raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+    prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+    if descriptions is None:
+        descriptions = [None] * len(prompt)
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+    assert prompt_tokens is not None
+    return self._generate_tokens(attributes, prompt_tokens, progress)
+
+
+
+def set_custom_progress_callback(self, progress_callback: Optional[Callable[[int, int], None]] = None) +
+
+

Override the default progress callback.

+
+ +Expand source code + +
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+    """Override the default progress callback."""
+    self._progress_callback = progress_callback
+
+
+
+def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, duration: float = 10.0, cfg_coef: float = 3.0, two_step_cfg: bool = False, extend_stride: float = 2) +
+
+

Set the generation parameters for AudioGen.

+

Args

+
+
use_sampling : bool, optional
+
Use sampling if True, else do argmax decoding. Defaults to True.
+
top_k : int, optional
+
top_k used for sampling. Defaults to 250.
+
top_p : float, optional
+
top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+
temperature : float, optional
+
Softmax temperature parameter. Defaults to 1.0.
+
duration : float, optional
+
Duration of the generated waveform. Defaults to 10.0.
+
cfg_coef : float, optional
+
Coefficient used for classifier free guidance. Defaults to 3.0.
+
two_step_cfg : bool, optional
+
If True, performs 2 forward for Classifier Free Guidance, +instead of batching together the two. This has some impact on how things +are padded but seems to have little impact in practice.
+
extend_stride
+
when doing extended generation (i.e. more than 10 seconds), by how much +should we extend the audio each time. Larger values will mean less context is +preserved, and shorter value will require extra computations.
+
+
+ +Expand source code + +
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                          top_p: float = 0.0, temperature: float = 1.0,
+                          duration: float = 10.0, cfg_coef: float = 3.0,
+                          two_step_cfg: bool = False, extend_stride: float = 2):
+    """Set the generation parameters for AudioGen.
+
+    Args:
+        use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+        top_k (int, optional): top_k used for sampling. Defaults to 250.
+        top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+        temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+        duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
+        cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+        two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+            instead of batching together the two. This has some impact on how things
+            are padded but seems to have little impact in practice.
+        extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
+            should we extend the audio each time. Larger values will mean less context is
+            preserved, and shorter value will require extra computations.
+    """
+    assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+    self.extend_stride = extend_stride
+    self.duration = duration
+    self.generation_params = {
+        'use_sampling': use_sampling,
+        'temp': temperature,
+        'top_k': top_k,
+        'top_p': top_p,
+        'cfg_coef': cfg_coef,
+        'two_step_cfg': two_step_cfg,
+    }
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/builders.html b/api_docs/audiocraft/models/builders.html new file mode 100644 index 00000000..96aff900 --- /dev/null +++ b/api_docs/audiocraft/models/builders.html @@ -0,0 +1,650 @@ + + + + + + +audiocraft.models.builders API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.builders

+
+
+

All the functions to build the relevant models and modules +from the Hydra config.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+All the functions to build the relevant models and modules
+from the Hydra config.
+"""
+
+import typing as tp
+
+import audiocraft
+import omegaconf
+import torch
+
+from .encodec import CompressionModel, EncodecModel
+from .lm import LMModel
+from ..modules.codebooks_patterns import (
+    CodebooksPatternProvider,
+    DelayedPatternProvider,
+    MusicLMPattern,
+    ParallelPatternProvider,
+    UnrolledPatternProvider,
+    CoarseFirstPattern,
+)
+from ..modules.conditioners import (
+    BaseConditioner,
+    ChromaStemConditioner,
+    CLAPEmbeddingConditioner,
+    ConditionFuser,
+    ConditioningProvider,
+    LUTConditioner,
+    T5Conditioner,
+)
+from .unet import DiffusionUnet
+from .. import quantization as qt
+from ..utils.utils import dict_from_config
+from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
+
+
+def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
+    klass = {
+        'no_quant': qt.DummyQuantizer,
+        'rvq': qt.ResidualVectorQuantizer
+    }[quantizer]
+    kwargs = dict_from_config(getattr(cfg, quantizer))
+    if quantizer != 'no_quant':
+        kwargs['dimension'] = dimension
+    return klass(**kwargs)
+
+
+def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
+    if encoder_name == 'seanet':
+        kwargs = dict_from_config(getattr(cfg, 'seanet'))
+        encoder_override_kwargs = kwargs.pop('encoder')
+        decoder_override_kwargs = kwargs.pop('decoder')
+        encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+        decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+        encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+        decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+        return encoder, decoder
+    else:
+        raise KeyError(f"Unexpected compression model {cfg.compression_model}")
+
+
+def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
+    """Instantiate a compression model."""
+    if cfg.compression_model == 'encodec':
+        kwargs = dict_from_config(getattr(cfg, 'encodec'))
+        encoder_name = kwargs.pop('autoencoder')
+        quantizer_name = kwargs.pop('quantizer')
+        encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
+        quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
+        frame_rate = kwargs['sample_rate'] // encoder.hop_length
+        renormalize = kwargs.pop('renormalize', False)
+        # deprecated params
+        kwargs.pop('renorm', None)
+        return EncodecModel(encoder, decoder, quantizer,
+                            frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
+    else:
+        raise KeyError(f"Unexpected compression model {cfg.compression_model}")
+
+
+def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
+    """Instantiate a transformer LM."""
+    if cfg.lm_model == 'transformer_lm':
+        kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
+        n_q = kwargs['n_q']
+        q_modeling = kwargs.pop('q_modeling', None)
+        codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
+        attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
+        cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
+        cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
+        fuser = get_condition_fuser(cfg)
+        condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
+        if len(fuser.fuse2cond['cross']) > 0:  # enforce cross-att programmatically
+            kwargs['cross_attention'] = True
+        if codebooks_pattern_cfg.modeling is None:
+            assert q_modeling is not None, \
+                "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
+            codebooks_pattern_cfg = omegaconf.OmegaConf.create(
+                {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
+            )
+        pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
+        return LMModel(
+            pattern_provider=pattern_provider,
+            condition_provider=condition_provider,
+            fuser=fuser,
+            cfg_dropout=cfg_prob,
+            cfg_coef=cfg_coef,
+            attribute_dropout=attribute_dropout,
+            dtype=getattr(torch, cfg.dtype),
+            device=cfg.device,
+            **kwargs
+        ).to(cfg.device)
+    else:
+        raise KeyError(f"Unexpected LM model {cfg.lm_model}")
+
+
+def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
+    """Instantiate a conditioning model."""
+    device = cfg.device
+    duration = cfg.dataset.segment_duration
+    cfg = getattr(cfg, 'conditioners')
+    dict_cfg = {} if cfg is None else dict_from_config(cfg)
+    conditioners: tp.Dict[str, BaseConditioner] = {}
+    condition_provider_args = dict_cfg.pop('args', {})
+    condition_provider_args.pop('merge_text_conditions_p', None)
+    condition_provider_args.pop('drop_desc_p', None)
+
+    for cond, cond_cfg in dict_cfg.items():
+        model_type = cond_cfg['model']
+        model_args = cond_cfg[model_type]
+        if model_type == 't5':
+            conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
+        elif model_type == 'lut':
+            conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
+        elif model_type == 'chroma_stem':
+            conditioners[str(cond)] = ChromaStemConditioner(
+                output_dim=output_dim,
+                duration=duration,
+                device=device,
+                **model_args
+            )
+        elif model_type == 'clap':
+            conditioners[str(cond)] = CLAPEmbeddingConditioner(
+                output_dim=output_dim,
+                device=device,
+                **model_args
+            )
+        else:
+            raise ValueError(f"Unrecognized conditioning model: {model_type}")
+    conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
+    return conditioner
+
+
+def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
+    """Instantiate a condition fuser object."""
+    fuser_cfg = getattr(cfg, 'fuser')
+    fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
+    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
+    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
+    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
+    return fuser
+
+
+def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
+    """Instantiate a codebooks pattern provider object."""
+    pattern_providers = {
+        'parallel': ParallelPatternProvider,
+        'delay': DelayedPatternProvider,
+        'unroll': UnrolledPatternProvider,
+        'coarse_first': CoarseFirstPattern,
+        'musiclm': MusicLMPattern,
+    }
+    name = cfg.modeling
+    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
+    klass = pattern_providers[name]
+    return klass(n_q, **kwargs)
+
+
+def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
+    """Instantiate a debug compression model to be used for unit tests."""
+    assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
+    model_ratios = {
+        16000: [10, 8, 8],  # 25 Hz at 16kHz
+        32000: [10, 8, 16]  # 25 Hz at 32kHz
+    }
+    ratios: tp.List[int] = model_ratios[sample_rate]
+    frame_rate = 25
+    seanet_kwargs: dict = {
+        'n_filters': 4,
+        'n_residual_layers': 1,
+        'dimension': 32,
+        'ratios': ratios,
+    }
+    encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
+    decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
+    quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
+    init_x = torch.randn(8, 32, 128)
+    quantizer(init_x, 1)  # initialize kmeans etc.
+    compression_model = EncodecModel(
+        encoder, decoder, quantizer,
+        frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
+    return compression_model.eval()
+
+
+def get_diffusion_model(cfg: omegaconf.DictConfig):
+    # TODO Find a way to infer the channels from dset
+    channels = cfg.channels
+    num_steps = cfg.schedule.num_steps
+    return DiffusionUnet(
+            chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
+
+
+def get_processor(cfg, sample_rate: int = 24000):
+    sample_processor = SampleProcessor()
+    if cfg.use:
+        kw = dict(cfg)
+        kw.pop('use')
+        kw.pop('name')
+        if cfg.name == "multi_band_processor":
+            sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
+    return sample_processor
+
+
+def get_debug_lm_model(device='cpu'):
+    """Instantiate a debug LM to be used for unit tests."""
+    pattern = DelayedPatternProvider(n_q=4)
+    dim = 16
+    providers = {
+        'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
+    }
+    condition_provider = ConditioningProvider(providers)
+    fuser = ConditionFuser(
+        {'cross': ['description'], 'prepend': [],
+         'sum': [], 'input_interpolate': []})
+    lm = LMModel(
+        pattern, condition_provider, fuser,
+        n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
+        cross_attention=True, causal=True)
+    return lm.to(device).eval()
+
+
+def get_wrapped_compression_model(
+        compression_model: CompressionModel,
+        cfg: omegaconf.DictConfig) -> CompressionModel:
+    # more to come.
+    return compression_model
+
+
+
+
+
+
+
+

Functions

+
+
+def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.dictconfig.DictConfig) ‑> CodebooksPatternProvider +
+
+

Instantiate a codebooks pattern provider object.

+
+ +Expand source code + +
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
+    """Instantiate a codebooks pattern provider object."""
+    pattern_providers = {
+        'parallel': ParallelPatternProvider,
+        'delay': DelayedPatternProvider,
+        'unroll': UnrolledPatternProvider,
+        'coarse_first': CoarseFirstPattern,
+        'musiclm': MusicLMPattern,
+    }
+    name = cfg.modeling
+    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
+    klass = pattern_providers[name]
+    return klass(n_q, **kwargs)
+
+
+
+def get_compression_model(cfg: omegaconf.dictconfig.DictConfig) ‑> CompressionModel +
+
+

Instantiate a compression model.

+
+ +Expand source code + +
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
+    """Instantiate a compression model."""
+    if cfg.compression_model == 'encodec':
+        kwargs = dict_from_config(getattr(cfg, 'encodec'))
+        encoder_name = kwargs.pop('autoencoder')
+        quantizer_name = kwargs.pop('quantizer')
+        encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
+        quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
+        frame_rate = kwargs['sample_rate'] // encoder.hop_length
+        renormalize = kwargs.pop('renormalize', False)
+        # deprecated params
+        kwargs.pop('renorm', None)
+        return EncodecModel(encoder, decoder, quantizer,
+                            frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
+    else:
+        raise KeyError(f"Unexpected compression model {cfg.compression_model}")
+
+
+
+def get_condition_fuser(cfg: omegaconf.dictconfig.DictConfig) ‑> ConditionFuser +
+
+

Instantiate a condition fuser object.

+
+ +Expand source code + +
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
+    """Instantiate a condition fuser object."""
+    fuser_cfg = getattr(cfg, 'fuser')
+    fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
+    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
+    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
+    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
+    return fuser
+
+
+
+def get_conditioner_provider(output_dim: int, cfg: omegaconf.dictconfig.DictConfig) ‑> ConditioningProvider +
+
+

Instantiate a conditioning model.

+
+ +Expand source code + +
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
+    """Instantiate a conditioning model."""
+    device = cfg.device
+    duration = cfg.dataset.segment_duration
+    cfg = getattr(cfg, 'conditioners')
+    dict_cfg = {} if cfg is None else dict_from_config(cfg)
+    conditioners: tp.Dict[str, BaseConditioner] = {}
+    condition_provider_args = dict_cfg.pop('args', {})
+    condition_provider_args.pop('merge_text_conditions_p', None)
+    condition_provider_args.pop('drop_desc_p', None)
+
+    for cond, cond_cfg in dict_cfg.items():
+        model_type = cond_cfg['model']
+        model_args = cond_cfg[model_type]
+        if model_type == 't5':
+            conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
+        elif model_type == 'lut':
+            conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
+        elif model_type == 'chroma_stem':
+            conditioners[str(cond)] = ChromaStemConditioner(
+                output_dim=output_dim,
+                duration=duration,
+                device=device,
+                **model_args
+            )
+        elif model_type == 'clap':
+            conditioners[str(cond)] = CLAPEmbeddingConditioner(
+                output_dim=output_dim,
+                device=device,
+                **model_args
+            )
+        else:
+            raise ValueError(f"Unrecognized conditioning model: {model_type}")
+    conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
+    return conditioner
+
+
+
+def get_debug_compression_model(device='cpu', sample_rate: int = 32000) +
+
+

Instantiate a debug compression model to be used for unit tests.

+
+ +Expand source code + +
def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
+    """Instantiate a debug compression model to be used for unit tests."""
+    assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
+    model_ratios = {
+        16000: [10, 8, 8],  # 25 Hz at 16kHz
+        32000: [10, 8, 16]  # 25 Hz at 32kHz
+    }
+    ratios: tp.List[int] = model_ratios[sample_rate]
+    frame_rate = 25
+    seanet_kwargs: dict = {
+        'n_filters': 4,
+        'n_residual_layers': 1,
+        'dimension': 32,
+        'ratios': ratios,
+    }
+    encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
+    decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
+    quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
+    init_x = torch.randn(8, 32, 128)
+    quantizer(init_x, 1)  # initialize kmeans etc.
+    compression_model = EncodecModel(
+        encoder, decoder, quantizer,
+        frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
+    return compression_model.eval()
+
+
+
+def get_debug_lm_model(device='cpu') +
+
+

Instantiate a debug LM to be used for unit tests.

+
+ +Expand source code + +
def get_debug_lm_model(device='cpu'):
+    """Instantiate a debug LM to be used for unit tests."""
+    pattern = DelayedPatternProvider(n_q=4)
+    dim = 16
+    providers = {
+        'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
+    }
+    condition_provider = ConditioningProvider(providers)
+    fuser = ConditionFuser(
+        {'cross': ['description'], 'prepend': [],
+         'sum': [], 'input_interpolate': []})
+    lm = LMModel(
+        pattern, condition_provider, fuser,
+        n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
+        cross_attention=True, causal=True)
+    return lm.to(device).eval()
+
+
+
+def get_diffusion_model(cfg: omegaconf.dictconfig.DictConfig) +
+
+
+
+ +Expand source code + +
def get_diffusion_model(cfg: omegaconf.DictConfig):
+    # TODO Find a way to infer the channels from dset
+    channels = cfg.channels
+    num_steps = cfg.schedule.num_steps
+    return DiffusionUnet(
+            chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
+
+
+
+def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.dictconfig.DictConfig) +
+
+
+
+ +Expand source code + +
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
+    if encoder_name == 'seanet':
+        kwargs = dict_from_config(getattr(cfg, 'seanet'))
+        encoder_override_kwargs = kwargs.pop('encoder')
+        decoder_override_kwargs = kwargs.pop('decoder')
+        encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+        decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+        encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+        decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+        return encoder, decoder
+    else:
+        raise KeyError(f"Unexpected compression model {cfg.compression_model}")
+
+
+
+def get_lm_model(cfg: omegaconf.dictconfig.DictConfig) ‑> LMModel +
+
+

Instantiate a transformer LM.

+
+ +Expand source code + +
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
+    """Instantiate a transformer LM."""
+    if cfg.lm_model == 'transformer_lm':
+        kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
+        n_q = kwargs['n_q']
+        q_modeling = kwargs.pop('q_modeling', None)
+        codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
+        attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
+        cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
+        cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
+        fuser = get_condition_fuser(cfg)
+        condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
+        if len(fuser.fuse2cond['cross']) > 0:  # enforce cross-att programmatically
+            kwargs['cross_attention'] = True
+        if codebooks_pattern_cfg.modeling is None:
+            assert q_modeling is not None, \
+                "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
+            codebooks_pattern_cfg = omegaconf.OmegaConf.create(
+                {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
+            )
+        pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
+        return LMModel(
+            pattern_provider=pattern_provider,
+            condition_provider=condition_provider,
+            fuser=fuser,
+            cfg_dropout=cfg_prob,
+            cfg_coef=cfg_coef,
+            attribute_dropout=attribute_dropout,
+            dtype=getattr(torch, cfg.dtype),
+            device=cfg.device,
+            **kwargs
+        ).to(cfg.device)
+    else:
+        raise KeyError(f"Unexpected LM model {cfg.lm_model}")
+
+
+
+def get_processor(cfg, sample_rate: int = 24000) +
+
+
+
+ +Expand source code + +
def get_processor(cfg, sample_rate: int = 24000):
+    sample_processor = SampleProcessor()
+    if cfg.use:
+        kw = dict(cfg)
+        kw.pop('use')
+        kw.pop('name')
+        if cfg.name == "multi_band_processor":
+            sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
+    return sample_processor
+
+
+
+def get_quantizer(quantizer: str, cfg: omegaconf.dictconfig.DictConfig, dimension: int) ‑> BaseQuantizer +
+
+
+
+ +Expand source code + +
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
+    klass = {
+        'no_quant': qt.DummyQuantizer,
+        'rvq': qt.ResidualVectorQuantizer
+    }[quantizer]
+    kwargs = dict_from_config(getattr(cfg, quantizer))
+    if quantizer != 'no_quant':
+        kwargs['dimension'] = dimension
+    return klass(**kwargs)
+
+
+
+def get_wrapped_compression_model(compression_model: CompressionModel, cfg: omegaconf.dictconfig.DictConfig) ‑> CompressionModel +
+
+
+
+ +Expand source code + +
def get_wrapped_compression_model(
+        compression_model: CompressionModel,
+        cfg: omegaconf.DictConfig) -> CompressionModel:
+    # more to come.
+    return compression_model
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/encodec.html b/api_docs/audiocraft/models/encodec.html new file mode 100644 index 00000000..605002a8 --- /dev/null +++ b/api_docs/audiocraft/models/encodec.html @@ -0,0 +1,1622 @@ + + + + + + +audiocraft.models.encodec API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.encodec

+
+
+

Compression models or wrapper around existing models. +Also defines the main interface that a model must follow to be usable as an audio tokenizer.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Compression models or wrapper around existing models.
+Also defines the main interface that a model must follow to be usable as an audio tokenizer.
+"""
+
+from abc import ABC, abstractmethod
+import logging
+import math
+from pathlib import Path
+import typing as tp
+
+import numpy as np
+import torch
+from torch import nn
+from transformers import EncodecModel as HFEncodecModel
+
+from .. import quantization as qt
+
+
+logger = logging.getLogger()
+
+
+class CompressionModel(ABC, nn.Module):
+    """Base API for all compression model that aim at being used as audio tokenizers
+    with a language model.
+    """
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        ...
+
+    @abstractmethod
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """See `EncodecModel.encode`."""
+        ...
+
+    @abstractmethod
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """See `EncodecModel.decode`."""
+        ...
+
+    @abstractmethod
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        ...
+
+    @property
+    @abstractmethod
+    def channels(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def frame_rate(self) -> float:
+        ...
+
+    @property
+    @abstractmethod
+    def sample_rate(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def cardinality(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def num_codebooks(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def total_codebooks(self) -> int:
+        ...
+
+    @abstractmethod
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer."""
+        ...
+
+    @staticmethod
+    def get_pretrained(
+            name: str, device: tp.Union[torch.device, str] = 'cpu'
+            ) -> 'CompressionModel':
+        """Instantiate a CompressionModel from a given pretrained model.
+
+        Args:
+            name (Path or str): name of the pretrained model. See after.
+            device (torch.device or str): Device on which the model is loaded.
+
+        Pretrained models:
+            - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
+            - dac_24khz (same)
+            - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
+            - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
+            - your own model on HugginFace. Export instructions to come...
+        """
+
+        from . import builders, loaders
+        model: CompressionModel
+        if name in ['dac_44khz', 'dac_24khz']:
+            model_type = name.split('_')[1]
+            logger.info("Getting pretrained compression model from DAC %s", model_type)
+            model = DAC(model_type)
+        elif name in ['debug_compression_model']:
+            logger.info("Getting pretrained compression model for debug")
+            model = builders.get_debug_compression_model()
+        elif Path(name).exists():
+            # We assume here if the paths exist that it is in fact an AC checkpoint
+            # that was exported using `audiocraft.utils.export` functions.
+            model = loaders.load_compression_model(name, device=device)
+        else:
+            logger.info("Getting pretrained compression model from HF %s", name)
+            hf_model = HFEncodecModel.from_pretrained(name)
+            model = HFEncodecCompressionModel(hf_model).to(device)
+        return model.to(device).eval()
+
+
+class EncodecModel(CompressionModel):
+    """Encodec model operating on the raw waveform.
+
+    Args:
+        encoder (nn.Module): Encoder network.
+        decoder (nn.Module): Decoder network.
+        quantizer (qt.BaseQuantizer): Quantizer network.
+        frame_rate (int): Frame rate for the latent representation.
+        sample_rate (int): Audio sample rate.
+        channels (int): Number of audio channels.
+        causal (bool): Whether to use a causal version of the model.
+        renormalize (bool): Whether to renormalize the audio before running the model.
+    """
+    # we need assignment to override the property in the abstract class,
+    # I couldn't find a better way...
+    frame_rate: float = 0
+    sample_rate: int = 0
+    channels: int = 0
+
+    def __init__(self,
+                 encoder: nn.Module,
+                 decoder: nn.Module,
+                 quantizer: qt.BaseQuantizer,
+                 frame_rate: int,
+                 sample_rate: int,
+                 channels: int,
+                 causal: bool = False,
+                 renormalize: bool = False):
+        super().__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.quantizer = quantizer
+        self.frame_rate = frame_rate
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.renormalize = renormalize
+        self.causal = causal
+        if self.causal:
+            # we force disabling here to avoid handling linear overlap of segments
+            # as supported in original EnCodec codebase.
+            assert not self.renormalize, 'Causal model does not support renormalize'
+
+    @property
+    def total_codebooks(self):
+        """Total number of quantizer codebooks available."""
+        return self.quantizer.total_codebooks
+
+    @property
+    def num_codebooks(self):
+        """Active number of codebooks used by the quantizer."""
+        return self.quantizer.num_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer."""
+        self.quantizer.set_num_codebooks(n)
+
+    @property
+    def cardinality(self):
+        """Cardinality of each codebook."""
+        return self.quantizer.bins
+
+    def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        scale: tp.Optional[torch.Tensor]
+        if self.renormalize:
+            mono = x.mean(dim=1, keepdim=True)
+            volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+            scale = 1e-8 + volume
+            x = x / scale
+            scale = scale.view(-1, 1)
+        else:
+            scale = None
+        return x, scale
+
+    def postprocess(self,
+                    x: torch.Tensor,
+                    scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+        if scale is not None:
+            assert self.renormalize
+            x = x * scale.view(-1, 1, 1)
+        return x
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        assert x.dim() == 3
+        length = x.shape[-1]
+        x, scale = self.preprocess(x)
+
+        emb = self.encoder(x)
+        q_res = self.quantizer(emb, self.frame_rate)
+        out = self.decoder(q_res.x)
+
+        # remove extra padding added by the encoder and decoder
+        assert out.shape[-1] >= length, (out.shape[-1], length)
+        out = out[..., :length]
+
+        q_res.x = self.postprocess(out, scale)
+
+        return q_res
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """Encode the given input tensor to quantized representation along with scale parameter.
+
+        Args:
+            x (torch.Tensor): Float tensor of shape [B, C, T]
+
+        Returns:
+            codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
+                codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+                scale a float tensor containing the scale for audio renormalizealization.
+        """
+        assert x.dim() == 3
+        x, scale = self.preprocess(x)
+        emb = self.encoder(x)
+        codes = self.quantizer.encode(emb)
+        return codes, scale
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """Decode the given codes to a reconstructed representation, using the scale to perform
+        audio denormalization if needed.
+
+        Args:
+            codes (torch.Tensor): Int tensor of shape [B, K, T]
+            scale (torch.Tensor, optional): Float tensor containing the scale value.
+
+        Returns:
+            out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+        """
+        emb = self.decode_latent(codes)
+        out = self.decoder(emb)
+        out = self.postprocess(out, scale)
+        # out contains extra padding added by the encoder and decoder
+        return out
+
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        return self.quantizer.decode(codes)
+
+
+class DAC(CompressionModel):
+    def __init__(self, model_type: str = "44khz"):
+        super().__init__()
+        try:
+            import dac.utils
+        except ImportError:
+            raise RuntimeError("Could not import dac, make sure it is installed, "
+                               "please run `pip install descript-audio-codec`")
+        self.model = dac.utils.load_model(model_type=model_type)
+        self.n_quantizers = self.total_codebooks
+        self.model.eval()
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        # We don't support training with this.
+        raise NotImplementedError("Forward and training with DAC not supported.")
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        codes = self.model.encode(x, self.n_quantizers)[1]
+        return codes[:, :self.n_quantizers], None
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        assert scale is None
+        z_q = self.decode_latent(codes)
+        return self.model.decode(z_q)
+
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        return self.model.quantizer.from_codes(codes)[0]
+
+    @property
+    def channels(self) -> int:
+        return 1
+
+    @property
+    def frame_rate(self) -> float:
+        return self.model.sample_rate / self.model.hop_length
+
+    @property
+    def sample_rate(self) -> int:
+        return self.model.sample_rate
+
+    @property
+    def cardinality(self) -> int:
+        return self.model.codebook_size
+
+    @property
+    def num_codebooks(self) -> int:
+        return self.n_quantizers
+
+    @property
+    def total_codebooks(self) -> int:
+        return self.model.n_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        assert n >= 1
+        assert n <= self.total_codebooks
+        self.n_quantizers = n
+
+
+class HFEncodecCompressionModel(CompressionModel):
+    """Wrapper around HuggingFace Encodec.
+    """
+    def __init__(self, model: HFEncodecModel):
+        super().__init__()
+        self.model = model
+        bws = self.model.config.target_bandwidths
+        num_codebooks = [
+            bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
+            for bw in bws
+        ]
+        deltas = [nc - int(nc) for nc in num_codebooks]
+        # Checking we didn't do some bad maths and we indeed have integers!
+        assert all(deltas) <= 1e-3, deltas
+        self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
+        self.set_num_codebooks(max(self.possible_num_codebooks))
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        # We don't support training with this.
+        raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
+        bandwidth = self.model.config.target_bandwidths[bandwidth_index]
+        res = self.model.encode(x, None, bandwidth)
+        assert len(res[0]) == 1
+        assert len(res[1]) == 1
+        return res[0][0], res[1][0]
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        if scale is None:
+            scales = [None]  # type: ignore
+        else:
+            scales = scale  # type: ignore
+        res = self.model.decode(codes[None], scales)
+        return res[0]
+
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        return self.model.quantizer.decode(codes.transpose(0, 1))
+
+    @property
+    def channels(self) -> int:
+        return self.model.config.audio_channels
+
+    @property
+    def frame_rate(self) -> float:
+        hop_length = int(np.prod(self.model.config.upsampling_ratios))
+        return self.sample_rate / hop_length
+
+    @property
+    def sample_rate(self) -> int:
+        return self.model.config.sampling_rate
+
+    @property
+    def cardinality(self) -> int:
+        return self.model.config.codebook_size
+
+    @property
+    def num_codebooks(self) -> int:
+        return self._num_codebooks
+
+    @property
+    def total_codebooks(self) -> int:
+        return max(self.possible_num_codebooks)
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        if n not in self.possible_num_codebooks:
+            raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
+        self._num_codebooks = n
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CompressionModel +(*args, **kwargs) +
+
+

Base API for all compression model that aim at being used as audio tokenizers +with a language model.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class CompressionModel(ABC, nn.Module):
+    """Base API for all compression model that aim at being used as audio tokenizers
+    with a language model.
+    """
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        ...
+
+    @abstractmethod
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """See `EncodecModel.encode`."""
+        ...
+
+    @abstractmethod
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """See `EncodecModel.decode`."""
+        ...
+
+    @abstractmethod
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        ...
+
+    @property
+    @abstractmethod
+    def channels(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def frame_rate(self) -> float:
+        ...
+
+    @property
+    @abstractmethod
+    def sample_rate(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def cardinality(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def num_codebooks(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def total_codebooks(self) -> int:
+        ...
+
+    @abstractmethod
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer."""
+        ...
+
+    @staticmethod
+    def get_pretrained(
+            name: str, device: tp.Union[torch.device, str] = 'cpu'
+            ) -> 'CompressionModel':
+        """Instantiate a CompressionModel from a given pretrained model.
+
+        Args:
+            name (Path or str): name of the pretrained model. See after.
+            device (torch.device or str): Device on which the model is loaded.
+
+        Pretrained models:
+            - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
+            - dac_24khz (same)
+            - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
+            - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
+            - your own model on HugginFace. Export instructions to come...
+        """
+
+        from . import builders, loaders
+        model: CompressionModel
+        if name in ['dac_44khz', 'dac_24khz']:
+            model_type = name.split('_')[1]
+            logger.info("Getting pretrained compression model from DAC %s", model_type)
+            model = DAC(model_type)
+        elif name in ['debug_compression_model']:
+            logger.info("Getting pretrained compression model for debug")
+            model = builders.get_debug_compression_model()
+        elif Path(name).exists():
+            # We assume here if the paths exist that it is in fact an AC checkpoint
+            # that was exported using `audiocraft.utils.export` functions.
+            model = loaders.load_compression_model(name, device=device)
+        else:
+            logger.info("Getting pretrained compression model from HF %s", name)
+            hf_model = HFEncodecModel.from_pretrained(name)
+            model = HFEncodecCompressionModel(hf_model).to(device)
+        return model.to(device).eval()
+
+

Ancestors

+
    +
  • abc.ABC
  • +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Static methods

+
+
+def get_pretrained(name: str, device: Union[torch.device, str] = 'cpu') ‑> CompressionModel +
+
+

Instantiate a CompressionModel from a given pretrained model.

+

Args

+
+
name : Path or str
+
name of the pretrained model. See after.
+
device : torch.device or str
+
Device on which the model is loaded.
+
+

Pretrained models: +- dac_44khz (https://github.com/descriptinc/descript-audio-codec) +- dac_24khz (same) +- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) +- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) +- your own model on HugginFace. Export instructions to come…

+
+ +Expand source code + +
@staticmethod
+def get_pretrained(
+        name: str, device: tp.Union[torch.device, str] = 'cpu'
+        ) -> 'CompressionModel':
+    """Instantiate a CompressionModel from a given pretrained model.
+
+    Args:
+        name (Path or str): name of the pretrained model. See after.
+        device (torch.device or str): Device on which the model is loaded.
+
+    Pretrained models:
+        - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
+        - dac_24khz (same)
+        - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
+        - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
+        - your own model on HugginFace. Export instructions to come...
+    """
+
+    from . import builders, loaders
+    model: CompressionModel
+    if name in ['dac_44khz', 'dac_24khz']:
+        model_type = name.split('_')[1]
+        logger.info("Getting pretrained compression model from DAC %s", model_type)
+        model = DAC(model_type)
+    elif name in ['debug_compression_model']:
+        logger.info("Getting pretrained compression model for debug")
+        model = builders.get_debug_compression_model()
+    elif Path(name).exists():
+        # We assume here if the paths exist that it is in fact an AC checkpoint
+        # that was exported using `audiocraft.utils.export` functions.
+        model = loaders.load_compression_model(name, device=device)
+    else:
+        logger.info("Getting pretrained compression model from HF %s", name)
+        hf_model = HFEncodecModel.from_pretrained(name)
+        model = HFEncodecCompressionModel(hf_model).to(device)
+    return model.to(device).eval()
+
+
+
+

Instance variables

+
+
var cardinality : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def cardinality(self) -> int:
+    ...
+
+
+
var channels : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def channels(self) -> int:
+    ...
+
+
+
var frame_rate : float
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def frame_rate(self) -> float:
+    ...
+
+
+
var num_codebooks : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def num_codebooks(self) -> int:
+    ...
+
+
+
var sample_rate : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def sample_rate(self) -> int:
+    ...
+
+
+
var total_codebooks : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def total_codebooks(self) -> int:
+    ...
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) +
+
+ +
+ +Expand source code + +
@abstractmethod
+def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+    """See `EncodecModel.decode`."""
+    ...
+
+
+
+def decode_latent(self, codes: torch.Tensor) +
+
+

Decode from the discrete codes to continuous latent space.

+
+ +Expand source code + +
@abstractmethod
+def decode_latent(self, codes: torch.Tensor):
+    """Decode from the discrete codes to continuous latent space."""
+    ...
+
+
+
+def encode(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+ +
+ +Expand source code + +
@abstractmethod
+def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+    """See `EncodecModel.encode`."""
+    ...
+
+
+
+def forward(self, x: torch.Tensor) ‑> QuantizedResult +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
@abstractmethod
+def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+    ...
+
+
+
+def set_num_codebooks(self, n: int) +
+
+

Set the active number of codebooks used by the quantizer.

+
+ +Expand source code + +
@abstractmethod
+def set_num_codebooks(self, n: int):
+    """Set the active number of codebooks used by the quantizer."""
+    ...
+
+
+
+
+
+class DAC +(model_type: str = '44khz') +
+
+

Base API for all compression model that aim at being used as audio tokenizers +with a language model.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DAC(CompressionModel):
+    def __init__(self, model_type: str = "44khz"):
+        super().__init__()
+        try:
+            import dac.utils
+        except ImportError:
+            raise RuntimeError("Could not import dac, make sure it is installed, "
+                               "please run `pip install descript-audio-codec`")
+        self.model = dac.utils.load_model(model_type=model_type)
+        self.n_quantizers = self.total_codebooks
+        self.model.eval()
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        # We don't support training with this.
+        raise NotImplementedError("Forward and training with DAC not supported.")
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        codes = self.model.encode(x, self.n_quantizers)[1]
+        return codes[:, :self.n_quantizers], None
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        assert scale is None
+        z_q = self.decode_latent(codes)
+        return self.model.decode(z_q)
+
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        return self.model.quantizer.from_codes(codes)[0]
+
+    @property
+    def channels(self) -> int:
+        return 1
+
+    @property
+    def frame_rate(self) -> float:
+        return self.model.sample_rate / self.model.hop_length
+
+    @property
+    def sample_rate(self) -> int:
+        return self.model.sample_rate
+
+    @property
+    def cardinality(self) -> int:
+        return self.model.codebook_size
+
+    @property
+    def num_codebooks(self) -> int:
+        return self.n_quantizers
+
+    @property
+    def total_codebooks(self) -> int:
+        return self.model.n_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        assert n >= 1
+        assert n <= self.total_codebooks
+        self.n_quantizers = n
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var cardinality : int
+
+
+
+ +Expand source code + +
@property
+def cardinality(self) -> int:
+    return self.model.codebook_size
+
+
+
var channels : int
+
+
+
+ +Expand source code + +
@property
+def channels(self) -> int:
+    return 1
+
+
+
var frame_rate : float
+
+
+
+ +Expand source code + +
@property
+def frame_rate(self) -> float:
+    return self.model.sample_rate / self.model.hop_length
+
+
+
var num_codebooks : int
+
+
+
+ +Expand source code + +
@property
+def num_codebooks(self) -> int:
+    return self.n_quantizers
+
+
+
var sample_rate : int
+
+
+
+ +Expand source code + +
@property
+def sample_rate(self) -> int:
+    return self.model.sample_rate
+
+
+
var total_codebooks : int
+
+
+
+ +Expand source code + +
@property
+def total_codebooks(self) -> int:
+    return self.model.n_codebooks
+
+
+
+

Inherited members

+ +
+
+class EncodecModel +(encoder: torch.nn.modules.module.Module, decoder: torch.nn.modules.module.Module, quantizer: BaseQuantizer, frame_rate: int, sample_rate: int, channels: int, causal: bool = False, renormalize: bool = False) +
+
+

Encodec model operating on the raw waveform.

+

Args

+
+
encoder : nn.Module
+
Encoder network.
+
decoder : nn.Module
+
Decoder network.
+
quantizer : qt.BaseQuantizer
+
Quantizer network.
+
frame_rate : int
+
Frame rate for the latent representation.
+
sample_rate : int
+
Audio sample rate.
+
channels : int
+
Number of audio channels.
+
causal : bool
+
Whether to use a causal version of the model.
+
renormalize : bool
+
Whether to renormalize the audio before running the model.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class EncodecModel(CompressionModel):
+    """Encodec model operating on the raw waveform.
+
+    Args:
+        encoder (nn.Module): Encoder network.
+        decoder (nn.Module): Decoder network.
+        quantizer (qt.BaseQuantizer): Quantizer network.
+        frame_rate (int): Frame rate for the latent representation.
+        sample_rate (int): Audio sample rate.
+        channels (int): Number of audio channels.
+        causal (bool): Whether to use a causal version of the model.
+        renormalize (bool): Whether to renormalize the audio before running the model.
+    """
+    # we need assignment to override the property in the abstract class,
+    # I couldn't find a better way...
+    frame_rate: float = 0
+    sample_rate: int = 0
+    channels: int = 0
+
+    def __init__(self,
+                 encoder: nn.Module,
+                 decoder: nn.Module,
+                 quantizer: qt.BaseQuantizer,
+                 frame_rate: int,
+                 sample_rate: int,
+                 channels: int,
+                 causal: bool = False,
+                 renormalize: bool = False):
+        super().__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.quantizer = quantizer
+        self.frame_rate = frame_rate
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.renormalize = renormalize
+        self.causal = causal
+        if self.causal:
+            # we force disabling here to avoid handling linear overlap of segments
+            # as supported in original EnCodec codebase.
+            assert not self.renormalize, 'Causal model does not support renormalize'
+
+    @property
+    def total_codebooks(self):
+        """Total number of quantizer codebooks available."""
+        return self.quantizer.total_codebooks
+
+    @property
+    def num_codebooks(self):
+        """Active number of codebooks used by the quantizer."""
+        return self.quantizer.num_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer."""
+        self.quantizer.set_num_codebooks(n)
+
+    @property
+    def cardinality(self):
+        """Cardinality of each codebook."""
+        return self.quantizer.bins
+
+    def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        scale: tp.Optional[torch.Tensor]
+        if self.renormalize:
+            mono = x.mean(dim=1, keepdim=True)
+            volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+            scale = 1e-8 + volume
+            x = x / scale
+            scale = scale.view(-1, 1)
+        else:
+            scale = None
+        return x, scale
+
+    def postprocess(self,
+                    x: torch.Tensor,
+                    scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+        if scale is not None:
+            assert self.renormalize
+            x = x * scale.view(-1, 1, 1)
+        return x
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        assert x.dim() == 3
+        length = x.shape[-1]
+        x, scale = self.preprocess(x)
+
+        emb = self.encoder(x)
+        q_res = self.quantizer(emb, self.frame_rate)
+        out = self.decoder(q_res.x)
+
+        # remove extra padding added by the encoder and decoder
+        assert out.shape[-1] >= length, (out.shape[-1], length)
+        out = out[..., :length]
+
+        q_res.x = self.postprocess(out, scale)
+
+        return q_res
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """Encode the given input tensor to quantized representation along with scale parameter.
+
+        Args:
+            x (torch.Tensor): Float tensor of shape [B, C, T]
+
+        Returns:
+            codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
+                codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+                scale a float tensor containing the scale for audio renormalizealization.
+        """
+        assert x.dim() == 3
+        x, scale = self.preprocess(x)
+        emb = self.encoder(x)
+        codes = self.quantizer.encode(emb)
+        return codes, scale
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """Decode the given codes to a reconstructed representation, using the scale to perform
+        audio denormalization if needed.
+
+        Args:
+            codes (torch.Tensor): Int tensor of shape [B, K, T]
+            scale (torch.Tensor, optional): Float tensor containing the scale value.
+
+        Returns:
+            out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+        """
+        emb = self.decode_latent(codes)
+        out = self.decoder(emb)
+        out = self.postprocess(out, scale)
+        # out contains extra padding added by the encoder and decoder
+        return out
+
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        return self.quantizer.decode(codes)
+
+

Ancestors

+ +

Class variables

+
+
var channels : int
+
+
+
+
var frame_rate : float
+
+
+
+
var sample_rate : int
+
+
+
+
+

Instance variables

+
+
var cardinality
+
+

Cardinality of each codebook.

+
+ +Expand source code + +
@property
+def cardinality(self):
+    """Cardinality of each codebook."""
+    return self.quantizer.bins
+
+
+
var num_codebooks
+
+

Active number of codebooks used by the quantizer.

+
+ +Expand source code + +
@property
+def num_codebooks(self):
+    """Active number of codebooks used by the quantizer."""
+    return self.quantizer.num_codebooks
+
+
+
var total_codebooks
+
+

Total number of quantizer codebooks available.

+
+ +Expand source code + +
@property
+def total_codebooks(self):
+    """Total number of quantizer codebooks available."""
+    return self.quantizer.total_codebooks
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) +
+
+

Decode the given codes to a reconstructed representation, using the scale to perform +audio denormalization if needed.

+

Args

+
+
codes : torch.Tensor
+
Int tensor of shape [B, K, T]
+
scale : torch.Tensor, optional
+
Float tensor containing the scale value.
+
+

Returns

+

out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.

+
+ +Expand source code + +
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+    """Decode the given codes to a reconstructed representation, using the scale to perform
+    audio denormalization if needed.
+
+    Args:
+        codes (torch.Tensor): Int tensor of shape [B, K, T]
+        scale (torch.Tensor, optional): Float tensor containing the scale value.
+
+    Returns:
+        out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+    """
+    emb = self.decode_latent(codes)
+    out = self.decoder(emb)
+    out = self.postprocess(out, scale)
+    # out contains extra padding added by the encoder and decoder
+    return out
+
+
+
+def encode(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+

Encode the given input tensor to quantized representation along with scale parameter.

+

Args

+
+
x : torch.Tensor
+
Float tensor of shape [B, C, T]
+
+

Returns

+

codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: +codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. +scale a float tensor containing the scale for audio renormalizealization.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+    """Encode the given input tensor to quantized representation along with scale parameter.
+
+    Args:
+        x (torch.Tensor): Float tensor of shape [B, C, T]
+
+    Returns:
+        codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
+            codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+            scale a float tensor containing the scale for audio renormalizealization.
+    """
+    assert x.dim() == 3
+    x, scale = self.preprocess(x)
+    emb = self.encoder(x)
+    codes = self.quantizer.encode(emb)
+    return codes, scale
+
+
+
+def postprocess(self, x: torch.Tensor, scale: Optional[torch.Tensor] = None) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def postprocess(self,
+                x: torch.Tensor,
+                scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+    if scale is not None:
+        assert self.renormalize
+        x = x * scale.view(-1, 1, 1)
+    return x
+
+
+
+def preprocess(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+
+
+ +Expand source code + +
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+    scale: tp.Optional[torch.Tensor]
+    if self.renormalize:
+        mono = x.mean(dim=1, keepdim=True)
+        volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+        scale = 1e-8 + volume
+        x = x / scale
+        scale = scale.view(-1, 1)
+    else:
+        scale = None
+    return x, scale
+
+
+
+

Inherited members

+ +
+
+class HFEncodecCompressionModel +(model: transformers.models.encodec.modeling_encodec.EncodecModel) +
+
+

Wrapper around HuggingFace Encodec.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class HFEncodecCompressionModel(CompressionModel):
+    """Wrapper around HuggingFace Encodec.
+    """
+    def __init__(self, model: HFEncodecModel):
+        super().__init__()
+        self.model = model
+        bws = self.model.config.target_bandwidths
+        num_codebooks = [
+            bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
+            for bw in bws
+        ]
+        deltas = [nc - int(nc) for nc in num_codebooks]
+        # Checking we didn't do some bad maths and we indeed have integers!
+        assert all(deltas) <= 1e-3, deltas
+        self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
+        self.set_num_codebooks(max(self.possible_num_codebooks))
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        # We don't support training with this.
+        raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
+        bandwidth = self.model.config.target_bandwidths[bandwidth_index]
+        res = self.model.encode(x, None, bandwidth)
+        assert len(res[0]) == 1
+        assert len(res[1]) == 1
+        return res[0][0], res[1][0]
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        if scale is None:
+            scales = [None]  # type: ignore
+        else:
+            scales = scale  # type: ignore
+        res = self.model.decode(codes[None], scales)
+        return res[0]
+
+    def decode_latent(self, codes: torch.Tensor):
+        """Decode from the discrete codes to continuous latent space."""
+        return self.model.quantizer.decode(codes.transpose(0, 1))
+
+    @property
+    def channels(self) -> int:
+        return self.model.config.audio_channels
+
+    @property
+    def frame_rate(self) -> float:
+        hop_length = int(np.prod(self.model.config.upsampling_ratios))
+        return self.sample_rate / hop_length
+
+    @property
+    def sample_rate(self) -> int:
+        return self.model.config.sampling_rate
+
+    @property
+    def cardinality(self) -> int:
+        return self.model.config.codebook_size
+
+    @property
+    def num_codebooks(self) -> int:
+        return self._num_codebooks
+
+    @property
+    def total_codebooks(self) -> int:
+        return max(self.possible_num_codebooks)
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        if n not in self.possible_num_codebooks:
+            raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
+        self._num_codebooks = n
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var cardinality : int
+
+
+
+ +Expand source code + +
@property
+def cardinality(self) -> int:
+    return self.model.config.codebook_size
+
+
+
var channels : int
+
+
+
+ +Expand source code + +
@property
+def channels(self) -> int:
+    return self.model.config.audio_channels
+
+
+
var frame_rate : float
+
+
+
+ +Expand source code + +
@property
+def frame_rate(self) -> float:
+    hop_length = int(np.prod(self.model.config.upsampling_ratios))
+    return self.sample_rate / hop_length
+
+
+
var num_codebooks : int
+
+
+
+ +Expand source code + +
@property
+def num_codebooks(self) -> int:
+    return self._num_codebooks
+
+
+
var sample_rate : int
+
+
+
+ +Expand source code + +
@property
+def sample_rate(self) -> int:
+    return self.model.config.sampling_rate
+
+
+
var total_codebooks : int
+
+
+
+ +Expand source code + +
@property
+def total_codebooks(self) -> int:
+    return max(self.possible_num_codebooks)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/index.html b/api_docs/audiocraft/models/index.html new file mode 100644 index 00000000..78c60d86 --- /dev/null +++ b/api_docs/audiocraft/models/index.html @@ -0,0 +1,132 @@ + + + + + + +audiocraft.models API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models

+
+
+

Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
+"""
+# flake8: noqa
+from . import builders, loaders
+from .encodec import (
+    CompressionModel, EncodecModel, DAC,
+    HFEncodecModel, HFEncodecCompressionModel)
+from .audiogen import AudioGen
+from .lm import LMModel
+from .multibanddiffusion import MultiBandDiffusion
+from .musicgen import MusicGen
+from .unet import DiffusionUnet
+
+
+
+

Sub-modules

+
+
audiocraft.models.audiogen
+
+

Main model for using AudioGen. This will combine all the required components +and provide easy access to the generation API.

+
+
audiocraft.models.builders
+
+

All the functions to build the relevant models and modules +from the Hydra config.

+
+
audiocraft.models.encodec
+
+

Compression models or wrapper around existing models. +Also defines the main interface that a model must follow to be usable as an audio tokenizer.

+
+
audiocraft.models.lm
+
+
+
+
audiocraft.models.loaders
+
+

Utility functions to load from the checkpoints. +Each checkpoint is a torch.saved dict with the following keys: +- 'xp.cfg': the hydra config as dumped …

+
+
audiocraft.models.multibanddiffusion
+
+

Multi Band Diffusion models as described in +"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" +(paper link).

+
+
audiocraft.models.musicgen
+
+

Main model for using MusicGen. This will combine all the required components +and provide easy access to the generation API.

+
+
audiocraft.models.unet
+
+

Pytorch Unet Module used for diffusion.

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/lm.html b/api_docs/audiocraft/models/lm.html new file mode 100644 index 00000000..facc67f2 --- /dev/null +++ b/api_docs/audiocraft/models/lm.html @@ -0,0 +1,1745 @@ + + + + + + +audiocraft.models.lm API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.lm

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from functools import partial
+import logging
+import math
+import typing as tp
+
+import torch
+from torch import nn
+
+from ..utils import utils
+from ..modules.streaming import StreamingModule, State
+from ..modules.transformer import StreamingTransformer, create_norm_fn
+from ..modules.conditioners import (
+    ConditionFuser,
+    ClassifierFreeGuidanceDropout,
+    AttributeDropout,
+    ConditioningProvider,
+    ConditioningAttributes,
+    ConditionType,
+)
+from ..modules.codebooks_patterns import CodebooksPatternProvider
+from ..modules.activations import get_activation_fn
+
+
+logger = logging.getLogger(__name__)
+ConditionTensors = tp.Dict[str, ConditionType]
+CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
+
+
+def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
+    """LM layer initialization.
+    Inspired from xlformers: https://github.com/fairinternal/xlformers
+
+    Args:
+        method (str): Method name for init function. Valid options are:
+            'gaussian', 'uniform'.
+        input_dim (int): Input dimension of the initialized module.
+        init_depth (int, optional): Optional init depth value used to rescale
+            the standard deviation if defined.
+    """
+    # Compute std
+    std = 1 / math.sqrt(input_dim)
+    # Rescale with depth
+    if init_depth is not None:
+        std = std / math.sqrt(2 * init_depth)
+
+    if method == 'gaussian':
+        return partial(
+            torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
+        )
+    elif method == 'uniform':
+        bound = math.sqrt(3) * std  # ensure the standard deviation is `std`
+        return partial(torch.nn.init.uniform_, a=-bound, b=bound)
+    else:
+        raise ValueError("Unsupported layer initialization method")
+
+
+def init_layer(m: nn.Module,
+               method: str,
+               init_depth: tp.Optional[int] = None,
+               zero_bias_init: bool = False):
+    """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
+
+    Args:
+        m (nn.Module): Module to initialize.
+        method (str): Method name for the init function.
+        init_depth (int, optional): Optional init depth value used to rescale
+            the standard deviation if defined.
+        zero_bias_init (bool): Whether to initialize the bias to 0 or not.
+    """
+    if isinstance(m, nn.Linear):
+        init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+        if zero_bias_init and m.bias is not None:
+            nn.init.constant_(m.bias, 0)
+    elif isinstance(m, nn.Embedding):
+        init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+
+
+class ScaledEmbedding(nn.Embedding):
+    """Boost learning rate for embeddings (with `scale`).
+    """
+    def __init__(self, *args, lr=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.lr = lr
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        return group
+
+
+@dataclass
+class LMOutput:
+    # The logits are already re-aligned with the input codes
+    # hence no extra shift is required, e.g. when computing CE
+    logits: torch.Tensor  # [B, K, T, card]
+    mask: torch.Tensor  # [B, K, T]
+
+
+class LMModel(StreamingModule):
+    """Transformer-based language model on multiple streams of codes.
+
+    Args:
+        pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
+        condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
+        fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
+        n_q (int): Number of parallel streams to model.
+        card (int): Cardinality, vocabulary size.
+        dim (int): Dimension of the transformer encoder.
+        num_heads (int): Number of heads for the transformer encoder.
+        hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
+        norm (str): Normalization method.
+        norm_first (bool): Use pre-norm instead of post-norm.
+        emb_lr (float, optional): Embedding-specific learning rate.
+        bias_proj (bool): Use bias for output projections.
+        weight_init (str, optional): Method for weight initialization.
+        depthwise_init (str, optional): Method for depthwise weight initialization.
+        zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
+        cfg_dropout (float): Classifier-free guidance dropout.
+        cfg_coef (float): Classifier-free guidance coefficient.
+        attribute_dropout (dict): Attribute dropout probabilities.
+        two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
+        **kwargs: Additional parameters for the transformer encoder.
+    """
+    def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
+                 fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
+                 hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
+                 emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
+                 weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
+                 zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
+                 attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
+                 **kwargs):
+        super().__init__()
+        self.cfg_coef = cfg_coef
+        self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
+        self.att_dropout = AttributeDropout(p=attribute_dropout)
+        self.condition_provider = condition_provider
+        self.fuser = fuser
+        self.card = card
+        embed_dim = self.card + 1
+        self.n_q = n_q
+        self.dim = dim
+        self.pattern_provider = pattern_provider
+        self.two_step_cfg = two_step_cfg
+        self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
+        if 'activation' in kwargs:
+            kwargs['activation'] = get_activation_fn(kwargs['activation'])
+        self.transformer = StreamingTransformer(
+            d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
+            norm=norm, norm_first=norm_first, **kwargs)
+        self.out_norm: tp.Optional[nn.Module] = None
+        if norm_first:
+            self.out_norm = create_norm_fn(norm, dim)
+        self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
+        self._init_weights(weight_init, depthwise_init, zero_bias_init)
+        self._fsdp: tp.Optional[nn.Module]
+        self.__dict__['_fsdp'] = None
+
+    def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
+        """Initialization of the transformer module weights.
+
+        Args:
+            weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
+            depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
+                'current' where the depth corresponds to the current layer index or 'global' where the total number
+                of layer is used as depth. If not set, no depthwise initialization strategy is used.
+            zero_bias_init (bool): Whether to initialize bias to zero or not.
+        """
+        assert depthwise_init is None or depthwise_init in ['current', 'global']
+        assert depthwise_init is None or weight_init is not None, \
+            "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
+        assert not zero_bias_init or weight_init is not None, \
+            "If 'zero_bias_init', a 'weight_init' method should be provided"
+
+        if weight_init is None:
+            return
+
+        for emb_layer in self.emb:
+            init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+        for layer_idx, tr_layer in enumerate(self.transformer.layers):
+            depth = None
+            if depthwise_init == 'current':
+                depth = layer_idx + 1
+            elif depthwise_init == 'global':
+                depth = len(self.transformer.layers)
+            init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
+            tr_layer.apply(init_fn)
+
+        for linear in self.linears:
+            init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+    @property
+    def special_token_id(self) -> int:
+        return self.card
+
+    @property
+    def num_codebooks(self) -> int:
+        return self.n_q
+
+    def forward(self, sequence: torch.Tensor,
+                conditions: tp.List[ConditioningAttributes],
+                condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+        """Apply language model on sequence and conditions.
+        Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+        S the sequence steps, return the logits with shape [B, card, K, S].
+
+        Args:
+            indices (torch.Tensor): Indices of the codes to model.
+            conditions (list of ConditioningAttributes): Conditions to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            torch.Tensor: Logits.
+        """
+        B, K, S = sequence.shape
+        assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
+        input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+        if condition_tensors is None:
+            assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+            # apply dropout modules
+            conditions = self.cfg_dropout(conditions)
+            conditions = self.att_dropout(conditions)
+            tokenized = self.condition_provider.tokenize(conditions)
+            # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+            condition_tensors = self.condition_provider(tokenized)
+        else:
+            assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+        input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+        out = self.transformer(input_, cross_attention_src=cross_attention_input)
+        if self.out_norm:
+            out = self.out_norm(out)
+        logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)  # [B, K, S, card]
+
+        # remove the prefix from the model outputs
+        if len(self.fuser.fuse2cond['prepend']) > 0:
+            logits = logits[:, :, -S:]
+
+        return logits  # [B, K, S, card]
+
+    def compute_predictions(
+            self, codes: torch.Tensor,
+            conditions: tp.List[ConditioningAttributes],
+            condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+        """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+        forward using the specified codes interleaving pattern.
+
+        Args:
+            codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+                K the number of codebooks and T the number of timesteps.
+            conditions (list of ConditioningAttributes): conditionings to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            LMOutput: Language model outputs
+                logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+                    i.e. the first item corresponds to logits to predict the first code, meaning that
+                    no additional shifting of codes and logits is required.
+                mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+                    Given the specified interleaving strategies, parts of the logits and codes should
+                    not be considered as valid predictions because of invalid context.
+        """
+        B, K, T = codes.shape
+        codes = codes.contiguous()
+        # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+        pattern = self.pattern_provider.get_pattern(T)
+        sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+            codes, self.special_token_id, keep_only_valid_steps=True
+        )
+        # apply model on pattern sequence
+        model = self if self._fsdp is None else self._fsdp
+        logits = model(sequence_codes, conditions, condition_tensors)  # [B, K, S, card]
+        # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+        # and provide the corresponding mask over invalid positions of tokens
+        logits = logits.permute(0, 3, 1, 2)  # [B, card, K, S]
+        # note: we use nans as special token to make it obvious if we feed unexpected logits
+        logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+            logits, float('nan'), keep_only_valid_steps=True
+        )
+        logits = logits.permute(0, 2, 3, 1)  # [B, K, T, card]
+        logits_mask = logits_mask[None, :, :].expand(B, -1, -1)  # [K, T] -> [B, K, T]
+        return LMOutput(logits, logits_mask)
+
+    def _sample_next_token(self,
+                           sequence: torch.Tensor,
+                           cfg_conditions: CFGConditions,
+                           unconditional_state: State,
+                           use_sampling: bool = False,
+                           temp: float = 1.0,
+                           top_k: int = 0,
+                           top_p: float = 0.0,
+                           cfg_coef: tp.Optional[float] = None,
+                           two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
+        """Sample next token from the model given a sequence and a set of conditions. The model supports
+        multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
+
+        Args:
+            sequence (torch.Tensor): Current sequence of shape [B, K, S]
+                with K corresponding to the number of codebooks and S the number of sequence steps.
+                S = 1 in streaming mode, except for the first step that contains a bigger prompt.
+            condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
+                should be twice the batch size, being the concatenation of the conditions + null conditions.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            cfg_coef (float, optional): classifier free guidance coefficient
+        Returns:
+            next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
+        """
+        B = sequence.shape[0]
+        cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
+        model = self if self._fsdp is None else self._fsdp
+        two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+        if two_step_cfg and cfg_conditions != {}:
+            assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
+            condition_tensors, null_condition_tensors = cfg_conditions
+            cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
+            state = self.get_streaming_state()
+            self.set_streaming_state(unconditional_state)
+            uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
+            unconditional_state.update(self.get_streaming_state())
+            self.set_streaming_state(state)
+            logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
+        else:
+            assert isinstance(cfg_conditions, dict)
+            condition_tensors = cfg_conditions
+            if condition_tensors:
+                # Preparing for CFG, predicting both conditional and unconditional logits.
+                sequence = torch.cat([sequence, sequence], dim=0)
+            all_logits = model(
+                sequence,
+                conditions=[], condition_tensors=condition_tensors)
+            if condition_tensors:
+                cond_logits, uncond_logits = all_logits.split(B, dim=0)  # [B, K, T, card]
+                logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
+            else:
+                logits = all_logits
+
+        logits = logits.permute(0, 1, 3, 2)  # [B, K, card, T]
+        logits = logits[..., -1]  # [B x K x card]
+
+        # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
+        if use_sampling and temp > 0.0:
+            probs = torch.softmax(logits / temp, dim=-1)
+            if top_p > 0.0:
+                next_token = utils.sample_top_p(probs, p=top_p)
+            elif top_k > 0:
+                next_token = utils.sample_top_k(probs, k=top_k)
+            else:
+                next_token = utils.multinomial(probs, num_samples=1)
+        else:
+            next_token = torch.argmax(logits, dim=-1, keepdim=True)
+
+        return next_token
+
+    @torch.no_grad()
+    def generate(self,
+                 prompt: tp.Optional[torch.Tensor] = None,
+                 conditions: tp.List[ConditioningAttributes] = [],
+                 num_samples: tp.Optional[int] = None,
+                 max_gen_len: int = 256,
+                 use_sampling: bool = True,
+                 temp: float = 1.0,
+                 top_k: int = 250,
+                 top_p: float = 0.0,
+                 cfg_coef: tp.Optional[float] = None,
+                 two_step_cfg: tp.Optional[bool] = None,
+                 remove_prompts: bool = False,
+                 check: bool = False,
+                 callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+        """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+        be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+        Args:
+            prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
+            conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
+            num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
+            max_gen_len (int): Maximum generation length.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            cfg_coeff (float, optional): Classifier-free guidance coefficient.
+            two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
+            remove_prompts (bool): Whether to remove prompts from generation or not.
+            check (bool): Whether to apply further checks on generated sequence.
+            callback (Callback, optional): Callback function to report generation progress.
+        Returns:
+            torch.Tensor: Generated tokens.
+        """
+        assert not self.training, "generation shouldn't be used in training mode."
+        first_param = next(iter(self.parameters()))
+        device = first_param.device
+
+        # Checking all input shapes are consistent.
+        possible_num_samples = []
+        if num_samples is not None:
+            possible_num_samples.append(num_samples)
+        elif prompt is not None:
+            possible_num_samples.append(prompt.shape[0])
+        elif conditions:
+            possible_num_samples.append(len(conditions))
+        else:
+            possible_num_samples.append(1)
+        assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
+        num_samples = possible_num_samples[0]
+
+        # below we create set of conditions: one conditional and one unconditional
+        # to do that we merge the regular condition together with the null condition
+        # we then do 1 forward pass instead of 2.
+        # the reason for that is two-fold:
+        # 1. it is about x2 faster than doing 2 forward passes
+        # 2. avoid the streaming API treating the 2 passes as part of different time steps
+        # We also support doing two different passes, in particular to ensure that
+        # the padding structure is exactly the same between train and test.
+        # With a batch size of 1, this can be slower though.
+        cfg_conditions: CFGConditions
+        two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+        if conditions:
+            null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+            if two_step_cfg:
+                cfg_conditions = (
+                    self.condition_provider(self.condition_provider.tokenize(conditions)),
+                    self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+                )
+            else:
+                conditions = conditions + null_conditions
+                tokenized = self.condition_provider.tokenize(conditions)
+                cfg_conditions = self.condition_provider(tokenized)
+        else:
+            cfg_conditions = {}
+
+        if prompt is None:
+            assert num_samples > 0
+            prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+        B, K, T = prompt.shape
+        start_offset = T
+        assert start_offset < max_gen_len
+
+        pattern = self.pattern_provider.get_pattern(max_gen_len)
+        # this token is used as default value for codes that are not generated yet
+        unknown_token = -1
+
+        # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+        gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+        # filling the gen_codes with the prompt if needed
+        gen_codes[..., :start_offset] = prompt
+        # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+        gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+        # retrieve the start_offset in the sequence:
+        # it is the first sequence step that contains the `start_offset` timestep
+        start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+        assert start_offset_sequence is not None
+
+        with self.streaming():
+            unconditional_state = self.get_streaming_state()
+            prev_offset = 0
+            gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
+            for offset in range(start_offset_sequence, gen_sequence_len):
+                # get current sequence (note that the streaming API is providing the caching over previous offsets)
+                curr_sequence = gen_sequence[..., prev_offset:offset]
+                curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+                if check:
+                    # check coherence between mask and sequence
+                    assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+                    # should never happen as gen_sequence is filled progressively
+                    assert not (curr_sequence == unknown_token).any()
+                # sample next token from the model, next token shape is [B, K, 1]
+                next_token = self._sample_next_token(
+                    curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+                    cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
+                # ensure the tokens that should be masked are properly set to special_token_id
+                # as the model never output special_token_id
+                valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+                next_token[~valid_mask] = self.special_token_id
+                # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+                # (then mask tokens should be left as is as well, which is correct)
+                gen_sequence[..., offset:offset+1] = torch.where(
+                    gen_sequence[..., offset:offset+1] == unknown_token,
+                    next_token, gen_sequence[..., offset:offset+1]
+                )
+                prev_offset = offset
+                if callback is not None:
+                    callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+        unconditional_state.clear()
+
+        # ensure sequence has been entirely filled
+        assert not (gen_sequence == unknown_token).any()
+        # ensure gen_sequence pattern and mask are matching
+        # which means the gen_sequence is valid according to the pattern
+        assert (
+            gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+        ).all()
+        # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+        out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+        # sanity checks over the returned codes and corresponding masks
+        assert (out_codes[..., :max_gen_len] != unknown_token).all()
+        assert (out_mask[..., :max_gen_len] == 1).all()
+
+        out_start_offset = start_offset if remove_prompts else 0
+        out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+        # ensure the returned codes are all valid
+        assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+        return out_codes
+
+
+
+
+
+
+
+

Functions

+
+
+def get_init_fn(method: str, input_dim: int, init_depth: Optional[int] = None) +
+
+

LM layer initialization. +Inspired from xlformers: https://github.com/fairinternal/xlformers

+

Args

+
+
method : str
+
Method name for init function. Valid options are: +'gaussian', 'uniform'.
+
input_dim : int
+
Input dimension of the initialized module.
+
init_depth : int, optional
+
Optional init depth value used to rescale +the standard deviation if defined.
+
+
+ +Expand source code + +
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
+    """LM layer initialization.
+    Inspired from xlformers: https://github.com/fairinternal/xlformers
+
+    Args:
+        method (str): Method name for init function. Valid options are:
+            'gaussian', 'uniform'.
+        input_dim (int): Input dimension of the initialized module.
+        init_depth (int, optional): Optional init depth value used to rescale
+            the standard deviation if defined.
+    """
+    # Compute std
+    std = 1 / math.sqrt(input_dim)
+    # Rescale with depth
+    if init_depth is not None:
+        std = std / math.sqrt(2 * init_depth)
+
+    if method == 'gaussian':
+        return partial(
+            torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
+        )
+    elif method == 'uniform':
+        bound = math.sqrt(3) * std  # ensure the standard deviation is `std`
+        return partial(torch.nn.init.uniform_, a=-bound, b=bound)
+    else:
+        raise ValueError("Unsupported layer initialization method")
+
+
+
+def init_layer(m: torch.nn.modules.module.Module, method: str, init_depth: Optional[int] = None, zero_bias_init: bool = False) +
+
+

Wrapper around get_init_fn() for proper initialization of LM modules.

+

Args

+
+
m : nn.Module
+
Module to initialize.
+
method : str
+
Method name for the init function.
+
init_depth : int, optional
+
Optional init depth value used to rescale +the standard deviation if defined.
+
zero_bias_init : bool
+
Whether to initialize the bias to 0 or not.
+
+
+ +Expand source code + +
def init_layer(m: nn.Module,
+               method: str,
+               init_depth: tp.Optional[int] = None,
+               zero_bias_init: bool = False):
+    """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
+
+    Args:
+        m (nn.Module): Module to initialize.
+        method (str): Method name for the init function.
+        init_depth (int, optional): Optional init depth value used to rescale
+            the standard deviation if defined.
+        zero_bias_init (bool): Whether to initialize the bias to 0 or not.
+    """
+    if isinstance(m, nn.Linear):
+        init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+        if zero_bias_init and m.bias is not None:
+            nn.init.constant_(m.bias, 0)
+    elif isinstance(m, nn.Embedding):
+        init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+
+
+
+
+
+

Classes

+
+
+class LMModel +(pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, emb_lr: Optional[float] = None, bias_proj: bool = True, weight_init: Optional[str] = None, depthwise_init: Optional[str] = None, zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, attribute_dropout: Dict[str, Dict[str, float]] = {}, two_step_cfg: bool = False, **kwargs) +
+
+

Transformer-based language model on multiple streams of codes.

+

Args

+
+
pattern_provider : CodebooksPatternProvider
+
Pattern provider for codebook interleaving.
+
condition_provider : MusicConditioningProvider
+
Conditioning provider from metadata.
+
fuser : ConditionFuser
+
Fuser handling the fusing of conditions with language model input.
+
n_q : int
+
Number of parallel streams to model.
+
card : int
+
Cardinality, vocabulary size.
+
dim : int
+
Dimension of the transformer encoder.
+
num_heads : int
+
Number of heads for the transformer encoder.
+
hidden_scale : int
+
Scale for hidden feed forward dimension of the transformer encoder.
+
norm : str
+
Normalization method.
+
norm_first : bool
+
Use pre-norm instead of post-norm.
+
emb_lr : float, optional
+
Embedding-specific learning rate.
+
bias_proj : bool
+
Use bias for output projections.
+
weight_init : str, optional
+
Method for weight initialization.
+
depthwise_init : str, optional
+
Method for depthwise weight initialization.
+
zero_bias_init : bool
+
If true and bias in Linears, initialize bias to zeros.
+
cfg_dropout : float
+
Classifier-free guidance dropout.
+
cfg_coef : float
+
Classifier-free guidance coefficient.
+
attribute_dropout : dict
+
Attribute dropout probabilities.
+
two_step_cfg : bool
+
Whether to run classifier free-guidance with 2 distinct steps.
+
**kwargs
+
Additional parameters for the transformer encoder.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class LMModel(StreamingModule):
+    """Transformer-based language model on multiple streams of codes.
+
+    Args:
+        pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
+        condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
+        fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
+        n_q (int): Number of parallel streams to model.
+        card (int): Cardinality, vocabulary size.
+        dim (int): Dimension of the transformer encoder.
+        num_heads (int): Number of heads for the transformer encoder.
+        hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
+        norm (str): Normalization method.
+        norm_first (bool): Use pre-norm instead of post-norm.
+        emb_lr (float, optional): Embedding-specific learning rate.
+        bias_proj (bool): Use bias for output projections.
+        weight_init (str, optional): Method for weight initialization.
+        depthwise_init (str, optional): Method for depthwise weight initialization.
+        zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
+        cfg_dropout (float): Classifier-free guidance dropout.
+        cfg_coef (float): Classifier-free guidance coefficient.
+        attribute_dropout (dict): Attribute dropout probabilities.
+        two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
+        **kwargs: Additional parameters for the transformer encoder.
+    """
+    def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
+                 fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
+                 hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
+                 emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
+                 weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
+                 zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
+                 attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
+                 **kwargs):
+        super().__init__()
+        self.cfg_coef = cfg_coef
+        self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
+        self.att_dropout = AttributeDropout(p=attribute_dropout)
+        self.condition_provider = condition_provider
+        self.fuser = fuser
+        self.card = card
+        embed_dim = self.card + 1
+        self.n_q = n_q
+        self.dim = dim
+        self.pattern_provider = pattern_provider
+        self.two_step_cfg = two_step_cfg
+        self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
+        if 'activation' in kwargs:
+            kwargs['activation'] = get_activation_fn(kwargs['activation'])
+        self.transformer = StreamingTransformer(
+            d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
+            norm=norm, norm_first=norm_first, **kwargs)
+        self.out_norm: tp.Optional[nn.Module] = None
+        if norm_first:
+            self.out_norm = create_norm_fn(norm, dim)
+        self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
+        self._init_weights(weight_init, depthwise_init, zero_bias_init)
+        self._fsdp: tp.Optional[nn.Module]
+        self.__dict__['_fsdp'] = None
+
+    def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
+        """Initialization of the transformer module weights.
+
+        Args:
+            weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
+            depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
+                'current' where the depth corresponds to the current layer index or 'global' where the total number
+                of layer is used as depth. If not set, no depthwise initialization strategy is used.
+            zero_bias_init (bool): Whether to initialize bias to zero or not.
+        """
+        assert depthwise_init is None or depthwise_init in ['current', 'global']
+        assert depthwise_init is None or weight_init is not None, \
+            "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
+        assert not zero_bias_init or weight_init is not None, \
+            "If 'zero_bias_init', a 'weight_init' method should be provided"
+
+        if weight_init is None:
+            return
+
+        for emb_layer in self.emb:
+            init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+        for layer_idx, tr_layer in enumerate(self.transformer.layers):
+            depth = None
+            if depthwise_init == 'current':
+                depth = layer_idx + 1
+            elif depthwise_init == 'global':
+                depth = len(self.transformer.layers)
+            init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
+            tr_layer.apply(init_fn)
+
+        for linear in self.linears:
+            init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+    @property
+    def special_token_id(self) -> int:
+        return self.card
+
+    @property
+    def num_codebooks(self) -> int:
+        return self.n_q
+
+    def forward(self, sequence: torch.Tensor,
+                conditions: tp.List[ConditioningAttributes],
+                condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+        """Apply language model on sequence and conditions.
+        Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+        S the sequence steps, return the logits with shape [B, card, K, S].
+
+        Args:
+            indices (torch.Tensor): Indices of the codes to model.
+            conditions (list of ConditioningAttributes): Conditions to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            torch.Tensor: Logits.
+        """
+        B, K, S = sequence.shape
+        assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
+        input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+        if condition_tensors is None:
+            assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+            # apply dropout modules
+            conditions = self.cfg_dropout(conditions)
+            conditions = self.att_dropout(conditions)
+            tokenized = self.condition_provider.tokenize(conditions)
+            # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+            condition_tensors = self.condition_provider(tokenized)
+        else:
+            assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+        input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+        out = self.transformer(input_, cross_attention_src=cross_attention_input)
+        if self.out_norm:
+            out = self.out_norm(out)
+        logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)  # [B, K, S, card]
+
+        # remove the prefix from the model outputs
+        if len(self.fuser.fuse2cond['prepend']) > 0:
+            logits = logits[:, :, -S:]
+
+        return logits  # [B, K, S, card]
+
+    def compute_predictions(
+            self, codes: torch.Tensor,
+            conditions: tp.List[ConditioningAttributes],
+            condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+        """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+        forward using the specified codes interleaving pattern.
+
+        Args:
+            codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+                K the number of codebooks and T the number of timesteps.
+            conditions (list of ConditioningAttributes): conditionings to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            LMOutput: Language model outputs
+                logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+                    i.e. the first item corresponds to logits to predict the first code, meaning that
+                    no additional shifting of codes and logits is required.
+                mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+                    Given the specified interleaving strategies, parts of the logits and codes should
+                    not be considered as valid predictions because of invalid context.
+        """
+        B, K, T = codes.shape
+        codes = codes.contiguous()
+        # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+        pattern = self.pattern_provider.get_pattern(T)
+        sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+            codes, self.special_token_id, keep_only_valid_steps=True
+        )
+        # apply model on pattern sequence
+        model = self if self._fsdp is None else self._fsdp
+        logits = model(sequence_codes, conditions, condition_tensors)  # [B, K, S, card]
+        # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+        # and provide the corresponding mask over invalid positions of tokens
+        logits = logits.permute(0, 3, 1, 2)  # [B, card, K, S]
+        # note: we use nans as special token to make it obvious if we feed unexpected logits
+        logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+            logits, float('nan'), keep_only_valid_steps=True
+        )
+        logits = logits.permute(0, 2, 3, 1)  # [B, K, T, card]
+        logits_mask = logits_mask[None, :, :].expand(B, -1, -1)  # [K, T] -> [B, K, T]
+        return LMOutput(logits, logits_mask)
+
+    def _sample_next_token(self,
+                           sequence: torch.Tensor,
+                           cfg_conditions: CFGConditions,
+                           unconditional_state: State,
+                           use_sampling: bool = False,
+                           temp: float = 1.0,
+                           top_k: int = 0,
+                           top_p: float = 0.0,
+                           cfg_coef: tp.Optional[float] = None,
+                           two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
+        """Sample next token from the model given a sequence and a set of conditions. The model supports
+        multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
+
+        Args:
+            sequence (torch.Tensor): Current sequence of shape [B, K, S]
+                with K corresponding to the number of codebooks and S the number of sequence steps.
+                S = 1 in streaming mode, except for the first step that contains a bigger prompt.
+            condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
+                should be twice the batch size, being the concatenation of the conditions + null conditions.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            cfg_coef (float, optional): classifier free guidance coefficient
+        Returns:
+            next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
+        """
+        B = sequence.shape[0]
+        cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
+        model = self if self._fsdp is None else self._fsdp
+        two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+        if two_step_cfg and cfg_conditions != {}:
+            assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
+            condition_tensors, null_condition_tensors = cfg_conditions
+            cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
+            state = self.get_streaming_state()
+            self.set_streaming_state(unconditional_state)
+            uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
+            unconditional_state.update(self.get_streaming_state())
+            self.set_streaming_state(state)
+            logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
+        else:
+            assert isinstance(cfg_conditions, dict)
+            condition_tensors = cfg_conditions
+            if condition_tensors:
+                # Preparing for CFG, predicting both conditional and unconditional logits.
+                sequence = torch.cat([sequence, sequence], dim=0)
+            all_logits = model(
+                sequence,
+                conditions=[], condition_tensors=condition_tensors)
+            if condition_tensors:
+                cond_logits, uncond_logits = all_logits.split(B, dim=0)  # [B, K, T, card]
+                logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
+            else:
+                logits = all_logits
+
+        logits = logits.permute(0, 1, 3, 2)  # [B, K, card, T]
+        logits = logits[..., -1]  # [B x K x card]
+
+        # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
+        if use_sampling and temp > 0.0:
+            probs = torch.softmax(logits / temp, dim=-1)
+            if top_p > 0.0:
+                next_token = utils.sample_top_p(probs, p=top_p)
+            elif top_k > 0:
+                next_token = utils.sample_top_k(probs, k=top_k)
+            else:
+                next_token = utils.multinomial(probs, num_samples=1)
+        else:
+            next_token = torch.argmax(logits, dim=-1, keepdim=True)
+
+        return next_token
+
+    @torch.no_grad()
+    def generate(self,
+                 prompt: tp.Optional[torch.Tensor] = None,
+                 conditions: tp.List[ConditioningAttributes] = [],
+                 num_samples: tp.Optional[int] = None,
+                 max_gen_len: int = 256,
+                 use_sampling: bool = True,
+                 temp: float = 1.0,
+                 top_k: int = 250,
+                 top_p: float = 0.0,
+                 cfg_coef: tp.Optional[float] = None,
+                 two_step_cfg: tp.Optional[bool] = None,
+                 remove_prompts: bool = False,
+                 check: bool = False,
+                 callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+        """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+        be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+        Args:
+            prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
+            conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
+            num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
+            max_gen_len (int): Maximum generation length.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            cfg_coeff (float, optional): Classifier-free guidance coefficient.
+            two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
+            remove_prompts (bool): Whether to remove prompts from generation or not.
+            check (bool): Whether to apply further checks on generated sequence.
+            callback (Callback, optional): Callback function to report generation progress.
+        Returns:
+            torch.Tensor: Generated tokens.
+        """
+        assert not self.training, "generation shouldn't be used in training mode."
+        first_param = next(iter(self.parameters()))
+        device = first_param.device
+
+        # Checking all input shapes are consistent.
+        possible_num_samples = []
+        if num_samples is not None:
+            possible_num_samples.append(num_samples)
+        elif prompt is not None:
+            possible_num_samples.append(prompt.shape[0])
+        elif conditions:
+            possible_num_samples.append(len(conditions))
+        else:
+            possible_num_samples.append(1)
+        assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
+        num_samples = possible_num_samples[0]
+
+        # below we create set of conditions: one conditional and one unconditional
+        # to do that we merge the regular condition together with the null condition
+        # we then do 1 forward pass instead of 2.
+        # the reason for that is two-fold:
+        # 1. it is about x2 faster than doing 2 forward passes
+        # 2. avoid the streaming API treating the 2 passes as part of different time steps
+        # We also support doing two different passes, in particular to ensure that
+        # the padding structure is exactly the same between train and test.
+        # With a batch size of 1, this can be slower though.
+        cfg_conditions: CFGConditions
+        two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+        if conditions:
+            null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+            if two_step_cfg:
+                cfg_conditions = (
+                    self.condition_provider(self.condition_provider.tokenize(conditions)),
+                    self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+                )
+            else:
+                conditions = conditions + null_conditions
+                tokenized = self.condition_provider.tokenize(conditions)
+                cfg_conditions = self.condition_provider(tokenized)
+        else:
+            cfg_conditions = {}
+
+        if prompt is None:
+            assert num_samples > 0
+            prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+        B, K, T = prompt.shape
+        start_offset = T
+        assert start_offset < max_gen_len
+
+        pattern = self.pattern_provider.get_pattern(max_gen_len)
+        # this token is used as default value for codes that are not generated yet
+        unknown_token = -1
+
+        # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+        gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+        # filling the gen_codes with the prompt if needed
+        gen_codes[..., :start_offset] = prompt
+        # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+        gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+        # retrieve the start_offset in the sequence:
+        # it is the first sequence step that contains the `start_offset` timestep
+        start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+        assert start_offset_sequence is not None
+
+        with self.streaming():
+            unconditional_state = self.get_streaming_state()
+            prev_offset = 0
+            gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
+            for offset in range(start_offset_sequence, gen_sequence_len):
+                # get current sequence (note that the streaming API is providing the caching over previous offsets)
+                curr_sequence = gen_sequence[..., prev_offset:offset]
+                curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+                if check:
+                    # check coherence between mask and sequence
+                    assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+                    # should never happen as gen_sequence is filled progressively
+                    assert not (curr_sequence == unknown_token).any()
+                # sample next token from the model, next token shape is [B, K, 1]
+                next_token = self._sample_next_token(
+                    curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+                    cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
+                # ensure the tokens that should be masked are properly set to special_token_id
+                # as the model never output special_token_id
+                valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+                next_token[~valid_mask] = self.special_token_id
+                # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+                # (then mask tokens should be left as is as well, which is correct)
+                gen_sequence[..., offset:offset+1] = torch.where(
+                    gen_sequence[..., offset:offset+1] == unknown_token,
+                    next_token, gen_sequence[..., offset:offset+1]
+                )
+                prev_offset = offset
+                if callback is not None:
+                    callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+        unconditional_state.clear()
+
+        # ensure sequence has been entirely filled
+        assert not (gen_sequence == unknown_token).any()
+        # ensure gen_sequence pattern and mask are matching
+        # which means the gen_sequence is valid according to the pattern
+        assert (
+            gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+        ).all()
+        # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+        out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+        # sanity checks over the returned codes and corresponding masks
+        assert (out_codes[..., :max_gen_len] != unknown_token).all()
+        assert (out_mask[..., :max_gen_len] == 1).all()
+
+        out_start_offset = start_offset if remove_prompts else 0
+        out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+        # ensure the returned codes are all valid
+        assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+        return out_codes
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var num_codebooks : int
+
+
+
+ +Expand source code + +
@property
+def num_codebooks(self) -> int:
+    return self.n_q
+
+
+
var special_token_id : int
+
+
+
+ +Expand source code + +
@property
+def special_token_id(self) -> int:
+    return self.card
+
+
+
+

Methods

+
+
+def compute_predictions(self, codes: torch.Tensor, conditions: List[ConditioningAttributes], condition_tensors: Optional[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = None) ‑> LMOutput +
+
+

Given an input tensor of codes [B, K, T] and list of conditions, runs the model +forward using the specified codes interleaving pattern.

+

Args

+
+
codes : torch.Tensor
+
Input codes of shape [B, K, T] with B the batch size, +K the number of codebooks and T the number of timesteps.
+
conditions : list of ConditioningAttributes
+
conditionings to use when modeling +the given codes. Note that when evaluating multiple time with the same conditioning +you should pre-compute those and pass them as condition_tensors.
+
condition_tensors : dict[str, ConditionType], optional
+
pre-computed conditioning +tensors, see conditions.
+
+

Returns

+
+
LMOutput
+
Language model outputs +logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, +i.e. the first item corresponds to logits to predict the first code, meaning that +no additional shifting of codes and logits is required. +mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. +Given the specified interleaving strategies, parts of the logits and codes should +not be considered as valid predictions because of invalid context.
+
+
+ +Expand source code + +
def compute_predictions(
+        self, codes: torch.Tensor,
+        conditions: tp.List[ConditioningAttributes],
+        condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+    """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+    forward using the specified codes interleaving pattern.
+
+    Args:
+        codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+            K the number of codebooks and T the number of timesteps.
+        conditions (list of ConditioningAttributes): conditionings to use when modeling
+            the given codes. Note that when evaluating multiple time with the same conditioning
+            you should pre-compute those and pass them as `condition_tensors`.
+        condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
+            tensors, see `conditions`.
+    Returns:
+        LMOutput: Language model outputs
+            logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+                i.e. the first item corresponds to logits to predict the first code, meaning that
+                no additional shifting of codes and logits is required.
+            mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+                Given the specified interleaving strategies, parts of the logits and codes should
+                not be considered as valid predictions because of invalid context.
+    """
+    B, K, T = codes.shape
+    codes = codes.contiguous()
+    # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+    pattern = self.pattern_provider.get_pattern(T)
+    sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+        codes, self.special_token_id, keep_only_valid_steps=True
+    )
+    # apply model on pattern sequence
+    model = self if self._fsdp is None else self._fsdp
+    logits = model(sequence_codes, conditions, condition_tensors)  # [B, K, S, card]
+    # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+    # and provide the corresponding mask over invalid positions of tokens
+    logits = logits.permute(0, 3, 1, 2)  # [B, card, K, S]
+    # note: we use nans as special token to make it obvious if we feed unexpected logits
+    logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+        logits, float('nan'), keep_only_valid_steps=True
+    )
+    logits = logits.permute(0, 2, 3, 1)  # [B, K, T, card]
+    logits_mask = logits_mask[None, :, :].expand(B, -1, -1)  # [K, T] -> [B, K, T]
+    return LMOutput(logits, logits_mask)
+
+
+
+def forward(self, sequence: torch.Tensor, conditions: List[ConditioningAttributes], condition_tensors: Optional[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = None) ‑> torch.Tensor +
+
+

Apply language model on sequence and conditions. +Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and +S the sequence steps, return the logits with shape [B, card, K, S].

+

Args

+
+
indices : torch.Tensor
+
Indices of the codes to model.
+
conditions : list of ConditioningAttributes
+
Conditions to use when modeling +the given codes. Note that when evaluating multiple time with the same conditioning +you should pre-compute those and pass them as condition_tensors.
+
condition_tensors : dict[str, ConditionType], optional
+
Pre-computed conditioning +tensors, see conditions.
+
+

Returns

+
+
torch.Tensor
+
Logits.
+
+
+ +Expand source code + +
def forward(self, sequence: torch.Tensor,
+            conditions: tp.List[ConditioningAttributes],
+            condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+    """Apply language model on sequence and conditions.
+    Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+    S the sequence steps, return the logits with shape [B, card, K, S].
+
+    Args:
+        indices (torch.Tensor): Indices of the codes to model.
+        conditions (list of ConditioningAttributes): Conditions to use when modeling
+            the given codes. Note that when evaluating multiple time with the same conditioning
+            you should pre-compute those and pass them as `condition_tensors`.
+        condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
+            tensors, see `conditions`.
+    Returns:
+        torch.Tensor: Logits.
+    """
+    B, K, S = sequence.shape
+    assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
+    input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+    if condition_tensors is None:
+        assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+        # apply dropout modules
+        conditions = self.cfg_dropout(conditions)
+        conditions = self.att_dropout(conditions)
+        tokenized = self.condition_provider.tokenize(conditions)
+        # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+        condition_tensors = self.condition_provider(tokenized)
+    else:
+        assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+    input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+    out = self.transformer(input_, cross_attention_src=cross_attention_input)
+    if self.out_norm:
+        out = self.out_norm(out)
+    logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)  # [B, K, S, card]
+
+    # remove the prefix from the model outputs
+    if len(self.fuser.fuse2cond['prepend']) > 0:
+        logits = logits[:, :, -S:]
+
+    return logits  # [B, K, S, card]
+
+
+
+def generate(self, prompt: Optional[torch.Tensor] = None, conditions: List[ConditioningAttributes] = [], num_samples: Optional[int] = None, max_gen_len: int = 256, use_sampling: bool = True, temp: float = 1.0, top_k: int = 250, top_p: float = 0.0, cfg_coef: Optional[float] = None, two_step_cfg: Optional[bool] = None, remove_prompts: bool = False, check: bool = False, callback: Optional[Callable[[int, int], None]] = None) ‑> torch.Tensor +
+
+

Generate tokens sampling from the model given a prompt or unconditionally. Generation can +be perform in a greedy fashion or using sampling with top K and top P strategies.

+

Args

+
+
prompt : torch.Tensor, optional
+
Prompt tokens of shape [B, K, T].
+
conditions_tensors : list of ConditioningAttributes, optional
+
List of conditions.
+
num_samples : int, optional
+
Number of samples to generate when no prompt and no conditions are given.
+
max_gen_len : int
+
Maximum generation length.
+
use_sampling : bool
+
Whether to use a sampling strategy or not.
+
temp : float
+
Sampling temperature.
+
top_k : int
+
K for "top-k" sampling.
+
top_p : float
+
P for "top-p" sampling.
+
cfg_coeff : float, optional
+
Classifier-free guidance coefficient.
+
two_step_cfg : bool, optional
+
Whether to perform classifier-free guidance with two steps generation.
+
remove_prompts : bool
+
Whether to remove prompts from generation or not.
+
check : bool
+
Whether to apply further checks on generated sequence.
+
callback : Callback, optional
+
Callback function to report generation progress.
+
+

Returns

+
+
torch.Tensor
+
Generated tokens.
+
+
+ +Expand source code + +
@torch.no_grad()
+def generate(self,
+             prompt: tp.Optional[torch.Tensor] = None,
+             conditions: tp.List[ConditioningAttributes] = [],
+             num_samples: tp.Optional[int] = None,
+             max_gen_len: int = 256,
+             use_sampling: bool = True,
+             temp: float = 1.0,
+             top_k: int = 250,
+             top_p: float = 0.0,
+             cfg_coef: tp.Optional[float] = None,
+             two_step_cfg: tp.Optional[bool] = None,
+             remove_prompts: bool = False,
+             check: bool = False,
+             callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+    """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+    be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+    Args:
+        prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
+        conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
+        num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
+        max_gen_len (int): Maximum generation length.
+        use_sampling (bool): Whether to use a sampling strategy or not.
+        temp (float): Sampling temperature.
+        top_k (int): K for "top-k" sampling.
+        top_p (float): P for "top-p" sampling.
+        cfg_coeff (float, optional): Classifier-free guidance coefficient.
+        two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
+        remove_prompts (bool): Whether to remove prompts from generation or not.
+        check (bool): Whether to apply further checks on generated sequence.
+        callback (Callback, optional): Callback function to report generation progress.
+    Returns:
+        torch.Tensor: Generated tokens.
+    """
+    assert not self.training, "generation shouldn't be used in training mode."
+    first_param = next(iter(self.parameters()))
+    device = first_param.device
+
+    # Checking all input shapes are consistent.
+    possible_num_samples = []
+    if num_samples is not None:
+        possible_num_samples.append(num_samples)
+    elif prompt is not None:
+        possible_num_samples.append(prompt.shape[0])
+    elif conditions:
+        possible_num_samples.append(len(conditions))
+    else:
+        possible_num_samples.append(1)
+    assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
+    num_samples = possible_num_samples[0]
+
+    # below we create set of conditions: one conditional and one unconditional
+    # to do that we merge the regular condition together with the null condition
+    # we then do 1 forward pass instead of 2.
+    # the reason for that is two-fold:
+    # 1. it is about x2 faster than doing 2 forward passes
+    # 2. avoid the streaming API treating the 2 passes as part of different time steps
+    # We also support doing two different passes, in particular to ensure that
+    # the padding structure is exactly the same between train and test.
+    # With a batch size of 1, this can be slower though.
+    cfg_conditions: CFGConditions
+    two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+    if conditions:
+        null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+        if two_step_cfg:
+            cfg_conditions = (
+                self.condition_provider(self.condition_provider.tokenize(conditions)),
+                self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+            )
+        else:
+            conditions = conditions + null_conditions
+            tokenized = self.condition_provider.tokenize(conditions)
+            cfg_conditions = self.condition_provider(tokenized)
+    else:
+        cfg_conditions = {}
+
+    if prompt is None:
+        assert num_samples > 0
+        prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+    B, K, T = prompt.shape
+    start_offset = T
+    assert start_offset < max_gen_len
+
+    pattern = self.pattern_provider.get_pattern(max_gen_len)
+    # this token is used as default value for codes that are not generated yet
+    unknown_token = -1
+
+    # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+    gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+    # filling the gen_codes with the prompt if needed
+    gen_codes[..., :start_offset] = prompt
+    # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+    gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+    # retrieve the start_offset in the sequence:
+    # it is the first sequence step that contains the `start_offset` timestep
+    start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+    assert start_offset_sequence is not None
+
+    with self.streaming():
+        unconditional_state = self.get_streaming_state()
+        prev_offset = 0
+        gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
+        for offset in range(start_offset_sequence, gen_sequence_len):
+            # get current sequence (note that the streaming API is providing the caching over previous offsets)
+            curr_sequence = gen_sequence[..., prev_offset:offset]
+            curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+            if check:
+                # check coherence between mask and sequence
+                assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+                # should never happen as gen_sequence is filled progressively
+                assert not (curr_sequence == unknown_token).any()
+            # sample next token from the model, next token shape is [B, K, 1]
+            next_token = self._sample_next_token(
+                curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+                cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
+            # ensure the tokens that should be masked are properly set to special_token_id
+            # as the model never output special_token_id
+            valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+            next_token[~valid_mask] = self.special_token_id
+            # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+            # (then mask tokens should be left as is as well, which is correct)
+            gen_sequence[..., offset:offset+1] = torch.where(
+                gen_sequence[..., offset:offset+1] == unknown_token,
+                next_token, gen_sequence[..., offset:offset+1]
+            )
+            prev_offset = offset
+            if callback is not None:
+                callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+    unconditional_state.clear()
+
+    # ensure sequence has been entirely filled
+    assert not (gen_sequence == unknown_token).any()
+    # ensure gen_sequence pattern and mask are matching
+    # which means the gen_sequence is valid according to the pattern
+    assert (
+        gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+    ).all()
+    # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+    out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+    # sanity checks over the returned codes and corresponding masks
+    assert (out_codes[..., :max_gen_len] != unknown_token).all()
+    assert (out_mask[..., :max_gen_len] == 1).all()
+
+    out_start_offset = start_offset if remove_prompts else 0
+    out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+    # ensure the returned codes are all valid
+    assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+    return out_codes
+
+
+
+

Inherited members

+ +
+
+class LMOutput +(logits: torch.Tensor, mask: torch.Tensor) +
+
+

LMOutput(logits: torch.Tensor, mask: torch.Tensor)

+
+ +Expand source code + +
class LMOutput:
+    # The logits are already re-aligned with the input codes
+    # hence no extra shift is required, e.g. when computing CE
+    logits: torch.Tensor  # [B, K, T, card]
+    mask: torch.Tensor  # [B, K, T]
+
+

Class variables

+
+
var logits : torch.Tensor
+
+
+
+
var mask : torch.Tensor
+
+
+
+
+
+
+class ScaledEmbedding +(*args, lr=None, **kwargs) +
+
+

Boost learning rate for embeddings (with scale).

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ScaledEmbedding(nn.Embedding):
+    """Boost learning rate for embeddings (with `scale`).
+    """
+    def __init__(self, *args, lr=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.lr = lr
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        return group
+
+

Ancestors

+
    +
  • torch.nn.modules.sparse.Embedding
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var embedding_dim : int
+
+
+
+
var freeze : bool
+
+
+
+
var max_norm : Optional[float]
+
+
+
+
var norm_type : float
+
+
+
+
var num_embeddings : int
+
+
+
+
var padding_idx : Optional[int]
+
+
+
+
var scale_grad_by_freq : bool
+
+
+
+
var sparse : bool
+
+
+
+
var weight : torch.Tensor
+
+
+
+
+

Methods

+
+
+def make_optim_group(self) +
+
+
+
+ +Expand source code + +
def make_optim_group(self):
+    group = {"params": list(self.parameters())}
+    if self.lr is not None:
+        group["lr"] = self.lr
+    return group
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/loaders.html b/api_docs/audiocraft/models/loaders.html new file mode 100644 index 00000000..fa1b897c --- /dev/null +++ b/api_docs/audiocraft/models/loaders.html @@ -0,0 +1,367 @@ + + + + + + +audiocraft.models.loaders API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.loaders

+
+
+

Utility functions to load from the checkpoints. +Each checkpoint is a torch.saved dict with the following keys: +- 'xp.cfg': the hydra config as dumped during training. This should be used +to rebuild the object using the audiocraft.models.builders functions, +- 'model_best_state': a readily loadable best state for the model, including +the conditioner. The model obtained from xp.cfg should be compatible +with this state dict. In the case of a LM, the encodec model would not be +bundled along but instead provided separately.

+

Those functions also support loading from a remote location with the Torch Hub API. +They also support overriding some parameters, in particular the device and dtype +of the returned model.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility functions to load from the checkpoints.
+Each checkpoint is a torch.saved dict with the following keys:
+- 'xp.cfg': the hydra config as dumped during training. This should be used
+    to rebuild the object using the audiocraft.models.builders functions,
+- 'model_best_state': a readily loadable best state for the model, including
+    the conditioner. The model obtained from `xp.cfg` should be compatible
+    with this state dict. In the case of a LM, the encodec model would not be
+    bundled along but instead provided separately.
+
+Those functions also support loading from a remote location with the Torch Hub API.
+They also support overriding some parameters, in particular the device and dtype
+of the returned model.
+"""
+
+from pathlib import Path
+from huggingface_hub import hf_hub_download
+import typing as tp
+import os
+
+from omegaconf import OmegaConf, DictConfig
+import torch
+
+from . import builders
+from .encodec import CompressionModel
+
+
+def get_audiocraft_cache_dir() -> tp.Optional[str]:
+    return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
+
+
+def _get_state_dict(
+    file_or_url_or_id: tp.Union[Path, str],
+    filename: tp.Optional[str] = None,
+    device='cpu',
+    cache_dir: tp.Optional[str] = None,
+):
+    if cache_dir is None:
+        cache_dir = get_audiocraft_cache_dir()
+    # Return the state dict either from a file or url
+    file_or_url_or_id = str(file_or_url_or_id)
+    assert isinstance(file_or_url_or_id, str)
+
+    if os.path.isfile(file_or_url_or_id):
+        return torch.load(file_or_url_or_id, map_location=device)
+
+    if os.path.isdir(file_or_url_or_id):
+        file = f"{file_or_url_or_id}/{filename}"
+        return torch.load(file, map_location=device)
+
+    elif file_or_url_or_id.startswith('https://'):
+        return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
+
+    else:
+        assert filename is not None, "filename needs to be defined if using HF checkpoints"
+
+        file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir)
+        return torch.load(file, map_location=device)
+
+
+def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
+    return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
+
+
+def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
+    if 'pretrained' in pkg:
+        return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    model = builders.get_compression_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    return model
+
+
+def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
+    return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
+
+
+def _delete_param(cfg: DictConfig, full_name: str):
+    parts = full_name.split('.')
+    for part in parts[:-1]:
+        if part in cfg:
+            cfg = cfg[part]
+        else:
+            return
+    OmegaConf.set_struct(cfg, False)
+    if parts[-1] in cfg:
+        del cfg[parts[-1]]
+    OmegaConf.set_struct(cfg, True)
+
+
+def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    if cfg.device == 'cpu':
+        cfg.dtype = 'float32'
+    else:
+        cfg.dtype = 'float16'
+    _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
+    _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
+    _delete_param(cfg, 'conditioners.args.drop_desc_p')
+    model = builders.get_lm_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    model.cfg = cfg
+    return model
+
+
+def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
+                  filename: tp.Optional[str] = None,
+                  cache_dir: tp.Optional[str] = None):
+    return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
+
+
+def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
+                          device='cpu',
+                          filename: tp.Optional[str] = None,
+                          cache_dir: tp.Optional[str] = None):
+    pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
+    models = []
+    processors = []
+    cfgs = []
+    sample_rate = pkg['sample_rate']
+    for i in range(pkg['n_bands']):
+        cfg = pkg[i]['cfg']
+        model = builders.get_diffusion_model(cfg)
+        model_dict = pkg[i]['model_state']
+        model.load_state_dict(model_dict)
+        model.to(device)
+        processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
+        processor_dict = pkg[i]['processor_state']
+        processor.load_state_dict(processor_dict)
+        processor.to(device)
+        models.append(model)
+        processors.append(processor)
+        cfgs.append(cfg)
+    return models, processors, cfgs
+
+
+
+
+
+
+
+

Functions

+
+
+def get_audiocraft_cache_dir() ‑> Optional[str] +
+
+
+
+ +Expand source code + +
def get_audiocraft_cache_dir() -> tp.Optional[str]:
+    return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
+
+
+
+def load_compression_model(file_or_url_or_id: Union[str, pathlib.Path], device='cpu', cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
+    if 'pretrained' in pkg:
+        return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    model = builders.get_compression_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    return model
+
+
+
+def load_compression_model_ckpt(file_or_url_or_id: Union[str, pathlib.Path], cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
+    return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
+
+
+
+def load_diffusion_models(file_or_url_or_id: Union[str, pathlib.Path], device='cpu', filename: Optional[str] = None, cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
+                          device='cpu',
+                          filename: tp.Optional[str] = None,
+                          cache_dir: tp.Optional[str] = None):
+    pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
+    models = []
+    processors = []
+    cfgs = []
+    sample_rate = pkg['sample_rate']
+    for i in range(pkg['n_bands']):
+        cfg = pkg[i]['cfg']
+        model = builders.get_diffusion_model(cfg)
+        model_dict = pkg[i]['model_state']
+        model.load_state_dict(model_dict)
+        model.to(device)
+        processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
+        processor_dict = pkg[i]['processor_state']
+        processor.load_state_dict(processor_dict)
+        processor.to(device)
+        models.append(model)
+        processors.append(processor)
+        cfgs.append(cfg)
+    return models, processors, cfgs
+
+
+
+def load_lm_model(file_or_url_or_id: Union[str, pathlib.Path], device='cpu', cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    if cfg.device == 'cpu':
+        cfg.dtype = 'float32'
+    else:
+        cfg.dtype = 'float16'
+    _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
+    _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
+    _delete_param(cfg, 'conditioners.args.drop_desc_p')
+    model = builders.get_lm_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    model.cfg = cfg
+    return model
+
+
+
+def load_lm_model_ckpt(file_or_url_or_id: Union[str, pathlib.Path], cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
+    return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
+
+
+
+def load_mbd_ckpt(file_or_url_or_id: Union[str, pathlib.Path], filename: Optional[str] = None, cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
+                  filename: tp.Optional[str] = None,
+                  cache_dir: tp.Optional[str] = None):
+    return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/multibanddiffusion.html b/api_docs/audiocraft/models/multibanddiffusion.html new file mode 100644 index 00000000..3b8552c1 --- /dev/null +++ b/api_docs/audiocraft/models/multibanddiffusion.html @@ -0,0 +1,812 @@ + + + + + + +audiocraft.models.multibanddiffusion API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.multibanddiffusion

+
+
+

Multi Band Diffusion models as described in +"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" +(paper link).

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Multi Band Diffusion models as described in
+"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
+(paper link).
+"""
+
+import typing as tp
+
+import torch
+import julius
+
+from .unet import DiffusionUnet
+from ..modules.diffusion_schedule import NoiseSchedule
+from .encodec import CompressionModel
+from ..solvers.compression import CompressionSolver
+from .loaders import load_compression_model, load_diffusion_models
+
+
+class DiffusionProcess:
+    """Sampling for a diffusion Model.
+
+    Args:
+        model (DiffusionUnet): Diffusion U-Net model.
+        noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
+    """
+    def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
+        """
+        """
+        self.model = model
+        self.schedule = noise_schedule
+
+    def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
+                 step_list: tp.Optional[tp.List[int]] = None):
+        """Perform one diffusion process to generate one of the bands.
+
+        Args:
+            condition (tensor): The embeddings form the compression model.
+            initial_noise (tensor): The initial noise to start the process/
+        """
+        return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
+                                                 condition=condition)
+
+
+class MultiBandDiffusion:
+    """Sample from multiple diffusion models.
+
+    Args:
+        DPs (list of DiffusionProcess): Diffusion processes.
+        codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
+    """
+    def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
+        self.DPs = DPs
+        self.codec_model = codec_model
+        self.device = next(self.codec_model.parameters()).device
+
+    @property
+    def sample_rate(self) -> int:
+        return self.codec_model.sample_rate
+
+    @staticmethod
+    def get_mbd_musicgen(device=None):
+        """Load our diffusion models trained for MusicGen."""
+        if device is None:
+            device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        path = 'facebook/multiband-diffusion'
+        filename = 'mbd_musicgen_32khz.th'
+        name = 'facebook/musicgen-small'
+        codec_model = load_compression_model(name, device=device)
+        models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
+        DPs = []
+        for i in range(len(models)):
+            schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
+            DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+        return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+    @staticmethod
+    def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
+                      device: tp.Optional[tp.Union[torch.device, str]] = None,
+                      n_q: tp.Optional[int] = None):
+        """Get the pretrained Models for MultibandDiffusion.
+
+        Args:
+            bw (float): Bandwidth of the compression model.
+            pretrained (bool): Whether to use / download if necessary the models.
+            device (torch.device or str, optional): Device on which the models are loaded.
+            n_q (int, optional): Number of quantizers to use within the compression model.
+        """
+        if device is None:
+            device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
+        if n_q is not None:
+            assert n_q in [2, 4, 8]
+            assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
+                f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
+        n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
+        codec_model = CompressionSolver.model_from_checkpoint(
+            '//pretrained/facebook/encodec_24khz', device=device)
+        codec_model.set_num_codebooks(n_q)
+        codec_model = codec_model.to(device)
+        path = 'facebook/multiband-diffusion'
+        filename = f'mbd_comp_{n_q}.pt'
+        models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
+        DPs = []
+        for i in range(len(models)):
+            schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
+            DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+        return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+        return MultiBandDiffusion(DPs, codec_model)
+
+    @torch.no_grad()
+    def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
+        Args:
+            wav (torch.Tensor): The audio that we want to extract the conditioning from
+            sample_rate (int): sample rate of the audio"""
+        if sample_rate != self.sample_rate:
+            wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
+        codes, scale = self.codec_model.encode(wav)
+        assert scale is None, "Scaled compression models not supported."
+        emb = self.get_emb(codes)
+        return emb
+
+    @torch.no_grad()
+    def get_emb(self, codes: torch.Tensor):
+        """Get latent representation from the discrete codes
+        Argrs:
+            codes (torch.Tensor): discrete tokens"""
+        emb = self.codec_model.decode_latent(codes)
+        return emb
+
+    def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
+                 step_list: tp.Optional[tp.List[int]] = None):
+        """Generate Wavform audio from the latent embeddings of the compression model
+        Args:
+            emb (torch.Tensor): Conditioning embeddinds
+            size (none torch.Size): size of the output
+                if None this is computed from the typical upsampling of the model
+            step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
+        """
+        if size is None:
+            upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
+            size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
+        assert size[0] == emb.size(0)
+        out = torch.zeros(size).to(self.device)
+        for DP in self.DPs:
+            out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
+        return out
+
+    def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
+        """match the eq to the encodec output by matching the standard deviation of some frequency bands
+        Args:
+            wav (torch.Tensor): audio to equalize
+            ref (torch.Tensor):refenrence audio from which we match the spectrogram.
+            n_bands (int): number of bands of the eq
+            strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
+        """
+        split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
+        bands = split(wav)
+        bands_ref = split(ref)
+        out = torch.zeros_like(ref)
+        for i in range(n_bands):
+            out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
+        return out
+
+    def regenerate(self, wav: torch.Tensor, sample_rate: int):
+        """Regenerate a wavform through compression and diffusion regeneration.
+        Args:
+            wav (torch.Tensor): Original 'ground truth' audio
+            sample_rate (int): sample rate of the input (and output) wav
+        """
+        if sample_rate != self.codec_model.sample_rate:
+            wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
+        emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
+        size = wav.size()
+        out = self.generate(emb, size=size)
+        if sample_rate != self.codec_model.sample_rate:
+            out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
+        return out
+
+    def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
+        """Generate Waveform audio with diffusion from the discrete codes.
+        Args:
+            tokens (torch.Tensor): discrete codes
+            n_bands (int): bands for the eq matching.
+        """
+        wav_encodec = self.codec_model.decode(tokens)
+        condition = self.get_emb(tokens)
+        wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
+        return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class DiffusionProcess +(model: DiffusionUnet, noise_schedule: NoiseSchedule) +
+
+

Sampling for a diffusion Model.

+

Args

+
+
model : DiffusionUnet
+
Diffusion U-Net model.
+
noise_schedule : NoiseSchedule
+
Noise schedule for diffusion process.
+
+
+ +Expand source code + +
class DiffusionProcess:
+    """Sampling for a diffusion Model.
+
+    Args:
+        model (DiffusionUnet): Diffusion U-Net model.
+        noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
+    """
+    def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
+        """
+        """
+        self.model = model
+        self.schedule = noise_schedule
+
+    def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
+                 step_list: tp.Optional[tp.List[int]] = None):
+        """Perform one diffusion process to generate one of the bands.
+
+        Args:
+            condition (tensor): The embeddings form the compression model.
+            initial_noise (tensor): The initial noise to start the process/
+        """
+        return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
+                                                 condition=condition)
+
+

Methods

+
+
+def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, step_list: Optional[List[int]] = None) +
+
+

Perform one diffusion process to generate one of the bands.

+

Args

+
+
condition : tensor
+
The embeddings form the compression model.
+
initial_noise : tensor
+
The initial noise to start the process/
+
+
+ +Expand source code + +
def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
+             step_list: tp.Optional[tp.List[int]] = None):
+    """Perform one diffusion process to generate one of the bands.
+
+    Args:
+        condition (tensor): The embeddings form the compression model.
+        initial_noise (tensor): The initial noise to start the process/
+    """
+    return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
+                                             condition=condition)
+
+
+
+
+
+class MultiBandDiffusion +(DPs: List[DiffusionProcess], codec_model: CompressionModel) +
+
+

Sample from multiple diffusion models.

+

Args

+
+
DPs : list of DiffusionProcess
+
Diffusion processes.
+
codec_model : CompressionModel
+
Underlying compression model used to obtain discrete tokens.
+
+
+ +Expand source code + +
class MultiBandDiffusion:
+    """Sample from multiple diffusion models.
+
+    Args:
+        DPs (list of DiffusionProcess): Diffusion processes.
+        codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
+    """
+    def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
+        self.DPs = DPs
+        self.codec_model = codec_model
+        self.device = next(self.codec_model.parameters()).device
+
+    @property
+    def sample_rate(self) -> int:
+        return self.codec_model.sample_rate
+
+    @staticmethod
+    def get_mbd_musicgen(device=None):
+        """Load our diffusion models trained for MusicGen."""
+        if device is None:
+            device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        path = 'facebook/multiband-diffusion'
+        filename = 'mbd_musicgen_32khz.th'
+        name = 'facebook/musicgen-small'
+        codec_model = load_compression_model(name, device=device)
+        models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
+        DPs = []
+        for i in range(len(models)):
+            schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
+            DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+        return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+    @staticmethod
+    def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
+                      device: tp.Optional[tp.Union[torch.device, str]] = None,
+                      n_q: tp.Optional[int] = None):
+        """Get the pretrained Models for MultibandDiffusion.
+
+        Args:
+            bw (float): Bandwidth of the compression model.
+            pretrained (bool): Whether to use / download if necessary the models.
+            device (torch.device or str, optional): Device on which the models are loaded.
+            n_q (int, optional): Number of quantizers to use within the compression model.
+        """
+        if device is None:
+            device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
+        if n_q is not None:
+            assert n_q in [2, 4, 8]
+            assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
+                f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
+        n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
+        codec_model = CompressionSolver.model_from_checkpoint(
+            '//pretrained/facebook/encodec_24khz', device=device)
+        codec_model.set_num_codebooks(n_q)
+        codec_model = codec_model.to(device)
+        path = 'facebook/multiband-diffusion'
+        filename = f'mbd_comp_{n_q}.pt'
+        models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
+        DPs = []
+        for i in range(len(models)):
+            schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
+            DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+        return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+        return MultiBandDiffusion(DPs, codec_model)
+
+    @torch.no_grad()
+    def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
+        Args:
+            wav (torch.Tensor): The audio that we want to extract the conditioning from
+            sample_rate (int): sample rate of the audio"""
+        if sample_rate != self.sample_rate:
+            wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
+        codes, scale = self.codec_model.encode(wav)
+        assert scale is None, "Scaled compression models not supported."
+        emb = self.get_emb(codes)
+        return emb
+
+    @torch.no_grad()
+    def get_emb(self, codes: torch.Tensor):
+        """Get latent representation from the discrete codes
+        Argrs:
+            codes (torch.Tensor): discrete tokens"""
+        emb = self.codec_model.decode_latent(codes)
+        return emb
+
+    def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
+                 step_list: tp.Optional[tp.List[int]] = None):
+        """Generate Wavform audio from the latent embeddings of the compression model
+        Args:
+            emb (torch.Tensor): Conditioning embeddinds
+            size (none torch.Size): size of the output
+                if None this is computed from the typical upsampling of the model
+            step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
+        """
+        if size is None:
+            upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
+            size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
+        assert size[0] == emb.size(0)
+        out = torch.zeros(size).to(self.device)
+        for DP in self.DPs:
+            out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
+        return out
+
+    def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
+        """match the eq to the encodec output by matching the standard deviation of some frequency bands
+        Args:
+            wav (torch.Tensor): audio to equalize
+            ref (torch.Tensor):refenrence audio from which we match the spectrogram.
+            n_bands (int): number of bands of the eq
+            strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
+        """
+        split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
+        bands = split(wav)
+        bands_ref = split(ref)
+        out = torch.zeros_like(ref)
+        for i in range(n_bands):
+            out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
+        return out
+
+    def regenerate(self, wav: torch.Tensor, sample_rate: int):
+        """Regenerate a wavform through compression and diffusion regeneration.
+        Args:
+            wav (torch.Tensor): Original 'ground truth' audio
+            sample_rate (int): sample rate of the input (and output) wav
+        """
+        if sample_rate != self.codec_model.sample_rate:
+            wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
+        emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
+        size = wav.size()
+        out = self.generate(emb, size=size)
+        if sample_rate != self.codec_model.sample_rate:
+            out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
+        return out
+
+    def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
+        """Generate Waveform audio with diffusion from the discrete codes.
+        Args:
+            tokens (torch.Tensor): discrete codes
+            n_bands (int): bands for the eq matching.
+        """
+        wav_encodec = self.codec_model.decode(tokens)
+        condition = self.get_emb(tokens)
+        wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
+        return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
+
+

Static methods

+
+
+def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True, device: Union[str, torch.device, None] = None, n_q: Optional[int] = None) +
+
+

Get the pretrained Models for MultibandDiffusion.

+

Args

+
+
bw : float
+
Bandwidth of the compression model.
+
pretrained : bool
+
Whether to use / download if necessary the models.
+
device : torch.device or str, optional
+
Device on which the models are loaded.
+
n_q : int, optional
+
Number of quantizers to use within the compression model.
+
+
+ +Expand source code + +
@staticmethod
+def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
+                  device: tp.Optional[tp.Union[torch.device, str]] = None,
+                  n_q: tp.Optional[int] = None):
+    """Get the pretrained Models for MultibandDiffusion.
+
+    Args:
+        bw (float): Bandwidth of the compression model.
+        pretrained (bool): Whether to use / download if necessary the models.
+        device (torch.device or str, optional): Device on which the models are loaded.
+        n_q (int, optional): Number of quantizers to use within the compression model.
+    """
+    if device is None:
+        device = 'cuda' if torch.cuda.is_available() else 'cpu'
+    assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
+    if n_q is not None:
+        assert n_q in [2, 4, 8]
+        assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
+            f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
+    n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
+    codec_model = CompressionSolver.model_from_checkpoint(
+        '//pretrained/facebook/encodec_24khz', device=device)
+    codec_model.set_num_codebooks(n_q)
+    codec_model = codec_model.to(device)
+    path = 'facebook/multiband-diffusion'
+    filename = f'mbd_comp_{n_q}.pt'
+    models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
+    DPs = []
+    for i in range(len(models)):
+        schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
+        DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+    return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+    return MultiBandDiffusion(DPs, codec_model)
+
+
+
+def get_mbd_musicgen(device=None) +
+
+

Load our diffusion models trained for MusicGen.

+
+ +Expand source code + +
@staticmethod
+def get_mbd_musicgen(device=None):
+    """Load our diffusion models trained for MusicGen."""
+    if device is None:
+        device = 'cuda' if torch.cuda.is_available() else 'cpu'
+    path = 'facebook/multiband-diffusion'
+    filename = 'mbd_musicgen_32khz.th'
+    name = 'facebook/musicgen-small'
+    codec_model = load_compression_model(name, device=device)
+    models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
+    DPs = []
+    for i in range(len(models)):
+        schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
+        DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
+    return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
+
+
+
+

Instance variables

+
+
var sample_rate : int
+
+
+
+ +Expand source code + +
@property
+def sample_rate(self) -> int:
+    return self.codec_model.sample_rate
+
+
+
+

Methods

+
+
+def generate(self, emb: torch.Tensor, size: Optional[torch.Size] = None, step_list: Optional[List[int]] = None) +
+
+

Generate Wavform audio from the latent embeddings of the compression model

+

Args

+
+
emb : torch.Tensor
+
Conditioning embeddinds
+
size : none torch.Size
+
size of the output +if None this is computed from the typical upsampling of the model
+
step_list : optional list[int]
+
list of Markov chain steps, defaults to 50 linearly spaced step.
+
+
+ +Expand source code + +
def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
+             step_list: tp.Optional[tp.List[int]] = None):
+    """Generate Wavform audio from the latent embeddings of the compression model
+    Args:
+        emb (torch.Tensor): Conditioning embeddinds
+        size (none torch.Size): size of the output
+            if None this is computed from the typical upsampling of the model
+        step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
+    """
+    if size is None:
+        upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
+        size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
+    assert size[0] == emb.size(0)
+    out = torch.zeros(size).to(self.device)
+    for DP in self.DPs:
+        out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
+    return out
+
+
+
+def get_condition(self, wav: torch.Tensor, sample_rate: int) ‑> torch.Tensor +
+
+

Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.

+

Args

+
+
wav : torch.Tensor
+
The audio that we want to extract the conditioning from
+
sample_rate : int
+
sample rate of the audio
+
+
+ +Expand source code + +
@torch.no_grad()
+def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+    """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
+    Args:
+        wav (torch.Tensor): The audio that we want to extract the conditioning from
+        sample_rate (int): sample rate of the audio"""
+    if sample_rate != self.sample_rate:
+        wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
+    codes, scale = self.codec_model.encode(wav)
+    assert scale is None, "Scaled compression models not supported."
+    emb = self.get_emb(codes)
+    return emb
+
+
+
+def get_emb(self, codes: torch.Tensor) +
+
+

Get latent representation from the discrete codes

+

Argrs

+

codes (torch.Tensor): discrete tokens

+
+ +Expand source code + +
@torch.no_grad()
+def get_emb(self, codes: torch.Tensor):
+    """Get latent representation from the discrete codes
+    Argrs:
+        codes (torch.Tensor): discrete tokens"""
+    emb = self.codec_model.decode_latent(codes)
+    return emb
+
+
+
+def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1) +
+
+

match the eq to the encodec output by matching the standard deviation of some frequency bands

+

Args

+
+
wav : torch.Tensor
+
audio to equalize
+
ref (torch.Tensor):refenrence audio from which we match the spectrogram.
+
n_bands : int
+
number of bands of the eq
+
strictness : float
+
how strict the the matching. 0 is no matching, 1 is exact matching.
+
+
+ +Expand source code + +
def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
+    """match the eq to the encodec output by matching the standard deviation of some frequency bands
+    Args:
+        wav (torch.Tensor): audio to equalize
+        ref (torch.Tensor):refenrence audio from which we match the spectrogram.
+        n_bands (int): number of bands of the eq
+        strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
+    """
+    split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
+    bands = split(wav)
+    bands_ref = split(ref)
+    out = torch.zeros_like(ref)
+    for i in range(n_bands):
+        out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
+    return out
+
+
+
+def regenerate(self, wav: torch.Tensor, sample_rate: int) +
+
+

Regenerate a wavform through compression and diffusion regeneration.

+

Args

+
+
wav : torch.Tensor
+
Original 'ground truth' audio
+
sample_rate : int
+
sample rate of the input (and output) wav
+
+
+ +Expand source code + +
def regenerate(self, wav: torch.Tensor, sample_rate: int):
+    """Regenerate a wavform through compression and diffusion regeneration.
+    Args:
+        wav (torch.Tensor): Original 'ground truth' audio
+        sample_rate (int): sample rate of the input (and output) wav
+    """
+    if sample_rate != self.codec_model.sample_rate:
+        wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
+    emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
+    size = wav.size()
+    out = self.generate(emb, size=size)
+    if sample_rate != self.codec_model.sample_rate:
+        out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
+    return out
+
+
+
+def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32) +
+
+

Generate Waveform audio with diffusion from the discrete codes.

+

Args

+
+
tokens : torch.Tensor
+
discrete codes
+
n_bands : int
+
bands for the eq matching.
+
+
+ +Expand source code + +
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
+    """Generate Waveform audio with diffusion from the discrete codes.
+    Args:
+        tokens (torch.Tensor): discrete codes
+        n_bands (int): bands for the eq matching.
+    """
+    wav_encodec = self.codec_model.decode(tokens)
+    condition = self.get_emb(tokens)
+    wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
+    return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/musicgen.html b/api_docs/audiocraft/models/musicgen.html new file mode 100644 index 00000000..58456e41 --- /dev/null +++ b/api_docs/audiocraft/models/musicgen.html @@ -0,0 +1,1266 @@ + + + + + + +audiocraft.models.musicgen API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.musicgen

+
+
+

Main model for using MusicGen. This will combine all the required components +and provide easy access to the generation API.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Main model for using MusicGen. This will combine all the required components
+and provide easy access to the generation API.
+"""
+
+import typing as tp
+import warnings
+
+import torch
+
+from .encodec import CompressionModel
+from .lm import LMModel
+from .builders import get_debug_compression_model, get_debug_lm_model
+from .loaders import load_compression_model, load_lm_model
+from ..data.audio_utils import convert_audio
+from ..modules.conditioners import ConditioningAttributes, WavCondition
+from ..utils.autocast import TorchAutocast
+
+
+MelodyList = tp.List[tp.Optional[torch.Tensor]]
+MelodyType = tp.Union[torch.Tensor, MelodyList]
+
+
+# backward compatible names mapping
+_HF_MODEL_CHECKPOINTS_MAP = {
+    "small": "facebook/musicgen-small",
+    "medium": "facebook/musicgen-medium",
+    "large": "facebook/musicgen-large",
+    "melody": "facebook/musicgen-melody",
+}
+
+
+class MusicGen:
+    """MusicGen main model with convenient generation API.
+
+    Args:
+        name (str): name of the model.
+        compression_model (CompressionModel): Compression model
+            used to map audio to invertible discrete representations.
+        lm (LMModel): Language model over discrete representations.
+        max_duration (float, optional): maximum duration the model can produce,
+            otherwise, inferred from the training params.
+    """
+    def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+                 max_duration: tp.Optional[float] = None):
+        self.name = name
+        self.compression_model = compression_model
+        self.lm = lm
+        # Just to be safe, let's put everything in eval mode.
+        self.compression_model.eval()
+        self.lm.eval()
+
+        if max_duration is None:
+            if hasattr(lm, 'cfg'):
+                max_duration = lm.cfg.dataset.segment_duration  # type: ignore
+            else:
+                raise ValueError("You must provide max_duration when building directly MusicGen")
+        assert max_duration is not None
+        self.max_duration: float = max_duration
+        self.device = next(iter(lm.parameters())).device
+        self.generation_params: dict = {}
+        self.set_generation_params(duration=15)  # 15 seconds by default
+        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+        if self.device.type == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+        else:
+            self.autocast = TorchAutocast(
+                enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+    @property
+    def frame_rate(self) -> float:
+        """Roughly the number of AR steps per seconds."""
+        return self.compression_model.frame_rate
+
+    @property
+    def sample_rate(self) -> int:
+        """Sample rate of the generated audio."""
+        return self.compression_model.sample_rate
+
+    @property
+    def audio_channels(self) -> int:
+        """Audio channels of the generated audio."""
+        return self.compression_model.channels
+
+    @staticmethod
+    def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
+        """Return pretrained model, we provide four models:
+        - facebook/musicgen-small (300M), text to music,
+          # see: https://huggingface.co/facebook/musicgen-small
+        - facebook/musicgen-medium (1.5B), text to music,
+          # see: https://huggingface.co/facebook/musicgen-medium
+        - facebook/musicgen-melody (1.5B) text to music and text+melody to music,
+          # see: https://huggingface.co/facebook/musicgen-melody
+        - facebook/musicgen-large (3.3B), text to music,
+          # see: https://huggingface.co/facebook/musicgen-large
+        """
+        if device is None:
+            if torch.cuda.device_count():
+                device = 'cuda'
+            else:
+                device = 'cpu'
+
+        if name == 'debug':
+            # used only for unit tests
+            compression_model = get_debug_compression_model(device)
+            lm = get_debug_lm_model(device)
+            return MusicGen(name, compression_model, lm, max_duration=30)
+
+        if name in _HF_MODEL_CHECKPOINTS_MAP:
+            warnings.warn(
+                "MusicGen pretrained model relying on deprecated checkpoint mapping. " +
+                f"Please use full pre-trained id instead: facebook/musicgen-{name}")
+            name = _HF_MODEL_CHECKPOINTS_MAP[name]
+
+        lm = load_lm_model(name, device=device)
+        compression_model = load_compression_model(name, device=device)
+        if 'self_wav' in lm.condition_provider.conditioners:
+            lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+        return MusicGen(name, compression_model, lm)
+
+    def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                              top_p: float = 0.0, temperature: float = 1.0,
+                              duration: float = 30.0, cfg_coef: float = 3.0,
+                              two_step_cfg: bool = False, extend_stride: float = 18):
+        """Set the generation parameters for MusicGen.
+
+        Args:
+            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+            top_k (int, optional): top_k used for sampling. Defaults to 250.
+            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+            temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+            duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+            cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+            two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+                instead of batching together the two. This has some impact on how things
+                are padded but seems to have little impact in practice.
+            extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+                should we extend the audio each time. Larger values will mean less context is
+                preserved, and shorter value will require extra computations.
+        """
+        assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+        self.extend_stride = extend_stride
+        self.duration = duration
+        self.generation_params = {
+            'use_sampling': use_sampling,
+            'temp': temperature,
+            'top_k': top_k,
+            'top_p': top_p,
+            'cfg_coef': cfg_coef,
+            'two_step_cfg': two_step_cfg,
+        }
+
+    def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+        """Override the default progress callback."""
+        self._progress_callback = progress_callback
+
+    def generate_unconditional(self, num_samples: int, progress: bool = False,
+                               return_tokens: bool = False) -> tp.Union[torch.Tensor,
+                                                                        tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples in an unconditional manner.
+
+        Args:
+            num_samples (int): Number of samples to be generated.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
+            -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples conditioned on text.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        assert prompt_tokens is None
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
+                             melody_sample_rate: int, progress: bool = False,
+                             return_tokens: bool = False) -> tp.Union[torch.Tensor,
+                                                                      tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples conditioned on text and melody.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+                melody conditioning. Should have shape [B, C, T] with B matching the description length,
+                C=1 or 2. It can be [C, T] if there is a single description. It can also be
+                a list of [C, T] tensors.
+            melody_sample_rate: (int): Sample rate of the melody waveforms.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if isinstance(melody_wavs, torch.Tensor):
+            if melody_wavs.dim() == 2:
+                melody_wavs = melody_wavs[None]
+            if melody_wavs.dim() != 3:
+                raise ValueError("Melody wavs should have a shape [B, C, T].")
+            melody_wavs = list(melody_wavs)
+        else:
+            for melody in melody_wavs:
+                if melody is not None:
+                    assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+        melody_wavs = [
+            convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+            if wav is not None else None
+            for wav in melody_wavs]
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+                                                                        melody_wavs=melody_wavs)
+        assert prompt_tokens is None
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                              descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                              progress: bool = False, return_tokens: bool = False) \
+            -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples conditioned on audio prompts.
+
+        Args:
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+                Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+            prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+            descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if prompt.dim() == 2:
+            prompt = prompt[None]
+        if prompt.dim() != 3:
+            raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+        prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+        if descriptions is None:
+            descriptions = [None] * len(prompt)
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+        assert prompt_tokens is not None
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    @torch.no_grad()
+    def _prepare_tokens_and_attributes(
+            self,
+            descriptions: tp.Sequence[tp.Optional[str]],
+            prompt: tp.Optional[torch.Tensor],
+            melody_wavs: tp.Optional[MelodyList] = None,
+    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+        """Prepare model inputs.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+            melody_wavs (torch.Tensor, optional): A batch of waveforms
+                used as melody conditioning. Defaults to None.
+        """
+        attributes = [
+            ConditioningAttributes(text={'description': description})
+            for description in descriptions]
+
+        if melody_wavs is None:
+            for attr in attributes:
+                attr.wav['self_wav'] = WavCondition(
+                    torch.zeros((1, 1, 1), device=self.device),
+                    torch.tensor([0], device=self.device),
+                    sample_rate=[self.sample_rate],
+                    path=[None])
+        else:
+            if 'self_wav' not in self.lm.condition_provider.conditioners:
+                raise RuntimeError("This model doesn't support melody conditioning. "
+                                   "Use the `melody` model.")
+            assert len(melody_wavs) == len(descriptions), \
+                f"number of melody wavs must match number of descriptions! " \
+                f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
+            for attr, melody in zip(attributes, melody_wavs):
+                if melody is None:
+                    attr.wav['self_wav'] = WavCondition(
+                        torch.zeros((1, 1, 1), device=self.device),
+                        torch.tensor([0], device=self.device),
+                        sample_rate=[self.sample_rate],
+                        path=[None])
+                else:
+                    attr.wav['self_wav'] = WavCondition(
+                        melody[None].to(device=self.device),
+                        torch.tensor([melody.shape[-1]], device=self.device),
+                        sample_rate=[self.sample_rate],
+                        path=[None],
+                    )
+
+        if prompt is not None:
+            if descriptions is not None:
+                assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+            prompt = prompt.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt)
+            assert scale is None
+        else:
+            prompt_tokens = None
+        return attributes, prompt_tokens
+
+    def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+                         prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+        """Generate discrete audio tokens given audio prompt and/or conditions.
+
+        Args:
+            attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
+            prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        Returns:
+            torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+        """
+        total_gen_len = int(self.duration * self.frame_rate)
+        max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+        current_gen_offset: int = 0
+
+        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+            generated_tokens += current_gen_offset
+            if self._progress_callback is not None:
+                # Note that total_gen_len might be quite wrong depending on the
+                # codebook pattern used, but with delay it is almost accurate.
+                self._progress_callback(generated_tokens, total_gen_len)
+            else:
+                print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+        if prompt_tokens is not None:
+            assert max_prompt_len >= prompt_tokens.shape[-1], \
+                "Prompt is longer than audio to generate"
+
+        callback = None
+        if progress:
+            callback = _progress_callback
+
+        if self.duration <= self.max_duration:
+            # generate by sampling from LM, simple case.
+            with self.autocast:
+                gen_tokens = self.lm.generate(
+                    prompt_tokens, attributes,
+                    callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+        else:
+            # now this gets a bit messier, we need to handle prompts,
+            # melody conditioning etc.
+            ref_wavs = [attr.wav['self_wav'] for attr in attributes]
+            all_tokens = []
+            if prompt_tokens is None:
+                prompt_length = 0
+            else:
+                all_tokens.append(prompt_tokens)
+                prompt_length = prompt_tokens.shape[-1]
+
+            stride_tokens = int(self.frame_rate * self.extend_stride)
+
+            while current_gen_offset + prompt_length < total_gen_len:
+                time_offset = current_gen_offset / self.frame_rate
+                chunk_duration = min(self.duration - time_offset, self.max_duration)
+                max_gen_len = int(chunk_duration * self.frame_rate)
+                for attr, ref_wav in zip(attributes, ref_wavs):
+                    wav_length = ref_wav.length.item()
+                    if wav_length == 0:
+                        continue
+                    # We will extend the wav periodically if it not long enough.
+                    # we have to do it here rather than in conditioners.py as otherwise
+                    # we wouldn't have the full wav.
+                    initial_position = int(time_offset * self.sample_rate)
+                    wav_target_length = int(self.max_duration * self.sample_rate)
+                    positions = torch.arange(initial_position,
+                                             initial_position + wav_target_length, device=self.device)
+                    attr.wav['self_wav'] = WavCondition(
+                        ref_wav[0][..., positions % wav_length],
+                        torch.full_like(ref_wav[1], wav_target_length),
+                        [self.sample_rate] * ref_wav[0].size(0),
+                        [None], [0.])
+                with self.autocast:
+                    gen_tokens = self.lm.generate(
+                        prompt_tokens, attributes,
+                        callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+                if prompt_tokens is None:
+                    all_tokens.append(gen_tokens)
+                else:
+                    all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+                prompt_tokens = gen_tokens[:, :, stride_tokens:]
+                prompt_length = prompt_tokens.shape[-1]
+                current_gen_offset += stride_tokens
+
+            gen_tokens = torch.cat(all_tokens, dim=-1)
+        return gen_tokens
+
+    def generate_audio(self, gen_tokens: torch.Tensor):
+        """Generate Audio from tokens"""
+        assert gen_tokens.dim() == 3
+        with torch.no_grad():
+            gen_audio = self.compression_model.decode(gen_tokens, None)
+        return gen_audio
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MusicGen +(name: str, compression_model: CompressionModel, lm: LMModel, max_duration: Optional[float] = None) +
+
+

MusicGen main model with convenient generation API.

+

Args

+
+
name : str
+
name of the model.
+
compression_model : CompressionModel
+
Compression model +used to map audio to invertible discrete representations.
+
lm : LMModel
+
Language model over discrete representations.
+
max_duration : float, optional
+
maximum duration the model can produce, +otherwise, inferred from the training params.
+
+
+ +Expand source code + +
class MusicGen:
+    """MusicGen main model with convenient generation API.
+
+    Args:
+        name (str): name of the model.
+        compression_model (CompressionModel): Compression model
+            used to map audio to invertible discrete representations.
+        lm (LMModel): Language model over discrete representations.
+        max_duration (float, optional): maximum duration the model can produce,
+            otherwise, inferred from the training params.
+    """
+    def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+                 max_duration: tp.Optional[float] = None):
+        self.name = name
+        self.compression_model = compression_model
+        self.lm = lm
+        # Just to be safe, let's put everything in eval mode.
+        self.compression_model.eval()
+        self.lm.eval()
+
+        if max_duration is None:
+            if hasattr(lm, 'cfg'):
+                max_duration = lm.cfg.dataset.segment_duration  # type: ignore
+            else:
+                raise ValueError("You must provide max_duration when building directly MusicGen")
+        assert max_duration is not None
+        self.max_duration: float = max_duration
+        self.device = next(iter(lm.parameters())).device
+        self.generation_params: dict = {}
+        self.set_generation_params(duration=15)  # 15 seconds by default
+        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+        if self.device.type == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+        else:
+            self.autocast = TorchAutocast(
+                enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+    @property
+    def frame_rate(self) -> float:
+        """Roughly the number of AR steps per seconds."""
+        return self.compression_model.frame_rate
+
+    @property
+    def sample_rate(self) -> int:
+        """Sample rate of the generated audio."""
+        return self.compression_model.sample_rate
+
+    @property
+    def audio_channels(self) -> int:
+        """Audio channels of the generated audio."""
+        return self.compression_model.channels
+
+    @staticmethod
+    def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
+        """Return pretrained model, we provide four models:
+        - facebook/musicgen-small (300M), text to music,
+          # see: https://huggingface.co/facebook/musicgen-small
+        - facebook/musicgen-medium (1.5B), text to music,
+          # see: https://huggingface.co/facebook/musicgen-medium
+        - facebook/musicgen-melody (1.5B) text to music and text+melody to music,
+          # see: https://huggingface.co/facebook/musicgen-melody
+        - facebook/musicgen-large (3.3B), text to music,
+          # see: https://huggingface.co/facebook/musicgen-large
+        """
+        if device is None:
+            if torch.cuda.device_count():
+                device = 'cuda'
+            else:
+                device = 'cpu'
+
+        if name == 'debug':
+            # used only for unit tests
+            compression_model = get_debug_compression_model(device)
+            lm = get_debug_lm_model(device)
+            return MusicGen(name, compression_model, lm, max_duration=30)
+
+        if name in _HF_MODEL_CHECKPOINTS_MAP:
+            warnings.warn(
+                "MusicGen pretrained model relying on deprecated checkpoint mapping. " +
+                f"Please use full pre-trained id instead: facebook/musicgen-{name}")
+            name = _HF_MODEL_CHECKPOINTS_MAP[name]
+
+        lm = load_lm_model(name, device=device)
+        compression_model = load_compression_model(name, device=device)
+        if 'self_wav' in lm.condition_provider.conditioners:
+            lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+        return MusicGen(name, compression_model, lm)
+
+    def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                              top_p: float = 0.0, temperature: float = 1.0,
+                              duration: float = 30.0, cfg_coef: float = 3.0,
+                              two_step_cfg: bool = False, extend_stride: float = 18):
+        """Set the generation parameters for MusicGen.
+
+        Args:
+            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+            top_k (int, optional): top_k used for sampling. Defaults to 250.
+            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+            temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+            duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+            cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+            two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+                instead of batching together the two. This has some impact on how things
+                are padded but seems to have little impact in practice.
+            extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+                should we extend the audio each time. Larger values will mean less context is
+                preserved, and shorter value will require extra computations.
+        """
+        assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+        self.extend_stride = extend_stride
+        self.duration = duration
+        self.generation_params = {
+            'use_sampling': use_sampling,
+            'temp': temperature,
+            'top_k': top_k,
+            'top_p': top_p,
+            'cfg_coef': cfg_coef,
+            'two_step_cfg': two_step_cfg,
+        }
+
+    def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+        """Override the default progress callback."""
+        self._progress_callback = progress_callback
+
+    def generate_unconditional(self, num_samples: int, progress: bool = False,
+                               return_tokens: bool = False) -> tp.Union[torch.Tensor,
+                                                                        tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples in an unconditional manner.
+
+        Args:
+            num_samples (int): Number of samples to be generated.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
+            -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples conditioned on text.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        assert prompt_tokens is None
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
+                             melody_sample_rate: int, progress: bool = False,
+                             return_tokens: bool = False) -> tp.Union[torch.Tensor,
+                                                                      tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples conditioned on text and melody.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+                melody conditioning. Should have shape [B, C, T] with B matching the description length,
+                C=1 or 2. It can be [C, T] if there is a single description. It can also be
+                a list of [C, T] tensors.
+            melody_sample_rate: (int): Sample rate of the melody waveforms.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if isinstance(melody_wavs, torch.Tensor):
+            if melody_wavs.dim() == 2:
+                melody_wavs = melody_wavs[None]
+            if melody_wavs.dim() != 3:
+                raise ValueError("Melody wavs should have a shape [B, C, T].")
+            melody_wavs = list(melody_wavs)
+        else:
+            for melody in melody_wavs:
+                if melody is not None:
+                    assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+        melody_wavs = [
+            convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+            if wav is not None else None
+            for wav in melody_wavs]
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+                                                                        melody_wavs=melody_wavs)
+        assert prompt_tokens is None
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                              descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                              progress: bool = False, return_tokens: bool = False) \
+            -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+        """Generate samples conditioned on audio prompts.
+
+        Args:
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+                Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+            prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+            descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if prompt.dim() == 2:
+            prompt = prompt[None]
+        if prompt.dim() != 3:
+            raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+        prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+        if descriptions is None:
+            descriptions = [None] * len(prompt)
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+        assert prompt_tokens is not None
+        tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+        if return_tokens:
+            return self.generate_audio(tokens), tokens
+        return self.generate_audio(tokens)
+
+    @torch.no_grad()
+    def _prepare_tokens_and_attributes(
+            self,
+            descriptions: tp.Sequence[tp.Optional[str]],
+            prompt: tp.Optional[torch.Tensor],
+            melody_wavs: tp.Optional[MelodyList] = None,
+    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+        """Prepare model inputs.
+
+        Args:
+            descriptions (list of str): A list of strings used as text conditioning.
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+            melody_wavs (torch.Tensor, optional): A batch of waveforms
+                used as melody conditioning. Defaults to None.
+        """
+        attributes = [
+            ConditioningAttributes(text={'description': description})
+            for description in descriptions]
+
+        if melody_wavs is None:
+            for attr in attributes:
+                attr.wav['self_wav'] = WavCondition(
+                    torch.zeros((1, 1, 1), device=self.device),
+                    torch.tensor([0], device=self.device),
+                    sample_rate=[self.sample_rate],
+                    path=[None])
+        else:
+            if 'self_wav' not in self.lm.condition_provider.conditioners:
+                raise RuntimeError("This model doesn't support melody conditioning. "
+                                   "Use the `melody` model.")
+            assert len(melody_wavs) == len(descriptions), \
+                f"number of melody wavs must match number of descriptions! " \
+                f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
+            for attr, melody in zip(attributes, melody_wavs):
+                if melody is None:
+                    attr.wav['self_wav'] = WavCondition(
+                        torch.zeros((1, 1, 1), device=self.device),
+                        torch.tensor([0], device=self.device),
+                        sample_rate=[self.sample_rate],
+                        path=[None])
+                else:
+                    attr.wav['self_wav'] = WavCondition(
+                        melody[None].to(device=self.device),
+                        torch.tensor([melody.shape[-1]], device=self.device),
+                        sample_rate=[self.sample_rate],
+                        path=[None],
+                    )
+
+        if prompt is not None:
+            if descriptions is not None:
+                assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+            prompt = prompt.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt)
+            assert scale is None
+        else:
+            prompt_tokens = None
+        return attributes, prompt_tokens
+
+    def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+                         prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+        """Generate discrete audio tokens given audio prompt and/or conditions.
+
+        Args:
+            attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
+            prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        Returns:
+            torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+        """
+        total_gen_len = int(self.duration * self.frame_rate)
+        max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+        current_gen_offset: int = 0
+
+        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+            generated_tokens += current_gen_offset
+            if self._progress_callback is not None:
+                # Note that total_gen_len might be quite wrong depending on the
+                # codebook pattern used, but with delay it is almost accurate.
+                self._progress_callback(generated_tokens, total_gen_len)
+            else:
+                print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+        if prompt_tokens is not None:
+            assert max_prompt_len >= prompt_tokens.shape[-1], \
+                "Prompt is longer than audio to generate"
+
+        callback = None
+        if progress:
+            callback = _progress_callback
+
+        if self.duration <= self.max_duration:
+            # generate by sampling from LM, simple case.
+            with self.autocast:
+                gen_tokens = self.lm.generate(
+                    prompt_tokens, attributes,
+                    callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+        else:
+            # now this gets a bit messier, we need to handle prompts,
+            # melody conditioning etc.
+            ref_wavs = [attr.wav['self_wav'] for attr in attributes]
+            all_tokens = []
+            if prompt_tokens is None:
+                prompt_length = 0
+            else:
+                all_tokens.append(prompt_tokens)
+                prompt_length = prompt_tokens.shape[-1]
+
+            stride_tokens = int(self.frame_rate * self.extend_stride)
+
+            while current_gen_offset + prompt_length < total_gen_len:
+                time_offset = current_gen_offset / self.frame_rate
+                chunk_duration = min(self.duration - time_offset, self.max_duration)
+                max_gen_len = int(chunk_duration * self.frame_rate)
+                for attr, ref_wav in zip(attributes, ref_wavs):
+                    wav_length = ref_wav.length.item()
+                    if wav_length == 0:
+                        continue
+                    # We will extend the wav periodically if it not long enough.
+                    # we have to do it here rather than in conditioners.py as otherwise
+                    # we wouldn't have the full wav.
+                    initial_position = int(time_offset * self.sample_rate)
+                    wav_target_length = int(self.max_duration * self.sample_rate)
+                    positions = torch.arange(initial_position,
+                                             initial_position + wav_target_length, device=self.device)
+                    attr.wav['self_wav'] = WavCondition(
+                        ref_wav[0][..., positions % wav_length],
+                        torch.full_like(ref_wav[1], wav_target_length),
+                        [self.sample_rate] * ref_wav[0].size(0),
+                        [None], [0.])
+                with self.autocast:
+                    gen_tokens = self.lm.generate(
+                        prompt_tokens, attributes,
+                        callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+                if prompt_tokens is None:
+                    all_tokens.append(gen_tokens)
+                else:
+                    all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+                prompt_tokens = gen_tokens[:, :, stride_tokens:]
+                prompt_length = prompt_tokens.shape[-1]
+                current_gen_offset += stride_tokens
+
+            gen_tokens = torch.cat(all_tokens, dim=-1)
+        return gen_tokens
+
+    def generate_audio(self, gen_tokens: torch.Tensor):
+        """Generate Audio from tokens"""
+        assert gen_tokens.dim() == 3
+        with torch.no_grad():
+            gen_audio = self.compression_model.decode(gen_tokens, None)
+        return gen_audio
+
+

Static methods

+
+
+def get_pretrained(name: str = 'facebook/musicgen-melody', device=None) +
+
+

Return pretrained model, we provide four models: +- facebook/musicgen-small (300M), text to music, +# see: https://huggingface.co/facebook/musicgen-small +- facebook/musicgen-medium (1.5B), text to music, +# see: https://huggingface.co/facebook/musicgen-medium +- facebook/musicgen-melody (1.5B) text to music and text+melody to music, +# see: https://huggingface.co/facebook/musicgen-melody +- facebook/musicgen-large (3.3B), text to music, +# see: https://huggingface.co/facebook/musicgen-large

+
+ +Expand source code + +
@staticmethod
+def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
+    """Return pretrained model, we provide four models:
+    - facebook/musicgen-small (300M), text to music,
+      # see: https://huggingface.co/facebook/musicgen-small
+    - facebook/musicgen-medium (1.5B), text to music,
+      # see: https://huggingface.co/facebook/musicgen-medium
+    - facebook/musicgen-melody (1.5B) text to music and text+melody to music,
+      # see: https://huggingface.co/facebook/musicgen-melody
+    - facebook/musicgen-large (3.3B), text to music,
+      # see: https://huggingface.co/facebook/musicgen-large
+    """
+    if device is None:
+        if torch.cuda.device_count():
+            device = 'cuda'
+        else:
+            device = 'cpu'
+
+    if name == 'debug':
+        # used only for unit tests
+        compression_model = get_debug_compression_model(device)
+        lm = get_debug_lm_model(device)
+        return MusicGen(name, compression_model, lm, max_duration=30)
+
+    if name in _HF_MODEL_CHECKPOINTS_MAP:
+        warnings.warn(
+            "MusicGen pretrained model relying on deprecated checkpoint mapping. " +
+            f"Please use full pre-trained id instead: facebook/musicgen-{name}")
+        name = _HF_MODEL_CHECKPOINTS_MAP[name]
+
+    lm = load_lm_model(name, device=device)
+    compression_model = load_compression_model(name, device=device)
+    if 'self_wav' in lm.condition_provider.conditioners:
+        lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+    return MusicGen(name, compression_model, lm)
+
+
+
+

Instance variables

+
+
var audio_channels : int
+
+

Audio channels of the generated audio.

+
+ +Expand source code + +
@property
+def audio_channels(self) -> int:
+    """Audio channels of the generated audio."""
+    return self.compression_model.channels
+
+
+
var frame_rate : float
+
+

Roughly the number of AR steps per seconds.

+
+ +Expand source code + +
@property
+def frame_rate(self) -> float:
+    """Roughly the number of AR steps per seconds."""
+    return self.compression_model.frame_rate
+
+
+
var sample_rate : int
+
+

Sample rate of the generated audio.

+
+ +Expand source code + +
@property
+def sample_rate(self) -> int:
+    """Sample rate of the generated audio."""
+    return self.compression_model.sample_rate
+
+
+
+

Methods

+
+
+def generate(self, descriptions: List[str], progress: bool = False, return_tokens: bool = False) ‑> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] +
+
+

Generate samples conditioned on text.

+

Args

+
+
descriptions : list of str
+
A list of strings used as text conditioning.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
+        -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+    """Generate samples conditioned on text.
+
+    Args:
+        descriptions (list of str): A list of strings used as text conditioning.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+    assert prompt_tokens is None
+    tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+    if return_tokens:
+        return self.generate_audio(tokens), tokens
+    return self.generate_audio(tokens)
+
+
+
+def generate_audio(self, gen_tokens: torch.Tensor) +
+
+

Generate Audio from tokens

+
+ +Expand source code + +
def generate_audio(self, gen_tokens: torch.Tensor):
+    """Generate Audio from tokens"""
+    assert gen_tokens.dim() == 3
+    with torch.no_grad():
+        gen_audio = self.compression_model.decode(gen_tokens, None)
+    return gen_audio
+
+
+
+def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: Optional[List[Optional[str]]] = None, progress: bool = False, return_tokens: bool = False) ‑> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] +
+
+

Generate samples conditioned on audio prompts.

+

Args

+
+
prompt : torch.Tensor
+
A batch of waveforms used for continuation. +Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+
prompt_sample_rate : int
+
Sampling rate of the given audio waveforms.
+
descriptions : list of str, optional
+
A list of strings used as text conditioning. Defaults to None.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                          descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                          progress: bool = False, return_tokens: bool = False) \
+        -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
+    """Generate samples conditioned on audio prompts.
+
+    Args:
+        prompt (torch.Tensor): A batch of waveforms used for continuation.
+            Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+        prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+        descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    if prompt.dim() == 2:
+        prompt = prompt[None]
+    if prompt.dim() != 3:
+        raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+    prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+    if descriptions is None:
+        descriptions = [None] * len(prompt)
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+    assert prompt_tokens is not None
+    tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+    if return_tokens:
+        return self.generate_audio(tokens), tokens
+    return self.generate_audio(tokens)
+
+
+
+def generate_unconditional(self, num_samples: int, progress: bool = False, return_tokens: bool = False) ‑> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] +
+
+

Generate samples in an unconditional manner.

+

Args

+
+
num_samples : int
+
Number of samples to be generated.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate_unconditional(self, num_samples: int, progress: bool = False,
+                           return_tokens: bool = False) -> tp.Union[torch.Tensor,
+                                                                    tp.Tuple[torch.Tensor, torch.Tensor]]:
+    """Generate samples in an unconditional manner.
+
+    Args:
+        num_samples (int): Number of samples to be generated.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+    tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+    if return_tokens:
+        return self.generate_audio(tokens), tokens
+    return self.generate_audio(tokens)
+
+
+
+def generate_with_chroma(self, descriptions: List[str], melody_wavs: Union[torch.Tensor, List[Optional[torch.Tensor]]], melody_sample_rate: int, progress: bool = False, return_tokens: bool = False) ‑> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] +
+
+

Generate samples conditioned on text and melody.

+

Args

+
+
descriptions : list of str
+
A list of strings used as text conditioning.
+
melody_wavs
+
(torch.Tensor or list of Tensor): A batch of waveforms used as +melody conditioning. Should have shape [B, C, T] with B matching the description length, +C=1 or 2. It can be [C, T] if there is a single description. It can also be +a list of [C, T] tensors.
+
melody_sample_rate
+
(int): Sample rate of the melody waveforms.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
+                         melody_sample_rate: int, progress: bool = False,
+                         return_tokens: bool = False) -> tp.Union[torch.Tensor,
+                                                                  tp.Tuple[torch.Tensor, torch.Tensor]]:
+    """Generate samples conditioned on text and melody.
+
+    Args:
+        descriptions (list of str): A list of strings used as text conditioning.
+        melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+            melody conditioning. Should have shape [B, C, T] with B matching the description length,
+            C=1 or 2. It can be [C, T] if there is a single description. It can also be
+            a list of [C, T] tensors.
+        melody_sample_rate: (int): Sample rate of the melody waveforms.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    if isinstance(melody_wavs, torch.Tensor):
+        if melody_wavs.dim() == 2:
+            melody_wavs = melody_wavs[None]
+        if melody_wavs.dim() != 3:
+            raise ValueError("Melody wavs should have a shape [B, C, T].")
+        melody_wavs = list(melody_wavs)
+    else:
+        for melody in melody_wavs:
+            if melody is not None:
+                assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+    melody_wavs = [
+        convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+        if wav is not None else None
+        for wav in melody_wavs]
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+                                                                    melody_wavs=melody_wavs)
+    assert prompt_tokens is None
+    tokens = self._generate_tokens(attributes, prompt_tokens, progress)
+    if return_tokens:
+        return self.generate_audio(tokens), tokens
+    return self.generate_audio(tokens)
+
+
+
+def set_custom_progress_callback(self, progress_callback: Optional[Callable[[int, int], None]] = None) +
+
+

Override the default progress callback.

+
+ +Expand source code + +
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+    """Override the default progress callback."""
+    self._progress_callback = progress_callback
+
+
+
+def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, duration: float = 30.0, cfg_coef: float = 3.0, two_step_cfg: bool = False, extend_stride: float = 18) +
+
+

Set the generation parameters for MusicGen.

+

Args

+
+
use_sampling : bool, optional
+
Use sampling if True, else do argmax decoding. Defaults to True.
+
top_k : int, optional
+
top_k used for sampling. Defaults to 250.
+
top_p : float, optional
+
top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+
temperature : float, optional
+
Softmax temperature parameter. Defaults to 1.0.
+
duration : float, optional
+
Duration of the generated waveform. Defaults to 30.0.
+
cfg_coef : float, optional
+
Coefficient used for classifier free guidance. Defaults to 3.0.
+
two_step_cfg : bool, optional
+
If True, performs 2 forward for Classifier Free Guidance, +instead of batching together the two. This has some impact on how things +are padded but seems to have little impact in practice.
+
extend_stride
+
when doing extended generation (i.e. more than 30 seconds), by how much +should we extend the audio each time. Larger values will mean less context is +preserved, and shorter value will require extra computations.
+
+
+ +Expand source code + +
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                          top_p: float = 0.0, temperature: float = 1.0,
+                          duration: float = 30.0, cfg_coef: float = 3.0,
+                          two_step_cfg: bool = False, extend_stride: float = 18):
+    """Set the generation parameters for MusicGen.
+
+    Args:
+        use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+        top_k (int, optional): top_k used for sampling. Defaults to 250.
+        top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+        temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+        duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+        cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+        two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+            instead of batching together the two. This has some impact on how things
+            are padded but seems to have little impact in practice.
+        extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+            should we extend the audio each time. Larger values will mean less context is
+            preserved, and shorter value will require extra computations.
+    """
+    assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+    self.extend_stride = extend_stride
+    self.duration = duration
+    self.generation_params = {
+        'use_sampling': use_sampling,
+        'temp': temperature,
+        'top_k': top_k,
+        'top_p': top_p,
+        'cfg_coef': cfg_coef,
+        'two_step_cfg': two_step_cfg,
+    }
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/models/unet.html b/api_docs/audiocraft/models/unet.html new file mode 100644 index 00000000..defbb8c7 --- /dev/null +++ b/api_docs/audiocraft/models/unet.html @@ -0,0 +1,1004 @@ + + + + + + +audiocraft.models.unet API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.unet

+
+
+

Pytorch Unet Module used for diffusion.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Pytorch Unet Module used for diffusion.
+"""
+
+from dataclasses import dataclass
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
+
+
+@dataclass
+class Output:
+    sample: torch.Tensor
+
+
+def get_model(cfg, channels: int, side: int, num_steps: int):
+    if cfg.model == 'unet':
+        return DiffusionUnet(
+            chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
+    else:
+        raise RuntimeError('Not Implemented')
+
+
+class ResBlock(nn.Module):
+    def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
+                 dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+                 dropout: float = 0.):
+        super().__init__()
+        stride = 1
+        padding = dilation * (kernel - stride) // 2
+        Conv = nn.Conv1d
+        Drop = nn.Dropout1d
+        self.norm1 = nn.GroupNorm(norm_groups, channels)
+        self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+        self.activation1 = activation()
+        self.dropout1 = Drop(dropout)
+
+        self.norm2 = nn.GroupNorm(norm_groups, channels)
+        self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+        self.activation2 = activation()
+        self.dropout2 = Drop(dropout)
+
+    def forward(self, x):
+        h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
+        h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
+        return x + h
+
+
+class DecoderLayer(nn.Module):
+    def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
+                 norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+                 dropout: float = 0.):
+        super().__init__()
+        padding = (kernel - stride) // 2
+        self.res_blocks = nn.Sequential(
+            *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
+              for idx in range(res_blocks)])
+        self.norm = nn.GroupNorm(norm_groups, chin)
+        ConvTr = nn.ConvTranspose1d
+        self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
+        self.activation = activation()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.res_blocks(x)
+        x = self.norm(x)
+        x = self.activation(x)
+        x = self.convtr(x)
+        return x
+
+
+class EncoderLayer(nn.Module):
+    def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
+                 norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+                 dropout: float = 0.):
+        super().__init__()
+        padding = (kernel - stride) // 2
+        Conv = nn.Conv1d
+        self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
+        self.norm = nn.GroupNorm(norm_groups, chout)
+        self.activation = activation()
+        self.res_blocks = nn.Sequential(
+            *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
+              for idx in range(res_blocks)])
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        B, C, T = x.shape
+        stride, = self.conv.stride
+        pad = (stride - (T % stride)) % stride
+        x = F.pad(x, (0, pad))
+
+        x = self.conv(x)
+        x = self.norm(x)
+        x = self.activation(x)
+        x = self.res_blocks(x)
+        return x
+
+
+class BLSTM(nn.Module):
+    """BiLSTM with same hidden units as input dim.
+    """
+    def __init__(self, dim, layers=2):
+        super().__init__()
+        self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
+        self.linear = nn.Linear(2 * dim, dim)
+
+    def forward(self, x):
+        x = x.permute(2, 0, 1)
+        x = self.lstm(x)[0]
+        x = self.linear(x)
+        x = x.permute(1, 2, 0)
+        return x
+
+
+class DiffusionUnet(nn.Module):
+    def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
+                 max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
+                 bilstm: bool = False, transformer: bool = False,
+                 codec_dim: tp.Optional[int] = None, **kwargs):
+        super().__init__()
+        self.encoders = nn.ModuleList()
+        self.decoders = nn.ModuleList()
+        self.embeddings: tp.Optional[nn.ModuleList] = None
+        self.embedding = nn.Embedding(num_steps, hidden)
+        if emb_all_layers:
+            self.embeddings = nn.ModuleList()
+        self.condition_embedding: tp.Optional[nn.Module] = None
+        for d in range(depth):
+            encoder = EncoderLayer(chin, hidden, **kwargs)
+            decoder = DecoderLayer(hidden, chin, **kwargs)
+            self.encoders.append(encoder)
+            self.decoders.insert(0, decoder)
+            if emb_all_layers and d > 0:
+                assert self.embeddings is not None
+                self.embeddings.append(nn.Embedding(num_steps, hidden))
+            chin = hidden
+            hidden = min(int(chin * growth), max_channels)
+        self.bilstm: tp.Optional[nn.Module]
+        if bilstm:
+            self.bilstm = BLSTM(chin)
+        else:
+            self.bilstm = None
+        self.use_transformer = transformer
+        self.cross_attention = False
+        if transformer:
+            self.cross_attention = cross_attention
+            self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
+                                                    cross_attention=cross_attention)
+
+        self.use_codec = False
+        if codec_dim is not None:
+            self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
+            self.use_codec = True
+
+    def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
+        skips = []
+        bs = x.size(0)
+        z = x
+        view_args = [1]
+        if type(step) is torch.Tensor:
+            step_tensor = step
+        else:
+            step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
+
+        for idx, encoder in enumerate(self.encoders):
+            z = encoder(z)
+            if idx == 0:
+                z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
+            elif self.embeddings is not None:
+                z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
+
+            skips.append(z)
+
+        if self.use_codec:  # insert condition in the bottleneck
+            assert condition is not None, "Model defined for conditionnal generation"
+            condition_emb = self.conv_codec(condition)  # reshape to the bottleneck dim
+            assert condition_emb.size(-1) <= 2 * z.size(-1), \
+                f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
+            if not self.cross_attention:
+
+                condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
+                assert z.size() == condition_emb.size()
+                z += condition_emb
+                cross_attention_src = None
+            else:
+                cross_attention_src = condition_emb.permute(0, 2, 1)  # B, T, C
+                B, T, C = cross_attention_src.shape
+                positions = torch.arange(T, device=x.device).view(1, -1, 1)
+                pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
+                cross_attention_src = cross_attention_src + pos_emb
+        if self.use_transformer:
+            z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
+        else:
+            if self.bilstm is None:
+                z = torch.zeros_like(z)
+            else:
+                z = self.bilstm(z)
+
+        for decoder in self.decoders:
+            s = skips.pop(-1)
+            z = z[:, :, :s.shape[2]]
+            z = z + s
+            z = decoder(z)
+
+        z = z[:, :, :x.shape[2]]
+        return Output(z)
+
+
+
+
+
+
+
+

Functions

+
+
+def get_model(cfg, channels: int, side: int, num_steps: int) +
+
+
+
+ +Expand source code + +
def get_model(cfg, channels: int, side: int, num_steps: int):
+    if cfg.model == 'unet':
+        return DiffusionUnet(
+            chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
+    else:
+        raise RuntimeError('Not Implemented')
+
+
+
+
+
+

Classes

+
+
+class BLSTM +(dim, layers=2) +
+
+

BiLSTM with same hidden units as input dim.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BLSTM(nn.Module):
+    """BiLSTM with same hidden units as input dim.
+    """
+    def __init__(self, dim, layers=2):
+        super().__init__()
+        self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
+        self.linear = nn.Linear(2 * dim, dim)
+
+    def forward(self, x):
+        x = x.permute(2, 0, 1)
+        x = self.lstm(x)[0]
+        x = self.linear(x)
+        x = x.permute(1, 2, 0)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = x.permute(2, 0, 1)
+    x = self.lstm(x)[0]
+    x = self.linear(x)
+    x = x.permute(1, 2, 0)
+    return x
+
+
+
+
+
+class DecoderLayer +(chin: int, chout: int, kernel: int = 4, stride: int = 2, norm_groups: int = 4, res_blocks: int = 1, activation: Type[torch.nn.modules.module.Module] = torch.nn.modules.activation.ReLU, dropout: float = 0.0) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DecoderLayer(nn.Module):
+    def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
+                 norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+                 dropout: float = 0.):
+        super().__init__()
+        padding = (kernel - stride) // 2
+        self.res_blocks = nn.Sequential(
+            *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
+              for idx in range(res_blocks)])
+        self.norm = nn.GroupNorm(norm_groups, chin)
+        ConvTr = nn.ConvTranspose1d
+        self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
+        self.activation = activation()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.res_blocks(x)
+        x = self.norm(x)
+        x = self.activation(x)
+        x = self.convtr(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor) -> torch.Tensor:
+    x = self.res_blocks(x)
+    x = self.norm(x)
+    x = self.activation(x)
+    x = self.convtr(x)
+    return x
+
+
+
+
+
+class DiffusionUnet +(chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.0, max_channels: int = 10000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, bilstm: bool = False, transformer: bool = False, codec_dim: Optional[int] = None, **kwargs) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DiffusionUnet(nn.Module):
+    def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
+                 max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
+                 bilstm: bool = False, transformer: bool = False,
+                 codec_dim: tp.Optional[int] = None, **kwargs):
+        super().__init__()
+        self.encoders = nn.ModuleList()
+        self.decoders = nn.ModuleList()
+        self.embeddings: tp.Optional[nn.ModuleList] = None
+        self.embedding = nn.Embedding(num_steps, hidden)
+        if emb_all_layers:
+            self.embeddings = nn.ModuleList()
+        self.condition_embedding: tp.Optional[nn.Module] = None
+        for d in range(depth):
+            encoder = EncoderLayer(chin, hidden, **kwargs)
+            decoder = DecoderLayer(hidden, chin, **kwargs)
+            self.encoders.append(encoder)
+            self.decoders.insert(0, decoder)
+            if emb_all_layers and d > 0:
+                assert self.embeddings is not None
+                self.embeddings.append(nn.Embedding(num_steps, hidden))
+            chin = hidden
+            hidden = min(int(chin * growth), max_channels)
+        self.bilstm: tp.Optional[nn.Module]
+        if bilstm:
+            self.bilstm = BLSTM(chin)
+        else:
+            self.bilstm = None
+        self.use_transformer = transformer
+        self.cross_attention = False
+        if transformer:
+            self.cross_attention = cross_attention
+            self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
+                                                    cross_attention=cross_attention)
+
+        self.use_codec = False
+        if codec_dim is not None:
+            self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
+            self.use_codec = True
+
+    def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
+        skips = []
+        bs = x.size(0)
+        z = x
+        view_args = [1]
+        if type(step) is torch.Tensor:
+            step_tensor = step
+        else:
+            step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
+
+        for idx, encoder in enumerate(self.encoders):
+            z = encoder(z)
+            if idx == 0:
+                z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
+            elif self.embeddings is not None:
+                z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
+
+            skips.append(z)
+
+        if self.use_codec:  # insert condition in the bottleneck
+            assert condition is not None, "Model defined for conditionnal generation"
+            condition_emb = self.conv_codec(condition)  # reshape to the bottleneck dim
+            assert condition_emb.size(-1) <= 2 * z.size(-1), \
+                f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
+            if not self.cross_attention:
+
+                condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
+                assert z.size() == condition_emb.size()
+                z += condition_emb
+                cross_attention_src = None
+            else:
+                cross_attention_src = condition_emb.permute(0, 2, 1)  # B, T, C
+                B, T, C = cross_attention_src.shape
+                positions = torch.arange(T, device=x.device).view(1, -1, 1)
+                pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
+                cross_attention_src = cross_attention_src + pos_emb
+        if self.use_transformer:
+            z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
+        else:
+            if self.bilstm is None:
+                z = torch.zeros_like(z)
+            else:
+                z = self.bilstm(z)
+
+        for decoder in self.decoders:
+            s = skips.pop(-1)
+            z = z[:, :, :s.shape[2]]
+            z = z + s
+            z = decoder(z)
+
+        z = z[:, :, :x.shape[2]]
+        return Output(z)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor, step: Union[int, torch.Tensor], condition: Optional[torch.Tensor] = None) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
+    skips = []
+    bs = x.size(0)
+    z = x
+    view_args = [1]
+    if type(step) is torch.Tensor:
+        step_tensor = step
+    else:
+        step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
+
+    for idx, encoder in enumerate(self.encoders):
+        z = encoder(z)
+        if idx == 0:
+            z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
+        elif self.embeddings is not None:
+            z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
+
+        skips.append(z)
+
+    if self.use_codec:  # insert condition in the bottleneck
+        assert condition is not None, "Model defined for conditionnal generation"
+        condition_emb = self.conv_codec(condition)  # reshape to the bottleneck dim
+        assert condition_emb.size(-1) <= 2 * z.size(-1), \
+            f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
+        if not self.cross_attention:
+
+            condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
+            assert z.size() == condition_emb.size()
+            z += condition_emb
+            cross_attention_src = None
+        else:
+            cross_attention_src = condition_emb.permute(0, 2, 1)  # B, T, C
+            B, T, C = cross_attention_src.shape
+            positions = torch.arange(T, device=x.device).view(1, -1, 1)
+            pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
+            cross_attention_src = cross_attention_src + pos_emb
+    if self.use_transformer:
+        z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
+    else:
+        if self.bilstm is None:
+            z = torch.zeros_like(z)
+        else:
+            z = self.bilstm(z)
+
+    for decoder in self.decoders:
+        s = skips.pop(-1)
+        z = z[:, :, :s.shape[2]]
+        z = z + s
+        z = decoder(z)
+
+    z = z[:, :, :x.shape[2]]
+    return Output(z)
+
+
+
+
+
+class EncoderLayer +(chin: int, chout: int, kernel: int = 4, stride: int = 2, norm_groups: int = 4, res_blocks: int = 1, activation: Type[torch.nn.modules.module.Module] = torch.nn.modules.activation.ReLU, dropout: float = 0.0) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class EncoderLayer(nn.Module):
+    def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
+                 norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+                 dropout: float = 0.):
+        super().__init__()
+        padding = (kernel - stride) // 2
+        Conv = nn.Conv1d
+        self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
+        self.norm = nn.GroupNorm(norm_groups, chout)
+        self.activation = activation()
+        self.res_blocks = nn.Sequential(
+            *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
+              for idx in range(res_blocks)])
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        B, C, T = x.shape
+        stride, = self.conv.stride
+        pad = (stride - (T % stride)) % stride
+        x = F.pad(x, (0, pad))
+
+        x = self.conv(x)
+        x = self.norm(x)
+        x = self.activation(x)
+        x = self.res_blocks(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor) -> torch.Tensor:
+    B, C, T = x.shape
+    stride, = self.conv.stride
+    pad = (stride - (T % stride)) % stride
+    x = F.pad(x, (0, pad))
+
+    x = self.conv(x)
+    x = self.norm(x)
+    x = self.activation(x)
+    x = self.res_blocks(x)
+    return x
+
+
+
+
+
+class Output +(sample: torch.Tensor) +
+
+

Output(sample: torch.Tensor)

+
+ +Expand source code + +
class Output:
+    sample: torch.Tensor
+
+

Class variables

+
+
var sample : torch.Tensor
+
+
+
+
+
+
+class ResBlock +(channels: int, kernel: int = 3, norm_groups: int = 4, dilation: int = 1, activation: Type[torch.nn.modules.module.Module] = torch.nn.modules.activation.ReLU, dropout: float = 0.0) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResBlock(nn.Module):
+    def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
+                 dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
+                 dropout: float = 0.):
+        super().__init__()
+        stride = 1
+        padding = dilation * (kernel - stride) // 2
+        Conv = nn.Conv1d
+        Drop = nn.Dropout1d
+        self.norm1 = nn.GroupNorm(norm_groups, channels)
+        self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+        self.activation1 = activation()
+        self.dropout1 = Drop(dropout)
+
+        self.norm2 = nn.GroupNorm(norm_groups, channels)
+        self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
+        self.activation2 = activation()
+        self.dropout2 = Drop(dropout)
+
+    def forward(self, x):
+        h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
+        h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
+        return x + h
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
+    h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
+    return x + h
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/activations.html b/api_docs/audiocraft/modules/activations.html new file mode 100644 index 00000000..f5ec87bf --- /dev/null +++ b/api_docs/audiocraft/modules/activations.html @@ -0,0 +1,523 @@ + + + + + + +audiocraft.modules.activations API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.activations

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import Union, Callable
+
+
+class CustomGLU(nn.Module):
+    """Custom Gated Linear Unit activation.
+    Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
+    of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
+    function (i.e. sigmoid, swish, etc.).
+
+    Args:
+        activation (nn.Module): The custom activation to apply in the Gated Linear Unit
+        dim (int): the dimension on which to split the input. Default: -1
+
+    Shape:
+        - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
+
+    Examples::
+        >>> m = CustomGLU(nn.Sigmoid())
+        >>> input = torch.randn(4, 2)
+        >>> output = m(input)
+    """
+    def __init__(self, activation: nn.Module, dim: int = -1):
+        super(CustomGLU, self).__init__()
+        self.dim = dim
+        self.activation = activation
+
+    def forward(self, x: Tensor):
+        assert x.shape[self.dim] % 2 == 0  # M = N / 2
+        a, b = torch.chunk(x, 2, dim=self.dim)
+        return a * self.activation(b)
+
+
+class SwiGLU(CustomGLU):
+    """SiLU Gated Linear Unit activation.
+    Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(SwiGLU, self).__init__(nn.SiLU(), dim)
+
+
+class GeGLU(CustomGLU):
+    """GeLU Gated Linear Unit activation.
+    Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(GeGLU, self).__init__(nn.GELU(), dim)
+
+
+class ReGLU(CustomGLU):
+    """ReLU Gated Linear Unit activation.
+    Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(ReGLU, self).__init__(nn.ReLU(), dim)
+
+
+def get_activation_fn(
+    activation: Union[str, Callable[[Tensor], Tensor]]
+) -> Union[str, Callable[[Tensor], Tensor]]:
+    """Helper function to map an activation string to the activation class.
+    If the supplied activation is not a string that is recognized, the activation is passed back.
+
+    Args:
+        activation (str, or Callable[[Tensor], Tensor]): Activation to check
+    """
+    if isinstance(activation, str):
+        if activation == "reglu":
+            return ReGLU()
+        elif activation == "geglu":
+            return GeGLU()
+        elif activation == "swiglu":
+            return SwiGLU()
+    return activation
+
+
+
+
+
+
+
+

Functions

+
+
+def get_activation_fn(activation: Union[str, Callable[[torch.Tensor], torch.Tensor]]) ‑> Union[str, Callable[[torch.Tensor], torch.Tensor]] +
+
+

Helper function to map an activation string to the activation class. +If the supplied activation is not a string that is recognized, the activation is passed back.

+

Args

+
+
activation : str, or Callable[[Tensor], Tensor]
+
Activation to check
+
+
+ +Expand source code + +
def get_activation_fn(
+    activation: Union[str, Callable[[Tensor], Tensor]]
+) -> Union[str, Callable[[Tensor], Tensor]]:
+    """Helper function to map an activation string to the activation class.
+    If the supplied activation is not a string that is recognized, the activation is passed back.
+
+    Args:
+        activation (str, or Callable[[Tensor], Tensor]): Activation to check
+    """
+    if isinstance(activation, str):
+        if activation == "reglu":
+            return ReGLU()
+        elif activation == "geglu":
+            return GeGLU()
+        elif activation == "swiglu":
+            return SwiGLU()
+    return activation
+
+
+
+
+
+

Classes

+
+
+class CustomGLU +(activation: torch.nn.modules.module.Module, dim: int = -1) +
+
+

Custom Gated Linear Unit activation. +Applies a modified gated linear unit :math:a * f(b) where :math:a is the first half +of the input matrices, :math:b is the second half, and :math:f is a provided activation +function (i.e. sigmoid, swish, etc.).

+

Args

+
+
activation : nn.Module
+
The custom activation to apply in the Gated Linear Unit
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Shape

+
    +
  • Input: :math:(st_1, N, st_2) where * means, any number of additional +dimensions
  • +
  • Output: :math:(st_1, M, st_2) where :math:M=N/2
  • +
+

Examples:: +>>> m = CustomGLU(nn.Sigmoid()) +>>> input = torch.randn(4, 2) +>>> output = m(input)

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class CustomGLU(nn.Module):
+    """Custom Gated Linear Unit activation.
+    Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
+    of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
+    function (i.e. sigmoid, swish, etc.).
+
+    Args:
+        activation (nn.Module): The custom activation to apply in the Gated Linear Unit
+        dim (int): the dimension on which to split the input. Default: -1
+
+    Shape:
+        - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
+
+    Examples::
+        >>> m = CustomGLU(nn.Sigmoid())
+        >>> input = torch.randn(4, 2)
+        >>> output = m(input)
+    """
+    def __init__(self, activation: nn.Module, dim: int = -1):
+        super(CustomGLU, self).__init__()
+        self.dim = dim
+        self.activation = activation
+
+    def forward(self, x: Tensor):
+        assert x.shape[self.dim] % 2 == 0  # M = N / 2
+        a, b = torch.chunk(x, 2, dim=self.dim)
+        return a * self.activation(b)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: Tensor):
+    assert x.shape[self.dim] % 2 == 0  # M = N / 2
+    a, b = torch.chunk(x, 2, dim=self.dim)
+    return a * self.activation(b)
+
+
+
+
+
+class GeGLU +(dim: int = -1) +
+
+

GeLU Gated Linear Unit activation. +Applies GeLU Gated Linear Unit :math:a * GELU(b) where :math:a is +the first half of the input matrices, :math:b is the second half.

+

Args

+
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class GeGLU(CustomGLU):
+    """GeLU Gated Linear Unit activation.
+    Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(GeGLU, self).__init__(nn.GELU(), dim)
+
+

Ancestors

+
    +
  • CustomGLU
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class ReGLU +(dim: int = -1) +
+
+

ReLU Gated Linear Unit activation. +Applies ReLU Gated Linear Unit :math:a * ReLU(b) where :math:a is +the first half of the input matrices, :math:b is the second half.

+

Args

+
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ReGLU(CustomGLU):
+    """ReLU Gated Linear Unit activation.
+    Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(ReGLU, self).__init__(nn.ReLU(), dim)
+
+

Ancestors

+
    +
  • CustomGLU
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class SwiGLU +(dim: int = -1) +
+
+

SiLU Gated Linear Unit activation. +Applies SiLU Gated Linear Unit :math:a * SiLU(b) where :math:a is +the first half of the input matrices, :math:b is the second half.

+

Args

+
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SwiGLU(CustomGLU):
+    """SiLU Gated Linear Unit activation.
+    Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(SwiGLU, self).__init__(nn.SiLU(), dim)
+
+

Ancestors

+
    +
  • CustomGLU
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/chroma.html b/api_docs/audiocraft/modules/chroma.html new file mode 100644 index 00000000..a309fff9 --- /dev/null +++ b/api_docs/audiocraft/modules/chroma.html @@ -0,0 +1,285 @@ + + + + + + +audiocraft.modules.chroma API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.chroma

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import typing as tp
+
+from einops import rearrange
+from librosa import filters
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+
+
+class ChromaExtractor(nn.Module):
+    """Chroma extraction and quantization.
+
+    Args:
+        sample_rate (int): Sample rate for the chroma extraction.
+        n_chroma (int): Number of chroma bins for the chroma extraction.
+        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
+        nfft (int, optional): Number of FFT.
+        winlen (int, optional): Window length.
+        winhop (int, optional): Window hop size.
+        argmax (bool, optional): Whether to use argmax. Defaults to False.
+        norm (float, optional): Norm for chroma normalization. Defaults to inf.
+    """
+    def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
+                 winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
+                 norm: float = torch.inf):
+        super().__init__()
+        self.winlen = winlen or 2 ** radix2_exp
+        self.nfft = nfft or self.winlen
+        self.winhop = winhop or (self.winlen // 4)
+        self.sample_rate = sample_rate
+        self.n_chroma = n_chroma
+        self.norm = norm
+        self.argmax = argmax
+        self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+                                                                       n_chroma=self.n_chroma)), persistent=False)
+        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+                                                      hop_length=self.winhop, power=2, center=True,
+                                                      pad=0, normalized=True)
+
+    def forward(self, wav: torch.Tensor) -> torch.Tensor:
+        T = wav.shape[-1]
+        # in case we are getting a wav that was dropped out (nullified)
+        # from the conditioner, make sure wav length is no less that nfft
+        if T < self.nfft:
+            pad = self.nfft - T
+            r = 0 if pad % 2 == 0 else 1
+            wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+            assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
+
+        spec = self.spec(wav).squeeze(1)
+        raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
+        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+        norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
+
+        if self.argmax:
+            idx = norm_chroma.argmax(-1, keepdim=True)
+            norm_chroma[:] = 0
+            norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+        return norm_chroma
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ChromaExtractor +(sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: Optional[int] = None, winlen: Optional[int] = None, winhop: Optional[int] = None, argmax: bool = False, norm: float = inf) +
+
+

Chroma extraction and quantization.

+

Args

+
+
sample_rate : int
+
Sample rate for the chroma extraction.
+
n_chroma : int
+
Number of chroma bins for the chroma extraction.
+
radix2_exp : int
+
Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
+
nfft : int, optional
+
Number of FFT.
+
winlen : int, optional
+
Window length.
+
winhop : int, optional
+
Window hop size.
+
argmax : bool, optional
+
Whether to use argmax. Defaults to False.
+
norm : float, optional
+
Norm for chroma normalization. Defaults to inf.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ChromaExtractor(nn.Module):
+    """Chroma extraction and quantization.
+
+    Args:
+        sample_rate (int): Sample rate for the chroma extraction.
+        n_chroma (int): Number of chroma bins for the chroma extraction.
+        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
+        nfft (int, optional): Number of FFT.
+        winlen (int, optional): Window length.
+        winhop (int, optional): Window hop size.
+        argmax (bool, optional): Whether to use argmax. Defaults to False.
+        norm (float, optional): Norm for chroma normalization. Defaults to inf.
+    """
+    def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
+                 winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
+                 norm: float = torch.inf):
+        super().__init__()
+        self.winlen = winlen or 2 ** radix2_exp
+        self.nfft = nfft or self.winlen
+        self.winhop = winhop or (self.winlen // 4)
+        self.sample_rate = sample_rate
+        self.n_chroma = n_chroma
+        self.norm = norm
+        self.argmax = argmax
+        self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+                                                                       n_chroma=self.n_chroma)), persistent=False)
+        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+                                                      hop_length=self.winhop, power=2, center=True,
+                                                      pad=0, normalized=True)
+
+    def forward(self, wav: torch.Tensor) -> torch.Tensor:
+        T = wav.shape[-1]
+        # in case we are getting a wav that was dropped out (nullified)
+        # from the conditioner, make sure wav length is no less that nfft
+        if T < self.nfft:
+            pad = self.nfft - T
+            r = 0 if pad % 2 == 0 else 1
+            wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+            assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
+
+        spec = self.spec(wav).squeeze(1)
+        raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
+        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+        norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
+
+        if self.argmax:
+            idx = norm_chroma.argmax(-1, keepdim=True)
+            norm_chroma[:] = 0
+            norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+        return norm_chroma
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, wav: torch.Tensor) ‑> torch.Tensor +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, wav: torch.Tensor) -> torch.Tensor:
+    T = wav.shape[-1]
+    # in case we are getting a wav that was dropped out (nullified)
+    # from the conditioner, make sure wav length is no less that nfft
+    if T < self.nfft:
+        pad = self.nfft - T
+        r = 0 if pad % 2 == 0 else 1
+        wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+        assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
+
+    spec = self.spec(wav).squeeze(1)
+    raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
+    norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+    norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
+
+    if self.argmax:
+        idx = norm_chroma.argmax(-1, keepdim=True)
+        norm_chroma[:] = 0
+        norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+    return norm_chroma
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/codebooks_patterns.html b/api_docs/audiocraft/modules/codebooks_patterns.html new file mode 100644 index 00000000..a3733feb --- /dev/null +++ b/api_docs/audiocraft/modules/codebooks_patterns.html @@ -0,0 +1,1834 @@ + + + + + + +audiocraft.modules.codebooks_patterns API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.codebooks_patterns

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import lru_cache
+import logging
+import typing as tp
+
+from abc import ABC, abstractmethod
+import torch
+
+LayoutCoord = namedtuple('LayoutCoord', ['t', 'q'])  # (timestep, codebook index)
+PatternLayout = tp.List[tp.List[LayoutCoord]]  # Sequence of coordinates
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Pattern:
+    """Base implementation of a pattern over a sequence with multiple codebooks.
+
+    The codebook pattern consists in a layout, defining for each sequence step
+    the list of coordinates of each codebook timestep in the resulting interleaved sequence.
+    The first item of the pattern is always an empty list in order to properly insert a special token
+    to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
+    and ``timesteps`` the number of timesteps corresponding to the original sequence.
+
+    The pattern provides convenient methods to build and revert interleaved sequences from it:
+    ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
+        to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
+        K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
+        for the output sequence. The unfilled positions are replaced with a special token and the built sequence
+        is returned along with a mask indicating valid tokens.
+    ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
+        of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
+        to fill and specify invalid positions if needed.
+    See the dedicated methods for more details.
+    """
+    # Pattern layout, for each sequence step, we have a list of coordinates
+    # corresponding to the original codebook timestep and position.
+    # The first list is always an empty list in order to properly insert
+    # a special token to start with.
+    layout: PatternLayout
+    timesteps: int
+    n_q: int
+
+    def __post_init__(self):
+        assert len(self.layout) > 0
+        assert self.layout[0] == []
+        self._validate_layout()
+        self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
+        self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
+        logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
+
+    def _validate_layout(self):
+        """Runs checks on the layout to ensure a valid pattern is defined.
+        A pattern is considered invalid if:
+            - Multiple timesteps for a same codebook are defined in the same sequence step
+            - The timesteps for a given codebook are not in ascending order as we advance in the sequence
+              (this would mean that we have future timesteps before past timesteps).
+        """
+        q_timesteps = {q: 0 for q in range(self.n_q)}
+        for s, seq_coords in enumerate(self.layout):
+            if len(seq_coords) > 0:
+                qs = set()
+                for coord in seq_coords:
+                    qs.add(coord.q)
+                    last_q_timestep = q_timesteps[coord.q]
+                    assert coord.t >= last_q_timestep, \
+                        f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
+                    q_timesteps[coord.q] = coord.t
+                # each sequence step contains at max 1 coordinate per codebook
+                assert len(qs) == len(seq_coords), \
+                    f"Multiple entries for a same codebook are found at step {s}"
+
+    @property
+    def num_sequence_steps(self):
+        return len(self.layout) - 1
+
+    @property
+    def max_delay(self):
+        max_t_in_seq_coords = 0
+        for seq_coords in self.layout[1:]:
+            for coords in seq_coords:
+                max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+        return max_t_in_seq_coords - self.timesteps
+
+    @property
+    def valid_layout(self):
+        valid_step = len(self.layout) - self.max_delay
+        return self.layout[:valid_step]
+
+    def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+        """Get codebook coordinates in the layout that corresponds to the specified timestep t
+        and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+        and the actual codebook coordinates.
+        """
+        assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+        if q is not None:
+            assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+        coords = []
+        for s, seq_codes in enumerate(self.layout):
+            for code in seq_codes:
+                if code.t == t and (q is None or code.q == q):
+                    coords.append((s, code))
+        return coords
+
+    def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+        return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+    def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+        steps_with_timesteps = self.get_steps_with_timestep(t, q)
+        return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+    def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
+                                                device: tp.Union[torch.device, str] = 'cpu'):
+        """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
+
+        Args:
+            timesteps (int): Maximum number of timesteps steps to consider.
+            keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
+            device (torch.device or str): Device for created tensors.
+        Returns:
+            indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
+        """
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
+        # use the proper layout based on whether we limit ourselves to valid steps only or not,
+        # note that using the valid_layout will result in a truncated sequence up to the valid steps
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        # the last value is n_q * timesteps as we have flattened z and append special token as the last token
+        # which will correspond to the index: n_q * timesteps
+        indexes[:] = n_q * timesteps
+        # iterate over the pattern and fill scattered indexes and mask
+        for s, sequence_coords in enumerate(ref_layout):
+            for coords in sequence_coords:
+                if coords.t < timesteps:
+                    indexes[coords.q, s] = coords.t + coords.q * timesteps
+                    mask[coords.q, s] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Build sequence corresponding to the pattern from the input tensor z.
+        The sequence is built using up to sequence_steps if specified, and non-pattern
+        coordinates are filled with the special token.
+
+        Args:
+            z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+            special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+                corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+        """
+        B, K, T = z.shape
+        indexes, mask = self._build_pattern_sequence_scatter_indexes(
+            T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+        )
+        z = z.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+        values = z[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
+                                                 keep_only_valid_steps: bool = False,
+                                                 is_model_output: bool = False,
+                                                 device: tp.Union[torch.device, str] = 'cpu'):
+        """Builds scatter indexes required to retrieve the original multi-codebook sequence
+        from interleaving pattern.
+
+        Args:
+            sequence_steps (int): Sequence steps.
+            n_q (int): Number of codebooks.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+            is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
+            device (torch.device or str): Device for created tensors.
+        Returns:
+            indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
+        timesteps = self.timesteps
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert sequence_steps <= len(ref_layout), \
+            f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
+
+        # ensure we take the appropriate indexes to keep the model output from the first special token as well
+        if is_model_output:
+            ref_layout = ref_layout[1:]
+
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        indexes[:] = n_q * sequence_steps
+        for s, sequence_codes in enumerate(ref_layout):
+            if s < sequence_steps:
+                for code in sequence_codes:
+                    if code.t < timesteps:
+                        indexes[code.q, code.t] = s + code.q * sequence_steps
+                        mask[code.q, code.t] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+        The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+        are filled with the special token.
+
+        Args:
+            s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+            special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+                corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        B, K, S = s.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+        )
+        s = s.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+        values = s[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+        """Revert model logits obtained on a sequence built from the pattern
+        back to a tensor matching the original sequence.
+
+        This method is similar to ``revert_pattern_sequence`` with the following specificities:
+        1. It is designed to work with the extra cardinality dimension
+        2. We return the logits for the first sequence item that matches the special_token and
+        which matching target in the original sequence is the first item of the sequence,
+        while we skip the last logits as there is no matching target
+        """
+        B, card, K, S = logits.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+        )
+        logits = logits.reshape(B, card, -1)
+        # we append the special token as the last index of our flattened z tensor
+        logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1)  # [B, card, K x S]
+        values = logits[:, :, indexes.view(-1)]
+        values = values.view(B, card, K, indexes.shape[-1])
+        return values, indexes, mask
+
+
+class CodebooksPatternProvider(ABC):
+    """Abstraction around providing pattern for interleaving codebooks.
+
+    The CodebooksPatternProvider abstraction allows to implement various strategies to
+    define interleaving pattern of sequences composed of multiple codebooks. For a given
+    number of codebooks `n_q`, the pattern provider can generate a specified pattern
+    corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
+    can be used to construct a new sequence from the original codes respecting the specified
+    pattern. The pattern is defined as a list of list of code coordinates, code coordinate
+    being a tuple with the original timestep and codebook to build the new sequence.
+    Note that all patterns must start with an empty list that is then used to insert a first
+    sequence step of special tokens in the newly generated sequence.
+
+    Args:
+        n_q (int): number of codebooks.
+        cached (bool): if True, patterns for a given length are cached. In general
+            that should be true for efficiency reason to avoid synchronization points.
+    """
+    def __init__(self, n_q: int, cached: bool = True):
+        assert n_q > 0
+        self.n_q = n_q
+        self.get_pattern = lru_cache(100)(self.get_pattern)  # type: ignore
+
+    @abstractmethod
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern with specific interleaving between codebooks.
+
+        Args:
+            timesteps (int): Total number of timesteps.
+        """
+        raise NotImplementedError()
+
+
+class DelayedPatternProvider(CodebooksPatternProvider):
+    """Provider for delayed pattern across delayed codebooks.
+    Codebooks are delayed in the sequence and sequence steps will contain codebooks
+    from different timesteps.
+
+    Example:
+        Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
+        [[1, 2, 3, 4],
+        [1, 2, 3, 4],
+        [1, 2, 3, 4]]
+        The resulting sequence obtained from the returned pattern is:
+        [[S, 1, 2, 3, 4],
+        [S, S, 1, 2, 3],
+        [S, S, S, 1, 2]]
+        (with S being a special token)
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (list of int, optional): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+        flatten_first (int): Flatten the first N timesteps.
+        empty_initial (int): Prepend with N empty list of coordinates.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
+                 flatten_first: int = 0, empty_initial: int = 0):
+        super().__init__(n_q)
+        if delays is None:
+            delays = list(range(n_q))
+        self.delays = delays
+        self.flatten_first = flatten_first
+        self.empty_initial = empty_initial
+        assert len(self.delays) == self.n_q
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        max_delay = max(self.delays)
+        if self.empty_initial:
+            out += [[] for _ in range(self.empty_initial)]
+        if self.flatten_first:
+            for t in range(min(timesteps, self.flatten_first)):
+                for q in range(self.n_q):
+                    out.append([LayoutCoord(t, q)])
+        for t in range(self.flatten_first, timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= self.flatten_first:
+                    v.append(LayoutCoord(t_for_q, q))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class ParallelPatternProvider(DelayedPatternProvider):
+    """Provider for parallel pattern across codebooks.
+    This pattern provider is a special case of the delayed pattern with actually no delay,
+    hence delays=repeat(0, n_q).
+
+    Args:
+        n_q (int): Number of codebooks.
+    """
+    def __init__(self, n_q: int):
+        super().__init__(n_q, [0] * n_q)
+
+
+class UnrolledPatternProvider(CodebooksPatternProvider):
+    """Provider for unrolling codebooks pattern.
+    This pattern provider enables to represent the codebook flattened completely or only to some extend
+    while also specifying a given delay between the flattened codebooks representation, allowing to
+    unroll the codebooks in the sequence.
+
+    Example:
+        1. Flattening of the codebooks.
+        By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
+        taking n_q = 3 and timesteps = 4:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
+        for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
+        taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
+        allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
+        same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
+        and delays = [0, 3, 3]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, S, 1, S, 2, S, 3, S, 4],
+         [S, S, S, 1, S, 2, S, 3, S, 4],
+         [1, 2, 3, S, 4, S, 5, S, 6, S]]
+
+    Args:
+        n_q (int): Number of codebooks.
+        flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
+            the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
+            have n_q extra steps for each timestep.
+        delays (list of int, optional): Delay for each of the codebooks. If not defined,
+            no delay is added and therefore will default to [0] * ``n_q``.
+            Note that two codebooks that will be flattened to the same inner step
+            should have the same delay, otherwise the pattern is considered as invalid.
+    """
+    FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
+
+    def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
+                 delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if flattening is None:
+            flattening = list(range(n_q))
+        if delays is None:
+            delays = [0] * n_q
+        assert len(flattening) == n_q
+        assert len(delays) == n_q
+        assert sorted(flattening) == flattening
+        assert sorted(delays) == delays
+        self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
+        self.max_delay = max(delays)
+
+    def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
+        """Build a flattened codebooks representation as a dictionary of inner step
+        and the actual codebook indices corresponding to the flattened codebook. For convenience, we
+        also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
+        """
+        flattened_codebooks: dict = {}
+        for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
+            if inner_step not in flattened_codebooks:
+                flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
+            else:
+                flat_codebook = flattened_codebooks[inner_step]
+                assert flat_codebook.delay == delay, (
+                    "Delay and flattening between codebooks is inconsistent: ",
+                    "two codebooks flattened to the same position should have the same delay."
+                )
+                flat_codebook.codebooks.append(q)
+            flattened_codebooks[inner_step] = flat_codebook
+        return flattened_codebooks
+
+    @property
+    def _num_inner_steps(self):
+        """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
+        """
+        return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
+
+    def num_virtual_steps(self, timesteps: int) -> int:
+        return timesteps * self._num_inner_steps + 1
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern for delay across codebooks.
+
+        Args:
+            timesteps (int): Total number of timesteps.
+        """
+        # the PatternLayout is built as a tuple of sequence position and list of coordinates
+        # so that it can be reordered properly given the required delay between codebooks of given timesteps
+        indexed_out: list = [(-1, [])]
+        max_timesteps = timesteps + self.max_delay
+        for t in range(max_timesteps):
+            # for each timestep, we unroll the flattened codebooks,
+            # emitting the sequence step with the corresponding delay
+            for step in range(self._num_inner_steps):
+                if step in self._flattened_codebooks:
+                    # we have codebooks at this virtual step to emit
+                    step_codebooks = self._flattened_codebooks[step]
+                    t_for_q = t + step_codebooks.delay
+                    coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+                    if t_for_q < max_timesteps and t < max_timesteps:
+                        indexed_out.append((t_for_q, coords))
+                else:
+                    # there is no codebook in this virtual step so we emit an empty list
+                    indexed_out.append((t, []))
+        out = [coords for _, coords in sorted(indexed_out)]
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class CoarseFirstPattern(CodebooksPatternProvider):
+    """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
+    potentially with delays.
+
+    ..Warning:: You must always generate the full training duration at test time, for instance,
+        30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
+        location. This is due to the non causality of the remaining codebooks with respect to
+        the first ones.
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (list of int, optional): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if delays is None:
+            delays = [0] * (n_q - 1)
+        self.delays = delays
+        assert len(self.delays) == self.n_q - 1
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for t in range(timesteps):
+            out.append([LayoutCoord(t, 0)])
+        max_delay = max(self.delays)
+        for t in range(timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= 0:
+                    v.append(LayoutCoord(t_for_q, q + 1))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class MusicLMPattern(CodebooksPatternProvider):
+    """Almost MusicLM style pattern. This is equivalent to full flattening
+    but in a different order.
+
+    Args:
+        n_q (int): Number of codebooks.
+        group_by (int): Number of codebooks to group together.
+    """
+    def __init__(self, n_q: int, group_by: int = 2):
+        super().__init__(n_q)
+        self.group_by = group_by
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for offset in range(0, self.n_q, self.group_by):
+            for t in range(timesteps):
+                for q in range(offset, offset + self.group_by):
+                    out.append([LayoutCoord(t, q)])
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CoarseFirstPattern +(n_q: int, delays: Optional[List[int]] = None) +
+
+

First generates all the codebooks #1 (e.g. coarser), then the remaining ones, +potentially with delays.

+
+

Warning: You must always generate the full training duration at test time, for instance,

+

30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected +location. This is due to the non causality of the remaining codebooks with respect to +the first ones.

+
+

Args

+
+
n_q : int
+
Number of codebooks.
+
delays : list of int, optional
+
Delay for each of the codebooks. +If delays not defined, each codebook is delayed by 1 compared to the previous one.
+
+
+ +Expand source code + +
class CoarseFirstPattern(CodebooksPatternProvider):
+    """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
+    potentially with delays.
+
+    ..Warning:: You must always generate the full training duration at test time, for instance,
+        30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
+        location. This is due to the non causality of the remaining codebooks with respect to
+        the first ones.
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (list of int, optional): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if delays is None:
+            delays = [0] * (n_q - 1)
+        self.delays = delays
+        assert len(self.delays) == self.n_q - 1
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for t in range(timesteps):
+            out.append([LayoutCoord(t, 0)])
+        max_delay = max(self.delays)
+        for t in range(timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= 0:
+                    v.append(LayoutCoord(t_for_q, q + 1))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class CodebooksPatternProvider +(n_q: int, cached: bool = True) +
+
+

Abstraction around providing pattern for interleaving codebooks.

+

The CodebooksPatternProvider abstraction allows to implement various strategies to +define interleaving pattern of sequences composed of multiple codebooks. For a given +number of codebooks n_q, the pattern provider can generate a specified pattern +corresponding to a sequence of T timesteps with n_q parallel codebooks. This pattern +can be used to construct a new sequence from the original codes respecting the specified +pattern. The pattern is defined as a list of list of code coordinates, code coordinate +being a tuple with the original timestep and codebook to build the new sequence. +Note that all patterns must start with an empty list that is then used to insert a first +sequence step of special tokens in the newly generated sequence.

+

Args

+
+
n_q : int
+
number of codebooks.
+
cached : bool
+
if True, patterns for a given length are cached. In general +that should be true for efficiency reason to avoid synchronization points.
+
+
+ +Expand source code + +
class CodebooksPatternProvider(ABC):
+    """Abstraction around providing pattern for interleaving codebooks.
+
+    The CodebooksPatternProvider abstraction allows to implement various strategies to
+    define interleaving pattern of sequences composed of multiple codebooks. For a given
+    number of codebooks `n_q`, the pattern provider can generate a specified pattern
+    corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
+    can be used to construct a new sequence from the original codes respecting the specified
+    pattern. The pattern is defined as a list of list of code coordinates, code coordinate
+    being a tuple with the original timestep and codebook to build the new sequence.
+    Note that all patterns must start with an empty list that is then used to insert a first
+    sequence step of special tokens in the newly generated sequence.
+
+    Args:
+        n_q (int): number of codebooks.
+        cached (bool): if True, patterns for a given length are cached. In general
+            that should be true for efficiency reason to avoid synchronization points.
+    """
+    def __init__(self, n_q: int, cached: bool = True):
+        assert n_q > 0
+        self.n_q = n_q
+        self.get_pattern = lru_cache(100)(self.get_pattern)  # type: ignore
+
+    @abstractmethod
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern with specific interleaving between codebooks.
+
+        Args:
+            timesteps (int): Total number of timesteps.
+        """
+        raise NotImplementedError()
+
+

Ancestors

+
    +
  • abc.ABC
  • +
+

Subclasses

+ +

Methods

+
+
+def get_pattern(self, timesteps: int) ‑> Pattern +
+
+

Builds pattern with specific interleaving between codebooks.

+

Args

+
+
timesteps : int
+
Total number of timesteps.
+
+
+ +Expand source code + +
@abstractmethod
+def get_pattern(self, timesteps: int) -> Pattern:
+    """Builds pattern with specific interleaving between codebooks.
+
+    Args:
+        timesteps (int): Total number of timesteps.
+    """
+    raise NotImplementedError()
+
+
+
+
+
+class DelayedPatternProvider +(n_q: int, delays: Optional[List[int]] = None, flatten_first: int = 0, empty_initial: int = 0) +
+
+

Provider for delayed pattern across delayed codebooks. +Codebooks are delayed in the sequence and sequence steps will contain codebooks +from different timesteps.

+

Example

+

Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +The resulting sequence obtained from the returned pattern is: +[[S, 1, 2, 3, 4], +[S, S, 1, 2, 3], +[S, S, S, 1, 2]] +(with S being a special token)

+

Args

+
+
n_q : int
+
Number of codebooks.
+
delays : list of int, optional
+
Delay for each of the codebooks. +If delays not defined, each codebook is delayed by 1 compared to the previous one.
+
flatten_first : int
+
Flatten the first N timesteps.
+
empty_initial : int
+
Prepend with N empty list of coordinates.
+
+
+ +Expand source code + +
class DelayedPatternProvider(CodebooksPatternProvider):
+    """Provider for delayed pattern across delayed codebooks.
+    Codebooks are delayed in the sequence and sequence steps will contain codebooks
+    from different timesteps.
+
+    Example:
+        Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
+        [[1, 2, 3, 4],
+        [1, 2, 3, 4],
+        [1, 2, 3, 4]]
+        The resulting sequence obtained from the returned pattern is:
+        [[S, 1, 2, 3, 4],
+        [S, S, 1, 2, 3],
+        [S, S, S, 1, 2]]
+        (with S being a special token)
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (list of int, optional): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+        flatten_first (int): Flatten the first N timesteps.
+        empty_initial (int): Prepend with N empty list of coordinates.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
+                 flatten_first: int = 0, empty_initial: int = 0):
+        super().__init__(n_q)
+        if delays is None:
+            delays = list(range(n_q))
+        self.delays = delays
+        self.flatten_first = flatten_first
+        self.empty_initial = empty_initial
+        assert len(self.delays) == self.n_q
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        max_delay = max(self.delays)
+        if self.empty_initial:
+            out += [[] for _ in range(self.empty_initial)]
+        if self.flatten_first:
+            for t in range(min(timesteps, self.flatten_first)):
+                for q in range(self.n_q):
+                    out.append([LayoutCoord(t, q)])
+        for t in range(self.flatten_first, timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= self.flatten_first:
+                    v.append(LayoutCoord(t_for_q, q))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Subclasses

+ +

Inherited members

+ +
+
+class LayoutCoord +(t, q) +
+
+

LayoutCoord(t, q)

+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var q
+
+

Alias for field number 1

+
+
var t
+
+

Alias for field number 0

+
+
+
+
+class MusicLMPattern +(n_q: int, group_by: int = 2) +
+
+

Almost MusicLM style pattern. This is equivalent to full flattening +but in a different order.

+

Args

+
+
n_q : int
+
Number of codebooks.
+
group_by : int
+
Number of codebooks to group together.
+
+
+ +Expand source code + +
class MusicLMPattern(CodebooksPatternProvider):
+    """Almost MusicLM style pattern. This is equivalent to full flattening
+    but in a different order.
+
+    Args:
+        n_q (int): Number of codebooks.
+        group_by (int): Number of codebooks to group together.
+    """
+    def __init__(self, n_q: int, group_by: int = 2):
+        super().__init__(n_q)
+        self.group_by = group_by
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for offset in range(0, self.n_q, self.group_by):
+            for t in range(timesteps):
+                for q in range(offset, offset + self.group_by):
+                    out.append([LayoutCoord(t, q)])
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class ParallelPatternProvider +(n_q: int) +
+
+

Provider for parallel pattern across codebooks. +This pattern provider is a special case of the delayed pattern with actually no delay, +hence delays=repeat(0, n_q).

+

Args

+
+
n_q : int
+
Number of codebooks.
+
+
+ +Expand source code + +
class ParallelPatternProvider(DelayedPatternProvider):
+    """Provider for parallel pattern across codebooks.
+    This pattern provider is a special case of the delayed pattern with actually no delay,
+    hence delays=repeat(0, n_q).
+
+    Args:
+        n_q (int): Number of codebooks.
+    """
+    def __init__(self, n_q: int):
+        super().__init__(n_q, [0] * n_q)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class Pattern +(layout: List[List[LayoutCoord]], timesteps: int, n_q: int) +
+
+

Base implementation of a pattern over a sequence with multiple codebooks.

+

The codebook pattern consists in a layout, defining for each sequence step +the list of coordinates of each codebook timestep in the resulting interleaved sequence. +The first item of the pattern is always an empty list in order to properly insert a special token +to start with. For convenience, we also keep track of n_q the number of codebooks used for the pattern +and timesteps the number of timesteps corresponding to the original sequence.

+

The pattern provides convenient methods to build and revert interleaved sequences from it: +build_pattern_sequence maps a given a dense input tensor of multi-codebook sequence from [B, K, T] +to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, +K being the number of codebooks, T the number of original timesteps and S the number of sequence steps +for the output sequence. The unfilled positions are replaced with a special token and the built sequence +is returned along with a mask indicating valid tokens. +revert_pattern_sequence maps back an interleaved sequence of shape [B, K, S] to the original alignment +of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask +to fill and specify invalid positions if needed. +See the dedicated methods for more details.

+
+ +Expand source code + +
class Pattern:
+    """Base implementation of a pattern over a sequence with multiple codebooks.
+
+    The codebook pattern consists in a layout, defining for each sequence step
+    the list of coordinates of each codebook timestep in the resulting interleaved sequence.
+    The first item of the pattern is always an empty list in order to properly insert a special token
+    to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
+    and ``timesteps`` the number of timesteps corresponding to the original sequence.
+
+    The pattern provides convenient methods to build and revert interleaved sequences from it:
+    ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
+        to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
+        K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
+        for the output sequence. The unfilled positions are replaced with a special token and the built sequence
+        is returned along with a mask indicating valid tokens.
+    ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
+        of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
+        to fill and specify invalid positions if needed.
+    See the dedicated methods for more details.
+    """
+    # Pattern layout, for each sequence step, we have a list of coordinates
+    # corresponding to the original codebook timestep and position.
+    # The first list is always an empty list in order to properly insert
+    # a special token to start with.
+    layout: PatternLayout
+    timesteps: int
+    n_q: int
+
+    def __post_init__(self):
+        assert len(self.layout) > 0
+        assert self.layout[0] == []
+        self._validate_layout()
+        self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
+        self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
+        logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
+
+    def _validate_layout(self):
+        """Runs checks on the layout to ensure a valid pattern is defined.
+        A pattern is considered invalid if:
+            - Multiple timesteps for a same codebook are defined in the same sequence step
+            - The timesteps for a given codebook are not in ascending order as we advance in the sequence
+              (this would mean that we have future timesteps before past timesteps).
+        """
+        q_timesteps = {q: 0 for q in range(self.n_q)}
+        for s, seq_coords in enumerate(self.layout):
+            if len(seq_coords) > 0:
+                qs = set()
+                for coord in seq_coords:
+                    qs.add(coord.q)
+                    last_q_timestep = q_timesteps[coord.q]
+                    assert coord.t >= last_q_timestep, \
+                        f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
+                    q_timesteps[coord.q] = coord.t
+                # each sequence step contains at max 1 coordinate per codebook
+                assert len(qs) == len(seq_coords), \
+                    f"Multiple entries for a same codebook are found at step {s}"
+
+    @property
+    def num_sequence_steps(self):
+        return len(self.layout) - 1
+
+    @property
+    def max_delay(self):
+        max_t_in_seq_coords = 0
+        for seq_coords in self.layout[1:]:
+            for coords in seq_coords:
+                max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+        return max_t_in_seq_coords - self.timesteps
+
+    @property
+    def valid_layout(self):
+        valid_step = len(self.layout) - self.max_delay
+        return self.layout[:valid_step]
+
+    def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+        """Get codebook coordinates in the layout that corresponds to the specified timestep t
+        and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+        and the actual codebook coordinates.
+        """
+        assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+        if q is not None:
+            assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+        coords = []
+        for s, seq_codes in enumerate(self.layout):
+            for code in seq_codes:
+                if code.t == t and (q is None or code.q == q):
+                    coords.append((s, code))
+        return coords
+
+    def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+        return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+    def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+        steps_with_timesteps = self.get_steps_with_timestep(t, q)
+        return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+    def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
+                                                device: tp.Union[torch.device, str] = 'cpu'):
+        """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
+
+        Args:
+            timesteps (int): Maximum number of timesteps steps to consider.
+            keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
+            device (torch.device or str): Device for created tensors.
+        Returns:
+            indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
+        """
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
+        # use the proper layout based on whether we limit ourselves to valid steps only or not,
+        # note that using the valid_layout will result in a truncated sequence up to the valid steps
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        # the last value is n_q * timesteps as we have flattened z and append special token as the last token
+        # which will correspond to the index: n_q * timesteps
+        indexes[:] = n_q * timesteps
+        # iterate over the pattern and fill scattered indexes and mask
+        for s, sequence_coords in enumerate(ref_layout):
+            for coords in sequence_coords:
+                if coords.t < timesteps:
+                    indexes[coords.q, s] = coords.t + coords.q * timesteps
+                    mask[coords.q, s] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Build sequence corresponding to the pattern from the input tensor z.
+        The sequence is built using up to sequence_steps if specified, and non-pattern
+        coordinates are filled with the special token.
+
+        Args:
+            z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+            special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+                corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+        """
+        B, K, T = z.shape
+        indexes, mask = self._build_pattern_sequence_scatter_indexes(
+            T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+        )
+        z = z.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+        values = z[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
+                                                 keep_only_valid_steps: bool = False,
+                                                 is_model_output: bool = False,
+                                                 device: tp.Union[torch.device, str] = 'cpu'):
+        """Builds scatter indexes required to retrieve the original multi-codebook sequence
+        from interleaving pattern.
+
+        Args:
+            sequence_steps (int): Sequence steps.
+            n_q (int): Number of codebooks.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+            is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
+            device (torch.device or str): Device for created tensors.
+        Returns:
+            indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
+        timesteps = self.timesteps
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert sequence_steps <= len(ref_layout), \
+            f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
+
+        # ensure we take the appropriate indexes to keep the model output from the first special token as well
+        if is_model_output:
+            ref_layout = ref_layout[1:]
+
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        indexes[:] = n_q * sequence_steps
+        for s, sequence_codes in enumerate(ref_layout):
+            if s < sequence_steps:
+                for code in sequence_codes:
+                    if code.t < timesteps:
+                        indexes[code.q, code.t] = s + code.q * sequence_steps
+                        mask[code.q, code.t] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+        The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+        are filled with the special token.
+
+        Args:
+            s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+            special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+                corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        B, K, S = s.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+        )
+        s = s.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+        values = s[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+        """Revert model logits obtained on a sequence built from the pattern
+        back to a tensor matching the original sequence.
+
+        This method is similar to ``revert_pattern_sequence`` with the following specificities:
+        1. It is designed to work with the extra cardinality dimension
+        2. We return the logits for the first sequence item that matches the special_token and
+        which matching target in the original sequence is the first item of the sequence,
+        while we skip the last logits as there is no matching target
+        """
+        B, card, K, S = logits.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+        )
+        logits = logits.reshape(B, card, -1)
+        # we append the special token as the last index of our flattened z tensor
+        logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1)  # [B, card, K x S]
+        values = logits[:, :, indexes.view(-1)]
+        values = values.view(B, card, K, indexes.shape[-1])
+        return values, indexes, mask
+
+

Class variables

+
+
var layout : List[List[LayoutCoord]]
+
+
+
+
var n_q : int
+
+
+
+
var timesteps : int
+
+
+
+
+

Instance variables

+
+
var max_delay
+
+
+
+ +Expand source code + +
@property
+def max_delay(self):
+    max_t_in_seq_coords = 0
+    for seq_coords in self.layout[1:]:
+        for coords in seq_coords:
+            max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+    return max_t_in_seq_coords - self.timesteps
+
+
+
var num_sequence_steps
+
+
+
+ +Expand source code + +
@property
+def num_sequence_steps(self):
+    return len(self.layout) - 1
+
+
+
var valid_layout
+
+
+
+ +Expand source code + +
@property
+def valid_layout(self):
+    valid_step = len(self.layout) - self.max_delay
+    return self.layout[:valid_step]
+
+
+
+

Methods

+
+
+def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False) +
+
+

Build sequence corresponding to the pattern from the input tensor z. +The sequence is built using up to sequence_steps if specified, and non-pattern +coordinates are filled with the special token.

+

Args

+
+
z : torch.Tensor
+
Input tensor of multi-codebooks sequence, of shape [B, K, T].
+
special_token : int
+
Special token used to fill non-pattern coordinates in the new sequence.
+
keep_only_valid_steps : bool
+
Build a sequence from the pattern up to valid (= fully defined) steps. +Steps that are beyond valid steps will be replaced by the special_token in that case.
+
+

Returns

+

values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S +corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. +indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. +mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].

+
+ +Expand source code + +
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+    """Build sequence corresponding to the pattern from the input tensor z.
+    The sequence is built using up to sequence_steps if specified, and non-pattern
+    coordinates are filled with the special token.
+
+    Args:
+        z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+        special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+        keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+            Steps that are beyond valid steps will be replaced by the special_token in that case.
+    Returns:
+        values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+            corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+        indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+        mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+    """
+    B, K, T = z.shape
+    indexes, mask = self._build_pattern_sequence_scatter_indexes(
+        T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+    )
+    z = z.view(B, -1)
+    # we append the special token as the last index of our flattened z tensor
+    z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+    values = z[:, indexes.view(-1)]
+    values = values.view(B, K, indexes.shape[-1])
+    return values, indexes, mask
+
+
+
+def get_first_step_with_timesteps(self, t: int, q: Optional[int] = None) ‑> Optional[int] +
+
+
+
+ +Expand source code + +
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+    steps_with_timesteps = self.get_steps_with_timestep(t, q)
+    return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+
+
+def get_sequence_coords_with_timestep(self, t: int, q: Optional[int] = None) +
+
+

Get codebook coordinates in the layout that corresponds to the specified timestep t +and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step +and the actual codebook coordinates.

+
+ +Expand source code + +
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+    """Get codebook coordinates in the layout that corresponds to the specified timestep t
+    and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+    and the actual codebook coordinates.
+    """
+    assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+    if q is not None:
+        assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+    coords = []
+    for s, seq_codes in enumerate(self.layout):
+        for code in seq_codes:
+            if code.t == t and (q is None or code.q == q):
+                coords.append((s, code))
+    return coords
+
+
+
+def get_steps_with_timestep(self, t: int, q: Optional[int] = None) ‑> List[int] +
+
+
+
+ +Expand source code + +
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+    return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+
+
+def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False) +
+
+

Revert model logits obtained on a sequence built from the pattern +back to a tensor matching the original sequence.

+

This method is similar to revert_pattern_sequence with the following specificities: +1. It is designed to work with the extra cardinality dimension +2. We return the logits for the first sequence item that matches the special_token and +which matching target in the original sequence is the first item of the sequence, +while we skip the last logits as there is no matching target

+
+ +Expand source code + +
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+    """Revert model logits obtained on a sequence built from the pattern
+    back to a tensor matching the original sequence.
+
+    This method is similar to ``revert_pattern_sequence`` with the following specificities:
+    1. It is designed to work with the extra cardinality dimension
+    2. We return the logits for the first sequence item that matches the special_token and
+    which matching target in the original sequence is the first item of the sequence,
+    while we skip the last logits as there is no matching target
+    """
+    B, card, K, S = logits.shape
+    indexes, mask = self._build_reverted_sequence_scatter_indexes(
+        S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+    )
+    logits = logits.reshape(B, card, -1)
+    # we append the special token as the last index of our flattened z tensor
+    logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1)  # [B, card, K x S]
+    values = logits[:, :, indexes.view(-1)]
+    values = values.view(B, card, K, indexes.shape[-1])
+    return values, indexes, mask
+
+
+
+def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False) +
+
+

Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. +The sequence is reverted using up to timesteps if specified, and non-pattern coordinates +are filled with the special token.

+

Args

+
+
s : torch.Tensor
+
Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+
special_token : int or float
+
Special token used to fill non-pattern coordinates in the new sequence.
+
+

Returns

+

values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T +corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. +indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. +mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].

+
+ +Expand source code + +
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+    """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+    The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+    are filled with the special token.
+
+    Args:
+        s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+        special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+    Returns:
+        values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+            corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+        indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+        mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+    """
+    B, K, S = s.shape
+    indexes, mask = self._build_reverted_sequence_scatter_indexes(
+        S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+    )
+    s = s.view(B, -1)
+    # we append the special token as the last index of our flattened z tensor
+    s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+    values = s[:, indexes.view(-1)]
+    values = values.view(B, K, indexes.shape[-1])
+    return values, indexes, mask
+
+
+
+
+
+class UnrolledPatternProvider +(n_q: int, flattening: Optional[List[int]] = None, delays: Optional[List[int]] = None) +
+
+

Provider for unrolling codebooks pattern. +This pattern provider enables to represent the codebook flattened completely or only to some extend +while also specifying a given delay between the flattened codebooks representation, allowing to +unroll the codebooks in the sequence.

+

Example

+
    +
  1. Flattening of the codebooks. +By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), +taking n_q = 3 and timesteps = 4: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +will result into: +[[S, S, 1, S, S, 2, S, S, 3, S, S, 4], +[S, 1, S, S, 2, S, S, 3, S, S, 4, S], +[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
  2. +
  3. Partial flattening of the codebooks. The flattening parameter allows to specify the inner step +for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example +taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +will result into: +[[S, 1, S, S, 2, S, S, 3, S, S, 4, S], +[S, 1, S, S, 2, S, S, 3, S, S, 4, S], +[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
  4. +
  5. Flattening with delay. The delay parameter allows to further unroll the sequence of codebooks +allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the +same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] +and delays = [0, 3, 3]: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +will result into: +[[S, S, S, 1, S, 2, S, 3, S, 4], +[S, S, S, 1, S, 2, S, 3, S, 4], +[1, 2, 3, S, 4, S, 5, S, 6, S]]
  6. +
+

Args

+
+
n_q : int
+
Number of codebooks.
+
flattening : list of int, optional
+
Flattening schema over the codebooks. If not defined, +the codebooks will be flattened to 1 codebook per step, meaning that the sequence will +have n_q extra steps for each timestep.
+
delays : list of int, optional
+
Delay for each of the codebooks. If not defined, +no delay is added and therefore will default to [0] * n_q. +Note that two codebooks that will be flattened to the same inner step +should have the same delay, otherwise the pattern is considered as invalid.
+
+
+ +Expand source code + +
class UnrolledPatternProvider(CodebooksPatternProvider):
+    """Provider for unrolling codebooks pattern.
+    This pattern provider enables to represent the codebook flattened completely or only to some extend
+    while also specifying a given delay between the flattened codebooks representation, allowing to
+    unroll the codebooks in the sequence.
+
+    Example:
+        1. Flattening of the codebooks.
+        By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
+        taking n_q = 3 and timesteps = 4:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
+        for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
+        taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
+        allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
+        same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
+        and delays = [0, 3, 3]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, S, 1, S, 2, S, 3, S, 4],
+         [S, S, S, 1, S, 2, S, 3, S, 4],
+         [1, 2, 3, S, 4, S, 5, S, 6, S]]
+
+    Args:
+        n_q (int): Number of codebooks.
+        flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
+            the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
+            have n_q extra steps for each timestep.
+        delays (list of int, optional): Delay for each of the codebooks. If not defined,
+            no delay is added and therefore will default to [0] * ``n_q``.
+            Note that two codebooks that will be flattened to the same inner step
+            should have the same delay, otherwise the pattern is considered as invalid.
+    """
+    FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
+
+    def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
+                 delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if flattening is None:
+            flattening = list(range(n_q))
+        if delays is None:
+            delays = [0] * n_q
+        assert len(flattening) == n_q
+        assert len(delays) == n_q
+        assert sorted(flattening) == flattening
+        assert sorted(delays) == delays
+        self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
+        self.max_delay = max(delays)
+
+    def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
+        """Build a flattened codebooks representation as a dictionary of inner step
+        and the actual codebook indices corresponding to the flattened codebook. For convenience, we
+        also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
+        """
+        flattened_codebooks: dict = {}
+        for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
+            if inner_step not in flattened_codebooks:
+                flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
+            else:
+                flat_codebook = flattened_codebooks[inner_step]
+                assert flat_codebook.delay == delay, (
+                    "Delay and flattening between codebooks is inconsistent: ",
+                    "two codebooks flattened to the same position should have the same delay."
+                )
+                flat_codebook.codebooks.append(q)
+            flattened_codebooks[inner_step] = flat_codebook
+        return flattened_codebooks
+
+    @property
+    def _num_inner_steps(self):
+        """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
+        """
+        return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
+
+    def num_virtual_steps(self, timesteps: int) -> int:
+        return timesteps * self._num_inner_steps + 1
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern for delay across codebooks.
+
+        Args:
+            timesteps (int): Total number of timesteps.
+        """
+        # the PatternLayout is built as a tuple of sequence position and list of coordinates
+        # so that it can be reordered properly given the required delay between codebooks of given timesteps
+        indexed_out: list = [(-1, [])]
+        max_timesteps = timesteps + self.max_delay
+        for t in range(max_timesteps):
+            # for each timestep, we unroll the flattened codebooks,
+            # emitting the sequence step with the corresponding delay
+            for step in range(self._num_inner_steps):
+                if step in self._flattened_codebooks:
+                    # we have codebooks at this virtual step to emit
+                    step_codebooks = self._flattened_codebooks[step]
+                    t_for_q = t + step_codebooks.delay
+                    coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+                    if t_for_q < max_timesteps and t < max_timesteps:
+                        indexed_out.append((t_for_q, coords))
+                else:
+                    # there is no codebook in this virtual step so we emit an empty list
+                    indexed_out.append((t, []))
+        out = [coords for _, coords in sorted(indexed_out)]
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Class variables

+
+
var FlattenedCodebook
+
+
+
+
+

Methods

+
+
+def get_pattern(self, timesteps: int) ‑> Pattern +
+
+

Builds pattern for delay across codebooks.

+

Args

+
+
timesteps : int
+
Total number of timesteps.
+
+
+ +Expand source code + +
def get_pattern(self, timesteps: int) -> Pattern:
+    """Builds pattern for delay across codebooks.
+
+    Args:
+        timesteps (int): Total number of timesteps.
+    """
+    # the PatternLayout is built as a tuple of sequence position and list of coordinates
+    # so that it can be reordered properly given the required delay between codebooks of given timesteps
+    indexed_out: list = [(-1, [])]
+    max_timesteps = timesteps + self.max_delay
+    for t in range(max_timesteps):
+        # for each timestep, we unroll the flattened codebooks,
+        # emitting the sequence step with the corresponding delay
+        for step in range(self._num_inner_steps):
+            if step in self._flattened_codebooks:
+                # we have codebooks at this virtual step to emit
+                step_codebooks = self._flattened_codebooks[step]
+                t_for_q = t + step_codebooks.delay
+                coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+                if t_for_q < max_timesteps and t < max_timesteps:
+                    indexed_out.append((t_for_q, coords))
+            else:
+                # there is no codebook in this virtual step so we emit an empty list
+                indexed_out.append((t, []))
+    out = [coords for _, coords in sorted(indexed_out)]
+    return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+
+def num_virtual_steps(self, timesteps: int) ‑> int +
+
+
+
+ +Expand source code + +
def num_virtual_steps(self, timesteps: int) -> int:
+    return timesteps * self._num_inner_steps + 1
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/conditioners.html b/api_docs/audiocraft/modules/conditioners.html new file mode 100644 index 00000000..c044da08 --- /dev/null +++ b/api_docs/audiocraft/modules/conditioners.html @@ -0,0 +1,4662 @@ + + + + + + +audiocraft.modules.conditioners API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.conditioners

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from copy import deepcopy
+from dataclasses import dataclass, field
+from itertools import chain
+import logging
+import math
+from pathlib import Path
+import random
+import re
+import typing as tp
+import warnings
+
+import einops
+from num2words import num2words
+import spacy
+from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer  # type: ignore
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+
+from .chroma import ChromaExtractor
+from .streaming import StreamingModule
+from .transformer import create_sin_embedding
+from ..data.audio import audio_read
+from ..data.audio_dataset import SegmentInfo
+from ..data.audio_utils import convert_audio
+from ..environment import AudioCraftEnvironment
+from ..quantization import ResidualVectorQuantizer
+from ..utils.autocast import TorchAutocast
+from ..utils.cache import EmbeddingCache
+from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
+
+
+logger = logging.getLogger(__name__)
+TextCondition = tp.Optional[str]  # a text condition can be a string or None (if doesn't exist)
+ConditionType = tp.Tuple[torch.Tensor, torch.Tensor]  # condition, mask
+
+
+class WavCondition(tp.NamedTuple):
+    wav: torch.Tensor
+    length: torch.Tensor
+    sample_rate: tp.List[int]
+    path: tp.List[tp.Optional[str]] = []
+    seek_time: tp.List[tp.Optional[float]] = []
+
+
+class JointEmbedCondition(tp.NamedTuple):
+    wav: torch.Tensor
+    text: tp.List[tp.Optional[str]]
+    length: torch.Tensor
+    sample_rate: tp.List[int]
+    path: tp.List[tp.Optional[str]] = []
+    seek_time: tp.List[tp.Optional[float]] = []
+
+
+@dataclass
+class ConditioningAttributes:
+    text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
+    wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
+    joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    @property
+    def text_attributes(self):
+        return self.text.keys()
+
+    @property
+    def wav_attributes(self):
+        return self.wav.keys()
+
+    @property
+    def joint_embed_attributes(self):
+        return self.joint_embed.keys()
+
+    @property
+    def attributes(self):
+        return {
+            "text": self.text_attributes,
+            "wav": self.wav_attributes,
+            "joint_embed": self.joint_embed_attributes,
+        }
+
+    def to_flat_dict(self):
+        return {
+            **{f"text.{k}": v for k, v in self.text.items()},
+            **{f"wav.{k}": v for k, v in self.wav.items()},
+            **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
+        }
+
+    @classmethod
+    def from_flat_dict(cls, x):
+        out = cls()
+        for k, v in x.items():
+            kind, att = k.split(".")
+            out[kind][att] = v
+        return out
+
+
+class SegmentWithAttributes(SegmentInfo):
+    """Base class for all dataclasses that are used for conditioning.
+    All child classes should implement `to_condition_attributes` that converts
+    the existing attributes to a dataclass of type ConditioningAttributes.
+    """
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        raise NotImplementedError()
+
+
+def nullify_condition(condition: ConditionType, dim: int = 1):
+    """Transform an input condition to a null condition.
+    The way it is done by converting it to a single zero vector similarly
+    to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
+
+    Args:
+        condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
+        dim (int): The dimension that will be truncated (should be the time dimension)
+        WARNING!: dim should not be the batch dimension!
+    Returns:
+        ConditionType: A tuple of null condition and mask
+    """
+    assert dim != 0, "dim cannot be the batch dimension!"
+    assert isinstance(condition, tuple) and \
+        isinstance(condition[0], torch.Tensor) and \
+        isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
+    cond, mask = condition
+    B = cond.shape[0]
+    last_dim = cond.dim() - 1
+    out = cond.transpose(dim, last_dim)
+    out = 0. * out[..., :1]
+    out = out.transpose(dim, last_dim)
+    mask = torch.zeros((B, 1), device=out.device).int()
+    assert cond.dim() == out.dim()
+    return out, mask
+
+
+def nullify_wav(cond: WavCondition) -> WavCondition:
+    """Transform a WavCondition to a nullified WavCondition.
+    It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
+
+    Args:
+        cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
+    Returns:
+        WavCondition: Nullified wav condition.
+    """
+    null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
+    return WavCondition(
+        wav=null_wav,
+        length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
+        sample_rate=cond.sample_rate,
+        path=[None] * cond.wav.shape[0],
+        seek_time=[None] * cond.wav.shape[0],
+    )
+
+
+def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
+    """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
+    and replacing metadata by dummy attributes.
+
+    Args:
+        cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
+    """
+    null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
+    return JointEmbedCondition(
+        wav=null_wav, text=[None] * len(embed.text),
+        length=torch.LongTensor([0]).to(embed.wav.device),
+        sample_rate=embed.sample_rate,
+        path=[None] * embed.wav.shape[0],
+        seek_time=[0] * embed.wav.shape[0],
+    )
+
+
+class Tokenizer:
+    """Base tokenizer implementation
+    (in case we want to introduce more advances tokenizers in the future).
+    """
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        raise NotImplementedError()
+
+
+class WhiteSpaceTokenizer(Tokenizer):
+    """This tokenizer should be used for natural language descriptions.
+    For example:
+    ["he didn't, know he's going home.", 'shorter sentence'] =>
+    [[78, 62, 31,  4, 78, 25, 19, 34],
+    [59, 77,  0,  0,  0,  0,  0,  0]]
+    """
+    PUNCTUATION = "?:!.,;"
+
+    def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
+                 lemma: bool = True, stopwords: bool = True) -> None:
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+        self.lemma = lemma
+        self.stopwords = stopwords
+        try:
+            self.nlp = spacy.load(language)
+        except IOError:
+            spacy.cli.download(language)  # type: ignore
+            self.nlp = spacy.load(language)
+
+    @tp.no_type_check
+    def __call__(self, texts: tp.List[tp.Optional[str]],
+                 return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Take a list of strings and convert them to a tensor of indices.
+
+        Args:
+            texts (list[str]): List of strings.
+            return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
+        Returns:
+            tuple[torch.Tensor, torch.Tensor]:
+                - Indices of words in the LUT.
+                - And a mask indicating where the padding tokens are
+        """
+        output, lengths = [], []
+        texts = deepcopy(texts)
+        for i, text in enumerate(texts):
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(torch.Tensor([self.pad_idx]))
+                lengths.append(0)
+                continue
+
+            # convert numbers to words
+            text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text)  # type: ignore
+            # normalize text
+            text = self.nlp(text)  # type: ignore
+            # remove stopwords
+            if self.stopwords:
+                text = [w for w in text if not w.is_stop]  # type: ignore
+            # remove punctuation
+            text = [w for w in text if w.text not in self.PUNCTUATION]  # type: ignore
+            # lemmatize if needed
+            text = [getattr(t, "lemma_" if self.lemma else "text") for t in text]  # type: ignore
+
+            texts[i] = " ".join(text)
+            lengths.append(len(text))
+            # convert to tensor
+            tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
+            output.append(tokens)
+
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
+        if return_text:
+            return padded_output, mask, texts  # type: ignore
+        return padded_output, mask
+
+
+class NoopTokenizer(Tokenizer):
+    """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
+    The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
+    strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
+    split it to ["Jeff", "Buckley"] and return an index per word.
+
+    For example:
+    ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
+    ["Metal", "Rock", "Classical"] => [0, 223, 51]
+    """
+    def __init__(self, n_bins: int, pad_idx: int = 0):
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        output, lengths = [], []
+        for text in texts:
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(self.pad_idx)
+                lengths.append(0)
+            else:
+                output.append(hash_trick(text, self.n_bins))
+                lengths.append(1)
+
+        tokens = torch.LongTensor(output).unsqueeze(1)
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        return tokens, mask
+
+
+class BaseConditioner(nn.Module):
+    """Base model for all conditioner modules.
+    We allow the output dim to be different than the hidden dim for two reasons:
+    1) keep our LUTs small when the vocab is large;
+    2) make all condition dims consistent.
+
+    Args:
+        dim (int): Hidden dim of the model.
+        output_dim (int): Output dim of the conditioner.
+    """
+    def __init__(self, dim: int, output_dim: int):
+        super().__init__()
+        self.dim = dim
+        self.output_dim = output_dim
+        self.output_proj = nn.Linear(dim, output_dim)
+
+    def tokenize(self, *args, **kwargs) -> tp.Any:
+        """Should be any part of the processing that will lead to a synchronization
+        point, e.g. BPE tokenization with transfer to the GPU.
+
+        The returned value will be saved and return later when calling forward().
+        """
+        raise NotImplementedError()
+
+    def forward(self, inputs: tp.Any) -> ConditionType:
+        """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+        Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+        Returns:
+            ConditionType:
+                - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+                  output embedding and D is the dimension of the embedding.
+                - And a mask indicating where the padding tokens.
+        """
+        raise NotImplementedError()
+
+
+class TextConditioner(BaseConditioner):
+    ...
+
+
+class LUTConditioner(TextConditioner):
+    """Lookup table TextConditioner.
+
+    Args:
+        n_bins (int): Number of bins.
+        dim (int): Hidden dim of the model (text-encoder/LUT).
+        output_dim (int): Output dim of the conditioner.
+        tokenizer (str): Name of the tokenizer.
+        pad_idx (int, optional): Index for padding token. Defaults to 0.
+    """
+    def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
+        super().__init__(dim, output_dim)
+        self.embed = nn.Embedding(n_bins, dim)
+        self.tokenizer: Tokenizer
+        if tokenizer == 'whitespace':
+            self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
+        elif tokenizer == 'noop':
+            self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
+        else:
+            raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        device = self.embed.weight.device
+        tokens, mask = self.tokenizer(x)
+        tokens, mask = tokens.to(device), mask.to(device)
+        return tokens, mask
+
+    def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
+        tokens, mask = inputs
+        embeds = self.embed(tokens)
+        embeds = self.output_proj(embeds)
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+
+class T5Conditioner(TextConditioner):
+    """T5-based TextConditioner.
+
+    Args:
+        name (str): Name of the T5 model.
+        output_dim (int): Output dim of the conditioner.
+        finetune (bool): Whether to fine-tune T5 at train time.
+        device (str): Device for T5 Conditioner.
+        autocast_dtype (tp.Optional[str], optional): Autocast dtype.
+        word_dropout (float, optional): Word dropout probability.
+        normalize_text (bool, optional): Whether to apply text normalization.
+    """
+    MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
+              "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
+              "google/flan-t5-xl", "google/flan-t5-xxl"]
+    MODELS_DIMS = {
+        "t5-small": 512,
+        "t5-base": 768,
+        "t5-large": 1024,
+        "t5-3b": 1024,
+        "t5-11b": 1024,
+        "google/flan-t5-small": 512,
+        "google/flan-t5-base": 768,
+        "google/flan-t5-large": 1024,
+        "google/flan-t5-3b": 1024,
+        "google/flan-t5-11b": 1024,
+    }
+
+    def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
+                 autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
+                 normalize_text: bool = False):
+        assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
+        super().__init__(self.MODELS_DIMS[name], output_dim)
+        self.device = device
+        self.name = name
+        self.finetune = finetune
+        self.word_dropout = word_dropout
+        if autocast_dtype is None or self.device == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+            if self.device != 'cpu':
+                logger.warning("T5 has no autocast, this might lead to NaN")
+        else:
+            dtype = getattr(torch, autocast_dtype)
+            assert isinstance(dtype, torch.dtype)
+            logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
+            self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+        # Let's disable logging temporarily because T5 will vomit some errors otherwise.
+        # thanks https://gist.github.com/simon-weber/7853144
+        previous_level = logging.root.manager.disable
+        logging.disable(logging.ERROR)
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            try:
+                self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
+                t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
+            finally:
+                logging.disable(previous_level)
+        if finetune:
+            self.t5 = t5
+        else:
+            # this makes sure that the t5 models is not part
+            # of the saved checkpoint
+            self.__dict__['t5'] = t5.to(device)
+
+        self.normalize_text = normalize_text
+        if normalize_text:
+            self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
+        # if current sample doesn't have a certain attribute, replace with empty string
+        entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
+        if self.normalize_text:
+            _, _, entries = self.text_normalizer(entries, return_text=True)
+        if self.word_dropout > 0. and self.training:
+            new_entries = []
+            for entry in entries:
+                words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
+                new_entries.append(" ".join(words))
+            entries = new_entries
+
+        empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
+
+        inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
+        mask = inputs['attention_mask']
+        mask[empty_idx, :] = 0  # zero-out index where the input is non-existant
+        return inputs
+
+    def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
+        mask = inputs['attention_mask']
+        with torch.set_grad_enabled(self.finetune), self.autocast:
+            embeds = self.t5(**inputs).last_hidden_state
+        embeds = self.output_proj(embeds.to(self.output_proj.weight))
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+
+class WaveformConditioner(BaseConditioner):
+    """Base class for all conditioners that take a waveform as input.
+    Classes that inherit must implement `_get_wav_embedding` that outputs
+    a continuous tensor, and `_downsampling_factor` that returns the down-sampling
+    factor of the embedding model.
+
+    Args:
+        dim (int): The internal representation dimension.
+        output_dim (int): Output dimension.
+        device (tp.Union[torch.device, str]): Device.
+    """
+    def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
+        super().__init__(dim, output_dim)
+        self.device = device
+
+    def tokenize(self, x: WavCondition) -> WavCondition:
+        wav, length, sample_rate, path, seek_time = x
+        assert length is not None
+        return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
+
+    def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
+        """Gets as input a WavCondition and returns a dense embedding."""
+        raise NotImplementedError()
+
+    def _downsampling_factor(self):
+        """Returns the downsampling factor of the embedding model."""
+        raise NotImplementedError()
+
+    def forward(self, x: WavCondition) -> ConditionType:
+        """Extract condition embedding and mask from a waveform and its metadata.
+        Args:
+            x (WavCondition): Waveform condition containing raw waveform and metadata.
+        Returns:
+            ConditionType: a dense vector representing the conditioning along with its mask
+        """
+        wav, lengths, *_ = x
+        with torch.no_grad():
+            embeds = self._get_wav_embedding(x)
+        embeds = embeds.to(self.output_proj.weight)
+        embeds = self.output_proj(embeds)
+
+        if lengths is not None:
+            lengths = lengths / self._downsampling_factor()
+            mask = length_to_mask(lengths, max_len=embeds.shape[1]).int()  # type: ignore
+        else:
+            mask = torch.ones_like(embeds)
+        embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+        return embeds, mask
+
+
+class ChromaStemConditioner(WaveformConditioner):
+    """Chroma conditioner based on stems.
+    The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
+    the drums and bass often dominate the chroma leading to the chroma features
+    not containing information about the melody.
+
+    Args:
+        output_dim (int): Output dimension for the conditioner.
+        sample_rate (int): Sample rate for the chroma extractor.
+        n_chroma (int): Number of chroma bins for the chroma extractor.
+        radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
+        duration (int): duration used during training. This is later used for correct padding
+            in case we are using chroma as prefix.
+        match_len_on_eval (bool, optional): if True then all chromas are padded to the training
+            duration. Defaults to False.
+        eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
+            conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
+            Defaults to None.
+        n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
+        device (tp.Union[torch.device, str], optional): Device for the conditioner.
+        **kwargs: Additional parameters for the chroma extractor.
+    """
+    def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
+                 duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
+                 n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
+                 device: tp.Union[torch.device, str] = 'cpu', **kwargs):
+        from demucs import pretrained
+        super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
+        self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
+        self.sample_rate = sample_rate
+        self.match_len_on_eval = match_len_on_eval
+        self.duration = duration
+        self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
+        stem_sources: list = self.demucs.sources  # type: ignore
+        self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
+        self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
+                                      radix2_exp=radix2_exp, **kwargs).to(device)
+        self.chroma_len = self._get_chroma_len()
+        self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
+        self.cache = None
+        if cache_path is not None:
+            self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
+                                        compute_embed_fn=self._get_full_chroma_for_cache,
+                                        extract_embed_fn=self._extract_chroma_chunk)
+
+    def _downsampling_factor(self) -> int:
+        return self.chroma.winhop
+
+    def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
+        """Load pre-defined waveforms from a json.
+        These waveforms will be used for chroma extraction during evaluation.
+        This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
+        """
+        if path is None:
+            return None
+
+        logger.info(f"Loading evaluation wavs from {path}")
+        from audiocraft.data.audio_dataset import AudioDataset
+        dataset: AudioDataset = AudioDataset.from_meta(
+            path, segment_duration=self.duration, min_audio_duration=self.duration,
+            sample_rate=self.sample_rate, channels=1)
+
+        if len(dataset) > 0:
+            eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
+            logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
+            return eval_wavs
+        else:
+            raise ValueError("Could not find evaluation wavs, check lengths of wavs")
+
+    def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
+        self.eval_wavs = eval_wavs
+
+    def has_eval_wavs(self) -> bool:
+        return self.eval_wavs is not None
+
+    def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
+        """Sample wavs from a predefined list."""
+        assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
+        total_eval_wavs = len(self.eval_wavs)
+        out = self.eval_wavs
+        if num_samples > total_eval_wavs:
+            out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
+        return out[torch.randperm(len(out))][:num_samples]
+
+    def _get_chroma_len(self) -> int:
+        """Get length of chroma during training."""
+        dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
+        dummy_chr = self.chroma(dummy_wav)
+        return dummy_chr.shape[1]
+
+    @torch.no_grad()
+    def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        """Get parts of the wav that holds the melody, extracting the main stems from the wav."""
+        from demucs.apply import apply_model
+        from demucs.audio import convert_audio
+        with self.autocast:
+            wav = convert_audio(
+                wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels)  # type: ignore
+            stems = apply_model(self.demucs, wav, device=self.device)
+            stems = stems[:, self.stem_indices]  # extract relevant stems for melody conditioning
+            mix_wav = stems.sum(1)  # merge extracted stems to single waveform
+            mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1)  # type: ignore
+            return mix_wav
+
+    @torch.no_grad()
+    def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
+        """Extract chroma features from the waveform."""
+        with self.autocast:
+            return self.chroma(wav)
+
+    @torch.no_grad()
+    def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        """Compute wav embedding, applying stem and chroma extraction."""
+        # avoid 0-size tensors when we are working with null conds
+        if wav.shape[-1] == 1:
+            return self._extract_chroma(wav)
+        stems = self._get_stemmed_wav(wav, sample_rate)
+        chroma = self._extract_chroma(stems)
+        return chroma
+
+    @torch.no_grad()
+    def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
+        """Extract chroma from the whole audio waveform at the given path."""
+        wav, sr = audio_read(path)
+        wav = wav[None].to(self.device)
+        wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
+        chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
+        return chroma
+
+    def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
+        """Extract a chunk of chroma from the full chroma derived from the full waveform."""
+        wav_length = x.wav.shape[-1]
+        seek_time = x.seek_time[idx]
+        assert seek_time is not None, (
+            "WavCondition seek_time is required "
+            "when extracting chroma chunks from pre-computed chroma.")
+        full_chroma = full_chroma.float()
+        frame_rate = self.sample_rate / self._downsampling_factor()
+        target_length = int(frame_rate * wav_length / self.sample_rate)
+        index = int(frame_rate * seek_time)
+        out = full_chroma[index: index + target_length]
+        out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
+        return out.to(self.device)
+
+    @torch.no_grad()
+    def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
+        """Get the wav embedding from the WavCondition.
+        The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
+        or will rely on the embedding cache to load the pre-computed embedding if relevant.
+        """
+        sampled_wav: tp.Optional[torch.Tensor] = None
+        if not self.training and self.eval_wavs is not None:
+            warn_once(logger, "Using precomputed evaluation wavs!")
+            sampled_wav = self._sample_eval_wavs(len(x.wav))
+
+        no_undefined_paths = all(p is not None for p in x.path)
+        no_nullified_cond = x.wav.shape[-1] > 1
+        if sampled_wav is not None:
+            chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
+        elif self.cache is not None and no_undefined_paths and no_nullified_cond:
+            paths = [Path(p) for p in x.path if p is not None]
+            chroma = self.cache.get_embed_from_cache(paths, x)
+        else:
+            assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
+            chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
+
+        if self.match_len_on_eval:
+            B, T, C = chroma.shape
+            if T > self.chroma_len:
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
+            elif T < self.chroma_len:
+                n_repeat = int(math.ceil(self.chroma_len / T))
+                chroma = chroma.repeat(1, n_repeat, 1)
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
+
+        return chroma
+
+    def tokenize(self, x: WavCondition) -> WavCondition:
+        """Apply WavConditioner tokenization and populate cache if needed."""
+        x = super().tokenize(x)
+        no_undefined_paths = all(p is not None for p in x.path)
+        if self.cache is not None and no_undefined_paths:
+            paths = [Path(p) for p in x.path if p is not None]
+            self.cache.populate_embed_cache(paths, x)
+        return x
+
+
+class JointEmbeddingConditioner(BaseConditioner):
+    """Joint embedding conditioning supporting both audio or text conditioning.
+
+    Args:
+        dim (int): Dimension.
+        output_dim (int): Output dimension.
+        device (str): Device.
+        attribute (str): Attribute used by the conditioner.
+        autocast_dtype (str): Autocast for the conditioner.
+        quantize (bool): Whether to quantize the CLAP embedding.
+        n_q (int): Number of residual quantizers (used if quantize is true).
+        bins (int): Quantizers' codebooks size (used if quantize is true).
+        kwargs: Additional parameters for residual vector quantizer.
+    """
+    def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
+                 autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
+                 n_q: int = 12, bins: int = 1024, **kwargs):
+        super().__init__(dim=dim, output_dim=output_dim)
+        self.device = device
+        self.attribute = attribute
+        if autocast_dtype is None or device == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+            logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
+        else:
+            dtype = getattr(torch, autocast_dtype)
+            assert isinstance(dtype, torch.dtype)
+            logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
+            self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+        # residual vector quantizer to discretize the conditioned embedding
+        self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
+        if quantize:
+            self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
+
+    def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Get joint embedding in latent space from the inputs.
+
+        Returns:
+            tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
+                and corresponding empty indexes.
+        """
+        raise NotImplementedError()
+
+    def forward(self, x: JointEmbedCondition) -> ConditionType:
+        with self.autocast:
+            embed, empty_idx = self._get_embed(x)
+            if self.quantizer is not None:
+                embed = embed.view(-1, self.dim, 1)
+                q_res = self.quantizer(embed, frame_rate=1)
+                out_embed = q_res.x.view(-1, self.dim)
+            else:
+                out_embed = embed
+            out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
+            mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
+            mask[empty_idx, :] = 0  # zero-out index where the input is non-existant
+            out_embed = (out_embed * mask.unsqueeze(-1))
+            return out_embed, mask
+
+    def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
+        return x
+
+
+class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
+    """Joint Embedding conditioner based on pre-trained CLAP model.
+
+    This CLAP-based conditioner supports a caching mechanism
+    over the computed embeddings for faster training.
+
+    Args:
+        dim (int): Dimension.
+        output_dim (int): Output dimension.
+        device (str): Device.
+        attribute (str): Attribute used by the conditioner.
+        quantize (bool): Whether to quantize the CLAP embedding.
+        n_q (int): Number of residual quantizers (used if quantize is true).
+        bins (int): Quantizers' codebooks size (used if quantize is true).
+        checkpoint (str): Path to CLAP checkpoint.
+        model_arch (str): CLAP model architecture.
+        enable_fusion (bool): Enable fusion for CLAP model.
+        sample_rate (int): Sample rate used by CLAP model.
+        max_audio_length (float): Maximum audio length for CLAP model.
+        audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
+        normalize (bool): Whether to normalize the CLAP embedding.
+        text_p (float): Probability of using text representation instead of audio at train time.
+        batch_size (Optional[int]): Batch size for CLAP embedding computation.
+        autocast_dtype (str): Autocast for the conditioner.
+        cache_path (Optional[str]): Path for pre-computed embeddings caching.
+        kwargs: Additional parameters for residual vector quantizer.
+    """
+    def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
+                 quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
+                 enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
+                 normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
+                 autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
+        try:
+            import laion_clap  # type: ignore
+        except ImportError:
+            raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
+        warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
+                      "Please retrain all models.")
+        checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
+        clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+        clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
+        load_clap_state_dict(clap_model, checkpoint)
+        clap_model.eval()
+        clap_model.to(device)
+        super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
+                         autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
+                         **kwargs)
+        self.checkpoint = checkpoint
+        self.enable_fusion = enable_fusion
+        self.model_arch = model_arch
+        self.clap: laion_clap.CLAP_Module
+        self.clap_tokenize: RobertaTokenizer
+        self.clap_sample_rate = sample_rate
+        self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
+        self.clap_stride = int(self.clap_sample_rate * audio_stride)
+        self.batch_size = batch_size or 1
+        self.normalize = normalize
+        self.text_p = text_p
+        self.__dict__['clap_tokenize'] = clap_tokenize
+        self.__dict__['clap'] = clap_model
+        self.wav_cache, self.text_cache = None, None
+        if cache_path is not None:
+            self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
+                                            compute_embed_fn=self._get_wav_embedding_for_cache,
+                                            extract_embed_fn=self._extract_wav_embedding_chunk)
+            self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
+                                             compute_embed_fn=self._get_text_embedding_for_cache)
+
+    def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
+        # we use the default params from CLAP module here as well
+        return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
+
+    def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
+        """Compute text embedding from CLAP model on a given a batch of text.
+
+        Args:
+            text (list[str]): List of text for the batch, with B items.
+        Returns:
+            torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
+        """
+        with torch.no_grad():
+            embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
+            return embed.view(embed.size(0), 1, embed.size(-1))
+
+    def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
+                                      x: JointEmbedCondition, idx: int) -> torch.Tensor:
+        """Get text embedding function for the cache."""
+        text = x.text[idx]
+        text = text if text is not None else ""
+        return self._compute_text_embedding([text])[0]
+
+    def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
+        """Preprocess wav to expected format by CLAP model.
+
+        Args:
+            wav (torch.Tensor): Audio wav, of shape [B, C, T].
+            length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
+            sample_rates (list[int]): Sample rates for each sample in the batch
+        Returns:
+            torch.Tensor: Audio wav of shape [B, T].
+        """
+        assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
+        if sample_rates is not None:
+            _wav = []
+            for i, audio in enumerate(wav):
+                sr = sample_rates[i]
+                audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
+                _wav.append(audio)
+            wav = torch.stack(_wav, dim=0)
+        wav = wav.mean(dim=1)
+        return wav
+
+    def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
+                               sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
+        """Compute audio wave embedding from CLAP model.
+
+        Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
+        we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
+        average the resulting embeddings.
+
+        Args:
+            wav (torch.Tensor): Audio wav, of shape [B, C, T].
+            length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
+            sample_rates (list[int]): Sample rates for each sample in the batch.
+            reduce_mean (bool): Whether to get the average tensor.
+        Returns:
+            torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
+        """
+        with torch.no_grad():
+            wav = self._preprocess_wav(wav, length, sample_rates)
+            B, T = wav.shape
+            if T >= self.clap_max_frames:
+                wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride)  # [B, F, T]
+            else:
+                wav = wav.view(-1, 1, T)  # [B, F, T] with F=1
+            wav = einops.rearrange(wav, 'b f t -> (b f) t')
+            embed_list = []
+            for i in range(0, wav.size(0), self.batch_size):
+                _wav = wav[i:i+self.batch_size, ...]
+                _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
+                embed_list.append(_embed)
+            embed = torch.cat(embed_list, dim=0)
+            embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
+            if reduce_mean:
+                embed = embed.mean(dim=1, keepdim=True)
+            return embed  # [B, F, D] with F=1 if reduce_mean is True
+
+    def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
+                                     x: JointEmbedCondition, idx: int) -> torch.Tensor:
+        """Compute audio wave embedding for the cache.
+        The embedding is computed on a given audio read from file.
+
+        Args:
+            path (str or Path): Path to the full audio file.
+        Returns:
+            torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
+        """
+        wav, sr = audio_read(path)  # [C, T]
+        wav = wav.unsqueeze(0).to(self.device)  # [1, C, T]
+        wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
+        embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False)  # [B, F, D]
+        return embed.squeeze(0)  # [F, D]
+
+    def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
+        """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
+
+        Args:
+            full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
+            x (JointEmbedCondition): Joint embedding condition for the full batch.
+            idx (int): Index considered for the given embedding to extract.
+        Returns:
+            torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
+        """
+        sample_rate = x.sample_rate[idx]
+        seek_time = x.seek_time[idx]
+        seek_time = 0. if seek_time is None else seek_time
+        clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
+        end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
+        start_offset = int(seek_time * sample_rate // clap_stride)
+        end_offset = int(end_seek_time * sample_rate // clap_stride)
+        wav_embed = full_embed[start_offset:end_offset, ...]
+        wav_embed = wav_embed.mean(dim=0, keepdim=True)
+        return wav_embed.to(self.device)  # [F, D]
+
+    def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
+        """Get CLAP embedding from a batch of text descriptions."""
+        no_nullified_cond = x.wav.shape[-1] > 1  # we don't want to read from cache when condition dropout
+        if self.text_cache is not None and no_nullified_cond:
+            assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
+            paths = [Path(p) for p in x.path if p is not None]
+            embed = self.text_cache.get_embed_from_cache(paths, x)
+        else:
+            text = [xi if xi is not None else "" for xi in x.text]
+            embed = self._compute_text_embedding(text)
+        if self.normalize:
+            embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
+        return embed
+
+    def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
+        """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
+        no_undefined_paths = all(p is not None for p in x.path)
+        no_nullified_cond = x.wav.shape[-1] > 1  # we don't want to read from cache when condition dropout
+        if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
+            paths = [Path(p) for p in x.path if p is not None]
+            embed = self.wav_cache.get_embed_from_cache(paths, x)
+        else:
+            embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
+        if self.normalize:
+            embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
+        return embed
+
+    def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
+        # Trying to limit as much as possible sync points when the cache is warm.
+        no_undefined_paths = all(p is not None for p in x.path)
+        if self.wav_cache is not None and no_undefined_paths:
+            assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
+            paths = [Path(p) for p in x.path if p is not None]
+            self.wav_cache.populate_embed_cache(paths, x)
+        if self.text_cache is not None and no_undefined_paths:
+            assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
+            paths = [Path(p) for p in x.path if p is not None]
+            self.text_cache.populate_embed_cache(paths, x)
+        return x
+
+    def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Extract shared latent representation from either the wav or the text using CLAP."""
+        # decide whether to use text embedding at train time or not
+        use_text_embed = random.random() < self.text_p
+        if self.training and not use_text_embed:
+            embed = self._get_wav_embedding(x)
+            empty_idx = torch.LongTensor([])  # we assume we always have the audio wav
+        else:
+            embed = self._get_text_embedding(x)
+            empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
+        return embed, empty_idx
+
+
+def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
+    """Utility function for nullifying an attribute inside an ConditioningAttributes object.
+    If the condition is of type "wav", then nullify it using `nullify_condition` function.
+    If the condition is of any other type, set its value to None.
+    Works in-place.
+    """
+    if condition_type not in ['text', 'wav', 'joint_embed']:
+        raise ValueError(
+            "dropout_condition got an unexpected condition type!"
+            f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
+        )
+
+    if condition not in getattr(sample, condition_type):
+        raise ValueError(
+            "dropout_condition received an unexpected condition!"
+            f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
+            f" but got '{condition}' of type '{condition_type}'!"
+        )
+
+    if condition_type == 'wav':
+        wav_cond = sample.wav[condition]
+        sample.wav[condition] = nullify_wav(wav_cond)
+    elif condition_type == 'joint_embed':
+        embed = sample.joint_embed[condition]
+        sample.joint_embed[condition] = nullify_joint_embed(embed)
+    else:
+        sample.text[condition] = None
+
+    return sample
+
+
+class DropoutModule(nn.Module):
+    """Base module for all dropout modules."""
+    def __init__(self, seed: int = 1234):
+        super().__init__()
+        self.rng = torch.Generator()
+        self.rng.manual_seed(seed)
+
+
+class AttributeDropout(DropoutModule):
+    """Dropout with a given probability per attribute.
+    This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
+    to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
+    This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
+    must also be dropped.
+
+    Args:
+        p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
+            ...
+            "genre": 0.1,
+            "artist": 0.5,
+            "wav": 0.25,
+            ...
+        active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
+        seed (int, optional): Random seed.
+    """
+    def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.active_on_eval = active_on_eval
+        # construct dict that return the values from p otherwise 0
+        self.p = {}
+        for condition_type, probs in p.items():
+            self.p[condition_type] = defaultdict(lambda: 0, probs)
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (list[ConditioningAttributes]): List of conditions.
+        Returns:
+            list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+        """
+        if not self.training and not self.active_on_eval:
+            return samples
+
+        samples = deepcopy(samples)
+        for condition_type, ps in self.p.items():  # for condition types [text, wav]
+            for condition, p in ps.items():  # for attributes of each type (e.g., [artist, genre])
+                if torch.rand(1, generator=self.rng).item() < p:
+                    for sample in samples:
+                        dropout_condition(sample, condition_type, condition)
+        return samples
+
+    def __repr__(self):
+        return f"AttributeDropout({dict(self.p)})"
+
+
+class ClassifierFreeGuidanceDropout(DropoutModule):
+    """Classifier Free Guidance dropout.
+    All attributes are dropped with the same probability.
+
+    Args:
+        p (float): Probability to apply condition dropout during training.
+        seed (int): Random seed.
+    """
+    def __init__(self, p: float, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.p = p
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (list[ConditioningAttributes]): List of conditions.
+        Returns:
+            list[ConditioningAttributes]: List of conditions after all attributes were set to None.
+        """
+        if not self.training:
+            return samples
+
+        # decide on which attributes to drop in a batched fashion
+        drop = torch.rand(1, generator=self.rng).item() < self.p
+        if not drop:
+            return samples
+
+        # nullify conditions of all attributes
+        samples = deepcopy(samples)
+        for condition_type in ["wav", "text"]:
+            for sample in samples:
+                for condition in sample.attributes[condition_type]:
+                    dropout_condition(sample, condition_type, condition)
+        return samples
+
+    def __repr__(self):
+        return f"ClassifierFreeGuidanceDropout(p={self.p})"
+
+
+class ConditioningProvider(nn.Module):
+    """Prepare and provide conditions given all the supported conditioners.
+
+    Args:
+        conditioners (dict): Dictionary of conditioners.
+        device (torch.device or str, optional): Device for conditioners and output condition types.
+    """
+    def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
+        super().__init__()
+        self.device = device
+        self.conditioners = nn.ModuleDict(conditioners)
+
+    @property
+    def joint_embed_conditions(self):
+        return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
+
+    @property
+    def has_joint_embed_conditions(self):
+        return len(self.joint_embed_conditions) > 0
+
+    @property
+    def text_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+    @property
+    def wav_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+    @property
+    def has_wav_condition(self):
+        return len(self.wav_conditions) > 0
+
+    def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+        """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+        This should be called before starting any real GPU work to avoid synchronization points.
+        This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+        Args:
+            inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
+                text and wav conditions.
+        """
+        assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
+            "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
+            f" but types were {set([type(x) for x in inputs])}"
+        )
+
+        output = {}
+        text = self._collate_text(inputs)
+        wavs = self._collate_wavs(inputs)
+        joint_embeds = self._collate_joint_embeds(inputs)
+
+        assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
+            f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
+            f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
+        )
+
+        for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
+            output[attribute] = self.conditioners[attribute].tokenize(batch)
+        return output
+
+    def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+        """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
+        The output is for example:
+        {
+            "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+            "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+            ...
+        }
+
+        Args:
+            tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+        """
+        output = {}
+        for attribute, inputs in tokenized.items():
+            condition, mask = self.conditioners[attribute](inputs)
+            output[attribute] = (condition, mask)
+        return output
+
+    def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
+        """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
+        are the attributes and the values are the aggregated input per attribute.
+        For example:
+        Input:
+        [
+            ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
+            ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
+        ]
+        Output:
+        {
+            "genre": ["Rock", "Hip-hop"],
+            "description": ["A rock song with a guitar solo", "A hip-hop verse"]
+        }
+
+        Args:
+            samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
+        Returns:
+            dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
+        """
+        out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
+        texts = [x.text for x in samples]
+        for text in texts:
+            for condition in self.text_conditions:
+                out[condition].append(text[condition])
+        return out
+
+    def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
+        """Generate a dict where the keys are attributes by which we fetch similar wavs,
+        and the values are Tensors of wavs according to said attributes.
+
+        *Note*: by the time the samples reach this function, each sample should have some waveform
+        inside the "wav" attribute. It should be either:
+        1. A real waveform
+        2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
+        3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
+
+        Args:
+            samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
+        Returns:
+            dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
+        """
+        wavs = defaultdict(list)
+        lengths = defaultdict(list)
+        sample_rates = defaultdict(list)
+        paths = defaultdict(list)
+        seek_times = defaultdict(list)
+        out: tp.Dict[str, WavCondition] = {}
+
+        for sample in samples:
+            for attribute in self.wav_conditions:
+                wav, length, sample_rate, path, seek_time = sample.wav[attribute]
+                assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
+                assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
+                # mono-channel conditioning
+                wav = wav.mean(1, keepdim=True)  # [1, 1, T]
+                wavs[attribute].append(wav.flatten())  # [T]
+                lengths[attribute].append(length)
+                sample_rates[attribute].extend(sample_rate)
+                paths[attribute].extend(path)
+                seek_times[attribute].extend(seek_time)
+
+        # stack all wavs to a single tensor
+        for attribute in self.wav_conditions:
+            stacked_wav, _ = collate(wavs[attribute], dim=0)
+            out[attribute] = WavCondition(
+                stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
+                paths[attribute], seek_times[attribute])
+
+        return out
+
+    def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
+        """Generate a dict where the keys are attributes by which we compute joint embeddings,
+        and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
+
+        Args:
+            samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
+        Returns:
+            A dictionary mapping an attribute name to joint embeddings.
+        """
+        texts = defaultdict(list)
+        wavs = defaultdict(list)
+        lengths = defaultdict(list)
+        sample_rates = defaultdict(list)
+        paths = defaultdict(list)
+        seek_times = defaultdict(list)
+        channels: int = 0
+
+        out = {}
+        for sample in samples:
+            for attribute in self.joint_embed_conditions:
+                wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
+                assert wav.dim() == 3
+                if channels == 0:
+                    channels = wav.size(1)
+                else:
+                    assert channels == wav.size(1), "not all audio has same number of channels in batch"
+                assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
+                wav = einops.rearrange(wav, "b c t -> (b c t)")  # [1, C, T] => [C * T]
+                wavs[attribute].append(wav)
+                texts[attribute].extend(text)
+                lengths[attribute].append(length)
+                sample_rates[attribute].extend(sample_rate)
+                paths[attribute].extend(path)
+                seek_times[attribute].extend(seek_time)
+
+        for attribute in self.joint_embed_conditions:
+            stacked_texts = texts[attribute]
+            stacked_paths = paths[attribute]
+            stacked_seek_times = seek_times[attribute]
+            stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
+            stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
+            stacked_sample_rates = sample_rates[attribute]
+            stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
+            assert stacked_lengths.size(0) == stacked_wavs.size(0)
+            assert len(stacked_sample_rates) == stacked_wavs.size(0)
+            assert len(stacked_texts) == stacked_wavs.size(0)
+            out[attribute] = JointEmbedCondition(
+                text=stacked_texts, wav=stacked_wavs,
+                length=stacked_lengths, sample_rate=stacked_sample_rates,
+                path=stacked_paths, seek_time=stacked_seek_times)
+
+        return out
+
+
+class ConditionFuser(StreamingModule):
+    """Condition fuser handles the logic to combine the different conditions
+    to the actual model input.
+
+    Args:
+        fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
+            each condition. For example:
+            {
+                "prepend": ["description"],
+                "sum": ["genre", "bpm"],
+                "cross": ["description"],
+            }
+        cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
+        cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
+    """
+    FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
+
+    def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
+                 cross_attention_pos_emb_scale: float = 1.0):
+        super().__init__()
+        assert all(
+            [k in self.FUSING_METHODS for k in fuse2cond.keys()]
+        ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
+        self.cross_attention_pos_emb = cross_attention_pos_emb
+        self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
+        self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
+        self.cond2fuse: tp.Dict[str, str] = {}
+        for fuse_method, conditions in fuse2cond.items():
+            for condition in conditions:
+                self.cond2fuse[condition] = fuse_method
+
+    def forward(
+        self,
+        input: torch.Tensor,
+        conditions: tp.Dict[str, ConditionType]
+    ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """Fuse the conditions to the provided model input.
+
+        Args:
+            input (torch.Tensor): Transformer input.
+            conditions (dict[str, ConditionType]): Dict of conditions.
+        Returns:
+            tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
+                after the conditions have been fused. The second output tensor is the tensor
+                used for cross-attention or None if no cross attention inputs exist.
+        """
+        B, T, _ = input.shape
+
+        if 'offsets' in self._streaming_state:
+            first_step = False
+            offsets = self._streaming_state['offsets']
+        else:
+            first_step = True
+            offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+        assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+            f"given conditions contain unknown attributes for fuser, " \
+            f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+        cross_attention_output = None
+        for cond_type, (cond, cond_mask) in conditions.items():
+            op = self.cond2fuse[cond_type]
+            if op == 'sum':
+                input += cond
+            elif op == 'input_interpolate':
+                cond = einops.rearrange(cond, "b t d -> b d t")
+                cond = F.interpolate(cond, size=input.shape[1])
+                input += einops.rearrange(cond, "b d t -> b t d")
+            elif op == 'prepend':
+                if first_step:
+                    input = torch.cat([cond, input], dim=1)
+            elif op == 'cross':
+                if cross_attention_output is not None:
+                    cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+                else:
+                    cross_attention_output = cond
+            else:
+                raise ValueError(f"unknown op ({op})")
+
+        if self.cross_attention_pos_emb and cross_attention_output is not None:
+            positions = torch.arange(
+                cross_attention_output.shape[1],
+                device=cross_attention_output.device
+            ).view(1, -1, 1)
+            pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+            cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return input, cross_attention_output
+
+
+
+
+
+
+
+

Functions

+
+
+def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) ‑> ConditioningAttributes +
+
+

Utility function for nullifying an attribute inside an ConditioningAttributes object. +If the condition is of type "wav", then nullify it using nullify_condition() function. +If the condition is of any other type, set its value to None. +Works in-place.

+
+ +Expand source code + +
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
+    """Utility function for nullifying an attribute inside an ConditioningAttributes object.
+    If the condition is of type "wav", then nullify it using `nullify_condition` function.
+    If the condition is of any other type, set its value to None.
+    Works in-place.
+    """
+    if condition_type not in ['text', 'wav', 'joint_embed']:
+        raise ValueError(
+            "dropout_condition got an unexpected condition type!"
+            f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
+        )
+
+    if condition not in getattr(sample, condition_type):
+        raise ValueError(
+            "dropout_condition received an unexpected condition!"
+            f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
+            f" but got '{condition}' of type '{condition_type}'!"
+        )
+
+    if condition_type == 'wav':
+        wav_cond = sample.wav[condition]
+        sample.wav[condition] = nullify_wav(wav_cond)
+    elif condition_type == 'joint_embed':
+        embed = sample.joint_embed[condition]
+        sample.joint_embed[condition] = nullify_joint_embed(embed)
+    else:
+        sample.text[condition] = None
+
+    return sample
+
+
+
+def nullify_condition(condition: Tuple[torch.Tensor, torch.Tensor], dim: int = 1) +
+
+

Transform an input condition to a null condition. +The way it is done by converting it to a single zero vector similarly +to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.

+

Args

+
+
condition : ConditionType
+
A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
+
dim : int
+
The dimension that will be truncated (should be the time dimension)
+
+

WARNING!: dim should not be the batch dimension!

+

Returns

+
+
ConditionType
+
A tuple of null condition and mask
+
+
+ +Expand source code + +
def nullify_condition(condition: ConditionType, dim: int = 1):
+    """Transform an input condition to a null condition.
+    The way it is done by converting it to a single zero vector similarly
+    to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
+
+    Args:
+        condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
+        dim (int): The dimension that will be truncated (should be the time dimension)
+        WARNING!: dim should not be the batch dimension!
+    Returns:
+        ConditionType: A tuple of null condition and mask
+    """
+    assert dim != 0, "dim cannot be the batch dimension!"
+    assert isinstance(condition, tuple) and \
+        isinstance(condition[0], torch.Tensor) and \
+        isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
+    cond, mask = condition
+    B = cond.shape[0]
+    last_dim = cond.dim() - 1
+    out = cond.transpose(dim, last_dim)
+    out = 0. * out[..., :1]
+    out = out.transpose(dim, last_dim)
+    mask = torch.zeros((B, 1), device=out.device).int()
+    assert cond.dim() == out.dim()
+    return out, mask
+
+
+
+def nullify_joint_embed(embed: JointEmbedCondition) ‑> JointEmbedCondition +
+
+

Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, +and replacing metadata by dummy attributes.

+

Args

+
+
cond : JointEmbedCondition
+
Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
+
+
+ +Expand source code + +
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
+    """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
+    and replacing metadata by dummy attributes.
+
+    Args:
+        cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
+    """
+    null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
+    return JointEmbedCondition(
+        wav=null_wav, text=[None] * len(embed.text),
+        length=torch.LongTensor([0]).to(embed.wav.device),
+        sample_rate=embed.sample_rate,
+        path=[None] * embed.wav.shape[0],
+        seek_time=[0] * embed.wav.shape[0],
+    )
+
+
+
+def nullify_wav(cond: WavCondition) ‑> WavCondition +
+
+

Transform a WavCondition to a nullified WavCondition. +It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.

+

Args

+
+
cond : WavCondition
+
Wav condition with wav, tensor of shape [B, T].
+
+

Returns

+
+
WavCondition
+
Nullified wav condition.
+
+
+ +Expand source code + +
def nullify_wav(cond: WavCondition) -> WavCondition:
+    """Transform a WavCondition to a nullified WavCondition.
+    It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
+
+    Args:
+        cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
+    Returns:
+        WavCondition: Nullified wav condition.
+    """
+    null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
+    return WavCondition(
+        wav=null_wav,
+        length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
+        sample_rate=cond.sample_rate,
+        path=[None] * cond.wav.shape[0],
+        seek_time=[None] * cond.wav.shape[0],
+    )
+
+
+
+
+
+

Classes

+
+
+class AttributeDropout +(p: Dict[str, Dict[str, float]], active_on_eval: bool = False, seed: int = 1234) +
+
+

Dropout with a given probability per attribute. +This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes +to be dropped out separately. For example, "artist" can be dropped while "genre" remains. +This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" +must also be dropped.

+

Args

+
+
p : tp.Dict[str, float]
+
A dict mapping between attributes and dropout probability. For example: +… +"genre": 0.1, +"artist": 0.5, +"wav": 0.25, +…
+
active_on_eval : bool, optional
+
Whether the dropout is active at eval. Default to False.
+
seed : int, optional
+
Random seed.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class AttributeDropout(DropoutModule):
+    """Dropout with a given probability per attribute.
+    This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
+    to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
+    This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
+    must also be dropped.
+
+    Args:
+        p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
+            ...
+            "genre": 0.1,
+            "artist": 0.5,
+            "wav": 0.25,
+            ...
+        active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
+        seed (int, optional): Random seed.
+    """
+    def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.active_on_eval = active_on_eval
+        # construct dict that return the values from p otherwise 0
+        self.p = {}
+        for condition_type, probs in p.items():
+            self.p[condition_type] = defaultdict(lambda: 0, probs)
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (list[ConditioningAttributes]): List of conditions.
+        Returns:
+            list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+        """
+        if not self.training and not self.active_on_eval:
+            return samples
+
+        samples = deepcopy(samples)
+        for condition_type, ps in self.p.items():  # for condition types [text, wav]
+            for condition, p in ps.items():  # for attributes of each type (e.g., [artist, genre])
+                if torch.rand(1, generator=self.rng).item() < p:
+                    for sample in samples:
+                        dropout_condition(sample, condition_type, condition)
+        return samples
+
+    def __repr__(self):
+        return f"AttributeDropout({dict(self.p)})"
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, samples: List[ConditioningAttributes]) ‑> List[ConditioningAttributes] +
+
+

Args

+
+
samples : list[ConditioningAttributes]
+
List of conditions.
+
+

Returns

+
+
list[ConditioningAttributes]
+
List of conditions after certain attributes were set to None.
+
+
+ +Expand source code + +
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+    """
+    Args:
+        samples (list[ConditioningAttributes]): List of conditions.
+    Returns:
+        list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+    """
+    if not self.training and not self.active_on_eval:
+        return samples
+
+    samples = deepcopy(samples)
+    for condition_type, ps in self.p.items():  # for condition types [text, wav]
+        for condition, p in ps.items():  # for attributes of each type (e.g., [artist, genre])
+            if torch.rand(1, generator=self.rng).item() < p:
+                for sample in samples:
+                    dropout_condition(sample, condition_type, condition)
+    return samples
+
+
+
+
+
+class BaseConditioner +(dim: int, output_dim: int) +
+
+

Base model for all conditioner modules. +We allow the output dim to be different than the hidden dim for two reasons: +1) keep our LUTs small when the vocab is large; +2) make all condition dims consistent.

+

Args

+
+
dim : int
+
Hidden dim of the model.
+
output_dim : int
+
Output dim of the conditioner.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BaseConditioner(nn.Module):
+    """Base model for all conditioner modules.
+    We allow the output dim to be different than the hidden dim for two reasons:
+    1) keep our LUTs small when the vocab is large;
+    2) make all condition dims consistent.
+
+    Args:
+        dim (int): Hidden dim of the model.
+        output_dim (int): Output dim of the conditioner.
+    """
+    def __init__(self, dim: int, output_dim: int):
+        super().__init__()
+        self.dim = dim
+        self.output_dim = output_dim
+        self.output_proj = nn.Linear(dim, output_dim)
+
+    def tokenize(self, *args, **kwargs) -> tp.Any:
+        """Should be any part of the processing that will lead to a synchronization
+        point, e.g. BPE tokenization with transfer to the GPU.
+
+        The returned value will be saved and return later when calling forward().
+        """
+        raise NotImplementedError()
+
+    def forward(self, inputs: tp.Any) -> ConditionType:
+        """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+        Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+        Returns:
+            ConditionType:
+                - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+                  output embedding and D is the dimension of the embedding.
+                - And a mask indicating where the padding tokens.
+        """
+        raise NotImplementedError()
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, inputs: Any) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Gets input that should be used as conditioning (e.g, genre, description or a waveform). +Outputs a ConditionType, after the input data was embedded as a dense vector.

+

Returns

+

ConditionType: +- A tensor of size [B, T, D] where B is the batch size, T is the length of the +output embedding and D is the dimension of the embedding. +- And a mask indicating where the padding tokens.

+
+ +Expand source code + +
def forward(self, inputs: tp.Any) -> ConditionType:
+    """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+    Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+    Returns:
+        ConditionType:
+            - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+              output embedding and D is the dimension of the embedding.
+            - And a mask indicating where the padding tokens.
+    """
+    raise NotImplementedError()
+
+
+
+def tokenize(self, *args, **kwargs) ‑> Any +
+
+

Should be any part of the processing that will lead to a synchronization +point, e.g. BPE tokenization with transfer to the GPU.

+

The returned value will be saved and return later when calling forward().

+
+ +Expand source code + +
def tokenize(self, *args, **kwargs) -> tp.Any:
+    """Should be any part of the processing that will lead to a synchronization
+    point, e.g. BPE tokenization with transfer to the GPU.
+
+    The returned value will be saved and return later when calling forward().
+    """
+    raise NotImplementedError()
+
+
+
+
+
+class CLAPEmbeddingConditioner +(dim: int, output_dim: int, device: str, attribute: str, quantize: bool, n_q: int, bins: int, checkpoint: Union[str, pathlib.Path], model_arch: str, enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, normalize: bool, text_p: bool, batch_size: Optional[int] = None, autocast_dtype: Optional[str] = 'float32', cache_path: Optional[str] = None, **kwargs) +
+
+

Joint Embedding conditioner based on pre-trained CLAP model.

+

This CLAP-based conditioner supports a caching mechanism +over the computed embeddings for faster training.

+

Args

+
+
dim : int
+
Dimension.
+
output_dim : int
+
Output dimension.
+
device : str
+
Device.
+
attribute : str
+
Attribute used by the conditioner.
+
quantize : bool
+
Whether to quantize the CLAP embedding.
+
n_q : int
+
Number of residual quantizers (used if quantize is true).
+
bins : int
+
Quantizers' codebooks size (used if quantize is true).
+
checkpoint : str
+
Path to CLAP checkpoint.
+
model_arch : str
+
CLAP model architecture.
+
enable_fusion : bool
+
Enable fusion for CLAP model.
+
sample_rate : int
+
Sample rate used by CLAP model.
+
max_audio_length : float
+
Maximum audio length for CLAP model.
+
audio_stride : float
+
Stride to use for getting a CLAP embedding on the full sequence.
+
normalize : bool
+
Whether to normalize the CLAP embedding.
+
text_p : float
+
Probability of using text representation instead of audio at train time.
+
batch_size : Optional[int]
+
Batch size for CLAP embedding computation.
+
autocast_dtype : str
+
Autocast for the conditioner.
+
cache_path : Optional[str]
+
Path for pre-computed embeddings caching.
+
kwargs
+
Additional parameters for residual vector quantizer.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
+    """Joint Embedding conditioner based on pre-trained CLAP model.
+
+    This CLAP-based conditioner supports a caching mechanism
+    over the computed embeddings for faster training.
+
+    Args:
+        dim (int): Dimension.
+        output_dim (int): Output dimension.
+        device (str): Device.
+        attribute (str): Attribute used by the conditioner.
+        quantize (bool): Whether to quantize the CLAP embedding.
+        n_q (int): Number of residual quantizers (used if quantize is true).
+        bins (int): Quantizers' codebooks size (used if quantize is true).
+        checkpoint (str): Path to CLAP checkpoint.
+        model_arch (str): CLAP model architecture.
+        enable_fusion (bool): Enable fusion for CLAP model.
+        sample_rate (int): Sample rate used by CLAP model.
+        max_audio_length (float): Maximum audio length for CLAP model.
+        audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
+        normalize (bool): Whether to normalize the CLAP embedding.
+        text_p (float): Probability of using text representation instead of audio at train time.
+        batch_size (Optional[int]): Batch size for CLAP embedding computation.
+        autocast_dtype (str): Autocast for the conditioner.
+        cache_path (Optional[str]): Path for pre-computed embeddings caching.
+        kwargs: Additional parameters for residual vector quantizer.
+    """
+    def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
+                 quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
+                 enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
+                 normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
+                 autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
+        try:
+            import laion_clap  # type: ignore
+        except ImportError:
+            raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
+        warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
+                      "Please retrain all models.")
+        checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
+        clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+        clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
+        load_clap_state_dict(clap_model, checkpoint)
+        clap_model.eval()
+        clap_model.to(device)
+        super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
+                         autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
+                         **kwargs)
+        self.checkpoint = checkpoint
+        self.enable_fusion = enable_fusion
+        self.model_arch = model_arch
+        self.clap: laion_clap.CLAP_Module
+        self.clap_tokenize: RobertaTokenizer
+        self.clap_sample_rate = sample_rate
+        self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
+        self.clap_stride = int(self.clap_sample_rate * audio_stride)
+        self.batch_size = batch_size or 1
+        self.normalize = normalize
+        self.text_p = text_p
+        self.__dict__['clap_tokenize'] = clap_tokenize
+        self.__dict__['clap'] = clap_model
+        self.wav_cache, self.text_cache = None, None
+        if cache_path is not None:
+            self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
+                                            compute_embed_fn=self._get_wav_embedding_for_cache,
+                                            extract_embed_fn=self._extract_wav_embedding_chunk)
+            self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
+                                             compute_embed_fn=self._get_text_embedding_for_cache)
+
+    def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
+        # we use the default params from CLAP module here as well
+        return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
+
+    def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
+        """Compute text embedding from CLAP model on a given a batch of text.
+
+        Args:
+            text (list[str]): List of text for the batch, with B items.
+        Returns:
+            torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
+        """
+        with torch.no_grad():
+            embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
+            return embed.view(embed.size(0), 1, embed.size(-1))
+
+    def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
+                                      x: JointEmbedCondition, idx: int) -> torch.Tensor:
+        """Get text embedding function for the cache."""
+        text = x.text[idx]
+        text = text if text is not None else ""
+        return self._compute_text_embedding([text])[0]
+
+    def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
+        """Preprocess wav to expected format by CLAP model.
+
+        Args:
+            wav (torch.Tensor): Audio wav, of shape [B, C, T].
+            length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
+            sample_rates (list[int]): Sample rates for each sample in the batch
+        Returns:
+            torch.Tensor: Audio wav of shape [B, T].
+        """
+        assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
+        if sample_rates is not None:
+            _wav = []
+            for i, audio in enumerate(wav):
+                sr = sample_rates[i]
+                audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
+                _wav.append(audio)
+            wav = torch.stack(_wav, dim=0)
+        wav = wav.mean(dim=1)
+        return wav
+
+    def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
+                               sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
+        """Compute audio wave embedding from CLAP model.
+
+        Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
+        we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
+        average the resulting embeddings.
+
+        Args:
+            wav (torch.Tensor): Audio wav, of shape [B, C, T].
+            length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
+            sample_rates (list[int]): Sample rates for each sample in the batch.
+            reduce_mean (bool): Whether to get the average tensor.
+        Returns:
+            torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
+        """
+        with torch.no_grad():
+            wav = self._preprocess_wav(wav, length, sample_rates)
+            B, T = wav.shape
+            if T >= self.clap_max_frames:
+                wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride)  # [B, F, T]
+            else:
+                wav = wav.view(-1, 1, T)  # [B, F, T] with F=1
+            wav = einops.rearrange(wav, 'b f t -> (b f) t')
+            embed_list = []
+            for i in range(0, wav.size(0), self.batch_size):
+                _wav = wav[i:i+self.batch_size, ...]
+                _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
+                embed_list.append(_embed)
+            embed = torch.cat(embed_list, dim=0)
+            embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
+            if reduce_mean:
+                embed = embed.mean(dim=1, keepdim=True)
+            return embed  # [B, F, D] with F=1 if reduce_mean is True
+
+    def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
+                                     x: JointEmbedCondition, idx: int) -> torch.Tensor:
+        """Compute audio wave embedding for the cache.
+        The embedding is computed on a given audio read from file.
+
+        Args:
+            path (str or Path): Path to the full audio file.
+        Returns:
+            torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
+        """
+        wav, sr = audio_read(path)  # [C, T]
+        wav = wav.unsqueeze(0).to(self.device)  # [1, C, T]
+        wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
+        embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False)  # [B, F, D]
+        return embed.squeeze(0)  # [F, D]
+
+    def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
+        """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
+
+        Args:
+            full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
+            x (JointEmbedCondition): Joint embedding condition for the full batch.
+            idx (int): Index considered for the given embedding to extract.
+        Returns:
+            torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
+        """
+        sample_rate = x.sample_rate[idx]
+        seek_time = x.seek_time[idx]
+        seek_time = 0. if seek_time is None else seek_time
+        clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
+        end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
+        start_offset = int(seek_time * sample_rate // clap_stride)
+        end_offset = int(end_seek_time * sample_rate // clap_stride)
+        wav_embed = full_embed[start_offset:end_offset, ...]
+        wav_embed = wav_embed.mean(dim=0, keepdim=True)
+        return wav_embed.to(self.device)  # [F, D]
+
+    def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
+        """Get CLAP embedding from a batch of text descriptions."""
+        no_nullified_cond = x.wav.shape[-1] > 1  # we don't want to read from cache when condition dropout
+        if self.text_cache is not None and no_nullified_cond:
+            assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
+            paths = [Path(p) for p in x.path if p is not None]
+            embed = self.text_cache.get_embed_from_cache(paths, x)
+        else:
+            text = [xi if xi is not None else "" for xi in x.text]
+            embed = self._compute_text_embedding(text)
+        if self.normalize:
+            embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
+        return embed
+
+    def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
+        """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
+        no_undefined_paths = all(p is not None for p in x.path)
+        no_nullified_cond = x.wav.shape[-1] > 1  # we don't want to read from cache when condition dropout
+        if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
+            paths = [Path(p) for p in x.path if p is not None]
+            embed = self.wav_cache.get_embed_from_cache(paths, x)
+        else:
+            embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
+        if self.normalize:
+            embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
+        return embed
+
+    def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
+        # Trying to limit as much as possible sync points when the cache is warm.
+        no_undefined_paths = all(p is not None for p in x.path)
+        if self.wav_cache is not None and no_undefined_paths:
+            assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
+            paths = [Path(p) for p in x.path if p is not None]
+            self.wav_cache.populate_embed_cache(paths, x)
+        if self.text_cache is not None and no_undefined_paths:
+            assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
+            paths = [Path(p) for p in x.path if p is not None]
+            self.text_cache.populate_embed_cache(paths, x)
+        return x
+
+    def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Extract shared latent representation from either the wav or the text using CLAP."""
+        # decide whether to use text embedding at train time or not
+        use_text_embed = random.random() < self.text_p
+        if self.training and not use_text_embed:
+            embed = self._get_wav_embedding(x)
+            empty_idx = torch.LongTensor([])  # we assume we always have the audio wav
+        else:
+            embed = self._get_text_embedding(x)
+            empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
+        return embed, empty_idx
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class ChromaStemConditioner +(output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, duration: float, match_len_on_eval: bool = True, eval_wavs: Optional[str] = None, n_eval_wavs: int = 0, cache_path: Union[pathlib.Path, str, None] = None, device: Union[torch.device, str] = 'cpu', **kwargs) +
+
+

Chroma conditioner based on stems. +The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as +the drums and bass often dominate the chroma leading to the chroma features +not containing information about the melody.

+

Args

+
+
output_dim : int
+
Output dimension for the conditioner.
+
sample_rate : int
+
Sample rate for the chroma extractor.
+
n_chroma : int
+
Number of chroma bins for the chroma extractor.
+
radix2_exp : int
+
Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
+
duration : int
+
duration used during training. This is later used for correct padding +in case we are using chroma as prefix.
+
match_len_on_eval : bool, optional
+
if True then all chromas are padded to the training +duration. Defaults to False.
+
eval_wavs : str, optional
+
path to a dataset manifest with waveform, this waveforms are used as +conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). +Defaults to None.
+
n_eval_wavs : int, optional
+
limits the number of waveforms used for conditioning. Defaults to 0.
+
device : tp.Union[torch.device, str], optional
+
Device for the conditioner.
+
**kwargs
+
Additional parameters for the chroma extractor.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ChromaStemConditioner(WaveformConditioner):
+    """Chroma conditioner based on stems.
+    The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
+    the drums and bass often dominate the chroma leading to the chroma features
+    not containing information about the melody.
+
+    Args:
+        output_dim (int): Output dimension for the conditioner.
+        sample_rate (int): Sample rate for the chroma extractor.
+        n_chroma (int): Number of chroma bins for the chroma extractor.
+        radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
+        duration (int): duration used during training. This is later used for correct padding
+            in case we are using chroma as prefix.
+        match_len_on_eval (bool, optional): if True then all chromas are padded to the training
+            duration. Defaults to False.
+        eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
+            conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
+            Defaults to None.
+        n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
+        device (tp.Union[torch.device, str], optional): Device for the conditioner.
+        **kwargs: Additional parameters for the chroma extractor.
+    """
+    def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
+                 duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
+                 n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
+                 device: tp.Union[torch.device, str] = 'cpu', **kwargs):
+        from demucs import pretrained
+        super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
+        self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
+        self.sample_rate = sample_rate
+        self.match_len_on_eval = match_len_on_eval
+        self.duration = duration
+        self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
+        stem_sources: list = self.demucs.sources  # type: ignore
+        self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
+        self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
+                                      radix2_exp=radix2_exp, **kwargs).to(device)
+        self.chroma_len = self._get_chroma_len()
+        self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
+        self.cache = None
+        if cache_path is not None:
+            self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
+                                        compute_embed_fn=self._get_full_chroma_for_cache,
+                                        extract_embed_fn=self._extract_chroma_chunk)
+
+    def _downsampling_factor(self) -> int:
+        return self.chroma.winhop
+
+    def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
+        """Load pre-defined waveforms from a json.
+        These waveforms will be used for chroma extraction during evaluation.
+        This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
+        """
+        if path is None:
+            return None
+
+        logger.info(f"Loading evaluation wavs from {path}")
+        from audiocraft.data.audio_dataset import AudioDataset
+        dataset: AudioDataset = AudioDataset.from_meta(
+            path, segment_duration=self.duration, min_audio_duration=self.duration,
+            sample_rate=self.sample_rate, channels=1)
+
+        if len(dataset) > 0:
+            eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
+            logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
+            return eval_wavs
+        else:
+            raise ValueError("Could not find evaluation wavs, check lengths of wavs")
+
+    def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
+        self.eval_wavs = eval_wavs
+
+    def has_eval_wavs(self) -> bool:
+        return self.eval_wavs is not None
+
+    def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
+        """Sample wavs from a predefined list."""
+        assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
+        total_eval_wavs = len(self.eval_wavs)
+        out = self.eval_wavs
+        if num_samples > total_eval_wavs:
+            out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
+        return out[torch.randperm(len(out))][:num_samples]
+
+    def _get_chroma_len(self) -> int:
+        """Get length of chroma during training."""
+        dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
+        dummy_chr = self.chroma(dummy_wav)
+        return dummy_chr.shape[1]
+
+    @torch.no_grad()
+    def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        """Get parts of the wav that holds the melody, extracting the main stems from the wav."""
+        from demucs.apply import apply_model
+        from demucs.audio import convert_audio
+        with self.autocast:
+            wav = convert_audio(
+                wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels)  # type: ignore
+            stems = apply_model(self.demucs, wav, device=self.device)
+            stems = stems[:, self.stem_indices]  # extract relevant stems for melody conditioning
+            mix_wav = stems.sum(1)  # merge extracted stems to single waveform
+            mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1)  # type: ignore
+            return mix_wav
+
+    @torch.no_grad()
+    def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
+        """Extract chroma features from the waveform."""
+        with self.autocast:
+            return self.chroma(wav)
+
+    @torch.no_grad()
+    def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
+        """Compute wav embedding, applying stem and chroma extraction."""
+        # avoid 0-size tensors when we are working with null conds
+        if wav.shape[-1] == 1:
+            return self._extract_chroma(wav)
+        stems = self._get_stemmed_wav(wav, sample_rate)
+        chroma = self._extract_chroma(stems)
+        return chroma
+
+    @torch.no_grad()
+    def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
+        """Extract chroma from the whole audio waveform at the given path."""
+        wav, sr = audio_read(path)
+        wav = wav[None].to(self.device)
+        wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
+        chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
+        return chroma
+
+    def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
+        """Extract a chunk of chroma from the full chroma derived from the full waveform."""
+        wav_length = x.wav.shape[-1]
+        seek_time = x.seek_time[idx]
+        assert seek_time is not None, (
+            "WavCondition seek_time is required "
+            "when extracting chroma chunks from pre-computed chroma.")
+        full_chroma = full_chroma.float()
+        frame_rate = self.sample_rate / self._downsampling_factor()
+        target_length = int(frame_rate * wav_length / self.sample_rate)
+        index = int(frame_rate * seek_time)
+        out = full_chroma[index: index + target_length]
+        out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
+        return out.to(self.device)
+
+    @torch.no_grad()
+    def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
+        """Get the wav embedding from the WavCondition.
+        The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
+        or will rely on the embedding cache to load the pre-computed embedding if relevant.
+        """
+        sampled_wav: tp.Optional[torch.Tensor] = None
+        if not self.training and self.eval_wavs is not None:
+            warn_once(logger, "Using precomputed evaluation wavs!")
+            sampled_wav = self._sample_eval_wavs(len(x.wav))
+
+        no_undefined_paths = all(p is not None for p in x.path)
+        no_nullified_cond = x.wav.shape[-1] > 1
+        if sampled_wav is not None:
+            chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
+        elif self.cache is not None and no_undefined_paths and no_nullified_cond:
+            paths = [Path(p) for p in x.path if p is not None]
+            chroma = self.cache.get_embed_from_cache(paths, x)
+        else:
+            assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
+            chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
+
+        if self.match_len_on_eval:
+            B, T, C = chroma.shape
+            if T > self.chroma_len:
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
+            elif T < self.chroma_len:
+                n_repeat = int(math.ceil(self.chroma_len / T))
+                chroma = chroma.repeat(1, n_repeat, 1)
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
+
+        return chroma
+
+    def tokenize(self, x: WavCondition) -> WavCondition:
+        """Apply WavConditioner tokenization and populate cache if needed."""
+        x = super().tokenize(x)
+        no_undefined_paths = all(p is not None for p in x.path)
+        if self.cache is not None and no_undefined_paths:
+            paths = [Path(p) for p in x.path if p is not None]
+            self.cache.populate_embed_cache(paths, x)
+        return x
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def has_eval_wavs(self) ‑> bool +
+
+
+
+ +Expand source code + +
def has_eval_wavs(self) -> bool:
+    return self.eval_wavs is not None
+
+
+
+def reset_eval_wavs(self, eval_wavs: Optional[torch.Tensor]) ‑> None +
+
+
+
+ +Expand source code + +
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
+    self.eval_wavs = eval_wavs
+
+
+
+def tokenize(self, x: WavCondition) ‑> WavCondition +
+
+

Apply WavConditioner tokenization and populate cache if needed.

+
+ +Expand source code + +
def tokenize(self, x: WavCondition) -> WavCondition:
+    """Apply WavConditioner tokenization and populate cache if needed."""
+    x = super().tokenize(x)
+    no_undefined_paths = all(p is not None for p in x.path)
+    if self.cache is not None and no_undefined_paths:
+        paths = [Path(p) for p in x.path if p is not None]
+        self.cache.populate_embed_cache(paths, x)
+    return x
+
+
+
+

Inherited members

+ +
+
+class ClassifierFreeGuidanceDropout +(p: float, seed: int = 1234) +
+
+

Classifier Free Guidance dropout. +All attributes are dropped with the same probability.

+

Args

+
+
p : float
+
Probability to apply condition dropout during training.
+
seed : int
+
Random seed.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ClassifierFreeGuidanceDropout(DropoutModule):
+    """Classifier Free Guidance dropout.
+    All attributes are dropped with the same probability.
+
+    Args:
+        p (float): Probability to apply condition dropout during training.
+        seed (int): Random seed.
+    """
+    def __init__(self, p: float, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.p = p
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (list[ConditioningAttributes]): List of conditions.
+        Returns:
+            list[ConditioningAttributes]: List of conditions after all attributes were set to None.
+        """
+        if not self.training:
+            return samples
+
+        # decide on which attributes to drop in a batched fashion
+        drop = torch.rand(1, generator=self.rng).item() < self.p
+        if not drop:
+            return samples
+
+        # nullify conditions of all attributes
+        samples = deepcopy(samples)
+        for condition_type in ["wav", "text"]:
+            for sample in samples:
+                for condition in sample.attributes[condition_type]:
+                    dropout_condition(sample, condition_type, condition)
+        return samples
+
+    def __repr__(self):
+        return f"ClassifierFreeGuidanceDropout(p={self.p})"
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, samples: List[ConditioningAttributes]) ‑> List[ConditioningAttributes] +
+
+

Args

+
+
samples : list[ConditioningAttributes]
+
List of conditions.
+
+

Returns

+
+
list[ConditioningAttributes]
+
List of conditions after all attributes were set to None.
+
+
+ +Expand source code + +
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+    """
+    Args:
+        samples (list[ConditioningAttributes]): List of conditions.
+    Returns:
+        list[ConditioningAttributes]: List of conditions after all attributes were set to None.
+    """
+    if not self.training:
+        return samples
+
+    # decide on which attributes to drop in a batched fashion
+    drop = torch.rand(1, generator=self.rng).item() < self.p
+    if not drop:
+        return samples
+
+    # nullify conditions of all attributes
+    samples = deepcopy(samples)
+    for condition_type in ["wav", "text"]:
+        for sample in samples:
+            for condition in sample.attributes[condition_type]:
+                dropout_condition(sample, condition_type, condition)
+    return samples
+
+
+
+
+
+class ConditionFuser +(fuse2cond: Dict[str, List[str]], cross_attention_pos_emb: bool = False, cross_attention_pos_emb_scale: float = 1.0) +
+
+

Condition fuser handles the logic to combine the different conditions +to the actual model input.

+

Args

+
+
fuse2cond : tp.Dict[str, str]
+
A dictionary that says how to fuse +each condition. For example: +{ +"prepend": ["description"], +"sum": ["genre", "bpm"], +"cross": ["description"], +}
+
cross_attention_pos_emb : bool, optional
+
Use positional embeddings in cross attention.
+
cross_attention_pos_emb_scale : int
+
Scale for positional embeddings in cross attention if used.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ConditionFuser(StreamingModule):
+    """Condition fuser handles the logic to combine the different conditions
+    to the actual model input.
+
+    Args:
+        fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
+            each condition. For example:
+            {
+                "prepend": ["description"],
+                "sum": ["genre", "bpm"],
+                "cross": ["description"],
+            }
+        cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
+        cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
+    """
+    FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
+
+    def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
+                 cross_attention_pos_emb_scale: float = 1.0):
+        super().__init__()
+        assert all(
+            [k in self.FUSING_METHODS for k in fuse2cond.keys()]
+        ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
+        self.cross_attention_pos_emb = cross_attention_pos_emb
+        self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
+        self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
+        self.cond2fuse: tp.Dict[str, str] = {}
+        for fuse_method, conditions in fuse2cond.items():
+            for condition in conditions:
+                self.cond2fuse[condition] = fuse_method
+
+    def forward(
+        self,
+        input: torch.Tensor,
+        conditions: tp.Dict[str, ConditionType]
+    ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """Fuse the conditions to the provided model input.
+
+        Args:
+            input (torch.Tensor): Transformer input.
+            conditions (dict[str, ConditionType]): Dict of conditions.
+        Returns:
+            tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
+                after the conditions have been fused. The second output tensor is the tensor
+                used for cross-attention or None if no cross attention inputs exist.
+        """
+        B, T, _ = input.shape
+
+        if 'offsets' in self._streaming_state:
+            first_step = False
+            offsets = self._streaming_state['offsets']
+        else:
+            first_step = True
+            offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+        assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+            f"given conditions contain unknown attributes for fuser, " \
+            f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+        cross_attention_output = None
+        for cond_type, (cond, cond_mask) in conditions.items():
+            op = self.cond2fuse[cond_type]
+            if op == 'sum':
+                input += cond
+            elif op == 'input_interpolate':
+                cond = einops.rearrange(cond, "b t d -> b d t")
+                cond = F.interpolate(cond, size=input.shape[1])
+                input += einops.rearrange(cond, "b d t -> b t d")
+            elif op == 'prepend':
+                if first_step:
+                    input = torch.cat([cond, input], dim=1)
+            elif op == 'cross':
+                if cross_attention_output is not None:
+                    cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+                else:
+                    cross_attention_output = cond
+            else:
+                raise ValueError(f"unknown op ({op})")
+
+        if self.cross_attention_pos_emb and cross_attention_output is not None:
+            positions = torch.arange(
+                cross_attention_output.shape[1],
+                device=cross_attention_output.device
+            ).view(1, -1, 1)
+            pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+            cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return input, cross_attention_output
+
+

Ancestors

+ +

Class variables

+
+
var FUSING_METHODS
+
+
+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, input: torch.Tensor, conditions: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+

Fuse the conditions to the provided model input.

+

Args

+
+
input : torch.Tensor
+
Transformer input.
+
conditions : dict[str, ConditionType]
+
Dict of conditions.
+
+

Returns

+
+
tuple[torch.Tensor, torch.Tensor]
+
The first tensor is the transformer input +after the conditions have been fused. The second output tensor is the tensor +used for cross-attention or None if no cross attention inputs exist.
+
+
+ +Expand source code + +
def forward(
+    self,
+    input: torch.Tensor,
+    conditions: tp.Dict[str, ConditionType]
+) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+    """Fuse the conditions to the provided model input.
+
+    Args:
+        input (torch.Tensor): Transformer input.
+        conditions (dict[str, ConditionType]): Dict of conditions.
+    Returns:
+        tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
+            after the conditions have been fused. The second output tensor is the tensor
+            used for cross-attention or None if no cross attention inputs exist.
+    """
+    B, T, _ = input.shape
+
+    if 'offsets' in self._streaming_state:
+        first_step = False
+        offsets = self._streaming_state['offsets']
+    else:
+        first_step = True
+        offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+    assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+        f"given conditions contain unknown attributes for fuser, " \
+        f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+    cross_attention_output = None
+    for cond_type, (cond, cond_mask) in conditions.items():
+        op = self.cond2fuse[cond_type]
+        if op == 'sum':
+            input += cond
+        elif op == 'input_interpolate':
+            cond = einops.rearrange(cond, "b t d -> b d t")
+            cond = F.interpolate(cond, size=input.shape[1])
+            input += einops.rearrange(cond, "b d t -> b t d")
+        elif op == 'prepend':
+            if first_step:
+                input = torch.cat([cond, input], dim=1)
+        elif op == 'cross':
+            if cross_attention_output is not None:
+                cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+            else:
+                cross_attention_output = cond
+        else:
+            raise ValueError(f"unknown op ({op})")
+
+    if self.cross_attention_pos_emb and cross_attention_output is not None:
+        positions = torch.arange(
+            cross_attention_output.shape[1],
+            device=cross_attention_output.device
+        ).view(1, -1, 1)
+        pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+        cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+    if self._is_streaming:
+        self._streaming_state['offsets'] = offsets + T
+
+    return input, cross_attention_output
+
+
+
+

Inherited members

+ +
+
+class ConditioningAttributes +(text: Dict[str, Optional[str]] = <factory>, wav: Dict[str, WavCondition] = <factory>, joint_embed: Dict[str, JointEmbedCondition] = <factory>) +
+
+

ConditioningAttributes(text: Dict[str, Union[str, NoneType]] = , wav: Dict[str, audiocraft.modules.conditioners.WavCondition] = , joint_embed: Dict[str, audiocraft.modules.conditioners.JointEmbedCondition] = )

+
+ +Expand source code + +
class ConditioningAttributes:
+    text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
+    wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
+    joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    @property
+    def text_attributes(self):
+        return self.text.keys()
+
+    @property
+    def wav_attributes(self):
+        return self.wav.keys()
+
+    @property
+    def joint_embed_attributes(self):
+        return self.joint_embed.keys()
+
+    @property
+    def attributes(self):
+        return {
+            "text": self.text_attributes,
+            "wav": self.wav_attributes,
+            "joint_embed": self.joint_embed_attributes,
+        }
+
+    def to_flat_dict(self):
+        return {
+            **{f"text.{k}": v for k, v in self.text.items()},
+            **{f"wav.{k}": v for k, v in self.wav.items()},
+            **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
+        }
+
+    @classmethod
+    def from_flat_dict(cls, x):
+        out = cls()
+        for k, v in x.items():
+            kind, att = k.split(".")
+            out[kind][att] = v
+        return out
+
+

Class variables

+
+
var joint_embed : Dict[str, JointEmbedCondition]
+
+
+
+
var text : Dict[str, Optional[str]]
+
+
+
+
var wav : Dict[str, WavCondition]
+
+
+
+
+

Static methods

+
+
+def from_flat_dict(x) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_flat_dict(cls, x):
+    out = cls()
+    for k, v in x.items():
+        kind, att = k.split(".")
+        out[kind][att] = v
+    return out
+
+
+
+

Instance variables

+
+
var attributes
+
+
+
+ +Expand source code + +
@property
+def attributes(self):
+    return {
+        "text": self.text_attributes,
+        "wav": self.wav_attributes,
+        "joint_embed": self.joint_embed_attributes,
+    }
+
+
+
var joint_embed_attributes
+
+
+
+ +Expand source code + +
@property
+def joint_embed_attributes(self):
+    return self.joint_embed.keys()
+
+
+
var text_attributes
+
+
+
+ +Expand source code + +
@property
+def text_attributes(self):
+    return self.text.keys()
+
+
+
var wav_attributes
+
+
+
+ +Expand source code + +
@property
+def wav_attributes(self):
+    return self.wav.keys()
+
+
+
+

Methods

+
+
+def to_flat_dict(self) +
+
+
+
+ +Expand source code + +
def to_flat_dict(self):
+    return {
+        **{f"text.{k}": v for k, v in self.text.items()},
+        **{f"wav.{k}": v for k, v in self.wav.items()},
+        **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
+    }
+
+
+
+
+
+class ConditioningProvider +(conditioners: Dict[str, BaseConditioner], device: Union[torch.device, str] = 'cpu') +
+
+

Prepare and provide conditions given all the supported conditioners.

+

Args

+
+
conditioners : dict
+
Dictionary of conditioners.
+
device : torch.device or str, optional
+
Device for conditioners and output condition types.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ConditioningProvider(nn.Module):
+    """Prepare and provide conditions given all the supported conditioners.
+
+    Args:
+        conditioners (dict): Dictionary of conditioners.
+        device (torch.device or str, optional): Device for conditioners and output condition types.
+    """
+    def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
+        super().__init__()
+        self.device = device
+        self.conditioners = nn.ModuleDict(conditioners)
+
+    @property
+    def joint_embed_conditions(self):
+        return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
+
+    @property
+    def has_joint_embed_conditions(self):
+        return len(self.joint_embed_conditions) > 0
+
+    @property
+    def text_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+    @property
+    def wav_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+    @property
+    def has_wav_condition(self):
+        return len(self.wav_conditions) > 0
+
+    def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+        """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+        This should be called before starting any real GPU work to avoid synchronization points.
+        This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+        Args:
+            inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
+                text and wav conditions.
+        """
+        assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
+            "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
+            f" but types were {set([type(x) for x in inputs])}"
+        )
+
+        output = {}
+        text = self._collate_text(inputs)
+        wavs = self._collate_wavs(inputs)
+        joint_embeds = self._collate_joint_embeds(inputs)
+
+        assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
+            f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
+            f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
+        )
+
+        for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
+            output[attribute] = self.conditioners[attribute].tokenize(batch)
+        return output
+
+    def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+        """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
+        The output is for example:
+        {
+            "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+            "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+            ...
+        }
+
+        Args:
+            tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+        """
+        output = {}
+        for attribute, inputs in tokenized.items():
+            condition, mask = self.conditioners[attribute](inputs)
+            output[attribute] = (condition, mask)
+        return output
+
+    def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
+        """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
+        are the attributes and the values are the aggregated input per attribute.
+        For example:
+        Input:
+        [
+            ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
+            ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
+        ]
+        Output:
+        {
+            "genre": ["Rock", "Hip-hop"],
+            "description": ["A rock song with a guitar solo", "A hip-hop verse"]
+        }
+
+        Args:
+            samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
+        Returns:
+            dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
+        """
+        out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
+        texts = [x.text for x in samples]
+        for text in texts:
+            for condition in self.text_conditions:
+                out[condition].append(text[condition])
+        return out
+
+    def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
+        """Generate a dict where the keys are attributes by which we fetch similar wavs,
+        and the values are Tensors of wavs according to said attributes.
+
+        *Note*: by the time the samples reach this function, each sample should have some waveform
+        inside the "wav" attribute. It should be either:
+        1. A real waveform
+        2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
+        3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
+
+        Args:
+            samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
+        Returns:
+            dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
+        """
+        wavs = defaultdict(list)
+        lengths = defaultdict(list)
+        sample_rates = defaultdict(list)
+        paths = defaultdict(list)
+        seek_times = defaultdict(list)
+        out: tp.Dict[str, WavCondition] = {}
+
+        for sample in samples:
+            for attribute in self.wav_conditions:
+                wav, length, sample_rate, path, seek_time = sample.wav[attribute]
+                assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
+                assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
+                # mono-channel conditioning
+                wav = wav.mean(1, keepdim=True)  # [1, 1, T]
+                wavs[attribute].append(wav.flatten())  # [T]
+                lengths[attribute].append(length)
+                sample_rates[attribute].extend(sample_rate)
+                paths[attribute].extend(path)
+                seek_times[attribute].extend(seek_time)
+
+        # stack all wavs to a single tensor
+        for attribute in self.wav_conditions:
+            stacked_wav, _ = collate(wavs[attribute], dim=0)
+            out[attribute] = WavCondition(
+                stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
+                paths[attribute], seek_times[attribute])
+
+        return out
+
+    def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
+        """Generate a dict where the keys are attributes by which we compute joint embeddings,
+        and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
+
+        Args:
+            samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
+        Returns:
+            A dictionary mapping an attribute name to joint embeddings.
+        """
+        texts = defaultdict(list)
+        wavs = defaultdict(list)
+        lengths = defaultdict(list)
+        sample_rates = defaultdict(list)
+        paths = defaultdict(list)
+        seek_times = defaultdict(list)
+        channels: int = 0
+
+        out = {}
+        for sample in samples:
+            for attribute in self.joint_embed_conditions:
+                wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
+                assert wav.dim() == 3
+                if channels == 0:
+                    channels = wav.size(1)
+                else:
+                    assert channels == wav.size(1), "not all audio has same number of channels in batch"
+                assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
+                wav = einops.rearrange(wav, "b c t -> (b c t)")  # [1, C, T] => [C * T]
+                wavs[attribute].append(wav)
+                texts[attribute].extend(text)
+                lengths[attribute].append(length)
+                sample_rates[attribute].extend(sample_rate)
+                paths[attribute].extend(path)
+                seek_times[attribute].extend(seek_time)
+
+        for attribute in self.joint_embed_conditions:
+            stacked_texts = texts[attribute]
+            stacked_paths = paths[attribute]
+            stacked_seek_times = seek_times[attribute]
+            stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
+            stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
+            stacked_sample_rates = sample_rates[attribute]
+            stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
+            assert stacked_lengths.size(0) == stacked_wavs.size(0)
+            assert len(stacked_sample_rates) == stacked_wavs.size(0)
+            assert len(stacked_texts) == stacked_wavs.size(0)
+            out[attribute] = JointEmbedCondition(
+                text=stacked_texts, wav=stacked_wavs,
+                length=stacked_lengths, sample_rate=stacked_sample_rates,
+                path=stacked_paths, seek_time=stacked_seek_times)
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var has_joint_embed_conditions
+
+
+
+ +Expand source code + +
@property
+def has_joint_embed_conditions(self):
+    return len(self.joint_embed_conditions) > 0
+
+
+
var has_wav_condition
+
+
+
+ +Expand source code + +
@property
+def has_wav_condition(self):
+    return len(self.wav_conditions) > 0
+
+
+
var joint_embed_conditions
+
+
+
+ +Expand source code + +
@property
+def joint_embed_conditions(self):
+    return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
+
+
+
var text_conditions
+
+
+
+ +Expand source code + +
@property
+def text_conditions(self):
+    return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+
+
var wav_conditions
+
+
+
+ +Expand source code + +
@property
+def wav_conditions(self):
+    return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+
+
+

Methods

+
+
+def forward(self, tokenized: Dict[str, Any]) ‑> Dict[str, Tuple[torch.Tensor, torch.Tensor]] +
+
+

Compute pairs of (embedding, mask) using the configured conditioners and the tokenized representations. +The output is for example: +{ +"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), +"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), +… +}

+

Args

+
+
tokenized : dict
+
Dict of tokenized representations as returned by tokenize().
+
+
+ +Expand source code + +
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+    """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
+    The output is for example:
+    {
+        "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+        "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+        ...
+    }
+
+    Args:
+        tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+    """
+    output = {}
+    for attribute, inputs in tokenized.items():
+        condition, mask = self.conditioners[attribute](inputs)
+        output[attribute] = (condition, mask)
+    return output
+
+
+
+def tokenize(self, inputs: List[ConditioningAttributes]) ‑> Dict[str, Any] +
+
+

Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. +This should be called before starting any real GPU work to avoid synchronization points. +This will return a dict matching conditioner names to their arbitrary tokenized representations.

+

Args

+
+
inputs : list[ConditioningAttributes]
+
List of ConditioningAttributes objects containing +text and wav conditions.
+
+
+ +Expand source code + +
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+    """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+    This should be called before starting any real GPU work to avoid synchronization points.
+    This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+    Args:
+        inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
+            text and wav conditions.
+    """
+    assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
+        "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
+        f" but types were {set([type(x) for x in inputs])}"
+    )
+
+    output = {}
+    text = self._collate_text(inputs)
+    wavs = self._collate_wavs(inputs)
+    joint_embeds = self._collate_joint_embeds(inputs)
+
+    assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
+        f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
+        f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
+    )
+
+    for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
+        output[attribute] = self.conditioners[attribute].tokenize(batch)
+    return output
+
+
+
+
+
+class DropoutModule +(seed: int = 1234) +
+
+

Base module for all dropout modules.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DropoutModule(nn.Module):
+    """Base module for all dropout modules."""
+    def __init__(self, seed: int = 1234):
+        super().__init__()
+        self.rng = torch.Generator()
+        self.rng.manual_seed(seed)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+
+
+class JointEmbedCondition +(wav: torch.Tensor, text: List[Optional[str]], length: torch.Tensor, sample_rate: List[int], path: List[Optional[str]] = [], seek_time: List[Optional[float]] = []) +
+
+

JointEmbedCondition(wav, text, length, sample_rate, path, seek_time)

+
+ +Expand source code + +
class JointEmbedCondition(tp.NamedTuple):
+    wav: torch.Tensor
+    text: tp.List[tp.Optional[str]]
+    length: torch.Tensor
+    sample_rate: tp.List[int]
+    path: tp.List[tp.Optional[str]] = []
+    seek_time: tp.List[tp.Optional[float]] = []
+
+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var length : torch.Tensor
+
+

Alias for field number 2

+
+
var path : List[Optional[str]]
+
+

Alias for field number 4

+
+
var sample_rate : List[int]
+
+

Alias for field number 3

+
+
var seek_time : List[Optional[float]]
+
+

Alias for field number 5

+
+
var text : List[Optional[str]]
+
+

Alias for field number 1

+
+
var wav : torch.Tensor
+
+

Alias for field number 0

+
+
+
+
+class JointEmbeddingConditioner +(dim: int, output_dim: int, device: str, attribute: str, autocast_dtype: Optional[str] = 'float32', quantize: bool = True, n_q: int = 12, bins: int = 1024, **kwargs) +
+
+

Joint embedding conditioning supporting both audio or text conditioning.

+

Args

+
+
dim : int
+
Dimension.
+
output_dim : int
+
Output dimension.
+
device : str
+
Device.
+
attribute : str
+
Attribute used by the conditioner.
+
autocast_dtype : str
+
Autocast for the conditioner.
+
quantize : bool
+
Whether to quantize the CLAP embedding.
+
n_q : int
+
Number of residual quantizers (used if quantize is true).
+
bins : int
+
Quantizers' codebooks size (used if quantize is true).
+
kwargs
+
Additional parameters for residual vector quantizer.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class JointEmbeddingConditioner(BaseConditioner):
+    """Joint embedding conditioning supporting both audio or text conditioning.
+
+    Args:
+        dim (int): Dimension.
+        output_dim (int): Output dimension.
+        device (str): Device.
+        attribute (str): Attribute used by the conditioner.
+        autocast_dtype (str): Autocast for the conditioner.
+        quantize (bool): Whether to quantize the CLAP embedding.
+        n_q (int): Number of residual quantizers (used if quantize is true).
+        bins (int): Quantizers' codebooks size (used if quantize is true).
+        kwargs: Additional parameters for residual vector quantizer.
+    """
+    def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
+                 autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
+                 n_q: int = 12, bins: int = 1024, **kwargs):
+        super().__init__(dim=dim, output_dim=output_dim)
+        self.device = device
+        self.attribute = attribute
+        if autocast_dtype is None or device == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+            logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
+        else:
+            dtype = getattr(torch, autocast_dtype)
+            assert isinstance(dtype, torch.dtype)
+            logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
+            self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+        # residual vector quantizer to discretize the conditioned embedding
+        self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
+        if quantize:
+            self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
+
+    def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Get joint embedding in latent space from the inputs.
+
+        Returns:
+            tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
+                and corresponding empty indexes.
+        """
+        raise NotImplementedError()
+
+    def forward(self, x: JointEmbedCondition) -> ConditionType:
+        with self.autocast:
+            embed, empty_idx = self._get_embed(x)
+            if self.quantizer is not None:
+                embed = embed.view(-1, self.dim, 1)
+                q_res = self.quantizer(embed, frame_rate=1)
+                out_embed = q_res.x.view(-1, self.dim)
+            else:
+                out_embed = embed
+            out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
+            mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
+            mask[empty_idx, :] = 0  # zero-out index where the input is non-existant
+            out_embed = (out_embed * mask.unsqueeze(-1))
+            return out_embed, mask
+
+    def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
+        return x
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class LUTConditioner +(n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0) +
+
+

Lookup table TextConditioner.

+

Args

+
+
n_bins : int
+
Number of bins.
+
dim : int
+
Hidden dim of the model (text-encoder/LUT).
+
output_dim : int
+
Output dim of the conditioner.
+
tokenizer : str
+
Name of the tokenizer.
+
pad_idx : int, optional
+
Index for padding token. Defaults to 0.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class LUTConditioner(TextConditioner):
+    """Lookup table TextConditioner.
+
+    Args:
+        n_bins (int): Number of bins.
+        dim (int): Hidden dim of the model (text-encoder/LUT).
+        output_dim (int): Output dim of the conditioner.
+        tokenizer (str): Name of the tokenizer.
+        pad_idx (int, optional): Index for padding token. Defaults to 0.
+    """
+    def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
+        super().__init__(dim, output_dim)
+        self.embed = nn.Embedding(n_bins, dim)
+        self.tokenizer: Tokenizer
+        if tokenizer == 'whitespace':
+            self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
+        elif tokenizer == 'noop':
+            self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
+        else:
+            raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        device = self.embed.weight.device
+        tokens, mask = self.tokenizer(x)
+        tokens, mask = tokens.to(device), mask.to(device)
+        return tokens, mask
+
+    def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
+        tokens, mask = inputs
+        embeds = self.embed(tokens)
+        embeds = self.output_proj(embeds)
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class NoopTokenizer +(n_bins: int, pad_idx: int = 0) +
+
+

This tokenizer should be used for global conditioners such as: artist, genre, key, etc. +The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split +strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will +split it to ["Jeff", "Buckley"] and return an index per word.

+

For example: +["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] +["Metal", "Rock", "Classical"] => [0, 223, 51]

+
+ +Expand source code + +
class NoopTokenizer(Tokenizer):
+    """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
+    The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
+    strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
+    split it to ["Jeff", "Buckley"] and return an index per word.
+
+    For example:
+    ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
+    ["Metal", "Rock", "Classical"] => [0, 223, 51]
+    """
+    def __init__(self, n_bins: int, pad_idx: int = 0):
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        output, lengths = [], []
+        for text in texts:
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(self.pad_idx)
+                lengths.append(0)
+            else:
+                output.append(hash_trick(text, self.n_bins))
+                lengths.append(1)
+
+        tokens = torch.LongTensor(output).unsqueeze(1)
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        return tokens, mask
+
+

Ancestors

+ +
+
+class SegmentWithAttributes +(meta: AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int, channels: int) +
+
+

Base class for all dataclasses that are used for conditioning. +All child classes should implement to_condition_attributes that converts +the existing attributes to a dataclass of type ConditioningAttributes.

+
+ +Expand source code + +
class SegmentWithAttributes(SegmentInfo):
+    """Base class for all dataclasses that are used for conditioning.
+    All child classes should implement `to_condition_attributes` that converts
+    the existing attributes to a dataclass of type ConditioningAttributes.
+    """
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        raise NotImplementedError()
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var channels : int
+
+
+
+
var metaAudioMeta
+
+
+
+
var n_frames : int
+
+
+
+
var sample_rate : int
+
+
+
+
var seek_time : float
+
+
+
+
var total_frames : int
+
+
+
+
+

Methods

+
+
+def to_condition_attributes(self) ‑> ConditioningAttributes +
+
+
+
+ +Expand source code + +
def to_condition_attributes(self) -> ConditioningAttributes:
+    raise NotImplementedError()
+
+
+
+
+
+class T5Conditioner +(name: str, output_dim: int, finetune: bool, device: str, autocast_dtype: Optional[str] = 'float32', word_dropout: float = 0.0, normalize_text: bool = False) +
+
+

T5-based TextConditioner.

+

Args

+
+
name : str
+
Name of the T5 model.
+
output_dim : int
+
Output dim of the conditioner.
+
finetune : bool
+
Whether to fine-tune T5 at train time.
+
device : str
+
Device for T5 Conditioner.
+
autocast_dtype : tp.Optional[str], optional
+
Autocast dtype.
+
word_dropout : float, optional
+
Word dropout probability.
+
normalize_text : bool, optional
+
Whether to apply text normalization.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class T5Conditioner(TextConditioner):
+    """T5-based TextConditioner.
+
+    Args:
+        name (str): Name of the T5 model.
+        output_dim (int): Output dim of the conditioner.
+        finetune (bool): Whether to fine-tune T5 at train time.
+        device (str): Device for T5 Conditioner.
+        autocast_dtype (tp.Optional[str], optional): Autocast dtype.
+        word_dropout (float, optional): Word dropout probability.
+        normalize_text (bool, optional): Whether to apply text normalization.
+    """
+    MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
+              "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
+              "google/flan-t5-xl", "google/flan-t5-xxl"]
+    MODELS_DIMS = {
+        "t5-small": 512,
+        "t5-base": 768,
+        "t5-large": 1024,
+        "t5-3b": 1024,
+        "t5-11b": 1024,
+        "google/flan-t5-small": 512,
+        "google/flan-t5-base": 768,
+        "google/flan-t5-large": 1024,
+        "google/flan-t5-3b": 1024,
+        "google/flan-t5-11b": 1024,
+    }
+
+    def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
+                 autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
+                 normalize_text: bool = False):
+        assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
+        super().__init__(self.MODELS_DIMS[name], output_dim)
+        self.device = device
+        self.name = name
+        self.finetune = finetune
+        self.word_dropout = word_dropout
+        if autocast_dtype is None or self.device == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+            if self.device != 'cpu':
+                logger.warning("T5 has no autocast, this might lead to NaN")
+        else:
+            dtype = getattr(torch, autocast_dtype)
+            assert isinstance(dtype, torch.dtype)
+            logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
+            self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+        # Let's disable logging temporarily because T5 will vomit some errors otherwise.
+        # thanks https://gist.github.com/simon-weber/7853144
+        previous_level = logging.root.manager.disable
+        logging.disable(logging.ERROR)
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            try:
+                self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
+                t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
+            finally:
+                logging.disable(previous_level)
+        if finetune:
+            self.t5 = t5
+        else:
+            # this makes sure that the t5 models is not part
+            # of the saved checkpoint
+            self.__dict__['t5'] = t5.to(device)
+
+        self.normalize_text = normalize_text
+        if normalize_text:
+            self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
+        # if current sample doesn't have a certain attribute, replace with empty string
+        entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
+        if self.normalize_text:
+            _, _, entries = self.text_normalizer(entries, return_text=True)
+        if self.word_dropout > 0. and self.training:
+            new_entries = []
+            for entry in entries:
+                words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
+                new_entries.append(" ".join(words))
+            entries = new_entries
+
+        empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
+
+        inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
+        mask = inputs['attention_mask']
+        mask[empty_idx, :] = 0  # zero-out index where the input is non-existant
+        return inputs
+
+    def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
+        mask = inputs['attention_mask']
+        with torch.set_grad_enabled(self.finetune), self.autocast:
+            embeds = self.t5(**inputs).last_hidden_state
+        embeds = self.output_proj(embeds.to(self.output_proj.weight))
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+

Ancestors

+ +

Class variables

+
+
var MODELS
+
+
+
+
var MODELS_DIMS
+
+
+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class TextConditioner +(dim: int, output_dim: int) +
+
+

Base model for all conditioner modules. +We allow the output dim to be different than the hidden dim for two reasons: +1) keep our LUTs small when the vocab is large; +2) make all condition dims consistent.

+

Args

+
+
dim : int
+
Hidden dim of the model.
+
output_dim : int
+
Output dim of the conditioner.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class TextConditioner(BaseConditioner):
+    ...
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class Tokenizer +
+
+

Base tokenizer implementation +(in case we want to introduce more advances tokenizers in the future).

+
+ +Expand source code + +
class Tokenizer:
+    """Base tokenizer implementation
+    (in case we want to introduce more advances tokenizers in the future).
+    """
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        raise NotImplementedError()
+
+

Subclasses

+ +
+
+class WavCondition +(wav: torch.Tensor, length: torch.Tensor, sample_rate: List[int], path: List[Optional[str]] = [], seek_time: List[Optional[float]] = []) +
+
+

WavCondition(wav, length, sample_rate, path, seek_time)

+
+ +Expand source code + +
class WavCondition(tp.NamedTuple):
+    wav: torch.Tensor
+    length: torch.Tensor
+    sample_rate: tp.List[int]
+    path: tp.List[tp.Optional[str]] = []
+    seek_time: tp.List[tp.Optional[float]] = []
+
+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var length : torch.Tensor
+
+

Alias for field number 1

+
+
var path : List[Optional[str]]
+
+

Alias for field number 3

+
+
var sample_rate : List[int]
+
+

Alias for field number 2

+
+
var seek_time : List[Optional[float]]
+
+

Alias for field number 4

+
+
var wav : torch.Tensor
+
+

Alias for field number 0

+
+
+
+
+class WaveformConditioner +(dim: int, output_dim: int, device: Union[torch.device, str]) +
+
+

Base class for all conditioners that take a waveform as input. +Classes that inherit must implement _get_wav_embedding that outputs +a continuous tensor, and _downsampling_factor that returns the down-sampling +factor of the embedding model.

+

Args

+
+
dim : int
+
The internal representation dimension.
+
output_dim : int
+
Output dimension.
+
device : tp.Union[torch.device, str]
+
Device.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class WaveformConditioner(BaseConditioner):
+    """Base class for all conditioners that take a waveform as input.
+    Classes that inherit must implement `_get_wav_embedding` that outputs
+    a continuous tensor, and `_downsampling_factor` that returns the down-sampling
+    factor of the embedding model.
+
+    Args:
+        dim (int): The internal representation dimension.
+        output_dim (int): Output dimension.
+        device (tp.Union[torch.device, str]): Device.
+    """
+    def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
+        super().__init__(dim, output_dim)
+        self.device = device
+
+    def tokenize(self, x: WavCondition) -> WavCondition:
+        wav, length, sample_rate, path, seek_time = x
+        assert length is not None
+        return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
+
+    def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
+        """Gets as input a WavCondition and returns a dense embedding."""
+        raise NotImplementedError()
+
+    def _downsampling_factor(self):
+        """Returns the downsampling factor of the embedding model."""
+        raise NotImplementedError()
+
+    def forward(self, x: WavCondition) -> ConditionType:
+        """Extract condition embedding and mask from a waveform and its metadata.
+        Args:
+            x (WavCondition): Waveform condition containing raw waveform and metadata.
+        Returns:
+            ConditionType: a dense vector representing the conditioning along with its mask
+        """
+        wav, lengths, *_ = x
+        with torch.no_grad():
+            embeds = self._get_wav_embedding(x)
+        embeds = embeds.to(self.output_proj.weight)
+        embeds = self.output_proj(embeds)
+
+        if lengths is not None:
+            lengths = lengths / self._downsampling_factor()
+            mask = length_to_mask(lengths, max_len=embeds.shape[1]).int()  # type: ignore
+        else:
+            mask = torch.ones_like(embeds)
+        embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+        return embeds, mask
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: WavCondition) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Extract condition embedding and mask from a waveform and its metadata.

+

Args

+
+
x : WavCondition
+
Waveform condition containing raw waveform and metadata.
+
+

Returns

+
+
ConditionType
+
a dense vector representing the conditioning along with its mask
+
+
+ +Expand source code + +
def forward(self, x: WavCondition) -> ConditionType:
+    """Extract condition embedding and mask from a waveform and its metadata.
+    Args:
+        x (WavCondition): Waveform condition containing raw waveform and metadata.
+    Returns:
+        ConditionType: a dense vector representing the conditioning along with its mask
+    """
+    wav, lengths, *_ = x
+    with torch.no_grad():
+        embeds = self._get_wav_embedding(x)
+    embeds = embeds.to(self.output_proj.weight)
+    embeds = self.output_proj(embeds)
+
+    if lengths is not None:
+        lengths = lengths / self._downsampling_factor()
+        mask = length_to_mask(lengths, max_len=embeds.shape[1]).int()  # type: ignore
+    else:
+        mask = torch.ones_like(embeds)
+    embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+    return embeds, mask
+
+
+
+

Inherited members

+ +
+
+class WhiteSpaceTokenizer +(n_bins: int, pad_idx: int = 0, language: str = 'en_core_web_sm', lemma: bool = True, stopwords: bool = True) +
+
+

This tokenizer should be used for natural language descriptions. +For example: +["he didn't, know he's going home.", 'shorter sentence'] => +[[78, 62, 31, +4, 78, 25, 19, 34], +[59, 77, +0, +0, +0, +0, +0, +0]]

+
+ +Expand source code + +
class WhiteSpaceTokenizer(Tokenizer):
+    """This tokenizer should be used for natural language descriptions.
+    For example:
+    ["he didn't, know he's going home.", 'shorter sentence'] =>
+    [[78, 62, 31,  4, 78, 25, 19, 34],
+    [59, 77,  0,  0,  0,  0,  0,  0]]
+    """
+    PUNCTUATION = "?:!.,;"
+
+    def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
+                 lemma: bool = True, stopwords: bool = True) -> None:
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+        self.lemma = lemma
+        self.stopwords = stopwords
+        try:
+            self.nlp = spacy.load(language)
+        except IOError:
+            spacy.cli.download(language)  # type: ignore
+            self.nlp = spacy.load(language)
+
+    @tp.no_type_check
+    def __call__(self, texts: tp.List[tp.Optional[str]],
+                 return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        """Take a list of strings and convert them to a tensor of indices.
+
+        Args:
+            texts (list[str]): List of strings.
+            return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
+        Returns:
+            tuple[torch.Tensor, torch.Tensor]:
+                - Indices of words in the LUT.
+                - And a mask indicating where the padding tokens are
+        """
+        output, lengths = [], []
+        texts = deepcopy(texts)
+        for i, text in enumerate(texts):
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(torch.Tensor([self.pad_idx]))
+                lengths.append(0)
+                continue
+
+            # convert numbers to words
+            text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text)  # type: ignore
+            # normalize text
+            text = self.nlp(text)  # type: ignore
+            # remove stopwords
+            if self.stopwords:
+                text = [w for w in text if not w.is_stop]  # type: ignore
+            # remove punctuation
+            text = [w for w in text if w.text not in self.PUNCTUATION]  # type: ignore
+            # lemmatize if needed
+            text = [getattr(t, "lemma_" if self.lemma else "text") for t in text]  # type: ignore
+
+            texts[i] = " ".join(text)
+            lengths.append(len(text))
+            # convert to tensor
+            tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
+            output.append(tokens)
+
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
+        if return_text:
+            return padded_output, mask, texts  # type: ignore
+        return padded_output, mask
+
+

Ancestors

+ +

Class variables

+
+
var PUNCTUATION
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/conv.html b/api_docs/audiocraft/modules/conv.html new file mode 100644 index 00000000..98c6c6a8 --- /dev/null +++ b/api_docs/audiocraft/modules/conv.html @@ -0,0 +1,1044 @@ + + + + + + +audiocraft.modules.conv API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.conv

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+
+CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
+                                 'time_group_norm'])
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'weight_norm':
+        return weight_norm(module)
+    elif norm == 'spectral_norm':
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'time_group_norm':
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+                                 padding_total: int = 0) -> int:
+    """See `pad_for_conv1d`."""
+    length = x.shape[-1]
+    n_frames = (length - kernel_size + padding_total) / stride + 1
+    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    return ideal_length - length
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+    """Pad for a convolution to make sure that the last window is full.
+    Extra padding is added at the end. This is required to ensure that we can rebuild
+    an output of the same length, as otherwise, even with padding, some time steps
+    might get removed.
+    For instance, with total padding = 4, kernel size = 4, stride = 2:
+        0 0 1 2 3 4 5 0 0   # (0s are padding)
+        1   2   3           # (output frames of a convolution, last 0 is never used)
+        0 0 1 2 3 4 5 0     # (output of tr. conv., but pos. 5 is going to get removed as padding)
+            1 2 3 4         # once you removed padding, we are missing one time step !
+    """
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    return F.pad(x, (0, extra_padding))
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+    If this is the case, we insert extra 0 padding to the right before the reflection happen.
+    """
+    length = x.shape[-1]
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    if mode == 'reflect':
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            x = F.pad(x, (0, extra_pad))
+        padded = F.pad(x, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+    else:
+        return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+    """Remove padding from x, handling properly zero padding. Only for 1d!"""
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    assert (padding_left + padding_right) <= x.shape[-1]
+    end = x.shape[-1] - padding_right
+    return x[..., padding_left: end]
+
+
+class NormConv1d(nn.Module):
+    """Wrapper around Conv1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConv2d(nn.Module):
+    """Wrapper around Conv2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConvTranspose1d(nn.Module):
+    """Wrapper around ConvTranspose1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConvTranspose2d(nn.Module):
+    """Wrapper around ConvTranspose2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+
+class StreamableConv1d(nn.Module):
+    """Conv1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, dilation: int = 1,
+                 groups: int = 1, bias: bool = True, causal: bool = False,
+                 norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+                 pad_mode: str = 'reflect'):
+        super().__init__()
+        # warn user on unusual setup between dilation and stride
+        if stride > 1 and dilation > 1:
+            warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
+                          f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
+        self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+                               dilation=dilation, groups=groups, bias=bias, causal=causal,
+                               norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.pad_mode = pad_mode
+
+    def forward(self, x):
+        B, C, T = x.shape
+        kernel_size = self.conv.conv.kernel_size[0]
+        stride = self.conv.conv.stride[0]
+        dilation = self.conv.conv.dilation[0]
+        kernel_size = (kernel_size - 1) * dilation + 1  # effective kernel size with dilations
+        padding_total = kernel_size - stride
+        extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+        if self.causal:
+            # Left padding for causal
+            x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+        return self.conv(x)
+
+
+class StreamableConvTranspose1d(nn.Module):
+    """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, causal: bool = False,
+                 norm: str = 'none', trim_right_ratio: float = 1.,
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}):
+        super().__init__()
+        self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+                                          causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.trim_right_ratio = trim_right_ratio
+        assert self.causal or self.trim_right_ratio == 1., \
+            "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+        assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+    def forward(self, x):
+        kernel_size = self.convtr.convtr.kernel_size[0]
+        stride = self.convtr.convtr.stride[0]
+        padding_total = kernel_size - stride
+
+        y = self.convtr(x)
+
+        # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+        # removed at the very end, when keeping only the right length for the output,
+        # as removing it here would require also passing the length at the matching layer
+        # in the encoder.
+        if self.causal:
+            # Trim the padding on the right according to the specified ratio
+            # if trim_right_ratio = 1.0, trim everything from right
+            padding_right = math.ceil(padding_total * self.trim_right_ratio)
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        return y
+
+
+
+
+
+
+
+

Functions

+
+
+def apply_parametrization_norm(module: torch.nn.modules.module.Module, norm: str = 'none') +
+
+
+
+ +Expand source code + +
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'weight_norm':
+        return weight_norm(module)
+    elif norm == 'spectral_norm':
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) ‑> int +
+
+ +
+ +Expand source code + +
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+                                 padding_total: int = 0) -> int:
+    """See `pad_for_conv1d`."""
+    length = x.shape[-1]
+    n_frames = (length - kernel_size + padding_total) / stride + 1
+    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    return ideal_length - length
+
+
+
+def get_norm_module(module: torch.nn.modules.module.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) +
+
+

Return the proper normalization module. If causal is True, this will ensure the returned +module is causal, or return an error if the normalization doesn't support causal evaluation.

+
+ +Expand source code + +
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'time_group_norm':
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return nn.Identity()
+
+
+
+def pad1d(x: torch.Tensor, paddings: Tuple[int, int], mode: str = 'constant', value: float = 0.0) +
+
+

Tiny wrapper around F.pad, just to allow for reflect padding on small input. +If this is the case, we insert extra 0 padding to the right before the reflection happen.

+
+ +Expand source code + +
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+    If this is the case, we insert extra 0 padding to the right before the reflection happen.
+    """
+    length = x.shape[-1]
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    if mode == 'reflect':
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            x = F.pad(x, (0, extra_pad))
+        padded = F.pad(x, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+    else:
+        return F.pad(x, paddings, mode, value)
+
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) +
+
+

Pad for a convolution to make sure that the last window is full. +Extra padding is added at the end. This is required to ensure that we can rebuild +an output of the same length, as otherwise, even with padding, some time steps +might get removed. +For instance, with total padding = 4, kernel size = 4, stride = 2: +0 0 1 2 3 4 5 0 0 +# (0s are padding) +1 +2 +3 +# (output frames of a convolution, last 0 is never used) +0 0 1 2 3 4 5 0 +# (output of tr. conv., but pos. 5 is going to get removed as padding) +1 2 3 4 +# once you removed padding, we are missing one time step !

+
+ +Expand source code + +
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+    """Pad for a convolution to make sure that the last window is full.
+    Extra padding is added at the end. This is required to ensure that we can rebuild
+    an output of the same length, as otherwise, even with padding, some time steps
+    might get removed.
+    For instance, with total padding = 4, kernel size = 4, stride = 2:
+        0 0 1 2 3 4 5 0 0   # (0s are padding)
+        1   2   3           # (output frames of a convolution, last 0 is never used)
+        0 0 1 2 3 4 5 0     # (output of tr. conv., but pos. 5 is going to get removed as padding)
+            1 2 3 4         # once you removed padding, we are missing one time step !
+    """
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    return F.pad(x, (0, extra_padding))
+
+
+
+def unpad1d(x: torch.Tensor, paddings: Tuple[int, int]) +
+
+

Remove padding from x, handling properly zero padding. Only for 1d!

+
+ +Expand source code + +
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+    """Remove padding from x, handling properly zero padding. Only for 1d!"""
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    assert (padding_left + padding_right) <= x.shape[-1]
+    end = x.shape[-1] - padding_right
+    return x[..., padding_left: end]
+
+
+
+
+
+

Classes

+
+
+class NormConv1d +(*args, causal: bool = False, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around Conv1d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConv1d(nn.Module):
+    """Wrapper around Conv1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.conv(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class NormConv2d +(*args, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around Conv2d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConv2d(nn.Module):
+    """Wrapper around Conv2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.conv(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class NormConvTranspose1d +(*args, causal: bool = False, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around ConvTranspose1d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConvTranspose1d(nn.Module):
+    """Wrapper around ConvTranspose1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.convtr(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class NormConvTranspose2d +(*args, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around ConvTranspose2d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConvTranspose2d(nn.Module):
+    """Wrapper around ConvTranspose2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.convtr(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class StreamableConv1d +(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, causal: bool = False, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, pad_mode: str = 'reflect') +
+
+

Conv1d with some builtin handling of asymmetric or causal padding +and normalization.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamableConv1d(nn.Module):
+    """Conv1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, dilation: int = 1,
+                 groups: int = 1, bias: bool = True, causal: bool = False,
+                 norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+                 pad_mode: str = 'reflect'):
+        super().__init__()
+        # warn user on unusual setup between dilation and stride
+        if stride > 1 and dilation > 1:
+            warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
+                          f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
+        self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+                               dilation=dilation, groups=groups, bias=bias, causal=causal,
+                               norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.pad_mode = pad_mode
+
+    def forward(self, x):
+        B, C, T = x.shape
+        kernel_size = self.conv.conv.kernel_size[0]
+        stride = self.conv.conv.stride[0]
+        dilation = self.conv.conv.dilation[0]
+        kernel_size = (kernel_size - 1) * dilation + 1  # effective kernel size with dilations
+        padding_total = kernel_size - stride
+        extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+        if self.causal:
+            # Left padding for causal
+            x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+        return self.conv(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    B, C, T = x.shape
+    kernel_size = self.conv.conv.kernel_size[0]
+    stride = self.conv.conv.stride[0]
+    dilation = self.conv.conv.dilation[0]
+    kernel_size = (kernel_size - 1) * dilation + 1  # effective kernel size with dilations
+    padding_total = kernel_size - stride
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    if self.causal:
+        # Left padding for causal
+        x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+    else:
+        # Asymmetric padding required for odd strides
+        padding_right = padding_total // 2
+        padding_left = padding_total - padding_right
+        x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+    return self.conv(x)
+
+
+
+
+
+class StreamableConvTranspose1d +(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, causal: bool = False, norm: str = 'none', trim_right_ratio: float = 1.0, norm_kwargs: Dict[str, Any] = {}) +
+
+

ConvTranspose1d with some builtin handling of asymmetric or causal padding +and normalization.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamableConvTranspose1d(nn.Module):
+    """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, causal: bool = False,
+                 norm: str = 'none', trim_right_ratio: float = 1.,
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}):
+        super().__init__()
+        self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+                                          causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.trim_right_ratio = trim_right_ratio
+        assert self.causal or self.trim_right_ratio == 1., \
+            "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+        assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+    def forward(self, x):
+        kernel_size = self.convtr.convtr.kernel_size[0]
+        stride = self.convtr.convtr.stride[0]
+        padding_total = kernel_size - stride
+
+        y = self.convtr(x)
+
+        # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+        # removed at the very end, when keeping only the right length for the output,
+        # as removing it here would require also passing the length at the matching layer
+        # in the encoder.
+        if self.causal:
+            # Trim the padding on the right according to the specified ratio
+            # if trim_right_ratio = 1.0, trim everything from right
+            padding_right = math.ceil(padding_total * self.trim_right_ratio)
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        return y
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    kernel_size = self.convtr.convtr.kernel_size[0]
+    stride = self.convtr.convtr.stride[0]
+    padding_total = kernel_size - stride
+
+    y = self.convtr(x)
+
+    # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+    # removed at the very end, when keeping only the right length for the output,
+    # as removing it here would require also passing the length at the matching layer
+    # in the encoder.
+    if self.causal:
+        # Trim the padding on the right according to the specified ratio
+        # if trim_right_ratio = 1.0, trim everything from right
+        padding_right = math.ceil(padding_total * self.trim_right_ratio)
+        padding_left = padding_total - padding_right
+        y = unpad1d(y, (padding_left, padding_right))
+    else:
+        # Asymmetric padding required for odd strides
+        padding_right = padding_total // 2
+        padding_left = padding_total - padding_right
+        y = unpad1d(y, (padding_left, padding_right))
+    return y
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/diffusion_schedule.html b/api_docs/audiocraft/modules/diffusion_schedule.html new file mode 100644 index 00000000..29e46666 --- /dev/null +++ b/api_docs/audiocraft/modules/diffusion_schedule.html @@ -0,0 +1,1145 @@ + + + + + + +audiocraft.modules.diffusion_schedule API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.diffusion_schedule

+
+
+

Functions for Noise Schedule, defines diffusion process, reverse process and data processor.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
+"""
+
+from collections import namedtuple
+import random
+import typing as tp
+import julius
+import torch
+
+TrainingItem = namedtuple("TrainingItem", "noisy noise step")
+
+
+def betas_from_alpha_bar(alpha_bar):
+    alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
+    return 1 - alphas
+
+
+class SampleProcessor(torch.nn.Module):
+    def project_sample(self, x: torch.Tensor):
+        """Project the original sample to the 'space' where the diffusion will happen."""
+        return x
+
+    def return_sample(self, z: torch.Tensor):
+        """Project back from diffusion space to the actual sample space."""
+        return z
+
+
+class MultiBandProcessor(SampleProcessor):
+    """
+    MultiBand sample processor. The input audio is splitted across
+    frequency bands evenly distributed in mel-scale.
+
+    Each band will be rescaled to match the power distribution
+    of Gaussian noise in that band, using online metrics
+    computed on the first few samples.
+
+    Args:
+        n_bands (int): Number of mel-bands to split the signal over.
+        sample_rate (int): Sample rate of the audio.
+        num_samples (int): Number of samples to use to fit the rescaling
+            for each band. The processor won't be stable
+            until it has seen that many samples.
+        power_std (float or list/tensor): The rescaling factor computed to match the
+            power of Gaussian noise in each band is taken to
+            that power, i.e. `1.` means full correction of the energy
+            in each band, and values less than `1` means only partial
+            correction. Can be used to balance the relative importance
+            of low vs. high freq in typical audio signals.
+    """
+    def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
+                 num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
+        super().__init__()
+        self.n_bands = n_bands
+        self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
+        self.num_samples = num_samples
+        self.power_std = power_std
+        if isinstance(power_std, list):
+            assert len(power_std) == n_bands
+            power_std = torch.tensor(power_std)
+        self.register_buffer('counts', torch.zeros(1))
+        self.register_buffer('sum_x', torch.zeros(n_bands))
+        self.register_buffer('sum_x2', torch.zeros(n_bands))
+        self.register_buffer('sum_target_x2', torch.zeros(n_bands))
+        self.counts: torch.Tensor
+        self.sum_x: torch.Tensor
+        self.sum_x2: torch.Tensor
+        self.sum_target_x2: torch.Tensor
+
+    @property
+    def mean(self):
+        mean = self.sum_x / self.counts
+        return mean
+
+    @property
+    def std(self):
+        std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
+        return std
+
+    @property
+    def target_std(self):
+        target_std = self.sum_target_x2 / self.counts
+        return target_std
+
+    def project_sample(self, x: torch.Tensor):
+        assert x.dim() == 3
+        bands = self.split_bands(x)
+        if self.counts.item() < self.num_samples:
+            ref_bands = self.split_bands(torch.randn_like(x))
+            self.counts += len(x)
+            self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
+            self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
+            self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
+        rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std  # same output size
+        bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
+        return bands.sum(dim=0)
+
+    def return_sample(self, x: torch.Tensor):
+        assert x.dim() == 3
+        bands = self.split_bands(x)
+        rescale = (self.std / self.target_std) ** self.power_std
+        bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
+        return bands.sum(dim=0)
+
+
+class NoiseSchedule:
+    """Noise schedule for diffusion.
+
+    Args:
+        beta_t0 (float): Variance of the first diffusion step.
+        beta_t1 (float): Variance of the last diffusion step.
+        beta_exp (float): Power schedule exponent
+        num_steps (int): Number of diffusion step.
+        variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
+        clip (float): clipping value for the denoising steps
+        rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
+        repartition (str): shape of the schedule only power schedule is supported
+        sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
+        noise_scale (float): Scaling factor for the noise
+    """
+    def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
+                 clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
+                 repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
+                 sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
+
+        self.beta_t0 = beta_t0
+        self.beta_t1 = beta_t1
+        self.variance = variance
+        self.num_steps = num_steps
+        self.clip = clip
+        self.sample_processor = sample_processor
+        self.rescale = rescale
+        self.n_bands = n_bands
+        self.noise_scale = noise_scale
+        assert n_bands is None
+        if repartition == "power":
+            self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
+                                        device=device, dtype=torch.float) ** beta_exp
+        else:
+            raise RuntimeError('Not implemented')
+        self.rng = random.Random(1234)
+
+    def get_beta(self, step: tp.Union[int, torch.Tensor]):
+        if self.n_bands is None:
+            return self.betas[step]
+        else:
+            return self.betas[:, step]  # [n_bands, len(step)]
+
+    def get_initial_noise(self, x: torch.Tensor):
+        if self.n_bands is None:
+            return torch.randn_like(x)
+        return torch.randn((x.size(0), self.n_bands, x.size(2)))
+
+    def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
+        """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
+        if step is None:
+            return (1 - self.betas).cumprod(dim=-1)  # works for simgle and multi bands
+        if type(step) is int:
+            return (1 - self.betas[:step + 1]).prod()
+        else:
+            return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
+
+    def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
+        """Create a noisy data item for diffusion model training:
+
+        Args:
+            x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
+            tensor_step (bool): If tensor_step = false, only one step t is sample,
+                the whole batch is diffused to the same step and t is int.
+                If tensor_step = true, t is a tensor of size (x.size(0),)
+                every element of the batch is diffused to a independently sampled.
+        """
+        step: tp.Union[int, torch.Tensor]
+        if tensor_step:
+            bs = x.size(0)
+            step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
+        else:
+            step = self.rng.randrange(self.num_steps)
+        alpha_bar = self.get_alpha_bar(step)  # [batch_size, n_bands, 1]
+
+        x = self.sample_processor.project_sample(x)
+        noise = torch.randn_like(x)
+        noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
+        return TrainingItem(noisy, noise, step)
+
+    def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
+                 condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+        """Full ddpm reverse process.
+
+        Args:
+            model (nn.Module): Diffusion model.
+            initial (tensor): Initial Noise.
+            condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
+            return_list (bool): Whether to return the whole process or only the sampled point.
+        """
+        alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+        current = initial
+        iterates = [initial]
+        for step in range(self.num_steps)[::-1]:
+            with torch.no_grad():
+                estimate = model(current, step, condition=condition).sample
+            alpha = 1 - self.betas[step]
+            previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+            previous_alpha_bar = self.get_alpha_bar(step=step - 1)
+            if step == 0:
+                sigma2 = 0
+            elif self.variance == 'beta':
+                sigma2 = 1 - alpha
+            elif self.variance == 'beta_tilde':
+                sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+            elif self.variance == 'none':
+                sigma2 = 0
+            else:
+                raise ValueError(f'Invalid variance type {self.variance}')
+
+            if sigma2 > 0:
+                previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+            if self.clip:
+                previous = previous.clamp(-self.clip, self.clip)
+            current = previous
+            alpha_bar = previous_alpha_bar
+            if step == 0:
+                previous *= self.rescale
+            if return_list:
+                iterates.append(previous.cpu())
+
+        if return_list:
+            return iterates
+        else:
+            return self.sample_processor.return_sample(previous)
+
+    def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
+                            condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+        """Reverse process that only goes through Markov chain states in step_list."""
+        if step_list is None:
+            step_list = list(range(1000))[::-50] + [0]
+        alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+        alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
+        betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
+        current = initial * self.noise_scale
+        iterates = [current]
+        for idx, step in enumerate(step_list[:-1]):
+            with torch.no_grad():
+                estimate = model(current, step, condition=condition).sample * self.noise_scale
+            alpha = 1 - betas_subsampled[-1 - idx]
+            previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+            previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
+            if step == step_list[-2]:
+                sigma2 = 0
+                previous_alpha_bar = torch.tensor(1.0)
+            else:
+                sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+            if sigma2 > 0:
+                previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+            if self.clip:
+                previous = previous.clamp(-self.clip, self.clip)
+            current = previous
+            alpha_bar = previous_alpha_bar
+            if step == 0:
+                previous *= self.rescale
+            if return_list:
+                iterates.append(previous.cpu())
+        if return_list:
+            return iterates
+        else:
+            return self.sample_processor.return_sample(previous)
+
+
+
+
+
+
+
+

Functions

+
+
+def betas_from_alpha_bar(alpha_bar) +
+
+
+
+ +Expand source code + +
def betas_from_alpha_bar(alpha_bar):
+    alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
+    return 1 - alphas
+
+
+
+
+
+

Classes

+
+
+class MultiBandProcessor +(n_bands: int = 8, sample_rate: float = 24000, num_samples: int = 10000, power_std: Union[float, List[float], torch.Tensor] = 1.0) +
+
+

MultiBand sample processor. The input audio is splitted across +frequency bands evenly distributed in mel-scale.

+

Each band will be rescaled to match the power distribution +of Gaussian noise in that band, using online metrics +computed on the first few samples.

+

Args

+
+
n_bands : int
+
Number of mel-bands to split the signal over.
+
sample_rate : int
+
Sample rate of the audio.
+
num_samples : int
+
Number of samples to use to fit the rescaling +for each band. The processor won't be stable +until it has seen that many samples.
+
+

power_std (float or list/tensor): The rescaling factor computed to match the +power of Gaussian noise in each band is taken to +that power, i.e. 1. means full correction of the energy +in each band, and values less than 1 means only partial +correction. Can be used to balance the relative importance +of low vs. high freq in typical audio signals. +Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class MultiBandProcessor(SampleProcessor):
+    """
+    MultiBand sample processor. The input audio is splitted across
+    frequency bands evenly distributed in mel-scale.
+
+    Each band will be rescaled to match the power distribution
+    of Gaussian noise in that band, using online metrics
+    computed on the first few samples.
+
+    Args:
+        n_bands (int): Number of mel-bands to split the signal over.
+        sample_rate (int): Sample rate of the audio.
+        num_samples (int): Number of samples to use to fit the rescaling
+            for each band. The processor won't be stable
+            until it has seen that many samples.
+        power_std (float or list/tensor): The rescaling factor computed to match the
+            power of Gaussian noise in each band is taken to
+            that power, i.e. `1.` means full correction of the energy
+            in each band, and values less than `1` means only partial
+            correction. Can be used to balance the relative importance
+            of low vs. high freq in typical audio signals.
+    """
+    def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
+                 num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
+        super().__init__()
+        self.n_bands = n_bands
+        self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
+        self.num_samples = num_samples
+        self.power_std = power_std
+        if isinstance(power_std, list):
+            assert len(power_std) == n_bands
+            power_std = torch.tensor(power_std)
+        self.register_buffer('counts', torch.zeros(1))
+        self.register_buffer('sum_x', torch.zeros(n_bands))
+        self.register_buffer('sum_x2', torch.zeros(n_bands))
+        self.register_buffer('sum_target_x2', torch.zeros(n_bands))
+        self.counts: torch.Tensor
+        self.sum_x: torch.Tensor
+        self.sum_x2: torch.Tensor
+        self.sum_target_x2: torch.Tensor
+
+    @property
+    def mean(self):
+        mean = self.sum_x / self.counts
+        return mean
+
+    @property
+    def std(self):
+        std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
+        return std
+
+    @property
+    def target_std(self):
+        target_std = self.sum_target_x2 / self.counts
+        return target_std
+
+    def project_sample(self, x: torch.Tensor):
+        assert x.dim() == 3
+        bands = self.split_bands(x)
+        if self.counts.item() < self.num_samples:
+            ref_bands = self.split_bands(torch.randn_like(x))
+            self.counts += len(x)
+            self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
+            self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
+            self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
+        rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std  # same output size
+        bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
+        return bands.sum(dim=0)
+
+    def return_sample(self, x: torch.Tensor):
+        assert x.dim() == 3
+        bands = self.split_bands(x)
+        rescale = (self.std / self.target_std) ** self.power_std
+        bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
+        return bands.sum(dim=0)
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var mean
+
+
+
+ +Expand source code + +
@property
+def mean(self):
+    mean = self.sum_x / self.counts
+    return mean
+
+
+
var std
+
+
+
+ +Expand source code + +
@property
+def std(self):
+    std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
+    return std
+
+
+
var target_std
+
+
+
+ +Expand source code + +
@property
+def target_std(self):
+    target_std = self.sum_target_x2 / self.counts
+    return target_std
+
+
+
+

Inherited members

+ +
+
+class NoiseSchedule +(beta_t0: float = 0.0001, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', clip: float = 5.0, rescale: float = 1.0, device='cuda', beta_exp: float = 1, repartition: str = 'power', alpha_sigmoid: dict = {}, n_bands: Optional[int] = None, sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs) +
+
+

Noise schedule for diffusion.

+

Args

+
+
beta_t0 : float
+
Variance of the first diffusion step.
+
beta_t1 : float
+
Variance of the last diffusion step.
+
beta_exp : float
+
Power schedule exponent
+
num_steps : int
+
Number of diffusion step.
+
variance : str
+
choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
+
clip : float
+
clipping value for the denoising steps
+
rescale : float
+
rescaling value to avoid vanishing signals unused by default (i.e 1)
+
repartition : str
+
shape of the schedule only power schedule is supported
+
sample_processor : SampleProcessor
+
Module that normalize data to match better the gaussian distribution
+
noise_scale : float
+
Scaling factor for the noise
+
+
+ +Expand source code + +
class NoiseSchedule:
+    """Noise schedule for diffusion.
+
+    Args:
+        beta_t0 (float): Variance of the first diffusion step.
+        beta_t1 (float): Variance of the last diffusion step.
+        beta_exp (float): Power schedule exponent
+        num_steps (int): Number of diffusion step.
+        variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
+        clip (float): clipping value for the denoising steps
+        rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
+        repartition (str): shape of the schedule only power schedule is supported
+        sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
+        noise_scale (float): Scaling factor for the noise
+    """
+    def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
+                 clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
+                 repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
+                 sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
+
+        self.beta_t0 = beta_t0
+        self.beta_t1 = beta_t1
+        self.variance = variance
+        self.num_steps = num_steps
+        self.clip = clip
+        self.sample_processor = sample_processor
+        self.rescale = rescale
+        self.n_bands = n_bands
+        self.noise_scale = noise_scale
+        assert n_bands is None
+        if repartition == "power":
+            self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
+                                        device=device, dtype=torch.float) ** beta_exp
+        else:
+            raise RuntimeError('Not implemented')
+        self.rng = random.Random(1234)
+
+    def get_beta(self, step: tp.Union[int, torch.Tensor]):
+        if self.n_bands is None:
+            return self.betas[step]
+        else:
+            return self.betas[:, step]  # [n_bands, len(step)]
+
+    def get_initial_noise(self, x: torch.Tensor):
+        if self.n_bands is None:
+            return torch.randn_like(x)
+        return torch.randn((x.size(0), self.n_bands, x.size(2)))
+
+    def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
+        """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
+        if step is None:
+            return (1 - self.betas).cumprod(dim=-1)  # works for simgle and multi bands
+        if type(step) is int:
+            return (1 - self.betas[:step + 1]).prod()
+        else:
+            return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
+
+    def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
+        """Create a noisy data item for diffusion model training:
+
+        Args:
+            x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
+            tensor_step (bool): If tensor_step = false, only one step t is sample,
+                the whole batch is diffused to the same step and t is int.
+                If tensor_step = true, t is a tensor of size (x.size(0),)
+                every element of the batch is diffused to a independently sampled.
+        """
+        step: tp.Union[int, torch.Tensor]
+        if tensor_step:
+            bs = x.size(0)
+            step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
+        else:
+            step = self.rng.randrange(self.num_steps)
+        alpha_bar = self.get_alpha_bar(step)  # [batch_size, n_bands, 1]
+
+        x = self.sample_processor.project_sample(x)
+        noise = torch.randn_like(x)
+        noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
+        return TrainingItem(noisy, noise, step)
+
+    def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
+                 condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+        """Full ddpm reverse process.
+
+        Args:
+            model (nn.Module): Diffusion model.
+            initial (tensor): Initial Noise.
+            condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
+            return_list (bool): Whether to return the whole process or only the sampled point.
+        """
+        alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+        current = initial
+        iterates = [initial]
+        for step in range(self.num_steps)[::-1]:
+            with torch.no_grad():
+                estimate = model(current, step, condition=condition).sample
+            alpha = 1 - self.betas[step]
+            previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+            previous_alpha_bar = self.get_alpha_bar(step=step - 1)
+            if step == 0:
+                sigma2 = 0
+            elif self.variance == 'beta':
+                sigma2 = 1 - alpha
+            elif self.variance == 'beta_tilde':
+                sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+            elif self.variance == 'none':
+                sigma2 = 0
+            else:
+                raise ValueError(f'Invalid variance type {self.variance}')
+
+            if sigma2 > 0:
+                previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+            if self.clip:
+                previous = previous.clamp(-self.clip, self.clip)
+            current = previous
+            alpha_bar = previous_alpha_bar
+            if step == 0:
+                previous *= self.rescale
+            if return_list:
+                iterates.append(previous.cpu())
+
+        if return_list:
+            return iterates
+        else:
+            return self.sample_processor.return_sample(previous)
+
+    def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
+                            condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+        """Reverse process that only goes through Markov chain states in step_list."""
+        if step_list is None:
+            step_list = list(range(1000))[::-50] + [0]
+        alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+        alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
+        betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
+        current = initial * self.noise_scale
+        iterates = [current]
+        for idx, step in enumerate(step_list[:-1]):
+            with torch.no_grad():
+                estimate = model(current, step, condition=condition).sample * self.noise_scale
+            alpha = 1 - betas_subsampled[-1 - idx]
+            previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+            previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
+            if step == step_list[-2]:
+                sigma2 = 0
+                previous_alpha_bar = torch.tensor(1.0)
+            else:
+                sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+            if sigma2 > 0:
+                previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+            if self.clip:
+                previous = previous.clamp(-self.clip, self.clip)
+            current = previous
+            alpha_bar = previous_alpha_bar
+            if step == 0:
+                previous *= self.rescale
+            if return_list:
+                iterates.append(previous.cpu())
+        if return_list:
+            return iterates
+        else:
+            return self.sample_processor.return_sample(previous)
+
+

Methods

+
+
+def generate(self, model: torch.nn.modules.module.Module, initial: Optional[torch.Tensor] = None, condition: Optional[torch.Tensor] = None, return_list: bool = False) +
+
+

Full ddpm reverse process.

+

Args

+
+
model : nn.Module
+
Diffusion model.
+
initial : tensor
+
Initial Noise.
+
condition : tensor
+
Input conditionning Tensor (e.g. encodec compressed representation).
+
return_list : bool
+
Whether to return the whole process or only the sampled point.
+
+
+ +Expand source code + +
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
+             condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+    """Full ddpm reverse process.
+
+    Args:
+        model (nn.Module): Diffusion model.
+        initial (tensor): Initial Noise.
+        condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
+        return_list (bool): Whether to return the whole process or only the sampled point.
+    """
+    alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+    current = initial
+    iterates = [initial]
+    for step in range(self.num_steps)[::-1]:
+        with torch.no_grad():
+            estimate = model(current, step, condition=condition).sample
+        alpha = 1 - self.betas[step]
+        previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+        previous_alpha_bar = self.get_alpha_bar(step=step - 1)
+        if step == 0:
+            sigma2 = 0
+        elif self.variance == 'beta':
+            sigma2 = 1 - alpha
+        elif self.variance == 'beta_tilde':
+            sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+        elif self.variance == 'none':
+            sigma2 = 0
+        else:
+            raise ValueError(f'Invalid variance type {self.variance}')
+
+        if sigma2 > 0:
+            previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+        if self.clip:
+            previous = previous.clamp(-self.clip, self.clip)
+        current = previous
+        alpha_bar = previous_alpha_bar
+        if step == 0:
+            previous *= self.rescale
+        if return_list:
+            iterates.append(previous.cpu())
+
+    if return_list:
+        return iterates
+    else:
+        return self.sample_processor.return_sample(previous)
+
+
+
+def generate_subsampled(self, model: torch.nn.modules.module.Module, initial: torch.Tensor, step_list: Optional[list] = None, condition: Optional[torch.Tensor] = None, return_list: bool = False) +
+
+

Reverse process that only goes through Markov chain states in step_list.

+
+ +Expand source code + +
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
+                        condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
+    """Reverse process that only goes through Markov chain states in step_list."""
+    if step_list is None:
+        step_list = list(range(1000))[::-50] + [0]
+    alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
+    alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
+    betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
+    current = initial * self.noise_scale
+    iterates = [current]
+    for idx, step in enumerate(step_list[:-1]):
+        with torch.no_grad():
+            estimate = model(current, step, condition=condition).sample * self.noise_scale
+        alpha = 1 - betas_subsampled[-1 - idx]
+        previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
+        previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
+        if step == step_list[-2]:
+            sigma2 = 0
+            previous_alpha_bar = torch.tensor(1.0)
+        else:
+            sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
+        if sigma2 > 0:
+            previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
+        if self.clip:
+            previous = previous.clamp(-self.clip, self.clip)
+        current = previous
+        alpha_bar = previous_alpha_bar
+        if step == 0:
+            previous *= self.rescale
+        if return_list:
+            iterates.append(previous.cpu())
+    if return_list:
+        return iterates
+    else:
+        return self.sample_processor.return_sample(previous)
+
+
+
+def get_alpha_bar(self, step: Union[int, torch.Tensor, None] = None) ‑> torch.Tensor +
+
+

Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.

+
+ +Expand source code + +
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
+    """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
+    if step is None:
+        return (1 - self.betas).cumprod(dim=-1)  # works for simgle and multi bands
+    if type(step) is int:
+        return (1 - self.betas[:step + 1]).prod()
+    else:
+        return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
+
+
+
+def get_beta(self, step: Union[int, torch.Tensor]) +
+
+
+
+ +Expand source code + +
def get_beta(self, step: tp.Union[int, torch.Tensor]):
+    if self.n_bands is None:
+        return self.betas[step]
+    else:
+        return self.betas[:, step]  # [n_bands, len(step)]
+
+
+
+def get_initial_noise(self, x: torch.Tensor) +
+
+
+
+ +Expand source code + +
def get_initial_noise(self, x: torch.Tensor):
+    if self.n_bands is None:
+        return torch.randn_like(x)
+    return torch.randn((x.size(0), self.n_bands, x.size(2)))
+
+
+
+def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) ‑> TrainingItem +
+
+

Create a noisy data item for diffusion model training:

+

Args

+
+
x : torch.Tensor
+
clean audio data torch.tensor(bs, 1, T)
+
tensor_step : bool
+
If tensor_step = false, only one step t is sample, +the whole batch is diffused to the same step and t is int. +If tensor_step = true, t is a tensor of size (x.size(0),) +every element of the batch is diffused to a independently sampled.
+
+
+ +Expand source code + +
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
+    """Create a noisy data item for diffusion model training:
+
+    Args:
+        x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
+        tensor_step (bool): If tensor_step = false, only one step t is sample,
+            the whole batch is diffused to the same step and t is int.
+            If tensor_step = true, t is a tensor of size (x.size(0),)
+            every element of the batch is diffused to a independently sampled.
+    """
+    step: tp.Union[int, torch.Tensor]
+    if tensor_step:
+        bs = x.size(0)
+        step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
+    else:
+        step = self.rng.randrange(self.num_steps)
+    alpha_bar = self.get_alpha_bar(step)  # [batch_size, n_bands, 1]
+
+    x = self.sample_processor.project_sample(x)
+    noise = torch.randn_like(x)
+    noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
+    return TrainingItem(noisy, noise, step)
+
+
+
+
+
+class SampleProcessor +(*args, **kwargs) +
+
+

Base class for all neural network modules.

+

Your models should also subclass this class.

+

Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

+
import torch.nn as nn
+import torch.nn.functional as F
+
+class Model(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5)
+        self.conv2 = nn.Conv2d(20, 20, 5)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        return F.relu(self.conv2(x))
+
+

Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

+
+

Note

+

As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

+
+

:ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SampleProcessor(torch.nn.Module):
+    def project_sample(self, x: torch.Tensor):
+        """Project the original sample to the 'space' where the diffusion will happen."""
+        return x
+
+    def return_sample(self, z: torch.Tensor):
+        """Project back from diffusion space to the actual sample space."""
+        return z
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+def project_sample(self, x: torch.Tensor) +
+
+

Project the original sample to the 'space' where the diffusion will happen.

+
+ +Expand source code + +
def project_sample(self, x: torch.Tensor):
+    """Project the original sample to the 'space' where the diffusion will happen."""
+    return x
+
+
+
+def return_sample(self, z: torch.Tensor) +
+
+

Project back from diffusion space to the actual sample space.

+
+ +Expand source code + +
def return_sample(self, z: torch.Tensor):
+    """Project back from diffusion space to the actual sample space."""
+    return z
+
+
+
+
+
+class TrainingItem +(noisy, noise, step) +
+
+

TrainingItem(noisy, noise, step)

+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var noise
+
+

Alias for field number 1

+
+
var noisy
+
+

Alias for field number 0

+
+
var step
+
+

Alias for field number 2

+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/index.html b/api_docs/audiocraft/modules/index.html new file mode 100644 index 00000000..0fa41878 --- /dev/null +++ b/api_docs/audiocraft/modules/index.html @@ -0,0 +1,144 @@ + + + + + + +audiocraft.modules API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules

+
+
+

Modules used for building the models.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Modules used for building the models."""
+
+# flake8: noqa
+from .conv import (
+    NormConv1d,
+    NormConv2d,
+    NormConvTranspose1d,
+    NormConvTranspose2d,
+    StreamableConv1d,
+    StreamableConvTranspose1d,
+    pad_for_conv1d,
+    pad1d,
+    unpad1d,
+)
+from .lstm import StreamableLSTM
+from .seanet import SEANetEncoder, SEANetDecoder
+from .transformer import StreamingTransformer
+
+
+
+

Sub-modules

+
+
audiocraft.modules.activations
+
+
+
+
audiocraft.modules.chroma
+
+
+
+
audiocraft.modules.codebooks_patterns
+
+
+
+
audiocraft.modules.conditioners
+
+
+
+
audiocraft.modules.conv
+
+
+
+
audiocraft.modules.diffusion_schedule
+
+

Functions for Noise Schedule, defines diffusion process, reverse process and data processor.

+
+
audiocraft.modules.lstm
+
+
+
+
audiocraft.modules.rope
+
+
+
+
audiocraft.modules.seanet
+
+
+
+
audiocraft.modules.streaming
+
+

Streaming module API that should be implemented by all Streaming components,

+
+
audiocraft.modules.transformer
+
+

Transformer model, with streaming support, xformer attention support +and easy causal attention with a potentially finite receptive field …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/lstm.html b/api_docs/audiocraft/modules/lstm.html new file mode 100644 index 00000000..ad20d54e --- /dev/null +++ b/api_docs/audiocraft/modules/lstm.html @@ -0,0 +1,177 @@ + + + + + + +audiocraft.modules.lstm API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.lstm

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import nn
+
+
+class StreamableLSTM(nn.Module):
+    """LSTM without worrying about the hidden state, nor the layout of the data.
+    Expects input as convolutional layout.
+    """
+    def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+        super().__init__()
+        self.skip = skip
+        self.lstm = nn.LSTM(dimension, dimension, num_layers)
+
+    def forward(self, x):
+        x = x.permute(2, 0, 1)
+        y, _ = self.lstm(x)
+        if self.skip:
+            y = y + x
+        y = y.permute(1, 2, 0)
+        return y
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class StreamableLSTM +(dimension: int, num_layers: int = 2, skip: bool = True) +
+
+

LSTM without worrying about the hidden state, nor the layout of the data. +Expects input as convolutional layout.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamableLSTM(nn.Module):
+    """LSTM without worrying about the hidden state, nor the layout of the data.
+    Expects input as convolutional layout.
+    """
+    def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+        super().__init__()
+        self.skip = skip
+        self.lstm = nn.LSTM(dimension, dimension, num_layers)
+
+    def forward(self, x):
+        x = x.permute(2, 0, 1)
+        y, _ = self.lstm(x)
+        if self.skip:
+            y = y + x
+        y = y.permute(1, 2, 0)
+        return y
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = x.permute(2, 0, 1)
+    y, _ = self.lstm(x)
+    if self.skip:
+        y = y + x
+    y = y.permute(1, 2, 0)
+    return y
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/rope.html b/api_docs/audiocraft/modules/rope.html new file mode 100644 index 00000000..56bef835 --- /dev/null +++ b/api_docs/audiocraft/modules/rope.html @@ -0,0 +1,600 @@ + + + + + + +audiocraft.modules.rope API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.rope

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch import nn
+import torch
+
+
+class XPos(nn.Module):
+    """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
+    This applies an exponential decay to the RoPE rotation matrix.
+
+    Args:
+        dim (int): Embedding dimension.
+        smoothing (float): Smoothing factor applied to the decay rates.
+        base_scale (int): Base decay rate, given in terms of scaling time.
+        device (torch.device, optional): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
+                 device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+        self.base_scale = base_scale
+
+        half_dim = dim // 2
+        adim = torch.arange(half_dim, device=device, dtype=dtype)
+        decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
+        self.register_buffer("decay_rates", decay_rates)
+        self.decay: tp.Optional[torch.Tensor] = None
+
+    def get_decay(self, start: int, end: int):
+        """Create complex decay tensor, cache values for fast computation."""
+        if self.decay is None or end > self.decay.shape[0]:
+            assert isinstance(self.decay_rates, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+            power = idx / self.base_scale
+            scale = self.decay_rates ** power.unsqueeze(-1)
+            self.decay = torch.polar(scale, torch.zeros_like(scale))
+        return self.decay[start:end]  # [T, C/2]
+
+
+class RotaryEmbedding(nn.Module):
+    """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
+
+    Args:
+        dim (int): Embedding dimension (twice the number of frequencies).
+        max_period (float): Maximum period of the rotation frequencies.
+        xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
+        scale (float): Scale of positional embedding, set to 0 to deactivate.
+        device (torch.device, optional): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
+                 scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        self.scale = scale
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+
+        adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
+        frequencies = 1.0 / (max_period ** (adim / dim))
+        self.register_buffer("frequencies", frequencies)
+        self.rotation: tp.Optional[torch.Tensor] = None
+
+        self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
+
+    def get_rotation(self, start: int, end: int):
+        """Create complex rotation tensor, cache values for fast computation."""
+        if self.rotation is None or end > self.rotation.shape[0]:
+            assert isinstance(self.frequencies, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+            angles = torch.outer(idx, self.frequencies)
+            self.rotation = torch.polar(torch.ones_like(angles), angles)
+        return self.rotation[start:end]
+
+    def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
+        """Apply rope rotation to query or key tensor."""
+        T = x.shape[time_dim]
+        target_shape = [1] * x.dim()
+        target_shape[time_dim] = T
+        target_shape[-1] = -1
+        rotation = self.get_rotation(start, start + T).view(target_shape)
+
+        if self.xpos:
+            decay = self.xpos.get_decay(start, start + T).view(target_shape)
+        else:
+            decay = 1.0
+
+        if invert_decay:
+            decay = decay ** -1
+
+        x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+        scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+        x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
+
+        return x_out.type_as(x)
+
+    def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
+        """ Apply rope rotation to both query and key tensors.
+        Supports streaming mode, in which query and key are not expected to have the same shape.
+        In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
+        query will be [C] (typically C == 1).
+
+        Args:
+            query (torch.Tensor): Query to rotate.
+            key (torch.Tensor): Key to rotate.
+            start (int): Start index of the sequence for time offset.
+            time_dim (int): which dimension represent the time steps.
+        """
+        query_timesteps = query.shape[time_dim]
+        key_timesteps = key.shape[time_dim]
+        streaming_offset = key_timesteps - query_timesteps
+
+        query_out = self.rotate(query, start + streaming_offset, time_dim)
+        key_out = self.rotate(key, start, time_dim, invert_decay=True)
+
+        return query_out, key_out
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class RotaryEmbedding +(dim: int, max_period: float = 10000.0, xpos: bool = False, scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32) +
+
+

Rotary positional embedding (RoPE) from Su et al 2022.

+

Args

+
+
dim : int
+
Embedding dimension (twice the number of frequencies).
+
max_period : float
+
Maximum period of the rotation frequencies.
+
xpos : bool
+
Use xPos, applies an exponential decay to rotation matrix.
+
scale : float
+
Scale of positional embedding, set to 0 to deactivate.
+
device : torch.device, optional
+
Device on which to initialize the module.
+
dtype : torch.dtype
+
dtype to use to generate the embedding.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class RotaryEmbedding(nn.Module):
+    """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
+
+    Args:
+        dim (int): Embedding dimension (twice the number of frequencies).
+        max_period (float): Maximum period of the rotation frequencies.
+        xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
+        scale (float): Scale of positional embedding, set to 0 to deactivate.
+        device (torch.device, optional): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
+                 scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        self.scale = scale
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+
+        adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
+        frequencies = 1.0 / (max_period ** (adim / dim))
+        self.register_buffer("frequencies", frequencies)
+        self.rotation: tp.Optional[torch.Tensor] = None
+
+        self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
+
+    def get_rotation(self, start: int, end: int):
+        """Create complex rotation tensor, cache values for fast computation."""
+        if self.rotation is None or end > self.rotation.shape[0]:
+            assert isinstance(self.frequencies, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+            angles = torch.outer(idx, self.frequencies)
+            self.rotation = torch.polar(torch.ones_like(angles), angles)
+        return self.rotation[start:end]
+
+    def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
+        """Apply rope rotation to query or key tensor."""
+        T = x.shape[time_dim]
+        target_shape = [1] * x.dim()
+        target_shape[time_dim] = T
+        target_shape[-1] = -1
+        rotation = self.get_rotation(start, start + T).view(target_shape)
+
+        if self.xpos:
+            decay = self.xpos.get_decay(start, start + T).view(target_shape)
+        else:
+            decay = 1.0
+
+        if invert_decay:
+            decay = decay ** -1
+
+        x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+        scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+        x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
+
+        return x_out.type_as(x)
+
+    def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
+        """ Apply rope rotation to both query and key tensors.
+        Supports streaming mode, in which query and key are not expected to have the same shape.
+        In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
+        query will be [C] (typically C == 1).
+
+        Args:
+            query (torch.Tensor): Query to rotate.
+            key (torch.Tensor): Key to rotate.
+            start (int): Start index of the sequence for time offset.
+            time_dim (int): which dimension represent the time steps.
+        """
+        query_timesteps = query.shape[time_dim]
+        key_timesteps = key.shape[time_dim]
+        streaming_offset = key_timesteps - query_timesteps
+
+        query_out = self.rotate(query, start + streaming_offset, time_dim)
+        key_out = self.rotate(key, start, time_dim, invert_decay=True)
+
+        return query_out, key_out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+def get_rotation(self, start: int, end: int) +
+
+

Create complex rotation tensor, cache values for fast computation.

+
+ +Expand source code + +
def get_rotation(self, start: int, end: int):
+    """Create complex rotation tensor, cache values for fast computation."""
+    if self.rotation is None or end > self.rotation.shape[0]:
+        assert isinstance(self.frequencies, torch.Tensor)  # Satisfy type checker.
+        idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+        angles = torch.outer(idx, self.frequencies)
+        self.rotation = torch.polar(torch.ones_like(angles), angles)
+    return self.rotation[start:end]
+
+
+
+def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False) +
+
+

Apply rope rotation to query or key tensor.

+
+ +Expand source code + +
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
+    """Apply rope rotation to query or key tensor."""
+    T = x.shape[time_dim]
+    target_shape = [1] * x.dim()
+    target_shape[time_dim] = T
+    target_shape[-1] = -1
+    rotation = self.get_rotation(start, start + T).view(target_shape)
+
+    if self.xpos:
+        decay = self.xpos.get_decay(start, start + T).view(target_shape)
+    else:
+        decay = 1.0
+
+    if invert_decay:
+        decay = decay ** -1
+
+    x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+    scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+    x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
+
+    return x_out.type_as(x)
+
+
+
+def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1) +
+
+

Apply rope rotation to both query and key tensors. +Supports streaming mode, in which query and key are not expected to have the same shape. +In streaming mode, key will be of length [P + C] with P the cached past timesteps, but +query will be [C] (typically C == 1).

+

Args

+
+
query : torch.Tensor
+
Query to rotate.
+
key : torch.Tensor
+
Key to rotate.
+
start : int
+
Start index of the sequence for time offset.
+
time_dim : int
+
which dimension represent the time steps.
+
+
+ +Expand source code + +
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
+    """ Apply rope rotation to both query and key tensors.
+    Supports streaming mode, in which query and key are not expected to have the same shape.
+    In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
+    query will be [C] (typically C == 1).
+
+    Args:
+        query (torch.Tensor): Query to rotate.
+        key (torch.Tensor): Key to rotate.
+        start (int): Start index of the sequence for time offset.
+        time_dim (int): which dimension represent the time steps.
+    """
+    query_timesteps = query.shape[time_dim]
+    key_timesteps = key.shape[time_dim]
+    streaming_offset = key_timesteps - query_timesteps
+
+    query_out = self.rotate(query, start + streaming_offset, time_dim)
+    key_out = self.rotate(key, start, time_dim, invert_decay=True)
+
+    return query_out, key_out
+
+
+
+
+
+class XPos +(dim: int, smoothing: float = 0.4, base_scale: int = 512, device=None, dtype: torch.dtype = torch.float32) +
+
+

Length-extrapolatable positional embedding (xPos) from Sun et al 2022. +This applies an exponential decay to the RoPE rotation matrix.

+

Args

+
+
dim : int
+
Embedding dimension.
+
smoothing : float
+
Smoothing factor applied to the decay rates.
+
base_scale : int
+
Base decay rate, given in terms of scaling time.
+
device : torch.device, optional
+
Device on which to initialize the module.
+
dtype : torch.dtype
+
dtype to use to generate the embedding.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class XPos(nn.Module):
+    """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
+    This applies an exponential decay to the RoPE rotation matrix.
+
+    Args:
+        dim (int): Embedding dimension.
+        smoothing (float): Smoothing factor applied to the decay rates.
+        base_scale (int): Base decay rate, given in terms of scaling time.
+        device (torch.device, optional): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
+                 device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+        self.base_scale = base_scale
+
+        half_dim = dim // 2
+        adim = torch.arange(half_dim, device=device, dtype=dtype)
+        decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
+        self.register_buffer("decay_rates", decay_rates)
+        self.decay: tp.Optional[torch.Tensor] = None
+
+    def get_decay(self, start: int, end: int):
+        """Create complex decay tensor, cache values for fast computation."""
+        if self.decay is None or end > self.decay.shape[0]:
+            assert isinstance(self.decay_rates, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+            power = idx / self.base_scale
+            scale = self.decay_rates ** power.unsqueeze(-1)
+            self.decay = torch.polar(scale, torch.zeros_like(scale))
+        return self.decay[start:end]  # [T, C/2]
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+def get_decay(self, start: int, end: int) +
+
+

Create complex decay tensor, cache values for fast computation.

+
+ +Expand source code + +
def get_decay(self, start: int, end: int):
+    """Create complex decay tensor, cache values for fast computation."""
+    if self.decay is None or end > self.decay.shape[0]:
+        assert isinstance(self.decay_rates, torch.Tensor)  # Satisfy type checker.
+        idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+        power = idx / self.base_scale
+        scale = self.decay_rates ** power.unsqueeze(-1)
+        self.decay = torch.polar(scale, torch.zeros_like(scale))
+    return self.decay[start:end]  # [T, C/2]
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/seanet.html b/api_docs/audiocraft/modules/seanet.html new file mode 100644 index 00000000..831a462b --- /dev/null +++ b/api_docs/audiocraft/modules/seanet.html @@ -0,0 +1,879 @@ + + + + + + +audiocraft.modules.seanet API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.seanet

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+import torch.nn as nn
+
+from .conv import StreamableConv1d, StreamableConvTranspose1d
+from .lstm import StreamableLSTM
+
+
+class SEANetResnetBlock(nn.Module):
+    """Residual block from SEANet model.
+
+    Args:
+        dim (int): Dimension of the input/output.
+        kernel_sizes (list): List of kernel sizes for the convolutions.
+        dilations (list): List of dilations for the convolutions.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection.
+    """
+    def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
+                 activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
+                 pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
+        super().__init__()
+        assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
+        act = getattr(nn, activation)
+        hidden = dim // compress
+        block = []
+        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+            in_chs = dim if i == 0 else hidden
+            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+            block += [
+                act(**activation_params),
+                StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
+                                 norm=norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+        self.block = nn.Sequential(*block)
+        self.shortcut: nn.Module
+        if true_skip:
+            self.shortcut = nn.Identity()
+        else:
+            self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
+                                             causal=causal, pad_mode=pad_mode)
+
+    def forward(self, x):
+        return self.shortcut(x) + self.block(x)
+
+
+class SEANetEncoder(nn.Module):
+    """SEANet encoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+            upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+            that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the encoder, it corresponds to the N first blocks.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0):
+        super().__init__()
+        self.channels = channels
+        self.dimension = dimension
+        self.n_filters = n_filters
+        self.ratios = list(reversed(ratios))
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = 1
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(channels, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Downsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      norm=block_norm, norm_params=norm_params,
+                                      activation=activation, activation_params=activation_params,
+                                      causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            # Add downsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConv1d(mult * n_filters, mult * n_filters * 2,
+                                 kernel_size=ratio * 2, stride=ratio,
+                                 norm=block_norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+            mult *= 2
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        model += [
+            act(**activation_params),
+            StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        self.model = nn.Sequential(*model)
+
+    def forward(self, x):
+        return self.model(x)
+
+
+class SEANetDecoder(nn.Module):
+    """SEANet decoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        final_activation (str): Final activation function after all convolutions.
+        final_activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple.
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the decoder, it corresponds to the N last blocks.
+        trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+            If equal to 1.0, it means that all the trimming is done at the right.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
+        super().__init__()
+        self.dimension = dimension
+        self.channels = channels
+        self.n_filters = n_filters
+        self.ratios = ratios
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = int(2 ** len(self.ratios))
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(dimension, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        # Upsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
+            # Add upsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
+                                          kernel_size=ratio * 2, stride=ratio,
+                                          norm=block_norm, norm_kwargs=norm_params,
+                                          causal=causal, trim_right_ratio=trim_right_ratio),
+            ]
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      activation=activation, activation_params=activation_params,
+                                      norm=block_norm, norm_params=norm_params, causal=causal,
+                                      pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            mult //= 2
+
+        # Add final layers
+        model += [
+            act(**activation_params),
+            StreamableConv1d(n_filters, channels, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Add optional final activation to decoder (eg. tanh)
+        if final_activation is not None:
+            final_act = getattr(nn, final_activation)
+            final_activation_params = final_activation_params or {}
+            model += [
+                final_act(**final_activation_params)
+            ]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, z):
+        y = self.model(z)
+        return y
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class SEANetDecoder +(channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, ratios: List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, final_activation: Optional[str] = None, final_activation_params: Optional[dict] = None, norm: str = 'none', norm_params: Dict[str, Any] = {}, kernel_size: int = 7, last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0) +
+
+

SEANet decoder.

+

Args

+
+
channels : int
+
Audio channels.
+
dimension : int
+
Intermediate representation dimension.
+
n_filters : int
+
Base width for the model.
+
n_residual_layers : int
+
nb of residual layers.
+
ratios : Sequence[int]
+
kernel size and stride ratios.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
final_activation : str
+
Final activation function after all convolutions.
+
final_activation_params : dict
+
Parameters to provide to the activation function.
+
norm : str
+
Normalization method.
+
norm_params : dict
+
Parameters to provide to the underlying normalization used along with the convolution.
+
kernel_size : int
+
Kernel size for the initial convolution.
+
last_kernel_size : int
+
Kernel size for the initial convolution.
+
residual_kernel_size : int
+
Kernel size for the residual layers.
+
dilation_base : int
+
How much to increase the dilation with each layer.
+
causal : bool
+
Whether to use fully causal convolution.
+
pad_mode : str
+
Padding mode for the convolutions.
+
true_skip : bool
+
Whether to use true skip connection or a simple. +(streamable) convolution as the skip connection in the residual network blocks.
+
compress : int
+
Reduced dimensionality in residual branches (from Demucs v3).
+
lstm : int
+
Number of LSTM layers at the end of the encoder.
+
disable_norm_outer_blocks : int
+
Number of blocks for which we don't apply norm. +For the decoder, it corresponds to the N last blocks.
+
trim_right_ratio : float
+
Ratio for trimming at the right of the transposed convolution under the causal setup. +If equal to 1.0, it means that all the trimming is done at the right.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SEANetDecoder(nn.Module):
+    """SEANet decoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        final_activation (str): Final activation function after all convolutions.
+        final_activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple.
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the decoder, it corresponds to the N last blocks.
+        trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+            If equal to 1.0, it means that all the trimming is done at the right.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
+        super().__init__()
+        self.dimension = dimension
+        self.channels = channels
+        self.n_filters = n_filters
+        self.ratios = ratios
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = int(2 ** len(self.ratios))
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(dimension, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        # Upsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
+            # Add upsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
+                                          kernel_size=ratio * 2, stride=ratio,
+                                          norm=block_norm, norm_kwargs=norm_params,
+                                          causal=causal, trim_right_ratio=trim_right_ratio),
+            ]
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      activation=activation, activation_params=activation_params,
+                                      norm=block_norm, norm_params=norm_params, causal=causal,
+                                      pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            mult //= 2
+
+        # Add final layers
+        model += [
+            act(**activation_params),
+            StreamableConv1d(n_filters, channels, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Add optional final activation to decoder (eg. tanh)
+        if final_activation is not None:
+            final_act = getattr(nn, final_activation)
+            final_activation_params = final_activation_params or {}
+            model += [
+                final_act(**final_activation_params)
+            ]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, z):
+        y = self.model(z)
+        return y
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, z) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, z):
+    y = self.model(z)
+    return y
+
+
+
+
+
+class SEANetEncoder +(channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, ratios: List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'none', norm_params: Dict[str, Any] = {}, kernel_size: int = 7, last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, disable_norm_outer_blocks: int = 0) +
+
+

SEANet encoder.

+

Args

+
+
channels : int
+
Audio channels.
+
dimension : int
+
Intermediate representation dimension.
+
n_filters : int
+
Base width for the model.
+
n_residual_layers : int
+
nb of residual layers.
+
ratios : Sequence[int]
+
kernel size and stride ratios. The encoder uses downsampling ratios instead of +upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here +that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
norm : str
+
Normalization method.
+
norm_params : dict
+
Parameters to provide to the underlying normalization used along with the convolution.
+
kernel_size : int
+
Kernel size for the initial convolution.
+
last_kernel_size : int
+
Kernel size for the initial convolution.
+
residual_kernel_size : int
+
Kernel size for the residual layers.
+
dilation_base : int
+
How much to increase the dilation with each layer.
+
causal : bool
+
Whether to use fully causal convolution.
+
pad_mode : str
+
Padding mode for the convolutions.
+
true_skip : bool
+
Whether to use true skip connection or a simple +(streamable) convolution as the skip connection in the residual network blocks.
+
compress : int
+
Reduced dimensionality in residual branches (from Demucs v3).
+
lstm : int
+
Number of LSTM layers at the end of the encoder.
+
disable_norm_outer_blocks : int
+
Number of blocks for which we don't apply norm. +For the encoder, it corresponds to the N first blocks.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SEANetEncoder(nn.Module):
+    """SEANet encoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+            upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+            that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the encoder, it corresponds to the N first blocks.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0):
+        super().__init__()
+        self.channels = channels
+        self.dimension = dimension
+        self.n_filters = n_filters
+        self.ratios = list(reversed(ratios))
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = 1
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(channels, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Downsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      norm=block_norm, norm_params=norm_params,
+                                      activation=activation, activation_params=activation_params,
+                                      causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            # Add downsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConv1d(mult * n_filters, mult * n_filters * 2,
+                                 kernel_size=ratio * 2, stride=ratio,
+                                 norm=block_norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+            mult *= 2
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        model += [
+            act(**activation_params),
+            StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        self.model = nn.Sequential(*model)
+
+    def forward(self, x):
+        return self.model(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    return self.model(x)
+
+
+
+
+
+class SEANetResnetBlock +(dim: int, kernel_sizes: List[int] = [3, 1], dilations: List[int] = [1, 1], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'none', norm_params: Dict[str, Any] = {}, causal: bool = False, pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True) +
+
+

Residual block from SEANet model.

+

Args

+
+
dim : int
+
Dimension of the input/output.
+
kernel_sizes : list
+
List of kernel sizes for the convolutions.
+
dilations : list
+
List of dilations for the convolutions.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
norm : str
+
Normalization method.
+
norm_params : dict
+
Parameters to provide to the underlying normalization used along with the convolution.
+
causal : bool
+
Whether to use fully causal convolution.
+
pad_mode : str
+
Padding mode for the convolutions.
+
compress : int
+
Reduced dimensionality in residual branches (from Demucs v3).
+
true_skip : bool
+
Whether to use true skip connection or a simple +(streamable) convolution as the skip connection.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SEANetResnetBlock(nn.Module):
+    """Residual block from SEANet model.
+
+    Args:
+        dim (int): Dimension of the input/output.
+        kernel_sizes (list): List of kernel sizes for the convolutions.
+        dilations (list): List of dilations for the convolutions.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection.
+    """
+    def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
+                 activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
+                 pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
+        super().__init__()
+        assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
+        act = getattr(nn, activation)
+        hidden = dim // compress
+        block = []
+        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+            in_chs = dim if i == 0 else hidden
+            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+            block += [
+                act(**activation_params),
+                StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
+                                 norm=norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+        self.block = nn.Sequential(*block)
+        self.shortcut: nn.Module
+        if true_skip:
+            self.shortcut = nn.Identity()
+        else:
+            self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
+                                             causal=causal, pad_mode=pad_mode)
+
+    def forward(self, x):
+        return self.shortcut(x) + self.block(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    return self.shortcut(x) + self.block(x)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/streaming.html b/api_docs/audiocraft/modules/streaming.html new file mode 100644 index 00000000..a3334924 --- /dev/null +++ b/api_docs/audiocraft/modules/streaming.html @@ -0,0 +1,561 @@ + + + + + + +audiocraft.modules.streaming API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.streaming

+
+
+

Streaming module API that should be implemented by all Streaming components,

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Streaming module API that should be implemented by all Streaming components,
+"""
+
+from contextlib import contextmanager
+import typing as tp
+from torch import nn
+import torch
+
+
+State = tp.Dict[str, torch.Tensor]
+
+
+class StreamingModule(nn.Module):
+    """Common API for streaming components.
+
+    Each streaming component has a streaming state, which is just a dict[str, Tensor].
+    By convention, the first dim of each tensor must be the batch size.
+    Don't use dots in the key names, as this would clash with submodules
+    (like in state_dict).
+
+    If `self._is_streaming` is True, the component should use and remember
+    the proper state inside `self._streaming_state`.
+
+    To set a streaming component in streaming state, use
+
+        with module.streaming():
+            ...
+
+    This will automatically reset the streaming state when exiting the context manager.
+    This also automatically propagates to all streaming children module.
+
+    Some module might also implement the `StreamingModule.flush` method, although
+    this one is trickier, as all parents module must be StreamingModule and implement
+    it as well for it to work properly. See `StreamingSequential` after.
+    """
+    def __init__(self) -> None:
+        super().__init__()
+        self._streaming_state: State = {}
+        self._is_streaming = False
+
+    def _apply_named_streaming(self, fn: tp.Any):
+        for name, module in self.named_modules():
+            if isinstance(module, StreamingModule):
+                fn(name, module)
+
+    def _set_streaming(self, streaming: bool):
+        def _set_streaming(name, module):
+            module._is_streaming = streaming
+        self._apply_named_streaming(_set_streaming)
+
+    @contextmanager
+    def streaming(self):
+        """Context manager to enter streaming mode. Reset streaming state on exit."""
+        self._set_streaming(True)
+        try:
+            yield
+        finally:
+            self._set_streaming(False)
+            self.reset_streaming()
+
+    def reset_streaming(self):
+        """Reset the streaming state."""
+        def _reset(name: str, module: StreamingModule):
+            module._streaming_state.clear()
+
+        self._apply_named_streaming(_reset)
+
+    def get_streaming_state(self) -> State:
+        """Return the streaming state, including that of sub-modules."""
+        state: State = {}
+
+        def _add(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            for key, value in module._streaming_state.items():
+                state[name + key] = value
+
+        self._apply_named_streaming(_add)
+        return state
+
+    def set_streaming_state(self, state: State):
+        """Set the streaming state, including that of sub-modules."""
+        state = dict(state)
+
+        def _set(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            module._streaming_state.clear()
+            for key, value in list(state.items()):
+                # complexity is not ideal here, but probably fine.
+                if key.startswith(name):
+                    local_key = key[len(name):]
+                    if '.' not in local_key:
+                        module._streaming_state[local_key] = value
+                        del state[key]
+
+        self._apply_named_streaming(_set)
+        assert len(state) == 0, list(state.keys())
+
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        """Flush any remaining outputs that were waiting for completion.
+        Typically, for convolutions, this will add the final padding
+        and process the last buffer.
+
+        This should take an optional argument `x`, which will be provided
+        if a module before this one in the streaming pipeline has already
+        spitted out a flushed out buffer.
+        """
+        if x is None:
+            return None
+        else:
+            return self(x)
+
+
+class StreamingSequential(StreamingModule, nn.Sequential):
+    """A streaming compatible alternative of `nn.Sequential`.
+    """
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        for module in self:
+            if isinstance(module, StreamingModule):
+                x = module.flush(x)
+            elif x is not None:
+                x = module(x)
+        return x
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class StreamingModule +
+
+

Common API for streaming components.

+

Each streaming component has a streaming state, which is just a dict[str, Tensor]. +By convention, the first dim of each tensor must be the batch size. +Don't use dots in the key names, as this would clash with submodules +(like in state_dict).

+

If self._is_streaming is True, the component should use and remember +the proper state inside self._streaming_state.

+

To set a streaming component in streaming state, use

+
with module.streaming():
+    ...
+
+

This will automatically reset the streaming state when exiting the context manager. +This also automatically propagates to all streaming children module.

+

Some module might also implement the StreamingModule.flush() method, although +this one is trickier, as all parents module must be StreamingModule and implement +it as well for it to work properly. See StreamingSequential after.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingModule(nn.Module):
+    """Common API for streaming components.
+
+    Each streaming component has a streaming state, which is just a dict[str, Tensor].
+    By convention, the first dim of each tensor must be the batch size.
+    Don't use dots in the key names, as this would clash with submodules
+    (like in state_dict).
+
+    If `self._is_streaming` is True, the component should use and remember
+    the proper state inside `self._streaming_state`.
+
+    To set a streaming component in streaming state, use
+
+        with module.streaming():
+            ...
+
+    This will automatically reset the streaming state when exiting the context manager.
+    This also automatically propagates to all streaming children module.
+
+    Some module might also implement the `StreamingModule.flush` method, although
+    this one is trickier, as all parents module must be StreamingModule and implement
+    it as well for it to work properly. See `StreamingSequential` after.
+    """
+    def __init__(self) -> None:
+        super().__init__()
+        self._streaming_state: State = {}
+        self._is_streaming = False
+
+    def _apply_named_streaming(self, fn: tp.Any):
+        for name, module in self.named_modules():
+            if isinstance(module, StreamingModule):
+                fn(name, module)
+
+    def _set_streaming(self, streaming: bool):
+        def _set_streaming(name, module):
+            module._is_streaming = streaming
+        self._apply_named_streaming(_set_streaming)
+
+    @contextmanager
+    def streaming(self):
+        """Context manager to enter streaming mode. Reset streaming state on exit."""
+        self._set_streaming(True)
+        try:
+            yield
+        finally:
+            self._set_streaming(False)
+            self.reset_streaming()
+
+    def reset_streaming(self):
+        """Reset the streaming state."""
+        def _reset(name: str, module: StreamingModule):
+            module._streaming_state.clear()
+
+        self._apply_named_streaming(_reset)
+
+    def get_streaming_state(self) -> State:
+        """Return the streaming state, including that of sub-modules."""
+        state: State = {}
+
+        def _add(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            for key, value in module._streaming_state.items():
+                state[name + key] = value
+
+        self._apply_named_streaming(_add)
+        return state
+
+    def set_streaming_state(self, state: State):
+        """Set the streaming state, including that of sub-modules."""
+        state = dict(state)
+
+        def _set(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            module._streaming_state.clear()
+            for key, value in list(state.items()):
+                # complexity is not ideal here, but probably fine.
+                if key.startswith(name):
+                    local_key = key[len(name):]
+                    if '.' not in local_key:
+                        module._streaming_state[local_key] = value
+                        del state[key]
+
+        self._apply_named_streaming(_set)
+        assert len(state) == 0, list(state.keys())
+
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        """Flush any remaining outputs that were waiting for completion.
+        Typically, for convolutions, this will add the final padding
+        and process the last buffer.
+
+        This should take an optional argument `x`, which will be provided
+        if a module before this one in the streaming pipeline has already
+        spitted out a flushed out buffer.
+        """
+        if x is None:
+            return None
+        else:
+            return self(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def flush(self, x: Optional[torch.Tensor] = None) +
+
+

Flush any remaining outputs that were waiting for completion. +Typically, for convolutions, this will add the final padding +and process the last buffer.

+

This should take an optional argument x, which will be provided +if a module before this one in the streaming pipeline has already +spitted out a flushed out buffer.

+
+ +Expand source code + +
def flush(self, x: tp.Optional[torch.Tensor] = None):
+    """Flush any remaining outputs that were waiting for completion.
+    Typically, for convolutions, this will add the final padding
+    and process the last buffer.
+
+    This should take an optional argument `x`, which will be provided
+    if a module before this one in the streaming pipeline has already
+    spitted out a flushed out buffer.
+    """
+    if x is None:
+        return None
+    else:
+        return self(x)
+
+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+def get_streaming_state(self) ‑> Dict[str, torch.Tensor] +
+
+

Return the streaming state, including that of sub-modules.

+
+ +Expand source code + +
def get_streaming_state(self) -> State:
+    """Return the streaming state, including that of sub-modules."""
+    state: State = {}
+
+    def _add(name: str, module: StreamingModule):
+        if name:
+            name += "."
+        for key, value in module._streaming_state.items():
+            state[name + key] = value
+
+    self._apply_named_streaming(_add)
+    return state
+
+
+
+def reset_streaming(self) +
+
+

Reset the streaming state.

+
+ +Expand source code + +
def reset_streaming(self):
+    """Reset the streaming state."""
+    def _reset(name: str, module: StreamingModule):
+        module._streaming_state.clear()
+
+    self._apply_named_streaming(_reset)
+
+
+
+def set_streaming_state(self, state: Dict[str, torch.Tensor]) +
+
+

Set the streaming state, including that of sub-modules.

+
+ +Expand source code + +
def set_streaming_state(self, state: State):
+    """Set the streaming state, including that of sub-modules."""
+    state = dict(state)
+
+    def _set(name: str, module: StreamingModule):
+        if name:
+            name += "."
+        module._streaming_state.clear()
+        for key, value in list(state.items()):
+            # complexity is not ideal here, but probably fine.
+            if key.startswith(name):
+                local_key = key[len(name):]
+                if '.' not in local_key:
+                    module._streaming_state[local_key] = value
+                    del state[key]
+
+    self._apply_named_streaming(_set)
+    assert len(state) == 0, list(state.keys())
+
+
+
+def streaming(self) +
+
+

Context manager to enter streaming mode. Reset streaming state on exit.

+
+ +Expand source code + +
@contextmanager
+def streaming(self):
+    """Context manager to enter streaming mode. Reset streaming state on exit."""
+    self._set_streaming(True)
+    try:
+        yield
+    finally:
+        self._set_streaming(False)
+        self.reset_streaming()
+
+
+
+
+
+class StreamingSequential +
+
+

A streaming compatible alternative of nn.Sequential.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingSequential(StreamingModule, nn.Sequential):
+    """A streaming compatible alternative of `nn.Sequential`.
+    """
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        for module in self:
+            if isinstance(module, StreamingModule):
+                x = module.flush(x)
+            elif x is not None:
+                x = module(x)
+        return x
+
+

Ancestors

+
    +
  • StreamingModule
  • +
  • torch.nn.modules.container.Sequential
  • +
  • torch.nn.modules.module.Module
  • +
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/modules/transformer.html b/api_docs/audiocraft/modules/transformer.html new file mode 100644 index 00000000..a0b997a1 --- /dev/null +++ b/api_docs/audiocraft/modules/transformer.html @@ -0,0 +1,2015 @@ + + + + + + +audiocraft.modules.transformer API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.transformer

+
+
+

Transformer model, with streaming support, xformer attention support +and easy causal attention with a potentially finite receptive field.

+

See StreamingTransformer for more information.

+

Unlike regular PyTorch Transformer, we make the hard choice that batches are first.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Transformer model, with streaming support, xformer attention support
+and easy causal attention with a potentially finite receptive field.
+
+See `StreamingTransformer` for more information.
+
+Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
+"""
+
+import typing as tp
+
+from einops import rearrange
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint as torch_checkpoint
+from xformers import ops
+
+from .rope import RotaryEmbedding
+from .streaming import StreamingModule
+
+_efficient_attention_backend: str = 'torch'
+
+
+def set_efficient_attention_backend(backend: str = 'torch'):
+    # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
+    global _efficient_attention_backend
+    assert _efficient_attention_backend in ['xformers', 'torch']
+    _efficient_attention_backend = backend
+
+
+def _get_attention_time_dimension(memory_efficient: bool) -> int:
+    if _efficient_attention_backend == 'torch' and memory_efficient:
+        return 2
+    else:
+        return 1
+
+
+def _is_profiled() -> bool:
+    # Return true if we are currently running with a xformers profiler activated.
+    try:
+        from xformers.profiler import profiler
+    except ImportError:
+        return False
+    return profiler._Profiler._CURRENT_PROFILER is not None
+
+
+def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
+    """Create normalization module for transformer encoder layer.
+
+    Args:
+        norm_type (str): Normalization method.
+        dim (int): Dimension of the normalized layer.
+        **kwargs (dict): Additional parameters for normalization layer.
+    Returns:
+        nn.Module: Normalization module.
+    """
+    if norm_type == 'layer_norm':
+        return nn.LayerNorm(dim, eps=1e-5, **kwargs)
+    else:
+        raise ValueError(f"Unknown norm type: {norm_type}")
+
+
+def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
+                         dtype: torch.dtype = torch.float32) -> torch.Tensor:
+    """Create sinusoidal positional embedding, with shape `[B, T, C]`.
+
+    Args:
+        positions (torch.Tensor): LongTensor of positions.
+        dim (int): Dimension of the embedding.
+        max_period (float): Maximum period of the cosine/sine functions.
+        dtype (torch.dtype or str): dtype to use to generate the embedding.
+    Returns:
+        torch.Tensor: Sinusoidal positional embedding.
+    """
+    # We aim for BTC format
+    assert dim % 2 == 0
+    half_dim = dim // 2
+    positions = positions.to(dtype)
+    adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
+    max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)  # avoid sync point
+    phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
+    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
+
+
+def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
+    """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
+    if n_rep == 1:
+        return x
+    if _efficient_attention_backend == 'torch' and memory_efficient:
+        bs, n_kv_heads, slen, head_dim = x.shape
+        return (
+            x[:, :, None, :, :]
+            .expand(bs, n_kv_heads, n_rep, slen, head_dim)
+            .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
+        )
+    else:
+        bs, slen, n_kv_heads, head_dim = x.shape
+        return (
+            x[:, :, :, None, :]
+            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+        )
+
+
+class LayerScale(nn.Module):
+    """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+    This rescales diagonally the residual outputs close to 0, with a learnt scale.
+
+    Args:
+        channels (int): Number of channels.
+        init (float): Initial scale.
+        channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
+        device (torch.device or str, optional): Device on which to initialize the module.
+        dtype (torch.dtype, optional): dtype to use to initialize the module.
+    """
+    def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
+                 device=None, dtype=None):
+        super().__init__()
+        self.channel_last = channel_last
+        self.scale = nn.Parameter(
+            torch.full((channels,), init,
+                       requires_grad=True, device=device, dtype=dtype))
+
+    def forward(self, x: torch.Tensor):
+        if self.channel_last:
+            return self.scale * x
+        else:
+            return self.scale[:, None] * x
+
+
+class StreamingMultiheadAttention(StreamingModule):
+    """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
+
+    Args:
+        embed_dim (int): Dimension to project to.
+        num_heads (int): Number of heads.
+        dropout (float): Dropout level.
+        bias (bool): Use bias in projections.
+        causal (bool): Causal mask applied automatically.
+        past_context (int, optional): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        rope (`RotaryEmbedding`, optional): Rope embedding to use.
+        cross_attention: Should be true when used as a cross attention.
+            All keys and values must be available at once, streaming is only for the queries.
+            Cannot be used with `causal` or `rope` (as it wouldn't make sens to
+            interpret the time steps in the keys relative to those in the queries).
+        safe_streaming (bool): Bug fix, will go away with xformers update.
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device, optional): Device on which to initialize.
+        dtype (torch.dtype, optional): dtype to use.
+    """
+    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
+                 safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
+                 device=None, dtype=None):
+        super().__init__()
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        if past_context is not None:
+            assert causal
+
+        self.embed_dim = embed_dim
+        self.causal = causal
+        self.past_context = past_context
+        self.memory_efficient = memory_efficient
+        self.attention_as_float32 = attention_as_float32
+        self.rope = rope
+        self.cross_attention = cross_attention
+        self.safe_streaming = safe_streaming
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.kv_repeat = kv_repeat
+        if cross_attention:
+            assert not causal, "Causal cannot work with cross attention."
+            assert rope is None, "Rope cannot work with cross attention."
+
+        if memory_efficient:
+            _verify_xformers_memory_efficient_compat()
+
+        self.custom = _is_custom(custom, memory_efficient)
+        if self.custom:
+            out_dim = embed_dim
+            assert num_heads % kv_repeat == 0
+            assert not cross_attention or kv_repeat == 1
+            num_kv = num_heads // kv_repeat
+            kv_dim = (embed_dim // num_heads) * num_kv
+            out_dim += 2 * kv_dim
+            in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
+            # We try to follow the default PyTorch MHA convention, to easily compare results.
+            self.in_proj_weight = in_proj.weight
+            self.in_proj_bias = in_proj.bias
+            if bias:
+                self.in_proj_bias.data.zero_()  # Following Pytorch convention
+            self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+            if bias:
+                self.out_proj.bias.data.zero_()
+        else:
+            assert not qk_layer_norm
+            assert kv_repeat == 1
+            self.mha = nn.MultiheadAttention(
+                embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
+                **factory_kwargs)
+        self.qk_layer_norm = qk_layer_norm
+        if qk_layer_norm:
+            assert self.custom
+            assert kv_repeat == 1
+            ln_dim = embed_dim
+            self.q_layer_norm = nn.LayerNorm(ln_dim)
+            self.k_layer_norm = nn.LayerNorm(ln_dim)
+
+    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+        if not self.custom:
+            # Support compat with regular MHA
+            keys = [n for n, _ in self.mha.named_parameters()]
+            for key in keys:
+                if prefix + key in state_dict:
+                    state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
+        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+    def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
+        # Return a causal mask, accounting for potentially stored past keys/values
+        # We actually return a bias for the attention score, as this has the same
+        # convention both in the builtin MHA in Pytorch, and Xformers functions.
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        if self.memory_efficient:
+            from xformers.ops import LowerTriangularMask
+            if current_steps == 1:
+                # If we only have one step, then we do not need a mask.
+                return None
+            elif 'past_keys' in self._streaming_state:
+                raise RuntimeError("Not supported at the moment")
+            else:
+                # Then we can safely use a lower triangular mask
+                return LowerTriangularMask()
+        if self._streaming_state:
+            past_keys = self._streaming_state['past_keys']
+            past_steps = past_keys.shape[time_dim]
+        else:
+            past_steps = 0
+
+        queries_pos = torch.arange(
+            past_steps, current_steps + past_steps, device=device).view(-1, 1)
+        keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
+        delta = queries_pos - keys_pos
+        valid = delta >= 0
+        if self.past_context is not None:
+            valid &= (delta <= self.past_context)
+        return torch.where(
+            valid,
+            torch.zeros([], device=device, dtype=dtype),
+            torch.full([], float('-inf'), device=device, dtype=dtype))
+
+    def _complete_kv(self, k, v):
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        if self.cross_attention:
+            # With cross attention we assume all keys and values
+            # are already available, and streaming is with respect
+            # to the queries only.
+            return k, v
+        # Complete the key/value pair using the streaming state.
+        if self._streaming_state:
+            pk = self._streaming_state['past_keys']
+            nk = torch.cat([pk, k], dim=time_dim)
+            if v is k:
+                nv = nk
+            else:
+                pv = self._streaming_state['past_values']
+                nv = torch.cat([pv, v], dim=time_dim)
+        else:
+            nk = k
+            nv = v
+
+        assert nk.shape[time_dim] == nv.shape[time_dim]
+        offset = 0
+        if self.past_context is not None:
+            offset = max(0, nk.shape[time_dim] - self.past_context)
+        if self._is_streaming:
+            self._streaming_state['past_keys'] = nk[:, offset:]
+            if v is not k:
+                self._streaming_state['past_values'] = nv[:, offset:]
+            if 'offset' in self._streaming_state:
+                self._streaming_state['offset'] += offset
+            else:
+                self._streaming_state['offset'] = torch.tensor(0)
+        return nk, nv
+
+    def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        # Apply rope embeddings to query and key tensors.
+        assert self.rope is not None
+        if 'past_keys' in self._streaming_state:
+            past_keys_offset = self._streaming_state['past_keys'].shape[1]
+        else:
+            past_keys_offset = 0
+        if 'offset' in self._streaming_state:
+            past_context_offset = int(self._streaming_state['offset'].item())
+        else:
+            past_context_offset = 0
+        streaming_offset = past_context_offset + past_keys_offset
+        return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
+
+    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
+                key_padding_mask=None, need_weights=False, attn_mask=None,
+                average_attn_weights=True, is_causal=False):
+        assert attn_mask is None
+        assert not is_causal, ("New param added in torch 2.0.1 not supported, "
+                               "use the causal args in the constructor.")
+
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        if time_dim == 2:
+            layout = "b h t d"
+        else:
+            layout = "b t h d"
+        dtype = query.dtype
+        if self._is_streaming:
+            assert self.causal or self.cross_attention, \
+                "Streaming only available for causal or cross attention"
+
+        if self.causal:
+            # At the moment we specialize only for the self-attention case.
+            assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
+
+        if self.custom:
+            # custom implementation
+            assert need_weights is False
+            assert key_padding_mask is None
+            if self.cross_attention:
+                # Different queries, keys, values, we have to spit manually the weights
+                # before applying the linear.
+                dim = self.in_proj_weight.shape[0] // 3
+                if self.in_proj_bias is None:
+                    bias_q, bias_k, bias_v = None, None, None
+                else:
+                    bias_q = self.in_proj_bias[:dim]
+                    bias_k = self.in_proj_bias[dim: 2 * dim]
+                    bias_v = self.in_proj_bias[2 * dim:]
+                q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
+                # todo: when streaming, we could actually save k, v and check the shape actually match.
+                k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
+                v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
+                if self.qk_layer_norm is True:
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
+            else:
+                if not _is_profiled():
+                    # profiling breaks that propertysomehow.
+                    assert query is key, "specialized implementation"
+                    assert value is key, "specialized implementation"
+                projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
+                if self.kv_repeat == 1:
+                    if time_dim == 2:
+                        bound_layout = "b h p t d"
+                    else:
+                        bound_layout = "b t p h d"
+                    packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
+                    q, k, v = ops.unbind(packed, dim=2)
+                else:
+                    embed_dim = self.embed_dim
+                    per_head_dim = (embed_dim // self.num_heads)
+                    kv_heads = self.num_heads // self.kv_repeat
+                    q = projected[:, :, :embed_dim]
+                    start = embed_dim
+                    end = start + per_head_dim * kv_heads
+                    k = projected[:, :, start: end]
+                    v = projected[:, :, end:]
+                    q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
+                    k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
+                    v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
+
+                if self.qk_layer_norm is True:
+                    assert self.kv_repeat == 1
+                    q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                    q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
+                if self.rope:
+                    q, k = self._apply_rope(q, k)
+                k, v = self._complete_kv(k, v)
+                if self.kv_repeat > 1:
+                    k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
+                    v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
+            if self.attention_as_float32:
+                q, k, v = [x.float() for x in [q, k, v]]
+            if self.memory_efficient:
+                p = self.dropout if self.training else 0
+                if _efficient_attention_backend == 'torch':
+                    x = torch.nn.functional.scaled_dot_product_attention(
+                        q, k, v, is_causal=attn_mask is not None, dropout_p=p)
+                else:
+                    x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
+            else:
+                # We include the dot product as float32, for consistency
+                # with the other implementations that include that step
+                # as part of the attention. Note that when using `autocast`,
+                # the einsums would be done as bfloat16, but the softmax
+                # would be done as bfloat16, so `attention_as_float32` will
+                # extend a bit the range of operations done in float32,
+                # although this should make no difference.
+                q = q / q.shape[-1] ** 0.5
+                key_layout = layout.replace('t', 'k')
+                query_layout = layout
+                if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
+                    with torch.autocast(device_type=q.device.type, dtype=torch.float32):
+                        pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                else:
+                    pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                if attn_mask is not None:
+                    pre_w = pre_w + attn_mask
+                w = torch.softmax(pre_w, dim=-1)
+                w = F.dropout(w, self.dropout, training=self.training).to(v)
+                # Key and value have the same format.
+                x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
+            x = x.to(dtype)
+            x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
+            x = self.out_proj(x)
+        else:
+            key, value = self._complete_kv(key, value)
+            if self.attention_as_float32:
+                query, key, value = [x.float() for x in [query, key, value]]
+            x, _ = self.mha(
+                query, key, value, key_padding_mask,
+                need_weights, attn_mask, average_attn_weights)
+            x = x.to(dtype)
+
+        return x, None
+
+
+class StreamingTransformerLayer(nn.TransformerEncoderLayer):
+    """TransformerLayer with Streaming / Causal support.
+    This also integrates cross_attention, when passing `cross_attention=True`,
+    rather than having two separate classes like in PyTorch.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int, optional): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
+        qk_layer_norm_cross (bool): Same for the cross attention.
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+            Cross attention will use the default MHA, as it typically won't require
+            special treatment.
+        layer_scale (float, optional): If not None, LayerScale will be used with
+            the given value as initial scale.
+        rope (`RotaryEmbedding`, optional): Rope embedding to use.
+        attention_dropout (float, optional): If not None, separate the value of the dimension dropout
+            in FFN and of the attention dropout.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device, optional): Device on which to initialize.
+        dtype (torch.dtype, optional): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
+                 bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
+                 past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
+                 kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
+        super().__init__(d_model, num_heads, dim_feedforward, dropout,
+                         device=device, dtype=dtype, batch_first=True, **kwargs)
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        # Redefine self_attn to our streaming multi-head attention
+        attn_kwargs: tp.Dict[str, tp.Any] = {
+            'embed_dim': d_model,
+            'num_heads': num_heads,
+            'dropout': dropout if attention_dropout is None else attention_dropout,
+            'bias': bias_attn,
+            'custom': custom,
+            'memory_efficient': memory_efficient,
+            'attention_as_float32': attention_as_float32,
+        }
+        self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
+            causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
+            kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs)  # type: ignore
+        # Redefine feedforward layers to expose bias parameter
+        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
+        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
+
+        self.layer_scale_1: nn.Module
+        self.layer_scale_2: nn.Module
+        if layer_scale is None:
+            self.layer_scale_1 = nn.Identity()
+            self.layer_scale_2 = nn.Identity()
+        else:
+            self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
+            self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
+
+        self.cross_attention: tp.Optional[nn.Module] = None
+        if cross_attention:
+            self.cross_attention = StreamingMultiheadAttention(
+                cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
+                **attn_kwargs, **factory_kwargs)
+            # Norm and dropout
+            self.dropout_cross = nn.Dropout(dropout)
+            # eps value matching that used in PyTorch reference implementation.
+            self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
+            self.layer_scale_cross: nn.Module
+            if layer_scale is None:
+                self.layer_scale_cross = nn.Identity()
+            else:
+                self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
+        self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+        self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+
+    def _cross_attention_block(self, src: torch.Tensor,
+                               cross_attention_src: torch.Tensor) -> torch.Tensor:
+        assert self.cross_attention is not None
+        # queries are from src, keys and values from cross_attention_src.
+        x = self.cross_attention(
+            src, cross_attention_src, cross_attention_src, need_weights=False)[0]
+        return self.dropout_cross(x)  # type: ignore
+
+    def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None,  # type: ignore
+                src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+                cross_attention_src: tp.Optional[torch.Tensor] = None):
+        if self.cross_attention is None:
+            assert cross_attention_src is None
+        else:
+            assert cross_attention_src is not None
+        x = src
+        if self.norm_first:
+            x = x + self.layer_scale_1(
+                self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+            if cross_attention_src is not None:
+                x = x + self.layer_scale_cross(
+                    self._cross_attention_block(
+                        self.norm_cross(x), cross_attention_src))
+            x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+        else:
+            x = self.norm1(x + self.layer_scale_1(
+                self._sa_block(x, src_mask, src_key_padding_mask)))
+            if cross_attention_src is not None:
+                x = self.norm_cross(
+                    x + self.layer_scale_cross(
+                        self._cross_attention_block(src, cross_attention_src)))
+            x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+        return x
+
+
+class StreamingTransformer(StreamingModule):
+    """Transformer with Streaming / Causal support.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int, optional): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+        layer_scale (float, optional): If not None, LayerScale will be used
+            with the given value as initial scale.
+        positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
+        max_period (float): Maximum period of the time embedding.
+        positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
+        xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
+        lr (float, optional): learning rate override through the `make_optim_group` API.
+        weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
+        layer_class: (subclass of `StreamingTransformerLayer): class to use
+            to initialize the layers, allowing further customization outside of AudioCraft.
+        checkpointing (str): Checkpointing strategy to reduce memory usage.
+            No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
+            if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
+            minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
+            a policy for opting-out some operations of the checkpointing like
+            linear layers and attention, providing a middle ground between speed and memory.
+        device (torch.device, optional): Device on which to initialize.
+        dtype (torch.dtype, optional): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
+                 dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None,
+                 custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
+                 xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
+                 layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
+                 checkpointing: str = 'none', device=None, dtype=None, **kwargs):
+        super().__init__()
+        assert d_model % num_heads == 0
+
+        self.positional_embedding = positional_embedding
+        self.max_period = max_period
+        self.positional_scale = positional_scale
+        self.weight_decay = weight_decay
+        self.lr = lr
+
+        assert positional_embedding in ['sin', 'rope', 'sin_rope']
+        self.rope: tp.Optional[RotaryEmbedding] = None
+        if self.positional_embedding in ['rope', 'sin_rope']:
+            assert _is_custom(custom, memory_efficient)
+            self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
+                                        xpos=xpos, scale=positional_scale, device=device)
+
+        self.checkpointing = checkpointing
+
+        assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
+        if self.checkpointing.startswith('xformers'):
+            _verify_xformers_internal_compat()
+
+        self.layers = nn.ModuleList()
+        for idx in range(num_layers):
+            self.layers.append(
+                layer_class(
+                    d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
+                    dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
+                    causal=causal, past_context=past_context, custom=custom,
+                    memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
+                    cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
+                    device=device, dtype=dtype, **kwargs))
+
+        if self.checkpointing != 'none':
+            for layer in self.layers:
+                # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
+                # backward hook inside of FSDP...
+                layer._magma_checkpointed = True  # type: ignore
+                assert layer.layer_drop == 0., "Need further checking"  # type: ignore
+
+    def _apply_layer(self, layer, *args, **kwargs):
+        method = self.checkpointing
+        if method == 'none':
+            return layer(*args, **kwargs)
+        elif method == 'torch':
+            return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
+        elif method.startswith('xformers'):
+            from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
+            if method == 'xformers_default':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "xformers.efficient_attention_forward_cutlass.default",
+                    "xformers_flash.flash_fwd.default",
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            elif method == 'xformers_mm':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            else:
+                raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
+            policy_fn = _get_default_policy(allow_list)
+            return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
+        else:
+            raise ValueError(f"Checkpointing method {method} is unknown.")
+
+    def forward(self, x: torch.Tensor, *args, **kwargs):
+        B, T, C = x.shape
+
+        if 'offsets' in self._streaming_state:
+            offsets = self._streaming_state['offsets']
+        else:
+            offsets = torch.zeros(B, dtype=torch.long, device=x.device)
+
+        if self.positional_embedding in ['sin', 'sin_rope']:
+            positions = torch.arange(T, device=x.device).view(1, -1, 1)
+            positions = positions + offsets.view(-1, 1, 1)
+            pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
+            x = x + self.positional_scale * pos_emb
+
+        for layer in self.layers:
+            x = self._apply_layer(layer, x, *args, **kwargs)
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return x
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        if self.weight_decay is not None:
+            group["weight_decay"] = self.weight_decay
+        return group
+
+
+# special attention related function
+
+def _verify_xformers_memory_efficient_compat():
+    try:
+        from xformers.ops import memory_efficient_attention, LowerTriangularMask  # noqa
+    except ImportError:
+        raise ImportError(
+            "xformers is not installed. Please install it and try again.\n"
+            "To install on AWS and Azure, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+            "To install on FAIR Cluster, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _verify_xformers_internal_compat():
+    try:
+        from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy  # noqa
+    except ImportError:
+        raise ImportError(
+            "Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
+            "To install on AWS and Azure, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+            "To install on FAIR Cluster, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _is_custom(custom: bool, memory_efficient: bool):
+    return custom or memory_efficient
+
+
+
+
+
+
+
+

Functions

+
+
+def create_norm_fn(norm_type: str, dim: int, **kwargs) ‑> torch.nn.modules.module.Module +
+
+

Create normalization module for transformer encoder layer.

+

Args

+
+
norm_type : str
+
Normalization method.
+
dim : int
+
Dimension of the normalized layer.
+
**kwargs : dict
+
Additional parameters for normalization layer.
+
+

Returns

+
+
nn.Module
+
Normalization module.
+
+
+ +Expand source code + +
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
+    """Create normalization module for transformer encoder layer.
+
+    Args:
+        norm_type (str): Normalization method.
+        dim (int): Dimension of the normalized layer.
+        **kwargs (dict): Additional parameters for normalization layer.
+    Returns:
+        nn.Module: Normalization module.
+    """
+    if norm_type == 'layer_norm':
+        return nn.LayerNorm(dim, eps=1e-5, **kwargs)
+    else:
+        raise ValueError(f"Unknown norm type: {norm_type}")
+
+
+
+def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, dtype: torch.dtype = torch.float32) ‑> torch.Tensor +
+
+

Create sinusoidal positional embedding, with shape [B, T, C].

+

Args

+
+
positions : torch.Tensor
+
LongTensor of positions.
+
dim : int
+
Dimension of the embedding.
+
max_period : float
+
Maximum period of the cosine/sine functions.
+
dtype : torch.dtype or str
+
dtype to use to generate the embedding.
+
+

Returns

+
+
torch.Tensor
+
Sinusoidal positional embedding.
+
+
+ +Expand source code + +
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
+                         dtype: torch.dtype = torch.float32) -> torch.Tensor:
+    """Create sinusoidal positional embedding, with shape `[B, T, C]`.
+
+    Args:
+        positions (torch.Tensor): LongTensor of positions.
+        dim (int): Dimension of the embedding.
+        max_period (float): Maximum period of the cosine/sine functions.
+        dtype (torch.dtype or str): dtype to use to generate the embedding.
+    Returns:
+        torch.Tensor: Sinusoidal positional embedding.
+    """
+    # We aim for BTC format
+    assert dim % 2 == 0
+    half_dim = dim // 2
+    positions = positions.to(dtype)
+    adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
+    max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)  # avoid sync point
+    phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
+    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
+
+
+
+def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) ‑> torch.Tensor +
+
+

torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.

+
+ +Expand source code + +
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
+    """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
+    if n_rep == 1:
+        return x
+    if _efficient_attention_backend == 'torch' and memory_efficient:
+        bs, n_kv_heads, slen, head_dim = x.shape
+        return (
+            x[:, :, None, :, :]
+            .expand(bs, n_kv_heads, n_rep, slen, head_dim)
+            .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
+        )
+    else:
+        bs, slen, n_kv_heads, head_dim = x.shape
+        return (
+            x[:, :, :, None, :]
+            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+        )
+
+
+
+def set_efficient_attention_backend(backend: str = 'torch') +
+
+
+
+ +Expand source code + +
def set_efficient_attention_backend(backend: str = 'torch'):
+    # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
+    global _efficient_attention_backend
+    assert _efficient_attention_backend in ['xformers', 'torch']
+    _efficient_attention_backend = backend
+
+
+
+
+
+

Classes

+
+
+class LayerScale +(channels: int, init: float = 0.0001, channel_last: bool = True, device=None, dtype=None) +
+
+

Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). +This rescales diagonally the residual outputs close to 0, with a learnt scale.

+

Args

+
+
channels : int
+
Number of channels.
+
init : float
+
Initial scale.
+
channel_last : bool
+
If True, expect [*, C] shaped tensors, otherwise, [*, C, T].
+
device : torch.device or str, optional
+
Device on which to initialize the module.
+
dtype : torch.dtype, optional
+
dtype to use to initialize the module.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class LayerScale(nn.Module):
+    """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+    This rescales diagonally the residual outputs close to 0, with a learnt scale.
+
+    Args:
+        channels (int): Number of channels.
+        init (float): Initial scale.
+        channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
+        device (torch.device or str, optional): Device on which to initialize the module.
+        dtype (torch.dtype, optional): dtype to use to initialize the module.
+    """
+    def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
+                 device=None, dtype=None):
+        super().__init__()
+        self.channel_last = channel_last
+        self.scale = nn.Parameter(
+            torch.full((channels,), init,
+                       requires_grad=True, device=device, dtype=dtype))
+
+    def forward(self, x: torch.Tensor):
+        if self.channel_last:
+            return self.scale * x
+        else:
+            return self.scale[:, None] * x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor):
+    if self.channel_last:
+        return self.scale * x
+    else:
+        return self.scale[:, None] * x
+
+
+
+
+
+class StreamingMultiheadAttention +(embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, causal: bool = False, past_context: Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, rope: Optional[RotaryEmbedding] = None, cross_attention: bool = False, safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, device=None, dtype=None) +
+
+

Similar to nn.MultiheadAttention but with support for streaming, causal evaluation.

+

Args

+
+
embed_dim : int
+
Dimension to project to.
+
num_heads : int
+
Number of heads.
+
dropout : float
+
Dropout level.
+
bias : bool
+
Use bias in projections.
+
causal : bool
+
Causal mask applied automatically.
+
past_context : int, optional
+
Receptive field for the causal mask, infinite if None.
+
custom : bool
+
Use custom MHA implementation, for testing / benchmarking.
+
memory_efficient : bool
+
Use xformers based memory efficient attention.
+
attention_as_float32 : bool
+
Perform the attention as float32 +(especially important with memory_efficient as autocast won't do this automatically).
+
rope (RotaryEmbedding, optional): Rope embedding to use.
+
cross_attention
+
Should be true when used as a cross attention. +All keys and values must be available at once, streaming is only for the queries. +Cannot be used with causal or rope (as it wouldn't make sens to +interpret the time steps in the keys relative to those in the queries).
+
safe_streaming : bool
+
Bug fix, will go away with xformers update.
+
qk_layer_norm : bool
+
Layer normalization applied to queries and keys before dot product.
+
kv_repeat : int
+
If > 1, will repeat keys and queries multiple times (need to divide num_heads). +This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+
device : torch.device, optional
+
Device on which to initialize.
+
dtype : torch.dtype, optional
+
dtype to use.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingMultiheadAttention(StreamingModule):
+    """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
+
+    Args:
+        embed_dim (int): Dimension to project to.
+        num_heads (int): Number of heads.
+        dropout (float): Dropout level.
+        bias (bool): Use bias in projections.
+        causal (bool): Causal mask applied automatically.
+        past_context (int, optional): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        rope (`RotaryEmbedding`, optional): Rope embedding to use.
+        cross_attention: Should be true when used as a cross attention.
+            All keys and values must be available at once, streaming is only for the queries.
+            Cannot be used with `causal` or `rope` (as it wouldn't make sens to
+            interpret the time steps in the keys relative to those in the queries).
+        safe_streaming (bool): Bug fix, will go away with xformers update.
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device, optional): Device on which to initialize.
+        dtype (torch.dtype, optional): dtype to use.
+    """
+    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
+                 safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
+                 device=None, dtype=None):
+        super().__init__()
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        if past_context is not None:
+            assert causal
+
+        self.embed_dim = embed_dim
+        self.causal = causal
+        self.past_context = past_context
+        self.memory_efficient = memory_efficient
+        self.attention_as_float32 = attention_as_float32
+        self.rope = rope
+        self.cross_attention = cross_attention
+        self.safe_streaming = safe_streaming
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.kv_repeat = kv_repeat
+        if cross_attention:
+            assert not causal, "Causal cannot work with cross attention."
+            assert rope is None, "Rope cannot work with cross attention."
+
+        if memory_efficient:
+            _verify_xformers_memory_efficient_compat()
+
+        self.custom = _is_custom(custom, memory_efficient)
+        if self.custom:
+            out_dim = embed_dim
+            assert num_heads % kv_repeat == 0
+            assert not cross_attention or kv_repeat == 1
+            num_kv = num_heads // kv_repeat
+            kv_dim = (embed_dim // num_heads) * num_kv
+            out_dim += 2 * kv_dim
+            in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
+            # We try to follow the default PyTorch MHA convention, to easily compare results.
+            self.in_proj_weight = in_proj.weight
+            self.in_proj_bias = in_proj.bias
+            if bias:
+                self.in_proj_bias.data.zero_()  # Following Pytorch convention
+            self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+            if bias:
+                self.out_proj.bias.data.zero_()
+        else:
+            assert not qk_layer_norm
+            assert kv_repeat == 1
+            self.mha = nn.MultiheadAttention(
+                embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
+                **factory_kwargs)
+        self.qk_layer_norm = qk_layer_norm
+        if qk_layer_norm:
+            assert self.custom
+            assert kv_repeat == 1
+            ln_dim = embed_dim
+            self.q_layer_norm = nn.LayerNorm(ln_dim)
+            self.k_layer_norm = nn.LayerNorm(ln_dim)
+
+    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+        if not self.custom:
+            # Support compat with regular MHA
+            keys = [n for n, _ in self.mha.named_parameters()]
+            for key in keys:
+                if prefix + key in state_dict:
+                    state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
+        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+    def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
+        # Return a causal mask, accounting for potentially stored past keys/values
+        # We actually return a bias for the attention score, as this has the same
+        # convention both in the builtin MHA in Pytorch, and Xformers functions.
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        if self.memory_efficient:
+            from xformers.ops import LowerTriangularMask
+            if current_steps == 1:
+                # If we only have one step, then we do not need a mask.
+                return None
+            elif 'past_keys' in self._streaming_state:
+                raise RuntimeError("Not supported at the moment")
+            else:
+                # Then we can safely use a lower triangular mask
+                return LowerTriangularMask()
+        if self._streaming_state:
+            past_keys = self._streaming_state['past_keys']
+            past_steps = past_keys.shape[time_dim]
+        else:
+            past_steps = 0
+
+        queries_pos = torch.arange(
+            past_steps, current_steps + past_steps, device=device).view(-1, 1)
+        keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
+        delta = queries_pos - keys_pos
+        valid = delta >= 0
+        if self.past_context is not None:
+            valid &= (delta <= self.past_context)
+        return torch.where(
+            valid,
+            torch.zeros([], device=device, dtype=dtype),
+            torch.full([], float('-inf'), device=device, dtype=dtype))
+
+    def _complete_kv(self, k, v):
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        if self.cross_attention:
+            # With cross attention we assume all keys and values
+            # are already available, and streaming is with respect
+            # to the queries only.
+            return k, v
+        # Complete the key/value pair using the streaming state.
+        if self._streaming_state:
+            pk = self._streaming_state['past_keys']
+            nk = torch.cat([pk, k], dim=time_dim)
+            if v is k:
+                nv = nk
+            else:
+                pv = self._streaming_state['past_values']
+                nv = torch.cat([pv, v], dim=time_dim)
+        else:
+            nk = k
+            nv = v
+
+        assert nk.shape[time_dim] == nv.shape[time_dim]
+        offset = 0
+        if self.past_context is not None:
+            offset = max(0, nk.shape[time_dim] - self.past_context)
+        if self._is_streaming:
+            self._streaming_state['past_keys'] = nk[:, offset:]
+            if v is not k:
+                self._streaming_state['past_values'] = nv[:, offset:]
+            if 'offset' in self._streaming_state:
+                self._streaming_state['offset'] += offset
+            else:
+                self._streaming_state['offset'] = torch.tensor(0)
+        return nk, nv
+
+    def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        # Apply rope embeddings to query and key tensors.
+        assert self.rope is not None
+        if 'past_keys' in self._streaming_state:
+            past_keys_offset = self._streaming_state['past_keys'].shape[1]
+        else:
+            past_keys_offset = 0
+        if 'offset' in self._streaming_state:
+            past_context_offset = int(self._streaming_state['offset'].item())
+        else:
+            past_context_offset = 0
+        streaming_offset = past_context_offset + past_keys_offset
+        return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
+
+    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
+                key_padding_mask=None, need_weights=False, attn_mask=None,
+                average_attn_weights=True, is_causal=False):
+        assert attn_mask is None
+        assert not is_causal, ("New param added in torch 2.0.1 not supported, "
+                               "use the causal args in the constructor.")
+
+        time_dim = _get_attention_time_dimension(self.memory_efficient)
+        if time_dim == 2:
+            layout = "b h t d"
+        else:
+            layout = "b t h d"
+        dtype = query.dtype
+        if self._is_streaming:
+            assert self.causal or self.cross_attention, \
+                "Streaming only available for causal or cross attention"
+
+        if self.causal:
+            # At the moment we specialize only for the self-attention case.
+            assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
+
+        if self.custom:
+            # custom implementation
+            assert need_weights is False
+            assert key_padding_mask is None
+            if self.cross_attention:
+                # Different queries, keys, values, we have to spit manually the weights
+                # before applying the linear.
+                dim = self.in_proj_weight.shape[0] // 3
+                if self.in_proj_bias is None:
+                    bias_q, bias_k, bias_v = None, None, None
+                else:
+                    bias_q = self.in_proj_bias[:dim]
+                    bias_k = self.in_proj_bias[dim: 2 * dim]
+                    bias_v = self.in_proj_bias[2 * dim:]
+                q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
+                # todo: when streaming, we could actually save k, v and check the shape actually match.
+                k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
+                v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
+                if self.qk_layer_norm is True:
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
+            else:
+                if not _is_profiled():
+                    # profiling breaks that propertysomehow.
+                    assert query is key, "specialized implementation"
+                    assert value is key, "specialized implementation"
+                projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
+                if self.kv_repeat == 1:
+                    if time_dim == 2:
+                        bound_layout = "b h p t d"
+                    else:
+                        bound_layout = "b t p h d"
+                    packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
+                    q, k, v = ops.unbind(packed, dim=2)
+                else:
+                    embed_dim = self.embed_dim
+                    per_head_dim = (embed_dim // self.num_heads)
+                    kv_heads = self.num_heads // self.kv_repeat
+                    q = projected[:, :, :embed_dim]
+                    start = embed_dim
+                    end = start + per_head_dim * kv_heads
+                    k = projected[:, :, start: end]
+                    v = projected[:, :, end:]
+                    q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
+                    k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
+                    v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
+
+                if self.qk_layer_norm is True:
+                    assert self.kv_repeat == 1
+                    q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                    q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
+                if self.rope:
+                    q, k = self._apply_rope(q, k)
+                k, v = self._complete_kv(k, v)
+                if self.kv_repeat > 1:
+                    k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
+                    v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
+            if self.attention_as_float32:
+                q, k, v = [x.float() for x in [q, k, v]]
+            if self.memory_efficient:
+                p = self.dropout if self.training else 0
+                if _efficient_attention_backend == 'torch':
+                    x = torch.nn.functional.scaled_dot_product_attention(
+                        q, k, v, is_causal=attn_mask is not None, dropout_p=p)
+                else:
+                    x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
+            else:
+                # We include the dot product as float32, for consistency
+                # with the other implementations that include that step
+                # as part of the attention. Note that when using `autocast`,
+                # the einsums would be done as bfloat16, but the softmax
+                # would be done as bfloat16, so `attention_as_float32` will
+                # extend a bit the range of operations done in float32,
+                # although this should make no difference.
+                q = q / q.shape[-1] ** 0.5
+                key_layout = layout.replace('t', 'k')
+                query_layout = layout
+                if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
+                    with torch.autocast(device_type=q.device.type, dtype=torch.float32):
+                        pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                else:
+                    pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                if attn_mask is not None:
+                    pre_w = pre_w + attn_mask
+                w = torch.softmax(pre_w, dim=-1)
+                w = F.dropout(w, self.dropout, training=self.training).to(v)
+                # Key and value have the same format.
+                x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
+            x = x.to(dtype)
+            x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
+            x = self.out_proj(x)
+        else:
+            key, value = self._complete_kv(key, value)
+            if self.attention_as_float32:
+                query, key, value = [x.float() for x in [query, key, value]]
+            x, _ = self.mha(
+                query, key, value, key_padding_mask,
+                need_weights, attn_mask, average_attn_weights)
+            x = x.to(dtype)
+
+        return x, None
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class StreamingTransformer +(d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, past_context: Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, cross_attention: bool = False, layer_scale: Optional[float] = None, positional_embedding: str = 'sin', max_period: float = 10000, positional_scale: float = 1.0, xpos: bool = False, lr: Optional[float] = None, weight_decay: Optional[float] = None, layer_class: Type[StreamingTransformerLayer] = audiocraft.modules.transformer.StreamingTransformerLayer, checkpointing: str = 'none', device=None, dtype=None, **kwargs) +
+
+

Transformer with Streaming / Causal support.

+

Args

+
+
d_model : int
+
Dimension of the data.
+
num_heads : int
+
Number of heads.
+
dim_feedforward : int
+
Intermediate dimension of FF module.
+
dropout : float
+
Dropout both for MHA and FF.
+
bias_ff : bool
+
Use bias for FF.
+
bias_attn : bool
+
Use bias for MHA.
+
causal : bool
+
Causal mask applied automatically.
+
past_context : int, optional
+
Receptive field for the causal mask, infinite if None.
+
custom : bool
+
Use custom MHA implementation, for testing / benchmarking.
+
memory_efficient : bool
+
Use xformers based memory efficient attention.
+
attention_as_float32 : bool
+
Perform the attention as float32 +(especially important with memory_efficient as autocast won't do this automatically).
+
cross_attention : bool
+
If True, expect to get secondary input for cross-attention.
+
layer_scale : float, optional
+
If not None, LayerScale will be used +with the given value as initial scale.
+
positional_embedding : str
+
Positional embedding strategy (sin, rope, or sin_rope).
+
max_period : float
+
Maximum period of the time embedding.
+
positional_scale : float
+
Scale of positional embedding, set to 0 to deactivate.
+
xpos : bool
+
Apply xpos exponential decay to positional embedding (rope only).
+
lr : float, optional
+
learning rate override through the make_optim_group API.
+
weight_decay : float, optional
+
Weight_decay override through the make_optim_group API.
+
layer_class
+
(subclass of `StreamingTransformerLayer): class to use +to initialize the layers, allowing further customization outside of AudioCraft.
+
checkpointing : str
+
Checkpointing strategy to reduce memory usage. +No checkpointing if set to 'none'. Per layer checkpointing using PyTorch +if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, +minimal memory usage, but maximal runtime). Finally, xformers_default provide +a policy for opting-out some operations of the checkpointing like +linear layers and attention, providing a middle ground between speed and memory.
+
device : torch.device, optional
+
Device on which to initialize.
+
dtype : torch.dtype, optional
+
dtype to use.
+
**kwargs
+
See nn.TransformerEncoderLayer.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingTransformer(StreamingModule):
+    """Transformer with Streaming / Causal support.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int, optional): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+        layer_scale (float, optional): If not None, LayerScale will be used
+            with the given value as initial scale.
+        positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
+        max_period (float): Maximum period of the time embedding.
+        positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
+        xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
+        lr (float, optional): learning rate override through the `make_optim_group` API.
+        weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
+        layer_class: (subclass of `StreamingTransformerLayer): class to use
+            to initialize the layers, allowing further customization outside of AudioCraft.
+        checkpointing (str): Checkpointing strategy to reduce memory usage.
+            No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
+            if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
+            minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
+            a policy for opting-out some operations of the checkpointing like
+            linear layers and attention, providing a middle ground between speed and memory.
+        device (torch.device, optional): Device on which to initialize.
+        dtype (torch.dtype, optional): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
+                 dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None,
+                 custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
+                 xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
+                 layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
+                 checkpointing: str = 'none', device=None, dtype=None, **kwargs):
+        super().__init__()
+        assert d_model % num_heads == 0
+
+        self.positional_embedding = positional_embedding
+        self.max_period = max_period
+        self.positional_scale = positional_scale
+        self.weight_decay = weight_decay
+        self.lr = lr
+
+        assert positional_embedding in ['sin', 'rope', 'sin_rope']
+        self.rope: tp.Optional[RotaryEmbedding] = None
+        if self.positional_embedding in ['rope', 'sin_rope']:
+            assert _is_custom(custom, memory_efficient)
+            self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
+                                        xpos=xpos, scale=positional_scale, device=device)
+
+        self.checkpointing = checkpointing
+
+        assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
+        if self.checkpointing.startswith('xformers'):
+            _verify_xformers_internal_compat()
+
+        self.layers = nn.ModuleList()
+        for idx in range(num_layers):
+            self.layers.append(
+                layer_class(
+                    d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
+                    dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
+                    causal=causal, past_context=past_context, custom=custom,
+                    memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
+                    cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
+                    device=device, dtype=dtype, **kwargs))
+
+        if self.checkpointing != 'none':
+            for layer in self.layers:
+                # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
+                # backward hook inside of FSDP...
+                layer._magma_checkpointed = True  # type: ignore
+                assert layer.layer_drop == 0., "Need further checking"  # type: ignore
+
+    def _apply_layer(self, layer, *args, **kwargs):
+        method = self.checkpointing
+        if method == 'none':
+            return layer(*args, **kwargs)
+        elif method == 'torch':
+            return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
+        elif method.startswith('xformers'):
+            from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
+            if method == 'xformers_default':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "xformers.efficient_attention_forward_cutlass.default",
+                    "xformers_flash.flash_fwd.default",
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            elif method == 'xformers_mm':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            else:
+                raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
+            policy_fn = _get_default_policy(allow_list)
+            return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
+        else:
+            raise ValueError(f"Checkpointing method {method} is unknown.")
+
+    def forward(self, x: torch.Tensor, *args, **kwargs):
+        B, T, C = x.shape
+
+        if 'offsets' in self._streaming_state:
+            offsets = self._streaming_state['offsets']
+        else:
+            offsets = torch.zeros(B, dtype=torch.long, device=x.device)
+
+        if self.positional_embedding in ['sin', 'sin_rope']:
+            positions = torch.arange(T, device=x.device).view(1, -1, 1)
+            positions = positions + offsets.view(-1, 1, 1)
+            pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
+            x = x + self.positional_scale * pos_emb
+
+        for layer in self.layers:
+            x = self._apply_layer(layer, x, *args, **kwargs)
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return x
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        if self.weight_decay is not None:
+            group["weight_decay"] = self.weight_decay
+        return group
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def make_optim_group(self) +
+
+
+
+ +Expand source code + +
def make_optim_group(self):
+    group = {"params": list(self.parameters())}
+    if self.lr is not None:
+        group["lr"] = self.lr
+    if self.weight_decay is not None:
+        group["weight_decay"] = self.weight_decay
+    return group
+
+
+
+

Inherited members

+ +
+
+class StreamingTransformerLayer +(d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, past_context: Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, cross_attention: bool = False, layer_scale: Optional[float] = None, rope: Optional[RotaryEmbedding] = None, attention_dropout: Optional[float] = None, kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs) +
+
+

TransformerLayer with Streaming / Causal support. +This also integrates cross_attention, when passing cross_attention=True, +rather than having two separate classes like in PyTorch.

+

Args

+
+
d_model : int
+
Dimension of the data.
+
num_heads : int
+
Number of heads.
+
dim_feedforward : int
+
Intermediate dimension of FF module.
+
dropout : float
+
Dropout both for MHA and FF.
+
bias_ff : bool
+
Use bias for FF.
+
bias_attn : bool
+
Use bias for MHA.
+
causal : bool
+
Causal mask applied automatically.
+
past_context : int, optional
+
Receptive field for the causal mask, infinite if None.
+
custom : bool
+
Use custom MHA implementation, for testing / benchmarking.
+
memory_efficient : bool
+
Use xformers based memory efficient attention.
+
attention_as_float32 : bool
+
Perform the attention as float32 +(especially important with memory_efficient as autocast won't do this automatically).
+
qk_layer_norm : bool
+
Layer normalization applied to queries and keys before dot product in attention.
+
qk_layer_norm_cross : bool
+
Same for the cross attention.
+
cross_attention : bool
+
If True, expect to get secondary input for cross-attention. +Cross attention will use the default MHA, as it typically won't require +special treatment.
+
layer_scale : float, optional
+
If not None, LayerScale will be used with +the given value as initial scale.
+
rope (RotaryEmbedding, optional): Rope embedding to use.
+
attention_dropout : float, optional
+
If not None, separate the value of the dimension dropout +in FFN and of the attention dropout.
+
kv_repeat : int
+
If > 1, will repeat keys and queries multiple times (need to divide num_heads). +This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+
device : torch.device, optional
+
Device on which to initialize.
+
dtype : torch.dtype, optional
+
dtype to use.
+
**kwargs
+
See nn.TransformerEncoderLayer.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
+    """TransformerLayer with Streaming / Causal support.
+    This also integrates cross_attention, when passing `cross_attention=True`,
+    rather than having two separate classes like in PyTorch.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int, optional): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
+        qk_layer_norm_cross (bool): Same for the cross attention.
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+            Cross attention will use the default MHA, as it typically won't require
+            special treatment.
+        layer_scale (float, optional): If not None, LayerScale will be used with
+            the given value as initial scale.
+        rope (`RotaryEmbedding`, optional): Rope embedding to use.
+        attention_dropout (float, optional): If not None, separate the value of the dimension dropout
+            in FFN and of the attention dropout.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device, optional): Device on which to initialize.
+        dtype (torch.dtype, optional): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
+                 bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
+                 past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
+                 kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
+        super().__init__(d_model, num_heads, dim_feedforward, dropout,
+                         device=device, dtype=dtype, batch_first=True, **kwargs)
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        # Redefine self_attn to our streaming multi-head attention
+        attn_kwargs: tp.Dict[str, tp.Any] = {
+            'embed_dim': d_model,
+            'num_heads': num_heads,
+            'dropout': dropout if attention_dropout is None else attention_dropout,
+            'bias': bias_attn,
+            'custom': custom,
+            'memory_efficient': memory_efficient,
+            'attention_as_float32': attention_as_float32,
+        }
+        self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
+            causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
+            kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs)  # type: ignore
+        # Redefine feedforward layers to expose bias parameter
+        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
+        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
+
+        self.layer_scale_1: nn.Module
+        self.layer_scale_2: nn.Module
+        if layer_scale is None:
+            self.layer_scale_1 = nn.Identity()
+            self.layer_scale_2 = nn.Identity()
+        else:
+            self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
+            self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
+
+        self.cross_attention: tp.Optional[nn.Module] = None
+        if cross_attention:
+            self.cross_attention = StreamingMultiheadAttention(
+                cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
+                **attn_kwargs, **factory_kwargs)
+            # Norm and dropout
+            self.dropout_cross = nn.Dropout(dropout)
+            # eps value matching that used in PyTorch reference implementation.
+            self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
+            self.layer_scale_cross: nn.Module
+            if layer_scale is None:
+                self.layer_scale_cross = nn.Identity()
+            else:
+                self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
+        self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+        self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+
+    def _cross_attention_block(self, src: torch.Tensor,
+                               cross_attention_src: torch.Tensor) -> torch.Tensor:
+        assert self.cross_attention is not None
+        # queries are from src, keys and values from cross_attention_src.
+        x = self.cross_attention(
+            src, cross_attention_src, cross_attention_src, need_weights=False)[0]
+        return self.dropout_cross(x)  # type: ignore
+
+    def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None,  # type: ignore
+                src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+                cross_attention_src: tp.Optional[torch.Tensor] = None):
+        if self.cross_attention is None:
+            assert cross_attention_src is None
+        else:
+            assert cross_attention_src is not None
+        x = src
+        if self.norm_first:
+            x = x + self.layer_scale_1(
+                self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+            if cross_attention_src is not None:
+                x = x + self.layer_scale_cross(
+                    self._cross_attention_block(
+                        self.norm_cross(x), cross_attention_src))
+            x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+        else:
+            x = self.norm1(x + self.layer_scale_1(
+                self._sa_block(x, src_mask, src_key_padding_mask)))
+            if cross_attention_src is not None:
+                x = self.norm_cross(
+                    x + self.layer_scale_cross(
+                        self._cross_attention_block(src, cross_attention_src)))
+            x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.transformer.TransformerEncoderLayer
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, cross_attention_src: Optional[torch.Tensor] = None) ‑> Callable[..., Any] +
+
+

Pass the input through the encoder layer.

+

Args

+
+
src
+
the sequence to the encoder layer (required).
+
src_mask
+
the mask for the src sequence (optional).
+
src_key_padding_mask
+
the mask for the src keys per batch (optional).
+
is_causal
+
If specified, applies a causal mask as src mask. +Default: False. +Warning: +is_causal provides a hint that src_mask is the +causal mask. Providing incorrect hints can result in +incorrect execution, including forward and backward +compatibility.
+
+

Shape

+

see the docs in Transformer class.

+
+ +Expand source code + +
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None,  # type: ignore
+            src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+            cross_attention_src: tp.Optional[torch.Tensor] = None):
+    if self.cross_attention is None:
+        assert cross_attention_src is None
+    else:
+        assert cross_attention_src is not None
+    x = src
+    if self.norm_first:
+        x = x + self.layer_scale_1(
+            self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+        if cross_attention_src is not None:
+            x = x + self.layer_scale_cross(
+                self._cross_attention_block(
+                    self.norm_cross(x), cross_attention_src))
+        x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+    else:
+        x = self.norm1(x + self.layer_scale_1(
+            self._sa_block(x, src_mask, src_key_padding_mask)))
+        if cross_attention_src is not None:
+            x = self.norm_cross(
+                x + self.layer_scale_cross(
+                    self._cross_attention_block(src, cross_attention_src)))
+        x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+    return x
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/cosine_lr_scheduler.html b/api_docs/audiocraft/optim/cosine_lr_scheduler.html new file mode 100644 index 00000000..5a35dae1 --- /dev/null +++ b/api_docs/audiocraft/optim/cosine_lr_scheduler.html @@ -0,0 +1,201 @@ + + + + + + +audiocraft.optim.cosine_lr_scheduler API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim.cosine_lr_scheduler

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class CosineLRScheduler(_LRScheduler):
+    """Cosine LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        total_steps (int): Total number of steps.
+        lr_min_ratio (float): Minimum learning rate.
+        cycle_length (float): Cycle length.
+    """
+    def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
+                 lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
+        self.warmup_steps = warmup_steps
+        assert self.warmup_steps >= 0
+        self.total_steps = total_steps
+        assert self.total_steps >= 0
+        self.lr_min_ratio = lr_min_ratio
+        self.cycle_length = cycle_length
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if step < self.warmup_steps:
+            lr_ratio = step / self.warmup_steps
+            lr = lr_ratio * lr
+        elif step <= self.total_steps:
+            s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
+            lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
+                (1. + math.cos(math.pi * s / self.cycle_length))
+            lr = lr_ratio * lr
+        else:
+            lr_ratio = self.lr_min_ratio
+            lr = lr_ratio * lr
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CosineLRScheduler +(optimizer: torch.optim.optimizer.Optimizer, total_steps: int, warmup_steps: int, lr_min_ratio: float = 0.0, cycle_length: float = 1.0) +
+
+

Cosine LR scheduler.

+

Args

+
+
optimizer : Optimizer
+
Torch optimizer.
+
warmup_steps : int
+
Number of warmup steps.
+
total_steps : int
+
Total number of steps.
+
lr_min_ratio : float
+
Minimum learning rate.
+
cycle_length : float
+
Cycle length.
+
+
+ +Expand source code + +
class CosineLRScheduler(_LRScheduler):
+    """Cosine LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        total_steps (int): Total number of steps.
+        lr_min_ratio (float): Minimum learning rate.
+        cycle_length (float): Cycle length.
+    """
+    def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
+                 lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
+        self.warmup_steps = warmup_steps
+        assert self.warmup_steps >= 0
+        self.total_steps = total_steps
+        assert self.total_steps >= 0
+        self.lr_min_ratio = lr_min_ratio
+        self.cycle_length = cycle_length
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if step < self.warmup_steps:
+            lr_ratio = step / self.warmup_steps
+            lr = lr_ratio * lr
+        elif step <= self.total_steps:
+            s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
+            lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
+                (1. + math.cos(math.pi * s / self.cycle_length))
+            lr = lr_ratio * lr
+        else:
+            lr_ratio = self.lr_min_ratio
+            lr = lr_ratio * lr
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
+
+

Ancestors

+
    +
  • torch.optim.lr_scheduler._LRScheduler
  • +
  • torch.optim.lr_scheduler.LRScheduler
  • +
+

Methods

+
+
+def get_lr(self) +
+
+
+
+ +Expand source code + +
def get_lr(self):
+    return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/dadam.html b/api_docs/audiocraft/optim/dadam.html new file mode 100644 index 00000000..72b5c5b6 --- /dev/null +++ b/api_docs/audiocraft/optim/dadam.html @@ -0,0 +1,823 @@ + + + + + + +audiocraft.optim.dadam API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim.dadam

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Any
+
+import torch
+import torch.optim
+import torch.distributed as dist
+
+
+logger = logging.getLogger(__name__)
+_params_t = Any
+
+
+def to_real(x):
+    if torch.is_complex(x):
+        return x.real
+    else:
+        return x
+
+
+class DAdaptAdam(torch.optim.Optimizer):
+    """Adam with D-Adaptation automatic step-sizes.
+    Leave LR set to 1 unless you encounter instability.
+
+    Args:
+        params (iterable):
+            Iterable of parameters to optimize or dicts defining parameter groups.
+        lr (float):
+            Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
+        betas (tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        momentum (float):
+            Momentum value in  the range [0,1) (default: 0.9).
+        eps (float):
+            Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
+        weight_decay (float):
+            Weight decay, i.e. a L2 penalty (default: 0).
+        log_every (int):
+            Log using print every k steps, default 0 (no logging).
+        decouple (boolean):
+            Use AdamW style decoupled weight decay
+        d0 (float):
+            Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
+        growth_rate (float):
+            prevent the D estimate from growing faster than this multiplicative rate.
+            Default is inf, for unrestricted. Values like 1.02 give a kind of learning
+            rate warmup effect.
+        fsdp_in_use (bool):
+            If you're using sharded parameters, this should be set to True. The optimizer
+            will attempt to auto-detect this, but if you're using an implementation other
+            than PyTorch's builtin version, the auto-detection won't work.
+    """
+    def __init__(self, params, lr=1.0,
+                 betas=(0.9, 0.999),
+                 eps=1e-8,
+                 weight_decay=0,
+                 log_every=0,
+                 decouple=True,
+                 d0=1e-6,
+                 growth_rate=float('inf')):
+        if not 0.0 < d0:
+            raise ValueError("Invalid d0 value: {}".format(d0))
+        if not 0.0 < lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 < eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+        if decouple:
+            logger.info("Using decoupled weight decay")
+
+        from .fsdp import is_fsdp_used
+        fsdp_in_use = is_fsdp_used()
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay,
+                        d=d0,
+                        k=0,
+                        gsq_weighted=0.0,
+                        log_every=log_every,
+                        decouple=decouple,
+                        growth_rate=growth_rate,
+                        fsdp_in_use=fsdp_in_use)
+
+        super().__init__(params, defaults)
+
+    @property
+    def supports_memory_efficient_fp16(self):
+        return False
+
+    @property
+    def supports_flat_params(self):
+        return True
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        g_sq = 0.0
+        sksq_weighted = 0.0
+        sk_l1 = 0.0
+
+        lr = max(group['lr'] for group in self.param_groups)
+
+        group = self.param_groups[0]
+        gsq_weighted = group['gsq_weighted']
+        d = group['d']
+        dlr = d*lr
+
+        growth_rate = group['growth_rate']
+        decouple = group['decouple']
+        fsdp_in_use = group['fsdp_in_use']
+        log_every = group['log_every']
+
+        beta1, beta2 = group['betas']
+
+        for group in self.param_groups:
+            group_lr = group['lr']
+            decay = group['weight_decay']
+            k = group['k']
+            eps = group['eps']
+
+            if group_lr not in [lr, 0.0]:
+                raise RuntimeError("Setting different lr values in different parameter "
+                                   "groups is only supported for values of 0")
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                if hasattr(p, "_fsdp_flattened"):
+                    fsdp_in_use = True
+                grad = p.grad.data
+
+                # Apply weight decay (coupled variant)
+                if decay != 0 and not decouple:
+                    grad.add_(p.data, alpha=decay)
+
+                state = self.state[p]
+
+                # State initialization
+                if 'step' not in state:
+                    state['step'] = 0
+                    state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(
+                        to_real(p.data), memory_format=torch.preserve_format).detach()
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+                grad_grad = to_real(grad * grad.conj())
+
+                # Adam EMA updates
+                if group_lr > 0:
+                    exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1))
+                    exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2)
+
+                    denom = exp_avg_sq.sqrt().add_(eps)
+
+                    g_sq += grad_grad.div_(denom).sum().item()
+
+                    s = state['s']
+                    s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2))
+                    sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
+                    sk_l1 += s.abs().sum().item()
+
+            ######
+
+        gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2)
+        d_hat = d
+
+        # if we have not done any progres, return
+        # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
+        if sk_l1 == 0:
+            return loss
+
+        if lr > 0.0:
+            if fsdp_in_use:
+                dist_tensor = torch.zeros(3, device='cuda')
+                dist_tensor[0] = sksq_weighted
+                dist_tensor[1] = gsq_weighted
+                dist_tensor[2] = sk_l1
+                dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
+                global_sksq_weighted = dist_tensor[0]
+                global_gsq_weighted = dist_tensor[1]
+                global_sk_l1 = dist_tensor[2]
+            else:
+                global_sksq_weighted = sksq_weighted
+                global_gsq_weighted = gsq_weighted
+                global_sk_l1 = sk_l1
+
+            d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1
+            d = max(d, min(d_hat, d*growth_rate))
+
+        if log_every > 0 and k % log_every == 0:
+            logger.info(
+                f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. "
+                f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} "
+                f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}")
+
+        for group in self.param_groups:
+            group['gsq_weighted'] = gsq_weighted
+            group['d'] = d
+
+            group_lr = group['lr']
+            decay = group['weight_decay']
+            k = group['k']
+            eps = group['eps']
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+
+                state = self.state[p]
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+                state['step'] += 1
+
+                denom = exp_avg_sq.sqrt().add_(eps)
+                denom = denom.type(p.type())
+
+                # Apply weight decay (decoupled variant)
+                if decay != 0 and decouple and group_lr > 0:
+                    p.data.add_(p.data, alpha=-decay * dlr)
+
+                # Take step
+                p.data.addcdiv_(exp_avg, denom, value=-1)
+
+            group['k'] = k + 1
+
+        return loss
+
+
+
+
+
+
+
+

Functions

+
+
+def to_real(x) +
+
+
+
+ +Expand source code + +
def to_real(x):
+    if torch.is_complex(x):
+        return x.real
+    else:
+        return x
+
+
+
+
+
+

Classes

+
+
+class DAdaptAdam +(params, lr=1.0, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, log_every=0, decouple=True, d0=1e-06, growth_rate=inf) +
+
+

Adam with D-Adaptation automatic step-sizes. +Leave LR set to 1 unless you encounter instability.

+

Args

+
+
params (iterable):
+
Iterable of parameters to optimize or dicts defining parameter groups.
+
lr (float):
+
Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
+
betas : tuple[float, float], optional
+
coefficients used for computing +running averages of gradient and its square (default: (0.9, 0.999))
+
+

momentum (float): +Momentum value in +the range [0,1) (default: 0.9). +eps (float): +Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). +weight_decay (float): +Weight decay, i.e. a L2 penalty (default: 0). +log_every (int): +Log using print every k steps, default 0 (no logging). +decouple (boolean): +Use AdamW style decoupled weight decay +d0 (float): +Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. +growth_rate (float): +prevent the D estimate from growing faster than this multiplicative rate. +Default is inf, for unrestricted. Values like 1.02 give a kind of learning +rate warmup effect. +fsdp_in_use (bool): +If you're using sharded parameters, this should be set to True. The optimizer +will attempt to auto-detect this, but if you're using an implementation other +than PyTorch's builtin version, the auto-detection won't work.

+
+ +Expand source code + +
class DAdaptAdam(torch.optim.Optimizer):
+    """Adam with D-Adaptation automatic step-sizes.
+    Leave LR set to 1 unless you encounter instability.
+
+    Args:
+        params (iterable):
+            Iterable of parameters to optimize or dicts defining parameter groups.
+        lr (float):
+            Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
+        betas (tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        momentum (float):
+            Momentum value in  the range [0,1) (default: 0.9).
+        eps (float):
+            Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
+        weight_decay (float):
+            Weight decay, i.e. a L2 penalty (default: 0).
+        log_every (int):
+            Log using print every k steps, default 0 (no logging).
+        decouple (boolean):
+            Use AdamW style decoupled weight decay
+        d0 (float):
+            Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
+        growth_rate (float):
+            prevent the D estimate from growing faster than this multiplicative rate.
+            Default is inf, for unrestricted. Values like 1.02 give a kind of learning
+            rate warmup effect.
+        fsdp_in_use (bool):
+            If you're using sharded parameters, this should be set to True. The optimizer
+            will attempt to auto-detect this, but if you're using an implementation other
+            than PyTorch's builtin version, the auto-detection won't work.
+    """
+    def __init__(self, params, lr=1.0,
+                 betas=(0.9, 0.999),
+                 eps=1e-8,
+                 weight_decay=0,
+                 log_every=0,
+                 decouple=True,
+                 d0=1e-6,
+                 growth_rate=float('inf')):
+        if not 0.0 < d0:
+            raise ValueError("Invalid d0 value: {}".format(d0))
+        if not 0.0 < lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 < eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+        if decouple:
+            logger.info("Using decoupled weight decay")
+
+        from .fsdp import is_fsdp_used
+        fsdp_in_use = is_fsdp_used()
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay,
+                        d=d0,
+                        k=0,
+                        gsq_weighted=0.0,
+                        log_every=log_every,
+                        decouple=decouple,
+                        growth_rate=growth_rate,
+                        fsdp_in_use=fsdp_in_use)
+
+        super().__init__(params, defaults)
+
+    @property
+    def supports_memory_efficient_fp16(self):
+        return False
+
+    @property
+    def supports_flat_params(self):
+        return True
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        g_sq = 0.0
+        sksq_weighted = 0.0
+        sk_l1 = 0.0
+
+        lr = max(group['lr'] for group in self.param_groups)
+
+        group = self.param_groups[0]
+        gsq_weighted = group['gsq_weighted']
+        d = group['d']
+        dlr = d*lr
+
+        growth_rate = group['growth_rate']
+        decouple = group['decouple']
+        fsdp_in_use = group['fsdp_in_use']
+        log_every = group['log_every']
+
+        beta1, beta2 = group['betas']
+
+        for group in self.param_groups:
+            group_lr = group['lr']
+            decay = group['weight_decay']
+            k = group['k']
+            eps = group['eps']
+
+            if group_lr not in [lr, 0.0]:
+                raise RuntimeError("Setting different lr values in different parameter "
+                                   "groups is only supported for values of 0")
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                if hasattr(p, "_fsdp_flattened"):
+                    fsdp_in_use = True
+                grad = p.grad.data
+
+                # Apply weight decay (coupled variant)
+                if decay != 0 and not decouple:
+                    grad.add_(p.data, alpha=decay)
+
+                state = self.state[p]
+
+                # State initialization
+                if 'step' not in state:
+                    state['step'] = 0
+                    state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(
+                        to_real(p.data), memory_format=torch.preserve_format).detach()
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+                grad_grad = to_real(grad * grad.conj())
+
+                # Adam EMA updates
+                if group_lr > 0:
+                    exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1))
+                    exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2)
+
+                    denom = exp_avg_sq.sqrt().add_(eps)
+
+                    g_sq += grad_grad.div_(denom).sum().item()
+
+                    s = state['s']
+                    s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2))
+                    sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
+                    sk_l1 += s.abs().sum().item()
+
+            ######
+
+        gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2)
+        d_hat = d
+
+        # if we have not done any progres, return
+        # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
+        if sk_l1 == 0:
+            return loss
+
+        if lr > 0.0:
+            if fsdp_in_use:
+                dist_tensor = torch.zeros(3, device='cuda')
+                dist_tensor[0] = sksq_weighted
+                dist_tensor[1] = gsq_weighted
+                dist_tensor[2] = sk_l1
+                dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
+                global_sksq_weighted = dist_tensor[0]
+                global_gsq_weighted = dist_tensor[1]
+                global_sk_l1 = dist_tensor[2]
+            else:
+                global_sksq_weighted = sksq_weighted
+                global_gsq_weighted = gsq_weighted
+                global_sk_l1 = sk_l1
+
+            d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1
+            d = max(d, min(d_hat, d*growth_rate))
+
+        if log_every > 0 and k % log_every == 0:
+            logger.info(
+                f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. "
+                f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} "
+                f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}")
+
+        for group in self.param_groups:
+            group['gsq_weighted'] = gsq_weighted
+            group['d'] = d
+
+            group_lr = group['lr']
+            decay = group['weight_decay']
+            k = group['k']
+            eps = group['eps']
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+
+                state = self.state[p]
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+                state['step'] += 1
+
+                denom = exp_avg_sq.sqrt().add_(eps)
+                denom = denom.type(p.type())
+
+                # Apply weight decay (decoupled variant)
+                if decay != 0 and decouple and group_lr > 0:
+                    p.data.add_(p.data, alpha=-decay * dlr)
+
+                # Take step
+                p.data.addcdiv_(exp_avg, denom, value=-1)
+
+            group['k'] = k + 1
+
+        return loss
+
+

Ancestors

+
    +
  • torch.optim.optimizer.Optimizer
  • +
+

Class variables

+
+
var OptimizerPostHook : typing_extensions.TypeAlias
+
+
+
+
var OptimizerPreHook : typing_extensions.TypeAlias
+
+
+
+
+

Instance variables

+
+
var supports_flat_params
+
+
+
+ +Expand source code + +
@property
+def supports_flat_params(self):
+    return True
+
+
+
var supports_memory_efficient_fp16
+
+
+
+ +Expand source code + +
@property
+def supports_memory_efficient_fp16(self):
+    return False
+
+
+
+

Methods

+
+
+def step(self, closure=None) +
+
+

Performs a single optimization step.

+

Args

+
+
closure : callable, optional
+
A closure that reevaluates the model +and returns the loss.
+
+
+ +Expand source code + +
def step(self, closure=None):
+    """Performs a single optimization step.
+
+    Args:
+        closure (callable, optional): A closure that reevaluates the model
+            and returns the loss.
+    """
+    loss = None
+    if closure is not None:
+        loss = closure()
+
+    g_sq = 0.0
+    sksq_weighted = 0.0
+    sk_l1 = 0.0
+
+    lr = max(group['lr'] for group in self.param_groups)
+
+    group = self.param_groups[0]
+    gsq_weighted = group['gsq_weighted']
+    d = group['d']
+    dlr = d*lr
+
+    growth_rate = group['growth_rate']
+    decouple = group['decouple']
+    fsdp_in_use = group['fsdp_in_use']
+    log_every = group['log_every']
+
+    beta1, beta2 = group['betas']
+
+    for group in self.param_groups:
+        group_lr = group['lr']
+        decay = group['weight_decay']
+        k = group['k']
+        eps = group['eps']
+
+        if group_lr not in [lr, 0.0]:
+            raise RuntimeError("Setting different lr values in different parameter "
+                               "groups is only supported for values of 0")
+
+        for p in group['params']:
+            if p.grad is None:
+                continue
+            if hasattr(p, "_fsdp_flattened"):
+                fsdp_in_use = True
+            grad = p.grad.data
+
+            # Apply weight decay (coupled variant)
+            if decay != 0 and not decouple:
+                grad.add_(p.data, alpha=decay)
+
+            state = self.state[p]
+
+            # State initialization
+            if 'step' not in state:
+                state['step'] = 0
+                state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+                # Exponential moving average of gradient values
+                state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
+                # Exponential moving average of squared gradient values
+                state['exp_avg_sq'] = torch.zeros_like(
+                    to_real(p.data), memory_format=torch.preserve_format).detach()
+
+            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+            grad_grad = to_real(grad * grad.conj())
+
+            # Adam EMA updates
+            if group_lr > 0:
+                exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1))
+                exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2)
+
+                denom = exp_avg_sq.sqrt().add_(eps)
+
+                g_sq += grad_grad.div_(denom).sum().item()
+
+                s = state['s']
+                s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2))
+                sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
+                sk_l1 += s.abs().sum().item()
+
+        ######
+
+    gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2)
+    d_hat = d
+
+    # if we have not done any progres, return
+    # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
+    if sk_l1 == 0:
+        return loss
+
+    if lr > 0.0:
+        if fsdp_in_use:
+            dist_tensor = torch.zeros(3, device='cuda')
+            dist_tensor[0] = sksq_weighted
+            dist_tensor[1] = gsq_weighted
+            dist_tensor[2] = sk_l1
+            dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
+            global_sksq_weighted = dist_tensor[0]
+            global_gsq_weighted = dist_tensor[1]
+            global_sk_l1 = dist_tensor[2]
+        else:
+            global_sksq_weighted = sksq_weighted
+            global_gsq_weighted = gsq_weighted
+            global_sk_l1 = sk_l1
+
+        d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1
+        d = max(d, min(d_hat, d*growth_rate))
+
+    if log_every > 0 and k % log_every == 0:
+        logger.info(
+            f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. "
+            f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} "
+            f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}")
+
+    for group in self.param_groups:
+        group['gsq_weighted'] = gsq_weighted
+        group['d'] = d
+
+        group_lr = group['lr']
+        decay = group['weight_decay']
+        k = group['k']
+        eps = group['eps']
+
+        for p in group['params']:
+            if p.grad is None:
+                continue
+            grad = p.grad.data
+
+            state = self.state[p]
+
+            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+            state['step'] += 1
+
+            denom = exp_avg_sq.sqrt().add_(eps)
+            denom = denom.type(p.type())
+
+            # Apply weight decay (decoupled variant)
+            if decay != 0 and decouple and group_lr > 0:
+                p.data.add_(p.data, alpha=-decay * dlr)
+
+            # Take step
+            p.data.addcdiv_(exp_avg, denom, value=-1)
+
+        group['k'] = k + 1
+
+    return loss
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/ema.html b/api_docs/audiocraft/optim/ema.html new file mode 100644 index 00000000..6386e2c5 --- /dev/null +++ b/api_docs/audiocraft/optim/ema.html @@ -0,0 +1,273 @@ + + + + + + +audiocraft.optim.ema API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim.ema

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# ModelEMA implementation is taken from
+# https://github.com/facebookresearch/demucs
+
+from collections import defaultdict
+import typing as tp
+
+import torch
+import torch.nn as nn
+
+
+def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set:
+    names: set = set()
+    for (name, sub_module) in module.named_modules():
+        if name == '':
+            buffer_names = module._non_persistent_buffers_set
+            buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name
+                            for buff_name in buffer_names}
+            names.update(buffer_names)
+        else:
+            sub_name = f"{root}.{name}" if len(root) > 0 else name
+            sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name)
+            names.update(sub_buffer_names)
+    return names
+
+
+def _get_named_tensors(module: nn.Module):
+    non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module)
+    named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers()
+                     if name not in non_persistent_buffers_set]
+    named_parameters = list(module.named_parameters())
+    return named_parameters + named_buffers
+
+
+class ModuleDictEMA:
+    """Exponential Moving Average over a nn.ModuleDict.
+
+    You can switch to the EMA weights temporarily.
+    """
+    def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
+                 unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'):
+        self.decay = decay
+        self.module_dict = module_dict
+        self.state: dict = defaultdict(dict)
+        self.count = 0
+        self.device = device
+        self.unbias = unbias
+        self._init()
+
+    def _init(self):
+        for module_name, module in self.module_dict.items():
+            for key, val in _get_named_tensors(module):
+                if not val.is_floating_point():
+                    continue
+                device = self.device or val.device
+                if key not in self.state[module_name]:
+                    self.state[module_name][key] = val.detach().to(device, copy=True)
+
+    def step(self):
+        if self.unbias:
+            self.count = self.count * self.decay + 1
+            w = 1 / self.count
+        else:
+            w = 1 - self.decay
+        for module_name, module in self.module_dict.items():
+            for key, val in _get_named_tensors(module):
+                if not val.is_floating_point():
+                    continue
+                device = self.device or val.device
+                self.state[module_name][key].mul_(1 - w)
+                self.state[module_name][key].add_(val.detach().to(device), alpha=w)
+
+    def state_dict(self):
+        return {'state': self.state, 'count': self.count}
+
+    def load_state_dict(self, state):
+        self.count = state['count']
+        for module_name, module in state['state'].items():
+            for key, val in module.items():
+                self.state[module_name][key].copy_(val)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ModuleDictEMA +(module_dict: torch.nn.modules.container.ModuleDict, decay: float = 0.999, unbias: bool = True, device: Union[torch.device, str] = 'cpu') +
+
+

Exponential Moving Average over a nn.ModuleDict.

+

You can switch to the EMA weights temporarily.

+
+ +Expand source code + +
class ModuleDictEMA:
+    """Exponential Moving Average over a nn.ModuleDict.
+
+    You can switch to the EMA weights temporarily.
+    """
+    def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
+                 unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'):
+        self.decay = decay
+        self.module_dict = module_dict
+        self.state: dict = defaultdict(dict)
+        self.count = 0
+        self.device = device
+        self.unbias = unbias
+        self._init()
+
+    def _init(self):
+        for module_name, module in self.module_dict.items():
+            for key, val in _get_named_tensors(module):
+                if not val.is_floating_point():
+                    continue
+                device = self.device or val.device
+                if key not in self.state[module_name]:
+                    self.state[module_name][key] = val.detach().to(device, copy=True)
+
+    def step(self):
+        if self.unbias:
+            self.count = self.count * self.decay + 1
+            w = 1 / self.count
+        else:
+            w = 1 - self.decay
+        for module_name, module in self.module_dict.items():
+            for key, val in _get_named_tensors(module):
+                if not val.is_floating_point():
+                    continue
+                device = self.device or val.device
+                self.state[module_name][key].mul_(1 - w)
+                self.state[module_name][key].add_(val.detach().to(device), alpha=w)
+
+    def state_dict(self):
+        return {'state': self.state, 'count': self.count}
+
+    def load_state_dict(self, state):
+        self.count = state['count']
+        for module_name, module in state['state'].items():
+            for key, val in module.items():
+                self.state[module_name][key].copy_(val)
+
+

Methods

+
+
+def load_state_dict(self, state) +
+
+
+
+ +Expand source code + +
def load_state_dict(self, state):
+    self.count = state['count']
+    for module_name, module in state['state'].items():
+        for key, val in module.items():
+            self.state[module_name][key].copy_(val)
+
+
+
+def state_dict(self) +
+
+
+
+ +Expand source code + +
def state_dict(self):
+    return {'state': self.state, 'count': self.count}
+
+
+
+def step(self) +
+
+
+
+ +Expand source code + +
def step(self):
+    if self.unbias:
+        self.count = self.count * self.decay + 1
+        w = 1 / self.count
+    else:
+        w = 1 - self.decay
+    for module_name, module in self.module_dict.items():
+        for key, val in _get_named_tensors(module):
+            if not val.is_floating_point():
+                continue
+            device = self.device or val.device
+            self.state[module_name][key].mul_(1 - w)
+            self.state[module_name][key].add_(val.detach().to(device), alpha=w)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/fsdp.html b/api_docs/audiocraft/optim/fsdp.html new file mode 100644 index 00000000..f3be8879 --- /dev/null +++ b/api_docs/audiocraft/optim/fsdp.html @@ -0,0 +1,428 @@ + + + + + + +audiocraft.optim.fsdp API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim.fsdp

+
+
+

Wrapper around FSDP for more convenient use in the training loops.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Wrapper around FSDP for more convenient use in the training loops.
+"""
+
+from contextlib import contextmanager
+import typing as tp
+import dora
+import torch
+
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import (
+    MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType)
+from torch.distributed._shard.sharded_tensor.api import ShardedTensor
+
+
+def is_fsdp_used() -> bool:
+    """Return whether we are using FSDP."""
+    # A bit of a hack but should work from anywhere.
+    if dora.is_xp():
+        cfg = dora.get_xp().cfg
+        if hasattr(cfg, 'fsdp'):
+            return cfg.fsdp.use
+    return False
+
+
+def is_sharded_tensor(x: tp.Any) -> bool:
+    return isinstance(x, ShardedTensor)
+
+
+@contextmanager
+def switch_to_full_state_dict(models: tp.List[FSDP]):
+    # Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
+    # so let's do thing manually.
+    for model in models:
+        FSDP.set_state_dict_type(  # type: ignore
+            model, StateDictType.FULL_STATE_DICT,
+            FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
+    try:
+        yield
+    finally:
+        for model in models:
+            FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT)  # type: ignore
+
+
+def wrap_with_fsdp(cfg, model: torch.nn.Module,
+                   block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
+    """Wraps a model with FSDP."""
+    # Some of the typing is disabled until this gets integrated
+    # into the stable version of PyTorch.
+    from torch.distributed.fsdp.wrap import ModuleWrapPolicy  # type: ignore
+
+    # we import this here to prevent circular import.
+    from ..modules.transformer import StreamingTransformerLayer
+    from ..modules.conditioners import ConditioningProvider
+
+    _fix_post_backward_hook()
+
+    assert cfg.use
+    sharding_strategy_dict = {
+        "no_shard": ShardingStrategy.NO_SHARD,
+        "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
+        "full_shard": ShardingStrategy.FULL_SHARD,
+    }
+
+    dtype_dict = {
+        "float32": torch.float32,
+        "float16": torch.float16,
+        "bfloat16": torch.bfloat16,
+    }
+
+    mixed_precision_config = MixedPrecision(
+        param_dtype=dtype_dict[cfg.param_dtype],
+        reduce_dtype=dtype_dict[cfg.reduce_dtype],
+        buffer_dtype=dtype_dict[cfg.buffer_dtype],
+    )
+
+    sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
+    # The following is going to require being a bit smart
+    # when doing LM, because this would flush the weights for every time step
+    # during generation. One possiblity is to use hybrid sharding:
+    # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
+    assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
+        "Not supported at the moment, requires a bit more work."
+
+    local_rank = dora.distrib.get_distrib_spec().local_rank
+    assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"
+
+    auto_wrap_policy = None
+    if block_classes is None:
+        block_classes = {StreamingTransformerLayer, ConditioningProvider}
+    if cfg.per_block:
+        auto_wrap_policy = ModuleWrapPolicy(block_classes)
+    wrapped = _FSDPFixStateDict(
+        model,
+        sharding_strategy=sharding_strategy_config,
+        mixed_precision=mixed_precision_config,
+        device_id=local_rank,
+        sync_module_states=True,
+        use_orig_params=True,
+        auto_wrap_policy=auto_wrap_policy,
+    )  # type: ignore
+    FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT)  # type: ignore
+
+    # Let the wrapped model know about the wrapping!
+    # We use __dict__ to avoid it going into the state dict.
+    # This is a bit dirty, but needed during generation, as otherwise
+    # the wrapped model would call itself and bypass FSDP.
+    for module in FSDP.fsdp_modules(wrapped):
+        original = module._fsdp_wrapped_module
+        original.__dict__['_fsdp'] = module
+    return wrapped
+
+
+def purge_fsdp(model: FSDP):
+    """Purge the FSDP cached shard inside the model. This should
+    allow setting the best state or switching to the EMA.
+    """
+    from torch.distributed.fsdp._runtime_utils import _reshard  # type: ignore
+    for module in FSDP.fsdp_modules(model):
+        handles = module._handles
+        if not handles:
+            continue
+        handle = handles[0]
+        unsharded_flat_param = handle._get_padded_unsharded_flat_param()
+        storage_size: int = unsharded_flat_param._typed_storage()._size()  # type: ignore
+        if storage_size == 0:
+            continue
+        true_list = [True for h in handles]
+        _reshard(module, handles, true_list)
+
+
+class _FSDPFixStateDict(FSDP):
+    @staticmethod
+    def _name_without_fsdp_prefix(name: str) -> str:
+        from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE  # type: ignore
+        parts = name.split('.')
+        new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
+        return '.'.join(new_parts)
+
+    def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]:  # type: ignore
+        state = dict(super().state_dict(*args, **kwargs))
+        for key, value in list(state.items()):
+            if is_sharded_tensor(value):
+                del state[key]
+        return state
+
+    def load_state_dict(self, state: tp.Dict[str, tp.Any]):  # type: ignore
+        if self._state_dict_type is StateDictType.FULL_STATE_DICT:
+            super().load_state_dict(state)
+            purge_fsdp(self)
+            return
+        # Fix FSDP load state dict in all situation.
+        # Use this only with LOCAL_STATE_DICT !!!
+        current_state = dict(super().state_dict())
+        for key, value in state.items():
+            key = _FSDPFixStateDict._name_without_fsdp_prefix(key)
+            if key not in current_state:
+                # Emulate strict loading manually.
+                raise RuntimeError(f"Unknown state key {key}")
+            current_state[key].copy_(value)
+
+        # Purging cached weights from previous forward.
+        purge_fsdp(self)
+
+
+_hook_fixed = False
+
+
+def _fix_post_backward_hook():
+    global _hook_fixed
+    if _hook_fixed:
+        return
+    _hook_fixed = True
+
+    from torch.distributed.fsdp import _runtime_utils
+    from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState
+    old_hook = _runtime_utils._post_backward_hook
+
+    def _post_backward_hook(state, handle, *args, **kwargs):
+        checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False)
+        if checkpointed:
+            # there will be one more forward in the backward with checkpointing and that will
+            # massively confuse FSDP, so we have to make it think everything
+            # is going according to the plan.
+            state.training_state = TrainingState.FORWARD_BACKWARD
+            handle._training_state = HandleTrainingState.BACKWARD_PRE
+        old_hook(state, handle, *args, **kwargs)
+
+    _runtime_utils._post_backward_hook = _post_backward_hook
+
+
+
+
+
+
+
+

Functions

+
+
+def is_fsdp_used() ‑> bool +
+
+

Return whether we are using FSDP.

+
+ +Expand source code + +
def is_fsdp_used() -> bool:
+    """Return whether we are using FSDP."""
+    # A bit of a hack but should work from anywhere.
+    if dora.is_xp():
+        cfg = dora.get_xp().cfg
+        if hasattr(cfg, 'fsdp'):
+            return cfg.fsdp.use
+    return False
+
+
+
+def is_sharded_tensor(x: Any) ‑> bool +
+
+
+
+ +Expand source code + +
def is_sharded_tensor(x: tp.Any) -> bool:
+    return isinstance(x, ShardedTensor)
+
+
+
+def purge_fsdp(model: torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel) +
+
+

Purge the FSDP cached shard inside the model. This should +allow setting the best state or switching to the EMA.

+
+ +Expand source code + +
def purge_fsdp(model: FSDP):
+    """Purge the FSDP cached shard inside the model. This should
+    allow setting the best state or switching to the EMA.
+    """
+    from torch.distributed.fsdp._runtime_utils import _reshard  # type: ignore
+    for module in FSDP.fsdp_modules(model):
+        handles = module._handles
+        if not handles:
+            continue
+        handle = handles[0]
+        unsharded_flat_param = handle._get_padded_unsharded_flat_param()
+        storage_size: int = unsharded_flat_param._typed_storage()._size()  # type: ignore
+        if storage_size == 0:
+            continue
+        true_list = [True for h in handles]
+        _reshard(module, handles, true_list)
+
+
+
+def switch_to_full_state_dict(models: List[torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel]) +
+
+
+
+ +Expand source code + +
@contextmanager
+def switch_to_full_state_dict(models: tp.List[FSDP]):
+    # Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
+    # so let's do thing manually.
+    for model in models:
+        FSDP.set_state_dict_type(  # type: ignore
+            model, StateDictType.FULL_STATE_DICT,
+            FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
+    try:
+        yield
+    finally:
+        for model in models:
+            FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT)  # type: ignore
+
+
+
+def wrap_with_fsdp(cfg, model: torch.nn.modules.module.Module, block_classes: Optional[Set[Type[+CT_co]]] = None) ‑> torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel +
+
+

Wraps a model with FSDP.

+
+ +Expand source code + +
def wrap_with_fsdp(cfg, model: torch.nn.Module,
+                   block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
+    """Wraps a model with FSDP."""
+    # Some of the typing is disabled until this gets integrated
+    # into the stable version of PyTorch.
+    from torch.distributed.fsdp.wrap import ModuleWrapPolicy  # type: ignore
+
+    # we import this here to prevent circular import.
+    from ..modules.transformer import StreamingTransformerLayer
+    from ..modules.conditioners import ConditioningProvider
+
+    _fix_post_backward_hook()
+
+    assert cfg.use
+    sharding_strategy_dict = {
+        "no_shard": ShardingStrategy.NO_SHARD,
+        "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
+        "full_shard": ShardingStrategy.FULL_SHARD,
+    }
+
+    dtype_dict = {
+        "float32": torch.float32,
+        "float16": torch.float16,
+        "bfloat16": torch.bfloat16,
+    }
+
+    mixed_precision_config = MixedPrecision(
+        param_dtype=dtype_dict[cfg.param_dtype],
+        reduce_dtype=dtype_dict[cfg.reduce_dtype],
+        buffer_dtype=dtype_dict[cfg.buffer_dtype],
+    )
+
+    sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
+    # The following is going to require being a bit smart
+    # when doing LM, because this would flush the weights for every time step
+    # during generation. One possiblity is to use hybrid sharding:
+    # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
+    assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
+        "Not supported at the moment, requires a bit more work."
+
+    local_rank = dora.distrib.get_distrib_spec().local_rank
+    assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"
+
+    auto_wrap_policy = None
+    if block_classes is None:
+        block_classes = {StreamingTransformerLayer, ConditioningProvider}
+    if cfg.per_block:
+        auto_wrap_policy = ModuleWrapPolicy(block_classes)
+    wrapped = _FSDPFixStateDict(
+        model,
+        sharding_strategy=sharding_strategy_config,
+        mixed_precision=mixed_precision_config,
+        device_id=local_rank,
+        sync_module_states=True,
+        use_orig_params=True,
+        auto_wrap_policy=auto_wrap_policy,
+    )  # type: ignore
+    FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT)  # type: ignore
+
+    # Let the wrapped model know about the wrapping!
+    # We use __dict__ to avoid it going into the state dict.
+    # This is a bit dirty, but needed during generation, as otherwise
+    # the wrapped model would call itself and bypass FSDP.
+    for module in FSDP.fsdp_modules(wrapped):
+        original = module._fsdp_wrapped_module
+        original.__dict__['_fsdp'] = module
+    return wrapped
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/index.html b/api_docs/audiocraft/optim/index.html new file mode 100644 index 00000000..ccd2354b --- /dev/null +++ b/api_docs/audiocraft/optim/index.html @@ -0,0 +1,119 @@ + + + + + + +audiocraft.optim API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim

+
+
+

Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers +and Exponential Moving Average.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers
+and Exponential Moving Average.
+"""
+
+# flake8: noqa
+from .cosine_lr_scheduler import CosineLRScheduler
+from .dadam import DAdaptAdam
+from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler
+from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler
+from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler
+from .ema import ModuleDictEMA
+
+
+
+

Sub-modules

+
+
audiocraft.optim.cosine_lr_scheduler
+
+
+
+
audiocraft.optim.dadam
+
+
+
+
audiocraft.optim.ema
+
+
+
+
audiocraft.optim.fsdp
+
+

Wrapper around FSDP for more convenient use in the training loops.

+
+
audiocraft.optim.inverse_sqrt_lr_scheduler
+
+
+
+
audiocraft.optim.linear_warmup_lr_scheduler
+
+
+
+
audiocraft.optim.polynomial_decay_lr_scheduler
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/inverse_sqrt_lr_scheduler.html b/api_docs/audiocraft/optim/inverse_sqrt_lr_scheduler.html new file mode 100644 index 00000000..d65c9231 --- /dev/null +++ b/api_docs/audiocraft/optim/inverse_sqrt_lr_scheduler.html @@ -0,0 +1,178 @@ + + + + + + +audiocraft.optim.inverse_sqrt_lr_scheduler API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim.inverse_sqrt_lr_scheduler

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class InverseSquareRootLRScheduler(_LRScheduler):
+    """Inverse square root LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        warmup_init_lr (tp.Optional[float]): Initial learning rate
+            during warmup phase. When not set, use the provided learning rate.
+    """
+    def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
+        self.warmup_steps = warmup_steps
+        self.warmup_init_lr = warmup_init_lr
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if step < self.warmup_steps:
+            warmup_init_lr = self.warmup_init_lr or 0
+            lr_step = (lr - warmup_init_lr) / self.warmup_steps
+            lr = warmup_init_lr + step * lr_step
+        else:
+            decay_factor = lr * self.warmup_steps**0.5
+            lr = decay_factor * step**-0.5
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class InverseSquareRootLRScheduler +(optimizer: torch.optim.optimizer.Optimizer, warmup_steps: int, warmup_init_lr: Optional[float] = 0) +
+
+

Inverse square root LR scheduler.

+

Args

+
+
optimizer : Optimizer
+
Torch optimizer.
+
warmup_steps : int
+
Number of warmup steps.
+
warmup_init_lr : tp.Optional[float]
+
Initial learning rate +during warmup phase. When not set, use the provided learning rate.
+
+
+ +Expand source code + +
class InverseSquareRootLRScheduler(_LRScheduler):
+    """Inverse square root LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        warmup_init_lr (tp.Optional[float]): Initial learning rate
+            during warmup phase. When not set, use the provided learning rate.
+    """
+    def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
+        self.warmup_steps = warmup_steps
+        self.warmup_init_lr = warmup_init_lr
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if step < self.warmup_steps:
+            warmup_init_lr = self.warmup_init_lr or 0
+            lr_step = (lr - warmup_init_lr) / self.warmup_steps
+            lr = warmup_init_lr + step * lr_step
+        else:
+            decay_factor = lr * self.warmup_steps**0.5
+            lr = decay_factor * step**-0.5
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
+
+

Ancestors

+
    +
  • torch.optim.lr_scheduler._LRScheduler
  • +
  • torch.optim.lr_scheduler.LRScheduler
  • +
+

Methods

+
+
+def get_lr(self) +
+
+
+
+ +Expand source code + +
def get_lr(self):
+    return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/linear_warmup_lr_scheduler.html b/api_docs/audiocraft/optim/linear_warmup_lr_scheduler.html new file mode 100644 index 00000000..810765e7 --- /dev/null +++ b/api_docs/audiocraft/optim/linear_warmup_lr_scheduler.html @@ -0,0 +1,172 @@ + + + + + + +audiocraft.optim.linear_warmup_lr_scheduler API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim.linear_warmup_lr_scheduler

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class LinearWarmupLRScheduler(_LRScheduler):
+    """Inverse square root LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        warmup_init_lr (tp.Optional[float]): Initial learning rate
+            during warmup phase. When not set, use the provided learning rate.
+    """
+    def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
+        self.warmup_steps = warmup_steps
+        self.warmup_init_lr = warmup_init_lr
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if step < self.warmup_steps:
+            warmup_init_lr = self.warmup_init_lr or 0
+            lr_step = (lr - warmup_init_lr) / self.warmup_steps
+            lr = warmup_init_lr + step * lr_step
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class LinearWarmupLRScheduler +(optimizer: torch.optim.optimizer.Optimizer, warmup_steps: int, warmup_init_lr: Optional[float] = 0) +
+
+

Inverse square root LR scheduler.

+

Args

+
+
optimizer : Optimizer
+
Torch optimizer.
+
warmup_steps : int
+
Number of warmup steps.
+
warmup_init_lr : tp.Optional[float]
+
Initial learning rate +during warmup phase. When not set, use the provided learning rate.
+
+
+ +Expand source code + +
class LinearWarmupLRScheduler(_LRScheduler):
+    """Inverse square root LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        warmup_init_lr (tp.Optional[float]): Initial learning rate
+            during warmup phase. When not set, use the provided learning rate.
+    """
+    def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
+        self.warmup_steps = warmup_steps
+        self.warmup_init_lr = warmup_init_lr
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if step < self.warmup_steps:
+            warmup_init_lr = self.warmup_init_lr or 0
+            lr_step = (lr - warmup_init_lr) / self.warmup_steps
+            lr = warmup_init_lr + step * lr_step
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
+
+

Ancestors

+
    +
  • torch.optim.lr_scheduler._LRScheduler
  • +
  • torch.optim.lr_scheduler.LRScheduler
  • +
+

Methods

+
+
+def get_lr(self) +
+
+
+
+ +Expand source code + +
def get_lr(self):
+    return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/optim/polynomial_decay_lr_scheduler.html b/api_docs/audiocraft/optim/polynomial_decay_lr_scheduler.html new file mode 100644 index 00000000..e6dd3000 --- /dev/null +++ b/api_docs/audiocraft/optim/polynomial_decay_lr_scheduler.html @@ -0,0 +1,203 @@ + + + + + + +audiocraft.optim.polynomial_decay_lr_scheduler API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.optim.polynomial_decay_lr_scheduler

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class PolynomialDecayLRScheduler(_LRScheduler):
+    """Polynomial decay LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        total_steps (int): Total number of steps.
+        end_lr (float): Final learning rate to achieve over total number of steps.
+        zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
+        power (float): Decay exponent.
+    """
+    def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
+                 end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
+        self.warmup_steps = warmup_steps
+        self.total_steps = total_steps
+        self.end_lr = end_lr
+        self.zero_lr_warmup_steps = zero_lr_warmup_steps
+        self.power = power
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
+            lr = 0
+        elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
+            lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
+            lr = lr_ratio * lr
+        elif step >= self.total_steps:
+            lr = self.end_lr
+        else:
+            total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
+            lr_range = lr - self.end_lr
+            pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
+            lr = lr_range * pct_remaining ** self.power + self.end_lr
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class PolynomialDecayLRScheduler +(optimizer: torch.optim.optimizer.Optimizer, warmup_steps: int, total_steps: int, end_lr: float = 0.0, zero_lr_warmup_steps: int = 0, power: float = 1.0) +
+
+

Polynomial decay LR scheduler.

+

Args

+
+
optimizer : Optimizer
+
Torch optimizer.
+
warmup_steps : int
+
Number of warmup steps.
+
total_steps : int
+
Total number of steps.
+
end_lr : float
+
Final learning rate to achieve over total number of steps.
+
zero_lr_warmup_steps : int
+
Number of steps with a learning rate of value 0.
+
power : float
+
Decay exponent.
+
+
+ +Expand source code + +
class PolynomialDecayLRScheduler(_LRScheduler):
+    """Polynomial decay LR scheduler.
+
+    Args:
+        optimizer (Optimizer): Torch optimizer.
+        warmup_steps (int): Number of warmup steps.
+        total_steps (int): Total number of steps.
+        end_lr (float): Final learning rate to achieve over total number of steps.
+        zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
+        power (float): Decay exponent.
+    """
+    def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
+                 end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
+        self.warmup_steps = warmup_steps
+        self.total_steps = total_steps
+        self.end_lr = end_lr
+        self.zero_lr_warmup_steps = zero_lr_warmup_steps
+        self.power = power
+        super().__init__(optimizer)
+
+    def _get_sched_lr(self, lr: float, step: int):
+        if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
+            lr = 0
+        elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
+            lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
+            lr = lr_ratio * lr
+        elif step >= self.total_steps:
+            lr = self.end_lr
+        else:
+            total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
+            lr_range = lr - self.end_lr
+            pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
+            lr = lr_range * pct_remaining ** self.power + self.end_lr
+        return lr
+
+    def get_lr(self):
+        return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
+
+

Ancestors

+
    +
  • torch.optim.lr_scheduler._LRScheduler
  • +
  • torch.optim.lr_scheduler.LRScheduler
  • +
+

Methods

+
+
+def get_lr(self) +
+
+
+
+ +Expand source code + +
def get_lr(self):
+    return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/quantization/base.html b/api_docs/audiocraft/quantization/base.html new file mode 100644 index 00000000..d23cea1c --- /dev/null +++ b/api_docs/audiocraft/quantization/base.html @@ -0,0 +1,544 @@ + + + + + + +audiocraft.quantization.base API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization.base

+
+
+

Base class for all quantizers.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Base class for all quantizers.
+"""
+
+from dataclasses import dataclass, field
+import typing as tp
+
+import torch
+from torch import nn
+
+
+@dataclass
+class QuantizedResult:
+    x: torch.Tensor
+    codes: torch.Tensor
+    bandwidth: torch.Tensor  # bandwidth in kb/s used, per batch item.
+    penalty: tp.Optional[torch.Tensor] = None
+    metrics: dict = field(default_factory=dict)
+
+
+class BaseQuantizer(nn.Module):
+    """Base class for quantizers.
+    """
+
+    def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+        """
+        Given input tensor x, returns first the quantized (or approximately quantized)
+        representation along with quantized codes, bandwidth, and any penalty term for the loss.
+        Finally, this returns a dict of metrics to update logging etc.
+        Frame rate must be passed so that the bandwidth is properly computed.
+        """
+        raise NotImplementedError()
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth."""
+        raise NotImplementedError()
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation."""
+        raise NotImplementedError()
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks."""
+        raise NotImplementedError()
+
+    @property
+    def num_codebooks(self):
+        """Number of active codebooks."""
+        raise NotImplementedError()
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks."""
+        raise NotImplementedError()
+
+
+class DummyQuantizer(BaseQuantizer):
+    """Fake quantizer that actually does not perform any quantization.
+    """
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        q = x.unsqueeze(1)
+        return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return x.unsqueeze(1)
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return codes.squeeze(1)
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks."""
+        return 1
+
+    @property
+    def num_codebooks(self):
+        """Total number of codebooks."""
+        return self.total_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks."""
+        raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class BaseQuantizer +(*args, **kwargs) +
+
+

Base class for quantizers.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BaseQuantizer(nn.Module):
+    """Base class for quantizers.
+    """
+
+    def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+        """
+        Given input tensor x, returns first the quantized (or approximately quantized)
+        representation along with quantized codes, bandwidth, and any penalty term for the loss.
+        Finally, this returns a dict of metrics to update logging etc.
+        Frame rate must be passed so that the bandwidth is properly computed.
+        """
+        raise NotImplementedError()
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth."""
+        raise NotImplementedError()
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation."""
+        raise NotImplementedError()
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks."""
+        raise NotImplementedError()
+
+    @property
+    def num_codebooks(self):
+        """Number of active codebooks."""
+        raise NotImplementedError()
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks."""
+        raise NotImplementedError()
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var num_codebooks
+
+

Number of active codebooks.

+
+ +Expand source code + +
@property
+def num_codebooks(self):
+    """Number of active codebooks."""
+    raise NotImplementedError()
+
+
+
var total_codebooks
+
+

Total number of codebooks.

+
+ +Expand source code + +
@property
+def total_codebooks(self):
+    """Total number of codebooks."""
+    raise NotImplementedError()
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor) ‑> torch.Tensor +
+
+

Decode the given codes to the quantized representation.

+
+ +Expand source code + +
def decode(self, codes: torch.Tensor) -> torch.Tensor:
+    """Decode the given codes to the quantized representation."""
+    raise NotImplementedError()
+
+
+
+def encode(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Encode a given input tensor with the specified sample rate at the given bandwidth.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> torch.Tensor:
+    """Encode a given input tensor with the specified sample rate at the given bandwidth."""
+    raise NotImplementedError()
+
+
+
+def forward(self, x: torch.Tensor, frame_rate: int) ‑> QuantizedResult +
+
+

Given input tensor x, returns first the quantized (or approximately quantized) +representation along with quantized codes, bandwidth, and any penalty term for the loss. +Finally, this returns a dict of metrics to update logging etc. +Frame rate must be passed so that the bandwidth is properly computed.

+
+ +Expand source code + +
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+    """
+    Given input tensor x, returns first the quantized (or approximately quantized)
+    representation along with quantized codes, bandwidth, and any penalty term for the loss.
+    Finally, this returns a dict of metrics to update logging etc.
+    Frame rate must be passed so that the bandwidth is properly computed.
+    """
+    raise NotImplementedError()
+
+
+
+def set_num_codebooks(self, n: int) +
+
+

Set the number of active codebooks.

+
+ +Expand source code + +
def set_num_codebooks(self, n: int):
+    """Set the number of active codebooks."""
+    raise NotImplementedError()
+
+
+
+
+
+class DummyQuantizer +
+
+

Fake quantizer that actually does not perform any quantization.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DummyQuantizer(BaseQuantizer):
+    """Fake quantizer that actually does not perform any quantization.
+    """
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        q = x.unsqueeze(1)
+        return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return x.unsqueeze(1)
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return codes.squeeze(1)
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks."""
+        return 1
+
+    @property
+    def num_codebooks(self):
+        """Total number of codebooks."""
+        return self.total_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks."""
+        raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var num_codebooks
+
+

Total number of codebooks.

+
+ +Expand source code + +
@property
+def num_codebooks(self):
+    """Total number of codebooks."""
+    return self.total_codebooks
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor) ‑> torch.Tensor +
+
+

Decode the given codes to the quantized representation. +In the case of the DummyQuantizer, the codes are actually identical +to the input and resulting quantized representation as no quantization is done.

+
+ +Expand source code + +
def decode(self, codes: torch.Tensor) -> torch.Tensor:
+    """Decode the given codes to the quantized representation.
+    In the case of the DummyQuantizer, the codes are actually identical
+    to the input and resulting quantized representation as no quantization is done.
+    """
+    return codes.squeeze(1)
+
+
+
+def encode(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Encode a given input tensor with the specified sample rate at the given bandwidth. +In the case of the DummyQuantizer, the codes are actually identical +to the input and resulting quantized representation as no quantization is done.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> torch.Tensor:
+    """Encode a given input tensor with the specified sample rate at the given bandwidth.
+    In the case of the DummyQuantizer, the codes are actually identical
+    to the input and resulting quantized representation as no quantization is done.
+    """
+    return x.unsqueeze(1)
+
+
+
+

Inherited members

+ +
+
+class QuantizedResult +(x: torch.Tensor, codes: torch.Tensor, bandwidth: torch.Tensor, penalty: Optional[torch.Tensor] = None, metrics: dict = <factory>) +
+
+

QuantizedResult(x: torch.Tensor, codes: torch.Tensor, bandwidth: torch.Tensor, penalty: Union[torch.Tensor, NoneType] = None, metrics: dict = )

+
+ +Expand source code + +
class QuantizedResult:
+    x: torch.Tensor
+    codes: torch.Tensor
+    bandwidth: torch.Tensor  # bandwidth in kb/s used, per batch item.
+    penalty: tp.Optional[torch.Tensor] = None
+    metrics: dict = field(default_factory=dict)
+
+

Class variables

+
+
var bandwidth : torch.Tensor
+
+
+
+
var codes : torch.Tensor
+
+
+
+
var metrics : dict
+
+
+
+
var penalty : Optional[torch.Tensor]
+
+
+
+
var x : torch.Tensor
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/quantization/core_vq.html b/api_docs/audiocraft/quantization/core_vq.html new file mode 100644 index 00000000..198a3ea4 --- /dev/null +++ b/api_docs/audiocraft/quantization/core_vq.html @@ -0,0 +1,1538 @@ + + + + + + +audiocraft.quantization.core_vq API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization.core_vq

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from einops import rearrange, repeat
+import flashy
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+
+def exists(val: tp.Optional[tp.Any]) -> bool:
+    return val is not None
+
+
+def default(val: tp.Any, d: tp.Any) -> tp.Any:
+    return val if exists(val) else d
+
+
+def l2norm(t):
+    return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg, new, decay: float):
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+    return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+def uniform_init(*shape: int):
+    t = torch.empty(shape)
+    nn.init.kaiming_uniform_(t)
+    return t
+
+
+def sample_vectors(samples, num: int):
+    num_samples, device = samples.shape[0], samples.device
+
+    if num_samples >= num:
+        indices = torch.randperm(num_samples, device=device)[:num]
+    else:
+        indices = torch.randint(0, num_samples, (num,), device=device)
+
+    return samples[indices]
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
+    dim, dtype = samples.shape[-1], samples.dtype
+
+    means = sample_vectors(samples, num_clusters)
+
+    for _ in range(num_iters):
+        diffs = rearrange(samples, "n d -> n () d") - rearrange(
+            means, "c d -> () c d"
+        )
+        dists = -(diffs ** 2).sum(dim=-1)
+
+        buckets = dists.max(dim=-1).indices
+        bins = torch.bincount(buckets, minlength=num_clusters)
+        zero_mask = bins == 0
+        bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+        new_means = new_means / bins_min_clamped[..., None]
+
+        means = torch.where(zero_mask[..., None], means, new_means)
+
+    return means, bins
+
+
+def orthogonal_loss_fn(t):
+    # eq (2) from https://arxiv.org/abs/2112.00384
+    n = t.shape[0]
+    normed_codes = l2norm(t)
+    identity = torch.eye(n, device=t.device)
+    cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+    return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
+
+
+class EuclideanCodebook(nn.Module):
+    """Codebook with Euclidean distance.
+
+    Args:
+        dim (int): Dimension.
+        codebook_size (int): Codebook size.
+        kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+            If set to true, run the k-means algorithm on the first training batch and use
+            the learned centroids as initialization.
+        kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        kmeans_init: int = False,
+        kmeans_iters: int = 10,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        threshold_ema_dead_code: int = 2,
+    ):
+        super().__init__()
+        self.decay = decay
+        init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
+        embed = init_fn(codebook_size, dim)
+
+        self.codebook_size = codebook_size
+
+        self.kmeans_iters = kmeans_iters
+        self.epsilon = epsilon
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+
+        self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+        self.register_buffer("cluster_size", torch.zeros(codebook_size))
+        self.register_buffer("embed", embed)
+        self.register_buffer("embed_avg", embed.clone())
+
+    @torch.jit.ignore
+    def init_embed_(self, data):
+        if self.inited:
+            return
+
+        embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+        self.embed.data.copy_(embed)
+        self.embed_avg.data.copy_(embed.clone())
+        self.cluster_size.data.copy_(cluster_size)
+        self.inited.data.copy_(torch.Tensor([True]))
+        # Make sure all buffers across workers are in sync after initialization
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def replace_(self, samples, mask):
+        modified_codebook = torch.where(
+            mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+        )
+        self.embed.data.copy_(modified_codebook)
+
+    def expire_codes_(self, batch_samples):
+        if self.threshold_ema_dead_code == 0:
+            return
+
+        expired_codes = self.cluster_size < self.threshold_ema_dead_code
+        if not torch.any(expired_codes):
+            return
+
+        batch_samples = rearrange(batch_samples, "... d -> (...) d")
+        self.replace_(batch_samples, mask=expired_codes)
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def preprocess(self, x):
+        x = rearrange(x, "... d -> (...) d")
+        return x
+
+    def quantize(self, x):
+        embed = self.embed.t()
+        dist = -(
+            x.pow(2).sum(1, keepdim=True)
+            - 2 * x @ embed
+            + embed.pow(2).sum(0, keepdim=True)
+        )
+        embed_ind = dist.max(dim=-1).indices
+        return embed_ind
+
+    def postprocess_emb(self, embed_ind, shape):
+        return embed_ind.view(*shape[:-1])
+
+    def dequantize(self, embed_ind):
+        quantize = F.embedding(embed_ind, self.embed)
+        return quantize
+
+    def encode(self, x):
+        shape = x.shape
+        # pre-process
+        x = self.preprocess(x)
+        # quantize
+        embed_ind = self.quantize(x)
+        # post-process
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        return embed_ind
+
+    def decode(self, embed_ind):
+        quantize = self.dequantize(embed_ind)
+        return quantize
+
+    def forward(self, x):
+        shape, dtype = x.shape, x.dtype
+        x = self.preprocess(x)
+        self.init_embed_(x)
+
+        embed_ind = self.quantize(x)
+        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        quantize = self.dequantize(embed_ind)
+
+        if self.training:
+            # We do the expiry of code at that point as buffers are in sync
+            # and all the workers will take the same decision.
+            self.expire_codes_(x)
+            ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+            embed_sum = x.t() @ embed_onehot
+            ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+            cluster_size = (
+                laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+                * self.cluster_size.sum()
+            )
+            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+            self.embed.data.copy_(embed_normalized)
+
+        return quantize, embed_ind
+
+
+class VectorQuantization(nn.Module):
+    """Vector quantization implementation.
+    Currently supports only euclidean distance.
+
+    Args:
+        dim (int): Dimension
+        codebook_size (int): Codebook size
+        codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int):
+        channels_last (bool): Channels are the last dimension in the input tensors.
+        commitment_weight (float): Weight for commitment loss.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
+            for orthogonal regularization.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        codebook_dim: tp.Optional[int] = None,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        kmeans_init: bool = False,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        channels_last: bool = False,
+        commitment_weight: float = 1.,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        _codebook_dim: int = default(codebook_dim, dim)
+
+        requires_projection = _codebook_dim != dim
+        self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+        self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+
+        self.epsilon = epsilon
+        self.commitment_weight = commitment_weight
+
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+
+        self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+                                           kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+                                           decay=decay, epsilon=epsilon,
+                                           threshold_ema_dead_code=threshold_ema_dead_code)
+        self.codebook_size = codebook_size
+
+        self.channels_last = channels_last
+
+    @property
+    def codebook(self):
+        return self._codebook.embed
+
+    @property
+    def inited(self):
+        return self._codebook.inited
+
+    def _preprocess(self, x):
+        if not self.channels_last:
+            x = rearrange(x, "b d n -> b n d")
+        return x
+
+    def _postprocess(self, quantize):
+        if not self.channels_last:
+            quantize = rearrange(quantize, "b n d -> b d n")
+        return quantize
+
+    def encode(self, x):
+        x = self._preprocess(x)
+        x = self.project_in(x)
+        embed_in = self._codebook.encode(x)
+        return embed_in
+
+    def decode(self, embed_ind):
+        quantize = self._codebook.decode(embed_ind)
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+        return quantize
+
+    def forward(self, x):
+        device = x.device
+        x = self._preprocess(x)
+
+        x = self.project_in(x)
+        quantize, embed_ind = self._codebook(x)
+
+        if self.training:
+            quantize = x + (quantize - x).detach()
+
+        loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+        if self.training:
+            if self.commitment_weight > 0:
+                commit_loss = F.mse_loss(quantize.detach(), x)
+                loss = loss + commit_loss * self.commitment_weight
+
+            if self.orthogonal_reg_weight > 0:
+                codebook = self.codebook
+
+                if self.orthogonal_reg_active_codes_only:
+                    # only calculate orthogonal loss for the activated codes for this batch
+                    unique_code_ids = torch.unique(embed_ind)
+                    codebook = codebook[unique_code_ids]
+
+                num_codes = codebook.shape[0]
+                if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+                    rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+                    codebook = codebook[rand_ids]
+
+                orthogonal_reg_loss = orthogonal_loss_fn(codebook)
+                loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+
+        return quantize, embed_ind, loss
+
+
+class ResidualVectorQuantization(nn.Module):
+    """Residual vector quantization implementation.
+
+    Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+    """
+    def __init__(self, *, num_quantizers, **kwargs):
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+        )
+
+    def forward(self, x, n_q: tp.Optional[int] = None):
+        quantized_out = 0.0
+        residual = x
+
+        all_losses = []
+        all_indices = []
+
+        n_q = n_q or len(self.layers)
+
+        for i, layer in enumerate(self.layers[:n_q]):
+            quantized, indices, loss = layer(residual)
+            residual = residual - quantized
+            quantized_out = quantized_out + quantized
+            all_indices.append(indices)
+            all_losses.append(loss)
+
+        out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+        return quantized_out, out_indices, out_losses
+
+    def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+        residual = x
+        all_indices = []
+        n_q = n_q or len(self.layers)
+        for layer in self.layers[:n_q]:
+            indices = layer.encode(residual)
+            quantized = layer.decode(indices)
+            residual = residual - quantized
+            all_indices.append(indices)
+        out_indices = torch.stack(all_indices)
+        return out_indices
+
+    def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+        quantized_out = torch.tensor(0.0, device=q_indices.device)
+        for i, indices in enumerate(q_indices):
+            layer = self.layers[i]
+            quantized = layer.decode(indices)
+            quantized_out = quantized_out + quantized
+        return quantized_out
+
+
+
+
+
+
+
+

Functions

+
+
+def default(val: Any, d: Any) ‑> Any +
+
+
+
+ +Expand source code + +
def default(val: tp.Any, d: tp.Any) -> tp.Any:
+    return val if exists(val) else d
+
+
+
+def ema_inplace(moving_avg, new, decay: float) +
+
+
+
+ +Expand source code + +
def ema_inplace(moving_avg, new, decay: float):
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+
+def exists(val: Optional[Any]) ‑> bool +
+
+
+
+ +Expand source code + +
def exists(val: tp.Optional[tp.Any]) -> bool:
+    return val is not None
+
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10) +
+
+
+
+ +Expand source code + +
def kmeans(samples, num_clusters: int, num_iters: int = 10):
+    dim, dtype = samples.shape[-1], samples.dtype
+
+    means = sample_vectors(samples, num_clusters)
+
+    for _ in range(num_iters):
+        diffs = rearrange(samples, "n d -> n () d") - rearrange(
+            means, "c d -> () c d"
+        )
+        dists = -(diffs ** 2).sum(dim=-1)
+
+        buckets = dists.max(dim=-1).indices
+        bins = torch.bincount(buckets, minlength=num_clusters)
+        zero_mask = bins == 0
+        bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+        new_means = new_means / bins_min_clamped[..., None]
+
+        means = torch.where(zero_mask[..., None], means, new_means)
+
+    return means, bins
+
+
+
+def l2norm(t) +
+
+
+
+ +Expand source code + +
def l2norm(t):
+    return F.normalize(t, p=2, dim=-1)
+
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-05) +
+
+
+
+ +Expand source code + +
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+    return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+
+def orthogonal_loss_fn(t) +
+
+
+
+ +Expand source code + +
def orthogonal_loss_fn(t):
+    # eq (2) from https://arxiv.org/abs/2112.00384
+    n = t.shape[0]
+    normed_codes = l2norm(t)
+    identity = torch.eye(n, device=t.device)
+    cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+    return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
+
+
+
+def sample_vectors(samples, num: int) +
+
+
+
+ +Expand source code + +
def sample_vectors(samples, num: int):
+    num_samples, device = samples.shape[0], samples.device
+
+    if num_samples >= num:
+        indices = torch.randperm(num_samples, device=device)[:num]
+    else:
+        indices = torch.randint(0, num_samples, (num,), device=device)
+
+    return samples[indices]
+
+
+
+def uniform_init(*shape: int) +
+
+
+
+ +Expand source code + +
def uniform_init(*shape: int):
+    t = torch.empty(shape)
+    nn.init.kaiming_uniform_(t)
+    return t
+
+
+
+
+
+

Classes

+
+
+class EuclideanCodebook +(dim: int, codebook_size: int, kmeans_init: int = False, kmeans_iters: int = 10, decay: float = 0.8, epsilon: float = 1e-05, threshold_ema_dead_code: int = 2) +
+
+

Codebook with Euclidean distance.

+

Args

+
+
dim : int
+
Dimension.
+
codebook_size : int
+
Codebook size.
+
kmeans_init : bool
+
Whether to use k-means to initialize the codebooks. +If set to true, run the k-means algorithm on the first training batch and use +the learned centroids as initialization.
+
kmeans_iters : int
+
Number of iterations used for k-means algorithm at initialization.
+
decay : float
+
Decay for exponential moving average over the codebooks.
+
epsilon : float
+
Epsilon value for numerical stability.
+
threshold_ema_dead_code : int
+
Threshold for dead code expiration. Replace any codes +that have an exponential moving average cluster size less than the specified threshold with +randomly selected vector from the current batch.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class EuclideanCodebook(nn.Module):
+    """Codebook with Euclidean distance.
+
+    Args:
+        dim (int): Dimension.
+        codebook_size (int): Codebook size.
+        kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+            If set to true, run the k-means algorithm on the first training batch and use
+            the learned centroids as initialization.
+        kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        kmeans_init: int = False,
+        kmeans_iters: int = 10,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        threshold_ema_dead_code: int = 2,
+    ):
+        super().__init__()
+        self.decay = decay
+        init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
+        embed = init_fn(codebook_size, dim)
+
+        self.codebook_size = codebook_size
+
+        self.kmeans_iters = kmeans_iters
+        self.epsilon = epsilon
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+
+        self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+        self.register_buffer("cluster_size", torch.zeros(codebook_size))
+        self.register_buffer("embed", embed)
+        self.register_buffer("embed_avg", embed.clone())
+
+    @torch.jit.ignore
+    def init_embed_(self, data):
+        if self.inited:
+            return
+
+        embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+        self.embed.data.copy_(embed)
+        self.embed_avg.data.copy_(embed.clone())
+        self.cluster_size.data.copy_(cluster_size)
+        self.inited.data.copy_(torch.Tensor([True]))
+        # Make sure all buffers across workers are in sync after initialization
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def replace_(self, samples, mask):
+        modified_codebook = torch.where(
+            mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+        )
+        self.embed.data.copy_(modified_codebook)
+
+    def expire_codes_(self, batch_samples):
+        if self.threshold_ema_dead_code == 0:
+            return
+
+        expired_codes = self.cluster_size < self.threshold_ema_dead_code
+        if not torch.any(expired_codes):
+            return
+
+        batch_samples = rearrange(batch_samples, "... d -> (...) d")
+        self.replace_(batch_samples, mask=expired_codes)
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def preprocess(self, x):
+        x = rearrange(x, "... d -> (...) d")
+        return x
+
+    def quantize(self, x):
+        embed = self.embed.t()
+        dist = -(
+            x.pow(2).sum(1, keepdim=True)
+            - 2 * x @ embed
+            + embed.pow(2).sum(0, keepdim=True)
+        )
+        embed_ind = dist.max(dim=-1).indices
+        return embed_ind
+
+    def postprocess_emb(self, embed_ind, shape):
+        return embed_ind.view(*shape[:-1])
+
+    def dequantize(self, embed_ind):
+        quantize = F.embedding(embed_ind, self.embed)
+        return quantize
+
+    def encode(self, x):
+        shape = x.shape
+        # pre-process
+        x = self.preprocess(x)
+        # quantize
+        embed_ind = self.quantize(x)
+        # post-process
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        return embed_ind
+
+    def decode(self, embed_ind):
+        quantize = self.dequantize(embed_ind)
+        return quantize
+
+    def forward(self, x):
+        shape, dtype = x.shape, x.dtype
+        x = self.preprocess(x)
+        self.init_embed_(x)
+
+        embed_ind = self.quantize(x)
+        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        quantize = self.dequantize(embed_ind)
+
+        if self.training:
+            # We do the expiry of code at that point as buffers are in sync
+            # and all the workers will take the same decision.
+            self.expire_codes_(x)
+            ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+            embed_sum = x.t() @ embed_onehot
+            ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+            cluster_size = (
+                laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+                * self.cluster_size.sum()
+            )
+            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+            self.embed.data.copy_(embed_normalized)
+
+        return quantize, embed_ind
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def decode(self, embed_ind) +
+
+
+
+ +Expand source code + +
def decode(self, embed_ind):
+    quantize = self.dequantize(embed_ind)
+    return quantize
+
+
+
+def dequantize(self, embed_ind) +
+
+
+
+ +Expand source code + +
def dequantize(self, embed_ind):
+    quantize = F.embedding(embed_ind, self.embed)
+    return quantize
+
+
+
+def encode(self, x) +
+
+
+
+ +Expand source code + +
def encode(self, x):
+    shape = x.shape
+    # pre-process
+    x = self.preprocess(x)
+    # quantize
+    embed_ind = self.quantize(x)
+    # post-process
+    embed_ind = self.postprocess_emb(embed_ind, shape)
+    return embed_ind
+
+
+
+def expire_codes_(self, batch_samples) +
+
+
+
+ +Expand source code + +
def expire_codes_(self, batch_samples):
+    if self.threshold_ema_dead_code == 0:
+        return
+
+    expired_codes = self.cluster_size < self.threshold_ema_dead_code
+    if not torch.any(expired_codes):
+        return
+
+    batch_samples = rearrange(batch_samples, "... d -> (...) d")
+    self.replace_(batch_samples, mask=expired_codes)
+    flashy.distrib.broadcast_tensors(self.buffers())
+
+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    shape, dtype = x.shape, x.dtype
+    x = self.preprocess(x)
+    self.init_embed_(x)
+
+    embed_ind = self.quantize(x)
+    embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+    embed_ind = self.postprocess_emb(embed_ind, shape)
+    quantize = self.dequantize(embed_ind)
+
+    if self.training:
+        # We do the expiry of code at that point as buffers are in sync
+        # and all the workers will take the same decision.
+        self.expire_codes_(x)
+        ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+        embed_sum = x.t() @ embed_onehot
+        ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+        cluster_size = (
+            laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+            * self.cluster_size.sum()
+        )
+        embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+        self.embed.data.copy_(embed_normalized)
+
+    return quantize, embed_ind
+
+
+
+def init_embed_(self, data) +
+
+
+
+ +Expand source code + +
@torch.jit.ignore
+def init_embed_(self, data):
+    if self.inited:
+        return
+
+    embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+    self.embed.data.copy_(embed)
+    self.embed_avg.data.copy_(embed.clone())
+    self.cluster_size.data.copy_(cluster_size)
+    self.inited.data.copy_(torch.Tensor([True]))
+    # Make sure all buffers across workers are in sync after initialization
+    flashy.distrib.broadcast_tensors(self.buffers())
+
+
+
+def postprocess_emb(self, embed_ind, shape) +
+
+
+
+ +Expand source code + +
def postprocess_emb(self, embed_ind, shape):
+    return embed_ind.view(*shape[:-1])
+
+
+
+def preprocess(self, x) +
+
+
+
+ +Expand source code + +
def preprocess(self, x):
+    x = rearrange(x, "... d -> (...) d")
+    return x
+
+
+
+def quantize(self, x) +
+
+
+
+ +Expand source code + +
def quantize(self, x):
+    embed = self.embed.t()
+    dist = -(
+        x.pow(2).sum(1, keepdim=True)
+        - 2 * x @ embed
+        + embed.pow(2).sum(0, keepdim=True)
+    )
+    embed_ind = dist.max(dim=-1).indices
+    return embed_ind
+
+
+
+def replace_(self, samples, mask) +
+
+
+
+ +Expand source code + +
def replace_(self, samples, mask):
+    modified_codebook = torch.where(
+        mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+    )
+    self.embed.data.copy_(modified_codebook)
+
+
+
+
+
+class ResidualVectorQuantization +(*, num_quantizers, **kwargs) +
+
+

Residual vector quantization implementation.

+

Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResidualVectorQuantization(nn.Module):
+    """Residual vector quantization implementation.
+
+    Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+    """
+    def __init__(self, *, num_quantizers, **kwargs):
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+        )
+
+    def forward(self, x, n_q: tp.Optional[int] = None):
+        quantized_out = 0.0
+        residual = x
+
+        all_losses = []
+        all_indices = []
+
+        n_q = n_q or len(self.layers)
+
+        for i, layer in enumerate(self.layers[:n_q]):
+            quantized, indices, loss = layer(residual)
+            residual = residual - quantized
+            quantized_out = quantized_out + quantized
+            all_indices.append(indices)
+            all_losses.append(loss)
+
+        out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+        return quantized_out, out_indices, out_losses
+
+    def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+        residual = x
+        all_indices = []
+        n_q = n_q or len(self.layers)
+        for layer in self.layers[:n_q]:
+            indices = layer.encode(residual)
+            quantized = layer.decode(indices)
+            residual = residual - quantized
+            all_indices.append(indices)
+        out_indices = torch.stack(all_indices)
+        return out_indices
+
+    def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+        quantized_out = torch.tensor(0.0, device=q_indices.device)
+        for i, indices in enumerate(q_indices):
+            layer = self.layers[i]
+            quantized = layer.decode(indices)
+            quantized_out = quantized_out + quantized
+        return quantized_out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def decode(self, q_indices: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+    quantized_out = torch.tensor(0.0, device=q_indices.device)
+    for i, indices in enumerate(q_indices):
+        layer = self.layers[i]
+        quantized = layer.decode(indices)
+        quantized_out = quantized_out + quantized
+    return quantized_out
+
+
+
+def encode(self, x: torch.Tensor, n_q: Optional[int] = None) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+    residual = x
+    all_indices = []
+    n_q = n_q or len(self.layers)
+    for layer in self.layers[:n_q]:
+        indices = layer.encode(residual)
+        quantized = layer.decode(indices)
+        residual = residual - quantized
+        all_indices.append(indices)
+    out_indices = torch.stack(all_indices)
+    return out_indices
+
+
+
+def forward(self, x, n_q: Optional[int] = None) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, n_q: tp.Optional[int] = None):
+    quantized_out = 0.0
+    residual = x
+
+    all_losses = []
+    all_indices = []
+
+    n_q = n_q or len(self.layers)
+
+    for i, layer in enumerate(self.layers[:n_q]):
+        quantized, indices, loss = layer(residual)
+        residual = residual - quantized
+        quantized_out = quantized_out + quantized
+        all_indices.append(indices)
+        all_losses.append(loss)
+
+    out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+    return quantized_out, out_indices, out_losses
+
+
+
+
+
+class VectorQuantization +(dim: int, codebook_size: int, codebook_dim: Optional[int] = None, decay: float = 0.8, epsilon: float = 1e-05, kmeans_init: bool = False, kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, channels_last: bool = False, commitment_weight: float = 1.0, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: Optional[int] = None) +
+
+

Vector quantization implementation. +Currently supports only euclidean distance.

+

Args

+
+
dim : int
+
Dimension
+
codebook_size : int
+
Codebook size
+
codebook_dim : int
+
Codebook dimension. If not defined, uses the specified dimension in dim.
+
decay : float
+
Decay for exponential moving average over the codebooks.
+
epsilon : float
+
Epsilon value for numerical stability.
+
kmeans_init : bool
+
Whether to use kmeans to initialize the codebooks.
+
kmeans_iters : int
+
Number of iterations used for kmeans initialization.
+
threshold_ema_dead_code (int):
+
channels_last : bool
+
Channels are the last dimension in the input tensors.
+
commitment_weight : float
+
Weight for commitment loss.
+
orthogonal_reg_weight : float
+
Orthogonal regularization weights.
+
orthogonal_reg_active_codes_only : bool
+
Apply orthogonal regularization only on active codes.
+
orthogonal_reg_max_codes : optional int
+
Maximum number of codes to consider +for orthogonal regularization.
+
threshold_ema_dead_code : int
+
Threshold for dead code expiration. Replace any codes +that have an exponential moving average cluster size less than the specified threshold with +randomly selected vector from the current batch.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class VectorQuantization(nn.Module):
+    """Vector quantization implementation.
+    Currently supports only euclidean distance.
+
+    Args:
+        dim (int): Dimension
+        codebook_size (int): Codebook size
+        codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int):
+        channels_last (bool): Channels are the last dimension in the input tensors.
+        commitment_weight (float): Weight for commitment loss.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
+            for orthogonal regularization.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        codebook_dim: tp.Optional[int] = None,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        kmeans_init: bool = False,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        channels_last: bool = False,
+        commitment_weight: float = 1.,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        _codebook_dim: int = default(codebook_dim, dim)
+
+        requires_projection = _codebook_dim != dim
+        self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+        self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+
+        self.epsilon = epsilon
+        self.commitment_weight = commitment_weight
+
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+
+        self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+                                           kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+                                           decay=decay, epsilon=epsilon,
+                                           threshold_ema_dead_code=threshold_ema_dead_code)
+        self.codebook_size = codebook_size
+
+        self.channels_last = channels_last
+
+    @property
+    def codebook(self):
+        return self._codebook.embed
+
+    @property
+    def inited(self):
+        return self._codebook.inited
+
+    def _preprocess(self, x):
+        if not self.channels_last:
+            x = rearrange(x, "b d n -> b n d")
+        return x
+
+    def _postprocess(self, quantize):
+        if not self.channels_last:
+            quantize = rearrange(quantize, "b n d -> b d n")
+        return quantize
+
+    def encode(self, x):
+        x = self._preprocess(x)
+        x = self.project_in(x)
+        embed_in = self._codebook.encode(x)
+        return embed_in
+
+    def decode(self, embed_ind):
+        quantize = self._codebook.decode(embed_ind)
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+        return quantize
+
+    def forward(self, x):
+        device = x.device
+        x = self._preprocess(x)
+
+        x = self.project_in(x)
+        quantize, embed_ind = self._codebook(x)
+
+        if self.training:
+            quantize = x + (quantize - x).detach()
+
+        loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+        if self.training:
+            if self.commitment_weight > 0:
+                commit_loss = F.mse_loss(quantize.detach(), x)
+                loss = loss + commit_loss * self.commitment_weight
+
+            if self.orthogonal_reg_weight > 0:
+                codebook = self.codebook
+
+                if self.orthogonal_reg_active_codes_only:
+                    # only calculate orthogonal loss for the activated codes for this batch
+                    unique_code_ids = torch.unique(embed_ind)
+                    codebook = codebook[unique_code_ids]
+
+                num_codes = codebook.shape[0]
+                if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+                    rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+                    codebook = codebook[rand_ids]
+
+                orthogonal_reg_loss = orthogonal_loss_fn(codebook)
+                loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+
+        return quantize, embed_ind, loss
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var codebook
+
+
+
+ +Expand source code + +
@property
+def codebook(self):
+    return self._codebook.embed
+
+
+
var inited
+
+
+
+ +Expand source code + +
@property
+def inited(self):
+    return self._codebook.inited
+
+
+
+

Methods

+
+
+def decode(self, embed_ind) +
+
+
+
+ +Expand source code + +
def decode(self, embed_ind):
+    quantize = self._codebook.decode(embed_ind)
+    quantize = self.project_out(quantize)
+    quantize = self._postprocess(quantize)
+    return quantize
+
+
+
+def encode(self, x) +
+
+
+
+ +Expand source code + +
def encode(self, x):
+    x = self._preprocess(x)
+    x = self.project_in(x)
+    embed_in = self._codebook.encode(x)
+    return embed_in
+
+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    device = x.device
+    x = self._preprocess(x)
+
+    x = self.project_in(x)
+    quantize, embed_ind = self._codebook(x)
+
+    if self.training:
+        quantize = x + (quantize - x).detach()
+
+    loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+    if self.training:
+        if self.commitment_weight > 0:
+            commit_loss = F.mse_loss(quantize.detach(), x)
+            loss = loss + commit_loss * self.commitment_weight
+
+        if self.orthogonal_reg_weight > 0:
+            codebook = self.codebook
+
+            if self.orthogonal_reg_active_codes_only:
+                # only calculate orthogonal loss for the activated codes for this batch
+                unique_code_ids = torch.unique(embed_ind)
+                codebook = codebook[unique_code_ids]
+
+            num_codes = codebook.shape[0]
+            if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+                rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+                codebook = codebook[rand_ids]
+
+            orthogonal_reg_loss = orthogonal_loss_fn(codebook)
+            loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+    quantize = self.project_out(quantize)
+    quantize = self._postprocess(quantize)
+
+    return quantize, embed_ind, loss
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/quantization/index.html b/api_docs/audiocraft/quantization/index.html new file mode 100644 index 00000000..629669b4 --- /dev/null +++ b/api_docs/audiocraft/quantization/index.html @@ -0,0 +1,90 @@ + + + + + + +audiocraft.quantization API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization

+
+
+

RVQ.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""RVQ."""
+# flake8: noqa
+from .vq import ResidualVectorQuantizer
+from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
+
+
+
+

Sub-modules

+
+
audiocraft.quantization.base
+
+

Base class for all quantizers.

+
+
audiocraft.quantization.core_vq
+
+
+
+
audiocraft.quantization.vq
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/quantization/vq.html b/api_docs/audiocraft/quantization/vq.html new file mode 100644 index 00000000..eb2508e2 --- /dev/null +++ b/api_docs/audiocraft/quantization/vq.html @@ -0,0 +1,388 @@ + + + + + + +audiocraft.quantization.vq API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization.vq

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import torch
+
+from .base import BaseQuantizer, QuantizedResult
+from .core_vq import ResidualVectorQuantization
+
+
+class ResidualVectorQuantizer(BaseQuantizer):
+    """Residual Vector Quantizer.
+
+    Args:
+        dimension (int): Dimension of the codebooks.
+        n_q (int): Number of residual vector quantizers used.
+        q_dropout (bool): Random quantizer drop out at train time.
+        bins (int): Codebook size.
+        decay (float): Decay for exponential moving average over the codebooks.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
+            for orthogonal regularization.
+    """
+    def __init__(
+        self,
+        dimension: int = 256,
+        n_q: int = 8,
+        q_dropout: bool = False,
+        bins: int = 1024,
+        decay: float = 0.99,
+        kmeans_init: bool = True,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        self.max_n_q = n_q
+        self.n_q = n_q
+        self.q_dropout = q_dropout
+        self.dimension = dimension
+        self.bins = bins
+        self.decay = decay
+        self.kmeans_init = kmeans_init
+        self.kmeans_iters = kmeans_iters
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+        self.vq = ResidualVectorQuantization(
+            dim=self.dimension,
+            codebook_size=self.bins,
+            num_quantizers=self.n_q,
+            decay=self.decay,
+            kmeans_init=self.kmeans_init,
+            kmeans_iters=self.kmeans_iters,
+            threshold_ema_dead_code=self.threshold_ema_dead_code,
+            orthogonal_reg_weight=self.orthogonal_reg_weight,
+            orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
+            orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
+            channels_last=False
+        )
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        n_q = self.n_q
+        if self.training and self.q_dropout:
+            n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
+        bw_per_q = math.log2(self.bins) * frame_rate / 1000
+        quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        bw = torch.tensor(n_q * bw_per_q).to(x)
+        return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified frame rate at the given bandwidth.
+        The RVQ encode method sets the appropriate number of quantizer to use
+        and returns indices for each quantizer.
+        """
+        n_q = self.n_q
+        codes = self.vq.encode(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        return codes
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation."""
+        # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
+        codes = codes.transpose(0, 1)
+        quantized = self.vq.decode(codes)
+        return quantized
+
+    @property
+    def total_codebooks(self):
+        return self.max_n_q
+
+    @property
+    def num_codebooks(self):
+        return self.n_q
+
+    def set_num_codebooks(self, n: int):
+        assert n > 0 and n <= self.max_n_q
+        self.n_q = n
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ResidualVectorQuantizer +(dimension: int = 256, n_q: int = 8, q_dropout: bool = False, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: Optional[int] = None) +
+
+

Residual Vector Quantizer.

+

Args

+
+
dimension : int
+
Dimension of the codebooks.
+
n_q : int
+
Number of residual vector quantizers used.
+
q_dropout : bool
+
Random quantizer drop out at train time.
+
bins : int
+
Codebook size.
+
decay : float
+
Decay for exponential moving average over the codebooks.
+
kmeans_init : bool
+
Whether to use kmeans to initialize the codebooks.
+
kmeans_iters : int
+
Number of iterations used for kmeans initialization.
+
threshold_ema_dead_code : int
+
Threshold for dead code expiration. Replace any codes +that have an exponential moving average cluster size less than the specified threshold with +randomly selected vector from the current batch.
+
orthogonal_reg_weight : float
+
Orthogonal regularization weights.
+
orthogonal_reg_active_codes_only : bool
+
Apply orthogonal regularization only on active codes.
+
orthogonal_reg_max_codes : optional int
+
Maximum number of codes to consider. +for orthogonal regularization.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResidualVectorQuantizer(BaseQuantizer):
+    """Residual Vector Quantizer.
+
+    Args:
+        dimension (int): Dimension of the codebooks.
+        n_q (int): Number of residual vector quantizers used.
+        q_dropout (bool): Random quantizer drop out at train time.
+        bins (int): Codebook size.
+        decay (float): Decay for exponential moving average over the codebooks.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
+            for orthogonal regularization.
+    """
+    def __init__(
+        self,
+        dimension: int = 256,
+        n_q: int = 8,
+        q_dropout: bool = False,
+        bins: int = 1024,
+        decay: float = 0.99,
+        kmeans_init: bool = True,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        self.max_n_q = n_q
+        self.n_q = n_q
+        self.q_dropout = q_dropout
+        self.dimension = dimension
+        self.bins = bins
+        self.decay = decay
+        self.kmeans_init = kmeans_init
+        self.kmeans_iters = kmeans_iters
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+        self.vq = ResidualVectorQuantization(
+            dim=self.dimension,
+            codebook_size=self.bins,
+            num_quantizers=self.n_q,
+            decay=self.decay,
+            kmeans_init=self.kmeans_init,
+            kmeans_iters=self.kmeans_iters,
+            threshold_ema_dead_code=self.threshold_ema_dead_code,
+            orthogonal_reg_weight=self.orthogonal_reg_weight,
+            orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
+            orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
+            channels_last=False
+        )
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        n_q = self.n_q
+        if self.training and self.q_dropout:
+            n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
+        bw_per_q = math.log2(self.bins) * frame_rate / 1000
+        quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        bw = torch.tensor(n_q * bw_per_q).to(x)
+        return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified frame rate at the given bandwidth.
+        The RVQ encode method sets the appropriate number of quantizer to use
+        and returns indices for each quantizer.
+        """
+        n_q = self.n_q
+        codes = self.vq.encode(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        return codes
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation."""
+        # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
+        codes = codes.transpose(0, 1)
+        quantized = self.vq.decode(codes)
+        return quantized
+
+    @property
+    def total_codebooks(self):
+        return self.max_n_q
+
+    @property
+    def num_codebooks(self):
+        return self.n_q
+
+    def set_num_codebooks(self, n: int):
+        assert n > 0 and n <= self.max_n_q
+        self.n_q = n
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def encode(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Encode a given input tensor with the specified frame rate at the given bandwidth. +The RVQ encode method sets the appropriate number of quantizer to use +and returns indices for each quantizer.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> torch.Tensor:
+    """Encode a given input tensor with the specified frame rate at the given bandwidth.
+    The RVQ encode method sets the appropriate number of quantizer to use
+    and returns indices for each quantizer.
+    """
+    n_q = self.n_q
+    codes = self.vq.encode(x, n_q=n_q)
+    codes = codes.transpose(0, 1)
+    # codes is [B, K, T], with T frames, K nb of codebooks.
+    return codes
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/solvers/audiogen.html b/api_docs/audiocraft/solvers/audiogen.html new file mode 100644 index 00000000..a81c5eee --- /dev/null +++ b/api_docs/audiocraft/solvers/audiogen.html @@ -0,0 +1,166 @@ + + + + + + +audiocraft.solvers.audiogen API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.solvers.audiogen

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import builders, musicgen
+
+
+class AudioGenSolver(musicgen.MusicGenSolver):
+    """Solver for AudioGen re-implementation training task.
+
+    Note that this implementation does not strictly follows
+    the method proposed in https://arxiv.org/abs/2209.15352
+    but is derived from MusicGen's training pipeline.
+
+    More information can be found in the AudioGen model card.
+    """
+    DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class AudioGenSolver +(cfg: omegaconf.dictconfig.DictConfig) +
+
+

Solver for AudioGen re-implementation training task.

+

Note that this implementation does not strictly follows +the method proposed in https://arxiv.org/abs/2209.15352 +but is derived from MusicGen's training pipeline.

+

More information can be found in the AudioGen model card.

+
+ +Expand source code + +
class AudioGenSolver(musicgen.MusicGenSolver):
+    """Solver for AudioGen re-implementation training task.
+
+    Note that this implementation does not strictly follows
+    the method proposed in https://arxiv.org/abs/2209.15352
+    but is derived from MusicGen's training pipeline.
+
+    More information can be found in the AudioGen model card.
+    """
+    DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
+
+

Ancestors

+ +

Class variables

+
+
var DATASET_TYPEDatasetType
+
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/solvers/base.html b/api_docs/audiocraft/solvers/base.html new file mode 100644 index 00000000..fceae734 --- /dev/null +++ b/api_docs/audiocraft/solvers/base.html @@ -0,0 +1,2301 @@ + + + + + + +audiocraft.solvers.base API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.solvers.base

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+import typing as tp
+
+import flashy
+import omegaconf
+import torch
+from torch import nn
+
+from .. import optim
+from ..optim import fsdp
+from ..utils import checkpoint
+from ..utils.autocast import TorchAutocast
+from ..utils.best_state import BestStateDictManager
+from ..utils.deadlock import DeadlockDetect
+from ..utils.profiler import Profiler
+from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng
+
+
+class StandardSolver(ABC, flashy.BaseSolver):
+    """Standard solver for AudioCraft.
+
+    The standard solver implements a base training loop with the following stages:
+    train, valid, evaluate and generate that are expected to be all defined for
+    solvers in AudioCraft. It also provides a nice default management of Dora history replay,
+    checkpoint management across epoch, and logging configuration.
+
+    AudioCraft solvers must inherit from the StandardSolver and define the methods
+    associated to each stage as well as the show, build_model and build_dataloaders methods.
+    """
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__()
+        self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}")
+        self.logger.info(f"All XP logs are stored in {self.xp.folder}")
+        self.cfg = cfg
+        self.device = cfg.device
+        self.model: nn.Module
+        self._continue_best_source_keys = ['best_state', 'fsdp_best_state']
+        self._fsdp_modules: tp.List[fsdp.FSDP] = []
+        self._ema_sources: nn.ModuleDict = nn.ModuleDict()
+        self.ema: tp.Optional[optim.ModuleDictEMA] = None
+        self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
+        self._log_updates = self.cfg.logging.get('log_updates', 10)
+        if self.cfg.logging.log_tensorboard:
+            self.init_tensorboard(**self.cfg.get('tensorboard'))
+        if self.cfg.logging.log_wandb and self:
+            self.init_wandb(**self.cfg.get('wandb'))
+        # keep a copy of the best performing state for stateful objects
+        # used for evaluation and generation stages
+        dtype_best: tp.Optional[torch.dtype] = None
+        if self.cfg.fsdp.use:
+            dtype_best = getattr(torch, self.cfg.fsdp.param_dtype)  # type: ignore
+            assert isinstance(dtype_best, torch.dtype)
+        elif self.cfg.autocast:
+            dtype_best = getattr(torch, self.cfg.autocast_dtype)  # type: ignore
+            assert isinstance(dtype_best, torch.dtype)
+        self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
+        # Hacky support for keeping a copy of the full best state in rank0.
+        self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
+        self.register_stateful('best_state', 'fsdp_best_state')  # register best_state object to keep it in state_dict
+        self._new_best_state: bool = False  # should save a new checkpoint
+        # instantiate datasets and appropriate number of updates per epoch
+        self.build_dataloaders()
+        if self.cfg.execute_only is None:
+            assert 'train' in self.dataloaders, "The train dataset split must be provided."
+            assert 'valid' in self.dataloaders, "The valid dataset split must be provided."
+        self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0
+        if self.cfg.optim.updates_per_epoch:
+            self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
+        self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
+        # instantiate model & exponential moving average on the model
+        self.build_model()
+        self.logger.info("Model hash: %s", model_hash(self.model))
+        assert 'model' in self.stateful.sources, \
+            "Please register the model to stateful with self.register_stateful('model') in build_model."
+        self.profiler = Profiler(self.model, **self.cfg.profiler)
+        self.initialize_ema()
+        self.register_stateful('ema')
+        assert self.ema is None or 'ema' in self.stateful.sources, \
+            "Please register the ema to stateful with self.register_stateful('ema') in build_model."
+        self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
+        # basic statistics on the trained model
+        model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6
+        # one copy of grad, one copy of momentum, one copy of denominator and model weights.
+        # and 4 bytes for each float!
+        mem_usage = model_size * 4 * 4 / 1000
+        self.logger.info("Model size: %.2f M params", model_size)
+        self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage)
+
+    @property
+    def autocast(self):
+        """Convenient autocast (or not) using the solver configuration."""
+        return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)
+
+    def _get_state_source(self, name) -> flashy.state.StateDictSource:
+        # Internal utility to get a state source from the solver
+        return self.stateful.sources[name]
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        """Metric name used to identify the best state. This metric should be stored in the metrics
+        used on the stage for best state identification (most likely, `valid`). If None, then
+        no best state is saved.
+        """
+        return None
+
+    def register_best_state(self, *args: str):
+        """Register state sources in `BestStateDictManager` to keep their best states along with their
+        latest states. The best state will be used at evaluation stages instead of the latest states.
+
+        Shortcut around `BestStateDictManager.register` method. You can pass any number of
+        attribute, included nested attributes and those will be included into the checkpoints
+        and automatically restored when `BaseSolver.restore` is called.
+        """
+        for name in args:
+            state_source = self._get_state_source(name)
+            assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
+            self.best_state.register(name, state_source)
+
+    def register_ema(self, *args: str):
+        """Register state sources for exponential moving average.
+
+        The registered sources are used to instantiate a ModuleDictEMA instance.
+        The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
+        and swapped with the original state sources with self.swap_ema_state() method.
+
+        Usage:
+            self.register_ema('model')
+        """
+        assert self.ema is None, "Cannot register state source to already instantiated EMA."
+        for name in args:
+            self._ema_sources[name] = getattr(self, name)
+
+    def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
+        model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
+        if isinstance(model, fsdp.FSDP):
+            self._fsdp_modules.append(model)
+        return model
+
+    def update_best_state_from_stage(self, stage_name: str = 'valid'):
+        """Update latest best state based on pending metrics of a given stage. This method relies
+        on the `BestStateDictManager.update` method to update the best state_dict with latest weights
+        if the registered states happen to match to the best performing setup.
+        """
+        if self.best_metric_name is None:
+            # when no best metric is defined, the last state is always the best
+            self._new_best_state = True
+            self.logger.info("Updating best state with current state.")
+        else:
+            assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
+            assert self.best_metric_name in self._pending_metrics[stage_name], \
+                f"Best metric not found in {stage_name} metrics. Cannot register best state"
+            current_score = self._pending_metrics[stage_name][self.best_metric_name]
+            all_best_metric_scores = [
+                past_metrics[stage_name][self.best_metric_name]
+                for past_metrics in self.history
+            ]
+            all_best_metric_scores.append(current_score)
+            best_score = min(all_best_metric_scores)
+            self._new_best_state = current_score == best_score
+            if self._new_best_state:
+                old_best = min(all_best_metric_scores[:-1] + [float('inf')])
+                self.logger.info(
+                    f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")
+
+        if self._new_best_state:
+            if self.cfg.fsdp.use:
+                # this will give an empty state dict on all ranks but the rank 0
+                # which will have a copy in memory of the full model.
+                with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+                    for name in self.best_state.states.keys():
+                        state_source = self._get_state_source(name)
+                        self.best_state.update(name, state_source)
+                    # we save to a different dict.
+                    self.fsdp_best_state.update(self.best_state.state_dict())
+                # We cannot efficiently load fsdp_best_state when using FSDP,
+                # so we have do do a second pass, with the local shards.
+            for name in self.best_state.states.keys():
+                state_source = self._get_state_source(name)
+                self.best_state.update(name, state_source)
+
+    def _load_new_state_dict(self, state_dict: dict) -> dict:
+        old_states = {}
+        for name, new_state in state_dict.items():
+            state_source = self._get_state_source(name)
+            old_states[name] = copy_state(state_source.state_dict())
+            state_source.load_state_dict(new_state)
+        return old_states
+
+    @contextmanager
+    def swap_best_state(self):
+        self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
+        old_states = self._load_new_state_dict(self.best_state.state_dict())
+        try:
+            yield
+        finally:
+            self.logger.debug("Swapping back from best to original state")
+            for name, old_state in old_states.items():
+                state_source = self._get_state_source(name)
+                state_source.load_state_dict(old_state)
+
+    @contextmanager
+    def swap_ema_state(self):
+        if self.ema is None:
+            yield
+        else:
+            ema_state_dict = self.ema.state_dict()['state']
+            self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
+            old_states = self._load_new_state_dict(ema_state_dict)
+            try:
+                yield
+            finally:
+                self.logger.debug("Swapping back from EMA state to original state")
+                for name, old_state in old_states.items():
+                    state_source = self._get_state_source(name)
+                    state_source.load_state_dict(old_state)
+
+    @property
+    def is_training(self):
+        return self.current_stage == 'train'
+
+    def log_model_summary(self, model: nn.Module):
+        """Log model summary, architecture and size of the model."""
+        self.logger.info(model)
+        mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
+        self.logger.info("Size: %.1f MB", mb)
+
+    @abstractmethod
+    def build_model(self):
+        """Method to implement to initialize model."""
+        ...
+
+    def initialize_ema(self):
+        """Initialize exponential moving average with the registered sources.
+        EMA object is created if the optim.ema.model.decay value is non-null.
+        """
+        from .builders import get_ema
+        self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
+        if self.ema is None:
+            self.logger.info('No EMA on the model.')
+        else:
+            assert self.cfg.optim.ema.updates > 0
+            self.logger.info(
+                f'Initializing EMA on the model with decay = {self.ema.decay}'
+                f' every {self.cfg.optim.ema.updates} updates'
+            )
+
+    @abstractmethod
+    def build_dataloaders(self):
+        """Method to implement to initialize dataloaders."""
+        ...
+
+    @abstractmethod
+    def show(self):
+        """Method to log any information without running the job."""
+        ...
+
+    @property
+    def log_updates(self):
+        # convenient access to log updates
+        return self._log_updates
+
+    def checkpoint_path(self, **kwargs):
+        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+        return self.folder / checkpoint.checkpoint_name(**kwargs)
+
+    def epoch_checkpoint_path(self, epoch: int, **kwargs):
+        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+        return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)
+
+    def checkpoint_path_with_name(self, name: str, **kwargs):
+        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+        return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)
+
+    def save_checkpoints(self):
+        """Save checkpoint, optionally keeping a copy for a given epoch."""
+        is_sharded = self.cfg.fsdp.use
+        if not flashy.distrib.is_rank_zero() and not is_sharded:
+            return
+        self.logger.info("Model hash: %s", model_hash(self.model))
+        state = self.state_dict()
+        epoch = self.epoch - 1  # pushing metrics will increase the epoch in Flashy, so we do -1 here
+
+        # save minimal state_dict as new checkpoint every X epoch
+        if self.cfg.checkpoint.save_every:
+            if epoch % self.cfg.checkpoint.save_every == 0:
+                minimal_state = state
+                if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
+                    minimal_state = {
+                        name: source for name, source in state.items()
+                        if name in self.cfg.checkpoint.keep_every_states
+                    }
+                epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
+                checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)
+
+        # save checkpoint as latest checkpoint
+        if self.cfg.checkpoint.save_last:
+            last_checkpoint_path = self.checkpoint_path()
+            checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)
+
+        # flush any stale checkpoint to reduce disk footprint
+        checkpoint.flush_stale_checkpoints(self.checkpoint_path())
+
+    def load_from_pretrained(self, name: str) -> dict:
+        raise NotImplementedError("Solver does not provide a way to load pretrained models.")
+
+    def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
+        """Load last checkpoint or the one specified in continue_from.
+
+        Args:
+            load_best (bool): Whether to load from best state dict or not.
+                Best state dict is always used when not loading the current xp.
+            ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
+        Returns:
+            state (dict, optional): The loaded state dictionary.
+        """
+        # load checkpoints from xp folder or cfg.continue_from
+        is_sharded = self.cfg.fsdp.use
+        load_from_path: tp.Optional[Path] = None
+        checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None
+
+        if load_best:
+            self.logger.info("Trying to load state_dict from best state.")
+
+        state: tp.Optional[dict] = None
+        rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
+        current_checkpoint_path = self.checkpoint_path()
+        _pretrained_prefix = '//pretrained/'
+        continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
+        if rank0_checkpoint_path.exists():
+            self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
+            load_from_path = current_checkpoint_path
+            checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
+            checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
+        elif self.cfg.continue_from and not continue_pretrained:
+            self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
+            # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
+            load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
+            if load_from_path is None:
+                self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
+                raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
+            checkpoint_source = checkpoint.CheckpointSource.OTHER
+
+        if load_from_path is not None:
+            state = checkpoint.load_checkpoint(load_from_path, is_sharded)
+        elif continue_pretrained:
+            self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
+            state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
+            checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
+            load_best = True
+
+        # checkpoints are not from the current xp, we only retrieve the best state
+        if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
+            assert state is not None
+            self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
+            load_best = True
+            state = {key: state[key] for key in self._continue_best_source_keys if key in state}
+            # loaded checkpoints are FSDP checkpoints: we're reading the best state
+            # from FSDP and we drop the regular best_state
+            if 'fsdp_best_state' in state and state['fsdp_best_state']:
+                state.pop('best_state', None)
+                self.logger.info("... Loaded checkpoint has FSDP best state")
+            # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
+            # then we're initializing FSDP best state with the regular best state
+            elif self.cfg.fsdp.use:
+                if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
+                    # we swap non-FSDP checkpoints best_state to FSDP-compatible best state
+                    state['fsdp_best_state'] = state.pop('best_state')
+                    self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")
+
+        if state is not None:
+            if load_best:
+                self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
+                for key in set(ignore_state_keys):
+                    if key in state:
+                        state.pop(key)
+                has_best_state = 'best_state' in state or 'fsdp_best_state' in state
+                assert has_best_state, ("Trying to load best state but neither 'best_state'",
+                                        " or 'fsdp_best_state' found in checkpoints.")
+            self.load_state_dict(state)
+
+        # for FSDP, let's make extra sure nothing bad happened with out of sync
+        # checkpoints across workers.
+        epoch = float(self.epoch)
+        avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
+        if avg_epoch != epoch:
+            raise RuntimeError(
+                f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
+                f"but average of epochs is {avg_epoch}, at least one gpu must have a "
+                "different epoch number.")
+
+        # on load_best, properly reinitialize state_dict, best states and ema
+        # otherwise we load from the current xp and don't alter anything
+        if load_best:
+            self.logger.info("Loading state_dict from best state.")
+            if not self.cfg.fsdp.use and self.fsdp_best_state:
+                # loading from an FSDP checkpoint but with FSDP deactivated
+                self.logger.info("... Loading from FSDP best state dict.")
+                self.best_state.load_state_dict(self.fsdp_best_state)
+
+            # if load_best, we permanently override the regular state_dict with the best state
+            if self.cfg.fsdp.use:
+                self.logger.info("FSDP is used, loading from FSDP best state.")
+                with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+                    # this might be really fragile but okay for now.
+                    self.load_state_dict(self.fsdp_best_state)
+            else:
+                # we permanently swap the stateful objects to their best state
+                self._load_new_state_dict(self.best_state.state_dict())
+
+            # the EMA modules should also be instantiated with best state.
+            # the easiest way to do so is to reinitialize a new EMA with best state loaded.
+            if self.ema is not None:
+                self.logger.info("Re-initializing EMA from best state")
+                self.initialize_ema()
+
+            if self.cfg.fsdp.use:
+                self.logger.info("Re-initializing best state after using FSDP best state.")
+                for name in self.best_state.states.keys():
+                    state_source = self._get_state_source(name)
+                    self.best_state.update(name, state_source)
+
+        return state
+
+    def restore(self, load_best: bool = False, replay_metrics: bool = False,
+                ignore_state_keys: tp.List[str] = []) -> bool:
+        """Restore the status of a solver for a given xp.
+
+        Args:
+            load_best (bool): if `True`, load the best state from the checkpoint.
+            replay_metrics (bool): if `True`, logs all the metrics from past epochs.
+            ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
+        """
+        self.logger.info("Restoring weights and history.")
+        restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)
+
+        self.logger.info("Model hash: %s", model_hash(self.model))
+
+        if replay_metrics and len(self.history) > 0:
+            self.logger.info("Replaying past metrics...")
+            for epoch, stages in enumerate(self.history):
+                for stage_name, metrics in stages.items():
+                    # We manually log the metrics summary to the result logger
+                    # as we don't want to add them to the pending metrics
+                    self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
+                                                    formatter=self.get_formatter(stage_name))
+        return restored_checkpoints is not None
+
+    def commit(self, save_checkpoints: bool = True):
+        """Commit metrics to dora and save checkpoints at the end of an epoch."""
+        # we override commit to introduce more complex checkpoint saving behaviors
+        self.history.append(self._pending_metrics)  # This will increase self.epoch
+        if save_checkpoints:
+            self.save_checkpoints()
+        self._start_epoch()
+        if flashy.distrib.is_rank_zero():
+            self.xp.link.update_history(self.history)
+
+    def run_epoch(self):
+        """Run a single epoch with all stages.
+
+        Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
+        Children solvers can extend this method with custom behavior, e.g.:
+
+            def run_epoch(self):
+                ... # custom code
+                super().run_epoch()
+                ... # custom code
+        """
+        self.run_stage('train', self.train)
+        with torch.no_grad():
+            with self.swap_ema_state():
+                self.run_stage('valid', self.valid)
+                # the best state is updated with EMA states if available
+                self.update_best_state_from_stage('valid')
+            with self.swap_best_state():
+                if self.should_run_stage('evaluate'):
+                    self.run_stage('evaluate', self.evaluate)
+                if self.should_run_stage('generate'):
+                    self.run_stage('generate', with_rank_rng()(self.generate))
+
+    def run(self):
+        """Training loop."""
+        assert len(self.state_dict()) > 0
+        self.restore(replay_metrics=True)  # load checkpoint and replay history
+        self.log_hyperparams(dict_from_config(self.cfg))
+        for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
+            if self.should_stop_training():
+                return
+            self.run_epoch()
+            # Commit will send the metrics to Dora and save checkpoints by default.
+            self.commit()
+
+    def should_stop_training(self) -> bool:
+        """Check whether we should stop training or not."""
+        return self.epoch > self.cfg.optim.epochs
+
+    def should_run_stage(self, stage_name) -> bool:
+        """Check whether we want to run the specified stages."""
+        stage_every = self.cfg[stage_name].get('every', None)
+        is_last_epoch = self.epoch == self.cfg.optim.epochs
+        is_epoch_every = (stage_every and self.epoch % stage_every == 0)
+        return is_last_epoch or is_epoch_every
+
+    @abstractmethod
+    def run_step(self, idx: int, batch: tp.Any, metrics: dict):
+        """Perform one training or valid step on a given batch."""
+        ...
+
+    def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
+        """Common logic for train and valid stages."""
+        self.model.train(self.is_training)
+
+        loader = self.dataloaders[dataset_split]
+        # get a different order for distributed training, otherwise this will get ignored
+        if flashy.distrib.world_size() > 1 \
+           and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
+            loader.sampler.set_epoch(self.epoch)
+        updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
+        if self.cfg.benchmark_no_load:
+            self.logger.warning("Fake loading for benchmarking: re-using first batch")
+            batch = next(iter(loader))
+            loader = [batch] * updates_per_epoch  # type: ignore
+        lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
+        average = flashy.averager()  # epoch wise average
+        instant_average = flashy.averager()  # average between two logging
+        metrics: dict = {}
+
+        with self.profiler, self.deadlock_detect:  # profiler will only run for the first 20 updates.
+            for idx, batch in enumerate(lp):
+                self.deadlock_detect.update('batch')
+                if idx >= updates_per_epoch:
+                    break
+                metrics = {}
+                metrics = self.run_step(idx, batch, metrics)
+                self.deadlock_detect.update('step')
+                # run EMA step
+                if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
+                    self.logger.debug("EMA model step")
+                    self.ema.step()
+                self.deadlock_detect.update('ema')
+                self.profiler.step()
+                instant_metrics = instant_average(metrics)
+                if lp.update(**instant_metrics):
+                    instant_average = flashy.averager()  # reset averager between two logging
+                metrics = average(metrics)  # epoch wise average
+                self.deadlock_detect.update('end_batch')
+
+        metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
+        return metrics
+
+    def train(self):
+        """Train stage."""
+        return self.common_train_valid('train')
+
+    def valid(self):
+        """Valid stage."""
+        return self.common_train_valid('valid')
+
+    @abstractmethod
+    def evaluate(self):
+        """Evaluate stage."""
+        ...
+
+    @abstractmethod
+    def generate(self):
+        """Generate stage."""
+        ...
+
+    def run_one_stage(self, stage_name: str):
+        """Run only the specified stage.
+        This method is useful to only generate samples from a trained experiment
+        or rerun the validation or evaluation stages.
+        """
+        fn = {
+            'generate': with_rank_rng()(self.generate),
+            'evaluate': self.evaluate,
+            'valid': self.valid,
+        }
+        if stage_name not in fn:
+            raise ValueError(f'Trying to run stage {stage_name} is not supported.')
+        assert len(self.state_dict()) > 0
+        self._start_epoch()
+        with torch.no_grad(), self.swap_best_state():
+            self.run_stage(stage_name, fn[stage_name])
+        if not self.cfg.execute_inplace:
+            self.commit(save_checkpoints=False)
+
+    @staticmethod
+    def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+                                 device: tp.Optional[str] = None, autocast: bool = True,
+                                 batch_size: tp.Optional[int] = None,
+                                 override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                                 **kwargs):
+        """Mostly a convenience function around audiocraft.train.get_solver_from_sig,
+        populating all the proper param, deactivating EMA, FSDP, loading the best state,
+        basically all you need to get a solver ready to "play" with in single GPU mode
+        and with minimal memory overhead.
+
+        Args:
+            sig (str): signature to load.
+            dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+            device (str or None): potential device, as a string, i.e. 'cuda'.
+            override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+        """
+        from audiocraft import train
+        our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+        our_override_cfg['autocast'] = autocast
+        if dtype is not None:
+            our_override_cfg['dtype'] = dtype
+        if device is not None:
+            our_override_cfg['device'] = device
+        if batch_size is not None:
+            our_override_cfg['dataset'] = {'batch_size': batch_size}
+        if override_cfg is None:
+            override_cfg = {}
+        override_cfg = omegaconf.OmegaConf.merge(
+            omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
+        solver = train.get_solver_from_sig(
+            sig, override_cfg=override_cfg,
+            load_best=True, disable_fsdp=True,
+            ignore_state_keys=['optimizer', 'ema'], **kwargs)
+        solver.model.eval()
+        return solver
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class StandardSolver +(cfg: omegaconf.dictconfig.DictConfig) +
+
+

Standard solver for AudioCraft.

+

The standard solver implements a base training loop with the following stages: +train, valid, evaluate and generate that are expected to be all defined for +solvers in AudioCraft. It also provides a nice default management of Dora history replay, +checkpoint management across epoch, and logging configuration.

+

AudioCraft solvers must inherit from the StandardSolver and define the methods +associated to each stage as well as the show, build_model and build_dataloaders methods.

+
+ +Expand source code + +
class StandardSolver(ABC, flashy.BaseSolver):
+    """Standard solver for AudioCraft.
+
+    The standard solver implements a base training loop with the following stages:
+    train, valid, evaluate and generate that are expected to be all defined for
+    solvers in AudioCraft. It also provides a nice default management of Dora history replay,
+    checkpoint management across epoch, and logging configuration.
+
+    AudioCraft solvers must inherit from the StandardSolver and define the methods
+    associated to each stage as well as the show, build_model and build_dataloaders methods.
+    """
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__()
+        self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}")
+        self.logger.info(f"All XP logs are stored in {self.xp.folder}")
+        self.cfg = cfg
+        self.device = cfg.device
+        self.model: nn.Module
+        self._continue_best_source_keys = ['best_state', 'fsdp_best_state']
+        self._fsdp_modules: tp.List[fsdp.FSDP] = []
+        self._ema_sources: nn.ModuleDict = nn.ModuleDict()
+        self.ema: tp.Optional[optim.ModuleDictEMA] = None
+        self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
+        self._log_updates = self.cfg.logging.get('log_updates', 10)
+        if self.cfg.logging.log_tensorboard:
+            self.init_tensorboard(**self.cfg.get('tensorboard'))
+        if self.cfg.logging.log_wandb and self:
+            self.init_wandb(**self.cfg.get('wandb'))
+        # keep a copy of the best performing state for stateful objects
+        # used for evaluation and generation stages
+        dtype_best: tp.Optional[torch.dtype] = None
+        if self.cfg.fsdp.use:
+            dtype_best = getattr(torch, self.cfg.fsdp.param_dtype)  # type: ignore
+            assert isinstance(dtype_best, torch.dtype)
+        elif self.cfg.autocast:
+            dtype_best = getattr(torch, self.cfg.autocast_dtype)  # type: ignore
+            assert isinstance(dtype_best, torch.dtype)
+        self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
+        # Hacky support for keeping a copy of the full best state in rank0.
+        self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
+        self.register_stateful('best_state', 'fsdp_best_state')  # register best_state object to keep it in state_dict
+        self._new_best_state: bool = False  # should save a new checkpoint
+        # instantiate datasets and appropriate number of updates per epoch
+        self.build_dataloaders()
+        if self.cfg.execute_only is None:
+            assert 'train' in self.dataloaders, "The train dataset split must be provided."
+            assert 'valid' in self.dataloaders, "The valid dataset split must be provided."
+        self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0
+        if self.cfg.optim.updates_per_epoch:
+            self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
+        self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
+        # instantiate model & exponential moving average on the model
+        self.build_model()
+        self.logger.info("Model hash: %s", model_hash(self.model))
+        assert 'model' in self.stateful.sources, \
+            "Please register the model to stateful with self.register_stateful('model') in build_model."
+        self.profiler = Profiler(self.model, **self.cfg.profiler)
+        self.initialize_ema()
+        self.register_stateful('ema')
+        assert self.ema is None or 'ema' in self.stateful.sources, \
+            "Please register the ema to stateful with self.register_stateful('ema') in build_model."
+        self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
+        # basic statistics on the trained model
+        model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6
+        # one copy of grad, one copy of momentum, one copy of denominator and model weights.
+        # and 4 bytes for each float!
+        mem_usage = model_size * 4 * 4 / 1000
+        self.logger.info("Model size: %.2f M params", model_size)
+        self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage)
+
+    @property
+    def autocast(self):
+        """Convenient autocast (or not) using the solver configuration."""
+        return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)
+
+    def _get_state_source(self, name) -> flashy.state.StateDictSource:
+        # Internal utility to get a state source from the solver
+        return self.stateful.sources[name]
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        """Metric name used to identify the best state. This metric should be stored in the metrics
+        used on the stage for best state identification (most likely, `valid`). If None, then
+        no best state is saved.
+        """
+        return None
+
+    def register_best_state(self, *args: str):
+        """Register state sources in `BestStateDictManager` to keep their best states along with their
+        latest states. The best state will be used at evaluation stages instead of the latest states.
+
+        Shortcut around `BestStateDictManager.register` method. You can pass any number of
+        attribute, included nested attributes and those will be included into the checkpoints
+        and automatically restored when `BaseSolver.restore` is called.
+        """
+        for name in args:
+            state_source = self._get_state_source(name)
+            assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
+            self.best_state.register(name, state_source)
+
+    def register_ema(self, *args: str):
+        """Register state sources for exponential moving average.
+
+        The registered sources are used to instantiate a ModuleDictEMA instance.
+        The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
+        and swapped with the original state sources with self.swap_ema_state() method.
+
+        Usage:
+            self.register_ema('model')
+        """
+        assert self.ema is None, "Cannot register state source to already instantiated EMA."
+        for name in args:
+            self._ema_sources[name] = getattr(self, name)
+
+    def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
+        model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
+        if isinstance(model, fsdp.FSDP):
+            self._fsdp_modules.append(model)
+        return model
+
+    def update_best_state_from_stage(self, stage_name: str = 'valid'):
+        """Update latest best state based on pending metrics of a given stage. This method relies
+        on the `BestStateDictManager.update` method to update the best state_dict with latest weights
+        if the registered states happen to match to the best performing setup.
+        """
+        if self.best_metric_name is None:
+            # when no best metric is defined, the last state is always the best
+            self._new_best_state = True
+            self.logger.info("Updating best state with current state.")
+        else:
+            assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
+            assert self.best_metric_name in self._pending_metrics[stage_name], \
+                f"Best metric not found in {stage_name} metrics. Cannot register best state"
+            current_score = self._pending_metrics[stage_name][self.best_metric_name]
+            all_best_metric_scores = [
+                past_metrics[stage_name][self.best_metric_name]
+                for past_metrics in self.history
+            ]
+            all_best_metric_scores.append(current_score)
+            best_score = min(all_best_metric_scores)
+            self._new_best_state = current_score == best_score
+            if self._new_best_state:
+                old_best = min(all_best_metric_scores[:-1] + [float('inf')])
+                self.logger.info(
+                    f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")
+
+        if self._new_best_state:
+            if self.cfg.fsdp.use:
+                # this will give an empty state dict on all ranks but the rank 0
+                # which will have a copy in memory of the full model.
+                with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+                    for name in self.best_state.states.keys():
+                        state_source = self._get_state_source(name)
+                        self.best_state.update(name, state_source)
+                    # we save to a different dict.
+                    self.fsdp_best_state.update(self.best_state.state_dict())
+                # We cannot efficiently load fsdp_best_state when using FSDP,
+                # so we have do do a second pass, with the local shards.
+            for name in self.best_state.states.keys():
+                state_source = self._get_state_source(name)
+                self.best_state.update(name, state_source)
+
+    def _load_new_state_dict(self, state_dict: dict) -> dict:
+        old_states = {}
+        for name, new_state in state_dict.items():
+            state_source = self._get_state_source(name)
+            old_states[name] = copy_state(state_source.state_dict())
+            state_source.load_state_dict(new_state)
+        return old_states
+
+    @contextmanager
+    def swap_best_state(self):
+        self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
+        old_states = self._load_new_state_dict(self.best_state.state_dict())
+        try:
+            yield
+        finally:
+            self.logger.debug("Swapping back from best to original state")
+            for name, old_state in old_states.items():
+                state_source = self._get_state_source(name)
+                state_source.load_state_dict(old_state)
+
+    @contextmanager
+    def swap_ema_state(self):
+        if self.ema is None:
+            yield
+        else:
+            ema_state_dict = self.ema.state_dict()['state']
+            self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
+            old_states = self._load_new_state_dict(ema_state_dict)
+            try:
+                yield
+            finally:
+                self.logger.debug("Swapping back from EMA state to original state")
+                for name, old_state in old_states.items():
+                    state_source = self._get_state_source(name)
+                    state_source.load_state_dict(old_state)
+
+    @property
+    def is_training(self):
+        return self.current_stage == 'train'
+
+    def log_model_summary(self, model: nn.Module):
+        """Log model summary, architecture and size of the model."""
+        self.logger.info(model)
+        mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
+        self.logger.info("Size: %.1f MB", mb)
+
+    @abstractmethod
+    def build_model(self):
+        """Method to implement to initialize model."""
+        ...
+
+    def initialize_ema(self):
+        """Initialize exponential moving average with the registered sources.
+        EMA object is created if the optim.ema.model.decay value is non-null.
+        """
+        from .builders import get_ema
+        self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
+        if self.ema is None:
+            self.logger.info('No EMA on the model.')
+        else:
+            assert self.cfg.optim.ema.updates > 0
+            self.logger.info(
+                f'Initializing EMA on the model with decay = {self.ema.decay}'
+                f' every {self.cfg.optim.ema.updates} updates'
+            )
+
+    @abstractmethod
+    def build_dataloaders(self):
+        """Method to implement to initialize dataloaders."""
+        ...
+
+    @abstractmethod
+    def show(self):
+        """Method to log any information without running the job."""
+        ...
+
+    @property
+    def log_updates(self):
+        # convenient access to log updates
+        return self._log_updates
+
+    def checkpoint_path(self, **kwargs):
+        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+        return self.folder / checkpoint.checkpoint_name(**kwargs)
+
+    def epoch_checkpoint_path(self, epoch: int, **kwargs):
+        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+        return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)
+
+    def checkpoint_path_with_name(self, name: str, **kwargs):
+        kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+        return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)
+
+    def save_checkpoints(self):
+        """Save checkpoint, optionally keeping a copy for a given epoch."""
+        is_sharded = self.cfg.fsdp.use
+        if not flashy.distrib.is_rank_zero() and not is_sharded:
+            return
+        self.logger.info("Model hash: %s", model_hash(self.model))
+        state = self.state_dict()
+        epoch = self.epoch - 1  # pushing metrics will increase the epoch in Flashy, so we do -1 here
+
+        # save minimal state_dict as new checkpoint every X epoch
+        if self.cfg.checkpoint.save_every:
+            if epoch % self.cfg.checkpoint.save_every == 0:
+                minimal_state = state
+                if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
+                    minimal_state = {
+                        name: source for name, source in state.items()
+                        if name in self.cfg.checkpoint.keep_every_states
+                    }
+                epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
+                checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)
+
+        # save checkpoint as latest checkpoint
+        if self.cfg.checkpoint.save_last:
+            last_checkpoint_path = self.checkpoint_path()
+            checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)
+
+        # flush any stale checkpoint to reduce disk footprint
+        checkpoint.flush_stale_checkpoints(self.checkpoint_path())
+
+    def load_from_pretrained(self, name: str) -> dict:
+        raise NotImplementedError("Solver does not provide a way to load pretrained models.")
+
+    def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
+        """Load last checkpoint or the one specified in continue_from.
+
+        Args:
+            load_best (bool): Whether to load from best state dict or not.
+                Best state dict is always used when not loading the current xp.
+            ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
+        Returns:
+            state (dict, optional): The loaded state dictionary.
+        """
+        # load checkpoints from xp folder or cfg.continue_from
+        is_sharded = self.cfg.fsdp.use
+        load_from_path: tp.Optional[Path] = None
+        checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None
+
+        if load_best:
+            self.logger.info("Trying to load state_dict from best state.")
+
+        state: tp.Optional[dict] = None
+        rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
+        current_checkpoint_path = self.checkpoint_path()
+        _pretrained_prefix = '//pretrained/'
+        continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
+        if rank0_checkpoint_path.exists():
+            self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
+            load_from_path = current_checkpoint_path
+            checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
+            checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
+        elif self.cfg.continue_from and not continue_pretrained:
+            self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
+            # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
+            load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
+            if load_from_path is None:
+                self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
+                raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
+            checkpoint_source = checkpoint.CheckpointSource.OTHER
+
+        if load_from_path is not None:
+            state = checkpoint.load_checkpoint(load_from_path, is_sharded)
+        elif continue_pretrained:
+            self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
+            state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
+            checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
+            load_best = True
+
+        # checkpoints are not from the current xp, we only retrieve the best state
+        if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
+            assert state is not None
+            self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
+            load_best = True
+            state = {key: state[key] for key in self._continue_best_source_keys if key in state}
+            # loaded checkpoints are FSDP checkpoints: we're reading the best state
+            # from FSDP and we drop the regular best_state
+            if 'fsdp_best_state' in state and state['fsdp_best_state']:
+                state.pop('best_state', None)
+                self.logger.info("... Loaded checkpoint has FSDP best state")
+            # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
+            # then we're initializing FSDP best state with the regular best state
+            elif self.cfg.fsdp.use:
+                if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
+                    # we swap non-FSDP checkpoints best_state to FSDP-compatible best state
+                    state['fsdp_best_state'] = state.pop('best_state')
+                    self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")
+
+        if state is not None:
+            if load_best:
+                self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
+                for key in set(ignore_state_keys):
+                    if key in state:
+                        state.pop(key)
+                has_best_state = 'best_state' in state or 'fsdp_best_state' in state
+                assert has_best_state, ("Trying to load best state but neither 'best_state'",
+                                        " or 'fsdp_best_state' found in checkpoints.")
+            self.load_state_dict(state)
+
+        # for FSDP, let's make extra sure nothing bad happened with out of sync
+        # checkpoints across workers.
+        epoch = float(self.epoch)
+        avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
+        if avg_epoch != epoch:
+            raise RuntimeError(
+                f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
+                f"but average of epochs is {avg_epoch}, at least one gpu must have a "
+                "different epoch number.")
+
+        # on load_best, properly reinitialize state_dict, best states and ema
+        # otherwise we load from the current xp and don't alter anything
+        if load_best:
+            self.logger.info("Loading state_dict from best state.")
+            if not self.cfg.fsdp.use and self.fsdp_best_state:
+                # loading from an FSDP checkpoint but with FSDP deactivated
+                self.logger.info("... Loading from FSDP best state dict.")
+                self.best_state.load_state_dict(self.fsdp_best_state)
+
+            # if load_best, we permanently override the regular state_dict with the best state
+            if self.cfg.fsdp.use:
+                self.logger.info("FSDP is used, loading from FSDP best state.")
+                with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+                    # this might be really fragile but okay for now.
+                    self.load_state_dict(self.fsdp_best_state)
+            else:
+                # we permanently swap the stateful objects to their best state
+                self._load_new_state_dict(self.best_state.state_dict())
+
+            # the EMA modules should also be instantiated with best state.
+            # the easiest way to do so is to reinitialize a new EMA with best state loaded.
+            if self.ema is not None:
+                self.logger.info("Re-initializing EMA from best state")
+                self.initialize_ema()
+
+            if self.cfg.fsdp.use:
+                self.logger.info("Re-initializing best state after using FSDP best state.")
+                for name in self.best_state.states.keys():
+                    state_source = self._get_state_source(name)
+                    self.best_state.update(name, state_source)
+
+        return state
+
+    def restore(self, load_best: bool = False, replay_metrics: bool = False,
+                ignore_state_keys: tp.List[str] = []) -> bool:
+        """Restore the status of a solver for a given xp.
+
+        Args:
+            load_best (bool): if `True`, load the best state from the checkpoint.
+            replay_metrics (bool): if `True`, logs all the metrics from past epochs.
+            ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
+        """
+        self.logger.info("Restoring weights and history.")
+        restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)
+
+        self.logger.info("Model hash: %s", model_hash(self.model))
+
+        if replay_metrics and len(self.history) > 0:
+            self.logger.info("Replaying past metrics...")
+            for epoch, stages in enumerate(self.history):
+                for stage_name, metrics in stages.items():
+                    # We manually log the metrics summary to the result logger
+                    # as we don't want to add them to the pending metrics
+                    self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
+                                                    formatter=self.get_formatter(stage_name))
+        return restored_checkpoints is not None
+
+    def commit(self, save_checkpoints: bool = True):
+        """Commit metrics to dora and save checkpoints at the end of an epoch."""
+        # we override commit to introduce more complex checkpoint saving behaviors
+        self.history.append(self._pending_metrics)  # This will increase self.epoch
+        if save_checkpoints:
+            self.save_checkpoints()
+        self._start_epoch()
+        if flashy.distrib.is_rank_zero():
+            self.xp.link.update_history(self.history)
+
+    def run_epoch(self):
+        """Run a single epoch with all stages.
+
+        Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
+        Children solvers can extend this method with custom behavior, e.g.:
+
+            def run_epoch(self):
+                ... # custom code
+                super().run_epoch()
+                ... # custom code
+        """
+        self.run_stage('train', self.train)
+        with torch.no_grad():
+            with self.swap_ema_state():
+                self.run_stage('valid', self.valid)
+                # the best state is updated with EMA states if available
+                self.update_best_state_from_stage('valid')
+            with self.swap_best_state():
+                if self.should_run_stage('evaluate'):
+                    self.run_stage('evaluate', self.evaluate)
+                if self.should_run_stage('generate'):
+                    self.run_stage('generate', with_rank_rng()(self.generate))
+
+    def run(self):
+        """Training loop."""
+        assert len(self.state_dict()) > 0
+        self.restore(replay_metrics=True)  # load checkpoint and replay history
+        self.log_hyperparams(dict_from_config(self.cfg))
+        for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
+            if self.should_stop_training():
+                return
+            self.run_epoch()
+            # Commit will send the metrics to Dora and save checkpoints by default.
+            self.commit()
+
+    def should_stop_training(self) -> bool:
+        """Check whether we should stop training or not."""
+        return self.epoch > self.cfg.optim.epochs
+
+    def should_run_stage(self, stage_name) -> bool:
+        """Check whether we want to run the specified stages."""
+        stage_every = self.cfg[stage_name].get('every', None)
+        is_last_epoch = self.epoch == self.cfg.optim.epochs
+        is_epoch_every = (stage_every and self.epoch % stage_every == 0)
+        return is_last_epoch or is_epoch_every
+
+    @abstractmethod
+    def run_step(self, idx: int, batch: tp.Any, metrics: dict):
+        """Perform one training or valid step on a given batch."""
+        ...
+
+    def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
+        """Common logic for train and valid stages."""
+        self.model.train(self.is_training)
+
+        loader = self.dataloaders[dataset_split]
+        # get a different order for distributed training, otherwise this will get ignored
+        if flashy.distrib.world_size() > 1 \
+           and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
+            loader.sampler.set_epoch(self.epoch)
+        updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
+        if self.cfg.benchmark_no_load:
+            self.logger.warning("Fake loading for benchmarking: re-using first batch")
+            batch = next(iter(loader))
+            loader = [batch] * updates_per_epoch  # type: ignore
+        lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
+        average = flashy.averager()  # epoch wise average
+        instant_average = flashy.averager()  # average between two logging
+        metrics: dict = {}
+
+        with self.profiler, self.deadlock_detect:  # profiler will only run for the first 20 updates.
+            for idx, batch in enumerate(lp):
+                self.deadlock_detect.update('batch')
+                if idx >= updates_per_epoch:
+                    break
+                metrics = {}
+                metrics = self.run_step(idx, batch, metrics)
+                self.deadlock_detect.update('step')
+                # run EMA step
+                if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
+                    self.logger.debug("EMA model step")
+                    self.ema.step()
+                self.deadlock_detect.update('ema')
+                self.profiler.step()
+                instant_metrics = instant_average(metrics)
+                if lp.update(**instant_metrics):
+                    instant_average = flashy.averager()  # reset averager between two logging
+                metrics = average(metrics)  # epoch wise average
+                self.deadlock_detect.update('end_batch')
+
+        metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
+        return metrics
+
+    def train(self):
+        """Train stage."""
+        return self.common_train_valid('train')
+
+    def valid(self):
+        """Valid stage."""
+        return self.common_train_valid('valid')
+
+    @abstractmethod
+    def evaluate(self):
+        """Evaluate stage."""
+        ...
+
+    @abstractmethod
+    def generate(self):
+        """Generate stage."""
+        ...
+
+    def run_one_stage(self, stage_name: str):
+        """Run only the specified stage.
+        This method is useful to only generate samples from a trained experiment
+        or rerun the validation or evaluation stages.
+        """
+        fn = {
+            'generate': with_rank_rng()(self.generate),
+            'evaluate': self.evaluate,
+            'valid': self.valid,
+        }
+        if stage_name not in fn:
+            raise ValueError(f'Trying to run stage {stage_name} is not supported.')
+        assert len(self.state_dict()) > 0
+        self._start_epoch()
+        with torch.no_grad(), self.swap_best_state():
+            self.run_stage(stage_name, fn[stage_name])
+        if not self.cfg.execute_inplace:
+            self.commit(save_checkpoints=False)
+
+    @staticmethod
+    def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+                                 device: tp.Optional[str] = None, autocast: bool = True,
+                                 batch_size: tp.Optional[int] = None,
+                                 override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                                 **kwargs):
+        """Mostly a convenience function around audiocraft.train.get_solver_from_sig,
+        populating all the proper param, deactivating EMA, FSDP, loading the best state,
+        basically all you need to get a solver ready to "play" with in single GPU mode
+        and with minimal memory overhead.
+
+        Args:
+            sig (str): signature to load.
+            dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+            device (str or None): potential device, as a string, i.e. 'cuda'.
+            override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+        """
+        from audiocraft import train
+        our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+        our_override_cfg['autocast'] = autocast
+        if dtype is not None:
+            our_override_cfg['dtype'] = dtype
+        if device is not None:
+            our_override_cfg['device'] = device
+        if batch_size is not None:
+            our_override_cfg['dataset'] = {'batch_size': batch_size}
+        if override_cfg is None:
+            override_cfg = {}
+        override_cfg = omegaconf.OmegaConf.merge(
+            omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
+        solver = train.get_solver_from_sig(
+            sig, override_cfg=override_cfg,
+            load_best=True, disable_fsdp=True,
+            ignore_state_keys=['optimizer', 'ema'], **kwargs)
+        solver.model.eval()
+        return solver
+
+

Ancestors

+
    +
  • abc.ABC
  • +
  • flashy.solver.BaseSolver
  • +
+

Subclasses

+ +

Static methods

+
+
+def get_eval_solver_from_sig(sig: str, dtype: Optional[str] = None, device: Optional[str] = None, autocast: bool = True, batch_size: Optional[int] = None, override_cfg: Union[dict, omegaconf.dictconfig.DictConfig, None] = None, **kwargs) +
+
+

Mostly a convenience function around audiocraft.train.get_solver_from_sig, +populating all the proper param, deactivating EMA, FSDP, loading the best state, +basically all you need to get a solver ready to "play" with in single GPU mode +and with minimal memory overhead.

+

Args

+
+
sig : str
+
signature to load.
+
dtype : str or None
+
potential dtype, as a string, i.e. 'float16'.
+
device : str or None
+
potential device, as a string, i.e. 'cuda'.
+
override_cfg : dict or omegaconf.DictConfig or None
+
potential device, as a string, i.e. 'cuda'.
+
+
+ +Expand source code + +
@staticmethod
+def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+                             device: tp.Optional[str] = None, autocast: bool = True,
+                             batch_size: tp.Optional[int] = None,
+                             override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                             **kwargs):
+    """Mostly a convenience function around audiocraft.train.get_solver_from_sig,
+    populating all the proper param, deactivating EMA, FSDP, loading the best state,
+    basically all you need to get a solver ready to "play" with in single GPU mode
+    and with minimal memory overhead.
+
+    Args:
+        sig (str): signature to load.
+        dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+        device (str or None): potential device, as a string, i.e. 'cuda'.
+        override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+    """
+    from audiocraft import train
+    our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+    our_override_cfg['autocast'] = autocast
+    if dtype is not None:
+        our_override_cfg['dtype'] = dtype
+    if device is not None:
+        our_override_cfg['device'] = device
+    if batch_size is not None:
+        our_override_cfg['dataset'] = {'batch_size': batch_size}
+    if override_cfg is None:
+        override_cfg = {}
+    override_cfg = omegaconf.OmegaConf.merge(
+        omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
+    solver = train.get_solver_from_sig(
+        sig, override_cfg=override_cfg,
+        load_best=True, disable_fsdp=True,
+        ignore_state_keys=['optimizer', 'ema'], **kwargs)
+    solver.model.eval()
+    return solver
+
+
+
+

Instance variables

+
+
var autocast
+
+

Convenient autocast (or not) using the solver configuration.

+
+ +Expand source code + +
@property
+def autocast(self):
+    """Convenient autocast (or not) using the solver configuration."""
+    return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)
+
+
+
var best_metric_name : Optional[str]
+
+

Metric name used to identify the best state. This metric should be stored in the metrics +used on the stage for best state identification (most likely, valid). If None, then +no best state is saved.

+
+ +Expand source code + +
@property
+def best_metric_name(self) -> tp.Optional[str]:
+    """Metric name used to identify the best state. This metric should be stored in the metrics
+    used on the stage for best state identification (most likely, `valid`). If None, then
+    no best state is saved.
+    """
+    return None
+
+
+
var is_training
+
+
+
+ +Expand source code + +
@property
+def is_training(self):
+    return self.current_stage == 'train'
+
+
+
var log_updates
+
+
+
+ +Expand source code + +
@property
+def log_updates(self):
+    # convenient access to log updates
+    return self._log_updates
+
+
+
+

Methods

+
+
+def build_dataloaders(self) +
+
+

Method to implement to initialize dataloaders.

+
+ +Expand source code + +
@abstractmethod
+def build_dataloaders(self):
+    """Method to implement to initialize dataloaders."""
+    ...
+
+
+
+def build_model(self) +
+
+

Method to implement to initialize model.

+
+ +Expand source code + +
@abstractmethod
+def build_model(self):
+    """Method to implement to initialize model."""
+    ...
+
+
+
+def checkpoint_path(self, **kwargs) +
+
+
+
+ +Expand source code + +
def checkpoint_path(self, **kwargs):
+    kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+    return self.folder / checkpoint.checkpoint_name(**kwargs)
+
+
+
+def checkpoint_path_with_name(self, name: str, **kwargs) +
+
+
+
+ +Expand source code + +
def checkpoint_path_with_name(self, name: str, **kwargs):
+    kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+    return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)
+
+
+
+def commit(self, save_checkpoints: bool = True) +
+
+

Commit metrics to dora and save checkpoints at the end of an epoch.

+
+ +Expand source code + +
def commit(self, save_checkpoints: bool = True):
+    """Commit metrics to dora and save checkpoints at the end of an epoch."""
+    # we override commit to introduce more complex checkpoint saving behaviors
+    self.history.append(self._pending_metrics)  # This will increase self.epoch
+    if save_checkpoints:
+        self.save_checkpoints()
+    self._start_epoch()
+    if flashy.distrib.is_rank_zero():
+        self.xp.link.update_history(self.history)
+
+
+
+def common_train_valid(self, dataset_split: str, **kwargs: Any) +
+
+

Common logic for train and valid stages.

+
+ +Expand source code + +
def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
+    """Common logic for train and valid stages."""
+    self.model.train(self.is_training)
+
+    loader = self.dataloaders[dataset_split]
+    # get a different order for distributed training, otherwise this will get ignored
+    if flashy.distrib.world_size() > 1 \
+       and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
+        loader.sampler.set_epoch(self.epoch)
+    updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
+    if self.cfg.benchmark_no_load:
+        self.logger.warning("Fake loading for benchmarking: re-using first batch")
+        batch = next(iter(loader))
+        loader = [batch] * updates_per_epoch  # type: ignore
+    lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
+    average = flashy.averager()  # epoch wise average
+    instant_average = flashy.averager()  # average between two logging
+    metrics: dict = {}
+
+    with self.profiler, self.deadlock_detect:  # profiler will only run for the first 20 updates.
+        for idx, batch in enumerate(lp):
+            self.deadlock_detect.update('batch')
+            if idx >= updates_per_epoch:
+                break
+            metrics = {}
+            metrics = self.run_step(idx, batch, metrics)
+            self.deadlock_detect.update('step')
+            # run EMA step
+            if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
+                self.logger.debug("EMA model step")
+                self.ema.step()
+            self.deadlock_detect.update('ema')
+            self.profiler.step()
+            instant_metrics = instant_average(metrics)
+            if lp.update(**instant_metrics):
+                instant_average = flashy.averager()  # reset averager between two logging
+            metrics = average(metrics)  # epoch wise average
+            self.deadlock_detect.update('end_batch')
+
+    metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
+    return metrics
+
+
+
+def epoch_checkpoint_path(self, epoch: int, **kwargs) +
+
+
+
+ +Expand source code + +
def epoch_checkpoint_path(self, epoch: int, **kwargs):
+    kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
+    return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)
+
+
+
+def evaluate(self) +
+
+

Evaluate stage.

+
+ +Expand source code + +
@abstractmethod
+def evaluate(self):
+    """Evaluate stage."""
+    ...
+
+
+
+def generate(self) +
+
+

Generate stage.

+
+ +Expand source code + +
@abstractmethod
+def generate(self):
+    """Generate stage."""
+    ...
+
+
+
+def initialize_ema(self) +
+
+

Initialize exponential moving average with the registered sources. +EMA object is created if the optim.ema.model.decay value is non-null.

+
+ +Expand source code + +
def initialize_ema(self):
+    """Initialize exponential moving average with the registered sources.
+    EMA object is created if the optim.ema.model.decay value is non-null.
+    """
+    from .builders import get_ema
+    self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
+    if self.ema is None:
+        self.logger.info('No EMA on the model.')
+    else:
+        assert self.cfg.optim.ema.updates > 0
+        self.logger.info(
+            f'Initializing EMA on the model with decay = {self.ema.decay}'
+            f' every {self.cfg.optim.ema.updates} updates'
+        )
+
+
+
+def load_checkpoints(self, load_best: bool = False, ignore_state_keys: List[str] = []) ‑> Optional[dict] +
+
+

Load last checkpoint or the one specified in continue_from.

+

Args

+
+
load_best : bool
+
Whether to load from best state dict or not. +Best state dict is always used when not loading the current xp.
+
ignore_state_keys : list of str
+
List of sources to ignore when loading the state, e.g. optimizer.
+
+

Returns

+

state (dict, optional): The loaded state dictionary.

+
+ +Expand source code + +
def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
+    """Load last checkpoint or the one specified in continue_from.
+
+    Args:
+        load_best (bool): Whether to load from best state dict or not.
+            Best state dict is always used when not loading the current xp.
+        ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
+    Returns:
+        state (dict, optional): The loaded state dictionary.
+    """
+    # load checkpoints from xp folder or cfg.continue_from
+    is_sharded = self.cfg.fsdp.use
+    load_from_path: tp.Optional[Path] = None
+    checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None
+
+    if load_best:
+        self.logger.info("Trying to load state_dict from best state.")
+
+    state: tp.Optional[dict] = None
+    rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
+    current_checkpoint_path = self.checkpoint_path()
+    _pretrained_prefix = '//pretrained/'
+    continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
+    if rank0_checkpoint_path.exists():
+        self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
+        load_from_path = current_checkpoint_path
+        checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
+        checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
+    elif self.cfg.continue_from and not continue_pretrained:
+        self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
+        # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
+        load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
+        if load_from_path is None:
+            self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
+            raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
+        checkpoint_source = checkpoint.CheckpointSource.OTHER
+
+    if load_from_path is not None:
+        state = checkpoint.load_checkpoint(load_from_path, is_sharded)
+    elif continue_pretrained:
+        self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
+        state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
+        checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
+        load_best = True
+
+    # checkpoints are not from the current xp, we only retrieve the best state
+    if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
+        assert state is not None
+        self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
+        load_best = True
+        state = {key: state[key] for key in self._continue_best_source_keys if key in state}
+        # loaded checkpoints are FSDP checkpoints: we're reading the best state
+        # from FSDP and we drop the regular best_state
+        if 'fsdp_best_state' in state and state['fsdp_best_state']:
+            state.pop('best_state', None)
+            self.logger.info("... Loaded checkpoint has FSDP best state")
+        # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
+        # then we're initializing FSDP best state with the regular best state
+        elif self.cfg.fsdp.use:
+            if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
+                # we swap non-FSDP checkpoints best_state to FSDP-compatible best state
+                state['fsdp_best_state'] = state.pop('best_state')
+                self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")
+
+    if state is not None:
+        if load_best:
+            self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
+            for key in set(ignore_state_keys):
+                if key in state:
+                    state.pop(key)
+            has_best_state = 'best_state' in state or 'fsdp_best_state' in state
+            assert has_best_state, ("Trying to load best state but neither 'best_state'",
+                                    " or 'fsdp_best_state' found in checkpoints.")
+        self.load_state_dict(state)
+
+    # for FSDP, let's make extra sure nothing bad happened with out of sync
+    # checkpoints across workers.
+    epoch = float(self.epoch)
+    avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
+    if avg_epoch != epoch:
+        raise RuntimeError(
+            f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
+            f"but average of epochs is {avg_epoch}, at least one gpu must have a "
+            "different epoch number.")
+
+    # on load_best, properly reinitialize state_dict, best states and ema
+    # otherwise we load from the current xp and don't alter anything
+    if load_best:
+        self.logger.info("Loading state_dict from best state.")
+        if not self.cfg.fsdp.use and self.fsdp_best_state:
+            # loading from an FSDP checkpoint but with FSDP deactivated
+            self.logger.info("... Loading from FSDP best state dict.")
+            self.best_state.load_state_dict(self.fsdp_best_state)
+
+        # if load_best, we permanently override the regular state_dict with the best state
+        if self.cfg.fsdp.use:
+            self.logger.info("FSDP is used, loading from FSDP best state.")
+            with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+                # this might be really fragile but okay for now.
+                self.load_state_dict(self.fsdp_best_state)
+        else:
+            # we permanently swap the stateful objects to their best state
+            self._load_new_state_dict(self.best_state.state_dict())
+
+        # the EMA modules should also be instantiated with best state.
+        # the easiest way to do so is to reinitialize a new EMA with best state loaded.
+        if self.ema is not None:
+            self.logger.info("Re-initializing EMA from best state")
+            self.initialize_ema()
+
+        if self.cfg.fsdp.use:
+            self.logger.info("Re-initializing best state after using FSDP best state.")
+            for name in self.best_state.states.keys():
+                state_source = self._get_state_source(name)
+                self.best_state.update(name, state_source)
+
+    return state
+
+
+
+def load_from_pretrained(self, name: str) ‑> dict +
+
+
+
+ +Expand source code + +
def load_from_pretrained(self, name: str) -> dict:
+    raise NotImplementedError("Solver does not provide a way to load pretrained models.")
+
+
+
+def log_model_summary(self, model: torch.nn.modules.module.Module) +
+
+

Log model summary, architecture and size of the model.

+
+ +Expand source code + +
def log_model_summary(self, model: nn.Module):
+    """Log model summary, architecture and size of the model."""
+    self.logger.info(model)
+    mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
+    self.logger.info("Size: %.1f MB", mb)
+
+
+
+def register_best_state(self, *args: str) +
+
+

Register state sources in BestStateDictManager to keep their best states along with their +latest states. The best state will be used at evaluation stages instead of the latest states.

+

Shortcut around BestStateDictManager.register method. You can pass any number of +attribute, included nested attributes and those will be included into the checkpoints +and automatically restored when BaseSolver.restore is called.

+
+ +Expand source code + +
def register_best_state(self, *args: str):
+    """Register state sources in `BestStateDictManager` to keep their best states along with their
+    latest states. The best state will be used at evaluation stages instead of the latest states.
+
+    Shortcut around `BestStateDictManager.register` method. You can pass any number of
+    attribute, included nested attributes and those will be included into the checkpoints
+    and automatically restored when `BaseSolver.restore` is called.
+    """
+    for name in args:
+        state_source = self._get_state_source(name)
+        assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
+        self.best_state.register(name, state_source)
+
+
+
+def register_ema(self, *args: str) +
+
+

Register state sources for exponential moving average.

+

The registered sources are used to instantiate a ModuleDictEMA instance. +The ModuleDictEMA keeps a nn.ModuleDict module that is updated when self.ema.step() is called +and swapped with the original state sources with self.swap_ema_state() method.

+

Usage

+

self.register_ema('model')

+
+ +Expand source code + +
def register_ema(self, *args: str):
+    """Register state sources for exponential moving average.
+
+    The registered sources are used to instantiate a ModuleDictEMA instance.
+    The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
+    and swapped with the original state sources with self.swap_ema_state() method.
+
+    Usage:
+        self.register_ema('model')
+    """
+    assert self.ema is None, "Cannot register state source to already instantiated EMA."
+    for name in args:
+        self._ema_sources[name] = getattr(self, name)
+
+
+
+def restore(self, load_best: bool = False, replay_metrics: bool = False, ignore_state_keys: List[str] = []) ‑> bool +
+
+

Restore the status of a solver for a given xp.

+

Args

+
+
load_best : bool
+
if True, load the best state from the checkpoint.
+
replay_metrics : bool
+
if True, logs all the metrics from past epochs.
+
ignore_state_keys : list of str
+
list of sources to ignore when loading the state, e.g. optimizer.
+
+
+ +Expand source code + +
def restore(self, load_best: bool = False, replay_metrics: bool = False,
+            ignore_state_keys: tp.List[str] = []) -> bool:
+    """Restore the status of a solver for a given xp.
+
+    Args:
+        load_best (bool): if `True`, load the best state from the checkpoint.
+        replay_metrics (bool): if `True`, logs all the metrics from past epochs.
+        ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
+    """
+    self.logger.info("Restoring weights and history.")
+    restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)
+
+    self.logger.info("Model hash: %s", model_hash(self.model))
+
+    if replay_metrics and len(self.history) > 0:
+        self.logger.info("Replaying past metrics...")
+        for epoch, stages in enumerate(self.history):
+            for stage_name, metrics in stages.items():
+                # We manually log the metrics summary to the result logger
+                # as we don't want to add them to the pending metrics
+                self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
+                                                formatter=self.get_formatter(stage_name))
+    return restored_checkpoints is not None
+
+
+
+def run(self) +
+
+

Training loop.

+
+ +Expand source code + +
def run(self):
+    """Training loop."""
+    assert len(self.state_dict()) > 0
+    self.restore(replay_metrics=True)  # load checkpoint and replay history
+    self.log_hyperparams(dict_from_config(self.cfg))
+    for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
+        if self.should_stop_training():
+            return
+        self.run_epoch()
+        # Commit will send the metrics to Dora and save checkpoints by default.
+        self.commit()
+
+
+
+def run_epoch(self) +
+
+

Run a single epoch with all stages.

+

Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards. +Children solvers can extend this method with custom behavior, e.g.:

+
def run_epoch(self):
+    ... # custom code
+    super().run_epoch()
+    ... # custom code
+
+
+ +Expand source code + +
def run_epoch(self):
+    """Run a single epoch with all stages.
+
+    Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
+    Children solvers can extend this method with custom behavior, e.g.:
+
+        def run_epoch(self):
+            ... # custom code
+            super().run_epoch()
+            ... # custom code
+    """
+    self.run_stage('train', self.train)
+    with torch.no_grad():
+        with self.swap_ema_state():
+            self.run_stage('valid', self.valid)
+            # the best state is updated with EMA states if available
+            self.update_best_state_from_stage('valid')
+        with self.swap_best_state():
+            if self.should_run_stage('evaluate'):
+                self.run_stage('evaluate', self.evaluate)
+            if self.should_run_stage('generate'):
+                self.run_stage('generate', with_rank_rng()(self.generate))
+
+
+
+def run_one_stage(self, stage_name: str) +
+
+

Run only the specified stage. +This method is useful to only generate samples from a trained experiment +or rerun the validation or evaluation stages.

+
+ +Expand source code + +
def run_one_stage(self, stage_name: str):
+    """Run only the specified stage.
+    This method is useful to only generate samples from a trained experiment
+    or rerun the validation or evaluation stages.
+    """
+    fn = {
+        'generate': with_rank_rng()(self.generate),
+        'evaluate': self.evaluate,
+        'valid': self.valid,
+    }
+    if stage_name not in fn:
+        raise ValueError(f'Trying to run stage {stage_name} is not supported.')
+    assert len(self.state_dict()) > 0
+    self._start_epoch()
+    with torch.no_grad(), self.swap_best_state():
+        self.run_stage(stage_name, fn[stage_name])
+    if not self.cfg.execute_inplace:
+        self.commit(save_checkpoints=False)
+
+
+
+def run_step(self, idx: int, batch: Any, metrics: dict) +
+
+

Perform one training or valid step on a given batch.

+
+ +Expand source code + +
@abstractmethod
+def run_step(self, idx: int, batch: tp.Any, metrics: dict):
+    """Perform one training or valid step on a given batch."""
+    ...
+
+
+
+def save_checkpoints(self) +
+
+

Save checkpoint, optionally keeping a copy for a given epoch.

+
+ +Expand source code + +
def save_checkpoints(self):
+    """Save checkpoint, optionally keeping a copy for a given epoch."""
+    is_sharded = self.cfg.fsdp.use
+    if not flashy.distrib.is_rank_zero() and not is_sharded:
+        return
+    self.logger.info("Model hash: %s", model_hash(self.model))
+    state = self.state_dict()
+    epoch = self.epoch - 1  # pushing metrics will increase the epoch in Flashy, so we do -1 here
+
+    # save minimal state_dict as new checkpoint every X epoch
+    if self.cfg.checkpoint.save_every:
+        if epoch % self.cfg.checkpoint.save_every == 0:
+            minimal_state = state
+            if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
+                minimal_state = {
+                    name: source for name, source in state.items()
+                    if name in self.cfg.checkpoint.keep_every_states
+                }
+            epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
+            checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)
+
+    # save checkpoint as latest checkpoint
+    if self.cfg.checkpoint.save_last:
+        last_checkpoint_path = self.checkpoint_path()
+        checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)
+
+    # flush any stale checkpoint to reduce disk footprint
+    checkpoint.flush_stale_checkpoints(self.checkpoint_path())
+
+
+
+def should_run_stage(self, stage_name) ‑> bool +
+
+

Check whether we want to run the specified stages.

+
+ +Expand source code + +
def should_run_stage(self, stage_name) -> bool:
+    """Check whether we want to run the specified stages."""
+    stage_every = self.cfg[stage_name].get('every', None)
+    is_last_epoch = self.epoch == self.cfg.optim.epochs
+    is_epoch_every = (stage_every and self.epoch % stage_every == 0)
+    return is_last_epoch or is_epoch_every
+
+
+
+def should_stop_training(self) ‑> bool +
+
+

Check whether we should stop training or not.

+
+ +Expand source code + +
def should_stop_training(self) -> bool:
+    """Check whether we should stop training or not."""
+    return self.epoch > self.cfg.optim.epochs
+
+
+
+def show(self) +
+
+

Method to log any information without running the job.

+
+ +Expand source code + +
@abstractmethod
+def show(self):
+    """Method to log any information without running the job."""
+    ...
+
+
+
+def swap_best_state(self) +
+
+
+
+ +Expand source code + +
@contextmanager
+def swap_best_state(self):
+    self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
+    old_states = self._load_new_state_dict(self.best_state.state_dict())
+    try:
+        yield
+    finally:
+        self.logger.debug("Swapping back from best to original state")
+        for name, old_state in old_states.items():
+            state_source = self._get_state_source(name)
+            state_source.load_state_dict(old_state)
+
+
+
+def swap_ema_state(self) +
+
+
+
+ +Expand source code + +
@contextmanager
+def swap_ema_state(self):
+    if self.ema is None:
+        yield
+    else:
+        ema_state_dict = self.ema.state_dict()['state']
+        self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
+        old_states = self._load_new_state_dict(ema_state_dict)
+        try:
+            yield
+        finally:
+            self.logger.debug("Swapping back from EMA state to original state")
+            for name, old_state in old_states.items():
+                state_source = self._get_state_source(name)
+                state_source.load_state_dict(old_state)
+
+
+
+def train(self) +
+
+

Train stage.

+
+ +Expand source code + +
def train(self):
+    """Train stage."""
+    return self.common_train_valid('train')
+
+
+
+def update_best_state_from_stage(self, stage_name: str = 'valid') +
+
+

Update latest best state based on pending metrics of a given stage. This method relies +on the BestStateDictManager.update method to update the best state_dict with latest weights +if the registered states happen to match to the best performing setup.

+
+ +Expand source code + +
def update_best_state_from_stage(self, stage_name: str = 'valid'):
+    """Update latest best state based on pending metrics of a given stage. This method relies
+    on the `BestStateDictManager.update` method to update the best state_dict with latest weights
+    if the registered states happen to match to the best performing setup.
+    """
+    if self.best_metric_name is None:
+        # when no best metric is defined, the last state is always the best
+        self._new_best_state = True
+        self.logger.info("Updating best state with current state.")
+    else:
+        assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
+        assert self.best_metric_name in self._pending_metrics[stage_name], \
+            f"Best metric not found in {stage_name} metrics. Cannot register best state"
+        current_score = self._pending_metrics[stage_name][self.best_metric_name]
+        all_best_metric_scores = [
+            past_metrics[stage_name][self.best_metric_name]
+            for past_metrics in self.history
+        ]
+        all_best_metric_scores.append(current_score)
+        best_score = min(all_best_metric_scores)
+        self._new_best_state = current_score == best_score
+        if self._new_best_state:
+            old_best = min(all_best_metric_scores[:-1] + [float('inf')])
+            self.logger.info(
+                f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")
+
+    if self._new_best_state:
+        if self.cfg.fsdp.use:
+            # this will give an empty state dict on all ranks but the rank 0
+            # which will have a copy in memory of the full model.
+            with fsdp.switch_to_full_state_dict(self._fsdp_modules):
+                for name in self.best_state.states.keys():
+                    state_source = self._get_state_source(name)
+                    self.best_state.update(name, state_source)
+                # we save to a different dict.
+                self.fsdp_best_state.update(self.best_state.state_dict())
+            # We cannot efficiently load fsdp_best_state when using FSDP,
+            # so we have do do a second pass, with the local shards.
+        for name in self.best_state.states.keys():
+            state_source = self._get_state_source(name)
+            self.best_state.update(name, state_source)
+
+
+
+def valid(self) +
+
+

Valid stage.

+
+ +Expand source code + +
def valid(self):
+    """Valid stage."""
+    return self.common_train_valid('valid')
+
+
+
+def wrap_with_fsdp(self, model: torch.nn.modules.module.Module, *args, **kwargs) +
+
+
+
+ +Expand source code + +
def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
+    model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
+    if isinstance(model, fsdp.FSDP):
+        self._fsdp_modules.append(model)
+    return model
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/solvers/builders.html b/api_docs/audiocraft/solvers/builders.html new file mode 100644 index 00000000..28826cd2 --- /dev/null +++ b/api_docs/audiocraft/solvers/builders.html @@ -0,0 +1,1009 @@ + + + + + + +audiocraft.solvers.builders API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.solvers.builders

+
+
+

All the functions to build the relevant solvers and used objects +from the Hydra config.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+All the functions to build the relevant solvers and used objects
+from the Hydra config.
+"""
+
+from enum import Enum
+import logging
+import typing as tp
+
+import dora
+import flashy
+import omegaconf
+import torch
+from torch import nn
+from torch.optim import Optimizer
+# LRScheduler was renamed in some torch versions
+try:
+    from torch.optim.lr_scheduler import LRScheduler  # type: ignore
+except ImportError:
+    from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+
+from .base import StandardSolver
+from .. import adversarial, data, losses, metrics, optim
+from ..utils.utils import dict_from_config, get_loader
+
+
+logger = logging.getLogger(__name__)
+
+
+class DatasetType(Enum):
+    AUDIO = "audio"
+    MUSIC = "music"
+    SOUND = "sound"
+
+
+def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
+    """Instantiate solver from config."""
+    from .audiogen import AudioGenSolver
+    from .compression import CompressionSolver
+    from .musicgen import MusicGenSolver
+    from .diffusion import DiffusionSolver
+    klass = {
+        'compression': CompressionSolver,
+        'musicgen': MusicGenSolver,
+        'audiogen': AudioGenSolver,
+        'lm': MusicGenSolver,  # backward compatibility
+        'diffusion': DiffusionSolver,
+        'sound_lm': AudioGenSolver,  # backward compatibility
+    }[cfg.solver]
+    return klass(cfg)  # type: ignore
+
+
+def get_optim_parameter_groups(model: nn.Module):
+    """Create parameter groups for the model using the appropriate method
+    if defined for each modules, to create the different groups.
+
+    Args:
+        model (nn.Module): torch model
+    Returns:
+        List of parameter groups
+    """
+    seen_params: tp.Set[nn.parameter.Parameter] = set()
+    other_params = []
+    groups = []
+    for name, module in model.named_modules():
+        if hasattr(module, 'make_optim_group'):
+            group = module.make_optim_group()
+            params = set(group['params'])
+            assert params.isdisjoint(seen_params)
+            seen_params |= set(params)
+            groups.append(group)
+    for param in model.parameters():
+        if param not in seen_params:
+            other_params.append(param)
+    groups.insert(0, {'params': other_params})
+    parameters = groups
+    return parameters
+
+
+def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer:
+    """Build torch optimizer from config and set of parameters.
+    Supported optimizers: Adam, AdamW
+
+    Args:
+        params (nn.Module or iterable of torch.Tensor): Parameters to optimize.
+        cfg (DictConfig): Optimization-related configuration.
+    Returns:
+        torch.optim.Optimizer.
+    """
+    if 'optimizer' not in cfg:
+        if getattr(cfg, 'optim', None) is not None:
+            raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?")
+        else:
+            raise KeyError("Optimizer not found in config.")
+
+    parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params
+    optimizer: torch.optim.Optimizer
+    if cfg.optimizer == 'adam':
+        optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam)
+    elif cfg.optimizer == 'adamw':
+        optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam)
+    elif cfg.optimizer == 'dadam':
+        optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam)
+    else:
+        raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
+    return optimizer
+
+
+def get_lr_scheduler(optimizer: torch.optim.Optimizer,
+                     cfg: omegaconf.DictConfig,
+                     total_updates: int) -> tp.Optional[LRScheduler]:
+    """Build torch learning rate scheduler from config and associated optimizer.
+    Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler
+
+    Args:
+        optimizer (torch.optim.Optimizer): Optimizer.
+        cfg (DictConfig): Schedule-related configuration.
+        total_updates (int): Total number of updates.
+    Returns:
+        torch.optim.Optimizer.
+    """
+    if 'lr_scheduler' not in cfg:
+        raise KeyError("LR Scheduler not found in config")
+
+    lr_sched: tp.Optional[LRScheduler] = None
+    if cfg.lr_scheduler == 'step':
+        lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step)
+    elif cfg.lr_scheduler == 'exponential':
+        lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential)
+    elif cfg.lr_scheduler == 'cosine':
+        kwargs = dict_from_config(cfg.cosine)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.CosineLRScheduler(
+            optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
+    elif cfg.lr_scheduler == 'polynomial_decay':
+        kwargs = dict_from_config(cfg.polynomial_decay)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.PolynomialDecayLRScheduler(
+            optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
+    elif cfg.lr_scheduler == 'inverse_sqrt':
+        kwargs = dict_from_config(cfg.inverse_sqrt)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
+    elif cfg.lr_scheduler == 'linear_warmup':
+        kwargs = dict_from_config(cfg.linear_warmup)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
+    elif cfg.lr_scheduler is not None:
+        raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
+    return lr_sched
+
+
+def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]:
+    """Initialize Exponential Moving Average.
+
+    Args:
+        module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA.
+        cfg (omegaconf.DictConfig): Optim EMA configuration.
+    Returns:
+        optim.ModuleDictEMA: EMA version of the ModuleDict.
+    """
+    kw: tp.Dict[str, tp.Any] = dict(cfg)
+    use = kw.pop('use', False)
+    decay = kw.pop('decay', None)
+    device = kw.pop('device', None)
+    if not use:
+        return None
+    if len(module_dict) == 0:
+        raise ValueError("Trying to build EMA but an empty module_dict source is provided!")
+    ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device)
+    return ema_module
+
+
+def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
+    """Instantiate loss from configuration."""
+    klass = {
+        'l1': torch.nn.L1Loss,
+        'l2': torch.nn.MSELoss,
+        'mel': losses.MelSpectrogramL1Loss,
+        'mrstft': losses.MRSTFTLoss,
+        'msspec': losses.MultiScaleMelSpectrogramLoss,
+        'sisnr': losses.SISNR,
+    }[loss_name]
+    kwargs = dict(getattr(cfg, loss_name))
+    return klass(**kwargs)
+
+
+def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer:
+    """Instantiate loss balancer from configuration for the provided weights."""
+    kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg)
+    return losses.Balancer(loss_weights, **kwargs)
+
+
+def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
+    """Initialize adversary from config."""
+    klass = {
+        'msd': adversarial.MultiScaleDiscriminator,
+        'mpd': adversarial.MultiPeriodDiscriminator,
+        'msstftd': adversarial.MultiScaleSTFTDiscriminator,
+    }[name]
+    adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name))
+    return klass(**adv_cfg)
+
+
+def get_adversarial_losses(cfg) -> nn.ModuleDict:
+    """Initialize dict of adversarial losses from config."""
+    device = cfg.device
+    adv_cfg = getattr(cfg, 'adversarial')
+    adversaries = adv_cfg.get('adversaries', [])
+    adv_loss_name = adv_cfg['adv_loss']
+    feat_loss_name = adv_cfg.get('feat_loss')
+    normalize = adv_cfg.get('normalize', True)
+    feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None
+    if feat_loss_name:
+        assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found."
+        loss = get_loss(feat_loss_name, cfg)
+        feat_loss = adversarial.FeatureMatchingLoss(loss, normalize)
+    loss = adversarial.get_adv_criterion(adv_loss_name)
+    loss_real = adversarial.get_real_criterion(adv_loss_name)
+    loss_fake = adversarial.get_fake_criterion(adv_loss_name)
+    adv_losses = nn.ModuleDict()
+    for adv_name in adversaries:
+        adversary = get_adversary(adv_name, cfg).to(device)
+        optimizer = get_optimizer(adversary.parameters(), cfg.optim)
+        adv_loss = adversarial.AdversarialLoss(
+            adversary,
+            optimizer,
+            loss=loss,
+            loss_real=loss_real,
+            loss_fake=loss_fake,
+            loss_feat=feat_loss,
+            normalize=normalize
+        )
+        adv_losses[adv_name] = adv_loss
+    return adv_losses
+
+
+def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
+    """Instantiate ViSQOL metric from config."""
+    kwargs = dict_from_config(cfg)
+    return metrics.ViSQOL(**kwargs)
+
+
+def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric:
+    """Instantiate Frechet Audio Distance metric from config."""
+    kwargs = dict_from_config(cfg.tf)
+    xp = dora.get_xp()
+    kwargs['log_folder'] = xp.folder
+    return metrics.FrechetAudioDistanceMetric(**kwargs)
+
+
+def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
+    """Instantiate KL-Divergence metric from config."""
+    kld_metrics = {
+        'passt': metrics.PasstKLDivergenceMetric,
+    }
+    klass = kld_metrics[cfg.model]
+    kwargs = dict_from_config(cfg.get(cfg.model))
+    return klass(**kwargs)
+
+
+def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric:
+    """Instantiate Text Consistency metric from config."""
+    text_consistency_metrics = {
+        'clap': metrics.CLAPTextConsistencyMetric
+    }
+    klass = text_consistency_metrics[cfg.model]
+    kwargs = dict_from_config(cfg.get(cfg.model))
+    return klass(**kwargs)
+
+
+def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric:
+    """Instantiate Chroma Cosine Similarity metric from config."""
+    assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric"
+    kwargs = dict_from_config(cfg.get(cfg.model))
+    return metrics.ChromaCosineSimilarityMetric(**kwargs)
+
+
+def get_audio_datasets(cfg: omegaconf.DictConfig,
+                       dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]:
+    """Build AudioDataset from configuration.
+
+    Args:
+        cfg (omegaconf.DictConfig): Configuration.
+        dataset_type: The type of dataset to create.
+    Returns:
+        dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split.
+    """
+    dataloaders: dict = {}
+
+    sample_rate = cfg.sample_rate
+    channels = cfg.channels
+    seed = cfg.seed
+    max_sample_rate = cfg.datasource.max_sample_rate
+    max_channels = cfg.datasource.max_channels
+
+    assert cfg.dataset is not None, "Could not find dataset definition in config"
+
+    dataset_cfg = dict_from_config(cfg.dataset)
+    splits_cfg: dict = {}
+    splits_cfg['train'] = dataset_cfg.pop('train')
+    splits_cfg['valid'] = dataset_cfg.pop('valid')
+    splits_cfg['evaluate'] = dataset_cfg.pop('evaluate')
+    splits_cfg['generate'] = dataset_cfg.pop('generate')
+    execute_only_stage = cfg.get('execute_only', None)
+
+    for split, path in cfg.datasource.items():
+        if not isinstance(path, str):
+            continue  # skipping this as not a path
+        if execute_only_stage is not None and split != execute_only_stage:
+            continue
+        logger.info(f"Loading audio data split {split}: {str(path)}")
+        assert (
+            cfg.sample_rate <= max_sample_rate
+        ), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found."
+        assert (
+            cfg.channels <= max_channels
+        ), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found."
+
+        split_cfg = splits_cfg[split]
+        split_kwargs = {k: v for k, v in split_cfg.items()}
+        kwargs = {**dataset_cfg, **split_kwargs}  # split kwargs overrides default dataset_cfg
+        kwargs['sample_rate'] = sample_rate
+        kwargs['channels'] = channels
+
+        if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch:
+            kwargs['num_samples'] = (
+                flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch)
+
+        num_samples = kwargs['num_samples']
+        shuffle = kwargs['shuffle']
+
+        return_info = kwargs.pop('return_info')
+        batch_size = kwargs.pop('batch_size', None)
+        num_workers = kwargs.pop('num_workers')
+
+        if dataset_type == DatasetType.MUSIC:
+            dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs)
+        elif dataset_type == DatasetType.SOUND:
+            dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs)
+        elif dataset_type == DatasetType.AUDIO:
+            dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs)
+        else:
+            raise ValueError(f"Dataset type is unsupported: {dataset_type}")
+
+        loader = get_loader(
+            dataset,
+            num_samples,
+            batch_size=batch_size,
+            num_workers=num_workers,
+            seed=seed,
+            collate_fn=dataset.collater if return_info else None,
+            shuffle=shuffle,
+        )
+        dataloaders[split] = loader
+
+    return dataloaders
+
+
+
+
+
+
+
+

Functions

+
+
+def get_adversarial_losses(cfg) ‑> torch.nn.modules.container.ModuleDict +
+
+

Initialize dict of adversarial losses from config.

+
+ +Expand source code + +
def get_adversarial_losses(cfg) -> nn.ModuleDict:
+    """Initialize dict of adversarial losses from config."""
+    device = cfg.device
+    adv_cfg = getattr(cfg, 'adversarial')
+    adversaries = adv_cfg.get('adversaries', [])
+    adv_loss_name = adv_cfg['adv_loss']
+    feat_loss_name = adv_cfg.get('feat_loss')
+    normalize = adv_cfg.get('normalize', True)
+    feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None
+    if feat_loss_name:
+        assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found."
+        loss = get_loss(feat_loss_name, cfg)
+        feat_loss = adversarial.FeatureMatchingLoss(loss, normalize)
+    loss = adversarial.get_adv_criterion(adv_loss_name)
+    loss_real = adversarial.get_real_criterion(adv_loss_name)
+    loss_fake = adversarial.get_fake_criterion(adv_loss_name)
+    adv_losses = nn.ModuleDict()
+    for adv_name in adversaries:
+        adversary = get_adversary(adv_name, cfg).to(device)
+        optimizer = get_optimizer(adversary.parameters(), cfg.optim)
+        adv_loss = adversarial.AdversarialLoss(
+            adversary,
+            optimizer,
+            loss=loss,
+            loss_real=loss_real,
+            loss_fake=loss_fake,
+            loss_feat=feat_loss,
+            normalize=normalize
+        )
+        adv_losses[adv_name] = adv_loss
+    return adv_losses
+
+
+
+def get_adversary(name: str, cfg: omegaconf.dictconfig.DictConfig) ‑> torch.nn.modules.module.Module +
+
+

Initialize adversary from config.

+
+ +Expand source code + +
def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
+    """Initialize adversary from config."""
+    klass = {
+        'msd': adversarial.MultiScaleDiscriminator,
+        'mpd': adversarial.MultiPeriodDiscriminator,
+        'msstftd': adversarial.MultiScaleSTFTDiscriminator,
+    }[name]
+    adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name))
+    return klass(**adv_cfg)
+
+
+
+def get_audio_datasets(cfg: omegaconf.dictconfig.DictConfig, dataset_type: DatasetType = DatasetType.AUDIO) ‑> Dict[str, torch.utils.data.dataloader.DataLoader] +
+
+

Build AudioDataset from configuration.

+

Args

+
+
cfg : omegaconf.DictConfig
+
Configuration.
+
dataset_type
+
The type of dataset to create.
+
+

Returns

+
+
dict[str, torch.utils.data.DataLoader]
+
Map of dataloader for each data split.
+
+
+ +Expand source code + +
def get_audio_datasets(cfg: omegaconf.DictConfig,
+                       dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]:
+    """Build AudioDataset from configuration.
+
+    Args:
+        cfg (omegaconf.DictConfig): Configuration.
+        dataset_type: The type of dataset to create.
+    Returns:
+        dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split.
+    """
+    dataloaders: dict = {}
+
+    sample_rate = cfg.sample_rate
+    channels = cfg.channels
+    seed = cfg.seed
+    max_sample_rate = cfg.datasource.max_sample_rate
+    max_channels = cfg.datasource.max_channels
+
+    assert cfg.dataset is not None, "Could not find dataset definition in config"
+
+    dataset_cfg = dict_from_config(cfg.dataset)
+    splits_cfg: dict = {}
+    splits_cfg['train'] = dataset_cfg.pop('train')
+    splits_cfg['valid'] = dataset_cfg.pop('valid')
+    splits_cfg['evaluate'] = dataset_cfg.pop('evaluate')
+    splits_cfg['generate'] = dataset_cfg.pop('generate')
+    execute_only_stage = cfg.get('execute_only', None)
+
+    for split, path in cfg.datasource.items():
+        if not isinstance(path, str):
+            continue  # skipping this as not a path
+        if execute_only_stage is not None and split != execute_only_stage:
+            continue
+        logger.info(f"Loading audio data split {split}: {str(path)}")
+        assert (
+            cfg.sample_rate <= max_sample_rate
+        ), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found."
+        assert (
+            cfg.channels <= max_channels
+        ), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found."
+
+        split_cfg = splits_cfg[split]
+        split_kwargs = {k: v for k, v in split_cfg.items()}
+        kwargs = {**dataset_cfg, **split_kwargs}  # split kwargs overrides default dataset_cfg
+        kwargs['sample_rate'] = sample_rate
+        kwargs['channels'] = channels
+
+        if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch:
+            kwargs['num_samples'] = (
+                flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch)
+
+        num_samples = kwargs['num_samples']
+        shuffle = kwargs['shuffle']
+
+        return_info = kwargs.pop('return_info')
+        batch_size = kwargs.pop('batch_size', None)
+        num_workers = kwargs.pop('num_workers')
+
+        if dataset_type == DatasetType.MUSIC:
+            dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs)
+        elif dataset_type == DatasetType.SOUND:
+            dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs)
+        elif dataset_type == DatasetType.AUDIO:
+            dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs)
+        else:
+            raise ValueError(f"Dataset type is unsupported: {dataset_type}")
+
+        loader = get_loader(
+            dataset,
+            num_samples,
+            batch_size=batch_size,
+            num_workers=num_workers,
+            seed=seed,
+            collate_fn=dataset.collater if return_info else None,
+            shuffle=shuffle,
+        )
+        dataloaders[split] = loader
+
+    return dataloaders
+
+
+
+def get_balancer(loss_weights: Dict[str, float], cfg: omegaconf.dictconfig.DictConfig) ‑> Balancer +
+
+

Instantiate loss balancer from configuration for the provided weights.

+
+ +Expand source code + +
def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer:
+    """Instantiate loss balancer from configuration for the provided weights."""
+    kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg)
+    return losses.Balancer(loss_weights, **kwargs)
+
+
+
+def get_chroma_cosine_similarity(cfg: omegaconf.dictconfig.DictConfig) ‑> ChromaCosineSimilarityMetric +
+
+

Instantiate Chroma Cosine Similarity metric from config.

+
+ +Expand source code + +
def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric:
+    """Instantiate Chroma Cosine Similarity metric from config."""
+    assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric"
+    kwargs = dict_from_config(cfg.get(cfg.model))
+    return metrics.ChromaCosineSimilarityMetric(**kwargs)
+
+
+
+def get_ema(module_dict: torch.nn.modules.container.ModuleDict, cfg: omegaconf.dictconfig.DictConfig) ‑> Optional[ModuleDictEMA] +
+
+

Initialize Exponential Moving Average.

+

Args

+
+
module_dict : nn.ModuleDict
+
ModuleDict for which to compute the EMA.
+
cfg : omegaconf.DictConfig
+
Optim EMA configuration.
+
+

Returns

+
+
optim.ModuleDictEMA
+
EMA version of the ModuleDict.
+
+
+ +Expand source code + +
def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]:
+    """Initialize Exponential Moving Average.
+
+    Args:
+        module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA.
+        cfg (omegaconf.DictConfig): Optim EMA configuration.
+    Returns:
+        optim.ModuleDictEMA: EMA version of the ModuleDict.
+    """
+    kw: tp.Dict[str, tp.Any] = dict(cfg)
+    use = kw.pop('use', False)
+    decay = kw.pop('decay', None)
+    device = kw.pop('device', None)
+    if not use:
+        return None
+    if len(module_dict) == 0:
+        raise ValueError("Trying to build EMA but an empty module_dict source is provided!")
+    ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device)
+    return ema_module
+
+
+
+def get_fad(cfg: omegaconf.dictconfig.DictConfig) ‑> FrechetAudioDistanceMetric +
+
+

Instantiate Frechet Audio Distance metric from config.

+
+ +Expand source code + +
def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric:
+    """Instantiate Frechet Audio Distance metric from config."""
+    kwargs = dict_from_config(cfg.tf)
+    xp = dora.get_xp()
+    kwargs['log_folder'] = xp.folder
+    return metrics.FrechetAudioDistanceMetric(**kwargs)
+
+
+
+def get_kldiv(cfg: omegaconf.dictconfig.DictConfig) ‑> KLDivergenceMetric +
+
+

Instantiate KL-Divergence metric from config.

+
+ +Expand source code + +
def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
+    """Instantiate KL-Divergence metric from config."""
+    kld_metrics = {
+        'passt': metrics.PasstKLDivergenceMetric,
+    }
+    klass = kld_metrics[cfg.model]
+    kwargs = dict_from_config(cfg.get(cfg.model))
+    return klass(**kwargs)
+
+
+
+def get_loss(loss_name: str, cfg: omegaconf.dictconfig.DictConfig) +
+
+

Instantiate loss from configuration.

+
+ +Expand source code + +
def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
+    """Instantiate loss from configuration."""
+    klass = {
+        'l1': torch.nn.L1Loss,
+        'l2': torch.nn.MSELoss,
+        'mel': losses.MelSpectrogramL1Loss,
+        'mrstft': losses.MRSTFTLoss,
+        'msspec': losses.MultiScaleMelSpectrogramLoss,
+        'sisnr': losses.SISNR,
+    }[loss_name]
+    kwargs = dict(getattr(cfg, loss_name))
+    return klass(**kwargs)
+
+
+
+def get_lr_scheduler(optimizer: torch.optim.optimizer.Optimizer, cfg: omegaconf.dictconfig.DictConfig, total_updates: int) ‑> Optional[torch.optim.lr_scheduler.LRScheduler] +
+
+

Build torch learning rate scheduler from config and associated optimizer. +Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler

+

Args

+
+
optimizer : torch.optim.Optimizer
+
Optimizer.
+
cfg : DictConfig
+
Schedule-related configuration.
+
total_updates : int
+
Total number of updates.
+
+

Returns

+

torch.optim.Optimizer.

+
+ +Expand source code + +
def get_lr_scheduler(optimizer: torch.optim.Optimizer,
+                     cfg: omegaconf.DictConfig,
+                     total_updates: int) -> tp.Optional[LRScheduler]:
+    """Build torch learning rate scheduler from config and associated optimizer.
+    Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler
+
+    Args:
+        optimizer (torch.optim.Optimizer): Optimizer.
+        cfg (DictConfig): Schedule-related configuration.
+        total_updates (int): Total number of updates.
+    Returns:
+        torch.optim.Optimizer.
+    """
+    if 'lr_scheduler' not in cfg:
+        raise KeyError("LR Scheduler not found in config")
+
+    lr_sched: tp.Optional[LRScheduler] = None
+    if cfg.lr_scheduler == 'step':
+        lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step)
+    elif cfg.lr_scheduler == 'exponential':
+        lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential)
+    elif cfg.lr_scheduler == 'cosine':
+        kwargs = dict_from_config(cfg.cosine)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.CosineLRScheduler(
+            optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
+    elif cfg.lr_scheduler == 'polynomial_decay':
+        kwargs = dict_from_config(cfg.polynomial_decay)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.PolynomialDecayLRScheduler(
+            optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
+    elif cfg.lr_scheduler == 'inverse_sqrt':
+        kwargs = dict_from_config(cfg.inverse_sqrt)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
+    elif cfg.lr_scheduler == 'linear_warmup':
+        kwargs = dict_from_config(cfg.linear_warmup)
+        warmup_steps = kwargs.pop('warmup')
+        lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
+    elif cfg.lr_scheduler is not None:
+        raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
+    return lr_sched
+
+
+
+def get_optim_parameter_groups(model: torch.nn.modules.module.Module) +
+
+

Create parameter groups for the model using the appropriate method +if defined for each modules, to create the different groups.

+

Args

+
+
model : nn.Module
+
torch model
+
+

Returns

+

List of parameter groups

+
+ +Expand source code + +
def get_optim_parameter_groups(model: nn.Module):
+    """Create parameter groups for the model using the appropriate method
+    if defined for each modules, to create the different groups.
+
+    Args:
+        model (nn.Module): torch model
+    Returns:
+        List of parameter groups
+    """
+    seen_params: tp.Set[nn.parameter.Parameter] = set()
+    other_params = []
+    groups = []
+    for name, module in model.named_modules():
+        if hasattr(module, 'make_optim_group'):
+            group = module.make_optim_group()
+            params = set(group['params'])
+            assert params.isdisjoint(seen_params)
+            seen_params |= set(params)
+            groups.append(group)
+    for param in model.parameters():
+        if param not in seen_params:
+            other_params.append(param)
+    groups.insert(0, {'params': other_params})
+    parameters = groups
+    return parameters
+
+
+
+def get_optimizer(params: Union[torch.nn.modules.module.Module, Iterable[torch.Tensor]], cfg: omegaconf.dictconfig.DictConfig) ‑> torch.optim.optimizer.Optimizer +
+
+

Build torch optimizer from config and set of parameters. +Supported optimizers: Adam, AdamW

+

Args

+
+
params : nn.Module or iterable of torch.Tensor
+
Parameters to optimize.
+
cfg : DictConfig
+
Optimization-related configuration.
+
+

Returns

+

torch.optim.Optimizer.

+
+ +Expand source code + +
def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer:
+    """Build torch optimizer from config and set of parameters.
+    Supported optimizers: Adam, AdamW
+
+    Args:
+        params (nn.Module or iterable of torch.Tensor): Parameters to optimize.
+        cfg (DictConfig): Optimization-related configuration.
+    Returns:
+        torch.optim.Optimizer.
+    """
+    if 'optimizer' not in cfg:
+        if getattr(cfg, 'optim', None) is not None:
+            raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?")
+        else:
+            raise KeyError("Optimizer not found in config.")
+
+    parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params
+    optimizer: torch.optim.Optimizer
+    if cfg.optimizer == 'adam':
+        optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam)
+    elif cfg.optimizer == 'adamw':
+        optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam)
+    elif cfg.optimizer == 'dadam':
+        optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam)
+    else:
+        raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
+    return optimizer
+
+
+
+def get_solver(cfg: omegaconf.dictconfig.DictConfig) ‑> StandardSolver +
+
+

Instantiate solver from config.

+
+ +Expand source code + +
def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
+    """Instantiate solver from config."""
+    from .audiogen import AudioGenSolver
+    from .compression import CompressionSolver
+    from .musicgen import MusicGenSolver
+    from .diffusion import DiffusionSolver
+    klass = {
+        'compression': CompressionSolver,
+        'musicgen': MusicGenSolver,
+        'audiogen': AudioGenSolver,
+        'lm': MusicGenSolver,  # backward compatibility
+        'diffusion': DiffusionSolver,
+        'sound_lm': AudioGenSolver,  # backward compatibility
+    }[cfg.solver]
+    return klass(cfg)  # type: ignore
+
+
+
+def get_text_consistency(cfg: omegaconf.dictconfig.DictConfig) ‑> TextConsistencyMetric +
+
+

Instantiate Text Consistency metric from config.

+
+ +Expand source code + +
def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric:
+    """Instantiate Text Consistency metric from config."""
+    text_consistency_metrics = {
+        'clap': metrics.CLAPTextConsistencyMetric
+    }
+    klass = text_consistency_metrics[cfg.model]
+    kwargs = dict_from_config(cfg.get(cfg.model))
+    return klass(**kwargs)
+
+
+
+def get_visqol(cfg: omegaconf.dictconfig.DictConfig) ‑> ViSQOL +
+
+

Instantiate ViSQOL metric from config.

+
+ +Expand source code + +
def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
+    """Instantiate ViSQOL metric from config."""
+    kwargs = dict_from_config(cfg)
+    return metrics.ViSQOL(**kwargs)
+
+
+
+
+
+

Classes

+
+
+class DatasetType +(value, names=None, *, module=None, qualname=None, type=None, start=1) +
+
+

An enumeration.

+
+ +Expand source code + +
class DatasetType(Enum):
+    AUDIO = "audio"
+    MUSIC = "music"
+    SOUND = "sound"
+
+

Ancestors

+
    +
  • enum.Enum
  • +
+

Class variables

+
+
var AUDIO
+
+
+
+
var MUSIC
+
+
+
+
var SOUND
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/solvers/compression.html b/api_docs/audiocraft/solvers/compression.html new file mode 100644 index 00000000..2dc643fc --- /dev/null +++ b/api_docs/audiocraft/solvers/compression.html @@ -0,0 +1,1010 @@ + + + + + + +audiocraft.solvers.compression API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.solvers.compression

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import multiprocessing
+from pathlib import Path
+import typing as tp
+
+import flashy
+import omegaconf
+import torch
+from torch import nn
+
+from . import base, builders
+from .. import models, quantization
+from ..utils import checkpoint
+from ..utils.samples.manager import SampleManager
+from ..utils.utils import get_pool_executor
+
+
+logger = logging.getLogger(__name__)
+
+
+class CompressionSolver(base.StandardSolver):
+    """Solver for compression task.
+
+    The compression task combines a set of perceptual and objective losses
+    to train an EncodecModel (composed of an encoder-decoder and a quantizer)
+    to perform high fidelity audio reconstruction.
+    """
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__(cfg)
+        self.rng: torch.Generator  # set at each epoch
+        self.adv_losses = builders.get_adversarial_losses(self.cfg)
+        self.aux_losses = nn.ModuleDict()
+        self.info_losses = nn.ModuleDict()
+        assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
+        loss_weights = dict()
+        for loss_name, weight in self.cfg.losses.items():
+            if loss_name in ['adv', 'feat']:
+                for adv_name, _ in self.adv_losses.items():
+                    loss_weights[f'{loss_name}_{adv_name}'] = weight
+            elif weight > 0:
+                self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
+                loss_weights[loss_name] = weight
+            else:
+                self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
+        self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
+        self.register_stateful('adv_losses')
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        # best model is the last for the compression model
+        return None
+
+    def build_model(self):
+        """Instantiate model and optimizer."""
+        # Model and optimizer
+        self.model = models.builders.get_compression_model(self.cfg).to(self.device)
+        self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+        self.register_stateful('model', 'optimizer')
+        self.register_best_state('model')
+        self.register_ema('model')
+
+    def build_dataloaders(self):
+        """Instantiate audio dataloaders for each stage."""
+        self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+    def show(self):
+        """Show the compression model and employed adversarial loss."""
+        self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:")
+        self.log_model_summary(self.model)
+        self.logger.info("Adversarial loss:")
+        self.log_model_summary(self.adv_losses)
+        self.logger.info("Auxiliary losses:")
+        self.logger.info(self.aux_losses)
+        self.logger.info("Info losses:")
+        self.logger.info(self.info_losses)
+
+    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
+        """Perform one training or valid step on a given batch."""
+        x = batch.to(self.device)
+        y = x.clone()
+
+        qres = self.model(x)
+        assert isinstance(qres, quantization.QuantizedResult)
+        y_pred = qres.x
+        # Log bandwidth in kb/s
+        metrics['bandwidth'] = qres.bandwidth.mean()
+
+        if self.is_training:
+            d_losses: dict = {}
+            if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
+                for adv_name, adversary in self.adv_losses.items():
+                    disc_loss = adversary.train_adv(y_pred, y)
+                    d_losses[f'd_{adv_name}'] = disc_loss
+                metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values())))
+            metrics.update(d_losses)
+
+        balanced_losses: dict = {}
+        other_losses: dict = {}
+
+        # penalty from quantization
+        if qres.penalty is not None and qres.penalty.requires_grad:
+            other_losses['penalty'] = qres.penalty  # penalty term from the quantizer
+
+        # adversarial losses
+        for adv_name, adversary in self.adv_losses.items():
+            adv_loss, feat_loss = adversary(y_pred, y)
+            balanced_losses[f'adv_{adv_name}'] = adv_loss
+            balanced_losses[f'feat_{adv_name}'] = feat_loss
+
+        # auxiliary losses
+        for loss_name, criterion in self.aux_losses.items():
+            loss = criterion(y_pred, y)
+            balanced_losses[loss_name] = loss
+
+        # weighted losses
+        metrics.update(balanced_losses)
+        metrics.update(other_losses)
+        metrics.update(qres.metrics)
+
+        if self.is_training:
+            # backprop losses that are not handled by balancer
+            other_loss = torch.tensor(0., device=self.device)
+            if 'penalty' in other_losses:
+                other_loss += other_losses['penalty']
+            if other_loss.requires_grad:
+                other_loss.backward(retain_graph=True)
+                ratio1 = sum(p.grad.data.norm(p=2).pow(2)
+                             for p in self.model.parameters() if p.grad is not None)
+                assert isinstance(ratio1, torch.Tensor)
+                metrics['ratio1'] = ratio1.sqrt()
+
+            # balancer losses backward, returns effective training loss
+            # with effective weights at the current batch.
+            metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)
+            # add metrics corresponding to weight ratios
+            metrics.update(self.balancer.metrics)
+            ratio2 = sum(p.grad.data.norm(p=2).pow(2)
+                         for p in self.model.parameters() if p.grad is not None)
+            assert isinstance(ratio2, torch.Tensor)
+            metrics['ratio2'] = ratio2.sqrt()
+
+            # optim
+            flashy.distrib.sync_model(self.model)
+            if self.cfg.optim.max_norm:
+                torch.nn.utils.clip_grad_norm_(
+                    self.model.parameters(), self.cfg.optim.max_norm
+                )
+            self.optimizer.step()
+            self.optimizer.zero_grad()
+
+        # informative losses only
+        info_losses: dict = {}
+        with torch.no_grad():
+            for loss_name, criterion in self.info_losses.items():
+                loss = criterion(y_pred, y)
+                info_losses[loss_name] = loss
+
+        metrics.update(info_losses)
+
+        # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
+        adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')]
+        if len(adv_losses) > 0:
+            metrics['adv'] = torch.sum(torch.stack(adv_losses))
+        feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')]
+        if len(feat_losses) > 0:
+            metrics['feat'] = torch.sum(torch.stack(feat_losses))
+
+        return metrics
+
+    def run_epoch(self):
+        # reset random seed at the beginning of the epoch
+        self.rng = torch.Generator()
+        self.rng.manual_seed(1234 + self.epoch)
+        # run epoch
+        super().run_epoch()
+
+    def evaluate(self):
+        """Evaluate stage. Runs audio reconstruction evaluation."""
+        self.model.eval()
+        evaluate_stage_name = str(self.current_stage)
+
+        loader = self.dataloaders['evaluate']
+        updates = len(loader)
+        lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+        average = flashy.averager()
+
+        pendings = []
+        ctx = multiprocessing.get_context('spawn')
+        with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
+            for idx, batch in enumerate(lp):
+                x = batch.to(self.device)
+                with torch.no_grad():
+                    qres = self.model(x)
+
+                y_pred = qres.x.cpu()
+                y = batch.cpu()  # should already be on CPU but just in case
+                pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
+
+            metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
+            for pending in metrics_lp:
+                metrics = pending.result()
+                metrics = average(metrics)
+
+        metrics = flashy.distrib.average_metrics(metrics, len(loader))
+        return metrics
+
+    def generate(self):
+        """Generate stage."""
+        self.model.eval()
+        sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
+        generate_stage_name = str(self.current_stage)
+
+        loader = self.dataloaders['generate']
+        updates = len(loader)
+        lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+        for batch in lp:
+            reference, _ = batch
+            reference = reference.to(self.device)
+            with torch.no_grad():
+                qres = self.model(reference)
+            assert isinstance(qres, quantization.QuantizedResult)
+
+            reference = reference.cpu()
+            estimate = qres.x.cpu()
+            sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
+
+        flashy.distrib.barrier()
+
+    def load_from_pretrained(self, name: str) -> dict:
+        model = models.CompressionModel.get_pretrained(name)
+        if isinstance(model, models.DAC):
+            raise RuntimeError("Cannot fine tune a DAC model.")
+        elif isinstance(model, models.HFEncodecCompressionModel):
+            self.logger.warning('Trying to automatically convert a HuggingFace model '
+                                'to AudioCraft, this might fail!')
+            state = model.model.state_dict()
+            new_state = {}
+            for k, v in state.items():
+                if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
+                    # We need to determine if this a convtr or a regular conv.
+                    layer = int(k.split('.')[2])
+                    if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
+
+                        k = k.replace('.conv.', '.convtr.')
+                k = k.replace('encoder.layers.', 'encoder.model.')
+                k = k.replace('decoder.layers.', 'decoder.model.')
+                k = k.replace('conv.', 'conv.conv.')
+                k = k.replace('convtr.', 'convtr.convtr.')
+                k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
+                k = k.replace('.codebook.', '._codebook.')
+                new_state[k] = v
+            state = new_state
+        elif isinstance(model, models.EncodecModel):
+            state = model.state_dict()
+        else:
+            raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
+        return {
+            'best_state': {'model': state}
+        }
+
+    @staticmethod
+    def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
+                              device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+        """Instantiate a CompressionModel from a given checkpoint path or dora sig.
+        This method is a convenient endpoint to load a CompressionModel to use in other solvers.
+
+        Args:
+            checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+                This also supports pre-trained models by using a path of the form //pretrained/NAME.
+                See `model_from_pretrained` for a list of supported pretrained models.
+            use_ema (bool): Use EMA variant of the model instead of the actual model.
+            device (torch.device or str): Device on which the model is loaded.
+        """
+        checkpoint_path = str(checkpoint_path)
+        if checkpoint_path.startswith('//pretrained/'):
+            name = checkpoint_path.split('/', 3)[-1]
+            return models.CompressionModel.get_pretrained(name, device)
+        logger = logging.getLogger(__name__)
+        logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
+        _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
+        assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
+        state = checkpoint.load_checkpoint(_checkpoint_path)
+        assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
+        cfg = state['xp.cfg']
+        cfg.device = device
+        compression_model = models.builders.get_compression_model(cfg).to(device)
+        assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
+
+        assert 'best_state' in state and state['best_state'] != {}
+        assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
+        compression_model.load_state_dict(state['best_state']['model'])
+        compression_model.eval()
+        logger.info("Compression model loaded!")
+        return compression_model
+
+    @staticmethod
+    def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
+                                      checkpoint_path: tp.Union[Path, str],
+                                      device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+        """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
+
+        Args:
+            cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
+            checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+            use_ema (bool): Use EMA variant of the model instead of the actual model.
+            device (torch.device or str): Device on which the model is loaded.
+        """
+        compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
+        compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
+        return compression_model
+
+
+def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict:
+    """Audio reconstruction evaluation method that can be conveniently pickled."""
+    metrics = {}
+    if cfg.evaluate.metrics.visqol:
+        visqol = builders.get_visqol(cfg.metrics.visqol)
+        metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate)
+    sisnr = builders.get_loss('sisnr', cfg)
+    metrics['sisnr'] = sisnr(y_pred, y)
+    return metrics
+
+
+
+
+
+
+
+

Functions

+
+
+def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.dictconfig.DictConfig) ‑> dict +
+
+

Audio reconstruction evaluation method that can be conveniently pickled.

+
+ +Expand source code + +
def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict:
+    """Audio reconstruction evaluation method that can be conveniently pickled."""
+    metrics = {}
+    if cfg.evaluate.metrics.visqol:
+        visqol = builders.get_visqol(cfg.metrics.visqol)
+        metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate)
+    sisnr = builders.get_loss('sisnr', cfg)
+    metrics['sisnr'] = sisnr(y_pred, y)
+    return metrics
+
+
+
+
+
+

Classes

+
+
+class CompressionSolver +(cfg: omegaconf.dictconfig.DictConfig) +
+
+

Solver for compression task.

+

The compression task combines a set of perceptual and objective losses +to train an EncodecModel (composed of an encoder-decoder and a quantizer) +to perform high fidelity audio reconstruction.

+
+ +Expand source code + +
class CompressionSolver(base.StandardSolver):
+    """Solver for compression task.
+
+    The compression task combines a set of perceptual and objective losses
+    to train an EncodecModel (composed of an encoder-decoder and a quantizer)
+    to perform high fidelity audio reconstruction.
+    """
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__(cfg)
+        self.rng: torch.Generator  # set at each epoch
+        self.adv_losses = builders.get_adversarial_losses(self.cfg)
+        self.aux_losses = nn.ModuleDict()
+        self.info_losses = nn.ModuleDict()
+        assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
+        loss_weights = dict()
+        for loss_name, weight in self.cfg.losses.items():
+            if loss_name in ['adv', 'feat']:
+                for adv_name, _ in self.adv_losses.items():
+                    loss_weights[f'{loss_name}_{adv_name}'] = weight
+            elif weight > 0:
+                self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
+                loss_weights[loss_name] = weight
+            else:
+                self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
+        self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
+        self.register_stateful('adv_losses')
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        # best model is the last for the compression model
+        return None
+
+    def build_model(self):
+        """Instantiate model and optimizer."""
+        # Model and optimizer
+        self.model = models.builders.get_compression_model(self.cfg).to(self.device)
+        self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+        self.register_stateful('model', 'optimizer')
+        self.register_best_state('model')
+        self.register_ema('model')
+
+    def build_dataloaders(self):
+        """Instantiate audio dataloaders for each stage."""
+        self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+    def show(self):
+        """Show the compression model and employed adversarial loss."""
+        self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:")
+        self.log_model_summary(self.model)
+        self.logger.info("Adversarial loss:")
+        self.log_model_summary(self.adv_losses)
+        self.logger.info("Auxiliary losses:")
+        self.logger.info(self.aux_losses)
+        self.logger.info("Info losses:")
+        self.logger.info(self.info_losses)
+
+    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
+        """Perform one training or valid step on a given batch."""
+        x = batch.to(self.device)
+        y = x.clone()
+
+        qres = self.model(x)
+        assert isinstance(qres, quantization.QuantizedResult)
+        y_pred = qres.x
+        # Log bandwidth in kb/s
+        metrics['bandwidth'] = qres.bandwidth.mean()
+
+        if self.is_training:
+            d_losses: dict = {}
+            if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
+                for adv_name, adversary in self.adv_losses.items():
+                    disc_loss = adversary.train_adv(y_pred, y)
+                    d_losses[f'd_{adv_name}'] = disc_loss
+                metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values())))
+            metrics.update(d_losses)
+
+        balanced_losses: dict = {}
+        other_losses: dict = {}
+
+        # penalty from quantization
+        if qres.penalty is not None and qres.penalty.requires_grad:
+            other_losses['penalty'] = qres.penalty  # penalty term from the quantizer
+
+        # adversarial losses
+        for adv_name, adversary in self.adv_losses.items():
+            adv_loss, feat_loss = adversary(y_pred, y)
+            balanced_losses[f'adv_{adv_name}'] = adv_loss
+            balanced_losses[f'feat_{adv_name}'] = feat_loss
+
+        # auxiliary losses
+        for loss_name, criterion in self.aux_losses.items():
+            loss = criterion(y_pred, y)
+            balanced_losses[loss_name] = loss
+
+        # weighted losses
+        metrics.update(balanced_losses)
+        metrics.update(other_losses)
+        metrics.update(qres.metrics)
+
+        if self.is_training:
+            # backprop losses that are not handled by balancer
+            other_loss = torch.tensor(0., device=self.device)
+            if 'penalty' in other_losses:
+                other_loss += other_losses['penalty']
+            if other_loss.requires_grad:
+                other_loss.backward(retain_graph=True)
+                ratio1 = sum(p.grad.data.norm(p=2).pow(2)
+                             for p in self.model.parameters() if p.grad is not None)
+                assert isinstance(ratio1, torch.Tensor)
+                metrics['ratio1'] = ratio1.sqrt()
+
+            # balancer losses backward, returns effective training loss
+            # with effective weights at the current batch.
+            metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)
+            # add metrics corresponding to weight ratios
+            metrics.update(self.balancer.metrics)
+            ratio2 = sum(p.grad.data.norm(p=2).pow(2)
+                         for p in self.model.parameters() if p.grad is not None)
+            assert isinstance(ratio2, torch.Tensor)
+            metrics['ratio2'] = ratio2.sqrt()
+
+            # optim
+            flashy.distrib.sync_model(self.model)
+            if self.cfg.optim.max_norm:
+                torch.nn.utils.clip_grad_norm_(
+                    self.model.parameters(), self.cfg.optim.max_norm
+                )
+            self.optimizer.step()
+            self.optimizer.zero_grad()
+
+        # informative losses only
+        info_losses: dict = {}
+        with torch.no_grad():
+            for loss_name, criterion in self.info_losses.items():
+                loss = criterion(y_pred, y)
+                info_losses[loss_name] = loss
+
+        metrics.update(info_losses)
+
+        # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
+        adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')]
+        if len(adv_losses) > 0:
+            metrics['adv'] = torch.sum(torch.stack(adv_losses))
+        feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')]
+        if len(feat_losses) > 0:
+            metrics['feat'] = torch.sum(torch.stack(feat_losses))
+
+        return metrics
+
+    def run_epoch(self):
+        # reset random seed at the beginning of the epoch
+        self.rng = torch.Generator()
+        self.rng.manual_seed(1234 + self.epoch)
+        # run epoch
+        super().run_epoch()
+
+    def evaluate(self):
+        """Evaluate stage. Runs audio reconstruction evaluation."""
+        self.model.eval()
+        evaluate_stage_name = str(self.current_stage)
+
+        loader = self.dataloaders['evaluate']
+        updates = len(loader)
+        lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+        average = flashy.averager()
+
+        pendings = []
+        ctx = multiprocessing.get_context('spawn')
+        with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
+            for idx, batch in enumerate(lp):
+                x = batch.to(self.device)
+                with torch.no_grad():
+                    qres = self.model(x)
+
+                y_pred = qres.x.cpu()
+                y = batch.cpu()  # should already be on CPU but just in case
+                pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
+
+            metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
+            for pending in metrics_lp:
+                metrics = pending.result()
+                metrics = average(metrics)
+
+        metrics = flashy.distrib.average_metrics(metrics, len(loader))
+        return metrics
+
+    def generate(self):
+        """Generate stage."""
+        self.model.eval()
+        sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
+        generate_stage_name = str(self.current_stage)
+
+        loader = self.dataloaders['generate']
+        updates = len(loader)
+        lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+        for batch in lp:
+            reference, _ = batch
+            reference = reference.to(self.device)
+            with torch.no_grad():
+                qres = self.model(reference)
+            assert isinstance(qres, quantization.QuantizedResult)
+
+            reference = reference.cpu()
+            estimate = qres.x.cpu()
+            sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
+
+        flashy.distrib.barrier()
+
+    def load_from_pretrained(self, name: str) -> dict:
+        model = models.CompressionModel.get_pretrained(name)
+        if isinstance(model, models.DAC):
+            raise RuntimeError("Cannot fine tune a DAC model.")
+        elif isinstance(model, models.HFEncodecCompressionModel):
+            self.logger.warning('Trying to automatically convert a HuggingFace model '
+                                'to AudioCraft, this might fail!')
+            state = model.model.state_dict()
+            new_state = {}
+            for k, v in state.items():
+                if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
+                    # We need to determine if this a convtr or a regular conv.
+                    layer = int(k.split('.')[2])
+                    if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
+
+                        k = k.replace('.conv.', '.convtr.')
+                k = k.replace('encoder.layers.', 'encoder.model.')
+                k = k.replace('decoder.layers.', 'decoder.model.')
+                k = k.replace('conv.', 'conv.conv.')
+                k = k.replace('convtr.', 'convtr.convtr.')
+                k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
+                k = k.replace('.codebook.', '._codebook.')
+                new_state[k] = v
+            state = new_state
+        elif isinstance(model, models.EncodecModel):
+            state = model.state_dict()
+        else:
+            raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
+        return {
+            'best_state': {'model': state}
+        }
+
+    @staticmethod
+    def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
+                              device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+        """Instantiate a CompressionModel from a given checkpoint path or dora sig.
+        This method is a convenient endpoint to load a CompressionModel to use in other solvers.
+
+        Args:
+            checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+                This also supports pre-trained models by using a path of the form //pretrained/NAME.
+                See `model_from_pretrained` for a list of supported pretrained models.
+            use_ema (bool): Use EMA variant of the model instead of the actual model.
+            device (torch.device or str): Device on which the model is loaded.
+        """
+        checkpoint_path = str(checkpoint_path)
+        if checkpoint_path.startswith('//pretrained/'):
+            name = checkpoint_path.split('/', 3)[-1]
+            return models.CompressionModel.get_pretrained(name, device)
+        logger = logging.getLogger(__name__)
+        logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
+        _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
+        assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
+        state = checkpoint.load_checkpoint(_checkpoint_path)
+        assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
+        cfg = state['xp.cfg']
+        cfg.device = device
+        compression_model = models.builders.get_compression_model(cfg).to(device)
+        assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
+
+        assert 'best_state' in state and state['best_state'] != {}
+        assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
+        compression_model.load_state_dict(state['best_state']['model'])
+        compression_model.eval()
+        logger.info("Compression model loaded!")
+        return compression_model
+
+    @staticmethod
+    def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
+                                      checkpoint_path: tp.Union[Path, str],
+                                      device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+        """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
+
+        Args:
+            cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
+            checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+            use_ema (bool): Use EMA variant of the model instead of the actual model.
+            device (torch.device or str): Device on which the model is loaded.
+        """
+        compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
+        compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
+        return compression_model
+
+

Ancestors

+ +

Static methods

+
+
+def model_from_checkpoint(checkpoint_path: Union[str, pathlib.Path], device: Union[torch.device, str] = 'cpu') ‑> CompressionModel +
+
+

Instantiate a CompressionModel from a given checkpoint path or dora sig. +This method is a convenient endpoint to load a CompressionModel to use in other solvers.

+

Args

+
+
checkpoint_path : Path or str
+
Path to checkpoint or dora sig from where the checkpoint is resolved. +This also supports pre-trained models by using a path of the form //pretrained/NAME. +See model_from_pretrained for a list of supported pretrained models.
+
use_ema : bool
+
Use EMA variant of the model instead of the actual model.
+
device : torch.device or str
+
Device on which the model is loaded.
+
+
+ +Expand source code + +
@staticmethod
+def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
+                          device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+    """Instantiate a CompressionModel from a given checkpoint path or dora sig.
+    This method is a convenient endpoint to load a CompressionModel to use in other solvers.
+
+    Args:
+        checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+            This also supports pre-trained models by using a path of the form //pretrained/NAME.
+            See `model_from_pretrained` for a list of supported pretrained models.
+        use_ema (bool): Use EMA variant of the model instead of the actual model.
+        device (torch.device or str): Device on which the model is loaded.
+    """
+    checkpoint_path = str(checkpoint_path)
+    if checkpoint_path.startswith('//pretrained/'):
+        name = checkpoint_path.split('/', 3)[-1]
+        return models.CompressionModel.get_pretrained(name, device)
+    logger = logging.getLogger(__name__)
+    logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
+    _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
+    assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
+    state = checkpoint.load_checkpoint(_checkpoint_path)
+    assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
+    cfg = state['xp.cfg']
+    cfg.device = device
+    compression_model = models.builders.get_compression_model(cfg).to(device)
+    assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
+
+    assert 'best_state' in state and state['best_state'] != {}
+    assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
+    compression_model.load_state_dict(state['best_state']['model'])
+    compression_model.eval()
+    logger.info("Compression model loaded!")
+    return compression_model
+
+
+
+def wrapped_model_from_checkpoint(cfg: omegaconf.dictconfig.DictConfig, checkpoint_path: Union[str, pathlib.Path], device: Union[torch.device, str] = 'cpu') ‑> CompressionModel +
+
+

Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.

+

Args

+
+
cfg : omegaconf.DictConfig
+
Configuration to read from for wrapped mode.
+
checkpoint_path : Path or str
+
Path to checkpoint or dora sig from where the checkpoint is resolved.
+
use_ema : bool
+
Use EMA variant of the model instead of the actual model.
+
device : torch.device or str
+
Device on which the model is loaded.
+
+
+ +Expand source code + +
@staticmethod
+def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
+                                  checkpoint_path: tp.Union[Path, str],
+                                  device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
+    """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
+
+    Args:
+        cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
+        checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
+        use_ema (bool): Use EMA variant of the model instead of the actual model.
+        device (torch.device or str): Device on which the model is loaded.
+    """
+    compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
+    compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
+    return compression_model
+
+
+
+

Methods

+
+
+def build_dataloaders(self) +
+
+

Instantiate audio dataloaders for each stage.

+
+ +Expand source code + +
def build_dataloaders(self):
+    """Instantiate audio dataloaders for each stage."""
+    self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+
+
+def build_model(self) +
+
+

Instantiate model and optimizer.

+
+ +Expand source code + +
def build_model(self):
+    """Instantiate model and optimizer."""
+    # Model and optimizer
+    self.model = models.builders.get_compression_model(self.cfg).to(self.device)
+    self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+    self.register_stateful('model', 'optimizer')
+    self.register_best_state('model')
+    self.register_ema('model')
+
+
+
+def evaluate(self) +
+
+

Evaluate stage. Runs audio reconstruction evaluation.

+
+ +Expand source code + +
def evaluate(self):
+    """Evaluate stage. Runs audio reconstruction evaluation."""
+    self.model.eval()
+    evaluate_stage_name = str(self.current_stage)
+
+    loader = self.dataloaders['evaluate']
+    updates = len(loader)
+    lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+    average = flashy.averager()
+
+    pendings = []
+    ctx = multiprocessing.get_context('spawn')
+    with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
+        for idx, batch in enumerate(lp):
+            x = batch.to(self.device)
+            with torch.no_grad():
+                qres = self.model(x)
+
+            y_pred = qres.x.cpu()
+            y = batch.cpu()  # should already be on CPU but just in case
+            pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
+
+        metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
+        for pending in metrics_lp:
+            metrics = pending.result()
+            metrics = average(metrics)
+
+    metrics = flashy.distrib.average_metrics(metrics, len(loader))
+    return metrics
+
+
+
+def load_from_pretrained(self, name: str) ‑> dict +
+
+
+
+ +Expand source code + +
def load_from_pretrained(self, name: str) -> dict:
+    model = models.CompressionModel.get_pretrained(name)
+    if isinstance(model, models.DAC):
+        raise RuntimeError("Cannot fine tune a DAC model.")
+    elif isinstance(model, models.HFEncodecCompressionModel):
+        self.logger.warning('Trying to automatically convert a HuggingFace model '
+                            'to AudioCraft, this might fail!')
+        state = model.model.state_dict()
+        new_state = {}
+        for k, v in state.items():
+            if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
+                # We need to determine if this a convtr or a regular conv.
+                layer = int(k.split('.')[2])
+                if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
+
+                    k = k.replace('.conv.', '.convtr.')
+            k = k.replace('encoder.layers.', 'encoder.model.')
+            k = k.replace('decoder.layers.', 'decoder.model.')
+            k = k.replace('conv.', 'conv.conv.')
+            k = k.replace('convtr.', 'convtr.convtr.')
+            k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
+            k = k.replace('.codebook.', '._codebook.')
+            new_state[k] = v
+        state = new_state
+    elif isinstance(model, models.EncodecModel):
+        state = model.state_dict()
+    else:
+        raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
+    return {
+        'best_state': {'model': state}
+    }
+
+
+
+def show(self) +
+
+

Show the compression model and employed adversarial loss.

+
+ +Expand source code + +
def show(self):
+    """Show the compression model and employed adversarial loss."""
+    self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:")
+    self.log_model_summary(self.model)
+    self.logger.info("Adversarial loss:")
+    self.log_model_summary(self.adv_losses)
+    self.logger.info("Auxiliary losses:")
+    self.logger.info(self.aux_losses)
+    self.logger.info("Info losses:")
+    self.logger.info(self.info_losses)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/solvers/diffusion.html b/api_docs/audiocraft/solvers/diffusion.html new file mode 100644 index 00000000..6fc7dc6b --- /dev/null +++ b/api_docs/audiocraft/solvers/diffusion.html @@ -0,0 +1,887 @@ + + + + + + +audiocraft.solvers.diffusion API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.solvers.diffusion

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import flashy
+import julius
+import omegaconf
+import torch
+import torch.nn.functional as F
+
+from . import builders
+from . import base
+from .. import models
+from ..modules.diffusion_schedule import NoiseSchedule
+from ..metrics import RelativeVolumeMel
+from ..models.builders import get_processor
+from ..utils.samples.manager import SampleManager
+from ..solvers.compression import CompressionSolver
+
+
+class PerStageMetrics:
+    """Handle prompting the metrics per stage.
+    It outputs the metrics per range of diffusion states.
+    e.g. avg loss when t in [250, 500]
+    """
+    def __init__(self, num_steps: int, num_stages: int = 4):
+        self.num_steps = num_steps
+        self.num_stages = num_stages
+
+    def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
+        if type(step) is int:
+            stage = int((step / self.num_steps) * self.num_stages)
+            return {f"{name}_{stage}": loss for name, loss in losses.items()}
+        elif type(step) is torch.Tensor:
+            stage_tensor = ((step / self.num_steps) * self.num_stages).long()
+            out: tp.Dict[str, float] = {}
+            for stage_idx in range(self.num_stages):
+                mask = (stage_tensor == stage_idx)
+                N = mask.sum()
+                stage_out = {}
+                if N > 0:  # pass if no elements in the stage
+                    for name, loss in losses.items():
+                        stage_loss = (mask * loss).sum() / N
+                        stage_out[f"{name}_{stage_idx}"] = stage_loss
+                out = {**out, **stage_out}
+            return out
+
+
+class DataProcess:
+    """Apply filtering or resampling.
+
+    Args:
+        initial_sr (int): Initial sample rate.
+        target_sr (int): Target sample rate.
+        use_resampling: Whether to use resampling or not.
+        use_filter (bool):
+        n_bands (int): Number of bands to consider.
+        idx_band (int):
+        device (torch.device or str):
+        cutoffs ():
+        boost (bool):
+    """
+    def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
+                 use_filter: bool = False, n_bands: int = 4,
+                 idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
+        """Apply filtering or resampling
+        Args:
+            initial_sr (int): sample rate of the dataset
+            target_sr (int): sample rate after resampling
+            use_resampling (bool): whether or not performs resampling
+            use_filter (bool): when True filter the data to keep only one frequency band
+            n_bands (int): Number of bands used
+            cuts (none or list): The cutoff frequencies of the band filtering
+                                if None then we use mel scale bands.
+            idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
+            boost (bool): make the data scale match our music dataset.
+        """
+        assert idx_band < n_bands
+        self.idx_band = idx_band
+        if use_filter:
+            if cutoffs is not None:
+                self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
+            else:
+                self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
+        self.use_filter = use_filter
+        self.use_resampling = use_resampling
+        self.target_sr = target_sr
+        self.initial_sr = initial_sr
+        self.boost = boost
+
+    def process_data(self, x, metric=False):
+        if x is None:
+            return None
+        if self.boost:
+            x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
+            x * 0.22
+        if self.use_filter and not metric:
+            x = self.filter(x)[self.idx_band]
+        if self.use_resampling:
+            x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
+        return x
+
+    def inverse_process(self, x):
+        """Upsampling only."""
+        if self.use_resampling:
+            x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
+        return x
+
+
+class DiffusionSolver(base.StandardSolver):
+    """Solver for compression task.
+
+    The diffusion task allows for MultiBand diffusion model training.
+
+    Args:
+        cfg (DictConfig): Configuration.
+    """
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__(cfg)
+        self.cfg = cfg
+        self.device = cfg.device
+        self.sample_rate: int = self.cfg.sample_rate
+        self.codec_model = CompressionSolver.model_from_checkpoint(
+            cfg.compression_model_checkpoint, device=self.device)
+
+        self.codec_model.set_num_codebooks(cfg.n_q)
+        assert self.codec_model.sample_rate == self.cfg.sample_rate, (
+            f"Codec model sample rate is {self.codec_model.sample_rate} but "
+            f"Solver sample rate is {self.cfg.sample_rate}."
+            )
+        assert self.codec_model.sample_rate == self.sample_rate, \
+            f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
+            "don't match."
+
+        self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
+        self.register_stateful('sample_processor')
+        self.sample_processor.to(self.device)
+
+        self.schedule = NoiseSchedule(
+            **cfg.schedule, device=self.device, sample_processor=self.sample_processor)
+
+        self.eval_metric: tp.Optional[torch.nn.Module] = None
+
+        self.rvm = RelativeVolumeMel()
+        self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
+                                          use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
+                                          use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
+                                          idx_band=cfg.filter.idx_band, device=self.device)
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        if self._current_stage == "evaluate":
+            return 'rvm'
+        else:
+            return 'loss'
+
+    @torch.no_grad()
+    def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
+        codes, scale = self.codec_model.encode(wav)
+        assert scale is None, "Scaled compression models not supported."
+        emb = self.codec_model.decode_latent(codes)
+        return emb
+
+    def build_model(self):
+        """Build model and optimizer as well as optional Exponential Moving Average of the model.
+        """
+        # Model and optimizer
+        self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
+        self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+        self.register_stateful('model', 'optimizer')
+        self.register_best_state('model')
+        self.register_ema('model')
+
+    def build_dataloaders(self):
+        """Build audio dataloaders for each stage."""
+        self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+    def show(self):
+        # TODO
+        raise NotImplementedError()
+
+    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
+        """Perform one training or valid step on a given batch."""
+        x = batch.to(self.device)
+        loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
+
+        condition = self.get_condition(x)  # [bs, 128, T/hop, n_emb]
+        sample = self.data_processor.process_data(x)
+
+        input_, target, step = self.schedule.get_training_item(sample,
+                                                               tensor_step=self.cfg.schedule.variable_step_batch)
+        out = self.model(input_, step, condition=condition).sample
+
+        base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
+        reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
+        loss = base_loss / reference_loss ** self.cfg.loss.norm_power
+
+        if self.is_training:
+            loss.mean().backward()
+            flashy.distrib.sync_model(self.model)
+            self.optimizer.step()
+            self.optimizer.zero_grad()
+        metrics = {
+            'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
+            }
+        metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
+        metrics.update({
+            'std_in': input_.std(), 'std_out': out.std()})
+        return metrics
+
+    def run_epoch(self):
+        # reset random seed at the beginning of the epoch
+        self.rng = torch.Generator()
+        self.rng.manual_seed(1234 + self.epoch)
+        self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
+        # run epoch
+        super().run_epoch()
+
+    def evaluate(self):
+        """Evaluate stage.
+        Runs audio reconstruction evaluation.
+        """
+        self.model.eval()
+        evaluate_stage_name = f'{self.current_stage}'
+        loader = self.dataloaders['evaluate']
+        updates = len(loader)
+        lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
+
+        metrics = {}
+        n = 1
+        for idx, batch in enumerate(lp):
+            x = batch.to(self.device)
+            with torch.no_grad():
+                y_pred = self.regenerate(x)
+
+            y_pred = y_pred.cpu()
+            y = batch.cpu()  # should already be on CPU but just in case
+            rvm = self.rvm(y_pred, y)
+            lp.update(**rvm)
+            if len(metrics) == 0:
+                metrics = rvm
+            else:
+                for key in rvm.keys():
+                    metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
+        metrics = flashy.distrib.average_metrics(metrics)
+        return metrics
+
+    @torch.no_grad()
+    def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
+        """Regenerate the given waveform."""
+        condition = self.get_condition(wav)
+        initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav))  # sampling rate changes.
+        result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
+                                                   step_list=step_list)
+        result = self.data_processor.inverse_process(result)
+        return result
+
+    def generate(self):
+        """Generate stage."""
+        sample_manager = SampleManager(self.xp)
+        self.model.eval()
+        generate_stage_name = f'{self.current_stage}'
+
+        loader = self.dataloaders['generate']
+        updates = len(loader)
+        lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+        for batch in lp:
+            reference, _ = batch
+            reference = reference.to(self.device)
+            estimate = self.regenerate(reference)
+            reference = reference.cpu()
+            estimate = estimate.cpu()
+            sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
+        flashy.distrib.barrier()
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class DataProcess +(initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, use_filter: bool = False, n_bands: int = 4, idx_band: int = 0, device: torch.device = device(type='cpu'), cutoffs=None, boost=False) +
+
+

Apply filtering or resampling.

+

Args

+
+
initial_sr : int
+
Initial sample rate.
+
target_sr : int
+
Target sample rate.
+
use_resampling
+
Whether to use resampling or not.
+
use_filter (bool):
+
n_bands : int
+
Number of bands to consider.
+
+

idx_band (int): +device (torch.device or str): +cutoffs (): +boost (bool): +Apply filtering or resampling

+

Args

+
+
initial_sr : int
+
sample rate of the dataset
+
target_sr : int
+
sample rate after resampling
+
use_resampling : bool
+
whether or not performs resampling
+
use_filter : bool
+
when True filter the data to keep only one frequency band
+
n_bands : int
+
Number of bands used
+
cuts : none or list
+
The cutoff frequencies of the band filtering +if None then we use mel scale bands.
+
idx_band : int
+
index of the frequency band. 0 are lows … (n_bands - 1) highs
+
boost : bool
+
make the data scale match our music dataset.
+
+
+ +Expand source code + +
class DataProcess:
+    """Apply filtering or resampling.
+
+    Args:
+        initial_sr (int): Initial sample rate.
+        target_sr (int): Target sample rate.
+        use_resampling: Whether to use resampling or not.
+        use_filter (bool):
+        n_bands (int): Number of bands to consider.
+        idx_band (int):
+        device (torch.device or str):
+        cutoffs ():
+        boost (bool):
+    """
+    def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
+                 use_filter: bool = False, n_bands: int = 4,
+                 idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
+        """Apply filtering or resampling
+        Args:
+            initial_sr (int): sample rate of the dataset
+            target_sr (int): sample rate after resampling
+            use_resampling (bool): whether or not performs resampling
+            use_filter (bool): when True filter the data to keep only one frequency band
+            n_bands (int): Number of bands used
+            cuts (none or list): The cutoff frequencies of the band filtering
+                                if None then we use mel scale bands.
+            idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
+            boost (bool): make the data scale match our music dataset.
+        """
+        assert idx_band < n_bands
+        self.idx_band = idx_band
+        if use_filter:
+            if cutoffs is not None:
+                self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
+            else:
+                self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
+        self.use_filter = use_filter
+        self.use_resampling = use_resampling
+        self.target_sr = target_sr
+        self.initial_sr = initial_sr
+        self.boost = boost
+
+    def process_data(self, x, metric=False):
+        if x is None:
+            return None
+        if self.boost:
+            x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
+            x * 0.22
+        if self.use_filter and not metric:
+            x = self.filter(x)[self.idx_band]
+        if self.use_resampling:
+            x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
+        return x
+
+    def inverse_process(self, x):
+        """Upsampling only."""
+        if self.use_resampling:
+            x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
+        return x
+
+

Methods

+
+
+def inverse_process(self, x) +
+
+

Upsampling only.

+
+ +Expand source code + +
def inverse_process(self, x):
+    """Upsampling only."""
+    if self.use_resampling:
+        x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
+    return x
+
+
+
+def process_data(self, x, metric=False) +
+
+
+
+ +Expand source code + +
def process_data(self, x, metric=False):
+    if x is None:
+        return None
+    if self.boost:
+        x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
+        x * 0.22
+    if self.use_filter and not metric:
+        x = self.filter(x)[self.idx_band]
+    if self.use_resampling:
+        x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
+    return x
+
+
+
+
+
+class DiffusionSolver +(cfg: omegaconf.dictconfig.DictConfig) +
+
+

Solver for compression task.

+

The diffusion task allows for MultiBand diffusion model training.

+

Args

+
+
cfg : DictConfig
+
Configuration.
+
+
+ +Expand source code + +
class DiffusionSolver(base.StandardSolver):
+    """Solver for compression task.
+
+    The diffusion task allows for MultiBand diffusion model training.
+
+    Args:
+        cfg (DictConfig): Configuration.
+    """
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__(cfg)
+        self.cfg = cfg
+        self.device = cfg.device
+        self.sample_rate: int = self.cfg.sample_rate
+        self.codec_model = CompressionSolver.model_from_checkpoint(
+            cfg.compression_model_checkpoint, device=self.device)
+
+        self.codec_model.set_num_codebooks(cfg.n_q)
+        assert self.codec_model.sample_rate == self.cfg.sample_rate, (
+            f"Codec model sample rate is {self.codec_model.sample_rate} but "
+            f"Solver sample rate is {self.cfg.sample_rate}."
+            )
+        assert self.codec_model.sample_rate == self.sample_rate, \
+            f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
+            "don't match."
+
+        self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
+        self.register_stateful('sample_processor')
+        self.sample_processor.to(self.device)
+
+        self.schedule = NoiseSchedule(
+            **cfg.schedule, device=self.device, sample_processor=self.sample_processor)
+
+        self.eval_metric: tp.Optional[torch.nn.Module] = None
+
+        self.rvm = RelativeVolumeMel()
+        self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
+                                          use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
+                                          use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
+                                          idx_band=cfg.filter.idx_band, device=self.device)
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        if self._current_stage == "evaluate":
+            return 'rvm'
+        else:
+            return 'loss'
+
+    @torch.no_grad()
+    def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
+        codes, scale = self.codec_model.encode(wav)
+        assert scale is None, "Scaled compression models not supported."
+        emb = self.codec_model.decode_latent(codes)
+        return emb
+
+    def build_model(self):
+        """Build model and optimizer as well as optional Exponential Moving Average of the model.
+        """
+        # Model and optimizer
+        self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
+        self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+        self.register_stateful('model', 'optimizer')
+        self.register_best_state('model')
+        self.register_ema('model')
+
+    def build_dataloaders(self):
+        """Build audio dataloaders for each stage."""
+        self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+    def show(self):
+        # TODO
+        raise NotImplementedError()
+
+    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
+        """Perform one training or valid step on a given batch."""
+        x = batch.to(self.device)
+        loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
+
+        condition = self.get_condition(x)  # [bs, 128, T/hop, n_emb]
+        sample = self.data_processor.process_data(x)
+
+        input_, target, step = self.schedule.get_training_item(sample,
+                                                               tensor_step=self.cfg.schedule.variable_step_batch)
+        out = self.model(input_, step, condition=condition).sample
+
+        base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
+        reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
+        loss = base_loss / reference_loss ** self.cfg.loss.norm_power
+
+        if self.is_training:
+            loss.mean().backward()
+            flashy.distrib.sync_model(self.model)
+            self.optimizer.step()
+            self.optimizer.zero_grad()
+        metrics = {
+            'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
+            }
+        metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
+        metrics.update({
+            'std_in': input_.std(), 'std_out': out.std()})
+        return metrics
+
+    def run_epoch(self):
+        # reset random seed at the beginning of the epoch
+        self.rng = torch.Generator()
+        self.rng.manual_seed(1234 + self.epoch)
+        self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
+        # run epoch
+        super().run_epoch()
+
+    def evaluate(self):
+        """Evaluate stage.
+        Runs audio reconstruction evaluation.
+        """
+        self.model.eval()
+        evaluate_stage_name = f'{self.current_stage}'
+        loader = self.dataloaders['evaluate']
+        updates = len(loader)
+        lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
+
+        metrics = {}
+        n = 1
+        for idx, batch in enumerate(lp):
+            x = batch.to(self.device)
+            with torch.no_grad():
+                y_pred = self.regenerate(x)
+
+            y_pred = y_pred.cpu()
+            y = batch.cpu()  # should already be on CPU but just in case
+            rvm = self.rvm(y_pred, y)
+            lp.update(**rvm)
+            if len(metrics) == 0:
+                metrics = rvm
+            else:
+                for key in rvm.keys():
+                    metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
+        metrics = flashy.distrib.average_metrics(metrics)
+        return metrics
+
+    @torch.no_grad()
+    def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
+        """Regenerate the given waveform."""
+        condition = self.get_condition(wav)
+        initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav))  # sampling rate changes.
+        result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
+                                                   step_list=step_list)
+        result = self.data_processor.inverse_process(result)
+        return result
+
+    def generate(self):
+        """Generate stage."""
+        sample_manager = SampleManager(self.xp)
+        self.model.eval()
+        generate_stage_name = f'{self.current_stage}'
+
+        loader = self.dataloaders['generate']
+        updates = len(loader)
+        lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+        for batch in lp:
+            reference, _ = batch
+            reference = reference.to(self.device)
+            estimate = self.regenerate(reference)
+            reference = reference.cpu()
+            estimate = estimate.cpu()
+            sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
+        flashy.distrib.barrier()
+
+

Ancestors

+ +

Methods

+
+
+def build_dataloaders(self) +
+
+

Build audio dataloaders for each stage.

+
+ +Expand source code + +
def build_dataloaders(self):
+    """Build audio dataloaders for each stage."""
+    self.dataloaders = builders.get_audio_datasets(self.cfg)
+
+
+
+def build_model(self) +
+
+

Build model and optimizer as well as optional Exponential Moving Average of the model.

+
+ +Expand source code + +
def build_model(self):
+    """Build model and optimizer as well as optional Exponential Moving Average of the model.
+    """
+    # Model and optimizer
+    self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
+    self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
+    self.register_stateful('model', 'optimizer')
+    self.register_best_state('model')
+    self.register_ema('model')
+
+
+
+def evaluate(self) +
+
+

Evaluate stage. +Runs audio reconstruction evaluation.

+
+ +Expand source code + +
def evaluate(self):
+    """Evaluate stage.
+    Runs audio reconstruction evaluation.
+    """
+    self.model.eval()
+    evaluate_stage_name = f'{self.current_stage}'
+    loader = self.dataloaders['evaluate']
+    updates = len(loader)
+    lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
+
+    metrics = {}
+    n = 1
+    for idx, batch in enumerate(lp):
+        x = batch.to(self.device)
+        with torch.no_grad():
+            y_pred = self.regenerate(x)
+
+        y_pred = y_pred.cpu()
+        y = batch.cpu()  # should already be on CPU but just in case
+        rvm = self.rvm(y_pred, y)
+        lp.update(**rvm)
+        if len(metrics) == 0:
+            metrics = rvm
+        else:
+            for key in rvm.keys():
+                metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
+    metrics = flashy.distrib.average_metrics(metrics)
+    return metrics
+
+
+
+def get_condition(self, wav: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
@torch.no_grad()
+def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
+    codes, scale = self.codec_model.encode(wav)
+    assert scale is None, "Scaled compression models not supported."
+    emb = self.codec_model.decode_latent(codes)
+    return emb
+
+
+
+def regenerate(self, wav: torch.Tensor, step_list: Optional[list] = None) +
+
+

Regenerate the given waveform.

+
+ +Expand source code + +
@torch.no_grad()
+def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
+    """Regenerate the given waveform."""
+    condition = self.get_condition(wav)
+    initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav))  # sampling rate changes.
+    result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
+                                               step_list=step_list)
+    result = self.data_processor.inverse_process(result)
+    return result
+
+
+
+

Inherited members

+ +
+
+class PerStageMetrics +(num_steps: int, num_stages: int = 4) +
+
+

Handle prompting the metrics per stage. +It outputs the metrics per range of diffusion states. +e.g. avg loss when t in [250, 500]

+
+ +Expand source code + +
class PerStageMetrics:
+    """Handle prompting the metrics per stage.
+    It outputs the metrics per range of diffusion states.
+    e.g. avg loss when t in [250, 500]
+    """
+    def __init__(self, num_steps: int, num_stages: int = 4):
+        self.num_steps = num_steps
+        self.num_stages = num_stages
+
+    def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
+        if type(step) is int:
+            stage = int((step / self.num_steps) * self.num_stages)
+            return {f"{name}_{stage}": loss for name, loss in losses.items()}
+        elif type(step) is torch.Tensor:
+            stage_tensor = ((step / self.num_steps) * self.num_stages).long()
+            out: tp.Dict[str, float] = {}
+            for stage_idx in range(self.num_stages):
+                mask = (stage_tensor == stage_idx)
+                N = mask.sum()
+                stage_out = {}
+                if N > 0:  # pass if no elements in the stage
+                    for name, loss in losses.items():
+                        stage_loss = (mask * loss).sum() / N
+                        stage_out[f"{name}_{stage_idx}"] = stage_loss
+                out = {**out, **stage_out}
+            return out
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/solvers/index.html b/api_docs/audiocraft/solvers/index.html new file mode 100644 index 00000000..c40e2426 --- /dev/null +++ b/api_docs/audiocraft/solvers/index.html @@ -0,0 +1,116 @@ + + + + + + +audiocraft.solvers API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.solvers

+
+
+

Solvers. A Solver is a training recipe, combining the dataloaders, models, +optimizer, losses etc into a single convenient object.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+Solvers. A Solver is a training recipe, combining the dataloaders, models,
+optimizer, losses etc into a single convenient object.
+"""
+
+# flake8: noqa
+from .audiogen import AudioGenSolver
+from .builders import get_solver
+from .base import StandardSolver
+from .compression import CompressionSolver
+from .musicgen import MusicGenSolver
+from .diffusion import DiffusionSolver
+
+
+
+

Sub-modules

+
+
audiocraft.solvers.audiogen
+
+
+
+
audiocraft.solvers.base
+
+
+
+
audiocraft.solvers.builders
+
+

All the functions to build the relevant solvers and used objects +from the Hydra config.

+
+
audiocraft.solvers.compression
+
+
+
+
audiocraft.solvers.diffusion
+
+
+
+
audiocraft.solvers.musicgen
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/solvers/musicgen.html b/api_docs/audiocraft/solvers/musicgen.html new file mode 100644 index 00000000..9a3c4c94 --- /dev/null +++ b/api_docs/audiocraft/solvers/musicgen.html @@ -0,0 +1,2048 @@ + + + + + + +audiocraft.solvers.musicgen API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.solvers.musicgen

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import time
+import typing as tp
+import warnings
+
+import flashy
+import math
+import omegaconf
+import torch
+from torch.nn import functional as F
+
+from . import base, builders
+from .compression import CompressionSolver
+from .. import metrics as eval_metrics
+from .. import models
+from ..data.audio_dataset import AudioDataset
+from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo
+from ..data.audio_utils import normalize_audio
+from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
+from ..utils.cache import CachedBatchWriter, CachedBatchLoader
+from ..utils.samples.manager import SampleManager
+from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once
+
+
+class MusicGenSolver(base.StandardSolver):
+    """Solver for MusicGen training task.
+
+    Used in: https://arxiv.org/abs/2306.05284
+    """
+    DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
+
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__(cfg)
+        # easier access to sampling parameters
+        self.generation_params = {
+            'use_sampling': self.cfg.generate.lm.use_sampling,
+            'temp': self.cfg.generate.lm.temp,
+            'top_k': self.cfg.generate.lm.top_k,
+            'top_p': self.cfg.generate.lm.top_p,
+        }
+        self._best_metric_name: tp.Optional[str] = 'ce'
+
+        self._cached_batch_writer = None
+        self._cached_batch_loader = None
+        if cfg.cache.path:
+            if cfg.cache.write:
+                self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
+                if self.cfg.cache.write_num_shards:
+                    self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
+                    self._best_metric_name = None
+            else:
+                self._cached_batch_loader = CachedBatchLoader(
+                    Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
+                    min_length=self.cfg.optim.updates_per_epoch or 1)
+                self.dataloaders['original_train'] = self.dataloaders['train']
+                self.dataloaders['train'] = self._cached_batch_loader  # type: ignore
+
+    @staticmethod
+    def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+                                 device: tp.Optional[str] = None, autocast: bool = True,
+                                 batch_size: tp.Optional[int] = None,
+                                 override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                                 **kwargs):
+        """Mostly a convenience function around magma.train.get_solver_from_sig,
+        populating all the proper param, deactivating EMA, FSDP, loading the best state,
+        basically all you need to get a solver ready to "play" with in single GPU mode
+        and with minimal memory overhead.
+
+        Args:
+            sig (str): signature to load.
+            dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+            device (str or None): potential device, as a string, i.e. 'cuda'.
+            override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+        """
+        from audiocraft import train
+        our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+        our_override_cfg['autocast'] = autocast
+        if dtype is not None:
+            our_override_cfg['dtype'] = dtype
+        if device is not None:
+            our_override_cfg['device'] = device
+        if batch_size is not None:
+            our_override_cfg['dataset'] = {'batch_size': batch_size}
+        if override_cfg is None:
+            override_cfg = {}
+        override_cfg = omegaconf.OmegaConf.merge(
+            omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
+        solver = train.get_solver_from_sig(
+            sig, override_cfg=override_cfg,
+            load_best=True, disable_fsdp=True,
+            ignore_state_keys=['optimizer', 'ema'], **kwargs)
+        solver.model.eval()
+        return solver
+
+    def get_formatter(self, stage_name: str) -> flashy.Formatter:
+        return flashy.Formatter({
+            'lr': '.2E',
+            'ce': '.3f',
+            'ppl': '.3f',
+            'grad_norm': '.3E',
+        }, exclude_keys=['ce_q*', 'ppl_q*'])
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        return self._best_metric_name
+
+    def build_model(self) -> None:
+        """Instantiate models and optimizer."""
+        # we can potentially not use all quantizers with which the EnCodec model was trained
+        # (e.g. we trained the model with quantizers dropout)
+        self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
+            self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
+        assert self.compression_model.sample_rate == self.cfg.sample_rate, (
+            f"Compression model sample rate is {self.compression_model.sample_rate} but "
+            f"Solver sample rate is {self.cfg.sample_rate}."
+            )
+        # ensure we have matching configuration between LM and compression model
+        assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
+            "Cardinalities of the LM and compression model don't match: ",
+            f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
+            f"compression model cardinality is {self.compression_model.cardinality}"
+        )
+        assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
+            "Numbers of codebooks of the LM and compression models don't match: ",
+            f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
+            f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
+        )
+        self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
+                         self.compression_model.num_codebooks, self.compression_model.cardinality,
+                         self.compression_model.frame_rate)
+        # instantiate LM model
+        self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
+        if self.cfg.fsdp.use:
+            assert not self.cfg.autocast, "Cannot use autocast with fsdp"
+            self.model = self.wrap_with_fsdp(self.model)
+        self.register_ema('model')
+        # initialize optimization
+        self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
+        self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
+        self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
+        self.register_best_state('model')
+        self.autocast_dtype = {
+            'float16': torch.float16, 'bfloat16': torch.bfloat16
+        }[self.cfg.autocast_dtype]
+        self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
+        if self.cfg.fsdp.use:
+            need_scaler = self.cfg.fsdp.param_dtype == 'float16'
+        else:
+            need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
+        if need_scaler:
+            if self.cfg.fsdp.use:
+                from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+                self.scaler = ShardedGradScaler()  # type: ignore
+            else:
+                self.scaler = torch.cuda.amp.GradScaler()
+            self.register_stateful('scaler')
+
+    def build_dataloaders(self) -> None:
+        """Instantiate audio dataloaders for each stage."""
+        self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
+
+    def show(self) -> None:
+        """Show the compression model and LM model."""
+        self.logger.info("Compression model:")
+        self.log_model_summary(self.compression_model)
+        self.logger.info("LM model:")
+        self.log_model_summary(self.model)
+
+    def load_state_dict(self, state: dict) -> None:
+        if 'condition_provider' in state:
+            model_state = state['model']
+            condition_provider_state = state.pop('condition_provider')
+            prefix = 'condition_provider.'
+            for key, value in condition_provider_state.items():
+                key = prefix + key
+                assert key not in model_state
+                model_state[key] = value
+        super().load_state_dict(state)
+
+    def load_from_pretrained(self, name: str):
+        # TODO: support native HF versions of MusicGen.
+        lm_pkg = models.loaders.load_lm_model_ckpt(name)
+        state: dict = {
+            'best_state': {
+                'model': lm_pkg['best_state'],
+            },
+        }
+        return state
+
+    def _compute_cross_entropy(
+        self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
+    ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
+        """Compute cross entropy between multi-codebook targets and model's logits.
+        The cross entropy is computed per codebook to provide codebook-level cross entropy.
+        Valid timesteps for each of the codebook are pulled from the mask, where invalid
+        timesteps are set to 0.
+
+        Args:
+            logits (torch.Tensor): Model's logits of shape [B, K, T, card].
+            targets (torch.Tensor): Target codes, of shape [B, K, T].
+            mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
+        Returns:
+            ce (torch.Tensor): Cross entropy averaged over the codebooks
+            ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
+        """
+        B, K, T = targets.shape
+        assert logits.shape[:-1] == targets.shape
+        assert mask.shape == targets.shape
+        ce = torch.zeros([], device=targets.device)
+        ce_per_codebook: tp.List[torch.Tensor] = []
+        for k in range(K):
+            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
+            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
+            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
+            ce_targets = targets_k[mask_k]
+            ce_logits = logits_k[mask_k]
+            q_ce = F.cross_entropy(ce_logits, ce_targets)
+            ce += q_ce
+            ce_per_codebook.append(q_ce.detach())
+        # average cross entropy across codebooks
+        ce = ce / K
+        return ce, ce_per_codebook
+
+    def _prepare_tokens_and_attributes(
+        self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
+        check_synchronization_points: bool = False
+    ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
+        """Prepare input batchs for language model training.
+
+        Args:
+            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
+                and corresponding metadata as SegmentWithAttributes (with B items).
+            check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
+        Returns:
+            Condition tensors (dict[str, any]): Preprocessed condition attributes.
+            Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
+                with B the batch size, K the number of codebooks, T_s the token timesteps.
+            Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
+        """
+        if self.model.training:
+            warnings.warn(
+                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
+                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
+                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
+                "Really sorry about that.")
+        if self._cached_batch_loader is None or self.current_stage != "train":
+            audio, infos = batch
+            audio = audio.to(self.device)
+            audio_tokens = None
+            assert audio.size(0) == len(infos), (
+                f"Mismatch between number of items in audio batch ({audio.size(0)})",
+                f" and in metadata ({len(infos)})"
+            )
+        else:
+            audio = None
+            # In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
+            infos, = batch  # type: ignore
+            assert all([isinstance(info, AudioInfo) for info in infos])
+            assert all([info.audio_tokens is not None for info in infos])  # type: ignore
+            audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device)  # type: ignore
+            audio_tokens = audio_tokens.long()
+            for info in infos:
+                if isinstance(info, MusicInfo):
+                    # Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
+                    # then you must be using the chroma cache! otherwise the code will try
+                    # to use this segment and fail (by that I mean you will see NaN everywhere).
+                    info.self_wav = WavCondition(
+                        torch.full([1, info.channels, info.total_frames], float('NaN')),
+                        length=torch.tensor([info.n_frames]),
+                        sample_rate=[info.sample_rate],
+                        path=[info.meta.path],
+                        seek_time=[info.seek_time])
+                    dataset = get_dataset_from_loader(self.dataloaders['original_train'])
+                    assert isinstance(dataset, MusicDataset), type(dataset)
+                    if dataset.paraphraser is not None and info.description is not None:
+                        # Hackingly reapplying paraphraser when using cache.
+                        info.description = dataset.paraphraser.sample_paraphrase(
+                            info.meta.path, info.description)
+        # prepare attributes
+        attributes = [info.to_condition_attributes() for info in infos]
+        attributes = self.model.cfg_dropout(attributes)
+        attributes = self.model.att_dropout(attributes)
+        tokenized = self.model.condition_provider.tokenize(attributes)
+
+        # Now we should be synchronization free.
+        if self.device == "cuda" and check_synchronization_points:
+            torch.cuda.set_sync_debug_mode("warn")
+
+        if audio_tokens is None:
+            with torch.no_grad():
+                audio_tokens, scale = self.compression_model.encode(audio)
+                assert scale is None, "Scaled compression model not supported with LM."
+
+        with self.autocast:
+            condition_tensors = self.model.condition_provider(tokenized)
+
+        # create a padding mask to hold valid vs invalid positions
+        padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
+        # replace encodec tokens from padded audio with special_token_id
+        if self.cfg.tokens.padding_with_special_token:
+            audio_tokens = audio_tokens.clone()
+            padding_mask = padding_mask.clone()
+            token_sample_rate = self.compression_model.frame_rate
+            B, K, T_s = audio_tokens.shape
+            for i in range(B):
+                n_samples = infos[i].n_frames
+                audio_sample_rate = infos[i].sample_rate
+                # take the last token generated from actual audio frames (non-padded audio)
+                valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
+                audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
+                padding_mask[i, :, valid_tokens:] = 0
+
+        if self.device == "cuda" and check_synchronization_points:
+            torch.cuda.set_sync_debug_mode("default")
+
+        if self._cached_batch_writer is not None and self.current_stage == 'train':
+            assert self._cached_batch_loader is None
+            assert audio_tokens is not None
+            for info, one_audio_tokens in zip(infos, audio_tokens):
+                assert isinstance(info, AudioInfo)
+                if isinstance(info, MusicInfo):
+                    assert not info.joint_embed, "joint_embed and cache not supported yet."
+                    info.self_wav = None
+                assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
+                info.audio_tokens = one_audio_tokens.short().cpu()
+            self._cached_batch_writer.save(infos)
+
+        return condition_tensors, audio_tokens, padding_mask
+
+    def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
+        """Perform one training or valid step on a given batch."""
+        check_synchronization_points = idx == 1 and self.device == 'cuda'
+
+        condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
+            batch, check_synchronization_points)
+
+        self.deadlock_detect.update('tokens_and_conditions')
+
+        if check_synchronization_points:
+            torch.cuda.set_sync_debug_mode('warn')
+
+        with self.autocast:
+            model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors)  # type: ignore
+            logits = model_output.logits
+            mask = padding_mask & model_output.mask
+            ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
+            loss = ce
+        self.deadlock_detect.update('loss')
+
+        if check_synchronization_points:
+            torch.cuda.set_sync_debug_mode('default')
+
+        if self.is_training:
+            metrics['lr'] = self.optimizer.param_groups[0]['lr']
+            if self.scaler is not None:
+                loss = self.scaler.scale(loss)
+            self.deadlock_detect.update('scale')
+            if self.cfg.fsdp.use:
+                loss.backward()
+                flashy.distrib.average_tensors(self.model.buffers())
+            elif self.cfg.optim.eager_sync:
+                with flashy.distrib.eager_sync_model(self.model):
+                    loss.backward()
+            else:
+                # this should always be slower but can be useful
+                # for weird use cases like multiple backwards.
+                loss.backward()
+                flashy.distrib.sync_model(self.model)
+            self.deadlock_detect.update('backward')
+
+            if self.scaler is not None:
+                self.scaler.unscale_(self.optimizer)
+            if self.cfg.optim.max_norm:
+                if self.cfg.fsdp.use:
+                    metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
+                else:
+                    metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
+                        self.model.parameters(), self.cfg.optim.max_norm
+                    )
+            if self.scaler is None:
+                self.optimizer.step()
+            else:
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+            if self.lr_scheduler:
+                self.lr_scheduler.step()
+            self.optimizer.zero_grad()
+            self.deadlock_detect.update('optim')
+            if self.scaler is not None:
+                scale = self.scaler.get_scale()
+                metrics['grad_scale'] = scale
+            if not loss.isfinite().all():
+                raise RuntimeError("Model probably diverged.")
+
+        metrics['ce'] = ce
+        metrics['ppl'] = torch.exp(ce)
+        for k, ce_q in enumerate(ce_per_codebook):
+            metrics[f'ce_q{k + 1}'] = ce_q
+            metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
+
+        return metrics
+
+    @torch.no_grad()
+    def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
+                          gen_duration: float, prompt_duration: tp.Optional[float] = None,
+                          remove_prompt: bool = False,
+                          **generation_params) -> dict:
+        """Run generate step on a batch of optional audio tensor and corresponding attributes.
+
+        Args:
+            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
+            use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
+            gen_duration (float): Target audio duration for the generation.
+            prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
+            remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
+            generation_params: Additional generation parameters.
+        Returns:
+            gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
+                and the prompt along with additional information.
+        """
+        bench_start = time.time()
+        audio, meta = batch
+        assert audio.size(0) == len(meta), (
+            f"Mismatch between number of items in audio batch ({audio.size(0)})",
+            f" and in metadata ({len(meta)})"
+        )
+        # prepare attributes
+        attributes = [x.to_condition_attributes() for x in meta]
+        # TODO: Add dropout for chroma?
+
+        # prepare audio prompt
+        if prompt_duration is None:
+            prompt_audio = None
+        else:
+            assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
+            prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
+            prompt_audio = audio[..., :prompt_audio_frames]
+
+        # get audio tokens from compression model
+        if prompt_audio is None or prompt_audio.nelement() == 0:
+            num_samples = len(attributes)
+            prompt_tokens = None
+        else:
+            num_samples = None
+            prompt_audio = prompt_audio.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt_audio)
+            assert scale is None, "Compression model in MusicGen should not require rescaling."
+
+        # generate by sampling from the LM
+        with self.autocast:
+            total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
+            gen_tokens = self.model.generate(
+                prompt_tokens, attributes, max_gen_len=total_gen_len,
+                num_samples=num_samples, **self.generation_params)
+
+        # generate audio from tokens
+        assert gen_tokens.dim() == 3
+        gen_audio = self.compression_model.decode(gen_tokens, None)
+
+        bench_end = time.time()
+        gen_outputs = {
+            'rtf': (bench_end - bench_start) / gen_duration,
+            'ref_audio': audio,
+            'gen_audio': gen_audio,
+            'gen_tokens': gen_tokens,
+            'prompt_audio': prompt_audio,
+            'prompt_tokens': prompt_tokens,
+        }
+        return gen_outputs
+
+    def generate_audio(self) -> dict:
+        """Audio generation stage."""
+        generate_stage_name = f'{self.current_stage}'
+        sample_manager = SampleManager(self.xp)
+        self.logger.info(f"Generating samples in {sample_manager.base_folder}")
+        loader = self.dataloaders['generate']
+        updates = len(loader)
+        lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+        dataset = get_dataset_from_loader(loader)
+        dataset_duration = dataset.segment_duration
+        assert dataset_duration is not None
+        assert isinstance(dataset, AudioDataset)
+        target_duration = self.cfg.generate.lm.gen_duration
+        prompt_duration = self.cfg.generate.lm.prompt_duration
+        if target_duration is None:
+            target_duration = dataset_duration
+        if prompt_duration is None:
+            prompt_duration = dataset_duration / 4
+        assert prompt_duration < dataset_duration, (
+            f"Specified prompt duration ({prompt_duration}s) is longer",
+            f" than reference audio duration ({dataset_duration}s)"
+        )
+
+        def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
+            hydrated_conditions = []
+            for sample in [x.to_condition_attributes() for x in meta]:
+                cond_dict = {}
+                for cond_type in sample.__annotations__.keys():
+                    for cond_key, cond_val in getattr(sample, cond_type).items():
+                        if cond_key not in self.model.condition_provider.conditioners.keys():
+                            continue
+                        if is_jsonable(cond_val):
+                            cond_dict[cond_key] = cond_val
+                        elif isinstance(cond_val, WavCondition):
+                            cond_dict[cond_key] = cond_val.path
+                        elif isinstance(cond_val, JointEmbedCondition):
+                            cond_dict[cond_key] = cond_val.text  # only support text at inference for now
+                        else:
+                            # if we reached this point, it is not clear how to log the condition
+                            # so we just log the type.
+                            cond_dict[cond_key] = str(type(cond_val))
+                            continue
+                hydrated_conditions.append(cond_dict)
+            return hydrated_conditions
+
+        metrics: dict = {}
+        average = flashy.averager()
+        for batch in lp:
+            audio, meta = batch
+            # metadata for sample manager
+            hydrated_conditions = get_hydrated_conditions(meta)
+            sample_generation_params = {
+                **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
+                **self.generation_params
+            }
+            if self.cfg.generate.lm.unprompted_samples:
+                if self.cfg.generate.lm.gen_gt_samples:
+                    # get the ground truth instead of generation
+                    self.logger.warn(
+                        "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
+                    gen_unprompted_audio = audio
+                    rtf = 1.
+                else:
+                    gen_unprompted_outputs = self.run_generate_step(
+                        batch, gen_duration=target_duration, prompt_duration=None,
+                        **self.generation_params)
+                    gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
+                    rtf = gen_unprompted_outputs['rtf']
+                sample_manager.add_samples(
+                    gen_unprompted_audio, self.epoch, hydrated_conditions,
+                    ground_truth_wavs=audio, generation_args=sample_generation_params)
+
+            if self.cfg.generate.lm.prompted_samples:
+                gen_outputs = self.run_generate_step(
+                    batch, gen_duration=target_duration, prompt_duration=prompt_duration,
+                    **self.generation_params)
+                gen_audio = gen_outputs['gen_audio'].cpu()
+                prompt_audio = gen_outputs['prompt_audio'].cpu()
+                sample_manager.add_samples(
+                    gen_audio, self.epoch, hydrated_conditions,
+                    prompt_wavs=prompt_audio, ground_truth_wavs=audio,
+                    generation_args=sample_generation_params)
+
+            metrics['rtf'] = rtf
+            metrics = average(metrics)
+
+        flashy.distrib.barrier()
+        return metrics
+
+    def generate(self) -> dict:
+        """Generate stage."""
+        self.model.eval()
+        with torch.no_grad():
+            return self.generate_audio()
+
+    def run_epoch(self):
+        if self.cfg.cache.write:
+            if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
+                return
+        super().run_epoch()
+
+    def train(self):
+        """Train stage.
+        """
+        if self._cached_batch_writer is not None:
+            self._cached_batch_writer.start_epoch(self.epoch)
+        if self._cached_batch_loader is None:
+            dataset = get_dataset_from_loader(self.dataloaders['train'])
+            assert isinstance(dataset, AudioDataset)
+            dataset.current_epoch = self.epoch
+        else:
+            self._cached_batch_loader.start_epoch(self.epoch)
+        return super().train()
+
+    def evaluate_audio_generation(self) -> dict:
+        """Evaluate audio generation with off-the-shelf metrics."""
+        evaluate_stage_name = f'{self.current_stage}_generation'
+        # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
+        fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
+        kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
+        text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
+        chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
+        should_run_eval = False
+        eval_chroma_wavs: tp.Optional[torch.Tensor] = None
+        if self.cfg.evaluate.metrics.fad:
+            fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
+            should_run_eval = True
+        if self.cfg.evaluate.metrics.kld:
+            kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
+            should_run_eval = True
+        if self.cfg.evaluate.metrics.text_consistency:
+            text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
+            should_run_eval = True
+        if self.cfg.evaluate.metrics.chroma_cosine:
+            chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
+            # if we have predefind wavs for chroma we should purge them for computing the cosine metric
+            has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
+                                          self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
+            if has_predefined_eval_chromas:
+                warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
+                                       'Resetting eval chromas to None for evaluation.')
+                eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs  # type: ignore
+                self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None)  # type: ignore
+            should_run_eval = True
+
+        def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
+            audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
+            compressed_audio = self.compression_model.decode(audio_tokens, scale)
+            return compressed_audio[..., :audio.shape[-1]]
+
+        metrics: dict = {}
+        if should_run_eval:
+            loader = self.dataloaders['evaluate']
+            updates = len(loader)
+            lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+            average = flashy.averager()
+            dataset = get_dataset_from_loader(loader)
+            assert isinstance(dataset, AudioDataset)
+            self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
+
+            for idx, batch in enumerate(lp):
+                audio, meta = batch
+                assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
+
+                target_duration = audio.shape[-1] / self.cfg.sample_rate
+                if self.cfg.evaluate.fixed_generation_duration:
+                    target_duration = self.cfg.evaluate.fixed_generation_duration
+
+                gen_outputs = self.run_generate_step(
+                    batch, gen_duration=target_duration,
+                    **self.generation_params
+                )
+                y_pred = gen_outputs['gen_audio'].detach()
+                y_pred = y_pred[..., :audio.shape[-1]]
+
+                normalize_kwargs = dict(self.cfg.generate.audio)
+                normalize_kwargs.pop('format', None)
+                y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
+                y = audio.cpu()  # should already be on CPU but just in case
+                sizes = torch.tensor([m.n_frames for m in meta])  # actual sizes without padding
+                sample_rates = torch.tensor([m.sample_rate for m in meta])  # sample rates for audio samples
+                audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
+
+                if fad is not None:
+                    if self.cfg.metrics.fad.use_gt:
+                        y_pred = get_compressed_audio(y).cpu()
+                    fad.update(y_pred, y, sizes, sample_rates, audio_stems)
+                if kldiv is not None:
+                    if self.cfg.metrics.kld.use_gt:
+                        y_pred = get_compressed_audio(y).cpu()
+                    kldiv.update(y_pred, y, sizes, sample_rates)
+                if text_consistency is not None:
+                    texts = [m.description for m in meta]
+                    if self.cfg.metrics.text_consistency.use_gt:
+                        y_pred = y
+                    text_consistency.update(y_pred, texts, sizes, sample_rates)
+                if chroma_cosine is not None:
+                    if self.cfg.metrics.chroma_cosine.use_gt:
+                        y_pred = get_compressed_audio(y).cpu()
+                    chroma_cosine.update(y_pred, y, sizes, sample_rates)
+                    # restore chroma conditioner's eval chroma wavs
+                    if eval_chroma_wavs is not None:
+                        self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
+
+            flashy.distrib.barrier()
+            if fad is not None:
+                metrics['fad'] = fad.compute()
+            if kldiv is not None:
+                kld_metrics = kldiv.compute()
+                metrics.update(kld_metrics)
+            if text_consistency is not None:
+                metrics['text_consistency'] = text_consistency.compute()
+            if chroma_cosine is not None:
+                metrics['chroma_cosine'] = chroma_cosine.compute()
+            metrics = average(metrics)
+            metrics = flashy.distrib.average_metrics(metrics, len(loader))
+
+        return metrics
+
+    def evaluate(self) -> dict:
+        """Evaluate stage."""
+        self.model.eval()
+        with torch.no_grad():
+            metrics: dict = {}
+            if self.cfg.evaluate.metrics.base:
+                metrics.update(self.common_train_valid('evaluate'))
+            gen_metrics = self.evaluate_audio_generation()
+            return {**metrics, **gen_metrics}
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MusicGenSolver +(cfg: omegaconf.dictconfig.DictConfig) +
+
+

Solver for MusicGen training task.

+

Used in: https://arxiv.org/abs/2306.05284

+
+ +Expand source code + +
class MusicGenSolver(base.StandardSolver):
+    """Solver for MusicGen training task.
+
+    Used in: https://arxiv.org/abs/2306.05284
+    """
+    DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
+
+    def __init__(self, cfg: omegaconf.DictConfig):
+        super().__init__(cfg)
+        # easier access to sampling parameters
+        self.generation_params = {
+            'use_sampling': self.cfg.generate.lm.use_sampling,
+            'temp': self.cfg.generate.lm.temp,
+            'top_k': self.cfg.generate.lm.top_k,
+            'top_p': self.cfg.generate.lm.top_p,
+        }
+        self._best_metric_name: tp.Optional[str] = 'ce'
+
+        self._cached_batch_writer = None
+        self._cached_batch_loader = None
+        if cfg.cache.path:
+            if cfg.cache.write:
+                self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
+                if self.cfg.cache.write_num_shards:
+                    self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
+                    self._best_metric_name = None
+            else:
+                self._cached_batch_loader = CachedBatchLoader(
+                    Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
+                    min_length=self.cfg.optim.updates_per_epoch or 1)
+                self.dataloaders['original_train'] = self.dataloaders['train']
+                self.dataloaders['train'] = self._cached_batch_loader  # type: ignore
+
+    @staticmethod
+    def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+                                 device: tp.Optional[str] = None, autocast: bool = True,
+                                 batch_size: tp.Optional[int] = None,
+                                 override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                                 **kwargs):
+        """Mostly a convenience function around magma.train.get_solver_from_sig,
+        populating all the proper param, deactivating EMA, FSDP, loading the best state,
+        basically all you need to get a solver ready to "play" with in single GPU mode
+        and with minimal memory overhead.
+
+        Args:
+            sig (str): signature to load.
+            dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+            device (str or None): potential device, as a string, i.e. 'cuda'.
+            override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+        """
+        from audiocraft import train
+        our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+        our_override_cfg['autocast'] = autocast
+        if dtype is not None:
+            our_override_cfg['dtype'] = dtype
+        if device is not None:
+            our_override_cfg['device'] = device
+        if batch_size is not None:
+            our_override_cfg['dataset'] = {'batch_size': batch_size}
+        if override_cfg is None:
+            override_cfg = {}
+        override_cfg = omegaconf.OmegaConf.merge(
+            omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
+        solver = train.get_solver_from_sig(
+            sig, override_cfg=override_cfg,
+            load_best=True, disable_fsdp=True,
+            ignore_state_keys=['optimizer', 'ema'], **kwargs)
+        solver.model.eval()
+        return solver
+
+    def get_formatter(self, stage_name: str) -> flashy.Formatter:
+        return flashy.Formatter({
+            'lr': '.2E',
+            'ce': '.3f',
+            'ppl': '.3f',
+            'grad_norm': '.3E',
+        }, exclude_keys=['ce_q*', 'ppl_q*'])
+
+    @property
+    def best_metric_name(self) -> tp.Optional[str]:
+        return self._best_metric_name
+
+    def build_model(self) -> None:
+        """Instantiate models and optimizer."""
+        # we can potentially not use all quantizers with which the EnCodec model was trained
+        # (e.g. we trained the model with quantizers dropout)
+        self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
+            self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
+        assert self.compression_model.sample_rate == self.cfg.sample_rate, (
+            f"Compression model sample rate is {self.compression_model.sample_rate} but "
+            f"Solver sample rate is {self.cfg.sample_rate}."
+            )
+        # ensure we have matching configuration between LM and compression model
+        assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
+            "Cardinalities of the LM and compression model don't match: ",
+            f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
+            f"compression model cardinality is {self.compression_model.cardinality}"
+        )
+        assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
+            "Numbers of codebooks of the LM and compression models don't match: ",
+            f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
+            f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
+        )
+        self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
+                         self.compression_model.num_codebooks, self.compression_model.cardinality,
+                         self.compression_model.frame_rate)
+        # instantiate LM model
+        self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
+        if self.cfg.fsdp.use:
+            assert not self.cfg.autocast, "Cannot use autocast with fsdp"
+            self.model = self.wrap_with_fsdp(self.model)
+        self.register_ema('model')
+        # initialize optimization
+        self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
+        self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
+        self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
+        self.register_best_state('model')
+        self.autocast_dtype = {
+            'float16': torch.float16, 'bfloat16': torch.bfloat16
+        }[self.cfg.autocast_dtype]
+        self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
+        if self.cfg.fsdp.use:
+            need_scaler = self.cfg.fsdp.param_dtype == 'float16'
+        else:
+            need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
+        if need_scaler:
+            if self.cfg.fsdp.use:
+                from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+                self.scaler = ShardedGradScaler()  # type: ignore
+            else:
+                self.scaler = torch.cuda.amp.GradScaler()
+            self.register_stateful('scaler')
+
+    def build_dataloaders(self) -> None:
+        """Instantiate audio dataloaders for each stage."""
+        self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
+
+    def show(self) -> None:
+        """Show the compression model and LM model."""
+        self.logger.info("Compression model:")
+        self.log_model_summary(self.compression_model)
+        self.logger.info("LM model:")
+        self.log_model_summary(self.model)
+
+    def load_state_dict(self, state: dict) -> None:
+        if 'condition_provider' in state:
+            model_state = state['model']
+            condition_provider_state = state.pop('condition_provider')
+            prefix = 'condition_provider.'
+            for key, value in condition_provider_state.items():
+                key = prefix + key
+                assert key not in model_state
+                model_state[key] = value
+        super().load_state_dict(state)
+
+    def load_from_pretrained(self, name: str):
+        # TODO: support native HF versions of MusicGen.
+        lm_pkg = models.loaders.load_lm_model_ckpt(name)
+        state: dict = {
+            'best_state': {
+                'model': lm_pkg['best_state'],
+            },
+        }
+        return state
+
+    def _compute_cross_entropy(
+        self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
+    ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
+        """Compute cross entropy between multi-codebook targets and model's logits.
+        The cross entropy is computed per codebook to provide codebook-level cross entropy.
+        Valid timesteps for each of the codebook are pulled from the mask, where invalid
+        timesteps are set to 0.
+
+        Args:
+            logits (torch.Tensor): Model's logits of shape [B, K, T, card].
+            targets (torch.Tensor): Target codes, of shape [B, K, T].
+            mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
+        Returns:
+            ce (torch.Tensor): Cross entropy averaged over the codebooks
+            ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
+        """
+        B, K, T = targets.shape
+        assert logits.shape[:-1] == targets.shape
+        assert mask.shape == targets.shape
+        ce = torch.zeros([], device=targets.device)
+        ce_per_codebook: tp.List[torch.Tensor] = []
+        for k in range(K):
+            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
+            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
+            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
+            ce_targets = targets_k[mask_k]
+            ce_logits = logits_k[mask_k]
+            q_ce = F.cross_entropy(ce_logits, ce_targets)
+            ce += q_ce
+            ce_per_codebook.append(q_ce.detach())
+        # average cross entropy across codebooks
+        ce = ce / K
+        return ce, ce_per_codebook
+
+    def _prepare_tokens_and_attributes(
+        self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
+        check_synchronization_points: bool = False
+    ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
+        """Prepare input batchs for language model training.
+
+        Args:
+            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
+                and corresponding metadata as SegmentWithAttributes (with B items).
+            check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
+        Returns:
+            Condition tensors (dict[str, any]): Preprocessed condition attributes.
+            Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
+                with B the batch size, K the number of codebooks, T_s the token timesteps.
+            Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
+        """
+        if self.model.training:
+            warnings.warn(
+                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
+                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
+                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
+                "Really sorry about that.")
+        if self._cached_batch_loader is None or self.current_stage != "train":
+            audio, infos = batch
+            audio = audio.to(self.device)
+            audio_tokens = None
+            assert audio.size(0) == len(infos), (
+                f"Mismatch between number of items in audio batch ({audio.size(0)})",
+                f" and in metadata ({len(infos)})"
+            )
+        else:
+            audio = None
+            # In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
+            infos, = batch  # type: ignore
+            assert all([isinstance(info, AudioInfo) for info in infos])
+            assert all([info.audio_tokens is not None for info in infos])  # type: ignore
+            audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device)  # type: ignore
+            audio_tokens = audio_tokens.long()
+            for info in infos:
+                if isinstance(info, MusicInfo):
+                    # Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
+                    # then you must be using the chroma cache! otherwise the code will try
+                    # to use this segment and fail (by that I mean you will see NaN everywhere).
+                    info.self_wav = WavCondition(
+                        torch.full([1, info.channels, info.total_frames], float('NaN')),
+                        length=torch.tensor([info.n_frames]),
+                        sample_rate=[info.sample_rate],
+                        path=[info.meta.path],
+                        seek_time=[info.seek_time])
+                    dataset = get_dataset_from_loader(self.dataloaders['original_train'])
+                    assert isinstance(dataset, MusicDataset), type(dataset)
+                    if dataset.paraphraser is not None and info.description is not None:
+                        # Hackingly reapplying paraphraser when using cache.
+                        info.description = dataset.paraphraser.sample_paraphrase(
+                            info.meta.path, info.description)
+        # prepare attributes
+        attributes = [info.to_condition_attributes() for info in infos]
+        attributes = self.model.cfg_dropout(attributes)
+        attributes = self.model.att_dropout(attributes)
+        tokenized = self.model.condition_provider.tokenize(attributes)
+
+        # Now we should be synchronization free.
+        if self.device == "cuda" and check_synchronization_points:
+            torch.cuda.set_sync_debug_mode("warn")
+
+        if audio_tokens is None:
+            with torch.no_grad():
+                audio_tokens, scale = self.compression_model.encode(audio)
+                assert scale is None, "Scaled compression model not supported with LM."
+
+        with self.autocast:
+            condition_tensors = self.model.condition_provider(tokenized)
+
+        # create a padding mask to hold valid vs invalid positions
+        padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
+        # replace encodec tokens from padded audio with special_token_id
+        if self.cfg.tokens.padding_with_special_token:
+            audio_tokens = audio_tokens.clone()
+            padding_mask = padding_mask.clone()
+            token_sample_rate = self.compression_model.frame_rate
+            B, K, T_s = audio_tokens.shape
+            for i in range(B):
+                n_samples = infos[i].n_frames
+                audio_sample_rate = infos[i].sample_rate
+                # take the last token generated from actual audio frames (non-padded audio)
+                valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
+                audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
+                padding_mask[i, :, valid_tokens:] = 0
+
+        if self.device == "cuda" and check_synchronization_points:
+            torch.cuda.set_sync_debug_mode("default")
+
+        if self._cached_batch_writer is not None and self.current_stage == 'train':
+            assert self._cached_batch_loader is None
+            assert audio_tokens is not None
+            for info, one_audio_tokens in zip(infos, audio_tokens):
+                assert isinstance(info, AudioInfo)
+                if isinstance(info, MusicInfo):
+                    assert not info.joint_embed, "joint_embed and cache not supported yet."
+                    info.self_wav = None
+                assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
+                info.audio_tokens = one_audio_tokens.short().cpu()
+            self._cached_batch_writer.save(infos)
+
+        return condition_tensors, audio_tokens, padding_mask
+
+    def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
+        """Perform one training or valid step on a given batch."""
+        check_synchronization_points = idx == 1 and self.device == 'cuda'
+
+        condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
+            batch, check_synchronization_points)
+
+        self.deadlock_detect.update('tokens_and_conditions')
+
+        if check_synchronization_points:
+            torch.cuda.set_sync_debug_mode('warn')
+
+        with self.autocast:
+            model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors)  # type: ignore
+            logits = model_output.logits
+            mask = padding_mask & model_output.mask
+            ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
+            loss = ce
+        self.deadlock_detect.update('loss')
+
+        if check_synchronization_points:
+            torch.cuda.set_sync_debug_mode('default')
+
+        if self.is_training:
+            metrics['lr'] = self.optimizer.param_groups[0]['lr']
+            if self.scaler is not None:
+                loss = self.scaler.scale(loss)
+            self.deadlock_detect.update('scale')
+            if self.cfg.fsdp.use:
+                loss.backward()
+                flashy.distrib.average_tensors(self.model.buffers())
+            elif self.cfg.optim.eager_sync:
+                with flashy.distrib.eager_sync_model(self.model):
+                    loss.backward()
+            else:
+                # this should always be slower but can be useful
+                # for weird use cases like multiple backwards.
+                loss.backward()
+                flashy.distrib.sync_model(self.model)
+            self.deadlock_detect.update('backward')
+
+            if self.scaler is not None:
+                self.scaler.unscale_(self.optimizer)
+            if self.cfg.optim.max_norm:
+                if self.cfg.fsdp.use:
+                    metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
+                else:
+                    metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
+                        self.model.parameters(), self.cfg.optim.max_norm
+                    )
+            if self.scaler is None:
+                self.optimizer.step()
+            else:
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+            if self.lr_scheduler:
+                self.lr_scheduler.step()
+            self.optimizer.zero_grad()
+            self.deadlock_detect.update('optim')
+            if self.scaler is not None:
+                scale = self.scaler.get_scale()
+                metrics['grad_scale'] = scale
+            if not loss.isfinite().all():
+                raise RuntimeError("Model probably diverged.")
+
+        metrics['ce'] = ce
+        metrics['ppl'] = torch.exp(ce)
+        for k, ce_q in enumerate(ce_per_codebook):
+            metrics[f'ce_q{k + 1}'] = ce_q
+            metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
+
+        return metrics
+
+    @torch.no_grad()
+    def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
+                          gen_duration: float, prompt_duration: tp.Optional[float] = None,
+                          remove_prompt: bool = False,
+                          **generation_params) -> dict:
+        """Run generate step on a batch of optional audio tensor and corresponding attributes.
+
+        Args:
+            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
+            use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
+            gen_duration (float): Target audio duration for the generation.
+            prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
+            remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
+            generation_params: Additional generation parameters.
+        Returns:
+            gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
+                and the prompt along with additional information.
+        """
+        bench_start = time.time()
+        audio, meta = batch
+        assert audio.size(0) == len(meta), (
+            f"Mismatch between number of items in audio batch ({audio.size(0)})",
+            f" and in metadata ({len(meta)})"
+        )
+        # prepare attributes
+        attributes = [x.to_condition_attributes() for x in meta]
+        # TODO: Add dropout for chroma?
+
+        # prepare audio prompt
+        if prompt_duration is None:
+            prompt_audio = None
+        else:
+            assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
+            prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
+            prompt_audio = audio[..., :prompt_audio_frames]
+
+        # get audio tokens from compression model
+        if prompt_audio is None or prompt_audio.nelement() == 0:
+            num_samples = len(attributes)
+            prompt_tokens = None
+        else:
+            num_samples = None
+            prompt_audio = prompt_audio.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt_audio)
+            assert scale is None, "Compression model in MusicGen should not require rescaling."
+
+        # generate by sampling from the LM
+        with self.autocast:
+            total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
+            gen_tokens = self.model.generate(
+                prompt_tokens, attributes, max_gen_len=total_gen_len,
+                num_samples=num_samples, **self.generation_params)
+
+        # generate audio from tokens
+        assert gen_tokens.dim() == 3
+        gen_audio = self.compression_model.decode(gen_tokens, None)
+
+        bench_end = time.time()
+        gen_outputs = {
+            'rtf': (bench_end - bench_start) / gen_duration,
+            'ref_audio': audio,
+            'gen_audio': gen_audio,
+            'gen_tokens': gen_tokens,
+            'prompt_audio': prompt_audio,
+            'prompt_tokens': prompt_tokens,
+        }
+        return gen_outputs
+
+    def generate_audio(self) -> dict:
+        """Audio generation stage."""
+        generate_stage_name = f'{self.current_stage}'
+        sample_manager = SampleManager(self.xp)
+        self.logger.info(f"Generating samples in {sample_manager.base_folder}")
+        loader = self.dataloaders['generate']
+        updates = len(loader)
+        lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+        dataset = get_dataset_from_loader(loader)
+        dataset_duration = dataset.segment_duration
+        assert dataset_duration is not None
+        assert isinstance(dataset, AudioDataset)
+        target_duration = self.cfg.generate.lm.gen_duration
+        prompt_duration = self.cfg.generate.lm.prompt_duration
+        if target_duration is None:
+            target_duration = dataset_duration
+        if prompt_duration is None:
+            prompt_duration = dataset_duration / 4
+        assert prompt_duration < dataset_duration, (
+            f"Specified prompt duration ({prompt_duration}s) is longer",
+            f" than reference audio duration ({dataset_duration}s)"
+        )
+
+        def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
+            hydrated_conditions = []
+            for sample in [x.to_condition_attributes() for x in meta]:
+                cond_dict = {}
+                for cond_type in sample.__annotations__.keys():
+                    for cond_key, cond_val in getattr(sample, cond_type).items():
+                        if cond_key not in self.model.condition_provider.conditioners.keys():
+                            continue
+                        if is_jsonable(cond_val):
+                            cond_dict[cond_key] = cond_val
+                        elif isinstance(cond_val, WavCondition):
+                            cond_dict[cond_key] = cond_val.path
+                        elif isinstance(cond_val, JointEmbedCondition):
+                            cond_dict[cond_key] = cond_val.text  # only support text at inference for now
+                        else:
+                            # if we reached this point, it is not clear how to log the condition
+                            # so we just log the type.
+                            cond_dict[cond_key] = str(type(cond_val))
+                            continue
+                hydrated_conditions.append(cond_dict)
+            return hydrated_conditions
+
+        metrics: dict = {}
+        average = flashy.averager()
+        for batch in lp:
+            audio, meta = batch
+            # metadata for sample manager
+            hydrated_conditions = get_hydrated_conditions(meta)
+            sample_generation_params = {
+                **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
+                **self.generation_params
+            }
+            if self.cfg.generate.lm.unprompted_samples:
+                if self.cfg.generate.lm.gen_gt_samples:
+                    # get the ground truth instead of generation
+                    self.logger.warn(
+                        "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
+                    gen_unprompted_audio = audio
+                    rtf = 1.
+                else:
+                    gen_unprompted_outputs = self.run_generate_step(
+                        batch, gen_duration=target_duration, prompt_duration=None,
+                        **self.generation_params)
+                    gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
+                    rtf = gen_unprompted_outputs['rtf']
+                sample_manager.add_samples(
+                    gen_unprompted_audio, self.epoch, hydrated_conditions,
+                    ground_truth_wavs=audio, generation_args=sample_generation_params)
+
+            if self.cfg.generate.lm.prompted_samples:
+                gen_outputs = self.run_generate_step(
+                    batch, gen_duration=target_duration, prompt_duration=prompt_duration,
+                    **self.generation_params)
+                gen_audio = gen_outputs['gen_audio'].cpu()
+                prompt_audio = gen_outputs['prompt_audio'].cpu()
+                sample_manager.add_samples(
+                    gen_audio, self.epoch, hydrated_conditions,
+                    prompt_wavs=prompt_audio, ground_truth_wavs=audio,
+                    generation_args=sample_generation_params)
+
+            metrics['rtf'] = rtf
+            metrics = average(metrics)
+
+        flashy.distrib.barrier()
+        return metrics
+
+    def generate(self) -> dict:
+        """Generate stage."""
+        self.model.eval()
+        with torch.no_grad():
+            return self.generate_audio()
+
+    def run_epoch(self):
+        if self.cfg.cache.write:
+            if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
+                return
+        super().run_epoch()
+
+    def train(self):
+        """Train stage.
+        """
+        if self._cached_batch_writer is not None:
+            self._cached_batch_writer.start_epoch(self.epoch)
+        if self._cached_batch_loader is None:
+            dataset = get_dataset_from_loader(self.dataloaders['train'])
+            assert isinstance(dataset, AudioDataset)
+            dataset.current_epoch = self.epoch
+        else:
+            self._cached_batch_loader.start_epoch(self.epoch)
+        return super().train()
+
+    def evaluate_audio_generation(self) -> dict:
+        """Evaluate audio generation with off-the-shelf metrics."""
+        evaluate_stage_name = f'{self.current_stage}_generation'
+        # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
+        fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
+        kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
+        text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
+        chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
+        should_run_eval = False
+        eval_chroma_wavs: tp.Optional[torch.Tensor] = None
+        if self.cfg.evaluate.metrics.fad:
+            fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
+            should_run_eval = True
+        if self.cfg.evaluate.metrics.kld:
+            kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
+            should_run_eval = True
+        if self.cfg.evaluate.metrics.text_consistency:
+            text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
+            should_run_eval = True
+        if self.cfg.evaluate.metrics.chroma_cosine:
+            chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
+            # if we have predefind wavs for chroma we should purge them for computing the cosine metric
+            has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
+                                          self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
+            if has_predefined_eval_chromas:
+                warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
+                                       'Resetting eval chromas to None for evaluation.')
+                eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs  # type: ignore
+                self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None)  # type: ignore
+            should_run_eval = True
+
+        def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
+            audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
+            compressed_audio = self.compression_model.decode(audio_tokens, scale)
+            return compressed_audio[..., :audio.shape[-1]]
+
+        metrics: dict = {}
+        if should_run_eval:
+            loader = self.dataloaders['evaluate']
+            updates = len(loader)
+            lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+            average = flashy.averager()
+            dataset = get_dataset_from_loader(loader)
+            assert isinstance(dataset, AudioDataset)
+            self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
+
+            for idx, batch in enumerate(lp):
+                audio, meta = batch
+                assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
+
+                target_duration = audio.shape[-1] / self.cfg.sample_rate
+                if self.cfg.evaluate.fixed_generation_duration:
+                    target_duration = self.cfg.evaluate.fixed_generation_duration
+
+                gen_outputs = self.run_generate_step(
+                    batch, gen_duration=target_duration,
+                    **self.generation_params
+                )
+                y_pred = gen_outputs['gen_audio'].detach()
+                y_pred = y_pred[..., :audio.shape[-1]]
+
+                normalize_kwargs = dict(self.cfg.generate.audio)
+                normalize_kwargs.pop('format', None)
+                y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
+                y = audio.cpu()  # should already be on CPU but just in case
+                sizes = torch.tensor([m.n_frames for m in meta])  # actual sizes without padding
+                sample_rates = torch.tensor([m.sample_rate for m in meta])  # sample rates for audio samples
+                audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
+
+                if fad is not None:
+                    if self.cfg.metrics.fad.use_gt:
+                        y_pred = get_compressed_audio(y).cpu()
+                    fad.update(y_pred, y, sizes, sample_rates, audio_stems)
+                if kldiv is not None:
+                    if self.cfg.metrics.kld.use_gt:
+                        y_pred = get_compressed_audio(y).cpu()
+                    kldiv.update(y_pred, y, sizes, sample_rates)
+                if text_consistency is not None:
+                    texts = [m.description for m in meta]
+                    if self.cfg.metrics.text_consistency.use_gt:
+                        y_pred = y
+                    text_consistency.update(y_pred, texts, sizes, sample_rates)
+                if chroma_cosine is not None:
+                    if self.cfg.metrics.chroma_cosine.use_gt:
+                        y_pred = get_compressed_audio(y).cpu()
+                    chroma_cosine.update(y_pred, y, sizes, sample_rates)
+                    # restore chroma conditioner's eval chroma wavs
+                    if eval_chroma_wavs is not None:
+                        self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
+
+            flashy.distrib.barrier()
+            if fad is not None:
+                metrics['fad'] = fad.compute()
+            if kldiv is not None:
+                kld_metrics = kldiv.compute()
+                metrics.update(kld_metrics)
+            if text_consistency is not None:
+                metrics['text_consistency'] = text_consistency.compute()
+            if chroma_cosine is not None:
+                metrics['chroma_cosine'] = chroma_cosine.compute()
+            metrics = average(metrics)
+            metrics = flashy.distrib.average_metrics(metrics, len(loader))
+
+        return metrics
+
+    def evaluate(self) -> dict:
+        """Evaluate stage."""
+        self.model.eval()
+        with torch.no_grad():
+            metrics: dict = {}
+            if self.cfg.evaluate.metrics.base:
+                metrics.update(self.common_train_valid('evaluate'))
+            gen_metrics = self.evaluate_audio_generation()
+            return {**metrics, **gen_metrics}
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var DATASET_TYPEDatasetType
+
+
+
+
+

Static methods

+
+
+def get_eval_solver_from_sig(sig: str, dtype: Optional[str] = None, device: Optional[str] = None, autocast: bool = True, batch_size: Optional[int] = None, override_cfg: Union[dict, omegaconf.dictconfig.DictConfig, None] = None, **kwargs) +
+
+

Mostly a convenience function around magma.train.get_solver_from_sig, +populating all the proper param, deactivating EMA, FSDP, loading the best state, +basically all you need to get a solver ready to "play" with in single GPU mode +and with minimal memory overhead.

+

Args

+
+
sig : str
+
signature to load.
+
dtype : str or None
+
potential dtype, as a string, i.e. 'float16'.
+
device : str or None
+
potential device, as a string, i.e. 'cuda'.
+
override_cfg : dict or omegaconf.DictConfig or None
+
potential device, as a string, i.e. 'cuda'.
+
+
+ +Expand source code + +
@staticmethod
+def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
+                             device: tp.Optional[str] = None, autocast: bool = True,
+                             batch_size: tp.Optional[int] = None,
+                             override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                             **kwargs):
+    """Mostly a convenience function around magma.train.get_solver_from_sig,
+    populating all the proper param, deactivating EMA, FSDP, loading the best state,
+    basically all you need to get a solver ready to "play" with in single GPU mode
+    and with minimal memory overhead.
+
+    Args:
+        sig (str): signature to load.
+        dtype (str or None): potential dtype, as a string, i.e. 'float16'.
+        device (str or None): potential device, as a string, i.e. 'cuda'.
+        override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
+    """
+    from audiocraft import train
+    our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
+    our_override_cfg['autocast'] = autocast
+    if dtype is not None:
+        our_override_cfg['dtype'] = dtype
+    if device is not None:
+        our_override_cfg['device'] = device
+    if batch_size is not None:
+        our_override_cfg['dataset'] = {'batch_size': batch_size}
+    if override_cfg is None:
+        override_cfg = {}
+    override_cfg = omegaconf.OmegaConf.merge(
+        omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg))  # type: ignore
+    solver = train.get_solver_from_sig(
+        sig, override_cfg=override_cfg,
+        load_best=True, disable_fsdp=True,
+        ignore_state_keys=['optimizer', 'ema'], **kwargs)
+    solver.model.eval()
+    return solver
+
+
+
+

Methods

+
+
+def build_dataloaders(self) ‑> None +
+
+

Instantiate audio dataloaders for each stage.

+
+ +Expand source code + +
def build_dataloaders(self) -> None:
+    """Instantiate audio dataloaders for each stage."""
+    self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
+
+
+
+def build_model(self) ‑> None +
+
+

Instantiate models and optimizer.

+
+ +Expand source code + +
def build_model(self) -> None:
+    """Instantiate models and optimizer."""
+    # we can potentially not use all quantizers with which the EnCodec model was trained
+    # (e.g. we trained the model with quantizers dropout)
+    self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
+        self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
+    assert self.compression_model.sample_rate == self.cfg.sample_rate, (
+        f"Compression model sample rate is {self.compression_model.sample_rate} but "
+        f"Solver sample rate is {self.cfg.sample_rate}."
+        )
+    # ensure we have matching configuration between LM and compression model
+    assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
+        "Cardinalities of the LM and compression model don't match: ",
+        f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
+        f"compression model cardinality is {self.compression_model.cardinality}"
+    )
+    assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
+        "Numbers of codebooks of the LM and compression models don't match: ",
+        f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
+        f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
+    )
+    self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
+                     self.compression_model.num_codebooks, self.compression_model.cardinality,
+                     self.compression_model.frame_rate)
+    # instantiate LM model
+    self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
+    if self.cfg.fsdp.use:
+        assert not self.cfg.autocast, "Cannot use autocast with fsdp"
+        self.model = self.wrap_with_fsdp(self.model)
+    self.register_ema('model')
+    # initialize optimization
+    self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
+    self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
+    self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
+    self.register_best_state('model')
+    self.autocast_dtype = {
+        'float16': torch.float16, 'bfloat16': torch.bfloat16
+    }[self.cfg.autocast_dtype]
+    self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
+    if self.cfg.fsdp.use:
+        need_scaler = self.cfg.fsdp.param_dtype == 'float16'
+    else:
+        need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
+    if need_scaler:
+        if self.cfg.fsdp.use:
+            from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+            self.scaler = ShardedGradScaler()  # type: ignore
+        else:
+            self.scaler = torch.cuda.amp.GradScaler()
+        self.register_stateful('scaler')
+
+
+
+def evaluate_audio_generation(self) ‑> dict +
+
+

Evaluate audio generation with off-the-shelf metrics.

+
+ +Expand source code + +
def evaluate_audio_generation(self) -> dict:
+    """Evaluate audio generation with off-the-shelf metrics."""
+    evaluate_stage_name = f'{self.current_stage}_generation'
+    # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
+    fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
+    kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
+    text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
+    chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
+    should_run_eval = False
+    eval_chroma_wavs: tp.Optional[torch.Tensor] = None
+    if self.cfg.evaluate.metrics.fad:
+        fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
+        should_run_eval = True
+    if self.cfg.evaluate.metrics.kld:
+        kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
+        should_run_eval = True
+    if self.cfg.evaluate.metrics.text_consistency:
+        text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
+        should_run_eval = True
+    if self.cfg.evaluate.metrics.chroma_cosine:
+        chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
+        # if we have predefind wavs for chroma we should purge them for computing the cosine metric
+        has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
+                                      self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
+        if has_predefined_eval_chromas:
+            warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
+                                   'Resetting eval chromas to None for evaluation.')
+            eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs  # type: ignore
+            self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None)  # type: ignore
+        should_run_eval = True
+
+    def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
+        audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
+        compressed_audio = self.compression_model.decode(audio_tokens, scale)
+        return compressed_audio[..., :audio.shape[-1]]
+
+    metrics: dict = {}
+    if should_run_eval:
+        loader = self.dataloaders['evaluate']
+        updates = len(loader)
+        lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
+        average = flashy.averager()
+        dataset = get_dataset_from_loader(loader)
+        assert isinstance(dataset, AudioDataset)
+        self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
+
+        for idx, batch in enumerate(lp):
+            audio, meta = batch
+            assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
+
+            target_duration = audio.shape[-1] / self.cfg.sample_rate
+            if self.cfg.evaluate.fixed_generation_duration:
+                target_duration = self.cfg.evaluate.fixed_generation_duration
+
+            gen_outputs = self.run_generate_step(
+                batch, gen_duration=target_duration,
+                **self.generation_params
+            )
+            y_pred = gen_outputs['gen_audio'].detach()
+            y_pred = y_pred[..., :audio.shape[-1]]
+
+            normalize_kwargs = dict(self.cfg.generate.audio)
+            normalize_kwargs.pop('format', None)
+            y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
+            y = audio.cpu()  # should already be on CPU but just in case
+            sizes = torch.tensor([m.n_frames for m in meta])  # actual sizes without padding
+            sample_rates = torch.tensor([m.sample_rate for m in meta])  # sample rates for audio samples
+            audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
+
+            if fad is not None:
+                if self.cfg.metrics.fad.use_gt:
+                    y_pred = get_compressed_audio(y).cpu()
+                fad.update(y_pred, y, sizes, sample_rates, audio_stems)
+            if kldiv is not None:
+                if self.cfg.metrics.kld.use_gt:
+                    y_pred = get_compressed_audio(y).cpu()
+                kldiv.update(y_pred, y, sizes, sample_rates)
+            if text_consistency is not None:
+                texts = [m.description for m in meta]
+                if self.cfg.metrics.text_consistency.use_gt:
+                    y_pred = y
+                text_consistency.update(y_pred, texts, sizes, sample_rates)
+            if chroma_cosine is not None:
+                if self.cfg.metrics.chroma_cosine.use_gt:
+                    y_pred = get_compressed_audio(y).cpu()
+                chroma_cosine.update(y_pred, y, sizes, sample_rates)
+                # restore chroma conditioner's eval chroma wavs
+                if eval_chroma_wavs is not None:
+                    self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
+
+        flashy.distrib.barrier()
+        if fad is not None:
+            metrics['fad'] = fad.compute()
+        if kldiv is not None:
+            kld_metrics = kldiv.compute()
+            metrics.update(kld_metrics)
+        if text_consistency is not None:
+            metrics['text_consistency'] = text_consistency.compute()
+        if chroma_cosine is not None:
+            metrics['chroma_cosine'] = chroma_cosine.compute()
+        metrics = average(metrics)
+        metrics = flashy.distrib.average_metrics(metrics, len(loader))
+
+    return metrics
+
+
+
+def generate_audio(self) ‑> dict +
+
+

Audio generation stage.

+
+ +Expand source code + +
def generate_audio(self) -> dict:
+    """Audio generation stage."""
+    generate_stage_name = f'{self.current_stage}'
+    sample_manager = SampleManager(self.xp)
+    self.logger.info(f"Generating samples in {sample_manager.base_folder}")
+    loader = self.dataloaders['generate']
+    updates = len(loader)
+    lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
+
+    dataset = get_dataset_from_loader(loader)
+    dataset_duration = dataset.segment_duration
+    assert dataset_duration is not None
+    assert isinstance(dataset, AudioDataset)
+    target_duration = self.cfg.generate.lm.gen_duration
+    prompt_duration = self.cfg.generate.lm.prompt_duration
+    if target_duration is None:
+        target_duration = dataset_duration
+    if prompt_duration is None:
+        prompt_duration = dataset_duration / 4
+    assert prompt_duration < dataset_duration, (
+        f"Specified prompt duration ({prompt_duration}s) is longer",
+        f" than reference audio duration ({dataset_duration}s)"
+    )
+
+    def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
+        hydrated_conditions = []
+        for sample in [x.to_condition_attributes() for x in meta]:
+            cond_dict = {}
+            for cond_type in sample.__annotations__.keys():
+                for cond_key, cond_val in getattr(sample, cond_type).items():
+                    if cond_key not in self.model.condition_provider.conditioners.keys():
+                        continue
+                    if is_jsonable(cond_val):
+                        cond_dict[cond_key] = cond_val
+                    elif isinstance(cond_val, WavCondition):
+                        cond_dict[cond_key] = cond_val.path
+                    elif isinstance(cond_val, JointEmbedCondition):
+                        cond_dict[cond_key] = cond_val.text  # only support text at inference for now
+                    else:
+                        # if we reached this point, it is not clear how to log the condition
+                        # so we just log the type.
+                        cond_dict[cond_key] = str(type(cond_val))
+                        continue
+            hydrated_conditions.append(cond_dict)
+        return hydrated_conditions
+
+    metrics: dict = {}
+    average = flashy.averager()
+    for batch in lp:
+        audio, meta = batch
+        # metadata for sample manager
+        hydrated_conditions = get_hydrated_conditions(meta)
+        sample_generation_params = {
+            **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
+            **self.generation_params
+        }
+        if self.cfg.generate.lm.unprompted_samples:
+            if self.cfg.generate.lm.gen_gt_samples:
+                # get the ground truth instead of generation
+                self.logger.warn(
+                    "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
+                gen_unprompted_audio = audio
+                rtf = 1.
+            else:
+                gen_unprompted_outputs = self.run_generate_step(
+                    batch, gen_duration=target_duration, prompt_duration=None,
+                    **self.generation_params)
+                gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
+                rtf = gen_unprompted_outputs['rtf']
+            sample_manager.add_samples(
+                gen_unprompted_audio, self.epoch, hydrated_conditions,
+                ground_truth_wavs=audio, generation_args=sample_generation_params)
+
+        if self.cfg.generate.lm.prompted_samples:
+            gen_outputs = self.run_generate_step(
+                batch, gen_duration=target_duration, prompt_duration=prompt_duration,
+                **self.generation_params)
+            gen_audio = gen_outputs['gen_audio'].cpu()
+            prompt_audio = gen_outputs['prompt_audio'].cpu()
+            sample_manager.add_samples(
+                gen_audio, self.epoch, hydrated_conditions,
+                prompt_wavs=prompt_audio, ground_truth_wavs=audio,
+                generation_args=sample_generation_params)
+
+        metrics['rtf'] = rtf
+        metrics = average(metrics)
+
+    flashy.distrib.barrier()
+    return metrics
+
+
+
+def get_formatter(self, stage_name: str) ‑> flashy.formatter.Formatter +
+
+
+
+ +Expand source code + +
def get_formatter(self, stage_name: str) -> flashy.Formatter:
+    return flashy.Formatter({
+        'lr': '.2E',
+        'ce': '.3f',
+        'ppl': '.3f',
+        'grad_norm': '.3E',
+    }, exclude_keys=['ce_q*', 'ppl_q*'])
+
+
+
+def load_from_pretrained(self, name: str) +
+
+
+
+ +Expand source code + +
def load_from_pretrained(self, name: str):
+    # TODO: support native HF versions of MusicGen.
+    lm_pkg = models.loaders.load_lm_model_ckpt(name)
+    state: dict = {
+        'best_state': {
+            'model': lm_pkg['best_state'],
+        },
+    }
+    return state
+
+
+
+def load_state_dict(self, state: dict) ‑> None +
+
+
+
+ +Expand source code + +
def load_state_dict(self, state: dict) -> None:
+    if 'condition_provider' in state:
+        model_state = state['model']
+        condition_provider_state = state.pop('condition_provider')
+        prefix = 'condition_provider.'
+        for key, value in condition_provider_state.items():
+            key = prefix + key
+            assert key not in model_state
+            model_state[key] = value
+    super().load_state_dict(state)
+
+
+
+def run_generate_step(self, batch: Tuple[torch.Tensor, List[SegmentWithAttributes]], gen_duration: float, prompt_duration: Optional[float] = None, remove_prompt: bool = False, **generation_params) ‑> dict +
+
+

Run generate step on a batch of optional audio tensor and corresponding attributes.

+

Args

+
+
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
+
use_prompt : bool
+
Whether to do audio continuation generation with prompt from audio batch.
+
gen_duration : float
+
Target audio duration for the generation.
+
prompt_duration : float, optional
+
Duration for the audio prompt to use for continuation.
+
remove_prompt : bool, optional
+
Whether to remove the prompt from the generated audio.
+
generation_params
+
Additional generation parameters.
+
+

Returns

+

gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation +and the prompt along with additional information.

+
+ +Expand source code + +
@torch.no_grad()
+def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
+                      gen_duration: float, prompt_duration: tp.Optional[float] = None,
+                      remove_prompt: bool = False,
+                      **generation_params) -> dict:
+    """Run generate step on a batch of optional audio tensor and corresponding attributes.
+
+    Args:
+        batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
+        use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
+        gen_duration (float): Target audio duration for the generation.
+        prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
+        remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
+        generation_params: Additional generation parameters.
+    Returns:
+        gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
+            and the prompt along with additional information.
+    """
+    bench_start = time.time()
+    audio, meta = batch
+    assert audio.size(0) == len(meta), (
+        f"Mismatch between number of items in audio batch ({audio.size(0)})",
+        f" and in metadata ({len(meta)})"
+    )
+    # prepare attributes
+    attributes = [x.to_condition_attributes() for x in meta]
+    # TODO: Add dropout for chroma?
+
+    # prepare audio prompt
+    if prompt_duration is None:
+        prompt_audio = None
+    else:
+        assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
+        prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
+        prompt_audio = audio[..., :prompt_audio_frames]
+
+    # get audio tokens from compression model
+    if prompt_audio is None or prompt_audio.nelement() == 0:
+        num_samples = len(attributes)
+        prompt_tokens = None
+    else:
+        num_samples = None
+        prompt_audio = prompt_audio.to(self.device)
+        prompt_tokens, scale = self.compression_model.encode(prompt_audio)
+        assert scale is None, "Compression model in MusicGen should not require rescaling."
+
+    # generate by sampling from the LM
+    with self.autocast:
+        total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
+        gen_tokens = self.model.generate(
+            prompt_tokens, attributes, max_gen_len=total_gen_len,
+            num_samples=num_samples, **self.generation_params)
+
+    # generate audio from tokens
+    assert gen_tokens.dim() == 3
+    gen_audio = self.compression_model.decode(gen_tokens, None)
+
+    bench_end = time.time()
+    gen_outputs = {
+        'rtf': (bench_end - bench_start) / gen_duration,
+        'ref_audio': audio,
+        'gen_audio': gen_audio,
+        'gen_tokens': gen_tokens,
+        'prompt_audio': prompt_audio,
+        'prompt_tokens': prompt_tokens,
+    }
+    return gen_outputs
+
+
+
+def show(self) ‑> None +
+
+

Show the compression model and LM model.

+
+ +Expand source code + +
def show(self) -> None:
+    """Show the compression model and LM model."""
+    self.logger.info("Compression model:")
+    self.log_model_summary(self.compression_model)
+    self.logger.info("LM model:")
+    self.log_model_summary(self.model)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/train.html b/api_docs/audiocraft/train.html new file mode 100644 index 00000000..f899c5bd --- /dev/null +++ b/api_docs/audiocraft/train.html @@ -0,0 +1,404 @@ + + + + + + +audiocraft.train API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.train

+
+
+

Entry point for dora to launch solvers for running training loops. +See more info on how to use dora: https://github.com/facebookresearch/dora

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Entry point for dora to launch solvers for running training loops.
+See more info on how to use dora: https://github.com/facebookresearch/dora
+"""
+
+import logging
+import multiprocessing
+import os
+from pathlib import Path
+import sys
+import typing as tp
+
+from dora import git_save, hydra_main, XP
+import flashy
+import hydra
+import omegaconf
+
+from .environment import AudioCraftEnvironment
+from .utils.cluster import get_slurm_parameters
+
+logger = logging.getLogger(__name__)
+
+
+def resolve_config_dset_paths(cfg):
+    """Enable Dora to load manifest from git clone repository."""
+    # manifest files for the different splits
+    for key, value in cfg.datasource.items():
+        if isinstance(value, str):
+            cfg.datasource[key] = git_save.to_absolute_path(value)
+
+
+def get_solver(cfg):
+    from . import solvers
+    # Convert batch size to batch size for each GPU
+    assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
+    cfg.dataset.batch_size //= flashy.distrib.world_size()
+    for split in ['train', 'valid', 'evaluate', 'generate']:
+        if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
+            assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
+            cfg.dataset[split].batch_size //= flashy.distrib.world_size()
+    resolve_config_dset_paths(cfg)
+    solver = solvers.get_solver(cfg)
+    return solver
+
+
+def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                       restore: bool = True, load_best: bool = True,
+                       ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
+    """Given a XP, return the Solver object.
+
+    Args:
+        xp (XP): Dora experiment for which to retrieve the solver.
+        override_cfg (dict or None): If not None, should be a dict used to
+            override some values in the config of `xp`. This will not impact
+            the XP signature or folder. The format is different
+            than the one used in Dora grids, nested keys should actually be nested dicts,
+            not flattened, e.g. `{'optim': {'batch_size': 32}}`.
+        restore (bool): If `True` (the default), restore state from the last checkpoint.
+        load_best (bool): If `True` (the default), load the best state from the checkpoint.
+        ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
+        disable_fsdp (bool): if True, disables FSDP entirely. This will
+            also automatically skip loading the EMA. For solver specific
+            state sources, like the optimizer, you might want to
+            use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
+    """
+    logger.info(f"Loading solver from XP {xp.sig}. "
+                f"Overrides used: {xp.argv}")
+    cfg = xp.cfg
+    if override_cfg is not None:
+        cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
+    if disable_fsdp and cfg.fsdp.use:
+        cfg.fsdp.use = False
+        assert load_best is True
+        # ignoring some keys that were FSDP sharded like model, ema, and best_state.
+        # fsdp_best_state will be used in that case. When using a specific solver,
+        # one is responsible for adding the relevant keys, e.g. 'optimizer'.
+        # We could make something to automatically register those inside the solver, but that
+        # seem overkill at this point.
+        ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']
+
+    try:
+        with xp.enter():
+            solver = get_solver(cfg)
+            if restore:
+                solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
+        return solver
+    finally:
+        hydra.core.global_hydra.GlobalHydra.instance().clear()
+
+
+def get_solver_from_sig(sig: str, *args, **kwargs):
+    """Return Solver object from Dora signature, i.e. to play with it from a notebook.
+    See `get_solver_from_xp` for more information.
+    """
+    xp = main.get_xp_from_sig(sig)
+    return get_solver_from_xp(xp, *args, **kwargs)
+
+
+def init_seed_and_system(cfg):
+    import numpy as np
+    import torch
+    import random
+    from audiocraft.modules.transformer import set_efficient_attention_backend
+
+    multiprocessing.set_start_method(cfg.mp_start_method)
+    logger.debug('Setting mp start method to %s', cfg.mp_start_method)
+    random.seed(cfg.seed)
+    np.random.seed(cfg.seed)
+    # torch also initialize cuda seed if available
+    torch.manual_seed(cfg.seed)
+    torch.set_num_threads(cfg.num_threads)
+    os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
+    os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
+    logger.debug('Setting num threads to %d', cfg.num_threads)
+    set_efficient_attention_backend(cfg.efficient_attention_backend)
+    logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
+    if 'SLURM_JOB_ID' in os.environ:
+        tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
+        if tmpdir.exists():
+            logger.info("Changing tmpdir to %s", tmpdir)
+            os.environ['TMPDIR'] = str(tmpdir)
+
+
+@hydra_main(config_path='../config', config_name='config', version_base='1.1')
+def main(cfg):
+    init_seed_and_system(cfg)
+
+    # Setup logging both to XP specific folder, and to stderr.
+    log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}'
+    flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name)
+    # Initialize distributed training, no need to specify anything when using Dora.
+    flashy.distrib.init()
+    solver = get_solver(cfg)
+    if cfg.show:
+        solver.show()
+        return
+
+    if cfg.execute_only:
+        assert cfg.execute_inplace or cfg.continue_from is not None, \
+            "Please explicitly specify the checkpoint to continue from with continue_from=<sig_or_path> " + \
+            "when running with execute_only or set execute_inplace to True."
+        solver.restore(replay_metrics=False)  # load checkpoint
+        solver.run_one_stage(cfg.execute_only)
+        return
+
+    return solver.run()
+
+
+main.dora.dir = AudioCraftEnvironment.get_dora_dir()
+main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm)
+
+if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK):
+    print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr)
+    main.dora.shared = None
+
+if __name__ == '__main__':
+    main()
+
+
+
+
+
+
+
+

Functions

+
+
+def get_solver(cfg) +
+
+
+
+ +Expand source code + +
def get_solver(cfg):
+    from . import solvers
+    # Convert batch size to batch size for each GPU
+    assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
+    cfg.dataset.batch_size //= flashy.distrib.world_size()
+    for split in ['train', 'valid', 'evaluate', 'generate']:
+        if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
+            assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
+            cfg.dataset[split].batch_size //= flashy.distrib.world_size()
+    resolve_config_dset_paths(cfg)
+    solver = solvers.get_solver(cfg)
+    return solver
+
+
+
+def get_solver_from_sig(sig: str, *args, **kwargs) +
+
+

Return Solver object from Dora signature, i.e. to play with it from a notebook. +See get_solver_from_xp() for more information.

+
+ +Expand source code + +
def get_solver_from_sig(sig: str, *args, **kwargs):
+    """Return Solver object from Dora signature, i.e. to play with it from a notebook.
+    See `get_solver_from_xp` for more information.
+    """
+    xp = main.get_xp_from_sig(sig)
+    return get_solver_from_xp(xp, *args, **kwargs)
+
+
+
+def get_solver_from_xp(xp: dora.xp.XP, override_cfg: Union[dict, omegaconf.dictconfig.DictConfig, None] = None, restore: bool = True, load_best: bool = True, ignore_state_keys: List[str] = [], disable_fsdp: bool = True) +
+
+

Given a XP, return the Solver object.

+

Args

+
+
xp : XP
+
Dora experiment for which to retrieve the solver.
+
override_cfg : dict or None
+
If not None, should be a dict used to +override some values in the config of xp. This will not impact +the XP signature or folder. The format is different +than the one used in Dora grids, nested keys should actually be nested dicts, +not flattened, e.g. {'optim': {'batch_size': 32}}.
+
restore : bool
+
If True (the default), restore state from the last checkpoint.
+
load_best : bool
+
If True (the default), load the best state from the checkpoint.
+
ignore_state_keys : list[str]
+
List of sources to ignore when loading the state, e.g. optimizer.
+
disable_fsdp : bool
+
if True, disables FSDP entirely. This will +also automatically skip loading the EMA. For solver specific +state sources, like the optimizer, you might want to +use along ignore_state_keys=['optimizer']. Must be used with load_best=True.
+
+
+ +Expand source code + +
def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
+                       restore: bool = True, load_best: bool = True,
+                       ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
+    """Given a XP, return the Solver object.
+
+    Args:
+        xp (XP): Dora experiment for which to retrieve the solver.
+        override_cfg (dict or None): If not None, should be a dict used to
+            override some values in the config of `xp`. This will not impact
+            the XP signature or folder. The format is different
+            than the one used in Dora grids, nested keys should actually be nested dicts,
+            not flattened, e.g. `{'optim': {'batch_size': 32}}`.
+        restore (bool): If `True` (the default), restore state from the last checkpoint.
+        load_best (bool): If `True` (the default), load the best state from the checkpoint.
+        ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
+        disable_fsdp (bool): if True, disables FSDP entirely. This will
+            also automatically skip loading the EMA. For solver specific
+            state sources, like the optimizer, you might want to
+            use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
+    """
+    logger.info(f"Loading solver from XP {xp.sig}. "
+                f"Overrides used: {xp.argv}")
+    cfg = xp.cfg
+    if override_cfg is not None:
+        cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
+    if disable_fsdp and cfg.fsdp.use:
+        cfg.fsdp.use = False
+        assert load_best is True
+        # ignoring some keys that were FSDP sharded like model, ema, and best_state.
+        # fsdp_best_state will be used in that case. When using a specific solver,
+        # one is responsible for adding the relevant keys, e.g. 'optimizer'.
+        # We could make something to automatically register those inside the solver, but that
+        # seem overkill at this point.
+        ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']
+
+    try:
+        with xp.enter():
+            solver = get_solver(cfg)
+            if restore:
+                solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
+        return solver
+    finally:
+        hydra.core.global_hydra.GlobalHydra.instance().clear()
+
+
+
+def init_seed_and_system(cfg) +
+
+
+
+ +Expand source code + +
def init_seed_and_system(cfg):
+    import numpy as np
+    import torch
+    import random
+    from audiocraft.modules.transformer import set_efficient_attention_backend
+
+    multiprocessing.set_start_method(cfg.mp_start_method)
+    logger.debug('Setting mp start method to %s', cfg.mp_start_method)
+    random.seed(cfg.seed)
+    np.random.seed(cfg.seed)
+    # torch also initialize cuda seed if available
+    torch.manual_seed(cfg.seed)
+    torch.set_num_threads(cfg.num_threads)
+    os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
+    os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
+    logger.debug('Setting num threads to %d', cfg.num_threads)
+    set_efficient_attention_backend(cfg.efficient_attention_backend)
+    logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
+    if 'SLURM_JOB_ID' in os.environ:
+        tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
+        if tmpdir.exists():
+            logger.info("Changing tmpdir to %s", tmpdir)
+            os.environ['TMPDIR'] = str(tmpdir)
+
+
+
+def resolve_config_dset_paths(cfg) +
+
+

Enable Dora to load manifest from git clone repository.

+
+ +Expand source code + +
def resolve_config_dset_paths(cfg):
+    """Enable Dora to load manifest from git clone repository."""
+    # manifest files for the different splits
+    for key, value in cfg.datasource.items():
+        if isinstance(value, str):
+            cfg.datasource[key] = git_save.to_absolute_path(value)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/autocast.html b/api_docs/audiocraft/utils/autocast.html new file mode 100644 index 00000000..bbf4554e --- /dev/null +++ b/api_docs/audiocraft/utils/autocast.html @@ -0,0 +1,163 @@ + + + + + + +audiocraft.utils.autocast API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.autocast

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+class TorchAutocast:
+    """TorchAutocast utility class.
+    Allows you to enable and disable autocast. This is specially useful
+    when dealing with different architectures and clusters with different
+    levels of support.
+
+    Args:
+        enabled (bool): Whether to enable torch.autocast or not.
+        args: Additional args for torch.autocast.
+        kwargs: Additional kwargs for torch.autocast
+    """
+    def __init__(self, enabled: bool, *args, **kwargs):
+        self.autocast = torch.autocast(*args, **kwargs) if enabled else None
+
+    def __enter__(self):
+        if self.autocast is None:
+            return
+        try:
+            self.autocast.__enter__()
+        except RuntimeError:
+            device = self.autocast.device
+            dtype = self.autocast.fast_dtype
+            raise RuntimeError(
+                f"There was an error autocasting with dtype={dtype} device={device}\n"
+                "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
+            )
+
+    def __exit__(self, *args, **kwargs):
+        if self.autocast is None:
+            return
+        self.autocast.__exit__(*args, **kwargs)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class TorchAutocast +(enabled: bool, *args, **kwargs) +
+
+

TorchAutocast utility class. +Allows you to enable and disable autocast. This is specially useful +when dealing with different architectures and clusters with different +levels of support.

+

Args

+
+
enabled : bool
+
Whether to enable torch.autocast or not.
+
args
+
Additional args for torch.autocast.
+
kwargs
+
Additional kwargs for torch.autocast
+
+
+ +Expand source code + +
class TorchAutocast:
+    """TorchAutocast utility class.
+    Allows you to enable and disable autocast. This is specially useful
+    when dealing with different architectures and clusters with different
+    levels of support.
+
+    Args:
+        enabled (bool): Whether to enable torch.autocast or not.
+        args: Additional args for torch.autocast.
+        kwargs: Additional kwargs for torch.autocast
+    """
+    def __init__(self, enabled: bool, *args, **kwargs):
+        self.autocast = torch.autocast(*args, **kwargs) if enabled else None
+
+    def __enter__(self):
+        if self.autocast is None:
+            return
+        try:
+            self.autocast.__enter__()
+        except RuntimeError:
+            device = self.autocast.device
+            dtype = self.autocast.fast_dtype
+            raise RuntimeError(
+                f"There was an error autocasting with dtype={dtype} device={device}\n"
+                "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
+            )
+
+    def __exit__(self, *args, **kwargs):
+        if self.autocast is None:
+            return
+        self.autocast.__exit__(*args, **kwargs)
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/best_state.html b/api_docs/audiocraft/utils/best_state.html new file mode 100644 index 00000000..b450a6cf --- /dev/null +++ b/api_docs/audiocraft/utils/best_state.html @@ -0,0 +1,321 @@ + + + + + + +audiocraft.utils.best_state API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.best_state

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+import typing as tp
+
+import flashy
+import torch
+
+from ..optim import ModuleDictEMA
+from .utils import copy_state
+
+
+logger = logging.getLogger(__name__)
+
+
+class BestStateDictManager(flashy.state.StateDictSource):
+    """BestStateDictManager maintains a copy of best state_dict() for registered sources.
+
+    BestStateDictManager has two main attributes:
+        states (dict): State dict of the registered StateDictSource.
+        param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.
+
+    When registering new sources, the BestStateDictManager will ensure two conflicting sources between
+    ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
+    what to consider for best state.
+
+    Args:
+        device (torch.device or str): Device on which we keep the copy.
+        dtype (torch.dtype): Data type for the state parameters.
+    """
+    def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
+                 dtype: tp.Optional[torch.dtype] = None):
+        self.device = device
+        self.states: dict = {}
+        self.param_ids: dict = defaultdict(dict)
+        self.dtype = dtype
+
+    def _get_parameter_ids(self, state_dict):
+        return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)}
+
+    def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
+        for registered_name, registered_param_ids in self.param_ids.items():
+            if registered_name != name:
+                overlap = set.intersection(registered_param_ids.keys(), param_ids.keys())
+                assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
+                f" in {name} and already registered {registered_name}: {' '.join(overlap)}"
+
+    def update(self, name: str, source: flashy.state.StateDictSource):
+        if name not in self.states:
+            raise ValueError(f"{name} missing from registered states.")
+        self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+    def register(self, name: str, source: flashy.state.StateDictSource):
+        if name in self.states:
+            raise ValueError(f"{name} already present in states.")
+        # Registering parameter ids for EMA and non-EMA states allows us to check that
+        # there is no overlap that would create ambiguity about how to handle the best state
+        param_ids = self._get_parameter_ids(source.state_dict())
+        if isinstance(source, ModuleDictEMA):
+            logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
+            self._validate_no_parameter_ids_overlap(name, param_ids)
+            self.param_ids[name] = param_ids
+        else:
+            logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
+            self._validate_no_parameter_ids_overlap('base', param_ids)
+            self.param_ids['base'].update(param_ids)
+        # Register state
+        self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+    def state_dict(self) -> flashy.state.StateDict:
+        return self.states
+
+    def load_state_dict(self, state: flashy.state.StateDict):
+        for name, sub_state in state.items():
+            for k, v in sub_state.items():
+                self.states[name][k].copy_(v)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class BestStateDictManager +(device: Union[torch.device, str] = 'cpu', dtype: Optional[torch.dtype] = None) +
+
+

BestStateDictManager maintains a copy of best state_dict() for registered sources.

+

BestStateDictManager has two main attributes: +states (dict): State dict of the registered StateDictSource. +param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.

+

When registering new sources, the BestStateDictManager will ensure two conflicting sources between +ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about +what to consider for best state.

+

Args

+
+
device : torch.device or str
+
Device on which we keep the copy.
+
dtype : torch.dtype
+
Data type for the state parameters.
+
+
+ +Expand source code + +
class BestStateDictManager(flashy.state.StateDictSource):
+    """BestStateDictManager maintains a copy of best state_dict() for registered sources.
+
+    BestStateDictManager has two main attributes:
+        states (dict): State dict of the registered StateDictSource.
+        param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.
+
+    When registering new sources, the BestStateDictManager will ensure two conflicting sources between
+    ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
+    what to consider for best state.
+
+    Args:
+        device (torch.device or str): Device on which we keep the copy.
+        dtype (torch.dtype): Data type for the state parameters.
+    """
+    def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
+                 dtype: tp.Optional[torch.dtype] = None):
+        self.device = device
+        self.states: dict = {}
+        self.param_ids: dict = defaultdict(dict)
+        self.dtype = dtype
+
+    def _get_parameter_ids(self, state_dict):
+        return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)}
+
+    def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
+        for registered_name, registered_param_ids in self.param_ids.items():
+            if registered_name != name:
+                overlap = set.intersection(registered_param_ids.keys(), param_ids.keys())
+                assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
+                f" in {name} and already registered {registered_name}: {' '.join(overlap)}"
+
+    def update(self, name: str, source: flashy.state.StateDictSource):
+        if name not in self.states:
+            raise ValueError(f"{name} missing from registered states.")
+        self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+    def register(self, name: str, source: flashy.state.StateDictSource):
+        if name in self.states:
+            raise ValueError(f"{name} already present in states.")
+        # Registering parameter ids for EMA and non-EMA states allows us to check that
+        # there is no overlap that would create ambiguity about how to handle the best state
+        param_ids = self._get_parameter_ids(source.state_dict())
+        if isinstance(source, ModuleDictEMA):
+            logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
+            self._validate_no_parameter_ids_overlap(name, param_ids)
+            self.param_ids[name] = param_ids
+        else:
+            logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
+            self._validate_no_parameter_ids_overlap('base', param_ids)
+            self.param_ids['base'].update(param_ids)
+        # Register state
+        self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+    def state_dict(self) -> flashy.state.StateDict:
+        return self.states
+
+    def load_state_dict(self, state: flashy.state.StateDict):
+        for name, sub_state in state.items():
+            for k, v in sub_state.items():
+                self.states[name][k].copy_(v)
+
+

Ancestors

+
    +
  • flashy.state.StateDictSource
  • +
  • typing.Protocol
  • +
  • typing.Generic
  • +
+

Methods

+
+
+def load_state_dict(self, state: Any) +
+
+
+
+ +Expand source code + +
def load_state_dict(self, state: flashy.state.StateDict):
+    for name, sub_state in state.items():
+        for k, v in sub_state.items():
+            self.states[name][k].copy_(v)
+
+
+
+def register(self, name: str, source: flashy.state.StateDictSource) +
+
+

Register a virtual subclass of an ABC.

+

Returns the subclass, to allow usage as a class decorator.

+
+ +Expand source code + +
def register(self, name: str, source: flashy.state.StateDictSource):
+    if name in self.states:
+        raise ValueError(f"{name} already present in states.")
+    # Registering parameter ids for EMA and non-EMA states allows us to check that
+    # there is no overlap that would create ambiguity about how to handle the best state
+    param_ids = self._get_parameter_ids(source.state_dict())
+    if isinstance(source, ModuleDictEMA):
+        logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
+        self._validate_no_parameter_ids_overlap(name, param_ids)
+        self.param_ids[name] = param_ids
+    else:
+        logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
+        self._validate_no_parameter_ids_overlap('base', param_ids)
+        self.param_ids['base'].update(param_ids)
+    # Register state
+    self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+
+
+def state_dict(self) ‑> Any +
+
+
+
+ +Expand source code + +
def state_dict(self) -> flashy.state.StateDict:
+    return self.states
+
+
+
+def update(self, name: str, source: flashy.state.StateDictSource) +
+
+
+
+ +Expand source code + +
def update(self, name: str, source: flashy.state.StateDictSource):
+    if name not in self.states:
+        raise ValueError(f"{name} missing from registered states.")
+    self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/cache.html b/api_docs/audiocraft/utils/cache.html new file mode 100644 index 00000000..eaf43de1 --- /dev/null +++ b/api_docs/audiocraft/utils/cache.html @@ -0,0 +1,1003 @@ + + + + + + +audiocraft.utils.cache API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.cache

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ThreadPoolExecutor
+from collections import deque
+from functools import partial
+from hashlib import sha1
+import logging
+from pathlib import Path
+import sys
+import typing as tp
+import zipfile
+
+import flashy
+import torch
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
+    """Utility function for the EmbeddingCache, returning the full embedding without any chunking.
+    This method can be used in case there is no need in extracting a chunk of the full embedding
+    read from the cache.
+
+    Args:
+        full_embed (torch.Tensor): The full embedding.
+        x (any): Batch object from which the full embedding is derived.
+        idx (torch.Tensor): Index of object to consider in the batch object.
+    Returns:
+        full_embed (torch.Tensor): The full embedding
+    """
+    return full_embed.to(device)
+
+
+class EmbeddingCache:
+    """Cache around embeddings computation for faster execution.
+    The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
+    to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
+    using a user-provided function. When the cache is warm (all embeddings are pre-computed),
+    the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
+    Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
+    and synchronization points in the forward calls.
+
+    Args:
+        cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
+        device (str or torch.device): Device on which the embedding is returned.
+        compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
+            the embedding from a given object and path. This user provided function can compute the
+            embedding from the provided object or using the provided path as entry point. The last parameter
+            specify the index corresponding to the current embedding in the object that can represent batch metadata.
+        extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
+            the desired embedding chunk from the full embedding loaded from the cache. The last parameter
+            specify the index corresponding to the current embedding in the object that can represent batch metadata.
+            If not specified, will return the full embedding unmodified.
+    """
+    def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
+                 compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
+                 extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
+        self.cache_path = Path(cache_path)
+        self.device = device
+        self._compute_embed_fn = compute_embed_fn
+        self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
+        if extract_embed_fn is not None:
+            self._extract_embed_fn = extract_embed_fn
+        else:
+            self._extract_embed_fn = partial(get_full_embed, device=device)
+        if self.cache_path is not None:
+            self.cache_path.mkdir(exist_ok=True, parents=True)
+            logger.info(f"Cache instantiated at: {self.cache_path}")
+            self.pool = ThreadPoolExecutor(8)
+            self.pool.__enter__()
+        self._current_batch_cache: dict = {}
+        self._memory_cache: dict = {}
+
+    def _get_cache_path(self, path: tp.Union[Path, str]):
+        """Get cache path for the given file path."""
+        sig = sha1(str(path).encode()).hexdigest()
+        return self.cache_path / sig
+
+    @staticmethod
+    def _get_full_embed_from_cache(cache: Path):
+        """Loads full pre-computed embedding from the cache."""
+        try:
+            embed = torch.load(cache, 'cpu')
+        except Exception as exc:
+            logger.error("Error loading %s: %r", cache, exc)
+            embed = None
+        return embed
+
+    def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
+        """Get embedding from cache, computing and storing it to cache if not already cached.
+        The EmbeddingCache first tries to load the embedding from the in-memory cache
+        containing the pre-computed chunks populated through `populate_embed_cache`.
+        If not found, the full embedding is computed and stored on disk to be later accessed
+        to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
+
+        Args:
+            paths (list[Path or str]): List of paths from where the embeddings can be loaded.
+            x (any): Object from which the embedding is extracted.
+        """
+        embeds = []
+        for idx, path in enumerate(paths):
+            cache = self._get_cache_path(path)
+            if cache in self._current_batch_cache:
+                embed = self._current_batch_cache[cache]
+            else:
+                full_embed = self._compute_embed_fn(path, x, idx)
+                try:
+                    with flashy.utils.write_and_rename(cache, pid=True) as f:
+                        torch.save(full_embed.cpu(), f)
+                except Exception as exc:
+                    logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
+                else:
+                    logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
+                    embed = self._extract_embed_fn(full_embed, x, idx)
+            embeds.append(embed)
+        embed = torch.stack(embeds, dim=0)
+        return embed
+
+    def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
+        """Populate in-memory caches for embeddings reading from the embeddings stored on disk.
+        The in-memory caches consist in a cache for the full embedding and another cache for the
+        final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
+        and reduce the IO footprint and synchronization points during forward passes.
+
+        Args:
+            paths (list[Path]): List of paths from where the embeddings can be loaded.
+            x (any): Object from which the embedding is extracted.
+        """
+        self._current_batch_cache.clear()
+        if self.cache_path is not None:
+            futures: list = []
+            for path in paths:
+                assert path is not None, "Path is required for computation from cache"
+                cache = self._get_cache_path(path)
+                if cache in self._memory_cache or not cache.exists():
+                    futures.append(None)
+                else:
+                    futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
+            for idx, (path, future) in enumerate(zip(paths, futures)):
+                assert path is not None
+                cache = self._get_cache_path(path)
+                full_embed = None
+                if future is None:
+                    if cache in self._memory_cache:
+                        full_embed = self._memory_cache[cache]
+                else:
+                    full_embed = future.result()
+                    if full_embed is not None:
+                        self._memory_cache[cache] = full_embed
+                        full_embed = full_embed.to(self.device)
+                if full_embed is not None:
+                    embed = self._extract_embed_fn(full_embed, x, idx)
+                    self._current_batch_cache[cache] = embed
+
+
+class CachedBatchWriter:
+    """Write pre computed caches for mini batches. This can
+    make loading a lot more efficient depending on your filesystem.
+
+    Args:
+        cache_folder (Path): folder in which the cached minibatches
+            will be stored.
+
+    Inside cache folder, the structure is the following:
+    `epoch_number / update_number.zip`
+    And the zip file contains one entry per batch item.
+
+    It is possible to use the cache with a batch size smaller than
+    created with but obviously not larger. Make sure to call the
+    `start_epoch(epoch)` method for indicating changes of epochs.
+
+    See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
+    for an example of how to warmup the cache.
+    """
+    def __init__(self, cache_folder: Path):
+        self.cache_folder = cache_folder
+        self._current_epoch: tp.Optional[int] = None
+        self._current_index = 0
+
+    def start_epoch(self, epoch: int):
+        """Call at the beginning of each epoch.
+        """
+        self._current_epoch = epoch
+        self._current_index = 0
+        self._zip_path.parent.mkdir(exist_ok=True, parents=True)
+
+    @staticmethod
+    def _get_zip_path(cache_folder: Path, epoch: int, index: int):
+        return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
+
+    @property
+    def _zip_path(self):
+        assert self._current_epoch is not None
+        return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
+
+    def save(self, *content):
+        """Save one mini batch. This function is distributed-aware
+        and will automatically merge all the items from the different
+        workers.
+        """
+        all_contents = []
+        for rank in range(flashy.distrib.world_size()):
+            their_content = flashy.distrib.broadcast_object(content, src=rank)
+            all_contents.append(their_content)
+
+        if flashy.distrib.is_rank_zero():
+            idx = 0
+            with flashy.utils.write_and_rename(self._zip_path) as tmp:
+                with zipfile.ZipFile(tmp, 'w') as zf:
+                    for content in all_contents:
+                        for vals in zip(*content):
+                            with zf.open(f'{idx}', 'w') as f:  # type: ignore
+                                torch.save(vals, f)
+                            idx += 1
+        flashy.distrib.barrier()
+        self._current_index += 1
+
+
+class CachedBatchLoader:
+    """Loader for cached mini-batches dumped with `CachedBatchWriter`.
+
+    Args:
+        cache_folder (Path): folder in which the cached minibatches are stored.
+        batch_size (int): batch size (per GPU) expected.
+        num_workers (int): number of workers to use for loading.
+        min_length (int): minimum expected length for each epoch. If some
+            mini-batches are missing, and error is raised.
+
+    This is iterable just like a regular DataLoader.
+    """
+
+    def __init__(self, cache_folder: Path, batch_size: int,
+                 num_workers: int = 10, min_length: int = 1):
+        self.cache_folder = cache_folder
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.min_length = min_length
+        self._current_epoch: tp.Optional[int] = None
+        self.sampler = None  # for compatibility with the regular DataLoader
+
+    def __len__(self):
+        path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
+        return len([p for p in path.iterdir() if p.suffix == ".zip"])
+
+    def start_epoch(self, epoch: int):
+        """Call at the beginning of each epoch.
+        """
+        self._current_epoch = epoch
+
+    def _zip_path(self, index: int):
+        assert self._current_epoch is not None
+        return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
+
+    def _load_one(self, index: int):
+        zip_path = self._zip_path(index)
+        if not zip_path.exists():
+            if index < self.min_length:
+                raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
+
+            return None
+        mode = "rb" if sys.version_info >= (3, 9) else "r"
+        try:
+            with zipfile.ZipFile(zip_path, 'r') as zf:
+                rank = flashy.distrib.rank()
+                world_size = flashy.distrib.world_size()
+                root = zipfile.Path(zf)
+                items = list(root.iterdir())
+                total_batch_size = self.batch_size * world_size
+                if len(items) < total_batch_size:
+                    raise RuntimeError(
+                        f"The cache can handle a max batch size of {len(items)}, "
+                        f"but {total_batch_size} is needed.")
+                start = rank * self.batch_size
+                items = items[start: start + self.batch_size]
+                assert len(items) == self.batch_size
+                entries = []
+                entries = [torch.load(item.open(mode), 'cpu') for item in items]  # type: ignore
+                transposed = zip(*entries)
+                out = []
+                for part in transposed:
+                    assert len(part) > 0
+                    if isinstance(part[0], torch.Tensor):
+                        out.append(torch.stack(part))
+                    else:
+                        out.append(part)
+                return out
+        except Exception:
+            logger.error("Error when reading zip path %s", zip_path)
+            raise
+
+    def __iter__(self):
+        """This will yields tuples, exactly as provided to the
+        `CachedBatchWriter.save` method.
+        """
+        pool = ThreadPoolExecutor(self.num_workers)
+        next_index = 0
+        queue = deque()
+
+        def _get_next():
+            nonlocal next_index
+            r = queue.popleft().result()
+            if r is None:
+                return None
+            else:
+                queue.append(pool.submit(self._load_one, next_index))
+                next_index += 1
+            return r
+
+        with pool:
+            # fill the buffer of fetching jobs.
+            for _ in range(2 * self.num_workers):
+                queue.append(pool.submit(self._load_one, next_index))
+                next_index += 1
+            while True:
+                batch = _get_next()
+                if batch is None:
+                    return
+                yield batch
+
+
+
+
+
+
+
+

Functions

+
+
+def get_full_embed(full_embed: torch.Tensor, x: Any, idx: int, device: Union[torch.device, str]) ‑> torch.Tensor +
+
+

Utility function for the EmbeddingCache, returning the full embedding without any chunking. +This method can be used in case there is no need in extracting a chunk of the full embedding +read from the cache.

+

Args

+
+
full_embed : torch.Tensor
+
The full embedding.
+
x : any
+
Batch object from which the full embedding is derived.
+
idx : torch.Tensor
+
Index of object to consider in the batch object.
+
+

Returns

+

full_embed (torch.Tensor): The full embedding

+
+ +Expand source code + +
def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
+    """Utility function for the EmbeddingCache, returning the full embedding without any chunking.
+    This method can be used in case there is no need in extracting a chunk of the full embedding
+    read from the cache.
+
+    Args:
+        full_embed (torch.Tensor): The full embedding.
+        x (any): Batch object from which the full embedding is derived.
+        idx (torch.Tensor): Index of object to consider in the batch object.
+    Returns:
+        full_embed (torch.Tensor): The full embedding
+    """
+    return full_embed.to(device)
+
+
+
+
+
+

Classes

+
+
+class CachedBatchLoader +(cache_folder: pathlib.Path, batch_size: int, num_workers: int = 10, min_length: int = 1) +
+
+

Loader for cached mini-batches dumped with CachedBatchWriter.

+

Args

+
+
cache_folder : Path
+
folder in which the cached minibatches are stored.
+
batch_size : int
+
batch size (per GPU) expected.
+
num_workers : int
+
number of workers to use for loading.
+
min_length : int
+
minimum expected length for each epoch. If some +mini-batches are missing, and error is raised.
+
+

This is iterable just like a regular DataLoader.

+
+ +Expand source code + +
class CachedBatchLoader:
+    """Loader for cached mini-batches dumped with `CachedBatchWriter`.
+
+    Args:
+        cache_folder (Path): folder in which the cached minibatches are stored.
+        batch_size (int): batch size (per GPU) expected.
+        num_workers (int): number of workers to use for loading.
+        min_length (int): minimum expected length for each epoch. If some
+            mini-batches are missing, and error is raised.
+
+    This is iterable just like a regular DataLoader.
+    """
+
+    def __init__(self, cache_folder: Path, batch_size: int,
+                 num_workers: int = 10, min_length: int = 1):
+        self.cache_folder = cache_folder
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.min_length = min_length
+        self._current_epoch: tp.Optional[int] = None
+        self.sampler = None  # for compatibility with the regular DataLoader
+
+    def __len__(self):
+        path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
+        return len([p for p in path.iterdir() if p.suffix == ".zip"])
+
+    def start_epoch(self, epoch: int):
+        """Call at the beginning of each epoch.
+        """
+        self._current_epoch = epoch
+
+    def _zip_path(self, index: int):
+        assert self._current_epoch is not None
+        return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
+
+    def _load_one(self, index: int):
+        zip_path = self._zip_path(index)
+        if not zip_path.exists():
+            if index < self.min_length:
+                raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
+
+            return None
+        mode = "rb" if sys.version_info >= (3, 9) else "r"
+        try:
+            with zipfile.ZipFile(zip_path, 'r') as zf:
+                rank = flashy.distrib.rank()
+                world_size = flashy.distrib.world_size()
+                root = zipfile.Path(zf)
+                items = list(root.iterdir())
+                total_batch_size = self.batch_size * world_size
+                if len(items) < total_batch_size:
+                    raise RuntimeError(
+                        f"The cache can handle a max batch size of {len(items)}, "
+                        f"but {total_batch_size} is needed.")
+                start = rank * self.batch_size
+                items = items[start: start + self.batch_size]
+                assert len(items) == self.batch_size
+                entries = []
+                entries = [torch.load(item.open(mode), 'cpu') for item in items]  # type: ignore
+                transposed = zip(*entries)
+                out = []
+                for part in transposed:
+                    assert len(part) > 0
+                    if isinstance(part[0], torch.Tensor):
+                        out.append(torch.stack(part))
+                    else:
+                        out.append(part)
+                return out
+        except Exception:
+            logger.error("Error when reading zip path %s", zip_path)
+            raise
+
+    def __iter__(self):
+        """This will yields tuples, exactly as provided to the
+        `CachedBatchWriter.save` method.
+        """
+        pool = ThreadPoolExecutor(self.num_workers)
+        next_index = 0
+        queue = deque()
+
+        def _get_next():
+            nonlocal next_index
+            r = queue.popleft().result()
+            if r is None:
+                return None
+            else:
+                queue.append(pool.submit(self._load_one, next_index))
+                next_index += 1
+            return r
+
+        with pool:
+            # fill the buffer of fetching jobs.
+            for _ in range(2 * self.num_workers):
+                queue.append(pool.submit(self._load_one, next_index))
+                next_index += 1
+            while True:
+                batch = _get_next()
+                if batch is None:
+                    return
+                yield batch
+
+

Methods

+
+
+def start_epoch(self, epoch: int) +
+
+

Call at the beginning of each epoch.

+
+ +Expand source code + +
def start_epoch(self, epoch: int):
+    """Call at the beginning of each epoch.
+    """
+    self._current_epoch = epoch
+
+
+
+
+
+class CachedBatchWriter +(cache_folder: pathlib.Path) +
+
+

Write pre computed caches for mini batches. This can +make loading a lot more efficient depending on your filesystem.

+

Args

+
+
cache_folder : Path
+
folder in which the cached minibatches +will be stored.
+
+

Inside cache folder, the structure is the following: +epoch_number / update_number.zip +And the zip file contains one entry per batch item.

+

It is possible to use the cache with a batch size smaller than +created with but obviously not larger. Make sure to call the +start_epoch(epoch) method for indicating changes of epochs.

+

See the grid audiocraft/grids/musicgen/musicgen_warmup_cache.py +for an example of how to warmup the cache.

+
+ +Expand source code + +
class CachedBatchWriter:
+    """Write pre computed caches for mini batches. This can
+    make loading a lot more efficient depending on your filesystem.
+
+    Args:
+        cache_folder (Path): folder in which the cached minibatches
+            will be stored.
+
+    Inside cache folder, the structure is the following:
+    `epoch_number / update_number.zip`
+    And the zip file contains one entry per batch item.
+
+    It is possible to use the cache with a batch size smaller than
+    created with but obviously not larger. Make sure to call the
+    `start_epoch(epoch)` method for indicating changes of epochs.
+
+    See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
+    for an example of how to warmup the cache.
+    """
+    def __init__(self, cache_folder: Path):
+        self.cache_folder = cache_folder
+        self._current_epoch: tp.Optional[int] = None
+        self._current_index = 0
+
+    def start_epoch(self, epoch: int):
+        """Call at the beginning of each epoch.
+        """
+        self._current_epoch = epoch
+        self._current_index = 0
+        self._zip_path.parent.mkdir(exist_ok=True, parents=True)
+
+    @staticmethod
+    def _get_zip_path(cache_folder: Path, epoch: int, index: int):
+        return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
+
+    @property
+    def _zip_path(self):
+        assert self._current_epoch is not None
+        return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
+
+    def save(self, *content):
+        """Save one mini batch. This function is distributed-aware
+        and will automatically merge all the items from the different
+        workers.
+        """
+        all_contents = []
+        for rank in range(flashy.distrib.world_size()):
+            their_content = flashy.distrib.broadcast_object(content, src=rank)
+            all_contents.append(their_content)
+
+        if flashy.distrib.is_rank_zero():
+            idx = 0
+            with flashy.utils.write_and_rename(self._zip_path) as tmp:
+                with zipfile.ZipFile(tmp, 'w') as zf:
+                    for content in all_contents:
+                        for vals in zip(*content):
+                            with zf.open(f'{idx}', 'w') as f:  # type: ignore
+                                torch.save(vals, f)
+                            idx += 1
+        flashy.distrib.barrier()
+        self._current_index += 1
+
+

Methods

+
+
+def save(self, *content) +
+
+

Save one mini batch. This function is distributed-aware +and will automatically merge all the items from the different +workers.

+
+ +Expand source code + +
def save(self, *content):
+    """Save one mini batch. This function is distributed-aware
+    and will automatically merge all the items from the different
+    workers.
+    """
+    all_contents = []
+    for rank in range(flashy.distrib.world_size()):
+        their_content = flashy.distrib.broadcast_object(content, src=rank)
+        all_contents.append(their_content)
+
+    if flashy.distrib.is_rank_zero():
+        idx = 0
+        with flashy.utils.write_and_rename(self._zip_path) as tmp:
+            with zipfile.ZipFile(tmp, 'w') as zf:
+                for content in all_contents:
+                    for vals in zip(*content):
+                        with zf.open(f'{idx}', 'w') as f:  # type: ignore
+                            torch.save(vals, f)
+                        idx += 1
+    flashy.distrib.barrier()
+    self._current_index += 1
+
+
+
+def start_epoch(self, epoch: int) +
+
+

Call at the beginning of each epoch.

+
+ +Expand source code + +
def start_epoch(self, epoch: int):
+    """Call at the beginning of each epoch.
+    """
+    self._current_epoch = epoch
+    self._current_index = 0
+    self._zip_path.parent.mkdir(exist_ok=True, parents=True)
+
+
+
+
+
+class EmbeddingCache +(cache_path: Union[str, pathlib.Path], device: Union[torch.device, str], compute_embed_fn: Callable[[pathlib.Path, Any, int], torch.Tensor], extract_embed_fn: Optional[Callable[[torch.Tensor, Any, int], torch.Tensor]] = None) +
+
+

Cache around embeddings computation for faster execution. +The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API +to retrieve the pre-computed embeddings on full inputs and extract only a given chunk +using a user-provided function. When the cache is warm (all embeddings are pre-computed), +the EmbeddingCache allows for faster training as it removes the need of computing the embeddings. +Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint +and synchronization points in the forward calls.

+

Args

+
+
cache_path : Path
+
Path to folder where all pre-computed embeddings are saved on disk.
+
device : str or torch.device
+
Device on which the embedding is returned.
+
compute_embed_fn : callable[[Path, any, int], torch.Tensor], optional
+
Function to compute +the embedding from a given object and path. This user provided function can compute the +embedding from the provided object or using the provided path as entry point. The last parameter +specify the index corresponding to the current embedding in the object that can represent batch metadata.
+
extract_embed_fn : callable[[torch.Tensor, any, int], torch.Tensor], optional
+
Function to extract +the desired embedding chunk from the full embedding loaded from the cache. The last parameter +specify the index corresponding to the current embedding in the object that can represent batch metadata. +If not specified, will return the full embedding unmodified.
+
+
+ +Expand source code + +
class EmbeddingCache:
+    """Cache around embeddings computation for faster execution.
+    The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
+    to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
+    using a user-provided function. When the cache is warm (all embeddings are pre-computed),
+    the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
+    Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
+    and synchronization points in the forward calls.
+
+    Args:
+        cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
+        device (str or torch.device): Device on which the embedding is returned.
+        compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
+            the embedding from a given object and path. This user provided function can compute the
+            embedding from the provided object or using the provided path as entry point. The last parameter
+            specify the index corresponding to the current embedding in the object that can represent batch metadata.
+        extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
+            the desired embedding chunk from the full embedding loaded from the cache. The last parameter
+            specify the index corresponding to the current embedding in the object that can represent batch metadata.
+            If not specified, will return the full embedding unmodified.
+    """
+    def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
+                 compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
+                 extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
+        self.cache_path = Path(cache_path)
+        self.device = device
+        self._compute_embed_fn = compute_embed_fn
+        self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
+        if extract_embed_fn is not None:
+            self._extract_embed_fn = extract_embed_fn
+        else:
+            self._extract_embed_fn = partial(get_full_embed, device=device)
+        if self.cache_path is not None:
+            self.cache_path.mkdir(exist_ok=True, parents=True)
+            logger.info(f"Cache instantiated at: {self.cache_path}")
+            self.pool = ThreadPoolExecutor(8)
+            self.pool.__enter__()
+        self._current_batch_cache: dict = {}
+        self._memory_cache: dict = {}
+
+    def _get_cache_path(self, path: tp.Union[Path, str]):
+        """Get cache path for the given file path."""
+        sig = sha1(str(path).encode()).hexdigest()
+        return self.cache_path / sig
+
+    @staticmethod
+    def _get_full_embed_from_cache(cache: Path):
+        """Loads full pre-computed embedding from the cache."""
+        try:
+            embed = torch.load(cache, 'cpu')
+        except Exception as exc:
+            logger.error("Error loading %s: %r", cache, exc)
+            embed = None
+        return embed
+
+    def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
+        """Get embedding from cache, computing and storing it to cache if not already cached.
+        The EmbeddingCache first tries to load the embedding from the in-memory cache
+        containing the pre-computed chunks populated through `populate_embed_cache`.
+        If not found, the full embedding is computed and stored on disk to be later accessed
+        to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
+
+        Args:
+            paths (list[Path or str]): List of paths from where the embeddings can be loaded.
+            x (any): Object from which the embedding is extracted.
+        """
+        embeds = []
+        for idx, path in enumerate(paths):
+            cache = self._get_cache_path(path)
+            if cache in self._current_batch_cache:
+                embed = self._current_batch_cache[cache]
+            else:
+                full_embed = self._compute_embed_fn(path, x, idx)
+                try:
+                    with flashy.utils.write_and_rename(cache, pid=True) as f:
+                        torch.save(full_embed.cpu(), f)
+                except Exception as exc:
+                    logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
+                else:
+                    logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
+                    embed = self._extract_embed_fn(full_embed, x, idx)
+            embeds.append(embed)
+        embed = torch.stack(embeds, dim=0)
+        return embed
+
+    def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
+        """Populate in-memory caches for embeddings reading from the embeddings stored on disk.
+        The in-memory caches consist in a cache for the full embedding and another cache for the
+        final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
+        and reduce the IO footprint and synchronization points during forward passes.
+
+        Args:
+            paths (list[Path]): List of paths from where the embeddings can be loaded.
+            x (any): Object from which the embedding is extracted.
+        """
+        self._current_batch_cache.clear()
+        if self.cache_path is not None:
+            futures: list = []
+            for path in paths:
+                assert path is not None, "Path is required for computation from cache"
+                cache = self._get_cache_path(path)
+                if cache in self._memory_cache or not cache.exists():
+                    futures.append(None)
+                else:
+                    futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
+            for idx, (path, future) in enumerate(zip(paths, futures)):
+                assert path is not None
+                cache = self._get_cache_path(path)
+                full_embed = None
+                if future is None:
+                    if cache in self._memory_cache:
+                        full_embed = self._memory_cache[cache]
+                else:
+                    full_embed = future.result()
+                    if full_embed is not None:
+                        self._memory_cache[cache] = full_embed
+                        full_embed = full_embed.to(self.device)
+                if full_embed is not None:
+                    embed = self._extract_embed_fn(full_embed, x, idx)
+                    self._current_batch_cache[cache] = embed
+
+

Methods

+
+
+def get_embed_from_cache(self, paths: List[pathlib.Path], x: Any) ‑> torch.Tensor +
+
+

Get embedding from cache, computing and storing it to cache if not already cached. +The EmbeddingCache first tries to load the embedding from the in-memory cache +containing the pre-computed chunks populated through populate_embed_cache. +If not found, the full embedding is computed and stored on disk to be later accessed +to populate the in-memory cache, and the desired embedding chunk is extracted and returned.

+

Args

+
+
paths : list[Path or str]
+
List of paths from where the embeddings can be loaded.
+
x : any
+
Object from which the embedding is extracted.
+
+
+ +Expand source code + +
def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
+    """Get embedding from cache, computing and storing it to cache if not already cached.
+    The EmbeddingCache first tries to load the embedding from the in-memory cache
+    containing the pre-computed chunks populated through `populate_embed_cache`.
+    If not found, the full embedding is computed and stored on disk to be later accessed
+    to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
+
+    Args:
+        paths (list[Path or str]): List of paths from where the embeddings can be loaded.
+        x (any): Object from which the embedding is extracted.
+    """
+    embeds = []
+    for idx, path in enumerate(paths):
+        cache = self._get_cache_path(path)
+        if cache in self._current_batch_cache:
+            embed = self._current_batch_cache[cache]
+        else:
+            full_embed = self._compute_embed_fn(path, x, idx)
+            try:
+                with flashy.utils.write_and_rename(cache, pid=True) as f:
+                    torch.save(full_embed.cpu(), f)
+            except Exception as exc:
+                logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
+            else:
+                logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
+                embed = self._extract_embed_fn(full_embed, x, idx)
+        embeds.append(embed)
+    embed = torch.stack(embeds, dim=0)
+    return embed
+
+
+
+def populate_embed_cache(self, paths: List[pathlib.Path], x: Any) ‑> None +
+
+

Populate in-memory caches for embeddings reading from the embeddings stored on disk. +The in-memory caches consist in a cache for the full embedding and another cache for the +final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings +and reduce the IO footprint and synchronization points during forward passes.

+

Args

+
+
paths : list[Path]
+
List of paths from where the embeddings can be loaded.
+
x : any
+
Object from which the embedding is extracted.
+
+
+ +Expand source code + +
def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
+    """Populate in-memory caches for embeddings reading from the embeddings stored on disk.
+    The in-memory caches consist in a cache for the full embedding and another cache for the
+    final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
+    and reduce the IO footprint and synchronization points during forward passes.
+
+    Args:
+        paths (list[Path]): List of paths from where the embeddings can be loaded.
+        x (any): Object from which the embedding is extracted.
+    """
+    self._current_batch_cache.clear()
+    if self.cache_path is not None:
+        futures: list = []
+        for path in paths:
+            assert path is not None, "Path is required for computation from cache"
+            cache = self._get_cache_path(path)
+            if cache in self._memory_cache or not cache.exists():
+                futures.append(None)
+            else:
+                futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
+        for idx, (path, future) in enumerate(zip(paths, futures)):
+            assert path is not None
+            cache = self._get_cache_path(path)
+            full_embed = None
+            if future is None:
+                if cache in self._memory_cache:
+                    full_embed = self._memory_cache[cache]
+            else:
+                full_embed = future.result()
+                if full_embed is not None:
+                    self._memory_cache[cache] = full_embed
+                    full_embed = full_embed.to(self.device)
+            if full_embed is not None:
+                embed = self._extract_embed_fn(full_embed, x, idx)
+                self._current_batch_cache[cache] = embed
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/checkpoint.html b/api_docs/audiocraft/utils/checkpoint.html new file mode 100644 index 00000000..d6a9ab66 --- /dev/null +++ b/api_docs/audiocraft/utils/checkpoint.html @@ -0,0 +1,492 @@ + + + + + + +audiocraft.utils.checkpoint API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.checkpoint

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import logging
+from pathlib import Path
+import re
+import typing as tp
+
+import flashy
+import torch
+
+from ..environment import AudioCraftEnvironment
+
+
+logger = logging.getLogger(__name__)
+
+
+class CheckpointSource(Enum):
+    CURRENT_XP = "current_xp"
+    PRETRAINED = "pretrained"
+    OTHER = "other"
+
+
+def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
+    """Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
+    `checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
+    'best' for the best checkpoint or the epoch number.
+
+    Args:
+        name (str, optional): Name suffix for the checkpoint file stem.
+        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
+        use_fsdp (bool): Whether the calling solver relies on FSDP.
+    Returns:
+        str: The checkpoint name.
+    """
+    suffix = ''
+    if rank is None:
+        rank = flashy.distrib.rank()
+    if rank > 0 and use_fsdp:
+        suffix = '.' + str(rank)
+    name_part = ''
+    if name is not None:
+        name_part = f'_{name}'
+    return f'checkpoint{name_part}.th{suffix}'
+
+
+def is_sharded_checkpoint(path: Path) -> bool:
+    """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
+    return re.search(r'\.th\.\d+$', path.name) is not None
+
+
+def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
+                            use_fsdp: bool = False) -> tp.Optional[Path]:
+    """Resolve a given checkpoint path for a provided dora sig or path.
+
+    Args:
+        sig_or_path (Path or str): Checkpoint path or dora signature.
+        name (str, optional): Name suffix for the checkpoint file stem.
+        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
+        use_fsdp (bool): Whether the calling solver relies on FSDP.
+    Returns:
+        Path, optional: Resolved checkpoint path, if it exists.
+    """
+    from audiocraft import train
+    xps_root = train.main.dora.dir / 'xps'
+    sig_or_path = str(sig_or_path)
+    if sig_or_path.startswith('//sig/'):
+        sig = sig_or_path[len('//sig/'):]
+        path = xps_root / sig
+    else:
+        path = Path(sig_or_path)
+        path = AudioCraftEnvironment.resolve_reference_path(path)
+
+    if path.is_dir():
+        path = path / checkpoint_name(name, use_fsdp=use_fsdp)
+
+    if path.exists():
+        return path
+    else:
+        return None
+
+
+def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
+    """Load state from checkpoints at the specified checkpoint path."""
+    if is_sharded:
+        rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
+        if rank0_checkpoint_path.exists():
+            check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
+    state = torch.load(checkpoint_path, 'cpu')
+    logger.info("Checkpoint loaded from %s", checkpoint_path)
+    return state
+
+
+def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
+    """Save state to disk to the specified checkpoint_path."""
+    _safe_save_checkpoint(state, checkpoint_path, is_sharded)
+    logger.info("Checkpoint saved to %s", checkpoint_path)
+
+
+def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
+    """Flush checkpoints to only keep last N checkpoints."""
+    if keep_last is None or keep_last <= 0:
+        return
+    checkpoint_dir = checkpoint_path.parent
+    suffix = ''
+    if flashy.distrib.rank() > 0:
+        suffix = f'.{flashy.distrib.rank()}'
+    checkpoint_files_with_epoch = []
+    for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
+        epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
+        if epoch_part.isdigit():
+            checkpoint_files_with_epoch.append((path, int(epoch_part)))
+    checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
+    total_to_flush = max(0, len(checkpoint_files) - keep_last)
+    files_to_flush = checkpoint_files[:total_to_flush]
+    for path in files_to_flush:
+        logger.debug("Removing checkpoint: %s", str(path))
+        path.unlink(missing_ok=True)
+
+
+def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
+    """Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
+    # Finish the work of a previous run that got interrupted while dumping.
+    old_path = Path(str(checkpoint_path) + '.old')
+    if old_path.exists():
+        raise RuntimeError(
+            f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
+    token = Path(str(rank0_checkpoint_path) + '.tmp.done')
+    tmp_path = Path(str(checkpoint_path) + '.tmp')
+    if token.exists():
+        if tmp_path.exists():
+            tmp_path.rename(checkpoint_path)
+    flashy.distrib.barrier()
+    if flashy.distrib.is_rank_zero() and token.exists():
+        token.unlink()
+
+
+def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
+    """Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
+    def _barrier_if_sharded():
+        if is_sharded:
+            flashy.distrib.barrier()
+
+    if flashy.distrib.is_rank_zero():
+        token = Path(str(checkpoint_path) + '.tmp.done')
+        if token.exists():
+            token.unlink()
+    _barrier_if_sharded()
+    with flashy.utils.write_and_rename(checkpoint_path) as f:
+        torch.save(state, f)
+        _barrier_if_sharded()
+        if flashy.distrib.is_rank_zero():
+            token.touch()
+        _barrier_if_sharded()
+    _barrier_if_sharded()
+    if flashy.distrib.rank() == 0:
+        token.unlink()
+
+
+
+
+
+
+
+

Functions

+
+
+def check_sharded_checkpoint(checkpoint_path: pathlib.Path, rank0_checkpoint_path: pathlib.Path) ‑> None +
+
+

Check sharded checkpoint state, ensuring the checkpoints are not corrupted.

+
+ +Expand source code + +
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
+    """Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
+    # Finish the work of a previous run that got interrupted while dumping.
+    old_path = Path(str(checkpoint_path) + '.old')
+    if old_path.exists():
+        raise RuntimeError(
+            f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
+    token = Path(str(rank0_checkpoint_path) + '.tmp.done')
+    tmp_path = Path(str(checkpoint_path) + '.tmp')
+    if token.exists():
+        if tmp_path.exists():
+            tmp_path.rename(checkpoint_path)
+    flashy.distrib.barrier()
+    if flashy.distrib.is_rank_zero() and token.exists():
+        token.unlink()
+
+
+
+def checkpoint_name(name: Optional[str] = None, rank: Optional[int] = None, use_fsdp: bool = False) ‑> str +
+
+

Checkpoint name formatted for all use in AudioCraft codebase and has the following format: +checkpoint_<name>.th(.<rank>). By convention, name is expected to be empty for last checkpoint, +'best' for the best checkpoint or the epoch number.

+

Args

+
+
name : str, optional
+
Name suffix for the checkpoint file stem.
+
rank : optional, int
+
Rank for distributed processing, retrieved with flashy if not provided.
+
use_fsdp : bool
+
Whether the calling solver relies on FSDP.
+
+

Returns

+
+
str
+
The checkpoint name.
+
+
+ +Expand source code + +
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
+    """Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
+    `checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
+    'best' for the best checkpoint or the epoch number.
+
+    Args:
+        name (str, optional): Name suffix for the checkpoint file stem.
+        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
+        use_fsdp (bool): Whether the calling solver relies on FSDP.
+    Returns:
+        str: The checkpoint name.
+    """
+    suffix = ''
+    if rank is None:
+        rank = flashy.distrib.rank()
+    if rank > 0 and use_fsdp:
+        suffix = '.' + str(rank)
+    name_part = ''
+    if name is not None:
+        name_part = f'_{name}'
+    return f'checkpoint{name_part}.th{suffix}'
+
+
+
+def flush_stale_checkpoints(checkpoint_path: pathlib.Path, keep_last: Optional[int] = None) ‑> None +
+
+

Flush checkpoints to only keep last N checkpoints.

+
+ +Expand source code + +
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
+    """Flush checkpoints to only keep last N checkpoints."""
+    if keep_last is None or keep_last <= 0:
+        return
+    checkpoint_dir = checkpoint_path.parent
+    suffix = ''
+    if flashy.distrib.rank() > 0:
+        suffix = f'.{flashy.distrib.rank()}'
+    checkpoint_files_with_epoch = []
+    for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
+        epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
+        if epoch_part.isdigit():
+            checkpoint_files_with_epoch.append((path, int(epoch_part)))
+    checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
+    total_to_flush = max(0, len(checkpoint_files) - keep_last)
+    files_to_flush = checkpoint_files[:total_to_flush]
+    for path in files_to_flush:
+        logger.debug("Removing checkpoint: %s", str(path))
+        path.unlink(missing_ok=True)
+
+
+
+def is_sharded_checkpoint(path: pathlib.Path) ‑> bool +
+
+

Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.

+
+ +Expand source code + +
def is_sharded_checkpoint(path: Path) -> bool:
+    """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
+    return re.search(r'\.th\.\d+$', path.name) is not None
+
+
+
+def load_checkpoint(checkpoint_path: pathlib.Path, is_sharded: bool = False) ‑> Any +
+
+

Load state from checkpoints at the specified checkpoint path.

+
+ +Expand source code + +
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
+    """Load state from checkpoints at the specified checkpoint path."""
+    if is_sharded:
+        rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
+        if rank0_checkpoint_path.exists():
+            check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
+    state = torch.load(checkpoint_path, 'cpu')
+    logger.info("Checkpoint loaded from %s", checkpoint_path)
+    return state
+
+
+
+def resolve_checkpoint_path(sig_or_path: Union[str, pathlib.Path], name: Optional[str] = None, use_fsdp: bool = False) ‑> Optional[pathlib.Path] +
+
+

Resolve a given checkpoint path for a provided dora sig or path.

+

Args

+
+
sig_or_path : Path or str
+
Checkpoint path or dora signature.
+
name : str, optional
+
Name suffix for the checkpoint file stem.
+
rank : optional, int
+
Rank for distributed processing, retrieved with flashy if not provided.
+
use_fsdp : bool
+
Whether the calling solver relies on FSDP.
+
+

Returns

+
+
Path, optional
+
Resolved checkpoint path, if it exists.
+
+
+ +Expand source code + +
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
+                            use_fsdp: bool = False) -> tp.Optional[Path]:
+    """Resolve a given checkpoint path for a provided dora sig or path.
+
+    Args:
+        sig_or_path (Path or str): Checkpoint path or dora signature.
+        name (str, optional): Name suffix for the checkpoint file stem.
+        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
+        use_fsdp (bool): Whether the calling solver relies on FSDP.
+    Returns:
+        Path, optional: Resolved checkpoint path, if it exists.
+    """
+    from audiocraft import train
+    xps_root = train.main.dora.dir / 'xps'
+    sig_or_path = str(sig_or_path)
+    if sig_or_path.startswith('//sig/'):
+        sig = sig_or_path[len('//sig/'):]
+        path = xps_root / sig
+    else:
+        path = Path(sig_or_path)
+        path = AudioCraftEnvironment.resolve_reference_path(path)
+
+    if path.is_dir():
+        path = path / checkpoint_name(name, use_fsdp=use_fsdp)
+
+    if path.exists():
+        return path
+    else:
+        return None
+
+
+
+def save_checkpoint(state: Any, checkpoint_path: pathlib.Path, is_sharded: bool = False) ‑> None +
+
+

Save state to disk to the specified checkpoint_path.

+
+ +Expand source code + +
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
+    """Save state to disk to the specified checkpoint_path."""
+    _safe_save_checkpoint(state, checkpoint_path, is_sharded)
+    logger.info("Checkpoint saved to %s", checkpoint_path)
+
+
+
+
+
+

Classes

+
+
+class CheckpointSource +(value, names=None, *, module=None, qualname=None, type=None, start=1) +
+
+

An enumeration.

+
+ +Expand source code + +
class CheckpointSource(Enum):
+    CURRENT_XP = "current_xp"
+    PRETRAINED = "pretrained"
+    OTHER = "other"
+
+

Ancestors

+
    +
  • enum.Enum
  • +
+

Class variables

+
+
var CURRENT_XP
+
+
+
+
var OTHER
+
+
+
+
var PRETRAINED
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/cluster.html b/api_docs/audiocraft/utils/cluster.html new file mode 100644 index 00000000..c58f030d --- /dev/null +++ b/api_docs/audiocraft/utils/cluster.html @@ -0,0 +1,257 @@ + + + + + + +audiocraft.utils.cluster API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.cluster

+
+
+

Utility functions for SLURM configuration and cluster settings.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility functions for SLURM configuration and cluster settings.
+"""
+
+from enum import Enum
+import os
+import socket
+import typing as tp
+
+import omegaconf
+
+
+class ClusterType(Enum):
+    AWS = "aws"
+    FAIR = "fair"
+    RSC = "rsc"
+    LOCAL_DARWIN = "darwin"
+    DEFAULT = "default"  # used for any other cluster.
+
+
+def _guess_cluster_type() -> ClusterType:
+    uname = os.uname()
+    fqdn = socket.getfqdn()
+    if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
+        return ClusterType.AWS
+
+    if fqdn.endswith(".fair"):
+        return ClusterType.FAIR
+
+    if fqdn.endswith(".facebook.com"):
+        return ClusterType.RSC
+
+    if uname.sysname == "Darwin":
+        return ClusterType.LOCAL_DARWIN
+
+    return ClusterType.DEFAULT
+
+
+def get_cluster_type(
+    cluster_type: tp.Optional[ClusterType] = None,
+) -> tp.Optional[ClusterType]:
+    if cluster_type is None:
+        return _guess_cluster_type()
+
+    return cluster_type
+
+
+def get_slurm_parameters(
+    cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
+) -> omegaconf.DictConfig:
+    """Update SLURM parameters in configuration based on cluster type.
+    If the cluster type is not specify, it infers it automatically.
+    """
+    from ..environment import AudioCraftEnvironment
+    cluster_type = get_cluster_type(cluster_type)
+    # apply cluster-specific adjustments
+    if cluster_type == ClusterType.AWS:
+        cfg["mem_per_gpu"] = None
+        cfg["constraint"] = None
+        cfg["setup"] = []
+    elif cluster_type == ClusterType.RSC:
+        cfg["mem_per_gpu"] = None
+        cfg["setup"] = []
+        cfg["constraint"] = None
+        cfg["partition"] = "learn"
+    slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
+    if slurm_exclude is not None:
+        cfg["exclude"] = slurm_exclude
+    return cfg
+
+
+
+
+
+
+
+

Functions

+
+
+def get_cluster_type(cluster_type: Optional[ClusterType] = None) ‑> Optional[ClusterType] +
+
+
+
+ +Expand source code + +
def get_cluster_type(
+    cluster_type: tp.Optional[ClusterType] = None,
+) -> tp.Optional[ClusterType]:
+    if cluster_type is None:
+        return _guess_cluster_type()
+
+    return cluster_type
+
+
+
+def get_slurm_parameters(cfg: omegaconf.dictconfig.DictConfig, cluster_type: Optional[ClusterType] = None) ‑> omegaconf.dictconfig.DictConfig +
+
+

Update SLURM parameters in configuration based on cluster type. +If the cluster type is not specify, it infers it automatically.

+
+ +Expand source code + +
def get_slurm_parameters(
+    cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
+) -> omegaconf.DictConfig:
+    """Update SLURM parameters in configuration based on cluster type.
+    If the cluster type is not specify, it infers it automatically.
+    """
+    from ..environment import AudioCraftEnvironment
+    cluster_type = get_cluster_type(cluster_type)
+    # apply cluster-specific adjustments
+    if cluster_type == ClusterType.AWS:
+        cfg["mem_per_gpu"] = None
+        cfg["constraint"] = None
+        cfg["setup"] = []
+    elif cluster_type == ClusterType.RSC:
+        cfg["mem_per_gpu"] = None
+        cfg["setup"] = []
+        cfg["constraint"] = None
+        cfg["partition"] = "learn"
+    slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
+    if slurm_exclude is not None:
+        cfg["exclude"] = slurm_exclude
+    return cfg
+
+
+
+
+
+

Classes

+
+
+class ClusterType +(value, names=None, *, module=None, qualname=None, type=None, start=1) +
+
+

An enumeration.

+
+ +Expand source code + +
class ClusterType(Enum):
+    AWS = "aws"
+    FAIR = "fair"
+    RSC = "rsc"
+    LOCAL_DARWIN = "darwin"
+    DEFAULT = "default"  # used for any other cluster.
+
+

Ancestors

+
    +
  • enum.Enum
  • +
+

Class variables

+
+
var AWS
+
+
+
+
var DEFAULT
+
+
+
+
var FAIR
+
+
+
+
var LOCAL_DARWIN
+
+
+
+
var RSC
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/deadlock.html b/api_docs/audiocraft/utils/deadlock.html new file mode 100644 index 00000000..d09ffc66 --- /dev/null +++ b/api_docs/audiocraft/utils/deadlock.html @@ -0,0 +1,199 @@ + + + + + + +audiocraft.utils.deadlock API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.deadlock

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+from queue import Queue, Empty
+import signal
+import sys
+import threading
+import traceback
+
+logger = logging.getLogger(__name__)
+
+
+class DeadlockDetect:
+    def __init__(self, use: bool = False, timeout: float = 120.):
+        self.use = use
+        self.timeout = timeout
+        self._queue: Queue = Queue()
+
+    def update(self, stage: str):
+        if self.use:
+            self._queue.put(stage)
+
+    def __enter__(self):
+        if self.use:
+            self._thread = threading.Thread(target=self._detector_thread)
+            self._thread.start()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.use:
+            self._queue.put(None)
+            self._thread.join()
+
+    def _detector_thread(self):
+        logger.debug("Deadlock detector started")
+        last_stage = "init"
+        while True:
+            try:
+                stage = self._queue.get(timeout=self.timeout)
+            except Empty:
+                break
+            if stage is None:
+                logger.debug("Exiting deadlock detector thread")
+                return
+            else:
+                last_stage = stage
+        logger.error("Deadlock detector timed out, last stage was %s", last_stage)
+        for th in threading.enumerate():
+            print(th, file=sys.stderr)
+            traceback.print_stack(sys._current_frames()[th.ident])
+            print(file=sys.stderr)
+        sys.stdout.flush()
+        sys.stderr.flush()
+        os.kill(os.getpid(), signal.SIGKILL)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class DeadlockDetect +(use: bool = False, timeout: float = 120.0) +
+
+
+
+ +Expand source code + +
class DeadlockDetect:
+    def __init__(self, use: bool = False, timeout: float = 120.):
+        self.use = use
+        self.timeout = timeout
+        self._queue: Queue = Queue()
+
+    def update(self, stage: str):
+        if self.use:
+            self._queue.put(stage)
+
+    def __enter__(self):
+        if self.use:
+            self._thread = threading.Thread(target=self._detector_thread)
+            self._thread.start()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.use:
+            self._queue.put(None)
+            self._thread.join()
+
+    def _detector_thread(self):
+        logger.debug("Deadlock detector started")
+        last_stage = "init"
+        while True:
+            try:
+                stage = self._queue.get(timeout=self.timeout)
+            except Empty:
+                break
+            if stage is None:
+                logger.debug("Exiting deadlock detector thread")
+                return
+            else:
+                last_stage = stage
+        logger.error("Deadlock detector timed out, last stage was %s", last_stage)
+        for th in threading.enumerate():
+            print(th, file=sys.stderr)
+            traceback.print_stack(sys._current_frames()[th.ident])
+            print(file=sys.stderr)
+        sys.stdout.flush()
+        sys.stderr.flush()
+        os.kill(os.getpid(), signal.SIGKILL)
+
+

Methods

+
+
+def update(self, stage: str) +
+
+
+
+ +Expand source code + +
def update(self, stage: str):
+    if self.use:
+        self._queue.put(stage)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/export.html b/api_docs/audiocraft/utils/export.html new file mode 100644 index 00000000..8560cfe3 --- /dev/null +++ b/api_docs/audiocraft/utils/export.html @@ -0,0 +1,243 @@ + + + + + + +audiocraft.utils.export API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.export

+
+
+

Utility to export a training checkpoint to a lightweight release checkpoint.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility to export a training checkpoint to a lightweight release checkpoint.
+"""
+
+from pathlib import Path
+import typing as tp
+
+from omegaconf import OmegaConf
+import torch
+
+from audiocraft import __version__
+
+
+def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
+    """Export only the best state from the given EnCodec checkpoint. This
+    should be used if you trained your own EnCodec model.
+    """
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['best_state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+        'version': __version__,
+        'exported': True,
+    }
+    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
+    """Export a compression model (potentially EnCodec) from a pretrained model.
+    This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
+    Do not include the //pretrained/ prefix. For instance if you trained a model
+    with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
+
+    In that case, this will not actually include a copy of the model, simply the reference
+    to the model used.
+    """
+    if Path(pretrained_encodec).exists():
+        pkg = torch.load(pretrained_encodec)
+        assert 'best_state' in pkg
+        assert 'xp.cfg' in pkg
+        assert 'version' in pkg
+        assert 'exported' in pkg
+    else:
+        pkg = {
+            'pretrained': pretrained_encodec,
+            'exported': True,
+            'version': __version__,
+        }
+    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+    torch.save(pkg, out_file)
+
+
+def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
+    """Export only the best state from the given MusicGen or AudioGen checkpoint.
+    """
+    pkg = torch.load(checkpoint_path, 'cpu')
+    if pkg['fsdp_best_state']:
+        best_state = pkg['fsdp_best_state']['model']
+    else:
+        assert pkg['best_state']
+        best_state = pkg['best_state']['model']
+    new_pkg = {
+        'best_state': best_state,
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+        'version': __version__,
+        'exported': True,
+    }
+
+    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+
+
+
+
+

Functions

+
+
+def export_encodec(checkpoint_path: Union[str, pathlib.Path], out_file: Union[str, pathlib.Path]) +
+
+

Export only the best state from the given EnCodec checkpoint. This +should be used if you trained your own EnCodec model.

+
+ +Expand source code + +
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
+    """Export only the best state from the given EnCodec checkpoint. This
+    should be used if you trained your own EnCodec model.
+    """
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['best_state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+        'version': __version__,
+        'exported': True,
+    }
+    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+def export_lm(checkpoint_path: Union[str, pathlib.Path], out_file: Union[str, pathlib.Path]) +
+
+

Export only the best state from the given MusicGen or AudioGen checkpoint.

+
+ +Expand source code + +
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
+    """Export only the best state from the given MusicGen or AudioGen checkpoint.
+    """
+    pkg = torch.load(checkpoint_path, 'cpu')
+    if pkg['fsdp_best_state']:
+        best_state = pkg['fsdp_best_state']['model']
+    else:
+        assert pkg['best_state']
+        best_state = pkg['best_state']['model']
+    new_pkg = {
+        'best_state': best_state,
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+        'version': __version__,
+        'exported': True,
+    }
+
+    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+def export_pretrained_compression_model(pretrained_encodec: str, out_file: Union[str, pathlib.Path]) +
+
+

Export a compression model (potentially EnCodec) from a pretrained model. +This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. +Do not include the //pretrained/ prefix. For instance if you trained a model +with facebook/encodec_32khz, just put that as a name. Same for dac_44khz.

+

In that case, this will not actually include a copy of the model, simply the reference +to the model used.

+
+ +Expand source code + +
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
+    """Export a compression model (potentially EnCodec) from a pretrained model.
+    This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
+    Do not include the //pretrained/ prefix. For instance if you trained a model
+    with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
+
+    In that case, this will not actually include a copy of the model, simply the reference
+    to the model used.
+    """
+    if Path(pretrained_encodec).exists():
+        pkg = torch.load(pretrained_encodec)
+        assert 'best_state' in pkg
+        assert 'xp.cfg' in pkg
+        assert 'version' in pkg
+        assert 'exported' in pkg
+    else:
+        pkg = {
+            'pretrained': pretrained_encodec,
+            'exported': True,
+            'version': __version__,
+        }
+    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
+    torch.save(pkg, out_file)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/export_legacy.html b/api_docs/audiocraft/utils/export_legacy.html new file mode 100644 index 00000000..1722da97 --- /dev/null +++ b/api_docs/audiocraft/utils/export_legacy.html @@ -0,0 +1,168 @@ + + + + + + +audiocraft.utils.export_legacy API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.export_legacy

+
+
+

Legacy functions used at the time of the first release, kept for referencd.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Legacy functions used at the time of the first release, kept for referencd.
+"""
+
+from pathlib import Path
+import typing as tp
+
+from omegaconf import OmegaConf, DictConfig
+import torch
+
+
+def _clean_lm_cfg(cfg: DictConfig):
+    OmegaConf.set_struct(cfg, False)
+    # This used to be set automatically in the LM solver, need a more robust solution
+    # for the future.
+    cfg['transformer_lm']['card'] = 2048
+    cfg['transformer_lm']['n_q'] = 4
+    # Experimental params no longer supported.
+    bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
+                  'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
+    for name in bad_params:
+        del cfg['transformer_lm'][name]
+    OmegaConf.set_struct(cfg, True)
+    return cfg
+
+
+def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['ema']['state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['fsdp_best_state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+
+
+
+
+

Functions

+
+
+def export_encodec(checkpoint_path: Union[str, pathlib.Path], out_folder: Union[str, pathlib.Path]) +
+
+
+
+ +Expand source code + +
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['ema']['state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+def export_lm(checkpoint_path: Union[str, pathlib.Path], out_folder: Union[str, pathlib.Path]) +
+
+
+
+ +Expand source code + +
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['fsdp_best_state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/index.html b/api_docs/audiocraft/utils/index.html new file mode 100644 index 00000000..f5ce143a --- /dev/null +++ b/api_docs/audiocraft/utils/index.html @@ -0,0 +1,132 @@ + + + + + + +audiocraft.utils API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils

+
+
+

Utilities.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+"""Utilities."""
+
+
+
+

Sub-modules

+
+
audiocraft.utils.autocast
+
+
+
+
audiocraft.utils.best_state
+
+
+
+
audiocraft.utils.cache
+
+
+
+
audiocraft.utils.checkpoint
+
+
+
+
audiocraft.utils.cluster
+
+

Utility functions for SLURM configuration and cluster settings.

+
+
audiocraft.utils.deadlock
+
+
+
+
audiocraft.utils.export
+
+

Utility to export a training checkpoint to a lightweight release checkpoint.

+
+
audiocraft.utils.export_legacy
+
+

Legacy functions used at the time of the first release, kept for referencd.

+
+
audiocraft.utils.notebook
+
+
+
+
audiocraft.utils.profiler
+
+
+
+
audiocraft.utils.samples
+
+
+
+
audiocraft.utils.utils
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/notebook.html b/api_docs/audiocraft/utils/notebook.html new file mode 100644 index 00000000..075a78d7 --- /dev/null +++ b/api_docs/audiocraft/utils/notebook.html @@ -0,0 +1,133 @@ + + + + + + +audiocraft.utils.notebook API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.notebook

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+try:
+    import IPython.display as ipd  # type: ignore
+except ImportError:
+    # Note in a notebook...
+    pass
+
+
+import torch
+
+
+def display_audio(samples: torch.Tensor, sample_rate: int):
+    """Renders an audio player for the given audio samples.
+
+    Args:
+        samples (torch.Tensor): a Tensor of decoded audio samples
+            with shapes [B, C, T] or [C, T]
+        sample_rate (int): sample rate audio should be displayed with.
+    """
+    assert samples.dim() == 2 or samples.dim() == 3
+
+    samples = samples.detach().cpu()
+    if samples.dim() == 2:
+        samples = samples[None, ...]
+
+    for audio in samples:
+        ipd.display(ipd.Audio(audio, rate=sample_rate))
+
+
+
+
+
+
+
+

Functions

+
+
+def display_audio(samples: torch.Tensor, sample_rate: int) +
+
+

Renders an audio player for the given audio samples.

+

Args

+
+
samples : torch.Tensor
+
a Tensor of decoded audio samples +with shapes [B, C, T] or [C, T]
+
sample_rate : int
+
sample rate audio should be displayed with.
+
+
+ +Expand source code + +
def display_audio(samples: torch.Tensor, sample_rate: int):
+    """Renders an audio player for the given audio samples.
+
+    Args:
+        samples (torch.Tensor): a Tensor of decoded audio samples
+            with shapes [B, C, T] or [C, T]
+        sample_rate (int): sample rate audio should be displayed with.
+    """
+    assert samples.dim() == 2 or samples.dim() == 3
+
+    samples = samples.detach().cpu()
+    if samples.dim() == 2:
+        samples = samples[None, ...]
+
+    for audio in samples:
+        ipd.display(ipd.Audio(audio, rate=sample_rate))
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/profiler.html b/api_docs/audiocraft/utils/profiler.html new file mode 100644 index 00000000..19067e21 --- /dev/null +++ b/api_docs/audiocraft/utils/profiler.html @@ -0,0 +1,160 @@ + + + + + + +audiocraft.utils.profiler API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.profiler

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import typing as tp
+
+import dora
+import torch
+
+
+logger = logging.getLogger(__name__)
+
+
+class Profiler:
+    """Context manager wrapper for xformers profiler.
+    """
+    def __init__(self, module: torch.nn.Module, enabled: bool = False):
+        self.profiler: tp.Optional[tp.Any] = None
+        if enabled:
+            from xformers.profiler import profile
+            output_dir = dora.get_xp().folder / 'profiler_data'
+            logger.info("Profiling activated, results with be saved to %s", output_dir)
+            self.profiler = profile(output_dir=output_dir, module=module)
+
+    def step(self):
+        if self.profiler is not None:
+            self.profiler.step()  # type: ignore
+
+    def __enter__(self):
+        if self.profiler is not None:
+            return self.profiler.__enter__()  # type: ignore
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        if self.profiler is not None:
+            return self.profiler.__exit__(exc_type, exc_value, exc_tb)  # type: ignore
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class Profiler +(module: torch.nn.modules.module.Module, enabled: bool = False) +
+
+

Context manager wrapper for xformers profiler.

+
+ +Expand source code + +
class Profiler:
+    """Context manager wrapper for xformers profiler.
+    """
+    def __init__(self, module: torch.nn.Module, enabled: bool = False):
+        self.profiler: tp.Optional[tp.Any] = None
+        if enabled:
+            from xformers.profiler import profile
+            output_dir = dora.get_xp().folder / 'profiler_data'
+            logger.info("Profiling activated, results with be saved to %s", output_dir)
+            self.profiler = profile(output_dir=output_dir, module=module)
+
+    def step(self):
+        if self.profiler is not None:
+            self.profiler.step()  # type: ignore
+
+    def __enter__(self):
+        if self.profiler is not None:
+            return self.profiler.__enter__()  # type: ignore
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        if self.profiler is not None:
+            return self.profiler.__exit__(exc_type, exc_value, exc_tb)  # type: ignore
+
+

Methods

+
+
+def step(self) +
+
+
+
+ +Expand source code + +
def step(self):
+    if self.profiler is not None:
+        self.profiler.step()  # type: ignore
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/samples/index.html b/api_docs/audiocraft/utils/samples/index.html new file mode 100644 index 00000000..93a81991 --- /dev/null +++ b/api_docs/audiocraft/utils/samples/index.html @@ -0,0 +1,75 @@ + + + + + + +audiocraft.utils.samples API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.samples

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+
+

Sub-modules

+
+
audiocraft.utils.samples.manager
+
+

API that can manage the storage and retrieval of generated samples produced by experiments …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/samples/manager.html b/api_docs/audiocraft/utils/samples/manager.html new file mode 100644 index 00000000..2eabe288 --- /dev/null +++ b/api_docs/audiocraft/utils/samples/manager.html @@ -0,0 +1,1233 @@ + + + + + + +audiocraft.utils.samples.manager API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.samples.manager

+
+
+

API that can manage the storage and retrieval of generated samples produced by experiments.

+

It offers the following benefits: +* Samples are stored in a consistent way across epoch +* Metadata about the samples can be stored and retrieved +* Can retrieve audio +* Identifiers are reliable and deterministic for prompted and conditioned samples +* Can request the samples for multiple XPs, grouped by sample identifier +* For no-input samples (not prompt and no conditions), samples across XPs are matched +by sorting their identifiers

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+API that can manage the storage and retrieval of generated samples produced by experiments.
+
+It offers the following benefits:
+* Samples are stored in a consistent way across epoch
+* Metadata about the samples can be stored and retrieved
+* Can retrieve audio
+* Identifiers are reliable and deterministic for prompted and conditioned samples
+* Can request the samples for multiple XPs, grouped by sample identifier
+* For no-input samples (not prompt and no conditions), samples across XPs are matched
+  by sorting their identifiers
+"""
+
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import asdict, dataclass
+from functools import lru_cache
+import hashlib
+import json
+import logging
+from pathlib import Path
+import re
+import typing as tp
+import unicodedata
+import uuid
+
+import dora
+import torch
+
+from ...data.audio import audio_read, audio_write
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ReferenceSample:
+    id: str
+    path: str
+    duration: float
+
+
+@dataclass
+class Sample:
+    id: str
+    path: str
+    epoch: int
+    duration: float
+    conditioning: tp.Optional[tp.Dict[str, tp.Any]]
+    prompt: tp.Optional[ReferenceSample]
+    reference: tp.Optional[ReferenceSample]
+    generation_args: tp.Optional[tp.Dict[str, tp.Any]]
+
+    def __hash__(self):
+        return hash(self.id)
+
+    def audio(self) -> tp.Tuple[torch.Tensor, int]:
+        return audio_read(self.path)
+
+    def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+        return audio_read(self.prompt.path) if self.prompt is not None else None
+
+    def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+        return audio_read(self.reference.path) if self.reference is not None else None
+
+
+class SampleManager:
+    """Audio samples IO handling within a given dora xp.
+
+    The sample manager handles the dumping and loading logic for generated and
+    references samples across epochs for a given xp, providing a simple API to
+    store, retrieve and compare audio samples.
+
+    Args:
+        xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
+            where all outputs are stored and the configuration of the experiment,
+            which is useful to retrieve audio-related parameters.
+        map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
+            instead of generating a dedicated hash id. This is useful to allow easier comparison
+            with ground truth sample from the files directly without having to read the JSON metadata
+            to do the mapping (at the cost of potentially dumping duplicate prompts/references
+            depending on the task).
+    """
+    def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
+        self.xp = xp
+        self.base_folder: Path = xp.folder / xp.cfg.generate.path
+        self.reference_folder = self.base_folder / 'reference'
+        self.map_reference_to_sample_id = map_reference_to_sample_id
+        self.samples: tp.List[Sample] = []
+        self._load_samples()
+
+    @property
+    def latest_epoch(self):
+        """Latest epoch across all samples."""
+        return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
+
+    def _load_samples(self):
+        """Scan the sample folder and load existing samples."""
+        jsons = self.base_folder.glob('**/*.json')
+        with ThreadPoolExecutor(6) as pool:
+            self.samples = list(pool.map(self._load_sample, jsons))
+
+    @staticmethod
+    @lru_cache(2**26)
+    def _load_sample(json_file: Path) -> Sample:
+        with open(json_file, 'r') as f:
+            data: tp.Dict[str, tp.Any] = json.load(f)
+        # fetch prompt data
+        prompt_data = data.get('prompt')
+        prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
+                                 duration=prompt_data['duration']) if prompt_data else None
+        # fetch reference data
+        reference_data = data.get('reference')
+        reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
+                                    duration=reference_data['duration']) if reference_data else None
+        # build sample object
+        return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
+                      prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
+                      generation_args=data.get('generation_args'))
+
+    def _init_hash(self):
+        return hashlib.sha1()
+
+    def _get_tensor_id(self, tensor: torch.Tensor) -> str:
+        hash_id = self._init_hash()
+        hash_id.update(tensor.numpy().data)
+        return hash_id.hexdigest()
+
+    def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
+                       conditions: tp.Optional[tp.Dict[str, str]]) -> str:
+        """Computes an id for a sample given its input data.
+        This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
+        Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
+
+        Args:
+            index (int): Batch index, Helpful to differentiate samples from the same batch.
+            prompt_wav (torch.Tensor): Prompt used during generation.
+            conditions (dict[str, str]): Conditioning used during generation.
+        """
+        # For totally unconditioned generations we will just use a random UUID.
+        # The function get_samples_for_xps will do a simple ordered match with a custom key.
+        if prompt_wav is None and not conditions:
+            return f"noinput_{uuid.uuid4().hex}"
+
+        # Human readable portion
+        hr_label = ""
+        # Create a deterministic id using hashing
+        hash_id = self._init_hash()
+        hash_id.update(f"{index}".encode())
+        if prompt_wav is not None:
+            hash_id.update(prompt_wav.numpy().data)
+            hr_label += "_prompted"
+        else:
+            hr_label += "_unprompted"
+        if conditions:
+            encoded_json = json.dumps(conditions, sort_keys=True).encode()
+            hash_id.update(encoded_json)
+            cond_str = "-".join([f"{key}={slugify(value)}"
+                                 for key, value in sorted(conditions.items())])
+            cond_str = cond_str[:100]  # some raw text might be too long to be a valid filename
+            cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
+            hr_label += f"_{cond_str}"
+        else:
+            hr_label += "_unconditioned"
+
+        return hash_id.hexdigest() + hr_label
+
+    def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
+        """Stores the audio with the given stem path using the XP's configuration.
+
+        Args:
+            wav (torch.Tensor): Audio to store.
+            stem_path (Path): Path in sample output directory with file stem to use.
+            overwrite (bool): When False (default), skips storing an existing audio file.
+        Returns:
+            Path: The path at which the audio is stored.
+        """
+        existing_paths = [
+            path for path in stem_path.parent.glob(stem_path.stem + '.*')
+            if path.suffix != '.json'
+        ]
+        exists = len(existing_paths) > 0
+        if exists and overwrite:
+            logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
+        elif exists:
+            return existing_paths[0]
+
+        audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
+        return audio_path
+
+    def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
+                   conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
+                   ground_truth_wav: tp.Optional[torch.Tensor] = None,
+                   generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
+        """Adds a single sample.
+        The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
+        Each sample is assigned an id which is computed using the input data. In addition to the
+        sample itself, a json file containing associated metadata is stored next to it.
+
+        Args:
+            sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
+            epoch (int): current training epoch.
+            index (int): helpful to differentiate samples from the same batch.
+            conditions (dict[str, str], optional): conditioning used during generation.
+            prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
+            ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
+                Tensor of shape [channels, shape].
+            generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
+        Returns:
+            Sample: The saved sample.
+        """
+        sample_id = self._get_sample_id(index, prompt_wav, conditions)
+        reuse_id = self.map_reference_to_sample_id
+        prompt, ground_truth = None, None
+        if prompt_wav is not None:
+            prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
+            prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
+            prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
+            prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
+        if ground_truth_wav is not None:
+            ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
+            ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
+            ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
+            ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
+        sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
+        duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
+        sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
+        self.samples.append(sample)
+        with open(sample_path.with_suffix('.json'), 'w') as f:
+            json.dump(asdict(sample), f, indent=2)
+        return sample
+
+    def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
+                    conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
+                    prompt_wavs: tp.Optional[torch.Tensor] = None,
+                    ground_truth_wavs: tp.Optional[torch.Tensor] = None,
+                    generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
+        """Adds a batch of samples.
+        The samples are stored in the XP's sample output directory, under a corresponding
+        epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
+        In addition to the sample itself, a json file containing associated metadata is stored next to it.
+
+        Args:
+            sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
+            epoch (int): Current training epoch.
+            conditioning (list of dict[str, str], optional): List of conditions used during generation,
+                one per sample in the batch.
+            prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
+                [batch_size, channels, shape].
+            ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
+                Tensor of shape [batch_size, channels, shape].
+            generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
+        Returns:
+            samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
+        """
+        samples = []
+        for idx, wav in enumerate(samples_wavs):
+            prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
+            gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
+            conditions = conditioning[idx] if conditioning is not None else None
+            samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
+        return samples
+
+    def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
+                    exclude_unprompted: bool = False, exclude_conditioned: bool = False,
+                    exclude_unconditioned: bool = False) -> tp.Set[Sample]:
+        """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
+        Please note that existing samples are loaded during the manager's initialization, and added samples through this
+        manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
+        is the only way detect them.
+
+        Args:
+            epoch (int): If provided, only return samples corresponding to this epoch.
+            max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
+            exclude_prompted (bool): If True, does not include samples that used a prompt.
+            exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
+            exclude_conditioned (bool): If True, excludes samples that used conditioning.
+            exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
+        Returns:
+            Samples (set of Sample): The retrieved samples matching the provided filters.
+        """
+        if max_epoch >= 0:
+            samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
+        else:
+            samples_epoch = self.latest_epoch if epoch < 0 else epoch
+        samples = {
+            sample
+            for sample in self.samples
+            if (
+                (sample.epoch == samples_epoch) and
+                (not exclude_prompted or sample.prompt is None) and
+                (not exclude_unprompted or sample.prompt is not None) and
+                (not exclude_conditioned or not sample.conditioning) and
+                (not exclude_unconditioned or sample.conditioning)
+            )
+        }
+        return samples
+
+
+def slugify(value: tp.Any, allow_unicode: bool = False):
+    """Process string for safer file naming.
+
+    Taken from https://github.com/django/django/blob/master/django/utils/text.py
+
+    Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
+    dashes to single dashes. Remove characters that aren't alphanumerics,
+    underscores, or hyphens. Convert to lowercase. Also strip leading and
+    trailing whitespace, dashes, and underscores.
+    """
+    value = str(value)
+    if allow_unicode:
+        value = unicodedata.normalize("NFKC", value)
+    else:
+        value = (
+            unicodedata.normalize("NFKD", value)
+            .encode("ascii", "ignore")
+            .decode("ascii")
+        )
+    value = re.sub(r"[^\w\s-]", "", value.lower())
+    return re.sub(r"[-\s]+", "-", value).strip("-_")
+
+
+def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
+    # Create a dictionary of stable id -> sample per XP
+    stable_samples_per_xp = [{
+        sample.id: sample for sample in samples
+        if sample.prompt is not None or sample.conditioning
+    } for samples in samples_per_xp]
+    # Set of all stable ids
+    stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()}
+    # Dictionary of stable id -> list of samples. If an XP does not have it, assign None
+    stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids}
+    # Filter out ids that contain None values (we only want matched samples after all)
+    # cast is necessary to avoid mypy linter errors.
+    return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples}
+
+
+def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
+    # For unstable ids, we use a sorted list since we'll match them in order
+    unstable_samples_per_xp = [[
+        sample for sample in sorted(samples, key=lambda x: x.id)
+        if sample.prompt is None and not sample.conditioning
+    ] for samples in samples_per_xp]
+    # Trim samples per xp so all samples can have a match
+    min_len = min([len(samples) for samples in unstable_samples_per_xp])
+    unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp]
+    # Dictionary of index -> list of matched samples
+    return {
+        f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len)
+    }
+
+
+def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
+    """Gets a dictionary of matched samples across the given XPs.
+    Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
+    will always match the number of XPs provided and will correspond to each XP in the same order given.
+    In other words, only samples that can be match across all provided XPs will be returned
+    in order to satisfy this rule.
+
+    There are two types of ids that can be returned: stable and unstable.
+    * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
+      (prompts/conditioning). This is why we can match them across XPs.
+    * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
+      that used non-deterministic, random ids. This is the case for samples that did not use prompts or
+      conditioning for their generation. This function will sort these samples by their id and match them
+      by their index.
+
+    Args:
+        xps: a list of XPs to match samples from.
+        start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
+        end_epoch (int): If provided, only return samples corresponding to this epoch or older.
+        exclude_prompted (bool): If True, does not include samples that used a prompt.
+        exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
+        exclude_conditioned (bool): If True, excludes samples that used conditioning.
+        exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
+    """
+    managers = [SampleManager(xp) for xp in xps]
+    samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
+    stable_samples = _match_stable_samples(samples_per_xp)
+    unstable_samples = _match_unstable_samples(samples_per_xp)
+    return dict(stable_samples, **unstable_samples)
+
+
+
+
+
+
+
+

Functions

+
+
+def get_samples_for_xps(xps: List[dora.xp.XP], **kwargs) ‑> Dict[str, List[Sample]] +
+
+

Gets a dictionary of matched samples across the given XPs. +Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id +will always match the number of XPs provided and will correspond to each XP in the same order given. +In other words, only samples that can be match across all provided XPs will be returned +in order to satisfy this rule.

+

There are two types of ids that can be returned: stable and unstable. +* Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs +(prompts/conditioning). This is why we can match them across XPs. +* Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples +that used non-deterministic, random ids. This is the case for samples that did not use prompts or +conditioning for their generation. This function will sort these samples by their id and match them +by their index.

+

Args

+
+
xps
+
a list of XPs to match samples from.
+
start_epoch : int
+
If provided, only return samples corresponding to this epoch or newer.
+
end_epoch : int
+
If provided, only return samples corresponding to this epoch or older.
+
exclude_prompted : bool
+
If True, does not include samples that used a prompt.
+
exclude_unprompted : bool
+
If True, does not include samples that did not use a prompt.
+
exclude_conditioned : bool
+
If True, excludes samples that used conditioning.
+
exclude_unconditioned : bool
+
If True, excludes samples that did not use conditioning.
+
+
+ +Expand source code + +
def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
+    """Gets a dictionary of matched samples across the given XPs.
+    Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
+    will always match the number of XPs provided and will correspond to each XP in the same order given.
+    In other words, only samples that can be match across all provided XPs will be returned
+    in order to satisfy this rule.
+
+    There are two types of ids that can be returned: stable and unstable.
+    * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
+      (prompts/conditioning). This is why we can match them across XPs.
+    * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
+      that used non-deterministic, random ids. This is the case for samples that did not use prompts or
+      conditioning for their generation. This function will sort these samples by their id and match them
+      by their index.
+
+    Args:
+        xps: a list of XPs to match samples from.
+        start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
+        end_epoch (int): If provided, only return samples corresponding to this epoch or older.
+        exclude_prompted (bool): If True, does not include samples that used a prompt.
+        exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
+        exclude_conditioned (bool): If True, excludes samples that used conditioning.
+        exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
+    """
+    managers = [SampleManager(xp) for xp in xps]
+    samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
+    stable_samples = _match_stable_samples(samples_per_xp)
+    unstable_samples = _match_unstable_samples(samples_per_xp)
+    return dict(stable_samples, **unstable_samples)
+
+
+
+def slugify(value: Any, allow_unicode: bool = False) +
+
+

Process string for safer file naming.

+

Taken from https://github.com/django/django/blob/master/django/utils/text.py

+

Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated +dashes to single dashes. Remove characters that aren't alphanumerics, +underscores, or hyphens. Convert to lowercase. Also strip leading and +trailing whitespace, dashes, and underscores.

+
+ +Expand source code + +
def slugify(value: tp.Any, allow_unicode: bool = False):
+    """Process string for safer file naming.
+
+    Taken from https://github.com/django/django/blob/master/django/utils/text.py
+
+    Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
+    dashes to single dashes. Remove characters that aren't alphanumerics,
+    underscores, or hyphens. Convert to lowercase. Also strip leading and
+    trailing whitespace, dashes, and underscores.
+    """
+    value = str(value)
+    if allow_unicode:
+        value = unicodedata.normalize("NFKC", value)
+    else:
+        value = (
+            unicodedata.normalize("NFKD", value)
+            .encode("ascii", "ignore")
+            .decode("ascii")
+        )
+    value = re.sub(r"[^\w\s-]", "", value.lower())
+    return re.sub(r"[-\s]+", "-", value).strip("-_")
+
+
+
+
+
+

Classes

+
+
+class ReferenceSample +(id: str, path: str, duration: float) +
+
+

ReferenceSample(id: str, path: str, duration: float)

+
+ +Expand source code + +
class ReferenceSample:
+    id: str
+    path: str
+    duration: float
+
+

Class variables

+
+
var duration : float
+
+
+
+
var id : str
+
+
+
+
var path : str
+
+
+
+
+
+
+class Sample +(id: str, path: str, epoch: int, duration: float, conditioning: Optional[Dict[str, Any]], prompt: Optional[ReferenceSample], reference: Optional[ReferenceSample], generation_args: Optional[Dict[str, Any]]) +
+
+

Sample(id: str, path: str, epoch: int, duration: float, conditioning: Union[Dict[str, Any], NoneType], prompt: Union[audiocraft.utils.samples.manager.ReferenceSample, NoneType], reference: Union[audiocraft.utils.samples.manager.ReferenceSample, NoneType], generation_args: Union[Dict[str, Any], NoneType])

+
+ +Expand source code + +
class Sample:
+    id: str
+    path: str
+    epoch: int
+    duration: float
+    conditioning: tp.Optional[tp.Dict[str, tp.Any]]
+    prompt: tp.Optional[ReferenceSample]
+    reference: tp.Optional[ReferenceSample]
+    generation_args: tp.Optional[tp.Dict[str, tp.Any]]
+
+    def __hash__(self):
+        return hash(self.id)
+
+    def audio(self) -> tp.Tuple[torch.Tensor, int]:
+        return audio_read(self.path)
+
+    def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+        return audio_read(self.prompt.path) if self.prompt is not None else None
+
+    def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+        return audio_read(self.reference.path) if self.reference is not None else None
+
+

Class variables

+
+
var conditioning : Optional[Dict[str, Any]]
+
+
+
+
var duration : float
+
+
+
+
var epoch : int
+
+
+
+
var generation_args : Optional[Dict[str, Any]]
+
+
+
+
var id : str
+
+
+
+
var path : str
+
+
+
+
var prompt : Optional[ReferenceSample]
+
+
+
+
var reference : Optional[ReferenceSample]
+
+
+
+
+

Methods

+
+
+def audio(self) ‑> Tuple[torch.Tensor, int] +
+
+
+
+ +Expand source code + +
def audio(self) -> tp.Tuple[torch.Tensor, int]:
+    return audio_read(self.path)
+
+
+
+def audio_prompt(self) ‑> Optional[Tuple[torch.Tensor, int]] +
+
+
+
+ +Expand source code + +
def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+    return audio_read(self.prompt.path) if self.prompt is not None else None
+
+
+
+def audio_reference(self) ‑> Optional[Tuple[torch.Tensor, int]] +
+
+
+
+ +Expand source code + +
def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
+    return audio_read(self.reference.path) if self.reference is not None else None
+
+
+
+
+
+class SampleManager +(xp: dora.xp.XP, map_reference_to_sample_id: bool = False) +
+
+

Audio samples IO handling within a given dora xp.

+

The sample manager handles the dumping and loading logic for generated and +references samples across epochs for a given xp, providing a simple API to +store, retrieve and compare audio samples.

+

Args

+
+
xp : dora.XP
+
Dora experiment object. The XP contains information on the XP folder +where all outputs are stored and the configuration of the experiment, +which is useful to retrieve audio-related parameters.
+
map_reference_to_sample_id : bool
+
Whether to use the sample_id for all reference samples +instead of generating a dedicated hash id. This is useful to allow easier comparison +with ground truth sample from the files directly without having to read the JSON metadata +to do the mapping (at the cost of potentially dumping duplicate prompts/references +depending on the task).
+
+
+ +Expand source code + +
class SampleManager:
+    """Audio samples IO handling within a given dora xp.
+
+    The sample manager handles the dumping and loading logic for generated and
+    references samples across epochs for a given xp, providing a simple API to
+    store, retrieve and compare audio samples.
+
+    Args:
+        xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
+            where all outputs are stored and the configuration of the experiment,
+            which is useful to retrieve audio-related parameters.
+        map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
+            instead of generating a dedicated hash id. This is useful to allow easier comparison
+            with ground truth sample from the files directly without having to read the JSON metadata
+            to do the mapping (at the cost of potentially dumping duplicate prompts/references
+            depending on the task).
+    """
+    def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
+        self.xp = xp
+        self.base_folder: Path = xp.folder / xp.cfg.generate.path
+        self.reference_folder = self.base_folder / 'reference'
+        self.map_reference_to_sample_id = map_reference_to_sample_id
+        self.samples: tp.List[Sample] = []
+        self._load_samples()
+
+    @property
+    def latest_epoch(self):
+        """Latest epoch across all samples."""
+        return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
+
+    def _load_samples(self):
+        """Scan the sample folder and load existing samples."""
+        jsons = self.base_folder.glob('**/*.json')
+        with ThreadPoolExecutor(6) as pool:
+            self.samples = list(pool.map(self._load_sample, jsons))
+
+    @staticmethod
+    @lru_cache(2**26)
+    def _load_sample(json_file: Path) -> Sample:
+        with open(json_file, 'r') as f:
+            data: tp.Dict[str, tp.Any] = json.load(f)
+        # fetch prompt data
+        prompt_data = data.get('prompt')
+        prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
+                                 duration=prompt_data['duration']) if prompt_data else None
+        # fetch reference data
+        reference_data = data.get('reference')
+        reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
+                                    duration=reference_data['duration']) if reference_data else None
+        # build sample object
+        return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
+                      prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
+                      generation_args=data.get('generation_args'))
+
+    def _init_hash(self):
+        return hashlib.sha1()
+
+    def _get_tensor_id(self, tensor: torch.Tensor) -> str:
+        hash_id = self._init_hash()
+        hash_id.update(tensor.numpy().data)
+        return hash_id.hexdigest()
+
+    def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
+                       conditions: tp.Optional[tp.Dict[str, str]]) -> str:
+        """Computes an id for a sample given its input data.
+        This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
+        Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
+
+        Args:
+            index (int): Batch index, Helpful to differentiate samples from the same batch.
+            prompt_wav (torch.Tensor): Prompt used during generation.
+            conditions (dict[str, str]): Conditioning used during generation.
+        """
+        # For totally unconditioned generations we will just use a random UUID.
+        # The function get_samples_for_xps will do a simple ordered match with a custom key.
+        if prompt_wav is None and not conditions:
+            return f"noinput_{uuid.uuid4().hex}"
+
+        # Human readable portion
+        hr_label = ""
+        # Create a deterministic id using hashing
+        hash_id = self._init_hash()
+        hash_id.update(f"{index}".encode())
+        if prompt_wav is not None:
+            hash_id.update(prompt_wav.numpy().data)
+            hr_label += "_prompted"
+        else:
+            hr_label += "_unprompted"
+        if conditions:
+            encoded_json = json.dumps(conditions, sort_keys=True).encode()
+            hash_id.update(encoded_json)
+            cond_str = "-".join([f"{key}={slugify(value)}"
+                                 for key, value in sorted(conditions.items())])
+            cond_str = cond_str[:100]  # some raw text might be too long to be a valid filename
+            cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
+            hr_label += f"_{cond_str}"
+        else:
+            hr_label += "_unconditioned"
+
+        return hash_id.hexdigest() + hr_label
+
+    def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
+        """Stores the audio with the given stem path using the XP's configuration.
+
+        Args:
+            wav (torch.Tensor): Audio to store.
+            stem_path (Path): Path in sample output directory with file stem to use.
+            overwrite (bool): When False (default), skips storing an existing audio file.
+        Returns:
+            Path: The path at which the audio is stored.
+        """
+        existing_paths = [
+            path for path in stem_path.parent.glob(stem_path.stem + '.*')
+            if path.suffix != '.json'
+        ]
+        exists = len(existing_paths) > 0
+        if exists and overwrite:
+            logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
+        elif exists:
+            return existing_paths[0]
+
+        audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
+        return audio_path
+
+    def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
+                   conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
+                   ground_truth_wav: tp.Optional[torch.Tensor] = None,
+                   generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
+        """Adds a single sample.
+        The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
+        Each sample is assigned an id which is computed using the input data. In addition to the
+        sample itself, a json file containing associated metadata is stored next to it.
+
+        Args:
+            sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
+            epoch (int): current training epoch.
+            index (int): helpful to differentiate samples from the same batch.
+            conditions (dict[str, str], optional): conditioning used during generation.
+            prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
+            ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
+                Tensor of shape [channels, shape].
+            generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
+        Returns:
+            Sample: The saved sample.
+        """
+        sample_id = self._get_sample_id(index, prompt_wav, conditions)
+        reuse_id = self.map_reference_to_sample_id
+        prompt, ground_truth = None, None
+        if prompt_wav is not None:
+            prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
+            prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
+            prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
+            prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
+        if ground_truth_wav is not None:
+            ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
+            ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
+            ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
+            ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
+        sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
+        duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
+        sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
+        self.samples.append(sample)
+        with open(sample_path.with_suffix('.json'), 'w') as f:
+            json.dump(asdict(sample), f, indent=2)
+        return sample
+
+    def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
+                    conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
+                    prompt_wavs: tp.Optional[torch.Tensor] = None,
+                    ground_truth_wavs: tp.Optional[torch.Tensor] = None,
+                    generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
+        """Adds a batch of samples.
+        The samples are stored in the XP's sample output directory, under a corresponding
+        epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
+        In addition to the sample itself, a json file containing associated metadata is stored next to it.
+
+        Args:
+            sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
+            epoch (int): Current training epoch.
+            conditioning (list of dict[str, str], optional): List of conditions used during generation,
+                one per sample in the batch.
+            prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
+                [batch_size, channels, shape].
+            ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
+                Tensor of shape [batch_size, channels, shape].
+            generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
+        Returns:
+            samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
+        """
+        samples = []
+        for idx, wav in enumerate(samples_wavs):
+            prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
+            gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
+            conditions = conditioning[idx] if conditioning is not None else None
+            samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
+        return samples
+
+    def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
+                    exclude_unprompted: bool = False, exclude_conditioned: bool = False,
+                    exclude_unconditioned: bool = False) -> tp.Set[Sample]:
+        """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
+        Please note that existing samples are loaded during the manager's initialization, and added samples through this
+        manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
+        is the only way detect them.
+
+        Args:
+            epoch (int): If provided, only return samples corresponding to this epoch.
+            max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
+            exclude_prompted (bool): If True, does not include samples that used a prompt.
+            exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
+            exclude_conditioned (bool): If True, excludes samples that used conditioning.
+            exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
+        Returns:
+            Samples (set of Sample): The retrieved samples matching the provided filters.
+        """
+        if max_epoch >= 0:
+            samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
+        else:
+            samples_epoch = self.latest_epoch if epoch < 0 else epoch
+        samples = {
+            sample
+            for sample in self.samples
+            if (
+                (sample.epoch == samples_epoch) and
+                (not exclude_prompted or sample.prompt is None) and
+                (not exclude_unprompted or sample.prompt is not None) and
+                (not exclude_conditioned or not sample.conditioning) and
+                (not exclude_unconditioned or sample.conditioning)
+            )
+        }
+        return samples
+
+

Instance variables

+
+
var latest_epoch
+
+

Latest epoch across all samples.

+
+ +Expand source code + +
@property
+def latest_epoch(self):
+    """Latest epoch across all samples."""
+    return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
+
+
+
+

Methods

+
+
+def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0, conditions: Optional[Dict[str, str]] = None, prompt_wav: Optional[torch.Tensor] = None, ground_truth_wav: Optional[torch.Tensor] = None, generation_args: Optional[Dict[str, Any]] = None) ‑> Sample +
+
+

Adds a single sample. +The sample is stored in the XP's sample output directory, under a corresponding epoch folder. +Each sample is assigned an id which is computed using the input data. In addition to the +sample itself, a json file containing associated metadata is stored next to it.

+

Args

+
+
sample_wav : torch.Tensor
+
sample audio to store. Tensor of shape [channels, shape].
+
epoch : int
+
current training epoch.
+
index : int
+
helpful to differentiate samples from the same batch.
+
conditions : dict[str, str], optional
+
conditioning used during generation.
+
prompt_wav : torch.Tensor, optional
+
prompt used during generation. Tensor of shape [channels, shape].
+
ground_truth_wav : torch.Tensor, optional
+
reference audio where prompt was extracted from. +Tensor of shape [channels, shape].
+
generation_args : dict[str, any], optional
+
dictionary of other arguments used during generation.
+
+

Returns

+
+
Sample
+
The saved sample.
+
+
+ +Expand source code + +
def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
+               conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
+               ground_truth_wav: tp.Optional[torch.Tensor] = None,
+               generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
+    """Adds a single sample.
+    The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
+    Each sample is assigned an id which is computed using the input data. In addition to the
+    sample itself, a json file containing associated metadata is stored next to it.
+
+    Args:
+        sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
+        epoch (int): current training epoch.
+        index (int): helpful to differentiate samples from the same batch.
+        conditions (dict[str, str], optional): conditioning used during generation.
+        prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
+        ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
+            Tensor of shape [channels, shape].
+        generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
+    Returns:
+        Sample: The saved sample.
+    """
+    sample_id = self._get_sample_id(index, prompt_wav, conditions)
+    reuse_id = self.map_reference_to_sample_id
+    prompt, ground_truth = None, None
+    if prompt_wav is not None:
+        prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
+        prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
+        prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
+        prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
+    if ground_truth_wav is not None:
+        ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
+        ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
+        ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
+        ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
+    sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
+    duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
+    sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
+    self.samples.append(sample)
+    with open(sample_path.with_suffix('.json'), 'w') as f:
+        json.dump(asdict(sample), f, indent=2)
+    return sample
+
+
+
+def add_samples(self, samples_wavs: torch.Tensor, epoch: int, conditioning: Optional[List[Dict[str, Any]]] = None, prompt_wavs: Optional[torch.Tensor] = None, ground_truth_wavs: Optional[torch.Tensor] = None, generation_args: Optional[Dict[str, Any]] = None) ‑> List[Sample] +
+
+

Adds a batch of samples. +The samples are stored in the XP's sample output directory, under a corresponding +epoch folder. Each sample is assigned an id which is computed using the input data and their batch index. +In addition to the sample itself, a json file containing associated metadata is stored next to it.

+

Args

+
+
sample_wavs : torch.Tensor
+
Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
+
epoch : int
+
Current training epoch.
+
conditioning : list of dict[str, str], optional
+
List of conditions used during generation, +one per sample in the batch.
+
prompt_wavs : torch.Tensor, optional
+
Prompts used during generation. Tensor of shape +[batch_size, channels, shape].
+
ground_truth_wav : torch.Tensor, optional
+
Reference audio where prompts were extracted from. +Tensor of shape [batch_size, channels, shape].
+
generation_args : dict[str, Any], optional
+
Dictionary of other arguments used during generation.
+
+

Returns

+

samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.

+
+ +Expand source code + +
def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
+                conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
+                prompt_wavs: tp.Optional[torch.Tensor] = None,
+                ground_truth_wavs: tp.Optional[torch.Tensor] = None,
+                generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
+    """Adds a batch of samples.
+    The samples are stored in the XP's sample output directory, under a corresponding
+    epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
+    In addition to the sample itself, a json file containing associated metadata is stored next to it.
+
+    Args:
+        sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
+        epoch (int): Current training epoch.
+        conditioning (list of dict[str, str], optional): List of conditions used during generation,
+            one per sample in the batch.
+        prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
+            [batch_size, channels, shape].
+        ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
+            Tensor of shape [batch_size, channels, shape].
+        generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
+    Returns:
+        samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
+    """
+    samples = []
+    for idx, wav in enumerate(samples_wavs):
+        prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
+        gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
+        conditions = conditioning[idx] if conditioning is not None else None
+        samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
+    return samples
+
+
+
+def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False, exclude_unprompted: bool = False, exclude_conditioned: bool = False, exclude_unconditioned: bool = False) ‑> Set[Sample] +
+
+

Returns a set of samples for this XP. Optionally, you can filter which samples to obtain. +Please note that existing samples are loaded during the manager's initialization, and added samples through this +manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager +is the only way detect them.

+

Args

+
+
epoch : int
+
If provided, only return samples corresponding to this epoch.
+
max_epoch : int
+
If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
+
exclude_prompted : bool
+
If True, does not include samples that used a prompt.
+
exclude_unprompted : bool
+
If True, does not include samples that did not use a prompt.
+
exclude_conditioned : bool
+
If True, excludes samples that used conditioning.
+
exclude_unconditioned : bool
+
If True, excludes samples that did not use conditioning.
+
+

Returns

+

Samples (set of Sample): The retrieved samples matching the provided filters.

+
+ +Expand source code + +
def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
+                exclude_unprompted: bool = False, exclude_conditioned: bool = False,
+                exclude_unconditioned: bool = False) -> tp.Set[Sample]:
+    """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
+    Please note that existing samples are loaded during the manager's initialization, and added samples through this
+    manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
+    is the only way detect them.
+
+    Args:
+        epoch (int): If provided, only return samples corresponding to this epoch.
+        max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
+        exclude_prompted (bool): If True, does not include samples that used a prompt.
+        exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
+        exclude_conditioned (bool): If True, excludes samples that used conditioning.
+        exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
+    Returns:
+        Samples (set of Sample): The retrieved samples matching the provided filters.
+    """
+    if max_epoch >= 0:
+        samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
+    else:
+        samples_epoch = self.latest_epoch if epoch < 0 else epoch
+    samples = {
+        sample
+        for sample in self.samples
+        if (
+            (sample.epoch == samples_epoch) and
+            (not exclude_prompted or sample.prompt is None) and
+            (not exclude_unprompted or sample.prompt is not None) and
+            (not exclude_conditioned or not sample.conditioning) and
+            (not exclude_unconditioned or sample.conditioning)
+        )
+    }
+    return samples
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/api_docs/audiocraft/utils/utils.html b/api_docs/audiocraft/utils/utils.html new file mode 100644 index 00000000..0e89a93b --- /dev/null +++ b/api_docs/audiocraft/utils/utils.html @@ -0,0 +1,983 @@ + + + + + + +audiocraft.utils.utils API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.utils

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ProcessPoolExecutor
+from contextlib import contextmanager
+from functools import wraps, lru_cache
+import hashlib
+import json
+import logging
+from pathlib import Path
+import typing as tp
+
+import flashy
+import flashy.distrib
+import omegaconf
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+logger = logging.getLogger(__name__)
+
+
+def model_hash(model: torch.nn.Module) -> str:
+    """Return a model hash. This should allow us to track regressions in model init
+    from the logs of past experiments.
+    """
+    hasher = hashlib.sha1()
+    for p in model.parameters():
+        hasher.update(p.data.cpu().numpy().tobytes())
+    return hasher.hexdigest()
+
+
+def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
+    """Convenience function to map an omegaconf configuration to a dictionary.
+
+    Args:
+        cfg (omegaconf.DictConfig): Original configuration to map to dict.
+    Returns:
+        dict: Config as dictionary object.
+    """
+    dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
+    assert isinstance(dct, dict)
+    return dct
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
+    if max_samples >= len(dataset):
+        return dataset
+
+    generator = torch.Generator().manual_seed(seed)
+    perm = torch.randperm(len(dataset), generator=generator)
+    return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
+
+
+def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
+               num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
+    """Convenience function to load dataset into a dataloader with optional subset sampling.
+
+    Args:
+        dataset: Dataset to load.
+        num_samples (Optional[int]): Number of samples to limit subset size.
+        batch_size (int): Batch size.
+        num_workers (int): Number of workers for data loading.
+        seed (int): Random seed.
+    """
+    if num_samples is not None:
+        dataset = random_subset(dataset, num_samples, seed)
+
+    dataloader = flashy.distrib.loader(
+        dataset,
+        batch_size=batch_size,
+        num_workers=num_workers,
+        **kwargs
+    )
+    return dataloader
+
+
+def get_dataset_from_loader(dataloader):
+    dataset = dataloader.dataset
+    if isinstance(dataset, torch.utils.data.Subset):
+        return dataset.dataset
+    else:
+        return dataset
+
+
+def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
+    """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
+
+    Args:
+        input (torch.Tensor): The input tensor containing probabilities.
+        num_samples (int): Number of samples to draw.
+        replacement (bool): Whether to draw with replacement or not.
+    Keywords args:
+        generator (torch.Generator): A pseudorandom number generator for sampling.
+    Returns:
+        torch.Tensor: Last dimension contains num_samples indices
+            sampled from the multinomial probability distribution
+            located in the last dimension of tensor input.
+    """
+    input_ = input.reshape(-1, input.shape[-1])
+    output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
+    output = output_.reshape(*list(input.shape[:-1]), -1)
+    return output
+
+
+def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
+    """Sample next token from top K values along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        k (int): The k in “top-k”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    top_k_value, _ = torch.topk(probs, k, dim=-1)
+    min_value_top_k = top_k_value[..., [-1]]
+    probs *= (probs >= min_value_top_k).float()
+    probs.div_(probs.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs, num_samples=1)
+    return next_token
+
+
+def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
+    """Sample next token from top P probabilities along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        p (int): The p in “top-p”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort *= (~mask).float()
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token
+
+
+class DummyPoolExecutor:
+    """Dummy pool executor to use when we actually have only 1 worker.
+    (e.g. instead of ProcessPoolExecutor).
+    """
+    class DummyResult:
+        def __init__(self, func, *args, **kwargs):
+            self.func = func
+            self.args = args
+            self.kwargs = kwargs
+
+        def result(self):
+            return self.func(*self.args, **self.kwargs)
+
+    def __init__(self, workers, mp_context=None):
+        pass
+
+    def submit(self, func, *args, **kwargs):
+        return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        return
+
+
+def get_pool_executor(num_workers: int, mp_context=None):
+    return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
+
+
+def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
+    """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
+    For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
+
+    Args:
+        lengths (torch.Tensor): tensor with lengths
+        max_len (int): can set the max length manually. Defaults to None.
+    Returns:
+        torch.Tensor: mask with 0s where there is pad tokens else 1s
+    """
+    assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
+    final_length = lengths.max().item() if not max_len else max_len
+    final_length = max(final_length, 1)  # if all seqs are of len zero we don't want a zero-size tensor
+    return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
+
+
+def hash_trick(word: str, vocab_size: int) -> int:
+    """Hash trick to pair each word with an index
+
+    Args:
+        word (str): word we wish to convert to an index
+        vocab_size (int): size of the vocabulary
+    Returns:
+        int: index of the word in the embedding LUT
+    """
+    hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
+    return hash % vocab_size
+
+
+def with_rank_rng(base_seed: int = 1234):
+    """Decorator for a function so that the function will use a Random Number Generator
+    whose state depend on the GPU rank. The original RNG state is restored upon returning.
+
+    Args:
+        base_seed (int): Random seed.
+    """
+    def _decorator(fun: tp.Callable):
+        @wraps(fun)
+        def _decorated(*args, **kwargs):
+            state = torch.get_rng_state()
+            seed = base_seed ^ flashy.distrib.rank()
+            torch.manual_seed(seed)
+            logger.debug('Rank dependent seed set to %d', seed)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                torch.set_rng_state(state)
+                logger.debug('RNG state restored.')
+        return _decorated
+    return _decorator
+
+
+def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+    """Get a list of tensors and collate them to a single tensor. according to the following logic:
+    - `dim` specifies the time dimension which will be stacked and padded.
+    - The output will contain 1 new dimension (dimension index 0) which will be the size of
+    of the original list.
+
+    Args:
+        tensors (tp.List[torch.Tensor]): List of tensors to collate.
+        dim (int): Dimension which will be stacked and padded.
+    Returns:
+        tp.Tuple[torch.Tensor, torch.Tensor]:
+            torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
+                (dimension index 0) which will be the size of the original list.
+            torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+    """
+    tensors = [x.transpose(0, dim) for x in tensors]
+    lens = torch.LongTensor([len(x) for x in tensors])
+    padded_tensors = pad_sequence(tensors)
+    padded_tensors = padded_tensors.transpose(0, 1)
+    padded_tensors = padded_tensors.transpose(1, dim + 1)
+    return padded_tensors, lens
+
+
+# TODO: Move to flashy?
+def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
+               dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
+    if isinstance(state, torch.Tensor):
+        if dtype is None or not state.is_floating_point():
+            dtype = state.dtype
+        return state.detach().to(device=device, dtype=dtype, copy=True)
+    elif isinstance(state, dict):
+        return {k: copy_state(v, device, dtype) for k, v in state.items()}
+    elif isinstance(state, list):
+        return [copy_state(v, device, dtype) for v in state]
+
+
+# TODO: Move to flashy?
+@contextmanager
+def swap_state(model, state, **kwargs):
+    old_state = copy_state(model.state_dict())
+    model.load_state_dict(state, **kwargs)
+    try:
+        yield
+    finally:
+        model.load_state_dict(old_state)
+
+
+@lru_cache(None)
+def warn_once(logger, msg):
+    """Warn about a given message only once."""
+    logger.warning(msg)
+
+
+def is_jsonable(x: tp.Any):
+    """Check if an object can be serialized into a json:"""
+    try:
+        json.dumps(x)
+        return True
+    except (TypeError, OverflowError):
+        return False
+
+
+def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
+    """Wrapper around state dict loading of CLAP model
+    addressing compatibility issues between CLAP and AudioCraft
+    HuggingFace transformer version.
+    See: https://github.com/LAION-AI/CLAP/issues/118
+    """
+    from clap_module.factory import load_state_dict  # type: ignore
+    pkg = load_state_dict(path)
+    pkg.pop('text_branch.embeddings.position_ids', None)
+    clap_model.model.load_state_dict(pkg)
+
+
+
+
+
+
+
+

Functions

+
+
+def collate(tensors: List[torch.Tensor], dim: int = 0) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Get a list of tensors and collate them to a single tensor. according to the following logic: +- dim specifies the time dimension which will be stacked and padded. +- The output will contain 1 new dimension (dimension index 0) which will be the size of +of the original list.

+

Args

+
+
tensors : tp.List[torch.Tensor]
+
List of tensors to collate.
+
dim : int
+
Dimension which will be stacked and padded.
+
+

Returns

+
+
tp.Tuple[torch.Tensor, torch.Tensor]:
+
+torch.Tensor
+
Stacked and padded tensor. The output will contain 1 new dimension +(dimension index 0) which will be the size of the original list. +torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+
+
+ +Expand source code + +
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+    """Get a list of tensors and collate them to a single tensor. according to the following logic:
+    - `dim` specifies the time dimension which will be stacked and padded.
+    - The output will contain 1 new dimension (dimension index 0) which will be the size of
+    of the original list.
+
+    Args:
+        tensors (tp.List[torch.Tensor]): List of tensors to collate.
+        dim (int): Dimension which will be stacked and padded.
+    Returns:
+        tp.Tuple[torch.Tensor, torch.Tensor]:
+            torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
+                (dimension index 0) which will be the size of the original list.
+            torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+    """
+    tensors = [x.transpose(0, dim) for x in tensors]
+    lens = torch.LongTensor([len(x) for x in tensors])
+    padded_tensors = pad_sequence(tensors)
+    padded_tensors = padded_tensors.transpose(0, 1)
+    padded_tensors = padded_tensors.transpose(1, dim + 1)
+    return padded_tensors, lens
+
+
+
+def copy_state(state: Any, device: Union[torch.device, str] = 'cpu', dtype: Optional[torch.dtype] = None) ‑> Any +
+
+
+
+ +Expand source code + +
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
+               dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
+    if isinstance(state, torch.Tensor):
+        if dtype is None or not state.is_floating_point():
+            dtype = state.dtype
+        return state.detach().to(device=device, dtype=dtype, copy=True)
+    elif isinstance(state, dict):
+        return {k: copy_state(v, device, dtype) for k, v in state.items()}
+    elif isinstance(state, list):
+        return [copy_state(v, device, dtype) for v in state]
+
+
+
+def dict_from_config(cfg: omegaconf.dictconfig.DictConfig) ‑> dict +
+
+

Convenience function to map an omegaconf configuration to a dictionary.

+

Args

+
+
cfg : omegaconf.DictConfig
+
Original configuration to map to dict.
+
+

Returns

+
+
dict
+
Config as dictionary object.
+
+
+ +Expand source code + +
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
+    """Convenience function to map an omegaconf configuration to a dictionary.
+
+    Args:
+        cfg (omegaconf.DictConfig): Original configuration to map to dict.
+    Returns:
+        dict: Config as dictionary object.
+    """
+    dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
+    assert isinstance(dct, dict)
+    return dct
+
+
+
+def get_dataset_from_loader(dataloader) +
+
+
+
+ +Expand source code + +
def get_dataset_from_loader(dataloader):
+    dataset = dataloader.dataset
+    if isinstance(dataset, torch.utils.data.Subset):
+        return dataset.dataset
+    else:
+        return dataset
+
+
+
+def get_loader(dataset, num_samples: Optional[int], batch_size: int, num_workers: int, seed: int, **kwargs) ‑> torch.utils.data.dataloader.DataLoader +
+
+

Convenience function to load dataset into a dataloader with optional subset sampling.

+

Args

+
+
dataset
+
Dataset to load.
+
num_samples : Optional[int]
+
Number of samples to limit subset size.
+
batch_size : int
+
Batch size.
+
num_workers : int
+
Number of workers for data loading.
+
seed : int
+
Random seed.
+
+
+ +Expand source code + +
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
+               num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
+    """Convenience function to load dataset into a dataloader with optional subset sampling.
+
+    Args:
+        dataset: Dataset to load.
+        num_samples (Optional[int]): Number of samples to limit subset size.
+        batch_size (int): Batch size.
+        num_workers (int): Number of workers for data loading.
+        seed (int): Random seed.
+    """
+    if num_samples is not None:
+        dataset = random_subset(dataset, num_samples, seed)
+
+    dataloader = flashy.distrib.loader(
+        dataset,
+        batch_size=batch_size,
+        num_workers=num_workers,
+        **kwargs
+    )
+    return dataloader
+
+
+
+def get_pool_executor(num_workers: int, mp_context=None) +
+
+
+
+ +Expand source code + +
def get_pool_executor(num_workers: int, mp_context=None):
+    return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
+
+
+
+def hash_trick(word: str, vocab_size: int) ‑> int +
+
+

Hash trick to pair each word with an index

+

Args

+
+
word : str
+
word we wish to convert to an index
+
vocab_size : int
+
size of the vocabulary
+
+

Returns

+
+
int
+
index of the word in the embedding LUT
+
+
+ +Expand source code + +
def hash_trick(word: str, vocab_size: int) -> int:
+    """Hash trick to pair each word with an index
+
+    Args:
+        word (str): word we wish to convert to an index
+        vocab_size (int): size of the vocabulary
+    Returns:
+        int: index of the word in the embedding LUT
+    """
+    hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
+    return hash % vocab_size
+
+
+
+def is_jsonable(x: Any) +
+
+

Check if an object can be serialized into a json:

+
+ +Expand source code + +
def is_jsonable(x: tp.Any):
+    """Check if an object can be serialized into a json:"""
+    try:
+        json.dumps(x)
+        return True
+    except (TypeError, OverflowError):
+        return False
+
+
+
+def length_to_mask(lengths: torch.Tensor, max_len: Optional[int] = None) ‑> torch.Tensor +
+
+

Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). +For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]

+

Args

+
+
lengths : torch.Tensor
+
tensor with lengths
+
max_len : int
+
can set the max length manually. Defaults to None.
+
+

Returns

+
+
torch.Tensor
+
mask with 0s where there is pad tokens else 1s
+
+
+ +Expand source code + +
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
+    """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
+    For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
+
+    Args:
+        lengths (torch.Tensor): tensor with lengths
+        max_len (int): can set the max length manually. Defaults to None.
+    Returns:
+        torch.Tensor: mask with 0s where there is pad tokens else 1s
+    """
+    assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
+    final_length = lengths.max().item() if not max_len else max_len
+    final_length = max(final_length, 1)  # if all seqs are of len zero we don't want a zero-size tensor
+    return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
+
+
+
+def load_clap_state_dict(clap_model, path: Union[str, pathlib.Path]) +
+
+

Wrapper around state dict loading of CLAP model +addressing compatibility issues between CLAP and AudioCraft +HuggingFace transformer version. +See: https://github.com/LAION-AI/CLAP/issues/118

+
+ +Expand source code + +
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
+    """Wrapper around state dict loading of CLAP model
+    addressing compatibility issues between CLAP and AudioCraft
+    HuggingFace transformer version.
+    See: https://github.com/LAION-AI/CLAP/issues/118
+    """
+    from clap_module.factory import load_state_dict  # type: ignore
+    pkg = load_state_dict(path)
+    pkg.pop('text_branch.embeddings.position_ids', None)
+    clap_model.model.load_state_dict(pkg)
+
+
+
+def model_hash(model: torch.nn.modules.module.Module) ‑> str +
+
+

Return a model hash. This should allow us to track regressions in model init +from the logs of past experiments.

+
+ +Expand source code + +
def model_hash(model: torch.nn.Module) -> str:
+    """Return a model hash. This should allow us to track regressions in model init
+    from the logs of past experiments.
+    """
+    hasher = hashlib.sha1()
+    for p in model.parameters():
+        hasher.update(p.data.cpu().numpy().tobytes())
+    return hasher.hexdigest()
+
+
+
+def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None) +
+
+

torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.

+

Args

+
+
input : torch.Tensor
+
The input tensor containing probabilities.
+
num_samples : int
+
Number of samples to draw.
+
replacement : bool
+
Whether to draw with replacement or not.
+
+

Keywords args: +generator (torch.Generator): A pseudorandom number generator for sampling.

+

Returns

+
+
torch.Tensor
+
Last dimension contains num_samples indices +sampled from the multinomial probability distribution +located in the last dimension of tensor input.
+
+
+ +Expand source code + +
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
+    """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
+
+    Args:
+        input (torch.Tensor): The input tensor containing probabilities.
+        num_samples (int): Number of samples to draw.
+        replacement (bool): Whether to draw with replacement or not.
+    Keywords args:
+        generator (torch.Generator): A pseudorandom number generator for sampling.
+    Returns:
+        torch.Tensor: Last dimension contains num_samples indices
+            sampled from the multinomial probability distribution
+            located in the last dimension of tensor input.
+    """
+    input_ = input.reshape(-1, input.shape[-1])
+    output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
+    output = output_.reshape(*list(input.shape[:-1]), -1)
+    return output
+
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42) ‑> torch.utils.data.dataset.Subset +
+
+
+
+ +Expand source code + +
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
+    if max_samples >= len(dataset):
+        return dataset
+
+    generator = torch.Generator().manual_seed(seed)
+    perm = torch.randperm(len(dataset), generator=generator)
+    return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
+
+
+
+def sample_top_k(probs: torch.Tensor, k: int) ‑> torch.Tensor +
+
+

Sample next token from top K values along the last dimension of the input probs tensor.

+

Args

+
+
probs : torch.Tensor
+
Input probabilities with token candidates on the last dimension.
+
k : int
+
The k in “top-k”.
+
+

Returns

+
+
torch.Tensor
+
Sampled tokens.
+
+
+ +Expand source code + +
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
+    """Sample next token from top K values along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        k (int): The k in “top-k”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    top_k_value, _ = torch.topk(probs, k, dim=-1)
+    min_value_top_k = top_k_value[..., [-1]]
+    probs *= (probs >= min_value_top_k).float()
+    probs.div_(probs.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs, num_samples=1)
+    return next_token
+
+
+
+def sample_top_p(probs: torch.Tensor, p: float) ‑> torch.Tensor +
+
+

Sample next token from top P probabilities along the last dimension of the input probs tensor.

+

Args

+
+
probs : torch.Tensor
+
Input probabilities with token candidates on the last dimension.
+
p : int
+
The p in “top-p”.
+
+

Returns

+
+
torch.Tensor
+
Sampled tokens.
+
+
+ +Expand source code + +
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
+    """Sample next token from top P probabilities along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        p (int): The p in “top-p”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort *= (~mask).float()
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token
+
+
+
+def swap_state(model, state, **kwargs) +
+
+
+
+ +Expand source code + +
@contextmanager
+def swap_state(model, state, **kwargs):
+    old_state = copy_state(model.state_dict())
+    model.load_state_dict(state, **kwargs)
+    try:
+        yield
+    finally:
+        model.load_state_dict(old_state)
+
+
+
+def warn_once(logger, msg) +
+
+

Warn about a given message only once.

+
+ +Expand source code + +
@lru_cache(None)
+def warn_once(logger, msg):
+    """Warn about a given message only once."""
+    logger.warning(msg)
+
+
+
+def with_rank_rng(base_seed: int = 1234) +
+
+

Decorator for a function so that the function will use a Random Number Generator +whose state depend on the GPU rank. The original RNG state is restored upon returning.

+

Args

+
+
base_seed : int
+
Random seed.
+
+
+ +Expand source code + +
def with_rank_rng(base_seed: int = 1234):
+    """Decorator for a function so that the function will use a Random Number Generator
+    whose state depend on the GPU rank. The original RNG state is restored upon returning.
+
+    Args:
+        base_seed (int): Random seed.
+    """
+    def _decorator(fun: tp.Callable):
+        @wraps(fun)
+        def _decorated(*args, **kwargs):
+            state = torch.get_rng_state()
+            seed = base_seed ^ flashy.distrib.rank()
+            torch.manual_seed(seed)
+            logger.debug('Rank dependent seed set to %d', seed)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                torch.set_rng_state(state)
+                logger.debug('RNG state restored.')
+        return _decorated
+    return _decorator
+
+
+
+
+
+

Classes

+
+
+class DummyPoolExecutor +(workers, mp_context=None) +
+
+

Dummy pool executor to use when we actually have only 1 worker. +(e.g. instead of ProcessPoolExecutor).

+
+ +Expand source code + +
class DummyPoolExecutor:
+    """Dummy pool executor to use when we actually have only 1 worker.
+    (e.g. instead of ProcessPoolExecutor).
+    """
+    class DummyResult:
+        def __init__(self, func, *args, **kwargs):
+            self.func = func
+            self.args = args
+            self.kwargs = kwargs
+
+        def result(self):
+            return self.func(*self.args, **self.kwargs)
+
+    def __init__(self, workers, mp_context=None):
+        pass
+
+    def submit(self, func, *args, **kwargs):
+        return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        return
+
+

Class variables

+
+
var DummyResult
+
+
+
+
+

Methods

+
+
+def submit(self, func, *args, **kwargs) +
+
+
+
+ +Expand source code + +
def submit(self, func, *args, **kwargs):
+    return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file