Skip to content

Commit

Permalink
Memory efficiency improvement to logprobs_from_logits_v2 (#220)
Browse files Browse the repository at this point in the history
Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
tensor internally to hold the element-wise application of `exp`.
However, by applying a loop over `logsumexp`, we can iteratively compute
logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for
bfloat16 case.
  • Loading branch information
tyler-romero authored Feb 8, 2025
1 parent 958a326 commit 4b51624
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
20 changes: 20 additions & 0 deletions tests/gpu_utility/test_torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from verl.utils.model import create_random_mask
from flash_attn.bert_padding import unpad_input
import torch
import pytest


def test_log_probs_from_logits_response_rmpad():
Expand Down Expand Up @@ -49,6 +50,25 @@ def test_log_probs_from_logits_response_rmpad():
assert torch.all(torch.eq(actual_output * response_mask, expected_output * response_mask))


@pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16])
def test_logprobs_from_logits_v2(dtype):
from verl.utils.torch_functional import logprobs_from_logits_v2, logprobs_from_logits_naive
vocab_size = 32000
batch_size = 2
seq_len = 512

labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device='cuda')
logits = torch.randn(batch_size, seq_len, vocab_size, device='cuda', dtype=dtype)

expected_output = logprobs_from_logits_naive(labels=labels, logits=logits)
actual_output = logprobs_from_logits_v2(labels=labels, logits=logits)

if dtype in [torch.float16, torch.bfloat16]: # float16 falls back to an exactly equivalent method
assert torch.equal(actual_output, expected_output)
else: # small numerical difference when using gather / logsumexp approach
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)


def test_lr_scheduler():
from torch import nn
model = nn.Linear(10, 10)
Expand Down
34 changes: 22 additions & 12 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def logprobs_from_logits(logits, labels):
output = logprobs_from_logits_flash_attn(logits, labels)
output = output.view(*batch_dim)
else:
output = logprobs_from_logits_naive(logits, labels)
output = logprobs_from_logits_v2(logits, labels)
return output


Expand All @@ -75,14 +75,24 @@ def logprobs_from_logits_naive(logits, labels):
return logpy


def logprobs_of_labels_v2(logits: torch.FloatTensor, labels):
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
"""
A memory efficient implementation of logprobs_from_logits
"""
assert logits.dtype == torch.float32, 'Using bf16 logits with logprobs_of_labels_v2 may lead to divergence'
logprobs_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1))
logprobs_labels = logprobs_labels - torch.logsumexp(logits, dim=-1, keepdim=True)
return logprobs_labels.squeeze(-1)
if logits.dtype in [torch.float32, torch.float64]:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
logprobs_labels = []
for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
row_logprobs = F.log_softmax(row_logits, dim=-1)
row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
logprobs_labels.append(row_logprobs_labels)
logprobs_labels = torch.stack(logprobs_labels)
return logprobs_labels


def clip_by_value(x, tensor_min, tensor_max):
Expand Down Expand Up @@ -277,7 +287,7 @@ def tokenize_and_postprocess_data(prompt: str,


def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):
""" Remove the pad token.
""" Remove the pad token.
Args:
input_ids shape: [bs, seq_length]
Expand All @@ -293,13 +303,13 @@ def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):

def log_probs_from_logits_response(input_ids, logits, response_length):
"""Compute the response log_probs from full logits. Note that logits = model(input_ids)
Args:
input_ids: [batch_size, seqlen]
logits: [batch_size, seqlen, vocab_size]
Returns:
response_log_prob:
response_log_prob:
"""
response_logits = logits[:, -response_length - 1:-1]
response = input_ids[:, -response_length:]
Expand All @@ -313,7 +323,7 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad
logits and input_ids.
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
for large vocab_size
Args:
input_ids: [batch_size, seqlen]
attention_mask: [batch_size, seqlen]
Expand Down Expand Up @@ -341,7 +351,7 @@ def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batc
logits and input_ids.
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
for large vocab_size
Args:
input_ids_rmpad: [1, total_nnz]
logits_rmpad: [total_nnz, vocab_size]
Expand Down

0 comments on commit 4b51624

Please sign in to comment.