From ccdacc1b4428d3afe6e7d379951cc46983100942 Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Thu, 6 Jul 2023 22:04:00 +0800 Subject: [PATCH] Fix --- ...-pruned_transducer_stateless7_ctc_bs-2023-01-29 | 1 + .../frame_reducer.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) create mode 160000 egs/librispeech/ASR/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 diff --git a/egs/librispeech/ASR/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 b/egs/librispeech/ASR/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 new file mode 160000 index 0000000000..9417bd9bc4 --- /dev/null +++ b/egs/librispeech/ASR/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 @@ -0,0 +1 @@ +Subproject commit 9417bd9bc4aae7ada8b7943f5849828eecbf3c91 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 0841f7cf16..c44cb1eafa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -81,20 +81,20 @@ def forward( fake_limit_indexes = torch.topk( ctc_output[:, :, blank_id], max_limit_len ).indices - T = ( + T_arange = ( torch.arange(max_limit_len) .expand_as( fake_limit_indexes, ) .to(device=x.device) ) - T = torch.remainder(T, limit_lens.unsqueeze(1)) - limit_indexes = torch.gather(fake_limit_indexes, 1, T) + T_arange = torch.remainder(T_arange, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, T_arange) limit_mask = torch.full_like( non_blank_mask, - False, + 0, device=x.device, - ).scatter_(1, limit_indexes, True) + ).scatter_(1, limit_indexes, 1) non_blank_mask = non_blank_mask | ~limit_mask @@ -108,9 +108,9 @@ def forward( ) - out_lens ) - max_pad_len = pad_lens_list.max() + max_pad_len = int(pad_lens_list.max()) - out = F.pad(x, (0, 0, 0, max_pad_len)) + out = F.pad(x, [0, 0, 0, max_pad_len]) valid_pad_mask = ~make_pad_mask(pad_lens_list) total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)