Skip to content

Commit

Permalink
Feed forward layer, frontend and encoder (#53)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Daniel Mann <[email protected]>
Co-authored-by: michelwi <[email protected]>
Co-authored-by: Benedikt Hilmes <[email protected]>
  • Loading branch information
4 people authored May 24, 2024
1 parent 1264482 commit 83ff39e
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 0 deletions.
1 change: 1 addition & 0 deletions i6_models/assemblies/ffnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ffnn_v1 import *
42 changes: 42 additions & 0 deletions i6_models/assemblies/ffnn/ffnn_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
__all__ = ["FeedForwardEncoderV1Config", "FeedForwardEncoderV1"]

from typing import Tuple
from dataclasses import dataclass
import torch
from torch import nn

from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config
from i6_models.config import ModelConfiguration, ModuleFactoryV1


@dataclass
class FeedForwardEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: number of feed-forward layers
frontend: module factory for the frontend
layer_cfg: configuration object for each feed-forward layer
"""

num_layers: int
frontend: ModuleFactoryV1
layer_cfg: FeedForwardLayerV1Config


class FeedForwardEncoderV1(nn.Module):
"""
Simple feed-forward encoder.
Subsampling can be achieved by setting stride > 1 in the frontend config.
"""

def __init__(self, cfg: FeedForwardEncoderV1Config):
super().__init__()
self.frontend = cfg.frontend()
self.module_list = nn.ModuleList([FeedForwardLayerV1(cfg.layer_cfg) for _ in range(cfg.num_layers)])

def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F']
for module in self.module_list:
x, sequence_mask = module(x, sequence_mask) # [B, T, F']

return x, sequence_mask
59 changes: 59 additions & 0 deletions i6_models/parts/ffnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
__all__ = ["FeedForwardConfig", "FeedForwardModel"]

from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F

from i6_models.config import ModelConfiguration


@dataclass
class FeedForwardLayerV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input feature dimension
output_dim: output feature dimension
dropout: dropout probability
activation: activation function applied after linear computation
"""

input_dim: int
output_dim: int
dropout: float
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]

def __post_init__(self):
super().__post_init__()
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability"


class FeedForwardLayerV1(nn.Module):
"""
Simple feed-forward layer module consisting of:
- linear
- activation
- dropout
"""

def __init__(self, cfg: FeedForwardLayerV1Config):
super().__init__()
self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True)
self.activation = cfg.activation
self.dropout = nn.Dropout(cfg.dropout)

def forward(
self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param tensor: shape [B,T,F], F=input_dim
:param sequence_mask: shape [B,T]
:return: shape [B,T,F'], F'=output_dim
"""
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]
tensor = self.dropout(tensor) # [B,T,F]
return tensor, sequence_mask
90 changes: 90 additions & 0 deletions i6_models/parts/frontend/window_convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
__all__ = [
"WindowConvolutionFrontendV1Config",
"WindowConvolutionFrontendV1",
]

from dataclasses import dataclass
from typing import Callable, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from i6_models.config import ModelConfiguration

from .common import mask_pool, apply_same_padding


@dataclass
class WindowConvolutionFrontendV1Config(ModelConfiguration):
"""
Attributes:
input_dim: number of input features to module
output_dim: output dimension
dropout: dropout after linear layer
kernel_size: number of feature frames to convolve
stride: skip (stride - 1) feature frames; stride > 1 implies subsampling
activation: activation function applied after linear computation
"""

input_dim: int
output_dim: int
dropout: float
kernel_size: int
stride: int
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]

def __post_init__(self):
super().__post_init__()
assert self.stride >= 1, "Choose an integer >= 1 for stride"
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability"


class WindowConvolutionFrontendV1(nn.Module):
"""
Simple feed-forward front-end that computes over a window
of input features. Choosing a stride > 1 allows for subsampling
of the features.
"""

def __init__(self, cfg: WindowConvolutionFrontendV1Config):
"""
:param cfg: model configuration for this module
"""
super().__init__()
self.conv = torch.nn.Conv1d(
in_channels=cfg.input_dim,
out_channels=cfg.output_dim,
kernel_size=cfg.kernel_size,
stride=cfg.stride,
padding=0,
bias=True,
)
self.activation = cfg.activation
self.pad = lambda x: apply_same_padding(x, cfg.kernel_size)
self.dropout = torch.nn.Dropout(cfg.dropout)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
T might be reduced to T' on stride
:param x: input tensor of shape [B,T,F]
:param sequence_mask: the sequence mask for the tensor
:return: torch.Tensor of shape [B,T',F'] and the shape of the sequence mask
"""
# torch 1d convolution is over last dim but we want time conv
x = x.transpose(1, 2) # [B, F, T]
x = self.pad(x)
x = self.conv(x).transpose(1, 2) # [B, T', F']

# change masking according to stride value
sequence_mask = mask_pool(
sequence_mask,
kernel_size=1,
stride=self.conv.stride[0],
padding=0, # done manually
)
x = self.activation(x)
x = self.dropout(x)

return x, sequence_mask
62 changes: 62 additions & 0 deletions tests/test_ffnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from itertools import product

import torch
from torch import nn
from torch.nn import functional as F

from i6_models.assemblies.ffnn import (
FeedForwardEncoderV1,
FeedForwardEncoderV1Config,
)

from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config


def test_output_shape():
input_dim = 80
output_dim = 2048
dropout = 0.1
max_seq_lens = 100

for window_size, stride in product(range(1, 22), range(1, 5)):
frontend = ModuleFactoryV1(
WindowConvolutionFrontendV1,
WindowConvolutionFrontendV1Config(
input_dim=80,
output_dim=output_dim,
kernel_size=window_size,
dropout=dropout,
stride=stride,
activation=F.relu,
),
)

layer_cfg = FeedForwardLayerV1Config(
input_dim=2048,
output_dim=2048,
dropout=0.1,
activation=F.relu,
)

encoder_cfg = FeedForwardEncoderV1Config(num_layers=6, layer_cfg=layer_cfg, frontend=frontend)

encoder = FeedForwardEncoderV1(encoder_cfg)

feat_len = torch.arange(start=1, end=max_seq_lens + 1)
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None])

features = torch.empty((max_seq_lens, max_seq_lens, input_dim))

out, out_mask = encoder(features, mask)

expected_out_len = (feat_len - 1) // stride + 1
expected_shape = (max_seq_lens, expected_out_len[-1], output_dim)
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}"
for i in range(expected_out_len[-1] - 1):
# check if masks are correct
assert (
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]]
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}"
42 changes: 42 additions & 0 deletions tests/test_window_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from itertools import product

import torch
from torch import nn
from torch.nn import functional as F

from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1


def test_output_shape():
in_features = 80
out_features = 2048
dropout = 0.1
max_seq_lens = 100

for window_size, stride in product(range(1, 22), range(1, 5)):
frontend = WindowConvolutionFrontendV1(
WindowConvolutionFrontendV1Config(
input_dim=80,
output_dim=out_features,
kernel_size=window_size,
dropout=dropout,
stride=stride,
activation=F.relu,
)
)

feat_len = torch.arange(start=1, end=max_seq_lens + 1)
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None])

features = torch.empty((max_seq_lens, max_seq_lens, in_features))

out, out_mask = frontend(features, mask)

expected_out_len = (feat_len - 1) // stride + 1
expected_shape = (max_seq_lens, expected_out_len[-1], out_features)
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}"
for i in range(expected_out_len[-1] - 1):
# check if masks are correct
assert (
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]]
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}"

0 comments on commit 83ff39e

Please sign in to comment.