diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index 2b945ba39..e9b6852d5 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Literal, Optional, Union from lmdeploy.serve.async_engine import AsyncEngine @@ -100,6 +101,9 @@ def check_vl_llm(config: dict) -> bool: def get_task(model_path: str): """get pipeline type and pipeline class from model config.""" + if os.path.exists(os.path.join(model_path, 'triton_models', 'weights')): + # workspace model + return 'llm', AsyncEngine config = get_hf_config_content(model_path) if check_vl_llm(config): return 'vlm', VLAsyncEngine