Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into openvino-2024.3.0-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jun 27, 2024
2 parents 4f0be96 + c3dde36 commit 6bad9bf
Show file tree
Hide file tree
Showing 16 changed files with 1,015 additions and 323 deletions.
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ Alongside each architecture, we include some popular models that use it.
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
- ✅︎
* - :code:`Gemma2ForCausalLM`
- Gemma2
- :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc.
- ✅︎
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
transformers >= 4.42.0 # Required for Gemma 2.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
Expand Down
82 changes: 1 addition & 81 deletions vllm/block.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,10 @@
"""Token blocks."""
import weakref
from collections import defaultdict
from typing import Dict, List
from typing import List

from vllm.utils import Device

_BLANK_TOKEN_ID = -1

DEFAULT_LAST_ACCESSED_TIME = -1

TokensBlock = List[int]


class BlockPool:
"""A pool of logical blocks.
When requests come, we create a lot of logical blocks;
when requests are done, we destroy a lot of logical blocks.
It turns out that creating and destroying logical blocks can be expensive,
especially for the `token_ids` field, which is a list of integers.
To avoid this overhead, we use a pool to manage the logical blocks.
When an old request is done and a new request comes, we can reuse the
logical blocks from the old request to feed the new request.
"""

def __init__(self) -> None:
# block size to list of token blocks
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)

def alloc_block(self, block_size: int) -> TokensBlock:
if block_size in self.pool and self.pool[block_size]:
return self.pool[block_size].pop()
return [_BLANK_TOKEN_ID] * block_size

def del_block(self, block: TokensBlock) -> None:
self.pool[len(block)].append(block)


_BLOCK_POOL = BlockPool()


class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""

def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size

self.token_ids = _BLOCK_POOL.alloc_block(block_size)
# this finalizer is used to return the block to the pool when the object is deleted # noqa
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
# i.e. `self.token_ids` may be deleted before `self`, and we lose
# the opportunity to return the block to the pool
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
self.token_ids)
self.num_tokens = 0

def is_empty(self) -> bool:
return self.num_tokens == 0

def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens

def is_full(self) -> bool:
return self.num_tokens == self.block_size

def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)

def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]

def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]


class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
Expand Down
30 changes: 23 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
update_environment_variables)
print_warning_once, update_environment_variables)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -141,6 +141,17 @@ def __init__(
code_revision, rope_scaling, rope_theta)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
print_warning_once(
"Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
self.disable_sliding_window = True

self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
Expand Down Expand Up @@ -257,8 +268,7 @@ def verify_with_parallel_config(
"BitAndBytes quantization with TP or PP is not supported yet.")

def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
"""Get the sliding window size, or None if disabled."""

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
Expand Down Expand Up @@ -1258,10 +1268,16 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32
# models.
logger.info("Casting torch.float32 to torch.float16.")
torch_dtype = torch.float16
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
Expand Down
19 changes: 9 additions & 10 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def __init__(
self.cross_block_tables: Dict[str, BlockTable] = {}

def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \
else len(seq.logical_token_blocks)
return 0 if seq is None else seq.n_blocks

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
Expand Down Expand Up @@ -298,7 +297,7 @@ def _allocate_sequence(self, \
ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
num_prompt_blocks = seq.n_blocks

block_table: BlockTable = []
for logical_idx in range(num_prompt_blocks):
Expand Down Expand Up @@ -367,7 +366,7 @@ def _promote_last_block(

# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
new_hash = seq.hash_of_block(seq.n_blocks - 1)

# if new_hash is already in the cached table, then free last_block
# and return the cached version
Expand Down Expand Up @@ -407,10 +406,10 @@ def _allocate_last_physical_block(
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
n_blocks = seq.n_blocks
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
block_hash = seq.hash_of_block(n_blocks - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)

# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
Expand All @@ -429,12 +428,12 @@ def append_slots(
num_lookahead_slots: int = 0,
) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
n_blocks = seq.n_blocks
block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks):
if len(block_table) < n_blocks:
# Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1
assert len(block_table) == n_blocks - 1

if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
Expand Down
4 changes: 4 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,10 @@ def vocab_size(self):
def scale(self):
return self.base_layer.scale

@property
def soft_cap(self):
return self.base_layer.soft_cap

@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
Expand Down
Loading

0 comments on commit 6bad9bf

Please sign in to comment.