Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Yifan Yang committed Jul 6, 2023
1 parent 438fef7 commit ccdacc1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
Submodule icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 added at 9417bd
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,20 @@ def forward(
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
T = (
T_arange = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
T = torch.remainder(T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
T_arange = torch.remainder(T_arange, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, T_arange)
limit_mask = torch.full_like(
non_blank_mask,
False,
0,
device=x.device,
).scatter_(1, limit_indexes, True)
).scatter_(1, limit_indexes, 1)

non_blank_mask = non_blank_mask | ~limit_mask

Expand All @@ -108,9 +108,9 @@ def forward(
)
- out_lens
)
max_pad_len = pad_lens_list.max()
max_pad_len = int(pad_lens_list.max())

out = F.pad(x, (0, 0, 0, max_pad_len))
out = F.pad(x, [0, 0, 0, max_pad_len])

valid_pad_mask = ~make_pad_mask(pad_lens_list)
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
Expand Down

0 comments on commit ccdacc1

Please sign in to comment.