Skip to content

Commit

Permalink
neuron
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 1, 2024
1 parent d076d08 commit 14b8af4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 26 deletions.
16 changes: 4 additions & 12 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from torch import nn
from transformers_neuronx.config import GenerationConfig

from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
vllm_config: ModelConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config

ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
if model_config is not None and model_config.get_sliding_window():
logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.")
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()

Expand Down
20 changes: 6 additions & 14 deletions vllm/worker/neuron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch
import torch.distributed

from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import EngineConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)


class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
Expand All @@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
vllm_config: EngineConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
Expand All @@ -44,7 +36,7 @@ def __init__(
init_cached_hf_modules()

self.model_runner: NeuronModelRunner = NeuronModelRunner(
model_config, parallel_config, scheduler_config, device_config)
vllm_config=vllm_config, )
self.is_driver_worker = True

def init_device(self) -> None:
Expand Down

0 comments on commit 14b8af4

Please sign in to comment.