Skip to content

Commit

Permalink
update multi head attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Elad Hoffer committed May 7, 2019
1 parent 4e0a3b5 commit 348276b
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 41 deletions.
16 changes: 11 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data.distributed import DistributedSampler
from seq2seq import models, datasets
from seq2seq.tools.utils.log import setup_logging
from seq2seq.tools.utils.misc import set_global_seeds
from seq2seq.tools.utils.misc import set_global_seeds, torch_dtypes
from seq2seq.tools.config import PAD
import seq2seq.tools.trainer as trainers

Expand Down Expand Up @@ -49,8 +49,10 @@
help='trainer used: ' +
' | '.join(trainers.__all__) +
' (default: Seq2SeqTrainer)')
parser.add_argument('--dtype', default='torch.float',
help='type of tensor - e.g torch.cuda.HalfTensor')
parser.add_argument('--dtype', default='float',
help='type of tensor: ' +
' | '.join(torch_dtypes.keys()) +
' (default: float)')
parser.add_argument('-j', '--workers', default=8, type=int,
help='number of data loading workers (default: 8)')
parser.add_argument('--epochs', default=100, type=int,
Expand Down Expand Up @@ -89,6 +91,8 @@
help='maximum grad norm value. negative for off')
parser.add_argument('--embedding-grad-clip', default=None, type=float,
help='maximum embedding grad norm value')
parser.add_argument('--loss-scale', default=1, type=float,
help='loss scale for mixed precision training.')
parser.add_argument('--label-smoothing', default=0, type=float,
help='label smoothing coefficient - default 0')
parser.add_argument('--uniform-init', default=None, type=float,
Expand All @@ -102,7 +106,7 @@
parser.add_argument('--chunk-batch', default=1, type=int,
help='chunk batch size for multiple passes (training) -- used to fit large batches in memory')
parser.add_argument('--duplicates', default=1, type=int,
help='number of duplicates over singel example')
help='number of duplicates over singel example')
parser.add_argument('--seed', default=123, type=int,
help='random seed (default: 123)')

Expand Down Expand Up @@ -136,6 +140,7 @@ def main(args):
logging.debug("run arguments: %s", args)

device = args.device
dtype = torch_dtypes.get(args.dtype)
if 'cuda' in args.device:
main_gpu = 0
if isinstance(args.device_ids, tuple):
Expand Down Expand Up @@ -168,7 +173,7 @@ def main(args):

model = getattr(models, args.model)(**model_config)

model.to(device)
model.to(device, dtype=dtype)
batch_first = getattr(model, 'batch_first', False)

logging.info(model)
Expand Down Expand Up @@ -213,6 +218,7 @@ def main(args):
device_ids=args.device_ids,
device=device,
dtype=args.dtype,
loss_scale=args.loss_scale,
print_freq=args.print_freq,
save_freq=args.save_freq,
eval_freq=args.eval_freq)
Expand Down
75 changes: 67 additions & 8 deletions seq2seq/models/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,28 +171,32 @@ def forward(self, q, k, v):
mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k)
mask = mask_q if mask is None else mask | mask_q
if mask is not None:
qk.masked_fill_(mask, -1e9)
qk.masked_fill_(mask, float('-inf'))

sm_qk = F.softmax(qk, dim=2)
sm_qk = F.softmax(qk, dim=2,
dtype=torch.float32 if qk.dtype == torch.float16 else qk.dtype)
sm_qk = self.dropout(sm_qk)
return torch.bmm(sm_qk, v), sm_qk # b x t_q x dim_v


class MultiHeadAttention(nn.Module):
class MultiHeadAttentionV2(nn.Module):
"""
Scaled Dot-Product Attention
"""

def __init__(self, input_size, output_size, num_heads, weight_norm=False, groups=1, dropout=0, causal=False):
super(MultiHeadAttention, self).__init__()
def __init__(self, input_size, output_size, num_heads, weight_norm=False, groups=1, dropout=0, causal=False, add_bias_kv=False):
super(MultiHeadAttentionV2, self).__init__()
assert(input_size % num_heads == 0)
wn_func = wn if weight_norm else lambda x: x
self.input_size = input_size
self.output_size = output_size
self.num_heads = num_heads
self.linear_q = wn_func(Linear(input_size, input_size, groups=groups))
self.linear_k = wn_func(Linear(input_size, input_size, groups=groups))
self.linear_v = wn_func(Linear(input_size, input_size, groups=groups))
self.linear_q = wn_func(
Linear(input_size, input_size, bias=False, groups=groups))
self.linear_k = wn_func(
Linear(input_size, input_size, bias=add_bias_kv, groups=groups))
self.linear_v = wn_func(
Linear(input_size, input_size, bias=add_bias_kv, groups=groups))
self.linear_out = wn_func(
Linear(input_size, output_size, groups=groups))
self.sdp_attention = SDPAttention(dropout=dropout, causal=causal)
Expand Down Expand Up @@ -226,3 +230,58 @@ def forward(self, q, k, v):
output = torch.cat(output, 2)

