Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Dec 18, 2024
1 parent 7514b1c commit ae964a0
Showing 1 changed file with 1 addition and 60 deletions.
61 changes: 1 addition & 60 deletions i6_models/assemblies/lstm/lstm_v1.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
__all__ = [
"LstmEncoderV1Config",
"LstmEncoderV1",
"LstmEncoderV2Config",
"LstmEncoderV2",
]

from dataclasses import dataclass
import math
import torch
from torch import nn
from typing import Any, Dict, Optional, Tuple, Union

from i6_models.config import ModelConfiguration
from i6_models.parts.lstm import LstmBlockV1Config, LstmBlockV1, LstmBlockV2Config, LstmBlockV2
from i6_models.parts.lstm import LstmBlockV1Config, LstmBlockV1


@dataclass
Expand Down Expand Up @@ -46,62 +43,6 @@ def __init__(self, model_cfg: Union[LstmEncoderV1Config, Dict], **kwargs):
if self.cfg.init_args is not None:
self._param_init(**self.cfg.init_args)

def _param_init(self, init_args_w=None, init_args_b=None):
for m in self.modules():
for name, param in m.named_parameters():
if "bias" in name:
if init_args_b["func"] == "normal":
init_func = nn.init.normal_
else:
raise NotImplementedError
hyp = init_args_b["arg"]
else:
if init_args_w["func"] == "normal":
init_func = nn.init.normal_
else:
raise NotImplementedError
hyp = init_args_w["arg"]
init_func(param, **hyp)

if m is self.output:
m.bias.data.fill_(-math.log(self.cfg.input_dim))

def forward(self, data_tensor: torch.Tensor, seq_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
embed = self.embedding(data_tensor)
embed = self.embed_dropout(embed)

out, _ = self.lstm_block(embed, seq_len)
out = self.lstm_dropout(out)

return out, seq_len


@dataclass
class LstmEncoderV2Config(LstmEncoderV1Config):
lstm_layers_cfg: LstmBlockV2Config

@classmethod
def from_dict(cls, model_cfg_dict: Dict):
model_cfg_dict = model_cfg_dict.copy()
model_cfg_dict["lstm_layers_cfg"] = LstmBlockV2Config.from_dict(model_cfg_dict["lstm_layers_cfg"])
return cls(**model_cfg_dict)


class LstmEncoderV2(nn.Module):
def __init__(self, model_cfg: Union[LstmEncoderV2Config, Dict], **kwargs):
super().__init__()

self.cfg = LstmEncoderV2Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg

self.embedding = nn.Embedding(self.cfg.input_dim, self.cfg.embed_dim)
self.embed_dropout = nn.Dropout(self.cfg.embed_dropout)

self.lstm_block = LstmBlockV2(self.cfg.lstm_layers_cfg)
self.lstm_dropout = nn.Dropout(self.cfg.lstm_dropout)

if self.cfg.init_args is not None:
self._param_init(**self.cfg.init_args)

def _param_init(self, init_args_w=None, init_args_b=None):
for m in self.modules():
for name, param in m.named_parameters():
Expand Down

0 comments on commit ae964a0

Please sign in to comment.