Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTC API for JAX #18952

Merged
merged 23 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8ca8bf8
Implement CTC loss in tensorflow backend
MaanasArora Dec 12, 2023
3b2d7ca
Implement CTC api in torch backend
MaanasArora Dec 12, 2023
ecbb28e
Add CTC loss to keras losses
MaanasArora Dec 12, 2023
e7584f7
Remove CTC from losses
MaanasArora Dec 12, 2023
f7554b4
Perform log softmax in torch CTC loss
MaanasArora Dec 12, 2023
cbbd628
Refactor reviewed code in CTC API
MaanasArora Dec 12, 2023
cb97574
Fix formatting issue in docstring
MaanasArora Dec 13, 2023
bedad98
Removed trailing space
MaanasArora Dec 13, 2023
ab2efc9
Naming changes in nn.ctc_loss backend functions
MaanasArora Dec 13, 2023
0741531
Add ctc_loss keras op
MaanasArora Dec 13, 2023
1ffa35d
Add correctness unit test for CTC loss
MaanasArora Dec 13, 2023
4edf18d
Skip test for CTC loss in JAX backend
MaanasArora Dec 13, 2023
85865ba
Update ctc_loss function to also export to ops.nn
MaanasArora Dec 13, 2023
ab8a81b
Add static type testing for CTC loss
MaanasArora Dec 13, 2023
92e6654
Fix enabled backends for CTC loss test
MaanasArora Dec 14, 2023
97873a0
Linting keras ops
MaanasArora Dec 14, 2023
cfd949e
Fix line overflow in CtcLoss class
MaanasArora Dec 14, 2023
58caf83
CTC loss implementation for JAX
MaanasArora Dec 17, 2023
56d84dd
Merge branch 'master' of github.com:keras-team/keras into ctc-jax
MaanasArora Dec 17, 2023
90801dc
Fix shape order in ctc loss documentation
MaanasArora Dec 17, 2023
28032cd
Transpose output in CTC loss
MaanasArora Dec 17, 2023
9f266e3
Use logits_time_major instead of transpose in TF CTC
MaanasArora Dec 17, 2023
5e98c5c
Fix linting issue
MaanasArora Dec 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions keras/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,78 @@ def batch_normalization(
res = res + offset

return x * inv + res


def ctc_loss(
target,
output,
target_length,
output_length,
mask_index=0,
):
batch_size, _, _ = output.shape
batch_size, max_target_length = target.shape

output = output.transpose((1, 0, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary for the scan op on the first dimension?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes AFAIK, since we're scanning along the time dimension and scan runs along the leading axis.

target = target.transpose((1, 0))

logits = jnn.log_softmax(output)
mgrid_t, mgrid_b = jnp.meshgrid(
jnp.arange(max_target_length), jnp.arange(batch_size)
)
logprobs_emit = logits[mgrid_t, mgrid_b, target[:, :, None]]
logprobs_mask = logits[:, :, mask_index]

logit_paddings = jnp.array(
jnp.arange(max_target_length) < output_length[:, None],
dtype=jnp.float32,
)

repeat = jnp.array(target[1:] == target[:-1])
repeat = jnp.pad(repeat, ((0, 1), (0, 0))).transpose((1, 0))

_logepsilon = -100000.0

def _iterate(prev, x):
prev_mask, prev_emit = prev
logprob_mask, logprob_emit, pad = x

prev_mask_orig = prev_mask
prev_mask = prev_mask.at[:, 1:].set(
jnp.logaddexp(prev_mask[:, 1:], prev_emit + _logepsilon * repeat),
)
emit = jnp.logaddexp(
prev_mask[:, :-1] + logprob_emit, prev_emit + logprob_emit
)

mask = prev_mask + logprob_mask[:, None]
mask = mask.at[:, 1:].set(
jnp.logaddexp(
mask[:, 1:],
prev_emit + logprob_mask[:, None] + _logepsilon * (1 - repeat),
)
)

pad = pad[:, None]
emit = emit * pad + prev_emit * (1 - pad)
mask = mask * pad + prev_mask_orig * (1 - pad)

return (mask, emit), (mask, emit)

mask_init = jnp.full((batch_size, max_target_length + 1), _logepsilon)
mask_init = mask_init.at[:, 0].set(0.0)
emit_init = jnp.full((batch_size, max_target_length), _logepsilon)

_, (alphas_mask, alphas_emit) = lax.scan(
_iterate,
(mask_init, emit_init),
(logprobs_mask, logprobs_emit, logit_paddings.transpose()),
)

last_alpha_mask = (
alphas_mask[-1]
.at[:, 1:]
.set(jnp.logaddexp(alphas_mask[-1, :, 1:], alphas_emit[-1]))
)

return -last_alpha_mask[jnp.arange(batch_size), target_length]
1 change: 1 addition & 0 deletions keras/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,4 +838,5 @@ def ctc_loss(
label_length=target_length,
logit_length=output_length,
blank_index=mask_index,
logits_time_major=False,
)
1 change: 1 addition & 0 deletions keras/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def ctc_loss(
target_length = convert_to_tensor(target_length)
output_length = convert_to_tensor(output_length)

output = torch.transpose(output, 1, 0)
logits = tnn.log_softmax(output, dim=-1)

return tnn.ctc_loss(
Expand Down
13 changes: 6 additions & 7 deletions keras/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,8 +975,8 @@ def test_batch_normalization(self):
)

@pytest.mark.skipif(
backend.backend() not in ["tensorflow", "torch"],
reason="Only TF and Torch support CTC loss",
backend.backend() == "numpy",
reason="Numpy does not support CTC loss",
)
def test_ctc_loss(self):
x = KerasTensor([10, 3, 4])
Expand Down Expand Up @@ -1762,16 +1762,15 @@ def test_batch_normalization(self):
self.assertEqual(tuple(output.shape), (2, 3, 3, 5))

@pytest.mark.skipif(
backend.backend() not in ["tensorflow", "torch"],
reason="Only TF and Torch support CTC loss",
backend.backend() == "numpy",
reason="Numpy does not support CTC loss",
)
def test_ctc_loss(self):
labels = np.array([[1, 2, 1], [1, 2, 2]])
outputs = np.array(
[
[[0.4, 0.8, 0.4], [0.4, 0.8, 0.4]],
[[0.2, 0.8, 0.3], [0.2, 0.3, 0.3]],
[[0.9, 0.4, 0.5], [0.4, 0.3, 0.2]],
[[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]],
[[0.4, 0.8, 0.4], [0.2, 0.3, 0.3], [0.4, 0.3, 0.2]],
]
)

Expand Down