Skip to content

Commit

Permalink
Refactor: pass more detail lint checker
Browse files Browse the repository at this point in the history
Signed-off-by: Dahai Tang <[email protected]>
  • Loading branch information
Dahai Tang committed Dec 4, 2024
1 parent f885301 commit b13d6e4
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 22 deletions.
2 changes: 1 addition & 1 deletion csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void kv_store_copy_blocks2CPU(torch::Tensor& src, torch::Tensor& dst,
}

// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads,
// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads,
// head_size]]
void kv_store_copy_blocks2GPU(
torch::Tensor& src, std::vector<torch::Tensor> const& kv_caches,
Expand Down
8 changes: 4 additions & 4 deletions csrc/kv_store/kv_store.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& vec) {
}

// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads,
// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads,
// head_size]]
void KVStore::CopyBlocks2GPUBatch(torch::Tensor& src,
std::vector<torch::Tensor> const& kv_caches,
Expand Down Expand Up @@ -256,7 +256,7 @@ __global__ void kv_store_copy_blocks_kernel(
namespace {

// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads,
// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads,
// head_size]]
void CopyLayerBlocks2GPUKernelFunc(
const torch::Tensor& src, std::vector<torch::Tensor> const& kv_caches,
Expand Down Expand Up @@ -333,7 +333,7 @@ void CopyLayerBlocks2GPUKernelFunc(
}

// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads,
// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads,
// head_size]]
void CopyLayerBlocks2GPUThreadFunc(
const torch::Tensor& src, std::vector<torch::Tensor> const& kv_caches,
Expand Down Expand Up @@ -388,7 +388,7 @@ void CopyLayerBlocks2GPUThreadFunc(
}; // namespace

// src layout: [num_blocks, 2, num_layer, block_size, num_kv_heads*head_size]
// kv_caches layout: [laysers, [2, num_blocks, block_size, num_kv_heads,
// kv_caches layout: [layers, [2, num_blocks, block_size, num_kv_heads,
// head_size]]
void KVStore::CopyLayerBlocks2GPU(torch::Tensor& src,
std::vector<torch::Tensor> const& kv_caches,
Expand Down
15 changes: 8 additions & 7 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def create_empty(cls) -> "SchedulerPrefillOutputs":
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=0,
kv_store_block_mapping_from_cpu=None,
kv_store_block_mapping_from_cpu=BlockMappingFromCPU.null(),
)


Expand Down Expand Up @@ -935,15 +935,16 @@ def _stop_schedule_prefill(num_new_tokens_uncached,
num_new_seqs,
max_num_batched_tokens,
budget):
ret = False
if (budget.num_batched_tokens >=
self.scheduler_config.max_num_batched_tokens):
return True
ret = True
if (num_new_tokens_uncached == 0 or
not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs)):
return True
return False
ret = True
return ret

kv_store_tmp_queue : Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and kv_store_waiting_queue:
Expand Down Expand Up @@ -1062,13 +1063,13 @@ def _stop_schedule_prefill(num_new_tokens_uncached,
waiting_queue.popleft()
continue

if (self.kv_store_manager != None):
if (self.kv_store_manager is not None):
self.kv_store_manager.is_prefill = seq_group.is_prefill()

block_mapping_from_cpu = []
self._allocate(seq_group)

if (self.kv_store_manager != None):
if (self.kv_store_manager is not None):
block_ids = self.block_manager.get_block_table(
seq_group.get_seqs()[0])
block_mapping_from_cpu = \
Expand Down Expand Up @@ -1137,7 +1138,7 @@ def _stop_schedule_prefill(num_new_tokens_uncached,
if len(seq_groups) > 0:
self.prev_prompt = True

if (self.kv_store_manager != None) and \
if (self.kv_store_manager is not None) and \
(len(kv_store_block_mapping) > 0):
self.kv_store_manager.close_send_flags(
[items[1]
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,15 @@ def forward(
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if (self.kv_store != None) and \
(self.kv_store.batch_layers_to_GPU == True):
if (self.kv_store is not None) and \
(self.kv_store.batch_layers_to_GPU):
self.kv_store.get_stream_sync(
attn_metadata.kv_store_meta.request_ids)

for i in range(self.start_layer, self.end_layer):
layer_id = (i - self.start_layer)
if (self.kv_store is not None) and \
(self.kv_store.batch_layers_to_GPU == False):
(not self.kv_store.batch_layers_to_GPU):
self.kv_store.get_stream_layer_sync(
layer_id, attn_metadata.kv_store_meta.request_ids)
layer = self.layers[i]
Expand Down
21 changes: 15 additions & 6 deletions vllm/store/kv_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

@dataclass
class BlockMappingFromCPU:
block_mapping: torch.Tensor # 2-D tenso
block_offset: torch.Tensor # 1-D tensor, like offset array in CSR format
# 2-D tensor
block_mapping: Optional[torch.Tensor]
# 1-D tensor, like offset array in CSR format
# the offset of each request in block_mapping
request_ids: torch.Tensor # request IDs
block_offset: Optional[torch.Tensor]
request_ids: Optional[torch.Tensor] # request IDs

def __init__(self, block_mapping: list[list[int, int]],
block_offset: list[int], request_ids: list[int]):
Expand All @@ -37,6 +39,11 @@ def __init__(self, block_mapping: list[list[int, int]],
device="cpu",
dtype=torch.int64).view(-1)

@staticmethod
def null():
return BlockMappingFromCPU(
torch.Tensor(), torch.Tensor(), torch.Tensor())

def __str__(self):
return "block_mapping: " + str(self.block_mapping) + \
" block_offset: " + str(self.block_offset) + \
Expand All @@ -45,12 +52,14 @@ def __str__(self):

@dataclass
class KVStoreMeta:
incomplete_put_block_ids: torch.Tensor # 4-D tensor:
# 4-D tensor:
# vllm_block_id,
# start_offset,end_offset,
# store_block_id
put_block_ids_mapping: torch.Tensor # 2-D tensor:
incomplete_put_block_ids: torch.Tensor
# 2-D tensor:
# vllm_block_id, store_block_id
put_block_ids_mapping: torch.Tensor
request_ids: torch.Tensor # 1-D tensor

@staticmethod
Expand Down Expand Up @@ -257,7 +266,7 @@ def get_block_mapping_from_torch(self, vllm_block_ids: torch.Tensor) \
return ret_tensor

def get_block_mapping_from_python(self, vllm_block_ids: list[int]) \
-> list[tuple[int, int]]:
-> list[list[int, int]]:
if (not self.is_prefill) or \
(len(vllm_block_ids) == 0):
return []
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata)
from vllm.store.kv_store import BlockMappingFromCPU, KVStoreMeta
from vllm.store.kv_store import KVStoreMeta
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables)
from vllm.worker.model_runner_base import (BroadcastableModelInput,
Expand Down

0 comments on commit b13d6e4

Please sign in to comment.