From ae964a0d21e6a699d631f0194f406a6d332c0a5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=2E=20L=C3=BCscher?= Date: Wed, 18 Dec 2024 17:48:46 +0100 Subject: [PATCH] cleanup --- i6_models/assemblies/lstm/lstm_v1.py | 61 +--------------------------- 1 file changed, 1 insertion(+), 60 deletions(-) diff --git a/i6_models/assemblies/lstm/lstm_v1.py b/i6_models/assemblies/lstm/lstm_v1.py index f19fb574..9c18ce70 100644 --- a/i6_models/assemblies/lstm/lstm_v1.py +++ b/i6_models/assemblies/lstm/lstm_v1.py @@ -1,18 +1,15 @@ __all__ = [ "LstmEncoderV1Config", "LstmEncoderV1", - "LstmEncoderV2Config", - "LstmEncoderV2", ] from dataclasses import dataclass -import math import torch from torch import nn from typing import Any, Dict, Optional, Tuple, Union from i6_models.config import ModelConfiguration -from i6_models.parts.lstm import LstmBlockV1Config, LstmBlockV1, LstmBlockV2Config, LstmBlockV2 +from i6_models.parts.lstm import LstmBlockV1Config, LstmBlockV1 @dataclass @@ -46,62 +43,6 @@ def __init__(self, model_cfg: Union[LstmEncoderV1Config, Dict], **kwargs): if self.cfg.init_args is not None: self._param_init(**self.cfg.init_args) - def _param_init(self, init_args_w=None, init_args_b=None): - for m in self.modules(): - for name, param in m.named_parameters(): - if "bias" in name: - if init_args_b["func"] == "normal": - init_func = nn.init.normal_ - else: - raise NotImplementedError - hyp = init_args_b["arg"] - else: - if init_args_w["func"] == "normal": - init_func = nn.init.normal_ - else: - raise NotImplementedError - hyp = init_args_w["arg"] - init_func(param, **hyp) - - if m is self.output: - m.bias.data.fill_(-math.log(self.cfg.input_dim)) - - def forward(self, data_tensor: torch.Tensor, seq_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - embed = self.embedding(data_tensor) - embed = self.embed_dropout(embed) - - out, _ = self.lstm_block(embed, seq_len) - out = self.lstm_dropout(out) - - return out, seq_len - - -@dataclass -class LstmEncoderV2Config(LstmEncoderV1Config): - lstm_layers_cfg: LstmBlockV2Config - - @classmethod - def from_dict(cls, model_cfg_dict: Dict): - model_cfg_dict = model_cfg_dict.copy() - model_cfg_dict["lstm_layers_cfg"] = LstmBlockV2Config.from_dict(model_cfg_dict["lstm_layers_cfg"]) - return cls(**model_cfg_dict) - - -class LstmEncoderV2(nn.Module): - def __init__(self, model_cfg: Union[LstmEncoderV2Config, Dict], **kwargs): - super().__init__() - - self.cfg = LstmEncoderV2Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg - - self.embedding = nn.Embedding(self.cfg.input_dim, self.cfg.embed_dim) - self.embed_dropout = nn.Dropout(self.cfg.embed_dropout) - - self.lstm_block = LstmBlockV2(self.cfg.lstm_layers_cfg) - self.lstm_dropout = nn.Dropout(self.cfg.lstm_dropout) - - if self.cfg.init_args is not None: - self._param_init(**self.cfg.init_args) - def _param_init(self, init_args_w=None, init_args_b=None): for m in self.modules(): for name, param in m.named_parameters():