diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 5c376797a054f..121458f8156a1 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -220,7 +220,7 @@ def _apply_min_tokens_penalty( seqs_to_penalize: List[int] = [] for j, seq_id in enumerate(seq_ids): seq_data = seq_group.seq_data[seq_id] - if len(seq_data.output_token_ids) < min_tokens: + if len(seq_data.output_token_ids_array) < min_tokens: seqs_to_penalize.append(j) if seqs_to_penalize: diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 390b5d173ebcd..27b37a9d53470 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,4 +1,5 @@ import random +from array import array from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -329,8 +330,8 @@ def from_sampling_metadata( user-defined seed for each sequence. extra_entropy: extra entropy to use when generating seeds. """ - prompt_tokens: List[List[int]] = [] - output_tokens: List[List[int]] = [] + prompt_tokens: List[array] = [] + output_tokens: List[array] = [] top_ks: List[int] = [] temperatures: List[float] = [] top_ps: List[float] = [] @@ -432,13 +433,15 @@ def from_sampling_metadata( if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) - prompt_tokens.extend([] for _ in range(prefill_len)) - output_tokens.extend([] for _ in range(prefill_len)) + prompt_tokens.extend( + array('l') for _ in range(prefill_len)) + output_tokens.extend( + array('l') for _ in range(prefill_len)) if seq_group.do_sample: for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(list(seq_data.prompt_token_ids)) - output_tokens.append(list(seq_data.output_token_ids)) + prompt_tokens.append(seq_data.prompt_token_ids_array) + output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, @@ -454,9 +457,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], frequency_penalties: List[float], repetition_penalties: List[float], sampling_seeds: List[int], sample_indices: List[int], - prompt_tokens: List[List[int]], - output_tokens: List[List[int]], vocab_size: int, - extra_seeds_to_generate: int, device: torch.device, + prompt_tokens: List[array], output_tokens: List[array], + vocab_size: int, extra_seeds_to_generate: int, + device: torch.device, dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. diff --git a/vllm/sequence.py b/vllm/sequence.py index 0cd4c7e71d78d..72821ecea0f47 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,6 +3,7 @@ import enum import math from abc import ABC, abstractmethod +from array import array from collections import defaultdict from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, @@ -119,10 +120,10 @@ def __init__( prompt_token_ids: List[int], output_token_ids: Optional[List[int]] = None, ) -> None: - self._prompt_token_ids: List[int] = list(prompt_token_ids) + self._prompt_token_ids = array('l', prompt_token_ids) self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) - self._output_token_ids: List[int] = ( - list(output_token_ids) if output_token_ids is not None else []) + self._output_token_ids = array( + 'l', output_token_ids if output_token_ids is not None else []) self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). @@ -132,8 +133,8 @@ def __init__( self._update_cached_all_tokens() def _update_cached_all_tokens(self): - self._cached_all_token_ids: List[int] = (self._prompt_token_ids + - self._output_token_ids) + self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + + self._output_token_ids) @property def prompt_token_ids(self) -> Tuple[int, ...]: @@ -141,19 +142,27 @@ def prompt_token_ids(self) -> Tuple[int, ...]: @prompt_token_ids.setter def prompt_token_ids(self, new_prompt_token_ids) -> None: - self._prompt_token_ids = list(new_prompt_token_ids) + self._prompt_token_ids = array('l', new_prompt_token_ids) self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) self._update_cached_all_tokens() + @property + def prompt_token_ids_array(self) -> array: + return self._prompt_token_ids + @property def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @output_token_ids.setter def output_token_ids(self, new_output_token_ids) -> None: - self._output_token_ids = list(new_output_token_ids) + self._output_token_ids = array('l', new_output_token_ids) self._update_cached_all_tokens() + @property + def output_token_ids_array(self) -> array: + return self._output_token_ids + def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) self._cached_all_token_ids.append(token_id)