From d19d1f44038c6027891c67a50ff8b8e93d3b91ad Mon Sep 17 00:00:00 2001 From: jimpang Date: Tue, 10 Sep 2024 10:38:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81multi-lora?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm/engine/async_llm_engine.py | 149 +++++++++++++++++--------------- 1 file changed, 79 insertions(+), 70 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8da02ddb29b4c..d67d46bc21ea2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -82,13 +82,13 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._finished = False def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, - Exception]) -> None: + Exception]) -> None: if not self._finished: self._queue.put_nowait(item) def finish( - self, - exception: Optional[Union[BaseException, Type[BaseException]]] = None, + self, + exception: Optional[Union[BaseException, Type[BaseException]]] = None, ) -> None: if not self._finished: self._finished = True @@ -100,7 +100,7 @@ def finished(self) -> bool: return self._finished async def generator( - self + self ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: try: while True: @@ -117,8 +117,8 @@ async def generator( @staticmethod def _is_raisable(value: Any): return isinstance(value, BaseException) or \ - (isinstance(value, type) and \ - issubclass(value, BaseException)) + (isinstance(value, type) and \ + issubclass(value, BaseException)) class RequestTracker: @@ -128,7 +128,7 @@ def __init__(self) -> None: self._request_streams: Dict[str, AsyncStream] = {} self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() self._new_requests: asyncio.Queue[Tuple[AsyncStream, - dict]] = asyncio.Queue() + dict]] = asyncio.Queue() self.new_requests_event = asyncio.Event() def __contains__(self, item): @@ -152,7 +152,7 @@ def propagate_exception(self, def process_request_output(self, request_output: Union[RequestOutput, - EmbeddingRequestOutput], + EmbeddingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -211,7 +211,7 @@ def abort_request(self, request_id: str, *, exception: Optional[Union[BaseException, - Type[BaseException]]] = None, + Type[BaseException]]] = None, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: @@ -262,7 +262,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def step_async( - self, virtual_engine: int + self, virtual_engine: int ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -405,10 +405,10 @@ async def stop_remote_worker_execution_loop_async(self) -> None: await self.model_executor.stop_remote_worker_execution_loop_async() async def _tokenize_prompt_async( - self, - prompt: str, - request_id: str, - lora_request: Optional[LoRARequest], + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], ) -> List[int]: """Async version of :meth:`_tokenize_prompt`.""" tokenizer = self.get_tokenizer_group( @@ -419,10 +419,10 @@ async def _tokenize_prompt_async( lora_request=lora_request) async def _extract_prompt_components_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" if isinstance(inputs, str): @@ -453,9 +453,9 @@ async def _extract_prompt_components_async( return prompt, prompt_token_ids, multi_modal_data async def _process_encoder_decoder_prompt_async( - self, - inputs: PromptInputs, - request_id: str, + self, + inputs: PromptInputs, + request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents @@ -489,11 +489,11 @@ async def _process_encoder_decoder_prompt_async( return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) async def _process_decoder_only_prompt_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( @@ -508,11 +508,11 @@ async def _process_decoder_only_prompt_async( ) async def process_model_inputs_async( - self, - inputs: PromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: """Async version of :meth:`process_model_inputs`.""" if self.is_encoder_decoder_model(): @@ -538,14 +538,14 @@ async def process_model_inputs_async( return self.input_processor(model_inputs) async def add_request_async( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: """Async version of :meth:`add_request`.""" if lora_request is not None and not self.lora_config: @@ -718,11 +718,11 @@ def _get_executor_cls( @classmethod def from_engine_args( - cls, - engine_args: AsyncEngineArgs, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. @@ -777,8 +777,8 @@ def _error_callback(self, exc: Exception) -> None: self._request_tracker.propagate_exception(exc) async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, + self, + lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: if self.engine_use_ray: return await self.engine.get_tokenizer.remote( # type: ignore @@ -944,9 +944,9 @@ async def run_engine_loop(self): if self.engine_use_ray: has_unfinished_requests = ( await (self.engine. - has_unfinished_requests_for_virtual_engine. - remote( # type: ignore - virtual_engine))) + has_unfinished_requests_for_virtual_engine. + remote( # type: ignore + virtual_engine))) else: has_unfinished_requests = ( self.engine. @@ -969,14 +969,14 @@ async def run_engine_loop(self): # This method does not need to be async, but kept that way # for backwards compatibility. async def add_request( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: if not self.is_running: if self.start_engine_loop: @@ -1001,13 +1001,13 @@ async def add_request( return stream.generator() async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -1073,6 +1073,15 @@ async def generate( >>> # Process and return the final output >>> ... """ + # jimpang: process lora id + if lora_request: + if lora_request.lora_name in self.lora_names_map: + lora_request.lora_int_id = self.lora_names_map[lora_request.lora_name] + else: + self.last_lora_id = self.last_lora_id + 1 + lora_request.lora_int_id = self.last_lora_id + self.lora_names_map[lora_request.lora_name] = lora_request.lora_int_id + async for output in await self.add_request( request_id, inputs, @@ -1084,12 +1093,12 @@ async def generate( yield LLMEngine.validate_output(output, RequestOutput) async def encode( - self, - inputs: PromptInputs, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model.