diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 25d15df9f915d..6a4cb52a15958 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -75,6 +75,8 @@ def process_outputs(self, sequence_group: SequenceGroup, assert len(seqs) == 1, ( "Beam search not supported in multi-step decoding.") seq = seqs[0] + #for output in outputs: + # print('output ' + str(output)) # Since there's only one sequence per sequence group, we can take the # first sample. diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6d00ea64f7cb8..3f91776a4d722 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -90,7 +90,9 @@ def forward( # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + logprobs = None + # TODO(sroy) - Add flag for this. + #logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. sample_results, maybe_sampled_tokens_tensor = _sample( @@ -109,13 +111,21 @@ def forward( on_device_tensors = None # Get the logprobs query results. + #print('sample_results ' + str(sample_results)) prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors) + output = _build_sampler_output(sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors) + #print('output ' + str(output)) + return output + #return _build_sampler_output(sample_results, + # sampling_metadata, + # prompt_logprobs, + # sample_logprobs, + # on_device_tensors=on_device_tensors) @property def _should_modify_greedy_probs_inplace(self) -> bool: @@ -472,10 +482,10 @@ def _sample_with_torch( # Create output tensor for sampled token ids. if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], + sampled_token_ids_tensor = torch.empty(probs.shape[0], 1, dtype=torch.long, - device=logprobs.device) + device=probs.device) else: sampled_token_ids_tensor = None @@ -492,7 +502,9 @@ def _sample_with_torch( sample_metadata[sampling_type] = (seq_group_id, seq_groups) long_sample_indices = sample_indices.long() if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], + #greedy_samples = torch.argmax(logprobs[long_sample_indices], + # dim=-1) + greedy_samples = torch.argmax(probs[long_sample_indices], dim=-1) if sampled_token_ids_tensor is not None: @@ -720,6 +732,11 @@ def _get_logprobs( Returns: A tuple of prompt and sample logprobs per sequence group in a batch. """ + if logprobs is None: + empty_sampled_logprob: List[SampleLogprobs] = [[None] * len(sample_results)] * ( + len(sampling_metadata.seq_groups)) + empty_prompt_logprob: List[PromptLogprobs] = [None] * len(sampling_metadata.seq_groups) + return empty_prompt_logprob, empty_sampled_logprob # The index of query token to calculate logprobs. It includes both # prompt and sample logprob indices. query_indices: List[int] = [] diff --git a/vllm/sequence.py b/vllm/sequence.py index 1cebf68d463db..601698d425d34 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -326,9 +326,12 @@ def append_token_id( token_id: int, logprobs: Dict[int, Logprob], ) -> None: - assert token_id in logprobs + if logprobs is not None: + assert token_id in logprobs + self.data.append_token_id(token_id, logprobs[token_id].logprob) + else: + self.data.append_token_id(token_id, 0.0) self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) def get_len(self) -> int: return self.data.get_len() diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 40516556344e9..e6cb7007e51f0 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -169,23 +169,29 @@ def _contract_batch( target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) target_probs = target_probs.reshape(*target_token_ids.shape, self._vocab_size) - target_logprobs = target_logprobs.reshape(target_probs.shape) + + if target_logprobs is not None: + target_logprobs = target_logprobs.reshape(target_probs.shape) all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) - all_logprobs = target_logprobs.new_full(size=all_probs.shape, - fill_value=-float("inf")) + all_logprobs = None + if target_logprobs is not None: + all_logprobs = target_logprobs.new_full( + size=all_probs.shape, fill_value=-float("inf")) if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs - all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if all_logprobs is not None: + all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs - all_logprobs[spec_indices] = target_logprobs + if all_logprobs is not None: + all_logprobs[spec_indices] = target_logprobs return all_tokens, all_probs, all_logprobs @@ -327,10 +333,16 @@ def _split_scoring_output( ) = sampler_output.sampled_token_probs.split(split_sizes) (spec_sampled_tokens, non_spec_sampled_tokens ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) - ( - spec_logprobs, - non_spec_logprobs, - ) = sampler_output.logprobs.split(split_sizes) + if (sampler_output.logprobs is not None): + ( + spec_logprobs, + non_spec_logprobs, + ) = sampler_output.logprobs.split(split_sizes) + else: + ( + spec_logprobs, + non_spec_logprobs, + ) = (None, None) # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 90bba96ee8acb..97293897a2827 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -110,9 +110,11 @@ def update_model_input( seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] - - seq.append_token_id(token_id, token_logprob.logprob) + if seq_output.logprobs is not None: + token_logprob = seq_output.logprobs[token_id] + seq.append_token_id(token_id, token_logprob.logprob) + else: + seq.append_token_id(token_id, 0.0) seq.update_num_computed_tokens(1) return self.prepare_model_input(self.cached_seq_group_metadata_list) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 09a77f9e870fb..46416846f17d2 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -216,9 +216,11 @@ def _append_new_tokens( seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] + token_logprob = 0.0 + if seq_output.logprobs is not None: + token_logprob = seq_output.logprobs[token_id].logprob - seq.append_token_id(token_id, token_logprob.logprob) + seq.append_token_id(token_id, token_logprob) seq.update_num_computed_tokens(1) @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3c8e3dee46831..914dba2c11a30 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -569,25 +569,31 @@ def _create_output_sampler_list( the same number of outputs. """ batch_size, num_steps = accepted_token_ids.shape - - # Organize input tensors by step instead of by sequence. - target_logprobs_by_step = target_logprobs.transpose(0, 1) accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + if target_logprobs is not None: + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) - # Get the logprobs/rank of the accepted tokens. - (accepted_token_id_ranks_by_step, - accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( - logprob_tensor=target_logprobs_by_step, - sampled_token_ids=accepted_token_ids_by_step, - ) - - # Get the top-k logprobs (which may or may not include the logprob of - # the accepted token). - (topk_logprobs_by_step, - topk_indices_by_step) = target_logprobs_by_step.topk( - k=self.scorer_worker.model_config.max_logprobs, - dim=-1, - ) + # Get the top-k logprobs (which may or may not include the logprob of + # the accepted token). + (topk_logprobs_by_step, + topk_indices_by_step) = target_logprobs_by_step.topk( + k=self.scorer_worker.model_config.max_logprobs, + dim=-1, + ) + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step.tolist() + topk_indices_by_step = topk_indices_by_step.tolist() # Get the sequence ids and num_logprobs (sampling parameter) in the # batch. @@ -598,12 +604,6 @@ def _create_output_sampler_list( # Serialize all tensors to CPU Python lists. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - accepted_token_id_ranks_by_step = ( - accepted_token_id_ranks_by_step.tolist()) - accepted_token_id_logprobs_by_step = ( - accepted_token_id_logprobs_by_step.tolist()) - topk_logprobs_by_step = topk_logprobs_by_step.tolist() - topk_indices_by_step = topk_indices_by_step.tolist() # Construct the output on a per-step, per-sequence basis. sampler_output_list: List[SamplerOutput] = [] @@ -616,20 +616,33 @@ def _create_output_sampler_list( for sequence_index in range(batch_size): # Each sequence may have a different num_logprobs; retrieve it. num_logprobs = num_logprobs_per_seq[sequence_index] - step_output_token_ids.append( - create_sequence_group_output( - token_id=accepted_token_ids_by_step[step_index] - [sequence_index], - token_id_logprob_rank=accepted_token_id_ranks_by_step[ - step_index][sequence_index], - token_id_logprob=accepted_token_id_logprobs_by_step[ - step_index][sequence_index], - seq_id=seq_ids[sequence_index], - topk_token_ids=topk_indices_by_step[step_index] - [sequence_index][:num_logprobs], - topk_logprobs=topk_logprobs_by_step[step_index] - [sequence_index][:num_logprobs], - )) + if target_logprobs is not None: + step_output_token_ids.append( + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index] , + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], + )) + else: + step_output_token_ids.append( + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=-1 , + token_id_logprob=-1, + seq_id=seq_ids[sequence_index], + topk_token_ids=[], + topk_logprobs=[], + )) + sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 80710419e602d..8cf28d120a0b3 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -68,19 +68,21 @@ def create_sequence_group_output( """ # vLLM logprobs always include the sampled token. In addition, the user may # request topk-logprobs (where top-k varies per user up to max_logprobs). - logprobs: Dict[int, Logprob] = { - token_id: Logprob( - logprob=token_id_logprob, - rank=token_id_logprob_rank, - ), - } - logprobs.update({ - topk_token_ids[topk_logprob_index]: Logprob( - logprob=topk_logprobs[topk_logprob_index], - rank=topk_logprob_index + 1, - ) - for topk_logprob_index, _ in enumerate(topk_token_ids) - }) + logprobs = None + if token_id_logprob >= 0.0: + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_ids[topk_logprob_index]: Logprob( + logprob=topk_logprobs[topk_logprob_index], + rank=topk_logprob_index + 1, + ) + for topk_logprob_index, _ in enumerate(topk_token_ids) + }) return CompletionSequenceGroupOutput( samples=[ @@ -149,12 +151,15 @@ def sampler_output_to_torch( sampled_token_probs = sampled_token_probs.transpose(0, 1) # shape: [batch_size, num_sampler_output, vocab_size] - sampled_token_logprobs = torch.stack( - [sampler_output.logprobs for sampler_output in sampler_output_list], - dim=0, - ) + if sampler_output_list[0].logprobs is not None: + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + else: + sampled_token_logprobs = None - if sampler_transposed: + if sampler_transposed and sampled_token_logprobs is not None: sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) # shape: [batch_size, num_sampler_output]