Skip to content

Commit

Permalink
updated code for decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Jul 26, 2023
1 parent b9a8f10 commit d70f6e2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -825,13 +825,15 @@ def deprecated_greedy_search_batch_for_cross_attn(
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
current_encoder_out = model.joiner.label_level_am_attention(
encoder_out[:, : t + 1, :].unsqueeze(2),
encoder_out.unsqueeze(2),
decoder_out.unsqueeze(2),
encoder_out_lens,
# encoder_out_lens,
None,
)
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),
None,
apply_attn=False,
project_input=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def forward(
batch_size,
lm_seq_len,
am_seq_len,
), f"{attn_weights.shape}"
), f"{attn_weights.shape} {x.shape}"

x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
# print("projected x.shape", x.shape)
Expand Down Expand Up @@ -406,19 +406,23 @@ def forward(
if key_padding_mask is not None:
# (batch, max_len)

key_padding_mask = (
key_padding_mask.unsqueeze(1)
.expand(
key_padding_mask.shape[0], # b
self.prune_range,
key_padding_mask.shape[1], # l
if b_p_dim == key_padding_mask.shape[0] * self.prune_range:
key_padding_mask = (
key_padding_mask.unsqueeze(1)
.expand(
key_padding_mask.shape[0], # b
self.prune_range,
key_padding_mask.shape[1], # l
)
.reshape(b_p_dim, am_seq_len)
.unsqueeze(1)
.unsqueeze(0)
)
.reshape(b_p_dim, am_seq_len)
.unsqueeze(1)
.unsqueeze(0)
)
# (1, b * p, 1, T)
else:
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(0)
# (1, b, 1, T)
# print(key_padding_mask.shape)
# (1, b * p, 1, T)

attn_scores = attn_scores.masked_fill(
key_padding_mask,
Expand Down Expand Up @@ -492,7 +496,7 @@ def __init__(
def forward(
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
) -> Tensor:
src_key_padding_mask = make_pad_mask(lengths)
src_key_padding_mask = make_pad_mask(lengths) if lengths is not None else None
# (batch, max_len)

if am_pruned.ndim == 4 and lm_pruned.ndim == 4:
Expand Down

0 comments on commit d70f6e2

Please sign in to comment.