diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 664707e9dc65d..7b02b6247b652 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -24,6 +24,7 @@ class _Backend(enum.Enum): OPENVINO = enum.auto() FLASHINFER = enum.auto() HPU_ATTN = enum.auto() + HPU_ATTN_V1 = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() NO_ATTENTION = enum.auto() @@ -174,6 +175,10 @@ def _cached_get_attn_backend( logger.info("Using HPUAttention backend.") from vllm.attention.backends.hpu_attn import HPUAttentionBackend return HPUAttentionBackend + elif backend == _Backend.HPU_ATTN_V1: + logger.info("Using HPUAttentionV1 backend.") + from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1 + return HPUAttentionBackendV1 elif backend == _Backend.PALLAS: logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend @@ -249,6 +254,10 @@ def which_attn_to_use(head_size: int, return _Backend.ROCM_FLASH if current_platform.is_hpu(): + if selected_backend != _Backend.HPU_ATTN and selected_backend != _Backend.HPU_ATTN_V1: + logger.info("Cannot use %s backend on HPU.", selected_backend) + if use_v1: + return _Backend.HPU_ATTN_V1 return _Backend.HPU_ATTN if use_v1: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 87ade377266a2..2ff8aaed80923 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -984,6 +984,7 @@ def init_distributed_environment( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) + print(distributed_init_method) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " diff --git a/vllm/v1/attention/backends/hpu_attn.py b/vllm/v1/attention/backends/hpu_attn.py new file mode 100644 index 0000000000000..611214ca7114f --- /dev/null +++ b/vllm/v1/attention/backends/hpu_attn.py @@ -0,0 +1,361 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import vllm_hpu_extension.ops as ops +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, + HPUPagedAttentionMetadata) +from vllm.logger import init_logger +from vllm.utils import is_fake_hpu + +logger = init_logger(__name__) + + +class HPUAttentionBackendV1(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "hpu-attn" + + @staticmethod + def get_impl_cls() -> Type["HPUAttentionImpl"]: + return HPUAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return HPUAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dsts: torch.Tensor, + ) -> None: + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dsts: torch.Tensor, + ) -> None: + HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) + + +@dataclass +class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): + """Metadata for HPUAttentionbackend.""" + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + attn_bias: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + seq_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] + + @classmethod + def make_prefill_metadata(cls, seq_lens_tensor, num_prefills, + num_prefill_tokens, slot_mapping): + return cls(is_prompt=True, + block_list=None, + block_mapping=None, + block_usage=None, + block_indices=None, + block_offsets=None, + block_scales=None, + block_groups=None, + attn_bias=None, + num_decode_tokens=0, + context_lens_tensor=None, + multi_modal_placeholder_index_maps=None, + seq_lens_tensor=seq_lens_tensor, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + slot_mapping=slot_mapping) + + @classmethod + def make_cached_prefill_metadata(cls, seq_lens_tensor, context_lens_tensor, num_prefills, + num_prefill_tokens, slot_mapping, block_list): + return cls(is_prompt=True, + block_list=block_list, + block_mapping=None, + block_usage=None, + block_indices=None, + block_offsets=None, + block_scales=None, + block_groups=None, + attn_bias=None, + num_decode_tokens=0, + context_lens_tensor=context_lens_tensor, + multi_modal_placeholder_index_maps=None, + seq_lens_tensor=seq_lens_tensor, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + slot_mapping=slot_mapping) + + @classmethod + def make_decode_metadata(cls, block_list, block_usage, block_groups, + num_decode_tokens, slot_mapping): + return cls(is_prompt=False, + block_mapping=None, + block_indices=None, + block_offsets=None, + block_scales=None, + attn_bias=None, + seq_lens_tensor=None, + context_lens_tensor=None, + num_prefills=0, + num_prefill_tokens=0, + multi_modal_placeholder_index_maps=None, + block_list=block_list, + block_usage=block_usage, + block_groups=block_groups, + num_decode_tokens=num_decode_tokens, + slot_mapping=slot_mapping) + + +class HPUAttentionImpl(AttentionImpl, torch.nn.Module): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + max_seq_len: int = 4096, + ) -> None: + super(AttentionImpl, self).__init__() + self.kv_cache_dtype = kv_cache_dtype + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + self.alibi_slopes = alibi_slopes + if alibi_slopes is not None: + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=torch.bfloat16) + self.alibi_slopes = alibi_slopes_tensor + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '1').lower() in ['1', 'true'] \ + and not is_fake_hpu() + if self.prefill_use_fusedsdpa: + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + + suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + batch_size, seq_len, hidden_size = query.shape + _, seq_len_kv, _ = key.shape + + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + if attn_metadata.is_prompt: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None: + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) + + if attn_metadata.is_prompt: + # Prompt run. + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + if attn_metadata is None or attn_metadata.block_list is None: + if not self.prefill_use_fusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ + 'attn_bias must be set before calling model.forward' + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None: + position_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, + attn_bias.dtype, attn_bias.shape[-1]) + attn_bias = attn_bias.tile( + (1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + else: + attn_bias = None + + out = ops.prompt_attention( + query.view(query_shape), + key.view(kv_shape), + value.view(kv_shape), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, + ) + else: + # TODO: enable FusedSDPA + out = HPUPagedAttention.forward_prefix( + query=query.view(query_shape), + key=key.view(kv_shape), + value=value.view(kv_shape), + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + attn_bias=attn_metadata.attn_bias, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + softmax_op=self.softmax, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) + output = out.reshape(batch_size, seq_len, hidden_size) + else: + # Decoding run. + output = HPUPagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_list=attn_metadata.block_list, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + block_scales=attn_metadata.block_scales, + block_groups=attn_metadata.block_groups, + scale=self.scale, + matmul_qk_op=self.matmul_qk, + matmul_av_op=self.matmul_av, + batch2block_matmul_op=self.batch2block_matmul, + block2batch_matmul_op=self.block2batch_matmul, + keys_fetch_func=self.k_cache.fetch_from_cache, + values_fetch_func=self.v_cache.fetch_from_cache) + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_len: int, +) -> torch.Tensor: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + return bias diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 38f1c03a4d3ac..f7c6ac6524e15 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,9 +1,10 @@ from collections import defaultdict +import os from typing import Dict, List, Optional from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, +from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockHeapQueue, FreeKVCacheBlockQueue, KVCacheBlock, hash_block_tokens, hash_request_tokens) from vllm.v1.request import Request @@ -24,7 +25,7 @@ def __init__( self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks self.sliding_window = sliding_window - self.enable_caching = enable_caching + self.enable_caching = os.environ.get('VLLM_ENABLE_PREFIX_CACHING', 'true') in ['true', '1'] # NOTE(woosuk): To avoid frequent block allocation, we preallocate some # blocks for each request. For example, when a request reaches the end # of its block table, we preallocate N blocks in advance. This way, we @@ -40,12 +41,13 @@ def __init__( # A Block pool of all kv-cache blocks. self.block_pool: List[KVCacheBlock] = [ - KVCacheBlock(idx) for idx in range(num_gpu_blocks) + KVCacheBlock(idx) for idx in range(1, num_gpu_blocks) ] # Free block queue that constructs and manipulates a doubly linked # list of free blocks (including eviction candidates when caching is # enabled). - self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool) + block_queue_impl = FreeKVCacheBlockHeapQueue if os.environ.get('VLLM_USE_HEAPQ') in ['1', 'true'] else FreeKVCacheBlockQueue + self.free_block_queue = block_queue_impl(self.block_pool) # {block_hash: {block ID: block}}. A cached block is # a full block with a block hash that can be used for prefix caching. diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 33dbfb7377bfd..5d89ad2b3bd01 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,7 +1,7 @@ """KV-Cache Utilities.""" from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Union - +from typing import Deque, Dict, List, Optional, Tuple, Union +import heapq from vllm.logger import init_logger logger = init_logger(__name__) @@ -31,6 +31,9 @@ class KVCacheBlock: prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None + def __lt__(self, other): + return self.block_id < other.block_id + def reset(self): """Reset the block metadata.""" self.ref_cnt = 0 @@ -39,6 +42,62 @@ def reset(self): self.num_hashed_tokens = 0 +class FreeKVCacheBlockHeapQueue: + """TODO(kzawora): document this + + Args: + blocks: A list of KVCacheBlock objects. + """ + + def __init__(self, blocks: List[KVCacheBlock]) -> None: + self.num_free_blocks = len(blocks) + self._free_block_indices: List[KVCacheBlock] = blocks[:] + self.tombstone: Dict[KVCacheBlock, int] = {} + heapq.heapify(self._free_block_indices) + assert len(self._free_block_indices) == self.num_free_blocks + + def popleft(self) -> KVCacheBlock: + """Pop the first free block and reduce num_free_blocks by 1. + + Returns: + The first free block. + """ + block: KVCacheBlock = heapq.heappop(self._free_block_indices) + #logger.info(f'[HEAPQ] Popped block {block.block_id}') + return block + + def remove(self, block: KVCacheBlock) -> None: + """Remove a block in the free list and reduce num_free_blocks by 1. + + Args: + block: The block to remove. + """ + self.tombstone[block] = self.tombstone.get(block, 0) + 1 + while len(self._free_block_indices) > 0 and self._free_block_indices[0] == block and self.tombstone[block] > 0: + heapq.heappop(self._free_block_indices) + self.tombstone[block] -= 1 + + self.num_free_blocks -= 1 + + def append(self, block: KVCacheBlock) -> None: + """Put a block back into the free list and increase + num_free_blocks by 1. + + Args: + block: The block to append. + """ + heapq.heappush(self._free_block_indices, block) + self.num_free_blocks += 1 + + def get_all_free_blocks(self) -> List[KVCacheBlock]: + """Get all free blocks in the free list. Mainly used for testing. + + Returns: + A list of free blocks. + """ + return list(item for item in self._free_block_indices) + + class FreeKVCacheBlockQueue: """This class organizes a list of KVCacheBlock objects to a doubly linked list of free blocks. We implement this class instead of using Python @@ -84,6 +143,7 @@ def popleft(self) -> KVCacheBlock: block = self.free_list_head self.remove(block) + #logger.info(f'[LL] Popped block {block.block_id}') return block def remove(self, block: KVCacheBlock) -> None: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ba50a9786d805..0cb440480300b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -27,9 +27,14 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], ) -> None: + # TODO: properly handle for HPU. +# cache_config.enable_prefix_caching = False +# scheduler_config.chunked_prefill_enabled = False + self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config + self.disable_prefill_chunking = True # TODO: Support LoRA. assert lora_config is None, "V1 does not support LoRA yet." @@ -200,6 +205,12 @@ def schedule(self) -> "SchedulerOutput": num_computed_tokens -= 1 num_new_tokens = 1 computed_blocks.pop() + + # If chunked prefill is not enabled, breakout of the loop. + if (not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget): + break + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -521,16 +532,14 @@ def from_request( block_ids: List[int], num_computed_tokens: int, ) -> "NewRequestData": - return cls( - req_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, - mm_inputs=request.mm_inputs, - mm_positions=request.mm_positions, - sampling_params=request.sampling_params, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - ) + return cls(req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens) @dataclass @@ -547,11 +556,9 @@ def from_request( block_ids: List[int], num_computed_tokens: int, ) -> "ResumedRequestData": - return cls( - req_id=request.request_id, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - ) + return cls(req_id=request.request_id, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens) @dataclass @@ -568,11 +575,9 @@ def from_request( new_block_ids: List[int], num_computed_tokens: int, ) -> "RunningRequestData": - return cls( - req_id=request.request_id, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) + return cls(req_id=request.request_id, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens) @dataclass diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 09bff9655a882..e7929ecb25149 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union +from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Type, Union from vllm.config import ModelConfig, VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -20,7 +20,7 @@ from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor -from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -30,7 +30,7 @@ class AsyncLLM(EngineClient): def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Any], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -121,6 +121,10 @@ def shutdown(self): @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_hpu(): + from vllm.v1.executor.hpu_executor import HPUExecutor + return HPUExecutor + from vllm.v1.executor.gpu_executor import GPUExecutor return GPUExecutor async def add_request( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 35ed131d50de9..f232b62f39d2e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -18,7 +18,7 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapper -from vllm.v1.executor.gpu_executor import GPUExecutor +#from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder from vllm.version import __version__ as VLLM_VERSION @@ -36,17 +36,17 @@ class EngineCore: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Any], usage_context: UsageContext, ): # Override the configs for V1. # FIXME - if usage_context == UsageContext.LLM_CLASS: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 2048 + #if usage_context == UsageContext.LLM_CLASS: + # vllm_config.scheduler_config.max_num_seqs = 1024 + # vllm_config.scheduler_config.max_num_batched_tokens = 8192 + #elif usage_context == UsageContext.OPENAI_API_SERVER: + # vllm_config.scheduler_config.max_num_seqs = 1024 + # vllm_config.scheduler_config.max_num_batched_tokens = 2048 # TODO (ywang96): Enable APC by default when VLM supports it. if not vllm_config.model_config.is_multimodal_model: @@ -135,7 +135,7 @@ class EngineCoreProc(EngineCore): def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Any], usage_context: UsageContext, input_path: str, output_path: str, @@ -220,7 +220,7 @@ def wait_for_startup( @staticmethod def make_engine_core_process( vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Any], usage_context: UsageContext, input_path: str, output_path: str, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 75a77be750acd..b4617a43ce9d5 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Mapping, Optional, Type, Union +from typing import Any, Dict, List, Mapping, Optional, Type, Union from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -9,6 +9,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -17,7 +18,6 @@ from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor -from vllm.v1.executor.gpu_executor import GPUExecutor logger = init_logger(__name__) @@ -28,7 +28,7 @@ class LLMEngine: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Any], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -99,6 +99,10 @@ def from_engine_args( @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_hpu(): + from vllm.v1.executor.hpu_executor import HPUExecutor + return HPUExecutor + from vllm.v1.executor.gpu_executor import GPUExecutor return GPUExecutor def stop_remote_worker_execution_loop(self) -> None: diff --git a/vllm/v1/executor/hpu_executor.py b/vllm/v1/executor/hpu_executor.py new file mode 100644 index 0000000000000..e13a3a73ed08a --- /dev/null +++ b/vllm/v1/executor/hpu_executor.py @@ -0,0 +1,78 @@ +from typing import Optional, Tuple + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.hpu_worker import HPUWorker + +logger = init_logger(__name__) + + +class HPUExecutor: + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.worker = self._create_worker() + self.worker.initialize() + self.worker.load_model() + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> HPUWorker: + """Return worker init args for a given rank.""" + + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + return HPUWorker( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.worker.determine_num_available_blocks() + + def initialize_cache(self, num_hpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# HPU blocks: %d", num_hpu_blocks) + from vllm_hpu_extension.profiler import HabanaMemoryProfiler + with HabanaMemoryProfiler() as cache_init_m: + self.worker.initialize_cache(num_hpu_blocks) + msg = f"init_cache_engine took {cache_init_m.get_summary_string()}" + logger.info(msg) + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + output = self.worker.execute_model(scheduler_output) + return output + + def check_health(self) -> None: + # HPUExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 927f274541c4d..60da6976bc4d5 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -20,10 +20,11 @@ def forward( logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_top_k_top_p(logits, sampling_metadata) - probs = self.get_probs(logits) + probs = self.get_probs( + logits) if not sampling_metadata.all_greedy else logits sampled = self.sample(probs, sampling_metadata) # Use int32 to reduce the tensor size. - sampled = sampled.to(torch.int32) + sampled = sampled # .to(torch.int32) NOTE(kzawora): WHY DO WE HAVE AN UNDEFINED BEHAVIOR HERE?! IN WHICH WORLD DOES 75696 INT64 CAST TO -828218624 INT32?!? HOW CAN ARGMAX EVEN RETURN -828218624?! >_< if sampling_metadata.max_num_logprobs > 0: logprobs = self.get_logprobs(logits) @@ -36,7 +37,6 @@ def forward( else: topk_logprobs = None topk_indices = None - sampler_output = SamplerOutput( sampled_token_ids=sampled, logprob_token_ids=topk_indices, diff --git a/vllm/v1/worker/hpu_model_runner.py b/vllm/v1/worker/hpu_model_runner.py new file mode 100755 index 0000000000000..9edc13889f4b1 --- /dev/null +++ b/vllm/v1/worker/hpu_model_runner.py @@ -0,0 +1,2024 @@ +import collections +from enum import Enum +import functools +import itertools +import math +import os +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +from vllm import envs +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_fake_hpu, + is_pin_memory_available) +from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1, HPUAttentionMetadata +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, + HabanaMemoryProfiler, format_bytes) +from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.ops import batch2block, block2batch +import habana_frameworks.torch as htorch +import habana_frameworks.torch.internal.bridge_config as bc +from multiprocessing.pool import ThreadPool + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.engine.detokenizer import Detokenizer + +logger = init_logger(__name__) + +_TYPE_CACHE = {} +# These values are assumed to be zero in several places. +# Use caution when updating them! +_PAD_SLOT_ID = 0 +_PAD_BLOCK_ID = 0 + +class PhaseType(Enum): + UNCACHED_PREFILL = 'uncached_prefill' + CACHED_PREFILL = 'cached_prefill' + DECODE = 'decode' + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + logits_indices: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata, self.logits_indices) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: HPUAttentionMetadata = None + logits_indices: Optional[torch.Tensor] = None + + + +def flatten(in_list): + return list(itertools.chain(*in_list)) + + +def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + +def _async_h2d_tensor_copy(source, device='hpu'): + assert source.device.type == 'cpu', "Source tensor is not present in host memory!" + target = torch.empty(source.shape, dtype=source.dtype, device=device) + target.copy_(source, non_blocking=True) + return target + + +class HpuModelAdapter: + + def __init__(self, model, block_size, dtype, enforce_eager): + self.model = model + self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '1').lower() in ['1', 'true'] \ + and not is_fake_hpu() + self.block_size = block_size + self.dtype = dtype + if not is_fake_hpu() and not htorch.utils.internal.is_lazy( + ) and not enforce_eager: + self.model = torch.compile(self.model, + backend='hpu_backend', + dynamic=False) + + def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, + dtype): + prefix_caching_enabled = attn_metadata.is_prompt and attn_metadata.block_list is not None + if (attn_metadata is None or (self.prefill_use_fusedsdpa and not prefix_caching_enabled) + or not attn_metadata.is_prompt): + return attn_metadata + + prefill_metadata = attn_metadata + + context_lens_t = prefill_metadata.context_lens_tensor + query_lens_t = prefill_metadata.seq_lens_tensor + + block_list = attn_metadata.block_list + max_context_len = (block_list.size(-1) // + batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + past_mask = torch.arange(0, + max_context_len, + dtype=torch.int32, + device=device) + past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge( + context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand( + batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) + + len_mask = (torch.arange(0, seq_len, device=device, + dtype=torch.int32).view(1, seq_len).ge( + query_lens_t.unsqueeze(-1)).view( + batch_size, 1, 1, seq_len)) + causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), + device=device, + dtype=torch.bool), + diagonal=1) + mask = causal_mask.logical_or(len_mask) + mask = torch.concat((past_mask, mask), dim=-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) + return attn_metadata + + def _set_block_mapping(self, metadata, batch_size, device, dtype): + mask = torch.arange(0, + self.block_size, + device=device, + dtype=torch.int32).unsqueeze(0) + mask = mask >= metadata.block_usage.unsqueeze(-1) + attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( + mask, -math.inf)) + + if not is_fake_hpu() and htorch.utils.internal.is_lazy(): + block_mapping = torch.nn.functional.one_hot(metadata.block_groups, + num_classes=batch_size) + else: + # Unfortunately one_hot on CPU/torch.compile mode/eager mode + # doesn't handle out of bounds classes so we need to convert + # all negative values to 0 (block_mapping) or bs (block_groups) + block_groups = metadata.block_groups.to(torch.long) + block_mapping = torch.nn.functional.relu(block_groups) + block_mapping = torch.nn.functional.one_hot(block_mapping, + num_classes=batch_size) + oob_values = block_groups.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + block_groups.masked_fill_(oob_values, batch_size) + metadata = metadata._replace(block_groups=block_groups) + block_mapping = block_mapping.to(dtype) + metadata = metadata._replace(block_mapping=block_mapping, + attn_bias=attn_bias) + return metadata + + def _set_block_scales(self, metadata, device): + block_mapping = metadata.block_mapping + ones = torch.ones((block_mapping.size(0), ), + device=device, + dtype=block_mapping.dtype) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + metadata = metadata._replace(block_scales=block_scales) + return metadata + + def _set_indices_and_offsets(self, metadata, block_size, is_prompt): + slot_mapping = metadata.slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + indices = indices.unflatten(0, (-1, block_size))[:, 0] + offsets = None + else: + offsets = torch.fmod(slot_mapping, block_size) + metadata = metadata._replace(block_offsets=offsets, + block_indices=indices) + return metadata + + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, + dtype): + if attn_metadata.is_prompt: + attn_metadata = self._set_attn_bias(attn_metadata, batch_size, + seq_len, device, dtype) + else: + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, + device, dtype) + attn_metadata = self._set_block_scales(attn_metadata, device) + attn_metadata = self._set_indices_and_offsets(attn_metadata, + self.block_size, + attn_metadata.is_prompt) + return attn_metadata + + def forward(self, *args, **kwargs): + kwargs = kwargs.copy() + input_ids = kwargs['input_ids'] + kwargs['attn_metadata'] = self._update_metadata( + kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), + input_ids.device, self.dtype) + hidden_states = self.model(*args, **kwargs) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + return hidden_states + + def compute_logits(self, *args, **kwargs): + return self.model.compute_logits(*args, **kwargs) + + def sample(self, *args, **kwargs): + return self.model.sample(*args, **kwargs) + + def generate_proposals(self, *args, **kwargs): + return self.model.generate_proposals(*args, **kwargs) + + # sampler property will be used by spec_decode_worker + # don't rename + @property + def sampler(self): + return self.model.sampler + + +def _maybe_wrap_in_hpu_graph(*args, **kwargs): + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter( + *args, **kwargs) + + +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + if type(obj) is dict: + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) + + +def trim_attn_metadata(metadata: HPUAttentionMetadata) -> object: + # NOTE(kzawora): To anyone working on this in the future: + # Trimming metadata is required when using HPUGraphs. + # Attention metadata is going to be hashed by PT bridge, and + # appropriate HPUGraphs will be matched based on all inputs' hash. + + # Before you put more keys in here, make sure you know their + # value type and make sure you know how it's going to be hashed. + # You can find that information in input_hash function + # in habana_frameworks/torch/hpu/graphs.py. You can also hash + # it manually with torch.hpu.graphs.input_hash(attention_metadata) + + # If you use primitive types here - they will get hashed based + # on their value. You *will* get lots of excessive graph captures + # (and an OOM eventually) if you decide to put something like + # seq_len int here. + # If you absolutely need a scalar, put it in a tensor. Tensors + # get hashed using their metadata, not their values: + # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) + # input_hash(123) != input_hash(321) + # input_hash("abc") != input_hash("cba") + attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ + 'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', 'block_list', + 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', + 'block_indices', 'block_offsets', 'block_scales', 'block_groups' + ]) + return attention_metadata + + +def next_pow2(value: int, base: int): + res = base + while value > 1: + value = (value + 1) // 2 + res *= 2 + return res + + +def round_up(value: int, k: int): + return (value + k - 1) // k * k + + +def pad_list(list, k, v): + target_len = round_up(len(list), k) + padding = target_len - len(list) + return list + [v] * padding + + +def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + indices = indices.unflatten(0, (-1, block_size))[:, 0] + offsets = None + else: + offsets = torch.fmod(slot_mapping, block_size) + return indices, offsets + + +class HPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + ): + #TODO(kzawora): remove this, this is ugly and only used for diagnostics + self._ENGINE_ITER = 0 + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + #TODO(kzawora): remove this, this is for debug purposes only + self._tokenizer = Detokenizer( + vllm_config.model_config.tokenizer).tokenizer + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + + # Model-related. + self.num_attn_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + + # Lazy initialization + # self.model: nn.Module # Set after load_model + self.kv_caches: List[torch.Tensor] = [] + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.scheduler_config.max_num_seqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + ) + + self.use_hpu_graph = not self.model_config.enforce_eager + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. + self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] + self.max_batch_size = 256 # TODO(kzawora): fix this garbage + self.input_ids = torch.zeros( + (self.max_batch_size, self.max_num_tokens), + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros( + (self.max_batch_size, self.max_num_tokens), + dtype=torch.int64, + device=self.device) + self.prefill_positions = torch.tensor( + range(self.max_model_len), + device="cpu", + ).to(torch.int32).reshape(1, -1) + + self.max_num_seqs = self.scheduler_config.max_num_seqs + self.max_prefill_batch_size = 16 # TODO(kzawora): add some knob for that + self.padding_aware_scheduling = True # TODO(kzawora): add some knob for that + self.padding_ratio_threshold = 0.9 # TODO(kzawora): add some knob for that + self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', + 'true').lower() == 'true' + self.seen_configs: set = set() + self.enable_bucketing = os.environ.get( + 'VLLM_DISABLE_BUCKETING', 'false').lower() not in ['true', '1'] + if self.enable_bucketing: + logger.info("Bucketing is ON.") + self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, + self.max_prefill_batch_size, + self.block_size, + self.scheduler_config.max_num_batched_tokens) + self.graphed_buckets: Set[Any] = set() + else: + logger.info("Bucketing is OFF.") + self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', + 'false').lower() == 'true' + + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + end_index = start_index + num_new_blocks + req_state.block_ids.extend(req_data.new_block_ids) + # assert (min(end_index, self.input_batch.block_table_cpu.shape[1]) - start_index) == len(req_data.new_block_ids): + self.input_batch.block_table_cpu[ + req_index, start_index:end_index] = req_data.new_block_ids + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for req_data in scheduler_output.scheduled_new_reqs: + req_id = req_data.req_id + sampling_params = req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=req_data.prompt_token_ids, + prompt=req_data.prompt, + mm_inputs=req_data.mm_inputs, + mm_positions=req_data.mm_positions, + sampling_params=sampling_params, + generator=generator, + block_ids=req_data.block_ids, + num_computed_tokens=req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for req_data in scheduler_output.scheduled_resumed_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # THIS MOVES ALL THE DECODES TO THE FIRST N IN BATCH. + # Condense the batched states if there are empty indices. + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + # ALL THE PREFILLS ARE THE LAST M IN THE BATCH. + # These are added at the end after the bacth is condensed. + self.input_batch.num_prefills = len(req_ids_to_add) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) + + def _prepare_sampling(self, + scheduler_output: "SchedulerOutput", + start_idx: Optional[int] = None, + end_idx: Optional[int] = None, + pad_to: Optional[int] = None) -> SamplingMetadata: + skip_copy = True + if start_idx is None and end_idx is None: + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + else: + #TODO(kzawora): something smells... kinda fishy in here + req_ids = self.input_batch.req_ids[start_idx:end_idx] + finished_req_ids = any([ + req_id in scheduler_output.finished_req_ids + for req_id in req_ids + ]) + preempted_req_ids = any([ + req_id in scheduler_output.preempted_req_ids + for req_id in req_ids + ]) + scheduled_new_reqs = any([ + req_id in scheduler_output.scheduled_new_reqs + for req_id in req_ids + ]) + scheduled_resumed_reqs = any([ + req_id in scheduler_output.scheduled_resumed_reqs + for req_id in req_ids + ]) + + if (finished_req_ids or preempted_req_ids): + skip_copy = False + if (scheduled_new_reqs or scheduled_resumed_reqs): + skip_copy = False + + # Create the sampling metadata. + sampling_metadata = self.input_batch.make_sampling_metadata( + skip_copy=skip_copy, + start_idx=start_idx, + end_idx=end_idx, + pad_to=pad_to) + return sampling_metadata + + def get_habana_paged_attn_buffers(self, + block_tables, + slot_mapping, + bucketing=True): + + last_block_usage = [ + slot[0] % self.block_size + 1 for slot in slot_mapping + ] + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt] + + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + + padding_fn = None + if self.use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + if bucketing: + block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + padding_fn = lambda tensor, pad_value: gather_list( + tensor, indices, pad_value) + else: + block_bucket_size: int + if bucketing: + block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + len(block_list)) + else: + block_bucket_size = len(block_list) + padding_fn = lambda tensor, pad_value: pad_list( + tensor, block_bucket_size, pad_value) + + block_list = padding_fn(block_list, _PAD_BLOCK_ID) + block_groups = padding_fn(block_groups, -1) + block_usage = padding_fn(block_usage, 1) + + block_list = torch.tensor(block_list, dtype=torch.long, device='cpu') + block_groups = torch.tensor(block_groups, + dtype=torch.long, + device='cpu') + block_usage = torch.tensor(block_usage, + dtype=self.model_config.dtype, + device='cpu') + + return block_list, block_groups, block_usage + + def _get_padded_prefill_dims(self, num_prefills, max_prompt_len, bucketing): + if bucketing: + padded_batch_size = self.bucketing_ctx.get_padded_batch_size( + num_prefills, True) + padded_prompt_len = self.bucketing_ctx.get_padded_prompt_seq_len( + max_prompt_len) + else: + #NOTE(kzawora): On HPU prompt length needs to be block_size + # aligned, so we're padding to that, even if bucketing + # is disabled. + padded_batch_size = num_prefills + padded_prompt_len = math.ceil( + max_prompt_len / self.block_size) * self.block_size + assert padded_prompt_len <= self.max_model_len + return padded_batch_size, padded_prompt_len + + + + def _prepare_prefill_inputs(self, + num_scheduled_tokens: List[int], + bucketing=True) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + prefill_logits_indices = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + # NOTE(kzawora): This loop was initially implemented as + # for batch_idx in range(num_decodes, num_reqs, max_prefill_batch_size) + # but was changed to accomodate variable loop step size for + # padding-aware scheduling + batch_idx = num_decodes + while batch_idx < num_reqs: + # Find the largest batch size in range [1, max_prefill_batch_size] + # that can fit within specified token budget + num_prefills: int + padded_batch_size: int + padded_prompt_len: int + padded_num_tokens: int + padding_ratio: float + batch_req_ids: List[int] + prompt_lens: List[int] + for possible_batch_size in reversed(range(1, self.max_prefill_batch_size + 1)): + if batch_idx + possible_batch_size > num_reqs: + continue + num_prefills = possible_batch_size + batch_req_ids = self.input_batch.req_ids[batch_idx:batch_idx + num_prefills] + prompt_lens = num_scheduled_tokens[batch_idx:batch_idx + num_prefills] + max_prompt_len = max(prompt_lens) + num_tokens = sum(prompt_lens) + padded_batch_size, padded_prompt_len = self._get_padded_prefill_dims(num_prefills, max_prompt_len, bucketing) + padded_num_tokens = padded_batch_size * padded_prompt_len + padding_ratio = 1 - (num_tokens/padded_num_tokens) + is_within_token_budget = padded_batch_size * padded_prompt_len < self.scheduler_config.max_num_batched_tokens + is_within_padding_ratio_threshold = padding_ratio < self.padding_ratio_threshold + can_schedule = is_within_token_budget and is_within_padding_ratio_threshold + # If padding aware scheduling is off, we'll break on the first + # loop iteration (==max_prefill_batch_size). + # Else, we'll break on first batch size that fits token budget. + if not self.padding_aware_scheduling or can_schedule: + break + context_lens = self.input_batch.num_computed_tokens_cpu[batch_idx:batch_idx + num_prefills] + use_prefix_caching = any(context_lens) + # TODO(kzawora): this is an ugly hack for prefix caching, remove that once batch padding works properly (idk why it doesn't) + if use_prefix_caching: + padded_batch_size = num_prefills + + padded_prompt_lens = [ + padded_prompt_len for _ in range(padded_batch_size) + ] + #logger.info(f"Using {'cached' if use_prefix_caching else 'uncached'} prefill batch size {num_prefills} padded to [{padded_batch_size}, {padded_prompt_len}] with context {max(context_lens)} (token budget: {self.scheduler_config.max_num_batched_tokens}, num_tokens: {padded_num_tokens}, padding ratio: {padding_ratio:.2f})") + + # TOKEN_IDS. + token_ids = torch.zeros((padded_batch_size, padded_prompt_len), + dtype=torch.int32, + device='cpu') + # POSITIONS. + positions = torch.zeros((padded_batch_size, padded_prompt_len), + dtype=torch.int32, + device='cpu') + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + slot_mapping = torch.zeros((padded_batch_size, padded_prompt_len), + dtype=torch.int32, + device='cpu') + slot_mapping.fill_(_PAD_SLOT_ID) + + for i, (prompt_len, context_len) in enumerate(zip(prompt_lens, context_lens)): + # Prepare and sanitize token ids (cpu) + batch_offset = batch_idx + i + token_ids[i, :prompt_len] = torch.from_numpy(self.input_batch.token_ids_cpu[batch_offset, context_len:context_len+prompt_len]) + #token_ids[i, prompt_len:] = 0 # no need to sanitize - buffer is pre-filled with 0s + + # Prepare and sanitize positions ids (cpu) + positions[i, :prompt_len] = self.prefill_positions[:, context_len:context_len+prompt_len] + #positions[i, prompt_len:] = 0 # no need to sanitize - buffer is pre-filled with 0s + + # Prepare and sanitize slot_mapping (cpu) + flat_prefill_positions = positions[i, :prompt_len].flatten() + block_numbers = self.input_batch.block_table_cpu_tensor[batch_offset, + flat_prefill_positions // self.block_size] + block_offsets = flat_prefill_positions % self.block_size + slot_mapping[i, :prompt_len] = block_numbers * self.block_size + block_offsets + #slot_mapping[i, prompt_len:] = _PAD_SLOT_ID # no need to sanitize - buffer is pre-filled with _PAD_SLOT_IDs + slot_mapping = slot_mapping.long() + + logits_indices = torch.zeros(padded_batch_size, + dtype=torch.int32, + device='cpu') + query_start_loc = torch.empty((num_prefills + 1, ), + dtype=torch.int32, + device="cpu") + query_start_loc_np = query_start_loc.numpy() + query_start_loc_np[0] = 0 + + # logits indices in prefill must account for padding: last token logits + # will be emitted at index (idx - 1) * padded_seq_len + seq_len[idx] - 1 + np.cumsum(padded_prompt_lens[:num_prefills], + out=query_start_loc_np[1:]) + query_start_loc_np[:num_prefills] += num_scheduled_tokens[ + batch_idx:batch_idx + num_prefills] + logits_indices[:num_prefills] = query_start_loc[:num_prefills] - 1 + + # HPU should *not* sync here with CPU + seq_lens_tensor = torch.zeros((padded_batch_size), + dtype=torch.int32, + device='cpu') + seq_lens_tensor[:num_prefills] = torch.tensor(prompt_lens, + device='cpu') + token_ids_device = _async_h2d_tensor_copy(token_ids, self.device) + positions_device = _async_h2d_tensor_copy(positions, self.device) + seq_lens_tensor_device = _async_h2d_tensor_copy( + seq_lens_tensor, self.device) + slot_mapping_device = _async_h2d_tensor_copy( + slot_mapping, self.device) + logits_indices_device = _async_h2d_tensor_copy( + logits_indices, self.device) + + prefill_request_ids.append(batch_req_ids) + prefill_prompt_lens.append(prompt_lens) + prefill_token_ids.append(token_ids_device) + prefill_position_ids.append(positions_device) + prefill_logits_indices.append(logits_indices_device) + attn_metadata = None + if use_prefix_caching: + # Prefix caching + num_blocks = np.ceil(context_lens / self.block_size).astype(np.int32).tolist() + max_num_blocks = max(num_blocks) + #if bucketing: + # max_num_blocks = self.bucketing_ctx.get_padded_decode_num_blocks(max_num_blocks) + prefix_block_tables = torch.zeros((padded_batch_size, max_num_blocks), dtype=torch.int32,device='cpu') + for i, n in enumerate(num_blocks): + prefix_block_tables[i, :n] = self.input_batch.block_table_cpu_tensor[i, :n] + context_lens_tensor = torch.zeros((padded_batch_size), dtype=torch.int32, device='cpu') + context_lens_tensor[:num_prefills] = torch.tensor(context_lens, device='cpu') + + block_list_device = _async_h2d_tensor_copy(prefix_block_tables.flatten(), self.device) + context_lens_tensor_device = _async_h2d_tensor_copy(context_lens_tensor, self.device) + attn_metadata = HPUAttentionMetadata.make_cached_prefill_metadata( + seq_lens_tensor=seq_lens_tensor_device, + context_lens_tensor=context_lens_tensor_device, + num_prefills=num_prefills, + num_prefill_tokens=sum(prompt_lens), + slot_mapping=slot_mapping_device, + block_list=block_list_device + ) + else: + attn_metadata = HPUAttentionMetadata.make_prefill_metadata( + seq_lens_tensor=seq_lens_tensor_device, + num_prefills=num_prefills, + num_prefill_tokens=sum(prompt_lens), + slot_mapping=slot_mapping_device, + ) + #import pdb; pdb.set_trace() + # ATTN_METADATA. + prefill_attn_metadata.append(attn_metadata) + batch_idx += num_prefills + + return PrefillInputData(request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + logits_indices=prefill_logits_indices) + + def _prepare_decode_inputs(self, + num_scheduled_tokens, + bucketing=True) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size: int + if bucketing: + padded_batch_size = self.bucketing_ctx.get_padded_batch_size( + num_decodes, False) + else: + padded_batch_size = num_decodes + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + positions = positions[:padded_batch_size] + + # TOKEN_IDS. [batch, 1] + token_ids = torch.zeros((padded_batch_size, 1), dtype=torch.int32) + token_ids[:num_decodes] = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:num_decodes] + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_number = torch.gather( + input=self.input_batch.block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + # NOTE(kzawora): the "-1" is what causes this entire thing to work + # properly and have good accuracy - why? beats me... + block_offsets = (index - 1) % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + context_lens = self.input_batch.num_computed_tokens_cpu[:num_decodes] + num_blocks = np.ceil(context_lens / self.block_size).astype( + np.int32).tolist() + block_tables_list = [] + for i, n in enumerate(num_blocks): + block_tables_list.append( + self.input_batch.block_table_cpu_tensor[i, :n].tolist()) + + # CONTEXT_LENS [batch_size] + #context_lens = (positions.reshape(-1) + 1) + + block_list, block_groups, block_usage = self.get_habana_paged_attn_buffers( + block_tables_list, slot_mapping.tolist(), bucketing) + + logits_indices = torch.zeros(padded_batch_size, + dtype=torch.int32, + device='cpu') + query_start_loc = torch.empty((num_decodes + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + query_start_loc_np = query_start_loc.numpy() + query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens[:num_decodes], + out=query_start_loc_np[1:]) + logits_indices[:num_decodes] = query_start_loc[1:] - 1 + num_decode_tokens = torch.tensor(np.sum(context_lens), device='cpu') + + # CPU<>HPU sync *should not* happen here. + token_ids_device = _async_h2d_tensor_copy(token_ids, self.device) + positions_device = _async_h2d_tensor_copy(positions, self.device) + logits_indices_device = _async_h2d_tensor_copy(logits_indices, + self.device) + block_list_device = _async_h2d_tensor_copy(block_list, self.device) + block_usage_device = _async_h2d_tensor_copy(block_usage, self.device) + block_groups_device = _async_h2d_tensor_copy(block_groups, self.device) + num_decode_tokens_device = _async_h2d_tensor_copy( + num_decode_tokens, self.device) + slot_mapping_device = _async_h2d_tensor_copy(slot_mapping, self.device) + + return DecodeInputData( + num_decodes=num_decodes, + token_ids=token_ids_device, + position_ids=positions_device, + logits_indices=logits_indices_device, + attn_metadata=HPUAttentionMetadata.make_decode_metadata( + block_list=block_list_device, + block_usage=block_usage_device, + block_groups=block_groups_device, + num_decode_tokens=num_decode_tokens_device, + slot_mapping=slot_mapping_device, + )) + + def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", + bucketing=True + ) -> Tuple[PrefillInputData, Optional[DecodeInputData]]: + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + + # NOTE: assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens, bucketing), + self._prepare_decode_inputs(num_scheduled_tokens, bucketing), + ) + + def _seq_len(self, attn_metadata): + return attn_metadata.slot_mapping.size(1) + + def _num_blocks(self, attn_metadata): + if attn_metadata.block_list is None: + return 0 + return attn_metadata.block_list.numel() + + def _phase(self, attn_metadata): + phase_type : PhaseType + is_prompt = attn_metadata.is_prompt + is_prefix_cached = is_prompt and attn_metadata.block_list is not None + if is_prompt and is_prefix_cached: + phase_type = PhaseType.CACHED_PREFILL + elif is_prompt and not is_prefix_cached: + phase_type = PhaseType.UNCACHED_PREFILL + elif not is_prompt: + phase_type = PhaseType.DECODE + else: + raise ValueError("Unrecognized pass type, likely due to malformed attention metadata") + return phase_type + + def _check_config(self, batch_size, seq_len, num_blocks, attn_metadata, warmup_mode): + phase = self._phase(attn_metadata) + cfg = (batch_size, seq_len, num_blocks, phase) + seen = cfg in self.seen_configs + self.seen_configs.add(cfg) + if not seen and not warmup_mode: + phase = phase.value + logger.warning("Configuration: (%s, %s, %s, %s) was not warmed-up!", + phase, batch_size, seq_len, num_blocks) + + def _execute_model_generic(self, token_ids, position_ids, attn_metadata, + logits_indices, kv_caches, warmup_mode=False): + # FORWARD. + batch_size = token_ids.size(0) + seq_len = self._seq_len(attn_metadata) + num_blocks = self._num_blocks(attn_metadata) + + is_prompt = attn_metadata.is_prompt + self._check_config(batch_size, seq_len, num_blocks, attn_metadata, warmup_mode) + additional_kwargs = {} + if htorch.utils.internal.is_lazy() and not self.model_config.enforce_eager: + use_graphs = self._use_graphs(batch_size, seq_len, num_blocks, is_prompt) + additional_kwargs.update( + {"bypass_hpu_graphs": not use_graphs}) + #import pdb; pdb.set_trace() + trimmed_attn_metadata = trim_attn_metadata(attn_metadata) + hidden_states = self.model.forward(input_ids=token_ids, + positions=position_ids, + attn_metadata=trimmed_attn_metadata, + kv_caches=kv_caches) + #hidden_states = hidden_states[:num_scheduled_tokens] + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + return logits + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + # NOTE(kzawora): Since scheduler doesn't differentiate between prefills + # and decodes, we must handle mixed batches. In _update_states we make + # sure that first self.input_batch.num_decodes requests are decodes, + # and remaining ones until the end are prefills. _update_states also + # handles changes in request cache based on scheduler outputs and + # previous iterations (e.g. keeping block tables and context lengths up + # to date, creating, pruning and updating request caches, and some more stuff) + + # If num_decodes == self.input_batch.num_reqs, then batch is all decode, and only a single decode forward pass will be executed in this method. + # If num_decodes == 0, then batch is all prefill, and only prefill forward passes will be executed in this method. + # If neither apply, then batch is mixed, and both prefill and decode forward passes will be executed in this method. + + # First, we will execute all decodes (if any) in a single batch, + # then we'll execute prefills in batches of up to max_prefill_batch_size elements. + # All shapes used in forward passes are bucketed appropriately to mitigate risk of graph recompilations. + + # We can do sampling directly after executing each forward pass (split_sampler=True), + # or execute all forward passes, join the results and execute it once (split_sampler=False). + # Everything is done asynchronously - the only sync point is the place + # where we copy the generated tokens back to the host. + + # Example: If a batch has 6 requests, 3 prefills and 3 decodes, the unprocessed sequences in batch will be laid as follows: + # [D0, D1, D2, P0, P1, P2] + # If we assume max_prefill_batch_size=2, and split_sampler=True the flow of this method will look as follows: + # prepare_inputs: bucket [D0, D1, D2] -> [D0, D1, D2, 0] (BS=4 bucket, 1 seq padding) + # prepare_inputs: bucket [P0, P1, P2] -> [P0, P1], [P2] (BS=2 + BS=1 bucket, no seqs padding) + # decode forward pass BS4 [D0, D1, D2, 0] + # decode compute_logits BS4 [D0, D1, D2, 0] + # decode sampler BS4 [D0, D1, D2, 0] -> [tokD0, tokD1, tokD2, 0] + # prefill[iter 0] forward pass BS2 [P0, P1] + # prefill[iter 0] compute_logits BS2 [P0, P1] + # prefill[iter 0] sampler BS2 [P0, P1] -> [tokP0, tokP1] + # prefill[iter 1] forward pass BS1 [P0, P1] + # prefill[iter 1] compute_logits BS1 [P0, P1] + # prefill[iter 1] sampler BS1 [P0, P1] -> [tokP2] + # prefill concat sampler results [tokP0, tokP1], [tokP2] -> [tokP0, tokP1, tokP2] + # Join the prefill and decode on device into [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2] + # Transfer [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2] to CPU + # On CPU, sanitize [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2] -> [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] + # Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] + + # Example2: Same thing, but with max_prefill_batch_size=4: + # prepare_inputs: bucket [D0, D1, D2] -> [D0, D1, D2, 0] (BS=4 bucket, 1 seq padding) + # prepare_inputs: bucket [P0, P1, P2] -> [P0, P1, P2, 0] (BS=4 bucket, 1 seq padding) + # decode forward pass BS4 [D0, D1, D2, 0] + # decode compute_logits BS4 [D0, D1, D2, 0] + # decode sampler BS4 [D0, D1, D2, 0] -> [tokD0, tokD1, tokD2, 0] + # prefill[iter 0] forward pass BS4 [P0, P1, P2, 0] + # prefill[iter 0] compute_logits BS4 [P0, P1, P2, 0] + # prefill[iter 0] sampler BS4 [P0, P1, P2, 0] -> [tokP0, tokP1, tokP2, 0] + # Join the prefill and decode on device into [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] + # Transfer [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] to CPU + # On CPU, sanitize [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] -> [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] + # Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] + + # Example2: Same thing, but max_prefill_batch_size=4 and split_sampler=False: + # prepare_inputs: bucket [D0, D1, D2] -> [D0, D1, D2, 0] (BS=4 bucket, 1 seq padding) + # prepare_inputs: bucket [P0, P1, P2] -> [P0, P1, P2, 0] (BS=4 bucket, 1 seq padding) + # decode forward pass BS4 [D0, D1, D2, 0] + # decode compute_logits BS4 [D0, D1, D2, 0] + # prefill[iter 0] forward pass BS4 [P0, P1, P2, 0] + # prefill[iter 0] compute_logits BS4 [P0, P1, P2, 0] + # Join the prefill and decode on device into [D0, D1, D2, 0, P0, P1, P2, 0] + # joint sampler BS8 [D0, D1, D2, 0, P0, P1, P2, 0] -> [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] + # Transfer [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] to CPU + # On CPU, sanitize [tokD0, tokD1, tokD2, 0, tokP0, tokP1, tokP2, 0] -> [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] + # Return [tokD0, tokD1, tokD2, tokP0, tokP1, tokP2] + + self._update_states(scheduler_output) + prefill_data, decode_data = self._prepare_inputs( + scheduler_output, bucketing=self.enable_bucketing) + num_reqs = self.input_batch.num_reqs + num_decodes = decode_data.num_decodes + num_prefills = num_reqs - num_decodes + num_padded_decodes = decode_data.token_ids.shape[ + 0] if num_decodes > 0 else 0 + + #FIXME(kzawora): Currently there's no handling of logprobs. Fix that later. + logprob_token_ids = None + logprobs = None + split_sampler = True + prefill_output_device = None + decode_output_device = None + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_decode_bs, 1] + if num_decodes > 0: + htorch.core.mark_step() + logits_device = self._execute_model_generic( + decode_data.token_ids, decode_data.position_ids, + decode_data.attn_metadata, decode_data.logits_indices, self.kv_caches) + htorch.core.mark_step() + if split_sampler: + sampling_metadata = self._prepare_sampling( + scheduler_output, + start_idx=0, + end_idx=num_decodes, + pad_to=num_padded_decodes) + htorch.core.mark_step() + sampler_output = self.model.sample( + logits=logits_device, sampling_metadata=sampling_metadata) + decode_output_device = sampler_output.sampled_token_ids + htorch.core.mark_step() + else: + decode_output_device = logits_device + htorch.core.mark_step() + + ######################### PREFILLS ######################### + # Prefills run with shape [padded_prefill_bs, padded_prefill_len] + if num_prefills > 0: + prefill_seq_offset_start = num_decodes + htorch.core.mark_step() + prefill_output_list = [] + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata, + logits_indices) in enumerate(prefill_data.zipped()): + htorch.core.mark_step() + logits_device = self._execute_model_generic( + token_ids, position_ids, attn_metadata, logits_indices, self.kv_caches) + htorch.core.mark_step() + if split_sampler: + num_curr_prefills = token_ids.shape[0] + prefill_seq_offset_end = prefill_seq_offset_start + num_curr_prefills + if prefill_seq_offset_start == prefill_seq_offset_end: + import pdb; pdb.set_trace() + sampling_metadata = self._prepare_sampling( + scheduler_output, + start_idx=prefill_seq_offset_start, + end_idx=prefill_seq_offset_end, + pad_to=num_curr_prefills) + htorch.core.mark_step() + sampler_output = self.model.sample( + logits=logits_device, + sampling_metadata=sampling_metadata) + sampled_token_ids_device = sampler_output.sampled_token_ids + htorch.core.mark_step() + prefill_seq_offset_end = prefill_seq_offset_start + prefill_output_list.append(sampled_token_ids_device) + else: + prefill_output_list.append(logits_device) + prefill_output_device = torch.cat(prefill_output_list, dim=0) + htorch.core.mark_step() + + ################### (maybe) SAMPLING ################### + #NOTE(kzawora): It might be better to do separate sampling + # for prefills and decodes, since they will have more predictable + # shapes. Or it might not. Idk. I implemented both. + # In my testing, split_sampler=False was a bit faster (Llama3.1-8B@GSM8K), + # no differences in accuracy observed. YMMV. + # HPU <-> CPU sync happens in this section + + sampled_token_ids_cpu: torch.Tensor + if split_sampler: + # If sampler was split, we already have tokens. Let's copy the data to CPU as is, and then discard padded tokens. + prefill_output_cpu = prefill_output_device.cpu( + ) if prefill_output_device is not None else None + decode_output_cpu = decode_output_device.cpu( + ) if decode_output_device is not None else None + # From this point onward, all operations are done on CPU. + + # Discard garbage tokens from prefills and/or decodes + if prefill_output_cpu is not None and decode_output_cpu is not None: + sampled_token_ids_cpu = torch.cat( + (decode_output_cpu[:num_decodes], + prefill_output_cpu[:num_prefills]), + dim=0) + else: + sampled_token_ids_cpu = decode_output_cpu[: + num_decodes] if decode_output_cpu is not None else prefill_output_cpu[: + num_prefills] + else: + # If sampler was not split, we need to sample on device before copying to CPU. + joint_logits_device: torch.Tensor + if decode_output_device is not None and prefill_output_device is not None: + joint_logits_device = torch.cat( + (decode_output_device, prefill_output_device), dim=0) + else: + joint_logits_device = decode_output_device if decode_output_device is not None else prefill_output_device + # NOTE(kzawora): this stuff is not gonna work + assert False, "imma be real, this ain't gonna work chief" + sampled_token_ids_device = torch.argmax(joint_logits_device, dim=1) + sampled_token_ids_padded_cpu = sampled_token_ids_device.cpu() + # From this point onward, all operations are done on CPU. + + # Discard garbage tokens from prefills and/or decodes + # NOTE(kzawora): If we have 3 prefills and 3 decodes, and both + # are padded to 4, the sampled tokens tensor looks as follows: + # [ D0, D1, D2, 0, P0, P1, P2, 0] + # ^___^___^ ^___^___^ + # Here, we're selecting these elements and discard the + # padding in the middle (after prefill tokens) and at the end of the + # tensor (after decode tokens) + # https://numpy.org/doc/stable/reference/generated/numpy.r_.html + sampled_token_ids_cpu = sampled_token_ids_padded_cpu[ + np.r_[:num_decodes, + num_padded_decodes:num_padded_decodes + num_prefills]] + + sampled_token_ids_list = sampled_token_ids_cpu.tolist() + ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + req_state = self.requests[req_id] + + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ################## RETURN ################## + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids_cpu, + logprob_token_ids_cpu=logprob_token_ids, + logprobs_cpu=logprobs, + ) + + if False: + for req_id in self.input_batch.req_ids[:num_reqs]: + req_idx = self.input_batch.req_id_to_index[req_id] + token_ids = self.input_batch.token_ids_cpu[req_idx] + prompt = self._tokenizer.decode( + token_ids[:self.input_batch. + num_prompt_tokens_cpu[req_idx]]) + + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + req_state = self.requests[req_id] + generated = self._tokenizer.decode(req_state.output_token_ids) + phase = 'prefill' if req_idx >= decode_data.num_decodes else 'decode' + logger.info( + f'[ENGINE_ITER {self._ENGINE_ITER}] REQ:{req_id} IDX:{req_idx} {phase} generated token: {self._tokenizer.decode(sampled_token_ids_cpu[req_idx])!r}, all generated so far: {generated!r}' + ) + self._ENGINE_ITER += 1 + #import pdb; pdb.set_trace() + return model_runner_output + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with HabanaMemoryProfiler() as m: # noqa: SIM117 + self.model = get_model(vllm_config=self.vllm_config) + self.model = _maybe_wrap_in_hpu_graph( + self.model, + self.block_size, + dtype=self.model_config.dtype, + enforce_eager=self.model_config.enforce_eager) + self.model_memory_usage = m.consumed_device_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + def _use_graphs(self, batch_size, seq_len, num_blocks, phase): + if self.model_config.enforce_eager: + return False + if self.skip_warmup: + return True + return (batch_size, seq_len, num_blocks, phase) in self.graphed_buckets + + def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): + num_candidates = len(buckets) + phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' + graphed = list(c[:2] for c in self.graphed_buckets + if c[2] == is_prompt) + if num_candidates == 0: + num_candidates = 1 + msg = (f'{phase} captured:{len(graphed)} ' + f'({100 * len(graphed) / num_candidates:.1f}%) ' + f'used_mem:{format_bytes(total_mem)} ' + f'buckets:{sorted(list(graphed))}') + logger.info(msg) + + def warmup_scenario(self, + batch_size, + seq_or_block, + is_prompt, + kv_caches) -> None: + """Dummy warmup run for memory usage and graph compilation.""" + + query_seq_len = seq_or_block if is_prompt else 1 + input_ids = torch.zeros((batch_size, query_seq_len), + dtype=torch.int32, + device='cpu') + position_ids = torch.zeros((batch_size, query_seq_len), + dtype=torch.int32, + device='cpu') + slot_mapping = torch.zeros((batch_size, query_seq_len), + dtype=torch.int64, + device='cpu') + + input_ids_device = _async_h2d_tensor_copy(input_ids, self.device) + position_ids_device = _async_h2d_tensor_copy(position_ids, self.device) + slot_mapping_device = _async_h2d_tensor_copy(slot_mapping, self.device) + + if is_prompt: + seq_lens = torch.zeros((batch_size), dtype=torch.int32, device='cpu') + seq_lens.fill_(seq_or_block) + seq_lens_device = _async_h2d_tensor_copy(seq_lens, self.device) + attn_metadata = HPUAttentionMetadata.make_prefill_metadata( + seq_lens_tensor=seq_lens_device, + num_prefills=batch_size, + num_prefill_tokens=batch_size*seq_or_block, + slot_mapping=slot_mapping_device + ) + else: + block_tables = [x.tolist() for x in np.array_split(np.arange(seq_or_block), batch_size)] + block_list, block_groups, block_usage = self.get_habana_paged_attn_buffers(block_tables=block_tables, slot_mapping=slot_mapping, bucketing=True) + block_list_device = _async_h2d_tensor_copy(block_list, self.device) + block_usage_device = _async_h2d_tensor_copy(block_usage, self.device) + block_groups_device = _async_h2d_tensor_copy(block_groups, self.device) + attn_metadata = HPUAttentionMetadata.make_decode_metadata( + block_list=block_list_device, + block_usage=block_usage_device, + block_groups=block_groups_device, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping_device + ) + + logits_indices = torch.arange(0,batch_size, device='cpu') + logits_indices_device = _async_h2d_tensor_copy(logits_indices, self.device) + # Dummy run. + htorch.core.mark_step() + logits = self._execute_model_generic(input_ids_device, position_ids_device, attn_metadata, logits_indices_device, kv_caches, True) + # TODO: do sampling on logits, warmup sampler and prefill joiner + htorch.core.mark_step() + temperature = torch.ones(batch_size, dtype=torch.float32, device='cpu') + top_p = torch.ones(batch_size, dtype=torch.float32, device='cpu') + top_k = torch.ones(batch_size, dtype=torch.float32, device='cpu') + temperature_device = _async_h2d_tensor_copy(temperature, self.device) + top_p_device = _async_h2d_tensor_copy(top_p, self.device) + top_k_device = _async_h2d_tensor_copy(top_k, self.device) + generators = {i:None for i in range(batch_size)} # NOTE(kzawora): idk what to set here + max_num_logprobs = 0 # NOTE(kzawora): idk what to set here + # NOTE(kzawora: do this in a smarter way) + return None + htorch.core.mark_step() + sampling_metadata = SamplingMetadata( + temperature=temperature_device, + all_greedy=False, # hacky + all_random=True, # hacky + top_p=top_p_device, + top_k=top_k_device, + no_top_p=True, + no_top_k=True, + generators=generators, + max_num_logprobs=max_num_logprobs, + ) + tokens_all_random = self.model.sample(logits, sampling_metadata) + htorch.core.mark_step() + sampling_metadata = SamplingMetadata( + temperature=temperature_device, + all_greedy=True, # hacky + all_random=False, # hacky + top_p=top_p_device, + top_k=top_k_device, + no_top_p=True, + no_top_k=True, + generators=generators, + max_num_logprobs=max_num_logprobs, + ) + tokens_all_greedy = self.model.sample(logits, sampling_metadata) + htorch.core.mark_step() + sampling_metadata = SamplingMetadata( + temperature=temperature_device, + all_greedy=False, # hacky + all_random=False, # hacky + top_p=top_p_device, + top_k=top_k_device, + no_top_p=True, + no_top_k=True, + generators=generators, + max_num_logprobs=max_num_logprobs, + ) + tokens_mixed = self.model.sample(logits, sampling_metadata) + htorch.core.mark_step() + return tokens_all_random, tokens_all_greedy, tokens_mixed + + def log_warmup(self, phase, i, max_i, batch_size, seq_len): + free_mem = format_bytes( + HabanaMemoryProfiler.current_free_device_memory()) + dim = "num_blocks" + if phase == "Prompt": + dim = "seq_len" + msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " + f"batch_size:{batch_size} " + f"{dim}:{seq_len} " + f"free_mem:{free_mem}") + logger.info(msg) + + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): + for i, (batch_size, seq_len) in enumerate(reversed(buckets)): + self.log_warmup('Prompt' if is_prompt else 'Decode', i, + len(buckets), batch_size, seq_len) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + torch.hpu.synchronize() + + def warmup_graphs(self, + strategy, + buckets, + is_prompt, + kv_caches, + available_mem, + starting_mem=0, + total_batch_seq=0.001): + total_mem = starting_mem + idx = 0 + phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' + num_candidates = len(buckets) + ordering : Union[Callable[[Any], Tuple[Any, Any]], \ + Callable[[Any], Tuple[Any, Any, Any]]] + if strategy == 'min_tokens': + ordering = lambda b: (b[0] * b[1], b[1], b[0]) + elif strategy == 'max_bs': + ordering = lambda b: (-b[0], b[1]) + else: + raise NotImplementedError( + f'Unsupported graph allocation strategy: {strategy}') + buckets = list(sorted(buckets, key=ordering)) + captured_all = True + for idx, (batch_size, seq_len) in enumerate(buckets): + # Graph memory usage is proportional to seq dimension in a batch + batch_seq = batch_size * seq_len if is_prompt else batch_size + mem_estimate = batch_seq / total_batch_seq * total_mem + if mem_estimate >= available_mem: + captured_all = False + continue + graphed_bucket = (batch_size, seq_len, is_prompt) + if graphed_bucket in self.graphed_buckets: + continue + self.graphed_buckets.add(graphed_bucket) + self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) + with HabanaMemoryProfiler() as mem_prof: + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) + #TODO(kzawora): align_workers + used_mem = mem_prof.consumed_device_memory + available_mem -= used_mem + total_mem += used_mem + total_batch_seq += batch_seq + + return total_mem, total_batch_seq, captured_all + + @torch.inference_mode() + def warmup_model(self) -> None: + kv_caches = self.kv_caches + if profile := os.environ.get('VLLM_PT_PROFILE', None): + phase, bs, seq_len, graph = profile.split('_') + is_prompt = phase == 'prompt' + graphs = graph == 't' + if graphs: + self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, + True) + raise AssertionError("Finished profiling") + if self.skip_warmup: + logger.info("Skipping warmup...") + return + #self.profiler.start('internal', 'warmup') + max_blocks = kv_caches[0][0].size(0) + self.bucketing_ctx.generate_prompt_buckets() + self.bucketing_ctx.generate_decode_buckets(max_blocks) + + if not htorch.utils.internal.is_lazy() and not self.enforce_eager: + cache_size_limit = len(self.bucketing_ctx.prompt_buckets) + len( + self.bucketing_ctx.decode_buckets) + 1 + torch._dynamo.config.cache_size_limit = max( + cache_size_limit, torch._dynamo.config.cache_size_limit) + # Multiply by 8 to follow the original default ratio between + # the cache_size_limit and accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = max( + cache_size_limit * 8, + torch._dynamo.config.accumulated_cache_size_limit) + + start_mem = HabanaMemoryProfiler.current_device_memory_usage() + start_time = time.perf_counter() + + compile_only_mode_context = functools.partial(bc.env_setting, + "PT_COMPILE_ONLY_MODE", + True) + can_use_compile_only_mode = True + try: + with compile_only_mode_context(): + pass + logger.debug("Using PT_COMPILE_ONLY_MODE.") + except KeyError: + can_use_compile_only_mode = False + logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' + 'Warmup time will be negatively impacted. ' + 'Please update Gaudi Software Suite.') + with compile_only_mode_context( + ) if can_use_compile_only_mode else contextlib.nullcontext(): + self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, + True, kv_caches) + self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, + False, kv_caches) + + if not self.model_config.enforce_eager and htorch.utils.internal.is_lazy(): + assert self.mem_margin is not None, \ + ("HabanaWorker.determine_num_available_blocks needs " + "to be called before warming up the model.") + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_margin + #TODO(kzawora): align_workers + graph_free_mem = graph_free_mem + prompt_graph_mem_ratio = float( + os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3')) + prompt_available_memory = (prompt_graph_mem_ratio * + graph_free_mem) + decode_available_memory = (graph_free_mem - + prompt_available_memory) + msg = ( + f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") + logger.info(msg) + prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', + 'min_tokens') + decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', + 'max_bs') + mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ + self.warmup_graphs( + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, prompt_available_memory) + mem_post_decode, decode_batch_seq, decode_captured_all = \ + self.warmup_graphs( + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, decode_available_memory) + + # Not all prompt buckets were captured, but all decode buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more prompt buckets. + if (mem_post_decode + mem_post_prompt < graph_free_mem + and not prompt_captured_all and decode_captured_all): + mem_post_prompt, _, prompt_captured_all = ( + self.warmup_graphs( + prompt_strategy, + self.bucketing_ctx.prompt_buckets, True, + kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_prompt, prompt_batch_seq)) + + # Not all decode buckets were captured, but all prompt buckets + # were captured and we have some free graph-allocated space + # left. Let's try to use it for capturing more decode buckets. + if mem_post_decode + mem_post_prompt < graph_free_mem \ + and not decode_captured_all \ + and prompt_captured_all: + mem_post_decode, _, _ = self.warmup_graphs( + decode_strategy, + self.bucketing_ctx.decode_buckets, False, + kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_decode, decode_batch_seq) + + self.log_graph_warmup_summary( + self.bucketing_ctx.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary( + self.bucketing_ctx.decode_buckets, False, + mem_post_decode) + + end_time = time.perf_counter() + end_mem = HabanaMemoryProfiler.current_device_memory_usage() + elapsed_time = end_time - start_time + msg = ( + f"Warmup finished in {elapsed_time:.0f} secs, " + f"allocated {format_bytes(end_mem - start_mem)} of device memory") + logger.info(msg) + #self.profiler.end() + + @torch.inference_mode() + def profile_run(self) -> None: + return + """Profile to measure peak memory during forward pass.""" + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + + # Run empty prefill forwards - prefill max batch and prefill max seq + self.warmup_scenario(batch_size=1, + seq_or_block=self.max_model_len, + is_prompt=True, + kv_caches=kv_caches) + max_seq_len = math.ceil((self.max_num_tokens//self.max_prefill_batch_size) / self.block_size) * self.block_size + self.warmup_scenario(batch_size=self.max_prefill_batch_size, + seq_or_block=max_seq_len, + is_prompt=True, + kv_caches=kv_caches) + + @torch.inference_mode() + def capture_model(self) -> None: + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + with set_forward_context(None): + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for num_tokens in reversed(self.cudagraph_batch_sizes): + self.model( + self.input_ids[:num_tokens], + self.positions[:num_tokens], + kv_caches=self.kv_caches, + attn_metadata=None, + ) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = HPUAttentionBackendV1.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + dtype = self.dtype + if self.device != 'hpu' and not is_fake_hpu() \ + and self.dtype == torch.float8_e4m3fn: + dtype = torch.uint8 + for _ in range(self.num_attn_layers): + key_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + value_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + kv_layer = (key_cache, value_cache) + self.kv_caches.append(kv_layer) + if self.enable_bucketing: + self.bucketing_ctx.num_hpu_blocks = num_blocks + htorch.hpu.synchronize() + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) + self.num_output_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + + # NOTE(kzawora): "+1" here prevents us from going OoB in block table + # when max model length is reached. + # Sometimes scheduler allocates two blocks ahead which can go out of + # valid seq len bounds, so e.g. if we have 16 blocks available, and + # we've just filled entirety of 15th block, sometimes scheduler assigns + # 16th and 17th block to the sequence, even though it can never + # reach block 17. I have no idea why that happens, but + # it smells like a bug. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req + 1), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req + 1), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + self.num_prefills = 0 + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + #self.num_output_tokens_cpu[req_index] = request.num_output_tokens + self.num_prompt_tokens_cpu[req_index] = len(request.prompt_token_ids) + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata(self, + skip_copy, + start_idx: Optional[int] = None, + end_idx: Optional[int] = None, + pad_to: Optional[int] = None): + if start_idx is None and end_idx is None and pad_to is None: + return self._make_sampling_metadata_all(skip_copy=skip_copy) + return self._make_sampling_metadata_range(skip_copy, + start_idx, + end_idx, + pad_to=pad_to) + + def _make_sampling_metadata_range( + self, + skip_copy: bool = False, + start_idx: Optional[int] = None, + end_idx: Optional[int] = None, + pad_to: Optional[int] = None) -> SamplingMetadata: + if start_idx is None: + start_idx = 0 + if end_idx is None: + end_idx = self.num_reqs + max_num_reqs = len(self.req_ids) + end_idx = min(end_idx,max_num_reqs) + num_seqs = end_idx - start_idx + padding_needed = max(0, pad_to - num_seqs) + req_ids = self.req_ids[start_idx:end_idx] + if not skip_copy: + self.temperature[start_idx:end_idx].copy_( + self.temperature_cpu_tensor[start_idx:end_idx], + non_blocking=True) + self.top_p[start_idx:end_idx].copy_( + self.top_p_cpu_tensor[start_idx:end_idx], non_blocking=True) + self.top_k[start_idx:end_idx].copy_( + self.top_k_cpu_tensor[start_idx:end_idx], non_blocking=True) + + all_greedy = all([req_id in self.greedy_reqs for req_id in req_ids]) + all_random = all([req_id in self.random_reqs for req_id in req_ids]) + if all_greedy and all_random: + import pdb + pdb.set_trace() #WTF?! + no_top_p = not any([req_id in self.top_p_reqs for req_id in req_ids]) + no_top_k = not any([req_id in self.top_k_reqs for req_id in req_ids]) + # NOTE(kzawora): Generators are used by sampler row-wise. If we got a + # generator for element 5, but it's first row in a batch, + # we need to assign that generator to index 0 - hence the + # i:generators.get(req_id) rather than req_id:generators.get(req_id) + generators = { + i: self.generators.get(req_id, None) + for i, req_id in enumerate( + range(start_idx, end_idx + padding_needed)) + } + temperature_device = self.temperature[start_idx:end_idx + + padding_needed] + top_p_device = self.top_p[start_idx:end_idx + padding_needed] + tok_k_device = self.top_k[start_idx:end_idx + padding_needed] + if end_idx + padding_needed >= max_num_reqs: + # NOTE(kzawora): this is janky, but [start_idx:end_idx+padding_needed] + # falls apart once your padding exceeds max_num_reqs (and it happens pretty + # often, you could increase the temperature/topp/topk allocation, but + # you cannot really make any guarantees ahead of time on the amount of padding you'll use) + # this is kind of a temporary fix, no idea on its performance impact... + temperature_device = torch.empty(pad_to, + device=self.temperature.device, + dtype=self.temperature.dtype) + top_p_device = torch.empty(pad_to, + device=self.top_p.device, + dtype=self.top_p.dtype) + top_k_device = torch.empty(pad_to, + device=self.top_k.device, + dtype=self.top_k.dtype) + # D2D copy + temperature_device[:num_seqs].copy_( + self.temperature[start_idx:end_idx], non_blocking=True) + top_p_device[:num_seqs].copy_(self.top_p[start_idx:end_idx], + non_blocking=True) + top_k_device[:num_seqs].copy_(self.top_k[start_idx:end_idx], + non_blocking=True) + + return SamplingMetadata( + temperature=temperature_device, + all_greedy=all_greedy, + all_random=all_random, + top_p=top_p_device, + top_k=tok_k_device, + no_top_p=no_top_p, + no_top_k=no_top_k, + generators=generators, + max_num_logprobs=self.max_num_logprobs, + ) + + def _make_sampling_metadata_all( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def num_decodes(self) -> int: + return self.num_reqs - self.num_prefills + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/hpu_worker.py b/vllm/v1/worker/hpu_worker.py new file mode 100644 index 0000000000000..085e12f8700fd --- /dev/null +++ b/vllm/v1/worker/hpu_worker.py @@ -0,0 +1,226 @@ +"""A GPU worker class.""" +from contextlib import contextmanager +import gc +import os +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.hpu_model_runner import HPUModelRunner + +logger = init_logger(__name__) +from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + + +class HPUWorker: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + ): + + # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner = HPUModelRunner(vllm_config) + + def initialize(self): + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self) -> None: + self.model_runner.load_model() + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + if is_fake_hpu(): + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu + return fake_hpu_cache_alloc // cache_block_size, 0 + with HabanaMemoryProfiler() as m: + self.model_runner.profile_run() + torch.hpu.synchronize() + msg = ("Model profiling run " + f"took {m.get_summary_string()}") + logger.info(msg) + # At this point we should've allocated the maximum workspace for all + # recipes we will use the extra memory for graphs/blocks + free_hpu_memory = torch.hpu.mem_get_info()[0] + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + graph_reserved_mem = (float( + os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1')) + if not self.model_config.enforce_eager else 0) + graph_headroom = 1 - graph_reserved_mem + available_hpu_memory = free_hpu_memory * \ + self.cache_config.gpu_memory_utilization + hpu_memory_margin = free_hpu_memory * ( + 1 - self.cache_config.gpu_memory_utilization) + self.model_runner.mem_margin = hpu_memory_margin + cache_size_bytes = available_hpu_memory * graph_headroom + graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) + msg = ( + f"Free device memory: {format_bytes(free_hpu_memory)}, " + f"{format_bytes(available_hpu_memory)} usable " + f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," + f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " + f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " + f"{format_bytes(cache_size_bytes)} reserved for KV cache") + logger.info(msg) + num_hpu_blocks = int(cache_size_bytes // cache_block_size) + num_hpu_blocks = max(num_hpu_blocks, 0) + + gc.collect() + return num_hpu_blocks, 0 + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks.""" + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_gpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + with HabanaMemoryProfiler() as m: + self.model_runner.initialize_kv_cache(num_gpu_blocks) + torch.hpu.synchronize() + msg = ("Initializing cache engine " + f"took {m.get_summary_string()}") + logger.info(msg) + self.compile_or_warm_up_model() + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.warmup_model() + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + #with track_graph_compile('HPUWorker.execute_model'): + output = self.model_runner.execute_model(scheduler_output) + # TODO(woosuk): Send the output to the engine process. + return output + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + init_distributed_environment(parallel_config.world_size, + rank, + distributed_init_method, + local_rank, + backend='hccl') + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + dummy_tensor_hpu = torch.ones(1).to('hpu') + torch.distributed.all_reduce(dummy_tensor_hpu) + assert dummy_tensor_hpu.item() == parallel_config.world_size + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total + + +@contextmanager +def track_graph_compile(name: str): + import habana_frameworks.torch as htorch + from habana_frameworks.torch.hpu.metrics import metric_localcontext + with metric_localcontext("graph_compilation") as gc: + yield + htorch.hpu.synchronize() + if gc.stats()[0][1] != 0: + msg = f"[{name}] graph compilation detected: {gc.stats()}" + logger.warning(msg) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7aa68d1e98abf..4f9d935686521 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -516,7 +516,7 @@ def __init__( self._set_gc_threshold() self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', - 'true').lower() == 'true' + 'false').lower() == 'true' # For multi-step scheduling self.cached_step_outputs: List[torch.Tensor] = []