return self.linear_out(output), attention_scores


class MultiHeadAttention(nn.MultiheadAttention):
"""
Scaled Dot-Product Attention
"""

def __init__(self, input_size, output_size, num_heads, dropout=0, causal=False, bias=True, add_bias_kv=False, add_zero_attn=False, batch_first=True, groups=None, weight_norm=None):
super(MultiHeadAttention, self).__init__(input_size, num_heads, dropout=dropout,
bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn)
assert(input_size % num_heads == 0)
assert(input_size == output_size)
self.causal = causal
self.batch_first = batch_first

def set_mask_q(self, masked_tq):
self.mask_q = masked_tq

def set_mask_k(self, masked_tk):
# applies a mask of b x tk length
self.mask_k = masked_tk

def forward(self, query, key, value, incremental_state=None, need_weights=False, static_kv=False):
key_padding_mask = attn_mask = None
time_dim = 1 if self.batch_first else 0
t_q = query.size(time_dim)
t_k = key.size(time_dim)
with torch.no_grad():
if self.causal and t_q > 1:
attn_mask = torch.full((t_q, t_k), float('-inf'),
device=query.device, dtype=query.dtype).triu_(1)
key_padding_mask = self.mask_k

if self.batch_first:
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
key = key.transpose(0, 1)
if kv_same:
value = key
else:
value = value.transpose(0, 1)
if qkv_same:
query = key
else:
query = query.transpose(0, 1)
elif key_padding_mask is not None:
key_padding_mask.t()


attn_output, attn_output_weights = super(
MultiHeadAttention, self).forward(query, key, value, key_padding_mask=key_padding_mask, attn_mask=attn_mask,
incremental_state=incremental_state, need_weights=need_weights, static_kv=static_kv)
if self.batch_first:
attn_output = attn_output.transpose(0, 1)
return attn_output, attn_output_weights
47 changes: 32 additions & 15 deletions seq2seq/models/modules/transformer_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from .recurrent import Recurrent


def positional_embedding(x, min_timescale=1.0, max_timescale=1.0e4, offset=0):
batch, length, channels = list(x.size())
def positional_embedding(x, min_timescale=1.0, max_timescale=1.0e4, offset=0, batch_first=True):
if batch_first:
batch, length, channels = list(x.size())
else:
length, batch, channels = list(x.size())
assert (channels % 2 == 0)
num_timescales = channels // 2
log_timescale_increment = (
Expand All @@ -24,8 +27,12 @@ def positional_embedding(x, min_timescale=1.0, max_timescale=1.0e4, offset=0):
scaled_time = position.unsqueeze(1) * inv_timescales.unsqueeze(0)
# scaled time is now length x num_timescales
# length x channels
signal = torch.cat([scaled_time.sin(), scaled_time.cos()], 1)
return signal.unsqueeze(0).expand(batch, length, channels)
signal = torch.cat(
[scaled_time.sin(), scaled_time.cos()], 1).to(dtype=x.dtype)
if batch_first:
return signal.unsqueeze(0).expand(batch, length, channels)
else:
return signal.unsqueeze(1).expand(length, batch, channels)


class AverageNetwork(nn.Module):
Expand Down Expand Up @@ -80,16 +87,18 @@ def forward(self, x, state=None):
class EncoderBlock(nn.Module):

def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups=1,
layer_norm=True, weight_norm=False, dropout=0):
batch_first=True, layer_norm=True, weight_norm=False, dropout=0):

