Skip to content

Commit

Permalink
define the parameters clearly
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Dec 5, 2024
1 parent bbcf9a5 commit b3a2887
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions lmdeploy/serve/vl_async_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union

import PIL

from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig,
VisionConfig)
from lmdeploy.pytorch.check_env import try_import_deeplink
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.utils import get_logger
Expand All @@ -19,17 +21,23 @@
class VLAsyncEngine(AsyncEngine):
"""Visual Language Async inference engine."""

def __init__(self, model_path: str, **kwargs) -> None:
vision_config = kwargs.pop('vision_config', None)
backend_config = kwargs.get('backend_config', None)
self.backend = kwargs['backend']
if self.backend == 'pytorch':
def __init__(self,
model_path: str,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
vision_config: Optional[VisionConfig] = None,
**kwargs) -> None:
if backend == 'pytorch':
try_import_deeplink(backend_config.device_type)
self.vl_encoder = ImageEncoder(model_path,
self.backend,
backend,
vision_config,
backend_config=backend_config)
super().__init__(model_path, **kwargs)
super().__init__(model_path,
backend=backend,
backend_config=backend_config,
**kwargs)
if self.model_name == 'base':
raise RuntimeError(
'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501
Expand Down

0 comments on commit b3a2887

Please sign in to comment.