forked from k2-fsa/icefall
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add merge_tokens for ctc forced alignment (k2-fsa#1649)
- Loading branch information
1 parent
e765d1c
commit 2a56d44
Showing
3 changed files
with
140 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
from typing import List | ||
|
||
from utils import TokenSpan, merge_tokens | ||
|
||
|
||
def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: | ||
"""Compute start and end frames of each token from the given alignment. | ||
Args: | ||
alignment: | ||
A list of token IDs. | ||
blank_id: | ||
ID of the blank. | ||
Returns: | ||
Return a list of TokenSpan. | ||
""" | ||
ans = [] | ||
last_token = None | ||
last_i = None | ||
|
||
# import pdb | ||
|
||
# pdb.set_trace() | ||
for i, token in enumerate(alignment): | ||
if token == blank: | ||
if last_token is None or last_token == token: | ||
continue | ||
|
||
# end of the last token | ||
span = TokenSpan(token=last_token, start=last_i, end=i) | ||
ans.append(span) | ||
last_token = None | ||
last_i = None | ||
continue | ||
|
||
# The current token is not a blank | ||
if last_token is None or last_token == blank: | ||
last_token = token | ||
last_i = i | ||
continue | ||
|
||
if last_token == token: | ||
continue | ||
|
||
# end of the last token and start of the current token | ||
span = TokenSpan(token=last_token, start=last_i, end=i) | ||
last_token = token | ||
last_i = i | ||
ans.append(span) | ||
|
||
if last_token is not None: | ||
assert last_i is not None, (last_i, last_token) | ||
span = TokenSpan(token=last_token, start=last_i, end=len(alignment)) | ||
# Note for the last token, its end is larger than len(alignment)-1 | ||
ans.append(span) | ||
|
||
return ans | ||
|
||
|
||
def test_merge_tokens(): | ||
data_list = [ | ||
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | ||
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], | ||
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], | ||
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0], | ||
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3], | ||
[0, 1, 2, 3, 0], | ||
[1, 2, 3, 0], | ||
[0, 1, 2, 3], | ||
[1, 2, 3], | ||
] | ||
|
||
for data in data_list: | ||
span1 = merge_tokens(data) | ||
span2 = inefficient_merge_tokens(data) | ||
assert span1 == span2, (data, span1, span2) | ||
|
||
|
||
def main(): | ||
test_merge_tokens() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
import torch | ||
|
||
|
||
@dataclass | ||
class TokenSpan: | ||
# ID of the token | ||
token: int | ||
|
||
# Start frame of this token in the output log_prob | ||
start: int | ||
|
||
# End frame of this token in the output log_prob | ||
end: int | ||
|
||
|
||
# See also | ||
# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96 | ||
# We use torchaudio as a reference while implementing this function | ||
def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]: | ||
"""Compute start and end frames of each token from the given alignment. | ||
Args: | ||
alignment: | ||
A list of token IDs. | ||
blank_id: | ||
ID of the blank. | ||
Returns: | ||
Return a list of TokenSpan. | ||
""" | ||
alignment_tensor = torch.tensor(alignment, dtype=torch.int32) | ||
|
||
diff = torch.diff( | ||
alignment_tensor, | ||
prepend=torch.tensor([-1]), | ||
append=torch.tensor([-1]), | ||
) | ||
|
||
non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist() | ||
|
||
ans = [] | ||
for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]): | ||
token = alignment[start] | ||
if token == blank: | ||
continue | ||
span = TokenSpan(token=token, start=start, end=end) | ||
ans.append(span) | ||
return ans |