-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CTC API for JAX #18952
CTC API for JAX #18952
Conversation
- Refactor sparse labels into main ctc_batch_cost function for tf
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #18952 +/- ##
==========================================
+ Coverage 79.55% 79.58% +0.02%
==========================================
Files 337 337
Lines 35056 35116 +60
Branches 6872 6879 +7
==========================================
+ Hits 27890 27947 +57
- Misses 5587 5588 +1
- Partials 1579 1581 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/ops/nn.py
Outdated
@@ -1823,7 +1823,7 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0): | |||
Args: | |||
target: A tensor of shape `(batch_size, target_max_length)` containing | |||
the true labels in integer format. | |||
output: A tensor of shape `(batch_size, output_max_length, num_classes)` | |||
output: A tensor of shape `(output_max_length, batch_size, num_classes)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems very counterintuitive given that all output tensors in Keras start with the batch dimension. Surely we should transpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely, this is fixed now. Thanks!
keras/backend/tensorflow/nn.py
Outdated
@@ -819,6 +819,7 @@ def ctc_loss( | |||
target = tf.cast(target, dtype="int32") | |||
output = tf.convert_to_tensor(output) | |||
output = tf.cast(output, dtype="float32") | |||
output = tf.transpose(output, perm=(1, 0, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the logits_time_major=False
arg in tf.nn.ctc_loss
to avoid this transposition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, done, thanks!
batch_size, _, _ = output.shape | ||
batch_size, max_target_length = target.shape | ||
|
||
output = output.transpose((1, 0, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary for the scan op on the first dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes AFAIK, since we're scanning along the time dimension and scan runs along the leading axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
CTC loss implementation for JAX. I also included a documentation fix for the op.
Thank you!