Skip to content

2x more memory efficient Graph-based RNN-T #11169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 117 additions & 38 deletions nemo/collections/asr/parts/k2/graph_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
import abc
from contextlib import nullcontext
from typing import ContextManager

import torch
import torch.nn.functional as F

from nemo.core.classes.loss import Loss
from nemo.core.utils.k2_guard import k2
from nemo.core.utils.optional_libs import TRITON_AVAILABLE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be made into a function, similar to Numba, instead of a hardcoded constant var

from nemo.utils import logging

if TRITON_AVAILABLE:
from nemo.collections.asr.parts.k2.rnnt_logprobs_triton import rnnt_logprobs_triton


def force_float32_context() -> ContextManager:
Expand Down Expand Up @@ -129,28 +135,30 @@ def get_composed_lattice(self, units_tensor: torch.Tensor, num_frames: int, voca
return composed

def get_graphs_batched(
self, logits_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, vocab_size: int
self, source_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, vocab_size: int
) -> "k2.Fsa":
"""
Get batched lattice (grid or composed) for the batch of sequences.

Args:
logits_lengths: tensor with lengths of logits
source_lengths: tensor with lengths of logits
targets: tensor with target units
target_lengths: tensor with lengths of targets
vocab_size: vocab size (including blank)

Returns:
batched lattice - FsaVec (k2.Fsa)
"""
batch_size = logits_lengths.shape[0]
batch_size = source_lengths.shape[0]
with torch.no_grad():
if self.use_grid_implementation:
source_lengths_list = source_lengths.tolist()
target_lengths_list = target_lengths.tolist()
return k2.create_fsa_vec(
[
self.get_grid(
units_tensor=targets[i, : target_lengths[i].item()],
num_frames=logits_lengths[i].item(),
units_tensor=targets[i, : target_lengths_list[i]],
num_frames=source_lengths_list[i],
vocab_size=vocab_size,
)
for i in range(batch_size)
Expand All @@ -167,30 +175,28 @@ def get_graphs_batched(
]
temporal_fsas = [
self.get_temporal_schema(
num_frames=logits_lengths[i].item(), vocab_size=vocab_size, device=targets.device
num_frames=source_lengths[i].item(), vocab_size=vocab_size, device=targets.device
)
for i in range(batch_size)
]
target_fsas_vec = k2.compose(
k2.create_fsa_vec(text_fsas), k2.create_fsa_vec(temporal_fsas), treat_epsilons_specially=False
)
if self.connect_composed:
k2.connect(target_fsas_vec)
target_fsas_vec = k2.connect(target_fsas_vec)
return target_fsas_vec

def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) -> torch.Tensor:
def get_batch_indices(self, target_fsas_vec: k2.Fsa) -> torch.Tensor:
"""
Get indices of flatten logits for each arc in the lattices.
Get batch indices (for logits) for each arc in the lattices.

Args:
target_fsas_vec: batch of target FSAs with lattices
logits_shape: shape of the logits tensor

Returns:
1d tensor with indices
"""
# logits_shape: B x Time x Text+1 x Labels
batch_size = logits_shape[0]
batch_size = target_fsas_vec.shape[0]
device = target_fsas_vec.device
scores_to_batch_i = torch.repeat_interleave(
torch.arange(batch_size, device=device, dtype=torch.int64),
Expand All @@ -199,6 +205,21 @@ def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size)
device=device,
),
)
return scores_to_batch_i

