From 7711d160e4086809f2481aaeae79287dd41c60fd Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Tue, 19 Dec 2023 06:19:27 +0530 Subject: [PATCH] Hybrid autoregressive transducer (HAT) (#1244) * removed workflow * minor fix in nbest str representation * initial commit for HAT loss * add HAT loss * remove unnecessary style changes * fix style issue * put hat option at end --- k2/python/k2/nbest.py | 2 +- k2/python/k2/rnnt_loss.py | 221 ++++++++++++++++++++++++++++-- k2/python/tests/rnnt_loss_test.py | 68 +++++++++ 3 files changed, 282 insertions(+), 9 deletions(-) diff --git a/k2/python/k2/nbest.py b/k2/python/k2/nbest.py index f71f9fefa..d2e7814c9 100644 --- a/k2/python/k2/nbest.py +++ b/k2/python/k2/nbest.py @@ -93,7 +93,7 @@ def __init__(self, def __str__(self): s = 'Nbest(' - s += f'num_seqs:{self.shape.dim0()}, ' + s += f'num_seqs:{self.shape.dim0}, ' s += f'num_fsas:{self.fsa.shape[0]})' return s diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 7e07271d4..9a80c5f9d 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -989,6 +989,194 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): return torch.gather(src, 2, index) +def get_hat_logprobs_pruned( + logits: Tensor, + symbols: Tensor, + ranges: Tensor, + termination_symbol: int, + boundary: Tensor, + rnnt_type: str = "regular", +) -> Tuple[Tensor, Tensor]: + """Construct px, py for mutual_information_recursion with pruned output. + This is a variant of get_rnnt_logprobs_pruned based on the Hybrid Autoregressive + Transducer (HAT) model proposed in https://arxiv.org/abs/2003.07705. + + NOTE: We assume that the RNNT blank is the zeroth symbol. + + Args: + logits: + The pruned output of joiner network, with shape (B, T, s_range, C) + symbols: + The symbol sequences, a LongTensor of shape [B][S], and elements in + {0..C-1}. + ranges: + A tensor containing the symbol ids for each frame that we want to keep. + It is a LongTensor of shape ``[B][T][s_range]``, where ``ranges[b,t,0]`` + contains the begin symbol ``0 <= s <= S - s_range + 1``, such that + ``logits[b,t,:,:]`` represents the logits with positions + ``s, s + 1, ... s + s_range - 1``. + See docs in :func:`get_rnnt_prune_ranges` for more details of what + ranges contains. + termination_symbol: + the termination symbol, with 0 <= termination_symbol < C. + boundary: + a optional LongTensor of shape [B, 4] with elements interpreted as + [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as + [0, 0, S, T] + if boundary is not supplied. + Most likely you will want begin_symbol and begin_frame to be zero. + rnnt_type: + Specifies the type of rnnt paths: `regular`, `modified` or `constrained`. + `regular`: The regular rnnt that taking you to the next frame only if + emitting a blank (i.e., emitting a symbol does not take you + to the next frame). + `modified`: A modified version of rnnt that will take you to the next + frame whether emitting a blank or a non-blank symbol. + `constrained`: A version likes the modified one that will go to the next + frame when you emit a non-blank symbol, but this is done + by "forcing" you to take the blank transition from the + *next* context on the *current* frame, e.g. if we emit + c given "a b" context, we are forced to emit "blank" + given "b c" context on the current frame. + Returns: + (px, py) (the names are quite arbitrary):: + + px: logprobs, of shape [B][S][T+1] if rnnt_type is regular, + [B][S][T] if rnnt_type is not regular. + py: logprobs, of shape [B][S+1][T] + + in the recursion:: + + p[b,0,0] = 0.0 + if rnnt_type == "regular": + p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], + p[b,s,t-1] + py[b,s,t-1]) + if rnnt_type != "regular": + p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], + p[b,s,t-1] + py[b,s,t-1]) + + .. where p[b][s][t] is the "joint score" of the pair of subsequences of + length s and t respectively. px[b][s][t] represents the probability of + extending the subsequences of length (s,t) by one in the s direction, + given the particular symbol, and py[b][s][t] represents the probability + of extending the subsequences of length (s,t) by one in the t direction, + i.e. of emitting the termination/next-frame symbol. + + if `rnnt_type == "regular"`, px[:,:,T] equals -infinity, meaning on the + "one-past-the-last" frame we cannot emit any symbols. + This is simply a way of incorporating + the probability of the termination symbol on the last frame. + """ + # logits (B, T, s_range, C) + # symbols (B, S) + # ranges (B, T, s_range) + assert logits.ndim == 4, logits.shape + (B, T, s_range, C) = logits.shape + assert ranges.shape == (B, T, s_range), (ranges.shape, B, T, s_range) + (B, S) = symbols.shape + assert S >= 0, S + assert ( + rnnt_type != "modified" or T >= S + ), f"Modified transducer requires T >= S, but got T={T} and S={S}" + assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type + assert termination_symbol == 0, f"Termination symbol must be 0, but got {termination_symbol}" + + # For blank symbol, log-prob is log-sigmoid of the score + logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) + + # For non-blank, we will compute the log-probs using log-softmax, for which we + # will need the following normalization factor. + nb_normalizers = torch.logsumexp(logits[..., 1:], dim=3) + + # Additionally, to ensure the the probs of blank and non-blank sum to 1, we + # need to add the following term to the log-probs of non-blank symbols. This + # is equivalent to log(1 - sigmoid(logits[..., 0])). + nb_shift = logp_b - logits[..., 0] + + symbols_with_terminal = torch.cat( + ( + symbols, + torch.tensor( + [termination_symbol] * B, + dtype=torch.int64, + device=symbols.device, + ).reshape((B, 1)), + ), + dim=1, + ) + + # (B, T, s_range) + pruned_symbols = torch.gather( + symbols_with_terminal.unsqueeze(1).expand((B, T, S + 1)), + dim=2, + index=ranges, + ) + + # (B, T, s_range) + px = torch.gather( + logits, dim=3, index=pruned_symbols.reshape(B, T, s_range, 1) + ).squeeze(-1) + px = px - nb_normalizers + nb_shift + + # (B, T, S) with index larger than s_range in dim 2 fill with -inf + px = torch.cat( + ( + px, + torch.full( + (B, T, S + 1 - s_range), + float("-inf"), + device=px.device, + dtype=px.dtype, + ), + ), + dim=2, + ) + + # (B, T, S) with index out of s_range in dim 2 fill with -inf + px = _roll_by_shifts(px, ranges[:, :, 0])[:, :, :S] + + px = px.permute((0, 2, 1)) + + if rnnt_type == "regular": + px = torch.cat( + ( + px, + torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype), + ), + dim=2, + ) # now: [B][S][T+1], index [:,:,T] has -inf.. + + py = logp_b.clone() # (B, T, s_range) + # py is blank log-probs, so we need to subtract the normalizers and add the shift. + # Note that it denotes the horizontal arcs on the RNNT lattice (blank transition) + + # (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf + py = torch.cat( + ( + py, + torch.full( + (B, T, S + 1 - s_range), + float("-inf"), + device=py.device, + dtype=py.dtype, + ), + ), + dim=2, + ) + + # (B, T, S + 1) with index out of s_range in dim 2 fill with -inf + py = _roll_by_shifts(py, ranges[:, :, 0]) + # (B, S + 1, T) + py = py.permute((0, 2, 1)) + + if rnnt_type == "regular": + px = fix_for_boundary(px, boundary) + elif rnnt_type == "constrained": + px += py[:, 1:, :] + + return (px, py) + + def get_rnnt_logprobs_pruned( logits: Tensor, symbols: Tensor, @@ -1169,6 +1357,7 @@ def rnnt_loss_pruned( rnnt_type: str = "regular", delay_penalty: float = 0.0, reduction: Optional[str] = "mean", + use_hat_loss: bool = False, ) -> Tensor: """A RNN-T loss with pruning, which uses the output of a pruned 'joiner' network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), @@ -1219,19 +1408,35 @@ def rnnt_loss_pruned( `mean`: apply `torch.mean` over the batches. `sum`: the output will be summed. Default: `mean` + use_hat_loss: + If True, we compute the Hybrid Autoregressive Transducer (HAT) loss from + https://arxiv.org/abs/2003.07705. This is a variant of RNN-T that models + the blank distribution separately as a Bernoulli distribution, and the + non-blanks are modeled as a multinomial. This formulation may be useful + for performing internal LM estimation, as described in the paper. Returns: If reduction is `none`, returns a tensor of shape (B,), containing the total RNN-T loss values for each sequence of the batch, otherwise a scalar with the reduction applied. """ - px, py = get_rnnt_logprobs_pruned( - logits=logits, - symbols=symbols, - ranges=ranges, - termination_symbol=termination_symbol, - boundary=boundary, - rnnt_type=rnnt_type, - ) + if not use_hat_loss: + px, py = get_rnnt_logprobs_pruned( + logits=logits, + symbols=symbols, + ranges=ranges, + termination_symbol=termination_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + ) + else: + px, py = get_hat_logprobs_pruned( + logits=logits, + symbols=symbols, + ranges=ranges, + termination_symbol=termination_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + ) if delay_penalty > 0.0: B, S, T0 = px.shape diff --git a/k2/python/tests/rnnt_loss_test.py b/k2/python/tests/rnnt_loss_test.py index 917d2f936..c8814c0f5 100644 --- a/k2/python/tests/rnnt_loss_test.py +++ b/k2/python/tests/rnnt_loss_test.py @@ -856,6 +856,74 @@ def test_rnnt_loss_pruned_small_s_range(self): ), f"Pruned loss is inf for r={r}, S={S}, T={T}." print(f"Pruned loss with range {r} : {pruned_loss}") + def test_hat_loss_pruned(self): + B = 4 + T = 300 + S = 50 + C = 10 + + frames = torch.randint(S, T, (B,)) + seq_length = torch.randint(3, S - 1, (B,)) + T = torch.max(frames) + S = torch.max(seq_length) + + am_ = torch.randn((B, T, C), dtype=torch.float64) + lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) + symbols_ = torch.randint(1, C, (B, S)) + terminal_symbol = 0 + + boundary_ = torch.zeros((B, 4), dtype=torch.int64) + boundary_[:, 2] = seq_length + boundary_[:, 3] = frames + + for rnnt_type in ["regular", "modified", "constrained"]: + for device in self.devices: + # normal rnnt + am = am_.to(device) + lm = lm_.to(device) + symbols = symbols_.to(device) + boundary = boundary_.to(device) + + # pruning + k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + return_grad=True, + reduction="none", + ) + + for r in range(2, 50, 5): + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=r, + ) + # (B, T, r, C) + pruned_am, pruned_lm = k2.do_rnnt_pruning( + am=am, lm=lm, ranges=ranges + ) + + logits = pruned_am + pruned_lm + # nonlinear transform + logits = torch.tanh(logits) + + pruned_loss = k2.rnnt_loss_pruned( + logits=logits, + symbols=symbols, + ranges=ranges, + termination_symbol=terminal_symbol, + boundary=boundary, + rnnt_type=rnnt_type, + reduction="none", + use_hat_loss=True, + ) + print(f"Pruned HAT loss with range {r} : {pruned_loss}") + # Check that training with an empty reference does not cause a crash. def _test_rnnt_loss_empty_reference(self): B = 1