diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 3c506121f..5092dbe16 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -402,9 +402,10 @@ def forward( sampling_tensors.xtc_probabilities) - # We use float32 for probabilities and log probabilities. # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # No cast required, as softmaxes internally accumulate with >= fp32. + # See https://github.com/pytorch/pytorch/blob/v2.4.0/aten/src/ATen/native/cuda/PersistentSoftmax.cuh#L62 + probs = torch.softmax(logits, dim=-1) # skew needs to be applied post-softmax if do_skew: @@ -422,7 +423,9 @@ def forward( logits = torch.log(probs) # Compute the log probabilities. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + logprobs = torch.log_softmax(logits, dim=-1) + # Logits unused past this point. They are HUGE, for prompt_logprobs. + del logits # Sample the next tokens. sample_results, maybe_sampled_tokens_tensor = _sample( @@ -1424,7 +1427,7 @@ def _sample( # sampling_tensors) -def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: +def _get_ranks(logprobs: torch.Tensor, query_indices: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ This function calculates the ranks of the chosen tokens in a logprob tensor. Args: @@ -1436,9 +1439,14 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: Each element in the returned tensor represents the rank of the chosen token in the input logprob tensor. """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), - indices] - return (x > vals[:, None]).long().sum(1).add_(1) + expanded_indices = torch.zeros(logprobs.shape[0], dtype=indices.dtype, device=indices.device) + expanded_indices.scatter_(-1, query_indices, indices) + vals = torch.gather(logprobs, -1, expanded_indices.unsqueeze(dim=1)) + # doing this in a single pass is not practical for prompt_logprobs. + outs = torch.zeros_like(indices) + for i in range(0, vals.shape[0], 10): + outs[i:i+10] = (logprobs[i:i+10] > vals[i:i+10]).sum(dim=-1, dtype=indices.dtype) + return outs[query_indices].add_(1) def _get_logprobs( @@ -1538,7 +1546,8 @@ def _get_logprobs( next_token_ids_gpu, ]] ranks = _get_ranks( - logprobs[query_indices_gpu], + logprobs, + query_indices_gpu, next_token_ids_gpu, ) assert selected_logprobs.shape[0] == ranks.shape[0]