Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTC API for JAX #18952

Merged
merged 23 commits into from
Dec 18, 2023
Merged

CTC API for JAX #18952

merged 23 commits into from
Dec 18, 2023

Conversation

MaanasArora
Copy link
Contributor

CTC loss implementation for JAX. I also included a documentation fix for the op.

Thank you!

@codecov-commenter
Copy link

codecov-commenter commented Dec 17, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (d550552) 79.55% compared to head (5e98c5c) 79.58%.
Report is 3 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18952      +/-   ##
==========================================
+ Coverage   79.55%   79.58%   +0.02%     
==========================================
  Files         337      337              
  Lines       35056    35116      +60     
  Branches     6872     6879       +7     
==========================================
+ Hits        27890    27947      +57     
- Misses       5587     5588       +1     
- Partials     1579     1581       +2     
Flag Coverage Δ
keras 79.44% <100.00%> (+0.02%) ⬆️
keras-jax 61.27% <96.87%> (+0.08%) ⬆️
keras-numpy 55.96% <3.12%> (-0.01%) ⬇️
keras-tensorflow 63.16% <3.12%> (-0.04%) ⬇️
keras-torch 63.81% <6.25%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras/ops/nn.py Outdated
@@ -1823,7 +1823,7 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0):
Args:
target: A tensor of shape `(batch_size, target_max_length)` containing
the true labels in integer format.
output: A tensor of shape `(batch_size, output_max_length, num_classes)`
output: A tensor of shape `(output_max_length, batch_size, num_classes)`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems very counterintuitive given that all output tensors in Keras start with the batch dimension. Surely we should transpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, this is fixed now. Thanks!

@@ -819,6 +819,7 @@ def ctc_loss(
target = tf.cast(target, dtype="int32")
output = tf.convert_to_tensor(output)
output = tf.cast(output, dtype="float32")
output = tf.transpose(output, perm=(1, 0, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the logits_time_major=False arg in tf.nn.ctc_loss to avoid this transposition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, done, thanks!

batch_size, _, _ = output.shape
batch_size, max_target_length = target.shape

output = output.transpose((1, 0, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary for the scan op on the first dimension?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes AFAIK, since we're scanning along the time dimension and scan runs along the leading axis.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@google-ml-butler google-ml-butler bot added the ready to pull Ready to be merged into the codebase label Dec 18, 2023
@fchollet fchollet merged commit 6a6e4f8 into keras-team:master Dec 18, 2023
6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

4 participants