Skip to content

Commit

Permalink
Add an auto_expand option to SinusoidalPositionalEmbedding (#5555)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaigoAkisame authored Oct 18, 2024
1 parent 018621f commit ecbf110
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
2 changes: 2 additions & 0 deletions fairseq/modules/positional_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def PositionalEmbedding(
embedding_dim: int,
padding_idx: int,
learned: bool = False,
auto_expand: bool = True,
):
if learned:
# if padding_idx is specified then offset the embedding ids by
Expand All @@ -31,5 +32,6 @@ def PositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
auto_expand=auto_expand,
)
return m
35 changes: 23 additions & 12 deletions fairseq/modules/sinusoidal_positional_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ class SinusoidalPositionalEmbedding(nn.Module):
Padding symbols are ignored.
"""

def __init__(self, embedding_dim, padding_idx, init_size=1024):
def __init__(self, embedding_dim, padding_idx, init_size=1024, auto_expand=True):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx if padding_idx is not None else 0
self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
), persistent=False)
self.register_buffer(
"weights",
SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
),
persistent=False,
)
self.max_positions = int(1e5)
self.auto_expand = auto_expand
self.onnx_trace = False

def prepare_for_onnx_export_(self):
Expand Down Expand Up @@ -75,28 +80,36 @@ def forward(
bspair = torch.onnx.operators.shape_as_tensor(input)
bsz, seq_len = bspair[0], bspair[1]
max_pos = self.padding_idx + 1 + seq_len
weights = self.weights

if max_pos > self.weights.size(0):
# expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
# If the input is longer than the number of pre-computed embeddings,
# compute the extra embeddings on the fly.
# Only store the expanded embeddings if auto_expand=True.
# In multithreading environments, mutating the weights of a module
# may cause trouble. Set auto_expand=False if this happens.
weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
).to(self.weights)
if self.auto_expand:
self.weights = weights

if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
if self.onnx_trace:
return (
self.weights.index_select(index=self.padding_idx + pos, dim=0)
weights.index_select(index=self.padding_idx + pos, dim=0)
.unsqueeze(1)
.repeat(bsz, 1, 1)
)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
return weights[self.padding_idx + pos, :].expand(bsz, 1, -1)

positions = utils.make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace
)
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
flat_embeddings = weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat(
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
)
Expand All @@ -105,7 +118,5 @@ def forward(
)
return embeddings
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
)

0 comments on commit ecbf110

Please sign in to comment.