diff --git a/nemo/collections/asr/parts/k2/graph_transducer.py b/nemo/collections/asr/parts/k2/graph_transducer.py index bcd49bcbd7a9..874e6e6fd2b4 100644 --- a/nemo/collections/asr/parts/k2/graph_transducer.py +++ b/nemo/collections/asr/parts/k2/graph_transducer.py @@ -15,11 +15,17 @@ import abc from contextlib import nullcontext from typing import ContextManager + import torch import torch.nn.functional as F from nemo.core.classes.loss import Loss from nemo.core.utils.k2_guard import k2 +from nemo.core.utils.optional_libs import TRITON_AVAILABLE +from nemo.utils import logging + +if TRITON_AVAILABLE: + from nemo.collections.asr.parts.k2.rnnt_logprobs_triton import rnnt_logprobs_triton def force_float32_context() -> ContextManager: @@ -129,13 +135,13 @@ def get_composed_lattice(self, units_tensor: torch.Tensor, num_frames: int, voca return composed def get_graphs_batched( - self, logits_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, vocab_size: int + self, source_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, vocab_size: int ) -> "k2.Fsa": """ Get batched lattice (grid or composed) for the batch of sequences. Args: - logits_lengths: tensor with lengths of logits + source_lengths: tensor with lengths of logits targets: tensor with target units target_lengths: tensor with lengths of targets vocab_size: vocab size (including blank) @@ -143,14 +149,16 @@ def get_graphs_batched( Returns: batched lattice - FsaVec (k2.Fsa) """ - batch_size = logits_lengths.shape[0] + batch_size = source_lengths.shape[0] with torch.no_grad(): if self.use_grid_implementation: + source_lengths_list = source_lengths.tolist() + target_lengths_list = target_lengths.tolist() return k2.create_fsa_vec( [ self.get_grid( - units_tensor=targets[i, : target_lengths[i].item()], - num_frames=logits_lengths[i].item(), + units_tensor=targets[i, : target_lengths_list[i]], + num_frames=source_lengths_list[i], vocab_size=vocab_size, ) for i in range(batch_size) @@ -167,7 +175,7 @@ def get_graphs_batched( ] temporal_fsas = [ self.get_temporal_schema( - num_frames=logits_lengths[i].item(), vocab_size=vocab_size, device=targets.device + num_frames=source_lengths[i].item(), vocab_size=vocab_size, device=targets.device ) for i in range(batch_size) ] @@ -175,22 +183,20 @@ def get_graphs_batched( k2.create_fsa_vec(text_fsas), k2.create_fsa_vec(temporal_fsas), treat_epsilons_specially=False ) if self.connect_composed: - k2.connect(target_fsas_vec) + target_fsas_vec = k2.connect(target_fsas_vec) return target_fsas_vec - def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) -> torch.Tensor: + def get_batch_indices(self, target_fsas_vec: k2.Fsa) -> torch.Tensor: """ - Get indices of flatten logits for each arc in the lattices. + Get batch indices (for logits) for each arc in the lattices. Args: target_fsas_vec: batch of target FSAs with lattices - logits_shape: shape of the logits tensor Returns: 1d tensor with indices """ - # logits_shape: B x Time x Text+1 x Labels - batch_size = logits_shape[0] + batch_size = target_fsas_vec.shape[0] device = target_fsas_vec.device scores_to_batch_i = torch.repeat_interleave( torch.arange(batch_size, device=device, dtype=torch.int64), @@ -199,6 +205,21 @@ def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) device=device, ), ) + return scores_to_batch_i + + def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) -> torch.Tensor: + """ + Get indices of flatten logits for each arc in the lattices. + + Args: + target_fsas_vec: batch of target FSAs with lattices + logits_shape: shape of the logits tensor + + Returns: + 1d tensor with indices + """ + # logits_shape: B x Time x Text+1 x Labels + scores_to_batch_i = self.get_batch_indices(target_fsas_vec=target_fsas_vec) indices = ( scores_to_batch_i * logits_shape[1] * logits_shape[2] * logits_shape[3] # Batch + target_fsas_vec.aux_labels.to(torch.int64) * logits_shape[2] * logits_shape[3] # Time indices @@ -222,6 +243,8 @@ def __init__( connect_composed=False, double_scores=False, cast_to_float32=False, + return_graph=False, + use_triton=True, ): """ Init method @@ -232,8 +255,11 @@ def __init__( connect_composed: Connect graph after composing unit and temporal schemas (only for Compose-Transducer). `connect` operation is slow, it is useful for visualization, but not necessary for loss computation. double_scores: Use calculation of loss in double precision (float64) in the lattice. - Does not significantly affect memory usage since the lattice is ~V/2 times smaller than the joint tensor. + Does not significantly affect memory usage since the lattice is ~V/2 times smaller + than the joint tensor. cast_to_float32: Force cast joint tensor to float32 before log-softmax calculation. + return_graph: Return graph (along with loss) from `forward` function + use_triton: use optimized log probs calculations with Triton (faster and more memory efficient) """ super().__init__( use_grid_implementation=use_grid_implementation, @@ -242,6 +268,10 @@ def __init__( cast_to_float32=cast_to_float32, ) self.blank = blank + self.return_graph = return_graph + self.use_triton = use_triton and TRITON_AVAILABLE + if not self.use_triton: + logging.warning("Triton is disabled, memory usage can be larger") def get_unit_schema(self, units_tensor: torch.Tensor, vocab_size: int) -> "k2.Fsa": """ @@ -370,13 +400,14 @@ def relabel_states(states: torch.Tensor, n: int, m: int) -> torch.Tensor: anti_diag = m + n - 1 - diag max_idx = n * m - 1 cur_diag_idx = i if m > n else m - j - 1 - states = ( + new_states = ( diag.lt(min_mn) * ((diag * (diag + 1) >> 1) + i) + torch.logical_and(diag.ge(min_mn), diag.lt(max_mn)) * ((min_mn * (min_mn + 1) >> 1) + (diag - min_mn) * min_mn + cur_diag_idx) + diag.ge(max_mn) * (max_idx - (anti_diag * (anti_diag + 1) >> 1) + m - j) ) - return states + torch.where(states >= n * m, states, new_states, out=new_states) + return new_states def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) -> "k2.Fsa": """ @@ -445,13 +476,76 @@ def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) rnnt_graph.unit_positions = unit_positions return rnnt_graph + def get_weighted_graphs( + self, + logits: torch.Tensor, + targets: torch.Tensor, + source_lengths: torch.Tensor, + target_lengths: torch.Tensor, + use_graph_weight=False, + ) -> "k2.Fsa": + """ + Get batch of graphs (FsaVec) for RNN-T loss calculation. + + Args: + logits: activations (joint tensor). NB: raw logits, not after log-softmax + targets: target labels + source_lengths: lengths of source sequences + target_lengths: length of target sequences + use_graph_weight: uses weight from graphs (if `get_graphs_batched` returns graphs with weights) + + Returns: + FsaVec containing RNN-T graphs for all utterances. + """ + vocab_size = logits.shape[-1] + target_fsas_vec = self.get_graphs_batched(source_lengths, targets, target_lengths, vocab_size) + + with torch.no_grad(): + # last transitions in the graph are labeled with -1 label + last_transition_mask = target_fsas_vec.labels == -1 + batch_indices = self.get_batch_indices(target_fsas_vec=target_fsas_vec) + time_indices = target_fsas_vec.aux_labels.clone().to(torch.int64) + unit_indices = target_fsas_vec.unit_positions.clone().to(torch.int64) + text_units = target_fsas_vec.labels.clone().to(torch.int64) + # fill in the indices outside the logits with 0, replace later + text_units.masked_fill_(last_transition_mask, 0) + + cast_context = force_float32_context() if self.cast_to_float32 else nullcontext() + with cast_context: + # NB: do not assign scores -> modify, k2 will not update all scores correctly (modify -> assign) + if self.use_triton and logits.device.type == "cuda": + unit_scores, blank_scores = rnnt_logprobs_triton( + logits=logits, + targets=targets, + blank_id=self.blank, + source_lengths=source_lengths, + target_lengths=target_lengths, + ) + text_units_blank_mask = text_units == self.blank + scores = torch.where( + text_units_blank_mask, + blank_scores[batch_indices, time_indices, unit_indices], + unit_scores[batch_indices, time_indices, unit_indices], + ).to(torch.float32) + scores[last_transition_mask] = 0.0 # fix weights for the arcs to the last state + else: + log_probs = F.log_softmax(logits, dim=-1) + scores = log_probs[batch_indices, time_indices, unit_indices, text_units].to(torch.float32) + scores[last_transition_mask] = 0.0 + + if use_graph_weight: + target_fsas_vec.scores = target_fsas_vec.scores + scores + else: + target_fsas_vec.scores = scores + return target_fsas_vec + def forward( self, acts: torch.Tensor, labels: torch.Tensor, act_lens: torch.Tensor, label_lens: torch.Tensor, - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, "k2.Fsa"]: """ Compute forward method for RNN-T. @@ -466,26 +560,11 @@ def forward( """ # argument names are consistent with NeMo, see RNNTLoss.forward: # self._loss(acts=log_probs, labels=targets, act_lens=input_lengths, label_lens=target_lengths) - logits, targets, logits_lengths, target_lengths = acts, labels, act_lens, label_lens - - # logits: B x Time x Text+1 x C - vocab_size = logits.shape[-1] - target_fsas_vec = self.get_graphs_batched(logits_lengths, targets, target_lengths, vocab_size) - - cast_context = force_float32_context() if self.cast_to_float32 else nullcontext() - with cast_context: - log_probs = F.log_softmax(logits, dim=-1) - with torch.no_grad(): - indices = self.get_logits_indices(target_fsas_vec, logits.shape) - # transition to the last state - # use 0 index (for valid index_select) and manually assign score after index_select for this case - indices[target_fsas_vec.labels == -1] = 0 - - # NB: do not assign scores -> modify, k2 will not update all scores correctly (modify -> assign) - scores = log_probs.flatten().index_select(-1, indices) - # fix weights for the arcs to the last state - scores[target_fsas_vec.labels == -1] = 0 + target_fsas_vec = self.get_weighted_graphs( + logits=acts, targets=labels, source_lengths=act_lens, target_lengths=label_lens, use_graph_weight=False + ) - target_fsas_vec.scores = scores - scores = -1 * target_fsas_vec.get_tot_scores(use_double_scores=self.double_scores, log_semiring=True) - return scores + scores = -1 * target_fsas_vec.get_tot_scores(use_double_scores=self.double_scores, log_semiring=True) + if self.return_graph: + return scores, target_fsas_vec + return scores diff --git a/nemo/collections/asr/parts/k2/rnnt_logprobs.py b/nemo/collections/asr/parts/k2/rnnt_logprobs.py new file mode 100644 index 000000000000..c41615f83bf9 --- /dev/null +++ b/nemo/collections/asr/parts/k2/rnnt_logprobs.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +def rnnt_logprobs_torch( + logits: torch.Tensor, targets: torch.Tensor, blank_id: int +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Given logits, calculate log probabilities for blank and target labels needed for transducer loss calculation. + Naive implementation in PyTorch, for testing and prototyping purposes. + + Args: + logits: Joint tensor of size [B, T, U+1, D] + targets: Targets of size [B, U] + blank_id: id of the blank output + + Returns: + Tuple of tensors with log probabilities for targets and blank labels, both of size [B, T, U+1]. + For the last non-existent target (U+1) output is zero. + """ + device = logits.device + batch_size = logits.shape[0] + log_probs = F.log_softmax(logits, dim=-1) + blank_scores = log_probs[..., blank_id] + targets = torch.cat((targets, torch.zeros([batch_size], dtype=targets.dtype, device=device).unsqueeze(1)), dim=-1) + target_scores = torch.gather( + log_probs, dim=-1, index=targets.unsqueeze(1).expand(log_probs.shape[:-1]).unsqueeze(-1) + ).squeeze(-1) + target_scores[:, :, -1] = 0.0 + return target_scores, blank_scores diff --git a/nemo/collections/asr/parts/k2/rnnt_logprobs_triton.py b/nemo/collections/asr/parts/k2/rnnt_logprobs_triton.py new file mode 100644 index 000000000000..64bc8abbdbeb --- /dev/null +++ b/nemo/collections/asr/parts/k2/rnnt_logprobs_triton.py @@ -0,0 +1,250 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rnnt_logprobs_fwd_kernel( + logits_ptr, + targets_ptr, + source_lengths_ptr, + target_lengths_ptr, + max_source_len: int, + max_target_len_plus_1: int, + num_labels: int, # vocab size (with blank) + blank_id: int, + target_scores_ptr, + blank_scores_ptr, + BLOCK_SIZE: tl.constexpr, +): + """ + Forward kernel for RNN-T log probs. Stores result in `target_scores_ptr` and `blank_scores_ptr`. + Calculations are performed in float32 (but original tensors can use any precision). + """ + batch_i = tl.program_id(axis=0).to(tl.int64) + source_i = tl.program_id(axis=1).to(tl.int64) + target_i = tl.program_id(axis=2).to(tl.int64) + + # load lengths for source/target + source_len = tl.load(source_lengths_ptr + batch_i) + target_len = tl.load(target_lengths_ptr + batch_i) + + if source_i >= source_len or target_i > target_len: + # no calculations required + return + + # calculate offset in [B, T, U+1, V] tensor for the current vector with target logits + flat_index = ((batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i) * num_labels + logits_ptr += flat_index + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < num_labels + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + # stable log softmax calculation + logits_max = tl.max(logits, axis=0) + logits_minus_max = logits - logits_max + denominator = tl.log(tl.sum(tl.exp(logits_minus_max), axis=0)) + blank_logit = tl.load(logits_ptr + blank_id).to(tl.float32) + flat_index_output = (batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i + tl.store(blank_scores_ptr + flat_index_output, blank_logit - logits_max - denominator) + + # calculate log prob for target if needed + if target_i < target_len: + target_id = tl.load(targets_ptr + batch_i * (max_target_len_plus_1 - 1) + target_i) + target_logit = tl.load(logits_ptr + target_id).to(tl.float32) + tl.store(target_scores_ptr + flat_index_output, target_logit - logits_max - denominator) + + +@triton.jit +def _rnnt_logprobs_bwd_kernel( + logits_ptr, + grad_logits_ptr, + targets_ptr, + source_lengths_ptr, + target_lengths_ptr, + max_source_len: int, + max_target_len_plus_1: int, + num_labels: int, + blank_id: int, + grad_target_scores_ptr, + grad_blank_scores_ptr, + BLOCK_SIZE: tl.constexpr, +): + """ + Backward kernel for RNN-T log probs. Stores result in `grad_target_scores_ptr` and `grad_blank_scores_ptr`. + We recalculate part of the forward here to avoid using extra memory in forward. + Calculations are performed in float32 (but original tensors can use any precision). + """ + batch_i = tl.program_id(axis=0).to(tl.int64) + source_i = tl.program_id(axis=1).to(tl.int64) + target_i = tl.program_id(axis=2).to(tl.int64) + + # load lengths for source/target + source_len = tl.load(source_lengths_ptr + batch_i) + target_len = tl.load(target_lengths_ptr + batch_i) + if source_i >= source_len or target_i > target_len: + # no calculations required + return + + # calculate offset in [B, T, U+1, V] tensor for the current vector with target logits/grad_logits + flat_index = ((batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i) * num_labels + logits_ptr += flat_index + grad_logits_ptr += flat_index + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < num_labels + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + # stable log softmax calculation + logits_max = tl.max(logits, axis=0) + logits_minus_max = logits - logits_max + denominator = tl.log(tl.sum(tl.exp(logits_minus_max), axis=0)) + log_softmax = logits_minus_max - denominator + # softmax for gradient + softmax = tl.exp(log_softmax) + + flat_index_grad = (batch_i * max_source_len + source_i) * max_target_len_plus_1 + target_i + blank_grad = tl.load(grad_blank_scores_ptr + flat_index_grad).to(tl.float32) + target_i_valid = target_i < target_len + target_grad = tl.load(grad_target_scores_ptr + flat_index_grad, mask=target_i_valid, other=0.0).to(tl.float32) + target_id = tl.load(targets_ptr + batch_i * (max_target_len_plus_1 - 1) + target_i, mask=target_i_valid, other=-1) + + grad_not_in_targets = (-softmax) * (blank_grad + target_grad) + grad = tl.where(col_offsets == blank_id, blank_grad + grad_not_in_targets, grad_not_in_targets) + grad = tl.where(col_offsets == target_id, target_grad + grad_not_in_targets, grad) + tl.store(grad_logits_ptr + col_offsets, grad, mask=mask) + + +class RnntLogProbs(torch.autograd.Function): + """ + Function to calculate log probabilities for target and blank labels for RNN-T, supporting torch.autograd. + """ + + @staticmethod + def forward( + ctx, + logits: torch.Tensor, + targets: torch.Tensor, + blank_id: int, + source_lengths: torch.Tensor | None, + target_lengths: torch.Tensor | None, + ): + """ + + Args: + ctx: ctx object for storing the context + logits: Joint tensor of size [B, T, U+1, D] + targets: Targets of size [B, U] + blank_id: id of the blank output + source_lengths: optional tensor with lengths for source utterances + target_lengths: optional tensor with lengths for targets + + Returns: + + """ + assert logits.is_contiguous() # logits are huge, so here we just check if logits are contiguous + targets = targets.contiguous() + device = logits.device + float_dtype = torch.float32 + + target_scores = torch.zeros(logits.shape[:-1], dtype=float_dtype, device=device) + blank_scores = torch.zeros_like(target_scores) + if source_lengths is None: + source_lengths = torch.full([logits.shape[0]], fill_value=logits.shape[1], dtype=torch.int, device=device) + else: + source_lengths = source_lengths.contiguous() + if target_lengths is None: + target_lengths = torch.full( + [logits.shape[0]], fill_value=logits.shape[2] - 1, dtype=torch.int, device=device + ) + else: + target_lengths = target_lengths.contiguous() + + # run Triton kernel + _rnnt_logprobs_fwd_kernel[(logits.shape[0], logits.shape[1], logits.shape[2])]( + logits_ptr=logits, + targets_ptr=targets, + source_lengths_ptr=source_lengths, + target_lengths_ptr=target_lengths, + max_source_len=logits.shape[1], + max_target_len_plus_1=logits.shape[2], + num_labels=logits.shape[3], + blank_id=blank_id, + target_scores_ptr=target_scores, + blank_scores_ptr=blank_scores, + BLOCK_SIZE=triton.next_power_of_2(logits.shape[-1]), + ) + + # saving for backward + ctx.save_for_backward(logits, targets, source_lengths, target_lengths) + ctx.blank_id = blank_id + return target_scores, blank_scores + + @staticmethod + def backward(ctx, grad_target_scores, grad_blank_scores): + """ + Backward calculation for RNN-T log-probs. + + Args: + ctx: ctx object for storing the context + grad_target_scores: upstream gradient for targets + grad_blank_scores: upstream gradient for blank scores + + Returns: + gradient for logits, None for all other arguments for `forward` + """ + (logits, targets, source_lengths, target_lengths) = ctx.saved_tensors + blank_id = ctx.blank_id + grad_logits = torch.zeros_like(logits) + _rnnt_logprobs_bwd_kernel[(logits.shape[0], logits.shape[1], logits.shape[2])]( + logits_ptr=logits, + grad_logits_ptr=grad_logits, + source_lengths_ptr=source_lengths, + target_lengths_ptr=target_lengths, + targets_ptr=targets, + max_source_len=logits.shape[1], + max_target_len_plus_1=logits.shape[2], + num_labels=logits.shape[3], + blank_id=blank_id, + grad_target_scores_ptr=grad_target_scores, + grad_blank_scores_ptr=grad_blank_scores, + BLOCK_SIZE=triton.next_power_of_2(logits.shape[-1]), + ) + return grad_logits, None, None, None, None + + +def rnnt_logprobs_triton( + logits: torch.Tensor, + targets: torch.Tensor, + blank_id: int, + source_lengths: torch.Tensor | None = None, + target_lengths: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Given logits, calculate log probabilities for blank and target labels needed for transducer loss calculation. + Optimized implementation in Triton. + + Args: + logits: Joint tensor of size [B, T, U+1, D] + targets: Targets of size [B, U] + blank_id: id of the blank output + source_lengths: optional tensor with lengths for source utterances + target_lengths: optional tensor with lengths for targets + + Returns: + Tuple of tensors with log probabilities for targets and blank labels, both of size [B, T, U+1]. + For the non-existent targets (U+1 or beyond target_lengths) output is zero. + """ + return RnntLogProbs.apply(logits, targets, blank_id, source_lengths, target_lengths) diff --git a/nemo/core/utils/optional_libs.py b/nemo/core/utils/optional_libs.py new file mode 100644 index 000000000000..9aa39260963c --- /dev/null +++ b/nemo/core/utils/optional_libs.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util + + +def is_lib_available(name: str) -> bool: + """ + Checks if the library/package with `name` is available in the system + NB: try/catch with importlib.import_module(name) requires importing the library, which can be slow. + So, `find_spec` should be preferred + """ + return importlib.util.find_spec(name) is not None + + +TRITON_AVAILABLE = is_lib_available("triton") + +try: + from nemo.core.utils.k2_guard import k2 as _ + + K2_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + K2_AVAILABLE = False diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7fd5e88eebe3..1b9fc88000b9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,5 +11,6 @@ tensorboard text-unidecode torch tqdm>=4.41.0 +triton>=3.1.0; sys_platform == 'linux' wget wrapt diff --git a/tests/collections/asr/k2/test_graph_transducer.py b/tests/collections/asr/k2/test_graph_transducer.py index 5879226e782d..592772767484 100644 --- a/tests/collections/asr/k2/test_graph_transducer.py +++ b/tests/collections/asr/k2/test_graph_transducer.py @@ -12,29 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import List import numpy as np import pytest import torch +from nemo.collections.asr.parts.k2.rnnt_logprobs import rnnt_logprobs_torch from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_numpy import RNNTLoss as RNNTLoss_Numpy +from nemo.core.utils.optional_libs import K2_AVAILABLE, TRITON_AVAILABLE + +if K2_AVAILABLE: + import k2 -try: from nemo.collections.asr.parts.k2.graph_transducer import GraphRnntLoss - from nemo.core.utils.k2_guard import k2 -except (ImportError, ModuleNotFoundError): - pytest.skip("k2 is not installed, skipping Graph-RNNT tests.", allow_module_level=True) + +if TRITON_AVAILABLE: + from nemo.collections.asr.parts.k2.rnnt_logprobs_triton import rnnt_logprobs_triton + EPS_SM_INPUT = 1e-6 EPS_L_INPUT = 1e-4 DEVICES = ['cpu'] -if torch.cuda.is_available() and k2.with_cuda: +if K2_AVAILABLE and torch.cuda.is_available() and k2.with_cuda: DEVICES.append('cuda') +@pytest.mark.skipif(not K2_AVAILABLE, reason="k2 is not installed, skipping Graph-RNNT tests.") class TestGraphRnnt: @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) @@ -214,9 +221,12 @@ def test_small_grid_transducer(self, device, rnnt_test_helper, rnn_loss_sample_d @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) - def test_medium_grid_transducer(self, device, rnnt_test_helper, rnn_loss_sample_data): + @pytest.mark.parametrize("use_triton", [True, False]) + def test_medium_grid_transducer(self, device, use_triton: bool, rnnt_test_helper, rnn_loss_sample_data): + if use_triton and device == "cpu": + pytest.skip("Triton does not support CPU yet") sample_data = rnn_loss_sample_data.get_sample_medium() - graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True) + graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True, use_triton=use_triton) graph_cost, graph_grads = rnnt_test_helper.wrap_and_call( graph_rnnt, sample_data.logits, sample_data.targets, device ) @@ -225,9 +235,12 @@ def test_medium_grid_transducer(self, device, rnnt_test_helper, rnn_loss_sample_ @pytest.mark.unit @pytest.mark.parametrize("device", DEVICES) - def test_medium_random_var_size(self, device, rnnt_test_helper, rnn_loss_sample_data): + @pytest.mark.parametrize("use_triton", [True, False]) + def test_medium_random_var_size(self, device, use_triton: bool, rnnt_test_helper, rnn_loss_sample_data): + if use_triton and device == "cpu": + pytest.skip("Triton does not support CPU yet") sample_data = rnn_loss_sample_data.get_sample_medium_random_var_size(blank_first=True) - graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True) + graph_rnnt = GraphRnntLoss(blank=0, use_grid_implementation=True, use_triton=use_triton) graph_cost, graph_grads = rnnt_test_helper.wrap_and_call( graph_rnnt, sample_data.logits.detach(), @@ -261,3 +274,63 @@ def test_small_random_grid_compose_equivalent(self, device: torch.device, blank_ assert k2.is_rand_equivalent( graph_grid, graph_composed, log_semiring=True, treat_epsilons_specially=False ), "Grid and composed graphs are not equivalent." + + +@pytest.mark.skipif(not TRITON_AVAILABLE, reason="Triton is not installed, skipping RNNT Log Probs tests") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is unavailable") +class TestRnntLogProbs: + @pytest.mark.parametrize( + "batch_size,num_frames,num_text_units,vocab_size", + [ + (1, 4, 2, 4), + (2, 3, 2, 5), + (2, 16, 31, 17), + (16, 129, 65, 2048), + ], + ) + @pytest.mark.parametrize( + "float_dtype", + [torch.float32] + ([torch.bfloat16] if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else []), + ) + def test_rnnt_logprobs_random( + self, batch_size: int, num_frames: int, num_text_units: int, vocab_size: int, float_dtype: torch.dtype + ): + """ + Test Triton-based implementation using etalon Torch-based implementation for RNN-T log-probs. + """ + device = torch.device("cuda") + torch.manual_seed(777) + + targets = torch.tensor( + [[random.randrange(0, vocab_size - 1) for i in range(num_text_units)] for j in range(batch_size)], + device=device, + dtype=torch.long, + ) + + logits = torch.rand( + [batch_size, num_frames, num_text_units + 1, vocab_size + 1], + dtype=float_dtype, + device=device, + requires_grad=True, + ) + + # Triton-based implementation works in float32 precision for accuracy purposes, should compare with float32 + target_scores_etalon, blank_scores_etalon = rnnt_logprobs_torch( + logits=logits.to(torch.float32), targets=targets, blank_id=vocab_size + ) + logits2 = logits.clone().detach() + logits2.requires_grad_(True) + target_scores, blank_scores = rnnt_logprobs_triton(logits=logits2, targets=targets, blank_id=vocab_size) + target_scores[..., -1:] = 0.0 + target_scores_etalon[..., -1:] = 0.0 + assert torch.allclose(blank_scores, blank_scores_etalon, atol=1e-5) + assert torch.allclose(target_scores, target_scores_etalon, atol=1e-5) + + # test backward + target_scales = torch.rand_like(target_scores, requires_grad=False) + blank_scales = torch.rand_like(blank_scores, requires_grad=False) + loss_etalon = (target_scales * target_scores_etalon + blank_scales * blank_scores_etalon).sum() + loss = (target_scales * target_scores + blank_scales * blank_scores).sum() + loss_etalon.backward() + loss.backward() + assert torch.allclose(logits.grad, logits2.grad, atol=1e-5)