Skip to content

Commit

Permalink
feat: 支持multi-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
jimpang committed Sep 10, 2024
1 parent 623aa6b commit d19d1f4
Showing 1 changed file with 79 additions and 70 deletions.
149 changes: 79 additions & 70 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self._finished = False

def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)

def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
Expand All @@ -100,7 +100,7 @@ def finished(self) -> bool:
return self._finished

async def generator(
self
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try:
while True:
Expand All @@ -117,8 +117,8 @@ async def generator(
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
(isinstance(value, type) and \
issubclass(value, BaseException))


class RequestTracker:
Expand All @@ -128,7 +128,7 @@ def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
dict]] = asyncio.Queue()
self.new_requests_event = asyncio.Event()

def __contains__(self, item):
Expand All @@ -152,7 +152,7 @@ def propagate_exception(self,

def process_request_output(self,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
EmbeddingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
Expand Down Expand Up @@ -211,7 +211,7 @@ def abort_request(self,
request_id: str,
*,
exception: Optional[Union[BaseException,
Type[BaseException]]] = None,
Type[BaseException]]] = None,
verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def step_async(
self, virtual_engine: int
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
Expand Down Expand Up @@ -405,10 +405,10 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
await self.model_executor.stop_remote_worker_execution_loop_async()

async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group(
Expand All @@ -419,10 +419,10 @@ async def _tokenize_prompt_async(
lora_request=lora_request)

async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
Expand Down Expand Up @@ -453,9 +453,9 @@ async def _extract_prompt_components_async(
return prompt, prompt_token_ids, multi_modal_data

async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
Expand Down Expand Up @@ -489,11 +489,11 @@ async def _process_encoder_decoder_prompt_async(
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
Expand All @@ -508,11 +508,11 @@ async def _process_decoder_only_prompt_async(
)

async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
Expand All @@ -538,14 +538,14 @@ async def process_model_inputs_async(
return self.input_processor(model_inputs)

async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Async version of :meth:`add_request`."""
if lora_request is not None and not self.lora_config:
Expand Down Expand Up @@ -718,11 +718,11 @@ def _get_executor_cls(

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
Expand Down Expand Up @@ -777,8 +777,8 @@ def _error_callback(self, exc: Exception) -> None:
self._request_tracker.propagate_exception(exc)

async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
Expand Down Expand Up @@ -944,9 +944,9 @@ async def run_engine_loop(self):
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
Expand All @@ -969,14 +969,14 @@ async def run_engine_loop(self):
# This method does not need to be async, but kept that way
# for backwards compatibility.
async def add_request(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
Expand All @@ -1001,13 +1001,13 @@ async def add_request(
return stream.generator()

async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Expand Down Expand Up @@ -1073,6 +1073,15 @@ async def generate(
>>> # Process and return the final output
>>> ...
"""
# jimpang: process lora id
if lora_request:
if lora_request.lora_name in self.lora_names_map:
lora_request.lora_int_id = self.lora_names_map[lora_request.lora_name]
else:
self.last_lora_id = self.last_lora_id + 1
lora_request.lora_int_id = self.last_lora_id
self.lora_names_map[lora_request.lora_name] = lora_request.lora_int_id

async for output in await self.add_request(
request_id,
inputs,
Expand All @@ -1084,12 +1093,12 @@ async def generate(
yield LLMEngine.validate_output(output, RequestOutput)

async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Expand Down

0 comments on commit d19d1f4

Please sign in to comment.