Skip to content

Commit

Permalink
[Bugfix] update neuron for version > 0.5.0 (#7175)
Browse files Browse the repository at this point in the history
Signed-off-by: omrishiv <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
omrishiv and DarkLight1337 authored Aug 15, 2024
1 parent fc93e56 commit 9c1f78d
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
help='Token block size for contiguous chunks of '
'tokens.')

Expand Down
5 changes: 2 additions & 3 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(
self.driver_worker.execute_model
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output

async def check_health_async(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def prepare_model_input(
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/neuron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def prepare_worker_input(
return WorkerInput(num_seq_groups=len(
execute_model_req.seq_group_metadata_list), )

def execute_worker(self, worker_input: WorkerInput) -> None:
pass

def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
Expand Down

0 comments on commit 9c1f78d

Please sign in to comment.