Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
Varun Sundar Rabindranath committed Dec 17, 2024
1 parent a82d7b5 commit d4d70cc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
8 changes: 5 additions & 3 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:

# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
metadata_hash = None if not request.lora_request else hash(request.lora_request.lora_int_id)
metadata_hash = (None if not request.lora_request else
request.lora_request.lora_int_id)
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids,
parent_hash=metadata_hash)
Expand Down Expand Up @@ -379,8 +380,9 @@ def _cache_full_blocks(
prev_block: The previous block in the chain.
"""
# Update the new blocks with the block hashes through the chain.
metadata_hash = None if request.lora_request is None else request.lora_request.lora_int_id
parent_hash = metadata_hash
metadata_hash = (None if request.lora_request is None else
request.lora_request.lora_int_id)
parent_hash = metadata_hash
if prev_block is not None:
# Previous block must have a block hash because it must be
# a full, cached block.
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def hash_block_tokens(parent_hash: Optional[int],
tuple(curr_block_token_ids))


def hash_request_tokens(block_size: int,
token_ids: Sequence[int],
parent_hash: Optional[int] = None) -> List[BlockHashType]:
def hash_request_tokens(
block_size: int,
token_ids: Sequence[int],
parent_hash: Optional[int] = None) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Expand All @@ -204,8 +205,7 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
block_hash = hash_block_tokens(parent_hash,
block_token_ids)
block_hash = hash_block_tokens(parent_hash, block_token_ids)
ret.append(block_hash)
parent_hash = block_hash.hash_value
return ret

0 comments on commit d4d70cc

Please sign in to comment.