super(EncoderBlock, self).__init__()
wn_func = wn if weight_norm else lambda x: x
if layer_norm:
self.lnorm1 = nn.LayerNorm(hidden_size)
self.lnorm2 = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
self.attention = MultiHeadAttention(
hidden_size, hidden_size, num_heads, dropout=dropout, causal=False, groups=inner_groups, weight_norm=weight_norm)
self.batch_first = batch_first
self.attention = MultiHeadAttention(hidden_size, hidden_size, num_heads,
dropout=dropout, causal=False, batch_first=batch_first,
groups=inner_groups, weight_norm=weight_norm)
self.fc = nn.Sequential(wn_func(Linear(hidden_size, inner_linear, groups=inner_groups)),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
Expand Down Expand Up @@ -131,7 +140,7 @@ def forward(self, inputs):

class DecoderBlock(nn.Module):

def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups=1,
def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups=1, batch_first=True,
layer_norm=True, weight_norm=False, dropout=0, stateful=None, state_dim=None):

super(DecoderBlock, self).__init__()
Expand All @@ -143,8 +152,10 @@ def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups
self.dropout = nn.Dropout(dropout)
self.weight_norm = weight_norm
self.stateful = stateful
self.attention = MultiHeadAttention(
hidden_size, hidden_size, num_heads, dropout=dropout, causal=False, groups=inner_groups, weight_norm=weight_norm)
self.batch_first = batch_first
self.attention = MultiHeadAttention(hidden_size, hidden_size, num_heads,
batch_first=batch_first, dropout=dropout,
causal=False, groups=inner_groups, weight_norm=weight_norm)
if stateful is not None:
residual = False
stateful_hidden = hidden_size
Expand All @@ -157,13 +168,14 @@ def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups
residual = True
if stateful in ['RNN', 'iRNN', 'LSTM', 'GRU']:
self.state_block = Recurrent(stateful, hidden_size, stateful_hidden,
dropout=dropout, residual=residual, batch_first=True)
dropout=dropout, residual=residual, batch_first=batch_first)
else:
self.state_block = AverageNetwork(
hidden_size, hidden_size, layer_norm=layer_norm, weight_norm=weight_norm, batch_first=True)
hidden_size, hidden_size, layer_norm=layer_norm, weight_norm=weight_norm, batch_first=batch_first)
else:
self.masked_attention = MultiHeadAttention(
hidden_size, hidden_size, num_heads, dropout=dropout, causal=True, groups=inner_groups, weight_norm=weight_norm)
hidden_size, hidden_size, num_heads, dropout=dropout,
batch_first=batch_first, causal=True, groups=inner_groups, weight_norm=weight_norm)

self.fc = nn.Sequential(wn_func(Linear(hidden_size, inner_linear, groups=inner_groups)),
nn.ReLU(inplace=True),
Expand All @@ -185,10 +197,15 @@ def forward(self, inputs, context, state=None):
else: # block_state are past inputs
if state is None:
x_past = x
mask_past = self.masked_attention.mask_k
else:
x_past = torch.cat((state, x), 1)
time_dim = 1 if self.batch_first else 0
x_past, mask_past = state
x_past = torch.cat((x_past, x), time_dim)
mask_past = torch.cat((mask_past, self.masked_attention.mask_k), time_dim)
self.masked_attention.set_mask_k(mask_past)
x, _ = self.masked_attention(x, x_past, x_past)
state = x_past
state = (x_past, mask_past)
if hasattr(self, 'state_proj'):
x = self.state_proj(x)
x = self.dropout(x).add(res)
Expand Down
29 changes: 18 additions & 11 deletions seq2seq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TransformerAttentionEncoder(nn.Module):

def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
num_layers=6, num_heads=8, inner_linear=2048, inner_groups=1, prenormalized=False,
mask_symbol=PAD, layer_norm=True, weight_norm=False, dropout=0, embedder=None):
mask_symbol=PAD, batch_first=True, layer_norm=True, weight_norm=False, dropout=0, embedder=None):

super(TransformerAttentionEncoder, self).__init__()
embedding_size = embedding_size or hidden_size
Expand All @@ -21,7 +21,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
torch.empty(embedding_size, hidden_size))
nn.init.kaiming_uniform_(self.input_projection, a=math.sqrt(5))
self.hidden_size = hidden_size
self.batch_first = True
self.batch_first = batch_first
self.mask_symbol = mask_symbol
self.embedder = embedder or nn.Embedding(
vocab_size, embedding_size, padding_idx=PAD)
Expand All @@ -37,6 +37,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
inner_groups=inner_groups,
layer_norm=layer_norm,
weight_norm=weight_norm,
batch_first=batch_first,
dropout=dropout)
for _ in range(num_layers)
])
Expand All @@ -51,7 +52,7 @@ def forward(self, inputs, hidden=None):
x = self.embedder(inputs).mul_(self.scale_embedding)
if hasattr(self, 'input_projection'):
x = x @ self.input_projection
x.add_(positional_embedding(x))
x.add_(positional_embedding(x, batch_first=self.batch_first))
x = self.dropout(x)

