-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feed forward layer, frontend and encoder (#53)
--------- Co-authored-by: Daniel Mann <[email protected]> Co-authored-by: michelwi <[email protected]> Co-authored-by: Benedikt Hilmes <[email protected]>
- Loading branch information
1 parent
1264482
commit 83ff39e
Showing
6 changed files
with
296 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ffnn_v1 import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
__all__ = ["FeedForwardEncoderV1Config", "FeedForwardEncoderV1"] | ||
|
||
from typing import Tuple | ||
from dataclasses import dataclass | ||
import torch | ||
from torch import nn | ||
|
||
from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config | ||
from i6_models.config import ModelConfiguration, ModuleFactoryV1 | ||
|
||
|
||
@dataclass | ||
class FeedForwardEncoderV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
num_layers: number of feed-forward layers | ||
frontend: module factory for the frontend | ||
layer_cfg: configuration object for each feed-forward layer | ||
""" | ||
|
||
num_layers: int | ||
frontend: ModuleFactoryV1 | ||
layer_cfg: FeedForwardLayerV1Config | ||
|
||
|
||
class FeedForwardEncoderV1(nn.Module): | ||
""" | ||
Simple feed-forward encoder. | ||
Subsampling can be achieved by setting stride > 1 in the frontend config. | ||
""" | ||
|
||
def __init__(self, cfg: FeedForwardEncoderV1Config): | ||
super().__init__() | ||
self.frontend = cfg.frontend() | ||
self.module_list = nn.ModuleList([FeedForwardLayerV1(cfg.layer_cfg) for _ in range(cfg.num_layers)]) | ||
|
||
def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F'] | ||
for module in self.module_list: | ||
x, sequence_mask = module(x, sequence_mask) # [B, T, F'] | ||
|
||
return x, sequence_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
__all__ = ["FeedForwardConfig", "FeedForwardModel"] | ||
|
||
from dataclasses import dataclass | ||
from functools import partial | ||
from typing import Callable, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
from i6_models.config import ModelConfiguration | ||
|
||
|
||
@dataclass | ||
class FeedForwardLayerV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
input_dim: input feature dimension | ||
output_dim: output feature dimension | ||
dropout: dropout probability | ||
activation: activation function applied after linear computation | ||
""" | ||
|
||
input_dim: int | ||
output_dim: int | ||
dropout: float | ||
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" | ||
|
||
|
||
class FeedForwardLayerV1(nn.Module): | ||
""" | ||
Simple feed-forward layer module consisting of: | ||
- linear | ||
- activation | ||
- dropout | ||
""" | ||
|
||
def __init__(self, cfg: FeedForwardLayerV1Config): | ||
super().__init__() | ||
self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True) | ||
self.activation = cfg.activation | ||
self.dropout = nn.Dropout(cfg.dropout) | ||
|
||
def forward( | ||
self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
:param tensor: shape [B,T,F], F=input_dim | ||
:param sequence_mask: shape [B,T] | ||
:return: shape [B,T,F'], F'=output_dim | ||
""" | ||
tensor = self.linear_ff(tensor) # [B,T,F] | ||
tensor = self.activation(tensor) # [B,T,F] | ||
tensor = self.dropout(tensor) # [B,T,F] | ||
return tensor, sequence_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
__all__ = [ | ||
"WindowConvolutionFrontendV1Config", | ||
"WindowConvolutionFrontendV1", | ||
] | ||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from i6_models.config import ModelConfiguration | ||
|
||
from .common import mask_pool, apply_same_padding | ||
|
||
|
||
@dataclass | ||
class WindowConvolutionFrontendV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
input_dim: number of input features to module | ||
output_dim: output dimension | ||
dropout: dropout after linear layer | ||
kernel_size: number of feature frames to convolve | ||
stride: skip (stride - 1) feature frames; stride > 1 implies subsampling | ||
activation: activation function applied after linear computation | ||
""" | ||
|
||
input_dim: int | ||
output_dim: int | ||
dropout: float | ||
kernel_size: int | ||
stride: int | ||
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
assert self.stride >= 1, "Choose an integer >= 1 for stride" | ||
assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" | ||
|
||
|
||
class WindowConvolutionFrontendV1(nn.Module): | ||
""" | ||
Simple feed-forward front-end that computes over a window | ||
of input features. Choosing a stride > 1 allows for subsampling | ||
of the features. | ||
""" | ||
|
||
def __init__(self, cfg: WindowConvolutionFrontendV1Config): | ||
""" | ||
:param cfg: model configuration for this module | ||
""" | ||
super().__init__() | ||
self.conv = torch.nn.Conv1d( | ||
in_channels=cfg.input_dim, | ||
out_channels=cfg.output_dim, | ||
kernel_size=cfg.kernel_size, | ||
stride=cfg.stride, | ||
padding=0, | ||
bias=True, | ||
) | ||
self.activation = cfg.activation | ||
self.pad = lambda x: apply_same_padding(x, cfg.kernel_size) | ||
self.dropout = torch.nn.Dropout(cfg.dropout) | ||
|
||
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
T might be reduced to T' on stride | ||
:param x: input tensor of shape [B,T,F] | ||
:param sequence_mask: the sequence mask for the tensor | ||
:return: torch.Tensor of shape [B,T',F'] and the shape of the sequence mask | ||
""" | ||
# torch 1d convolution is over last dim but we want time conv | ||
x = x.transpose(1, 2) # [B, F, T] | ||
x = self.pad(x) | ||
x = self.conv(x).transpose(1, 2) # [B, T', F'] | ||
|
||
# change masking according to stride value | ||
sequence_mask = mask_pool( | ||
sequence_mask, | ||
kernel_size=1, | ||
stride=self.conv.stride[0], | ||
padding=0, # done manually | ||
) | ||
x = self.activation(x) | ||
x = self.dropout(x) | ||
|
||
return x, sequence_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from itertools import product | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from i6_models.assemblies.ffnn import ( | ||
FeedForwardEncoderV1, | ||
FeedForwardEncoderV1Config, | ||
) | ||
|
||
from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1 | ||
|
||
from i6_models.config import ModelConfiguration, ModuleFactoryV1 | ||
from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config | ||
|
||
|
||
def test_output_shape(): | ||
input_dim = 80 | ||
output_dim = 2048 | ||
dropout = 0.1 | ||
max_seq_lens = 100 | ||
|
||
for window_size, stride in product(range(1, 22), range(1, 5)): | ||
frontend = ModuleFactoryV1( | ||
WindowConvolutionFrontendV1, | ||
WindowConvolutionFrontendV1Config( | ||
input_dim=80, | ||
output_dim=output_dim, | ||
kernel_size=window_size, | ||
dropout=dropout, | ||
stride=stride, | ||
activation=F.relu, | ||
), | ||
) | ||
|
||
layer_cfg = FeedForwardLayerV1Config( | ||
input_dim=2048, | ||
output_dim=2048, | ||
dropout=0.1, | ||
activation=F.relu, | ||
) | ||
|
||
encoder_cfg = FeedForwardEncoderV1Config(num_layers=6, layer_cfg=layer_cfg, frontend=frontend) | ||
|
||
encoder = FeedForwardEncoderV1(encoder_cfg) | ||
|
||
feat_len = torch.arange(start=1, end=max_seq_lens + 1) | ||
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None]) | ||
|
||
features = torch.empty((max_seq_lens, max_seq_lens, input_dim)) | ||
|
||
out, out_mask = encoder(features, mask) | ||
|
||
expected_out_len = (feat_len - 1) // stride + 1 | ||
expected_shape = (max_seq_lens, expected_out_len[-1], output_dim) | ||
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}" | ||
for i in range(expected_out_len[-1] - 1): | ||
# check if masks are correct | ||
assert ( | ||
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]] | ||
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from itertools import product | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1 | ||
|
||
|
||
def test_output_shape(): | ||
in_features = 80 | ||
out_features = 2048 | ||
dropout = 0.1 | ||
max_seq_lens = 100 | ||
|
||
for window_size, stride in product(range(1, 22), range(1, 5)): | ||
frontend = WindowConvolutionFrontendV1( | ||
WindowConvolutionFrontendV1Config( | ||
input_dim=80, | ||
output_dim=out_features, | ||
kernel_size=window_size, | ||
dropout=dropout, | ||
stride=stride, | ||
activation=F.relu, | ||
) | ||
) | ||
|
||
feat_len = torch.arange(start=1, end=max_seq_lens + 1) | ||
mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None]) | ||
|
||
features = torch.empty((max_seq_lens, max_seq_lens, in_features)) | ||
|
||
out, out_mask = frontend(features, mask) | ||
|
||
expected_out_len = (feat_len - 1) // stride + 1 | ||
expected_shape = (max_seq_lens, expected_out_len[-1], out_features) | ||
assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}" | ||
for i in range(expected_out_len[-1] - 1): | ||
# check if masks are correct | ||
assert ( | ||
out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]] | ||
), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" |