diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 5870070a54c75..bf4f40ca94e29 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -118,6 +118,12 @@ def send_kv_caches_and_hidden_states( start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer + model_config = model_executable.model.config + num_heads = model_config.num_key_value_heads + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance # FIXME(Kuntai): This assume that all requests are prefill. @@ -131,8 +137,6 @@ def send_kv_caches_and_hidden_states( for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - _, _, num_heads, head_size = kv_cache[0].shape - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) value_cache = kv_cache[1].reshape(-1, num_heads, head_size)