Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feed forward layer, frontend and encoder #53

Merged
merged 12 commits into from
May 24, 2024
1 change: 1 addition & 0 deletions i6_models/assemblies/ffnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ffnn_v1 import *
42 changes: 42 additions & 0 deletions i6_models/assemblies/ffnn/ffnn_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
__all__ = ["FeedForwardEncoderV1Config", "FeedForwardEncoderV1"]

from typing import Tuple
from dataclasses import dataclass
import torch
from torch import nn

from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config
from i6_models.config import ModelConfiguration, ModuleFactoryV1


@dataclass
class FeedForwardEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: number of feed-forward layers
frontend: module factory for the frontend
layer_cfg: configuration object for each feed-forward layer
"""

num_layers: int
frontend: ModuleFactoryV1
layer_cfg: FeedForwardLayerV1Config


class FeedForwardEncoderV1(nn.Module):
"""
Simple feed-forward encoder.
Subsampling can be achieved by setting stride > 1 in the frontend config.
"""

def __init__(self, cfg: FeedForwardEncoderV1Config):
super().__init__()
self.frontend = cfg.frontend()
self.module_list = nn.ModuleList([FeedForwardLayerV1(cfg.layer_cfg) for _ in range(cfg.num_layers)])

def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F']
for module in self.module_list:
x, sequence_mask = module(x, sequence_mask) # [B, T, F']

return x, sequence_mask
59 changes: 59 additions & 0 deletions i6_models/parts/ffnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
__all__ = ["FeedForwardConfig", "FeedForwardModel"]

from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F

from i6_models.config import ModelConfiguration


@dataclass
class FeedForwardLayerV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input feature dimension
output_dim: output feature dimension
dropout: dropout probability
activation: activation function applied after linear computation
"""

input_dim: int
output_dim: int
dropout: float
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]

def __post_init__(self):
super().__post_init__()
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability"


class FeedForwardLayerV1(nn.Module):
"""
Simple feed-forward layer module consisting of:
- linear
- activation
- dropout
"""

def __init__(self, cfg: FeedForwardLayerV1Config):
super().__init__()
self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True)
self.activation = cfg.activation
self.dropout = nn.Dropout(cfg.dropout)

def forward(
self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param tensor: shape [B,T,F], F=input_dim
:param sequence_mask: shape [B,T]
:return: shape [B,T,F'], F'=output_dim
"""
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]
tensor = self.dropout(tensor) # [B,T,F]
return tensor, sequence_mask
90 changes: 90 additions & 0 deletions i6_models/parts/frontend/window_convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
__all__ = [
"WindowConvolutionFrontendV1Config",
"WindowConvolutionFrontendV1",
]

from dataclasses import dataclass
from typing import Callable, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from i6_models.config import ModelConfiguration

from .common import mask_pool, apply_same_padding


@dataclass
class WindowConvolutionFrontendV1Config(ModelConfiguration):
"""
Attributes:
input_dim: number of input features to module
output_dim: output dimension
dropout: dropout after linear layer
kernel_size: number of feature frames to convolve
stride: skip (stride - 1) feature frames; stride > 1 implies subsampling
activation: activation function applied after linear computation
"""

input_dim: int
output_dim: int
dropout: float
kernel_size: int
stride: int
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]

def __post_init__(self):
super().__post_init__()
assert self.stride >= 1, "Choose an integer >= 1 for stride"
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability"


class WindowConvolutionFrontendV1(nn.Module):
"""
Simple feed-forward front-end that computes over a window
of input features. Choosing a stride > 1 allows for subsampling
of the features.
"""

def __init__(self, cfg: WindowConvolutionFrontendV1Config):
"""
:param cfg: model configuration for this module
"""
super().__init__()
self.conv = torch.nn.Conv1d(
in_channels=cfg.input_dim,
out_channels=cfg.output_dim,
kernel_size=cfg.kernel_size,
stride=cfg.stride,
padding=0,
bias=True,
)
self.activation = cfg.activation
self.pad = lambda x: apply_same_padding(x, cfg.kernel_size)
self.dropout = torch.nn.Dropout(cfg.dropout)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
T might be reduced to T' on stride

:param x: input tensor of shape [B,T,F]
:param sequence_mask: the sequence mask for the tensor
:return: torch.Tensor of shape [B,T',F'] and the shape of the sequence mask
"""
# torch 1d convolution is over last dim but we want time conv
x = x.transpose(1, 2) # [B, F, T]
x = self.pad(x)
x = self.conv(x).transpose(1, 2) # [B, T', F']

# change masking according to stride value
sequence_mask = mask_pool(
sequence_mask,
kernel_size=1,
stride=self.conv.stride[0],
padding=0, # done manually
)
x = self.activation(x)
x = self.dropout(x)

return x, sequence_mask
62 changes: 62 additions & 0 deletions tests/test_ffnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from itertools import product

import torch
from torch import nn
from torch.nn import functional as F

from i6_models.assemblies.ffnn import (
FeedForwardEncoderV1,
FeedForwardEncoderV1Config,
)

from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config


def test_output_shape():
input_dim = 80
output_dim = 2048
dropout = 0.1
max_seq_lens = 100

for window_size, stride in product(range(1, 22), range(1, 5)):
frontend = ModuleFactoryV1(
WindowConvolutionFrontendV1,
WindowConvolutionFrontendV1Config(
input_dim=80,
output_dim=output_dim,
kernel_size=window_size,
dropout=dropout,
stride=stride,
activation=F.relu,
),
)

layer_cfg = FeedForwardLayerV1Config(
input_dim=2048,
output_dim=2048,
dropout=0.1,
activation=F.relu,
)

encoder_cfg = FeedForwardEncoderV1Config(num_layers=6, layer_cfg=layer_cfg, frontend=frontend)

encoder = FeedForwardEncoderV1(encoder_cfg)

feat_len = torch.arange(start=1, end=max_seq_lens + 1)
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None])

features = torch.empty((max_seq_lens, max_seq_lens, input_dim))

out, out_mask = encoder(features, mask)

expected_out_len = (feat_len - 1) // stride + 1
expected_shape = (max_seq_lens, expected_out_len[-1], output_dim)
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}"
for i in range(expected_out_len[-1] - 1):
# check if masks are correct
assert (
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]]
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}"
42 changes: 42 additions & 0 deletions tests/test_window_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from itertools import product

import torch
from torch import nn
from torch.nn import functional as F

from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1


def test_output_shape():
in_features = 80
out_features = 2048
dropout = 0.1
max_seq_lens = 100

for window_size, stride in product(range(1, 22), range(1, 5)):
frontend = WindowConvolutionFrontendV1(
WindowConvolutionFrontendV1Config(
input_dim=80,
output_dim=out_features,
kernel_size=window_size,
dropout=dropout,
stride=stride,
activation=F.relu,
)
)

feat_len = torch.arange(start=1, end=max_seq_lens + 1)
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None])

features = torch.empty((max_seq_lens, max_seq_lens, in_features))

out, out_mask = frontend(features, mask)

expected_out_len = (feat_len - 1) // stride + 1
expected_shape = (max_seq_lens, expected_out_len[-1], out_features)
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}"
for i in range(expected_out_len[-1] - 1):
# check if masks are correct
assert (
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]]
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}"
Loading