Skip to content

Commit

Permalink
Working on TP>1; there is a dynamic shape bug though
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep committed Jan 17, 2025
1 parent 6550bcb commit 8b2e26b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
1 change: 1 addition & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def configure_as_vllm_process():
# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
import torch._inductor.config
torch._inductor.config.compile_threads = 1

from vllm.platforms import current_platform
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ def forward(
self.past_key_value_states = past_key_value_states

# mark dynamic
'''
if self.past_key_value_states is not None:
for layer in self.past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)
'''

# removing batch padding sequences to compute logits
batch_size = input_ids.shape[0]
Expand Down Expand Up @@ -190,6 +192,12 @@ def load_weights(self, model_config: ModelConfig, max_prompt_length: int,
f"accommodate prompt size of {max_prompt_length} and "
f"decode tokens of {max_decode_length}")

if envs.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
torch._dynamo.config.assume_static_by_default = True
torch._dynamo.config.dynamic_shapes = False
torch._dynamo.config.automatic_dynamic_shapes = False


if envs.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
self.model = torch.compile(self.model,
mode=compile_mode,
Expand Down
13 changes: 12 additions & 1 deletion vllm/platforms/spyre.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Optional
import torch

from vllm.logger import init_logger

Expand Down Expand Up @@ -53,4 +54,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Spyre.")
return False
return False

@classmethod
def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return torch.no_grad()
2 changes: 2 additions & 0 deletions vllm/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,12 @@ def _warmup_model_forward_pass(self, warmup_tokens_tensor,
self.model_runner._update_mask()
self.model_runner._update_position_ids()

'''
if past_key_value_states is not None:
for layer in past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)
'''

logits, past_key_value_states = self.model_runner.\
_raw_model_forward(
Expand Down

0 comments on commit 8b2e26b

Please sign in to comment.