def get_logits_indices(self, target_fsas_vec: k2.Fsa, logits_shape: torch.Size) -> torch.Tensor:
"""
Get indices of flatten logits for each arc in the lattices.

Args:
target_fsas_vec: batch of target FSAs with lattices
logits_shape: shape of the logits tensor

Returns:
1d tensor with indices
"""
# logits_shape: B x Time x Text+1 x Labels
scores_to_batch_i = self.get_batch_indices(target_fsas_vec=target_fsas_vec)
indices = (
scores_to_batch_i * logits_shape[1] * logits_shape[2] * logits_shape[3] # Batch
+ target_fsas_vec.aux_labels.to(torch.int64) * logits_shape[2] * logits_shape[3] # Time indices
Expand All @@ -222,6 +243,8 @@ def __init__(
connect_composed=False,
double_scores=False,
cast_to_float32=False,
return_graph=False,
use_triton=True,
):
"""
Init method
Expand All @@ -232,8 +255,11 @@ def __init__(
connect_composed: Connect graph after composing unit and temporal schemas (only for Compose-Transducer).
`connect` operation is slow, it is useful for visualization, but not necessary for loss computation.
double_scores: Use calculation of loss in double precision (float64) in the lattice.
Does not significantly affect memory usage since the lattice is ~V/2 times smaller than the joint tensor.
Does not significantly affect memory usage since the lattice is ~V/2 times smaller
than the joint tensor.
cast_to_float32: Force cast joint tensor to float32 before log-softmax calculation.
return_graph: Return graph (along with loss) from `forward` function
use_triton: use optimized log probs calculations with Triton (faster and more memory efficient)
"""
super().__init__(
use_grid_implementation=use_grid_implementation,
Expand All @@ -242,6 +268,10 @@ def __init__(
cast_to_float32=cast_to_float32,
)
self.blank = blank
self.return_graph = return_graph
self.use_triton = use_triton and TRITON_AVAILABLE
if not self.use_triton:
logging.warning("Triton is disabled, memory usage can be larger")

def get_unit_schema(self, units_tensor: torch.Tensor, vocab_size: int) -> "k2.Fsa":
"""
Expand Down Expand Up @@ -370,13 +400,14 @@ def relabel_states(states: torch.Tensor, n: int, m: int) -> torch.Tensor:
anti_diag = m + n - 1 - diag
max_idx = n * m - 1
cur_diag_idx = i if m > n else m - j - 1
states = (
new_states = (
diag.lt(min_mn) * ((diag * (diag + 1) >> 1) + i)
+ torch.logical_and(diag.ge(min_mn), diag.lt(max_mn))
* ((min_mn * (min_mn + 1) >> 1) + (diag - min_mn) * min_mn + cur_diag_idx)
+ diag.ge(max_mn) * (max_idx - (anti_diag * (anti_diag + 1) >> 1) + m - j)
)
return states
torch.where(states >= n * m, states, new_states, out=new_states)
return new_states

def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) -> "k2.Fsa":
"""
Expand Down Expand Up @@ -445,13 +476,76 @@ def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int)
rnnt_graph.unit_positions = unit_positions
return rnnt_graph

def get_weighted_graphs(
self,
logits: torch.Tensor,
targets: torch.Tensor,
source_lengths: torch.Tensor,
target_lengths: torch.Tensor,
use_graph_weight=False,
) -> "k2.Fsa":
"""
Get batch of graphs (FsaVec) for RNN-T loss calculation.

Args:
logits: activations (joint tensor). NB: raw logits, not after log-softmax
targets: target labels
source_lengths: lengths of source sequences
target_lengths: length of target sequences
use_graph_weight: uses weight from graphs (if `get_graphs_batched` returns graphs with weights)

Returns:
FsaVec containing RNN-T graphs for all utterances.
"""
vocab_size = logits.shape[-1]
target_fsas_vec = self.get_graphs_batched(source_lengths, targets, target_lengths, vocab_size)

with torch.no_grad():
# last transitions in the graph are labeled with -1 label
last_transition_mask = target_fsas_vec.labels == -1
batch_indices = self.get_batch_indices(target_fsas_vec=target_fsas_vec)
time_indices = target_fsas_vec.aux_labels.clone().to(torch.int64)
unit_indices = target_fsas_vec.unit_positions.clone().to(torch.int64)
text_units = target_fsas_vec.labels.clone().to(torch.int64)
# fill in the indices outside the logits with 0, replace later
text_units.masked_fill_(last_transition_mask, 0)

cast_context = force_float32_context() if self.cast_to_float32 else nullcontext()
with cast_context:
# NB: do not assign scores -> modify, k2 will not update all scores correctly (modify -> assign)
if self.use_triton and logits.device.type == "cuda":
unit_scores, blank_scores = rnnt_logprobs_triton(
logits=logits,
targets=targets,
blank_id=self.blank,
source_lengths=source_lengths,
target_lengths=target_lengths,
)
text_units_blank_mask = text_units == self.blank
scores = torch.where(
text_units_blank_mask,
blank_scores[batch_indices, time_indices, unit_indices],
unit_scores[batch_indices, time_indices, unit_indices],
).to(torch.float32)
scores[last_transition_mask] = 0.0 # fix weights for the arcs to the last state
else:
log_probs = F.log_softmax(logits, dim=-1)
scores = log_probs[batch_indices, time_indices, unit_indices, text_units].to(torch.float32)
scores[last_transition_mask] = 0.0

if use_graph_weight:
target_fsas_vec.scores = target_fsas_vec.scores + scores
else:
target_fsas_vec.scores = scores
return target_fsas_vec

def forward(
self,
acts: torch.Tensor,
labels: torch.Tensor,
act_lens: torch.Tensor,
label_lens: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor | tuple[torch.Tensor, "k2.Fsa"]:
"""
Compute forward method for RNN-T.

Expand All @@ -466,26 +560,11 @@ def forward(
"""
# argument names are consistent with NeMo, see RNNTLoss.forward:
# self._loss(acts=log_probs, labels=targets, act_lens=input_lengths, label_lens=target_lengths)
logits, targets, logits_lengths, target_lengths = acts, labels, act_lens, label_lens

# logits: B x Time x Text+1 x C
vocab_size = logits.shape[-1]
target_fsas_vec = self.get_graphs_batched(logits_lengths, targets, target_lengths, vocab_size)

cast_context = force_float32_context() if self.cast_to_float32 else nullcontext()
with cast_context:
log_probs = F.log_softmax(logits, dim=-1)
with torch.no_grad():
indices = self.get_logits_indices(target_fsas_vec, logits.shape)
# transition to the last state
# use 0 index (for valid index_select) and manually assign score after index_select for this case
indices[target_fsas_vec.labels == -1] = 0

# NB: do not assign scores -> modify, k2 will not update all scores correctly (modify -> assign)
scores = log_probs.flatten().index_select(-1, indices)
# fix weights for the arcs to the last state
scores[target_fsas_vec.labels == -1] = 0
target_fsas_vec = self.get_weighted_graphs(
logits=acts, targets=labels, source_lengths=act_lens, target_lengths=label_lens, use_graph_weight=False
)

target_fsas_vec.scores = scores
scores = -1 * target_fsas_vec.get_tot_scores(use_double_scores=self.double_scores, log_semiring=True)
return scores
scores = -1 * target_fsas_vec.get_tot_scores(use_double_scores=self.double_scores, log_semiring=True)
if self.return_graph:
return scores, target_fsas_vec
return scores
44 changes: 44 additions & 0 deletions nemo/collections/asr/parts/k2/rnnt_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F


def rnnt_logprobs_torch(
logits: torch.Tensor, targets: torch.Tensor, blank_id: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Given logits, calculate log probabilities for blank and target labels needed for transducer loss calculation.
Naive implementation in PyTorch, for testing and prototyping purposes.

Args:
logits: Joint tensor of size [B, T, U+1, D]
targets: Targets of size [B, U]
blank_id: id of the blank output

Returns:
Tuple of tensors with log probabilities for targets and blank labels, both of size [B, T, U+1].
For the last non-existent target (U+1) output is zero.
"""
device = logits.device
batch_size = logits.shape[0]
log_probs = F.log_softmax(logits, dim=-1)
blank_scores = log_probs[..., blank_id]
targets = torch.cat((targets, torch.zeros([batch_size], dtype=targets.dtype, device=device).unsqueeze(1)), dim=-1)
target_scores = torch.gather(
log_probs, dim=-1, index=targets.unsqueeze(1).expand(log_probs.shape[:-1]).unsqueeze(-1)
).squeeze(-1)
target_scores[:, :, -1] = 0.0
return target_scores, blank_scores
Loading
Loading