From 2c96d6b10b2a0a9d5e6b4834644ac6044a9b7f9e Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Mon, 13 May 2024 13:05:42 -0400 Subject: [PATCH 01/11] init --- i6_models/assemblies/ffnn/__init__.py | 1 + i6_models/assemblies/ffnn/ffnn_v1.py | 31 ++++++++++ i6_models/parts/ffnn.py | 53 ++++++++++++++++ i6_models/parts/frontend/window_ffnn.py | 80 +++++++++++++++++++++++++ tests/test_ffnn.py | 67 +++++++++++++++++++++ tests/test_window_frontend.py | 49 +++++++++++++++ 6 files changed, 281 insertions(+) create mode 100644 i6_models/assemblies/ffnn/__init__.py create mode 100644 i6_models/assemblies/ffnn/ffnn_v1.py create mode 100644 i6_models/parts/ffnn.py create mode 100644 i6_models/parts/frontend/window_ffnn.py create mode 100644 tests/test_ffnn.py create mode 100644 tests/test_window_frontend.py diff --git a/i6_models/assemblies/ffnn/__init__.py b/i6_models/assemblies/ffnn/__init__.py new file mode 100644 index 00000000..a108f223 --- /dev/null +++ b/i6_models/assemblies/ffnn/__init__.py @@ -0,0 +1 @@ +from .ffnn_v1 import * \ No newline at end of file diff --git a/i6_models/assemblies/ffnn/ffnn_v1.py b/i6_models/assemblies/ffnn/ffnn_v1.py new file mode 100644 index 00000000..64136648 --- /dev/null +++ b/i6_models/assemblies/ffnn/ffnn_v1.py @@ -0,0 +1,31 @@ +__all__ = [ + "FeedForwardEncoderV1Config", + "FeedForwardEncoderV1" +] + +from typing import Tuple +from dataclasses import dataclass +import torch +from torch import nn + +from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config +from i6_models.config import ModelConfiguration, ModuleFactoryV1 + +@dataclass +class FeedForwardEncoderV1Config(ModelConfiguration): + num_layers: int + frontend: ModuleFactoryV1 + layer_cfg: FeedForwardLayerV1Config + +class FeedForwardEncoderV1(nn.Module): + def __init__(self, cfg: FeedForwardEncoderV1Config): + super().__init__() + self.frontend = cfg.frontend() + self.module_list = nn.ModuleList([FeedForwardLayerV1(cfg.layer_cfg) for _ in range(cfg.num_layers)]) + + def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F'] + for module in self.module_list: + x, sequence_mask = module(x, sequence_mask) # [B, T, F'] + + return x, sequence_mask diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py new file mode 100644 index 00000000..21398574 --- /dev/null +++ b/i6_models/parts/ffnn.py @@ -0,0 +1,53 @@ +__all__ = [ + "FeedForwardConfig", + "FeedForwardModel" +] + +from dataclasses import dataclass +from functools import partial +from typing import Callable, Optional, Tuple + +import torch +import torch.nn.functional as F + +from i6_models.config import ModelConfiguration + +@dataclass +class FeedForwardLayerV1Config(ModelConfiguration): + """ + Attributes: + in_features: input feature dimension + hidden_dim: output feature dimension + dropout: dropout probability + activation: activation function applied after linear computation + """ + input_dim: int + hidden_dim: int + dropout: float = 0.0 + activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = F.relu + + +class FeedForwardLayerV1(torch.nn.Module): + """ + Simple feed-forward layer module consisting of: + - linear + - activation + - dropout + """ + + def __init__(self, cfg: FeedForwardLayerV1Config): + super().__init__() + self.linear_ff = torch.nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim, bias=True) + self.activation = cfg.activation + self.dropout = torch.nn.Dropout(cfg.dropout) + + def forward(self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param tensor: shape [B,T,F], F=input_dim + :param sequence_mask: shape [B,T] + :return: shape [B,T,F'], F=input_dim + """ + tensor = self.linear_ff(tensor) # [B,T,F] + tensor = self.activation(tensor) # [B,T,F] + tensor = self.dropout(tensor) # [B,T,F] + return tensor, sequence_mask diff --git a/i6_models/parts/frontend/window_ffnn.py b/i6_models/parts/frontend/window_ffnn.py new file mode 100644 index 00000000..dc11e546 --- /dev/null +++ b/i6_models/parts/frontend/window_ffnn.py @@ -0,0 +1,80 @@ +__all__ = [ + "WindowFeedForwardFrontendV1Config", + "WindowFeedForwardFrontendV1", +] + +from dataclasses import dataclass +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from i6_models.config import ModelConfiguration + +from .common import mask_pool, get_same_padding + + +@dataclass +class WindowFeedForwardFrontendV1Config(ModelConfiguration): + """ + Attributes: + in_features: number of input features to module + out_features: output dimension + dropout: dropout after linear layer + window_size: number of feature frames to convolve (kernel size) + stride: skip (stride - 1) feature frames; stride > 1 implies subsampling + activation: activation function applied after linear computation + """ + in_features: int + out_features: int + dropout: float + window_size: int = 15 + stride: int = 1 + activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = F.relu + +class WindowFeedForwardFrontendV1(nn.Module): + """ + Simple feed-forward front-end that computes over a window + of input features. Choosing a stride > 1 allows for subsampling + of the features. + """ + + def __init__(self, cfg: WindowFeedForwardFrontendV1Config): + """ + :param cfg: model configuration for this module + """ + super().__init__() + self.conv = torch.nn.Conv1d( + in_channels=cfg.in_features, + out_channels=cfg.out_features, + kernel_size=cfg.window_size, + stride=cfg.stride, + padding=get_same_padding(cfg.window_size), + bias=True, + ) + self.activation = cfg.activation or (lambda x: x) + self.dropout = torch.nn.Dropout(cfg.dropout) + + def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + T might be reduced to T' on stride + + :param x: input tensor of shape [B,T,F] + :param sequence_mask: the sequence mask for the tensor + :return: torch.Tensor of shape [B,T',F'] and the shape of the sequence mask + """ + x = x.transpose(1, 2) # torch 1d convolution is over last dim but we want time conv + x = self.conv(x).transpose(1, 2) + + # these settings apparently apply stride correctly to the masking whatever the kernel size + sequence_mask = mask_pool( + sequence_mask, + kernel_size=1, + stride=self.conv.stride[0], + padding=0, # done manually + ) + x = self.activation(x) + x = self.dropout(x) + + return x, sequence_mask diff --git a/tests/test_ffnn.py b/tests/test_ffnn.py new file mode 100644 index 00000000..c39c0c8f --- /dev/null +++ b/tests/test_ffnn.py @@ -0,0 +1,67 @@ +from itertools import product + +import torch +from torch import nn +from torch.nn import functional as F + +import sys +sys.path.insert(0, "/home/dmann/setups/2024-05-06--test-ffnn-fullsum/recipe/i6_models") + +from i6_models.assemblies.ffnn import ( + FeedForwardEncoderV1, + FeedForwardEncoderV1Config, +) + +from i6_models.parts.frontend.window_ffnn import ( + WindowFeedForwardFrontendV1Config, + WindowFeedForwardFrontendV1 +) + +from i6_models.config import ModelConfiguration, ModuleFactoryV1 +from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config + +def test_output_shape(): + in_features = 80 + out_features = 2048 + dropout = 0.1 + max_seq_lens = 100 + + # skip even window sizes for now + for window_size, stride in product(range(1, 22, 2), range(1, 5)): + frontend = ModuleFactoryV1( + WindowFeedForwardFrontendV1, + WindowFeedForwardFrontendV1Config( + in_features=80, + out_features=out_features, + window_size=window_size, + dropout=dropout, + stride=stride, + activation=F.relu, + ) + ) + + layer_cfg = FeedForwardLayerV1Config(input_dim=2048, hidden_dim=2048, dropout=0.1) + + encoder_cfg = FeedForwardEncoderV1Config( + num_layers=6, layer_cfg=layer_cfg, frontend=frontend + ) + + encoder = FeedForwardEncoderV1(encoder_cfg) + + feat_len = torch.arange(start=1, end=max_seq_lens+1) + mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None]) + + features = torch.empty((max_seq_lens, max_seq_lens, in_features)) + + out, out_mask = encoder(features, mask) + + expected_out_len = (feat_len - 1) // stride + 1 + expected_shape = (max_seq_lens, expected_out_len[-1], out_features) + assert out.shape == expected_shape, \ + f"Output with shape {out.shape} not as expected {expected_shape}" + for i in range(expected_out_len[-1] - 1): + # check if masks are correct + assert out_mask[i,expected_out_len[i] - 1] and not out_mask[i,expected_out_len[i]], \ + f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" + + diff --git a/tests/test_window_frontend.py b/tests/test_window_frontend.py new file mode 100644 index 00000000..ea6af686 --- /dev/null +++ b/tests/test_window_frontend.py @@ -0,0 +1,49 @@ +from itertools import product + +import torch +from torch import nn +from torch.nn import functional as F + +import sys +sys.path.insert(0, "/home/dmann/setups/2024-05-06--test-ffnn-fullsum/recipe/i6_models") + +from i6_models.parts.frontend.window_ffnn import ( + WindowFeedForwardFrontendV1Config, + WindowFeedForwardFrontendV1 +) + + +def test_output_shape(): + in_features = 80 + out_features = 2048 + dropout = 0.1 + max_seq_lens = 100 + + # skip even window sizes for now + for window_size, stride in product(range(1, 22, 2), range(1, 5)): + frontend = WindowFeedForwardFrontendV1( + WindowFeedForwardFrontendV1Config( + in_features=80, + out_features=out_features, + window_size=window_size, + dropout=dropout, + stride=stride, + activation=F.relu, + ) + ) + + feat_len = torch.arange(start=1, end=max_seq_lens+1) + mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None]) + + features = torch.empty((max_seq_lens, max_seq_lens, in_features)) + + out, out_mask = frontend(features, mask) + + expected_out_len = (feat_len - 1) // stride + 1 + expected_shape = (max_seq_lens, expected_out_len[-1], out_features) + assert out.shape == expected_shape, \ + f"Output with shape {out.shape} not as expected {expected_shape}" + for i in range(expected_out_len[-1] - 1): + # check if masks are correct + assert out_mask[i,expected_out_len[i] - 1] and not out_mask[i,expected_out_len[i]], \ + f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" From 0370f7b742aa9ec67f53c0d5a0a478f6c0453f5b Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Tue, 14 May 2024 06:59:08 -0400 Subject: [PATCH 02/11] remove defaults and add checks --- i6_models/assemblies/ffnn/ffnn_v1.py | 13 ++++++++++++- i6_models/parts/ffnn.py | 8 ++++++-- i6_models/parts/frontend/window_ffnn.py | 15 +++++++++++---- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/i6_models/assemblies/ffnn/ffnn_v1.py b/i6_models/assemblies/ffnn/ffnn_v1.py index 64136648..e04066d3 100644 --- a/i6_models/assemblies/ffnn/ffnn_v1.py +++ b/i6_models/assemblies/ffnn/ffnn_v1.py @@ -13,11 +13,22 @@ @dataclass class FeedForwardEncoderV1Config(ModelConfiguration): + """ + Attributes: + num_layers: number of feed-forward layers + frontend: module factory for the frontend + layer_cfg: configuration object for each feed-forward layer + """ num_layers: int frontend: ModuleFactoryV1 layer_cfg: FeedForwardLayerV1Config + class FeedForwardEncoderV1(nn.Module): + """ + Simple feed-forward encoder. + Subsampling can be achieved by setting stride > 1 in the frontend config. + """ def __init__(self, cfg: FeedForwardEncoderV1Config): super().__init__() self.frontend = cfg.frontend() @@ -26,6 +37,6 @@ def __init__(self, cfg: FeedForwardEncoderV1Config): def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F'] for module in self.module_list: - x, sequence_mask = module(x, sequence_mask) # [B, T, F'] + x, sequence_mask = module(x, sequence_mask) # [B, T, F'] return x, sequence_mask diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 21398574..a916a8d3 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -23,8 +23,12 @@ class FeedForwardLayerV1Config(ModelConfiguration): """ input_dim: int hidden_dim: int - dropout: float = 0.0 - activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = F.relu + dropout: float + activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] + + def __post_init__(self): + super().__post_init__() + assert 0.0 <= dropout <= 1.0, "Dropout value must be a probability" class FeedForwardLayerV1(torch.nn.Module): diff --git a/i6_models/parts/frontend/window_ffnn.py b/i6_models/parts/frontend/window_ffnn.py index dc11e546..5c2777ef 100644 --- a/i6_models/parts/frontend/window_ffnn.py +++ b/i6_models/parts/frontend/window_ffnn.py @@ -29,9 +29,16 @@ class WindowFeedForwardFrontendV1Config(ModelConfiguration): in_features: int out_features: int dropout: float - window_size: int = 15 - stride: int = 1 - activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = F.relu + window_size: int + stride: int + activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] + + def __post_init__(self): + super().__post_init__() + assert self.window_size % 2 == 1, "Only odd kernel sizes are supported so far" + assert stride >= 1, "Choose an integer >= 1 for stride" + assert 0.0 <= dropout <= 1.0, "Dropout value must be a probability" + class WindowFeedForwardFrontendV1(nn.Module): """ @@ -53,7 +60,7 @@ def __init__(self, cfg: WindowFeedForwardFrontendV1Config): padding=get_same_padding(cfg.window_size), bias=True, ) - self.activation = cfg.activation or (lambda x: x) + self.activation = cfg.activation self.dropout = torch.nn.Dropout(cfg.dropout) def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: From 819976bfa5aee35dacae6cc8e9cafc1c8df2281a Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Tue, 14 May 2024 07:00:00 -0400 Subject: [PATCH 03/11] apply black --- i6_models/assemblies/ffnn/__init__.py | 2 +- i6_models/assemblies/ffnn/ffnn_v1.py | 14 ++++++------- i6_models/parts/ffnn.py | 11 +++++----- i6_models/parts/frontend/window_ffnn.py | 7 ++++--- tests/test_ffnn.py | 27 ++++++++++--------------- tests/test_window_frontend.py | 18 ++++++++--------- 6 files changed, 37 insertions(+), 42 deletions(-) diff --git a/i6_models/assemblies/ffnn/__init__.py b/i6_models/assemblies/ffnn/__init__.py index a108f223..90afaff6 100644 --- a/i6_models/assemblies/ffnn/__init__.py +++ b/i6_models/assemblies/ffnn/__init__.py @@ -1 +1 @@ -from .ffnn_v1 import * \ No newline at end of file +from .ffnn_v1 import * diff --git a/i6_models/assemblies/ffnn/ffnn_v1.py b/i6_models/assemblies/ffnn/ffnn_v1.py index e04066d3..f6e5c239 100644 --- a/i6_models/assemblies/ffnn/ffnn_v1.py +++ b/i6_models/assemblies/ffnn/ffnn_v1.py @@ -1,7 +1,4 @@ -__all__ = [ - "FeedForwardEncoderV1Config", - "FeedForwardEncoderV1" -] +__all__ = ["FeedForwardEncoderV1Config", "FeedForwardEncoderV1"] from typing import Tuple from dataclasses import dataclass @@ -11,14 +8,16 @@ from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config from i6_models.config import ModelConfiguration, ModuleFactoryV1 + @dataclass class FeedForwardEncoderV1Config(ModelConfiguration): """ Attributes: num_layers: number of feed-forward layers - frontend: module factory for the frontend + frontend: module factory for the frontend layer_cfg: configuration object for each feed-forward layer """ + num_layers: int frontend: ModuleFactoryV1 layer_cfg: FeedForwardLayerV1Config @@ -29,14 +28,15 @@ class FeedForwardEncoderV1(nn.Module): Simple feed-forward encoder. Subsampling can be achieved by setting stride > 1 in the frontend config. """ + def __init__(self, cfg: FeedForwardEncoderV1Config): super().__init__() self.frontend = cfg.frontend() self.module_list = nn.ModuleList([FeedForwardLayerV1(cfg.layer_cfg) for _ in range(cfg.num_layers)]) - + def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F'] for module in self.module_list: - x, sequence_mask = module(x, sequence_mask) # [B, T, F'] + x, sequence_mask = module(x, sequence_mask) # [B, T, F'] return x, sequence_mask diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index a916a8d3..713bfd56 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -1,7 +1,4 @@ -__all__ = [ - "FeedForwardConfig", - "FeedForwardModel" -] +__all__ = ["FeedForwardConfig", "FeedForwardModel"] from dataclasses import dataclass from functools import partial @@ -12,6 +9,7 @@ from i6_models.config import ModelConfiguration + @dataclass class FeedForwardLayerV1Config(ModelConfiguration): """ @@ -21,6 +19,7 @@ class FeedForwardLayerV1Config(ModelConfiguration): dropout: dropout probability activation: activation function applied after linear computation """ + input_dim: int hidden_dim: int dropout: float @@ -45,7 +44,9 @@ def __init__(self, cfg: FeedForwardLayerV1Config): self.activation = cfg.activation self.dropout = torch.nn.Dropout(cfg.dropout) - def forward(self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: """ :param tensor: shape [B,T,F], F=input_dim :param sequence_mask: shape [B,T] diff --git a/i6_models/parts/frontend/window_ffnn.py b/i6_models/parts/frontend/window_ffnn.py index 5c2777ef..9c0e26a6 100644 --- a/i6_models/parts/frontend/window_ffnn.py +++ b/i6_models/parts/frontend/window_ffnn.py @@ -26,6 +26,7 @@ class WindowFeedForwardFrontendV1Config(ModelConfiguration): stride: skip (stride - 1) feature frames; stride > 1 implies subsampling activation: activation function applied after linear computation """ + in_features: int out_features: int dropout: float @@ -62,7 +63,7 @@ def __init__(self, cfg: WindowFeedForwardFrontendV1Config): ) self.activation = cfg.activation self.dropout = torch.nn.Dropout(cfg.dropout) - + def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ T might be reduced to T' on stride @@ -71,7 +72,7 @@ def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torc :param sequence_mask: the sequence mask for the tensor :return: torch.Tensor of shape [B,T',F'] and the shape of the sequence mask """ - x = x.transpose(1, 2) # torch 1d convolution is over last dim but we want time conv + x = x.transpose(1, 2) # torch 1d convolution is over last dim but we want time conv x = self.conv(x).transpose(1, 2) # these settings apparently apply stride correctly to the masking whatever the kernel size @@ -79,7 +80,7 @@ def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torc sequence_mask, kernel_size=1, stride=self.conv.stride[0], - padding=0, # done manually + padding=0, # done manually ) x = self.activation(x) x = self.dropout(x) diff --git a/tests/test_ffnn.py b/tests/test_ffnn.py index c39c0c8f..749b92e8 100644 --- a/tests/test_ffnn.py +++ b/tests/test_ffnn.py @@ -5,6 +5,7 @@ from torch.nn import functional as F import sys + sys.path.insert(0, "/home/dmann/setups/2024-05-06--test-ffnn-fullsum/recipe/i6_models") from i6_models.assemblies.ffnn import ( @@ -12,14 +13,12 @@ FeedForwardEncoderV1Config, ) -from i6_models.parts.frontend.window_ffnn import ( - WindowFeedForwardFrontendV1Config, - WindowFeedForwardFrontendV1 -) +from i6_models.parts.frontend.window_ffnn import WindowFeedForwardFrontendV1Config, WindowFeedForwardFrontendV1 from i6_models.config import ModelConfiguration, ModuleFactoryV1 from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config + def test_output_shape(): in_features = 80 out_features = 2048 @@ -31,24 +30,22 @@ def test_output_shape(): frontend = ModuleFactoryV1( WindowFeedForwardFrontendV1, WindowFeedForwardFrontendV1Config( - in_features=80, + in_features=80, out_features=out_features, window_size=window_size, dropout=dropout, stride=stride, activation=F.relu, - ) + ), ) layer_cfg = FeedForwardLayerV1Config(input_dim=2048, hidden_dim=2048, dropout=0.1) - encoder_cfg = FeedForwardEncoderV1Config( - num_layers=6, layer_cfg=layer_cfg, frontend=frontend - ) + encoder_cfg = FeedForwardEncoderV1Config(num_layers=6, layer_cfg=layer_cfg, frontend=frontend) encoder = FeedForwardEncoderV1(encoder_cfg) - feat_len = torch.arange(start=1, end=max_seq_lens+1) + feat_len = torch.arange(start=1, end=max_seq_lens + 1) mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None]) features = torch.empty((max_seq_lens, max_seq_lens, in_features)) @@ -57,11 +54,9 @@ def test_output_shape(): expected_out_len = (feat_len - 1) // stride + 1 expected_shape = (max_seq_lens, expected_out_len[-1], out_features) - assert out.shape == expected_shape, \ - f"Output with shape {out.shape} not as expected {expected_shape}" + assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}" for i in range(expected_out_len[-1] - 1): # check if masks are correct - assert out_mask[i,expected_out_len[i] - 1] and not out_mask[i,expected_out_len[i]], \ - f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" - - + assert ( + out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]] + ), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" diff --git a/tests/test_window_frontend.py b/tests/test_window_frontend.py index ea6af686..3054d439 100644 --- a/tests/test_window_frontend.py +++ b/tests/test_window_frontend.py @@ -5,12 +5,10 @@ from torch.nn import functional as F import sys + sys.path.insert(0, "/home/dmann/setups/2024-05-06--test-ffnn-fullsum/recipe/i6_models") -from i6_models.parts.frontend.window_ffnn import ( - WindowFeedForwardFrontendV1Config, - WindowFeedForwardFrontendV1 -) +from i6_models.parts.frontend.window_ffnn import WindowFeedForwardFrontendV1Config, WindowFeedForwardFrontendV1 def test_output_shape(): @@ -23,7 +21,7 @@ def test_output_shape(): for window_size, stride in product(range(1, 22, 2), range(1, 5)): frontend = WindowFeedForwardFrontendV1( WindowFeedForwardFrontendV1Config( - in_features=80, + in_features=80, out_features=out_features, window_size=window_size, dropout=dropout, @@ -32,7 +30,7 @@ def test_output_shape(): ) ) - feat_len = torch.arange(start=1, end=max_seq_lens+1) + feat_len = torch.arange(start=1, end=max_seq_lens + 1) mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None]) features = torch.empty((max_seq_lens, max_seq_lens, in_features)) @@ -41,9 +39,9 @@ def test_output_shape(): expected_out_len = (feat_len - 1) // stride + 1 expected_shape = (max_seq_lens, expected_out_len[-1], out_features) - assert out.shape == expected_shape, \ - f"Output with shape {out.shape} not as expected {expected_shape}" + assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}" for i in range(expected_out_len[-1] - 1): # check if masks are correct - assert out_mask[i,expected_out_len[i] - 1] and not out_mask[i,expected_out_len[i]], \ - f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" + assert ( + out_mask[i, expected_out_len[i] - 1] and not out_mask[i, expected_out_len[i]] + ), f"Failed for {i=}, {stride=}, {window_size=}, {out_mask[i]=}, {out_mask[i].shape=}" From 6fa109bd8bfdb6e9739219991be769a16acd52fd Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Tue, 14 May 2024 08:48:56 -0400 Subject: [PATCH 04/11] fix import --- i6_models/parts/ffnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 713bfd56..2f4c6038 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import partial -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union import torch import torch.nn.functional as F From d3a47550723a6a6aa439b73217d35b941a2c1a44 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Tue, 14 May 2024 08:59:12 -0400 Subject: [PATCH 05/11] fix more imports, remove path insert in tests --- i6_models/parts/ffnn.py | 9 +++++---- i6_models/parts/frontend/window_ffnn.py | 4 ++-- tests/test_ffnn.py | 11 ++++++----- tests/test_window_frontend.py | 4 ---- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 2f4c6038..f161741c 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -5,6 +5,7 @@ from typing import Callable, Optional, Tuple, Union import torch +from torch import nn import torch.nn.functional as F from i6_models.config import ModelConfiguration @@ -27,10 +28,10 @@ class FeedForwardLayerV1Config(ModelConfiguration): def __post_init__(self): super().__post_init__() - assert 0.0 <= dropout <= 1.0, "Dropout value must be a probability" + assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" -class FeedForwardLayerV1(torch.nn.Module): +class FeedForwardLayerV1(nn.Module): """ Simple feed-forward layer module consisting of: - linear @@ -40,9 +41,9 @@ class FeedForwardLayerV1(torch.nn.Module): def __init__(self, cfg: FeedForwardLayerV1Config): super().__init__() - self.linear_ff = torch.nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim, bias=True) + self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim, bias=True) self.activation = cfg.activation - self.dropout = torch.nn.Dropout(cfg.dropout) + self.dropout = nn.Dropout(cfg.dropout) def forward( self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None diff --git a/i6_models/parts/frontend/window_ffnn.py b/i6_models/parts/frontend/window_ffnn.py index 9c0e26a6..1e0d1b84 100644 --- a/i6_models/parts/frontend/window_ffnn.py +++ b/i6_models/parts/frontend/window_ffnn.py @@ -37,8 +37,8 @@ class WindowFeedForwardFrontendV1Config(ModelConfiguration): def __post_init__(self): super().__post_init__() assert self.window_size % 2 == 1, "Only odd kernel sizes are supported so far" - assert stride >= 1, "Choose an integer >= 1 for stride" - assert 0.0 <= dropout <= 1.0, "Dropout value must be a probability" + assert self.stride >= 1, "Choose an integer >= 1 for stride" + assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" class WindowFeedForwardFrontendV1(nn.Module): diff --git a/tests/test_ffnn.py b/tests/test_ffnn.py index 749b92e8..6153cb3b 100644 --- a/tests/test_ffnn.py +++ b/tests/test_ffnn.py @@ -4,10 +4,6 @@ from torch import nn from torch.nn import functional as F -import sys - -sys.path.insert(0, "/home/dmann/setups/2024-05-06--test-ffnn-fullsum/recipe/i6_models") - from i6_models.assemblies.ffnn import ( FeedForwardEncoderV1, FeedForwardEncoderV1Config, @@ -39,7 +35,12 @@ def test_output_shape(): ), ) - layer_cfg = FeedForwardLayerV1Config(input_dim=2048, hidden_dim=2048, dropout=0.1) + layer_cfg = FeedForwardLayerV1Config( + input_dim=2048, + hidden_dim=2048, + dropout=0.1, + activation=F.relu, + ) encoder_cfg = FeedForwardEncoderV1Config(num_layers=6, layer_cfg=layer_cfg, frontend=frontend) diff --git a/tests/test_window_frontend.py b/tests/test_window_frontend.py index 3054d439..2db71fce 100644 --- a/tests/test_window_frontend.py +++ b/tests/test_window_frontend.py @@ -4,10 +4,6 @@ from torch import nn from torch.nn import functional as F -import sys - -sys.path.insert(0, "/home/dmann/setups/2024-05-06--test-ffnn-fullsum/recipe/i6_models") - from i6_models.parts.frontend.window_ffnn import WindowFeedForwardFrontendV1Config, WindowFeedForwardFrontendV1 From 4a656b682dcf9bdcb0ae157042b116c138f3598f Mon Sep 17 00:00:00 2001 From: DanEnergetics Date: Tue, 14 May 2024 18:14:58 +0200 Subject: [PATCH 06/11] better comments Co-authored-by: michelwi --- i6_models/parts/frontend/window_ffnn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/i6_models/parts/frontend/window_ffnn.py b/i6_models/parts/frontend/window_ffnn.py index 1e0d1b84..ced7c7e1 100644 --- a/i6_models/parts/frontend/window_ffnn.py +++ b/i6_models/parts/frontend/window_ffnn.py @@ -72,10 +72,11 @@ def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torc :param sequence_mask: the sequence mask for the tensor :return: torch.Tensor of shape [B,T',F'] and the shape of the sequence mask """ - x = x.transpose(1, 2) # torch 1d convolution is over last dim but we want time conv - x = self.conv(x).transpose(1, 2) + # torch 1d convolution is over last dim but we want time conv + x = x.transpose(1, 2) # [B, F, T] + x = self.conv(x).transpose(1, 2) # [B, T', F'] - # these settings apparently apply stride correctly to the masking whatever the kernel size + # change masking according to stride value sequence_mask = mask_pool( sequence_mask, kernel_size=1, From 5a7b05c46a2f9b80720f81413e57a0397b23e198 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Tue, 14 May 2024 12:25:59 -0400 Subject: [PATCH 07/11] consistent attribute naming --- i6_models/parts/ffnn.py | 8 ++++---- tests/test_ffnn.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index f161741c..5f036920 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -21,7 +21,7 @@ class FeedForwardLayerV1Config(ModelConfiguration): activation: activation function applied after linear computation """ - input_dim: int + in_features: int hidden_dim: int dropout: float activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] @@ -41,7 +41,7 @@ class FeedForwardLayerV1(nn.Module): def __init__(self, cfg: FeedForwardLayerV1Config): super().__init__() - self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim, bias=True) + self.linear_ff = nn.Linear(in_features=cfg.in_features, out_features=cfg.hidden_dim, bias=True) self.activation = cfg.activation self.dropout = nn.Dropout(cfg.dropout) @@ -49,9 +49,9 @@ def forward( self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ - :param tensor: shape [B,T,F], F=input_dim + :param tensor: shape [B,T,F], F=in_features :param sequence_mask: shape [B,T] - :return: shape [B,T,F'], F=input_dim + :return: shape [B,T,F'], F=in_features """ tensor = self.linear_ff(tensor) # [B,T,F] tensor = self.activation(tensor) # [B,T,F] diff --git a/tests/test_ffnn.py b/tests/test_ffnn.py index 6153cb3b..3c387853 100644 --- a/tests/test_ffnn.py +++ b/tests/test_ffnn.py @@ -36,7 +36,7 @@ def test_output_shape(): ) layer_cfg = FeedForwardLayerV1Config( - input_dim=2048, + in_features=2048, hidden_dim=2048, dropout=0.1, activation=F.relu, From 8273d49072ac7da5702b1274050a0cfb0c08dc19 Mon Sep 17 00:00:00 2001 From: DanEnergetics Date: Thu, 23 May 2024 09:59:31 +0200 Subject: [PATCH 08/11] Apply doc string suggestions Co-authored-by: Benedikt Hilmes --- i6_models/parts/ffnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 5f036920..1005702f 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -49,9 +49,9 @@ def forward( self, tensor: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ - :param tensor: shape [B,T,F], F=in_features + :param tensor: shape [B,T,F], F=input_dim :param sequence_mask: shape [B,T] - :return: shape [B,T,F'], F=in_features + :return: shape [B,T,F'], F'=output_dim """ tensor = self.linear_ff(tensor) # [B,T,F] tensor = self.activation(tensor) # [B,T,F] From df19d3c3e92003d8ec80bbb097fe8c3bcfe67703 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Thu, 23 May 2024 04:19:16 -0400 Subject: [PATCH 09/11] parameter and other renaming --- i6_models/parts/ffnn.py | 10 +++--- .../{window_ffnn.py => window_convolution.py} | 32 +++++++++---------- tests/test_ffnn.py | 24 +++++++------- tests/test_window_frontend.py | 12 +++---- 4 files changed, 39 insertions(+), 39 deletions(-) rename i6_models/parts/frontend/{window_ffnn.py => window_convolution.py} (75%) diff --git a/i6_models/parts/ffnn.py b/i6_models/parts/ffnn.py index 1005702f..3bd30d38 100644 --- a/i6_models/parts/ffnn.py +++ b/i6_models/parts/ffnn.py @@ -15,14 +15,14 @@ class FeedForwardLayerV1Config(ModelConfiguration): """ Attributes: - in_features: input feature dimension - hidden_dim: output feature dimension + input_dim: input feature dimension + output_dim: output feature dimension dropout: dropout probability activation: activation function applied after linear computation """ - in_features: int - hidden_dim: int + input_dim: int + output_dim: int dropout: float activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] @@ -41,7 +41,7 @@ class FeedForwardLayerV1(nn.Module): def __init__(self, cfg: FeedForwardLayerV1Config): super().__init__() - self.linear_ff = nn.Linear(in_features=cfg.in_features, out_features=cfg.hidden_dim, bias=True) + self.linear_ff = nn.Linear(in_features=cfg.input_dim, out_features=cfg.output_dim, bias=True) self.activation = cfg.activation self.dropout = nn.Dropout(cfg.dropout) diff --git a/i6_models/parts/frontend/window_ffnn.py b/i6_models/parts/frontend/window_convolution.py similarity index 75% rename from i6_models/parts/frontend/window_ffnn.py rename to i6_models/parts/frontend/window_convolution.py index ced7c7e1..686b16e3 100644 --- a/i6_models/parts/frontend/window_ffnn.py +++ b/i6_models/parts/frontend/window_convolution.py @@ -1,6 +1,6 @@ __all__ = [ - "WindowFeedForwardFrontendV1Config", - "WindowFeedForwardFrontendV1", + "WindowConvolutionFrontendV1Config", + "WindowConvolutionFrontendV1", ] from dataclasses import dataclass @@ -16,49 +16,49 @@ @dataclass -class WindowFeedForwardFrontendV1Config(ModelConfiguration): +class WindowConvolutionFrontendV1Config(ModelConfiguration): """ Attributes: - in_features: number of input features to module - out_features: output dimension + input_dim: number of input features to module + output_dim: output dimension dropout: dropout after linear layer - window_size: number of feature frames to convolve (kernel size) + kernel_size: number of feature frames to convolve (kernel size) stride: skip (stride - 1) feature frames; stride > 1 implies subsampling activation: activation function applied after linear computation """ - in_features: int - out_features: int + input_dim: int + output_dim: int dropout: float - window_size: int + kernel_size: int stride: int activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] def __post_init__(self): super().__post_init__() - assert self.window_size % 2 == 1, "Only odd kernel sizes are supported so far" + assert self.kernel_size % 2 == 1, "Only odd kernel sizes are supported so far" assert self.stride >= 1, "Choose an integer >= 1 for stride" assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" -class WindowFeedForwardFrontendV1(nn.Module): +class WindowConvolutionFrontendV1(nn.Module): """ Simple feed-forward front-end that computes over a window of input features. Choosing a stride > 1 allows for subsampling of the features. """ - def __init__(self, cfg: WindowFeedForwardFrontendV1Config): + def __init__(self, cfg: WindowConvolutionFrontendV1Config): """ :param cfg: model configuration for this module """ super().__init__() self.conv = torch.nn.Conv1d( - in_channels=cfg.in_features, - out_channels=cfg.out_features, - kernel_size=cfg.window_size, + in_channels=cfg.input_dim, + out_channels=cfg.output_dim, + kernel_size=cfg.kernel_size, stride=cfg.stride, - padding=get_same_padding(cfg.window_size), + padding=get_same_padding(cfg.kernel_size), bias=True, ) self.activation = cfg.activation diff --git a/tests/test_ffnn.py b/tests/test_ffnn.py index 3c387853..4b3ee268 100644 --- a/tests/test_ffnn.py +++ b/tests/test_ffnn.py @@ -9,26 +9,26 @@ FeedForwardEncoderV1Config, ) -from i6_models.parts.frontend.window_ffnn import WindowFeedForwardFrontendV1Config, WindowFeedForwardFrontendV1 +from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1 from i6_models.config import ModelConfiguration, ModuleFactoryV1 from i6_models.parts.ffnn import FeedForwardLayerV1, FeedForwardLayerV1Config def test_output_shape(): - in_features = 80 - out_features = 2048 + input_dim = 80 + output_dim = 2048 dropout = 0.1 max_seq_lens = 100 # skip even window sizes for now for window_size, stride in product(range(1, 22, 2), range(1, 5)): frontend = ModuleFactoryV1( - WindowFeedForwardFrontendV1, - WindowFeedForwardFrontendV1Config( - in_features=80, - out_features=out_features, - window_size=window_size, + WindowConvolutionFrontendV1, + WindowConvolutionFrontendV1Config( + input_dim=80, + output_dim=output_dim, + kernel_size=window_size, dropout=dropout, stride=stride, activation=F.relu, @@ -36,8 +36,8 @@ def test_output_shape(): ) layer_cfg = FeedForwardLayerV1Config( - in_features=2048, - hidden_dim=2048, + input_dim=2048, + output_dim=2048, dropout=0.1, activation=F.relu, ) @@ -49,12 +49,12 @@ def test_output_shape(): feat_len = torch.arange(start=1, end=max_seq_lens + 1) mask = torch.less(torch.arange(max_seq_lens)[None, :], feat_len[:, None]) - features = torch.empty((max_seq_lens, max_seq_lens, in_features)) + features = torch.empty((max_seq_lens, max_seq_lens, input_dim)) out, out_mask = encoder(features, mask) expected_out_len = (feat_len - 1) // stride + 1 - expected_shape = (max_seq_lens, expected_out_len[-1], out_features) + expected_shape = (max_seq_lens, expected_out_len[-1], output_dim) assert out.shape == expected_shape, f"Output with shape {out.shape} not as expected {expected_shape}" for i in range(expected_out_len[-1] - 1): # check if masks are correct diff --git a/tests/test_window_frontend.py b/tests/test_window_frontend.py index 2db71fce..7f1b90e5 100644 --- a/tests/test_window_frontend.py +++ b/tests/test_window_frontend.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import functional as F -from i6_models.parts.frontend.window_ffnn import WindowFeedForwardFrontendV1Config, WindowFeedForwardFrontendV1 +from i6_models.parts.frontend.window_convolution import WindowConvolutionFrontendV1Config, WindowConvolutionFrontendV1 def test_output_shape(): @@ -15,11 +15,11 @@ def test_output_shape(): # skip even window sizes for now for window_size, stride in product(range(1, 22, 2), range(1, 5)): - frontend = WindowFeedForwardFrontendV1( - WindowFeedForwardFrontendV1Config( - in_features=80, - out_features=out_features, - window_size=window_size, + frontend = WindowConvolutionFrontendV1( + WindowConvolutionFrontendV1Config( + input_dim=80, + output_dim=out_features, + kernel_size=window_size, dropout=dropout, stride=stride, activation=F.relu, From 1b51cf6ac601998d579cb9ab34678b3796814002 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Thu, 23 May 2024 04:32:35 -0400 Subject: [PATCH 10/11] update paddinng mechanism --- i6_models/parts/frontend/window_convolution.py | 7 ++++--- tests/test_ffnn.py | 3 +-- tests/test_window_frontend.py | 3 +-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/i6_models/parts/frontend/window_convolution.py b/i6_models/parts/frontend/window_convolution.py index 686b16e3..44ba78fa 100644 --- a/i6_models/parts/frontend/window_convolution.py +++ b/i6_models/parts/frontend/window_convolution.py @@ -12,7 +12,7 @@ from i6_models.config import ModelConfiguration -from .common import mask_pool, get_same_padding +from .common import mask_pool, apply_same_padding @dataclass @@ -36,7 +36,6 @@ class WindowConvolutionFrontendV1Config(ModelConfiguration): def __post_init__(self): super().__post_init__() - assert self.kernel_size % 2 == 1, "Only odd kernel sizes are supported so far" assert self.stride >= 1, "Choose an integer >= 1 for stride" assert 0.0 <= self.dropout <= 1.0, "Dropout value must be a probability" @@ -58,10 +57,11 @@ def __init__(self, cfg: WindowConvolutionFrontendV1Config): out_channels=cfg.output_dim, kernel_size=cfg.kernel_size, stride=cfg.stride, - padding=get_same_padding(cfg.kernel_size), + padding=0, bias=True, ) self.activation = cfg.activation + self.pad = lambda x: apply_same_padding(x, cfg.kernel_size) self.dropout = torch.nn.Dropout(cfg.dropout) def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -74,6 +74,7 @@ def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> Tuple[torc """ # torch 1d convolution is over last dim but we want time conv x = x.transpose(1, 2) # [B, F, T] + x = self.pad(x) x = self.conv(x).transpose(1, 2) # [B, T', F'] # change masking according to stride value diff --git a/tests/test_ffnn.py b/tests/test_ffnn.py index 4b3ee268..fe5a4f56 100644 --- a/tests/test_ffnn.py +++ b/tests/test_ffnn.py @@ -21,8 +21,7 @@ def test_output_shape(): dropout = 0.1 max_seq_lens = 100 - # skip even window sizes for now - for window_size, stride in product(range(1, 22, 2), range(1, 5)): + for window_size, stride in product(range(1, 22), range(1, 5)): frontend = ModuleFactoryV1( WindowConvolutionFrontendV1, WindowConvolutionFrontendV1Config( diff --git a/tests/test_window_frontend.py b/tests/test_window_frontend.py index 7f1b90e5..aa5992bd 100644 --- a/tests/test_window_frontend.py +++ b/tests/test_window_frontend.py @@ -13,8 +13,7 @@ def test_output_shape(): dropout = 0.1 max_seq_lens = 100 - # skip even window sizes for now - for window_size, stride in product(range(1, 22, 2), range(1, 5)): + for window_size, stride in product(range(1, 22), range(1, 5)): frontend = WindowConvolutionFrontendV1( WindowConvolutionFrontendV1Config( input_dim=80, From 74768def4d98298b325e597653b52850b4a406da Mon Sep 17 00:00:00 2001 From: Benedikt Hilmes Date: Fri, 24 May 2024 11:02:00 +0200 Subject: [PATCH 11/11] Update i6_models/parts/frontend/window_convolution.py --- i6_models/parts/frontend/window_convolution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/parts/frontend/window_convolution.py b/i6_models/parts/frontend/window_convolution.py index 44ba78fa..a89d60ed 100644 --- a/i6_models/parts/frontend/window_convolution.py +++ b/i6_models/parts/frontend/window_convolution.py @@ -22,7 +22,7 @@ class WindowConvolutionFrontendV1Config(ModelConfiguration): input_dim: number of input features to module output_dim: output dimension dropout: dropout after linear layer - kernel_size: number of feature frames to convolve (kernel size) + kernel_size: number of feature frames to convolve stride: skip (stride - 1) feature frames; stride > 1 implies subsampling activation: activation function applied after linear computation """