From db35186391a2abfc6c91d703527dac20d2488107 Mon Sep 17 00:00:00 2001 From: Peng Guanwen Date: Fri, 2 Aug 2024 15:58:26 +0800 Subject: [PATCH] [Core] Comment out unused code in sampler (#7023) --- vllm/model_executor/sampling_metadata.py | 58 +++++++++++++----------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 59cfec9ec8934..015e85b4ca81d 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -13,6 +13,8 @@ _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 +# Some triton sampler related code is guarded before it is ready. +_USE_TRITON_SAMPLER = False @dataclass @@ -347,14 +349,16 @@ def from_sampling_metadata( repetition_penalties: List[float] = [] sampling_seeds: List[int] = [] sample_indices: List[int] = [] - prompt_best_of: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False - # We need one base seed per Triton slice. - seeds_to_generate = (extra_seeds_to_generate + - get_num_triton_sampler_splits(vocab_size)) + if _USE_TRITON_SAMPLER: + prompt_best_of: List[int] = [] + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: @@ -366,9 +370,6 @@ def from_sampling_metadata( r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p - seed = sampling_params.seed - - is_greedy = sampling_params.sampling_type == SamplingType.GREEDY # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) @@ -389,8 +390,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): + if (is_prompt and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -415,23 +415,27 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - if is_prompt: - prompt_best_of.append(sampling_params.best_of) - query_len = seq_group.query_len - assert query_len is not None - - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - extra_entropy = extra_entropy or () - seq_seeds = cls._get_sequence_seeds( - seed, - seq_data.get_len(), - *extra_entropy, - seq_id, - seeds_to_generate=seeds_to_generate, - is_greedy=is_greedy) - sampling_seeds.append(seq_seeds) - sample_indices.extend(seq_group.sample_indices) + if _USE_TRITON_SAMPLER: + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + query_len = seq_group.query_len + assert query_len is not None + + seed = sampling_params.seed + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + extra_entropy = extra_entropy or () + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seq_id, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.extend(seq_group.sample_indices) if do_penalties: for seq_group in sampling_metadata.seq_groups: @@ -549,7 +553,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], device="cpu", dtype=torch.long, pin_memory=pin_memory, - ).T.contiguous() + ).t().contiguous() # Because the memory is pinned, we can do non-blocking # transfer to device.