Skip to content

Commit

Permalink
reduce sampler peak memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
50h100a committed Dec 16, 2024
1 parent bc1a2bd commit d49ead7
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions aphrodite/modeling/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,10 @@ def forward(
sampling_tensors.xtc_probabilities)


# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# No cast required, as softmaxes internally accumulate with >= fp32.
# See https://github.com/pytorch/pytorch/blob/v2.4.0/aten/src/ATen/native/cuda/PersistentSoftmax.cuh#L62
probs = torch.softmax(logits, dim=-1)

# skew needs to be applied post-softmax
if do_skew:
Expand All @@ -422,7 +423,9 @@ def forward(
logits = torch.log(probs)

# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
logprobs = torch.log_softmax(logits, dim=-1)
# Logits unused past this point. They are HUGE, for prompt_logprobs.
del logits

# Sample the next tokens.
sample_results, maybe_sampled_tokens_tensor = _sample(
Expand Down Expand Up @@ -1424,7 +1427,7 @@ def _sample(
# sampling_tensors)


def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_ranks(logprobs: torch.Tensor, query_indices: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.
Args:
Expand All @@ -1436,9 +1439,14 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
indices]
return (x > vals[:, None]).long().sum(1).add_(1)
expanded_indices = torch.zeros(logprobs.shape[0], dtype=indices.dtype, device=indices.device)
expanded_indices.scatter_(-1, query_indices, indices)
vals = torch.gather(logprobs, -1, expanded_indices.unsqueeze(dim=1))
# doing this in a single pass is not practical for prompt_logprobs.
outs = torch.zeros_like(indices)
for i in range(0, vals.shape[0], 10):
outs[i:i+10] = (logprobs[i:i+10] > vals[i:i+10]).sum(dim=-1, dtype=indices.dtype)
return outs[query_indices].add_(1)


def _get_logprobs(
Expand Down Expand Up @@ -1538,7 +1546,8 @@ def _get_logprobs(
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
logprobs,
query_indices_gpu,
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]
Expand Down

0 comments on commit d49ead7

Please sign in to comment.