for block in self.blocks:
Expand All @@ -61,13 +62,13 @@ def forward(self, inputs, hidden=None):
if hasattr(self, 'lnorm'):
x = self.lnorm(x)

return State(outputs=x, mask=padding_mask, batch_first=True)
return State(outputs=x, mask=padding_mask, batch_first=self.batch_first)


class TransformerAttentionDecoder(nn.Module):

def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=6,
num_heads=8, dropout=0, inner_linear=2048, inner_groups=1, prenormalized=False, stateful=None, state_dim=None,
def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=6, num_heads=8,
batch_first=True, dropout=0, inner_linear=2048, inner_groups=1, prenormalized=False, stateful=None, state_dim=None,
mask_symbol=PAD, tie_embedding=True, layer_norm=True, weight_norm=False, embedder=None, classifier=True):

super(TransformerAttentionDecoder, self).__init__()
Expand All @@ -76,7 +77,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
self.input_projection = nn.Parameter(
torch.empty(embedding_size, hidden_size))
nn.init.kaiming_uniform_(self.input_projection, a=math.sqrt(5))
self.batch_first = True
self.batch_first = batch_first
self.mask_symbol = mask_symbol
self.embedder = embedder or nn.Embedding(
vocab_size, embedding_size, padding_idx=PAD)
Expand All @@ -94,6 +95,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
layer_norm=layer_norm,
weight_norm=weight_norm,
dropout=dropout,
batch_first=batch_first,
stateful=stateful,
state_dim=state_dim)
for _ in range(num_layers)
Expand Down Expand Up @@ -125,7 +127,9 @@ def forward(self, inputs, state, get_attention=False):
time_step = self.time_step
else:
block_state = state.inputs
time_step = 0 if block_state is None else block_state[0].size(1)
time_dim = 1 if self.batch_first else 0
time_step = 0 if block_state is None else \
block_state[0][0].size(time_dim)

if block_state is None:
block_state = [None] * len(self.blocks)
Expand All @@ -137,7 +141,8 @@ def forward(self, inputs, state, get_attention=False):
x = self.embedder(inputs).mul_(self.scale_embedding)
if hasattr(self, 'input_projection'):
x = x @ self.input_projection
x.add_(positional_embedding(x, offset=time_step))
x.add_(positional_embedding(
x, batch_first=self.batch_first, offset=time_step))
x = self.dropout(x)

attention_scores = []
Expand Down Expand Up @@ -173,7 +178,7 @@ class Transformer(Seq2Seq):

def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=6, num_heads=8,
inner_linear=2048, inner_groups=1, dropout=0.1, prenormalized=False, tie_embedding=True,
encoder=None, decoder=None, layer_norm=True, weight_norm=False, stateful=None):
encoder=None, decoder=None, layer_norm=True, weight_norm=False, batch_first=True, stateful=None):
super(Transformer, self).__init__()
embedding_size = embedding_size or hidden_size
# keeping encoder, decoder None will result with default configuration
Expand All @@ -192,6 +197,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
encoder.setdefault('inner_linear', inner_linear)
encoder.setdefault('inner_groups', inner_groups)
encoder.setdefault('prenormalized', prenormalized)
encoder.setdefault('batch_first', batch_first)

decoder.setdefault('embedding_size', embedding_size)
decoder.setdefault('hidden_size', hidden_size)
Expand All @@ -204,6 +210,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
decoder.setdefault('dropout', dropout)
decoder.setdefault('inner_linear', inner_linear)
decoder.setdefault('inner_groups', inner_groups)
decoder.setdefault('batch_first', batch_first)
decoder.setdefault('prenormalized', prenormalized)
decoder.setdefault('stateful', stateful)

Expand All @@ -214,7 +221,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
decoder.setdefault('embedder', embedder)
decoder['classifier'] = False

self.batch_first = True
self.batch_first = batch_first
self.encoder = TransformerAttentionEncoder(**encoder)
self.decoder = TransformerAttentionDecoder(**decoder)

Expand Down
Loading

0 comments on commit 348276b

Please sign in to comment.