Skip to content

Commit

Permalink
Add merge_tokens for ctc forced alignment (k2-fsa#1649)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored and Your Name committed Aug 9, 2024
1 parent e765d1c commit 2a56d44
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
1 change: 1 addition & 0 deletions icefall/ctc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
make_lexicon_fst_with_silence,
)
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
from .utils import merge_tokens
87 changes: 87 additions & 0 deletions icefall/ctc/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

from typing import List

from utils import TokenSpan, merge_tokens


def inefficient_merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
"""Compute start and end frames of each token from the given alignment.
Args:
alignment:
A list of token IDs.
blank_id:
ID of the blank.
Returns:
Return a list of TokenSpan.
"""
ans = []
last_token = None
last_i = None

# import pdb

# pdb.set_trace()
for i, token in enumerate(alignment):
if token == blank:
if last_token is None or last_token == token:
continue

# end of the last token
span = TokenSpan(token=last_token, start=last_i, end=i)
ans.append(span)
last_token = None
last_i = None
continue

# The current token is not a blank
if last_token is None or last_token == blank:
last_token = token
last_i = i
continue

if last_token == token:
continue

# end of the last token and start of the current token
span = TokenSpan(token=last_token, start=last_i, end=i)
last_token = token
last_i = i
ans.append(span)

if last_token is not None:
assert last_i is not None, (last_i, last_token)
span = TokenSpan(token=last_token, start=last_i, end=len(alignment))
# Note for the last token, its end is larger than len(alignment)-1
ans.append(span)

return ans


def test_merge_tokens():
data_list = [
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
[0, 1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3, 0],
[1, 1, 1, 2, 0, 0, 0, 2, 2, 3, 2, 3, 3],
[0, 1, 2, 3, 0],
[1, 2, 3, 0],
[0, 1, 2, 3],
[1, 2, 3],
]

for data in data_list:
span1 = merge_tokens(data)
span2 = inefficient_merge_tokens(data)
assert span1 == span2, (data, span1, span2)


def main():
test_merge_tokens()


if __name__ == "__main__":
main()
52 changes: 52 additions & 0 deletions icefall/ctc/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

from dataclasses import dataclass
from typing import List

import torch


@dataclass
class TokenSpan:
# ID of the token
token: int

# Start frame of this token in the output log_prob
start: int

# End frame of this token in the output log_prob
end: int


# See also
# https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/_alignment.py#L96
# We use torchaudio as a reference while implementing this function
def merge_tokens(alignment: List[int], blank: int = 0) -> List[TokenSpan]:
"""Compute start and end frames of each token from the given alignment.
Args:
alignment:
A list of token IDs.
blank_id:
ID of the blank.
Returns:
Return a list of TokenSpan.
"""
alignment_tensor = torch.tensor(alignment, dtype=torch.int32)

diff = torch.diff(
alignment_tensor,
prepend=torch.tensor([-1]),
append=torch.tensor([-1]),
)

non_zero_indexes = torch.nonzero(diff != 0).squeeze().tolist()

ans = []
for start, end in zip(non_zero_indexes[:-1], non_zero_indexes[1:]):
token = alignment[start]
if token == blank:
continue
span = TokenSpan(token=token, start=start, end=end)
ans.append(span)
return ans

0 comments on commit 2a56d44

Please sign in to comment.