-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCF.py
129 lines (117 loc) · 3.82 KB
/
CF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
from src import FFModule
from src import MHAModule
from src import KMeansMHA
from src import ConvModule
from src import Residual
class Conformer(nn.Module):
def __init__(
self,
d_model=512,
ff1_hsize=1024,
ff1_dropout=0,
n_head=8,
mha_dropout=0,
kernel_size=3,
conv_dropout=0,
ff2_hsize=1024,
ff2_dropout=0,
batch_size=None,
max_seq_length=512,
window_size=128,
decay=0.999,
kmeans_dropout=0,
is_left_to_right=False,
is_share_qk=False,
use_kmeans_mha=False,
):
"""Conformer Block.
Args:
d_model (int): Embedded dimension of input.
ff1_hsize (int): Hidden size of th first FFN
ff1_drop (float): Dropout rate for the first FFN
n_head (int): Number of heads for MHA
mha_dropout (float): Dropout rate for the first MHA
epsilon (float): Epsilon
kernel_size (int): Kernel_size for the Conv
conv_dropout (float): Dropout rate for the first Conv
ff2_hsize (int): Hidden size of th first FFN
ff2_drop (float): Dropout rate for the first FFN
km_config (dict): Config for KMeans Attention.
use_kmeans_mha(boolean): Flag to use KMeans Attention for multi-head attention.
"""
super(Conformer, self).__init__()
self.ff_module1 = Residual(
module=FFModule(
d_model=d_model,
h_size=ff1_hsize,
dropout=ff1_dropout
),
half=True
)
if use_kmeans_mha:
self.mha_module = Residual(
module=KMeansMHA(
d_model=d_model,
n_head=n_head,
batch_size=batch_size,
max_seq_length=max_seq_length,
window_size=window_size,
decay=decay,
dropout=kmeans_dropout,
is_left_to_right=is_left_to_right,
is_share_qk=is_share_qk,
)
)
else:
self.mha_module = Residual(
module=MHAModule(
d_model=d_model,
n_head=n_head,
dropout=mha_dropout
)
)
self.conv_module = Residual(
module=ConvModule(
in_channels=d_model,
kernel_size=kernel_size,
dropout=conv_dropout
)
)
self.ff_module2 = Residual(
FFModule(
d_model=d_model,
h_size=ff2_hsize,
dropout=ff2_dropout
),
half=True
)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, inputs, **kwargs):
"""Forward propagation of CF.
Args:
inputs (torch.Tensor): Input tensor. Shape is [B, L, D]
Returns:
torch.Tensor
"""
x = self.ff_module1(inputs)
x = self.mha_module(x, **kwargs)
x = self.conv_module(x)
x = self.ff_module2(x)
x = self.layer_norm(x)
return x
def get_conformer(config):
return Conformer(
d_model=config.d_model,
ff1_hsize=config.ff1_hsize,
ff1_dropout=config.ff1_dropout,
n_head=config.n_head,
mha_dropout=config.mha_dropout,
kernel_size=config.kernel_size,
conv_dropout=config.conv_dropout,
ff2_hsize=config.ff2_hsize,
ff2_dropout=config.ff2_dropout,
km_config=config.km_config,
use_kmeans_mha=config.use_kmeans_mha
)