Skip to content
forked from iamjanvijay/rnnt

An implementation of RNN-Transducer loss in TF-2.0.

License

Notifications You must be signed in to change notification settings

yjiangling/rnnt

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RNN-Transducer Loss

This package provides a implementation of Transducer Loss in TensorFlow==2.0.

Using the pakage

First install the module using pip command.

pip install rnnt

Then use the "rnnt" loss funtion from "rnnt" module, as described in the sample script: Sample Train Script

from rnnt import rnnt_loss

def loss_grad_gradtape(logits, labels, label_lengths, logit_lengths):
    with tf.GradientTape() as g:
        g.watch(logits)
        loss = rnnt_loss(logits, labels, label_lengths, logit_lengths)
    grad = g.gradient(loss, logits)
    return loss, grad
    
pred_loss, pred_grads = loss_grad_gradtape(logits, labels, label_lengths, logit_lengths)

Follwing are the shapes of input parameters for rnnt_loss method -
logits - (batch_size, input_time_steps, output_time_steps+1, vocab_size+1)
labels - (batch_size, output_time_steps)
label_length - (batch_size) - number of time steps for each output sequence in the minibatch.
logit_length - (batch_size) - number of time steps for each input sequence in the minibatch.

About

An implementation of RNN-Transducer loss in TF-2.0.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%