Losses and decoders for end-to-end Speech Recognition and Optical Character Recognition with PyTorch
The module focuses on experiments with CTC-loss (Connectionist Temporal Classification) and its modifications.
Under active development.
- CTC (C++, CPU)
- CTC Greedy Decoder (C++, CPU)
- CTC Beam Search Decoder (C++, CPU)
- CTC Beam Search Decoder with language model (C++, CPU)
Requirements:
- Python 3.6+
- Tested with PyTorch 1.6.0+ (maybe compatible with other versions)
-
Install PyTorch from pytorch.org, e.g.
pip install torch
-
Install tools to compile
sudo add-apt-repository ppa:ubuntu-toolchain-r/test -y && \ sudo apt-get update && \ sudo apt-get install libboost-all-dev g++-7 -y
-
Install the module
pip install -v git+https://github.com/artbataev/end2end.git
or
git clone --recursive https://github.com/artbataev/end2end.git cd end2end python setup.py install python -m tests.test_ctc python -m tests.test_ctc_decoder
import torch
from pytorch_end2end import CTCLoss
ctc_loss = CTCLoss(blank_idx=0, time_major=False,
reduce=True, size_average=True, after_logsoftmax=False)
batch_size = 4
alphabet_size = 28 # blank + 26 english characters + space
logits = torch.randn(batch_size, 50, alphabet_size).detach().requires_grad_()
targets = torch.randint(1, alphabet_size, (batch_size, 30), dtype=torch.long)
logits_lengths = torch.full((batch_size,), 50, dtype=torch.long)
targets_lengths = torch.randint(10, 30, (batch_size,), dtype=torch.long)
loss = ctc_loss(logits, targets, logits_lengths, targets_lengths)
loss.backward()
import torch
from pytorch_end2end import CTCDecoder
batch_size = 4
alphabet_size = 6
decoder = CTCDecoder(blank_idx=0, beam_width=100,
time_major=False, after_logsoftmax=False,
labels=["_", "a", "b", "c", "d", " "])
logits = torch.randn(batch_size, 50, alphabet_size).detach()
logits_lengths = torch.full((batch_size,), 50, dtype=torch.long)
decoded_targets, decoded_targets_lengths, decoded_sentences = decoder.decode(logits, logits_lengths)
for sentence in decoded_sentences:
print(sentence)
- Gram-CTC
- CTC without blank (C++, CPU)
- CTC (Cuda)
- CTC without blank refactoring
- Dynamic segmentation with CTC refactoring
- Restrict Beam Search with vocabulary
- Allow custom transcriptions
- Gram-CTC Beam Search Decoder