Skip to content

Commit

Permalink
add other devices supported model list
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Oct 17, 2024
1 parent bffbdb1 commit 4430b50
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 23 deletions.
2 changes: 1 addition & 1 deletion paddlex/inference/utils/new_ir_blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

NEWIR_BLOCKLIST = [
NEWIR_BLACKLIST = [
"FasterRCNN-ResNet34-FPN",
"FasterRCNN-ResNet50",
"FasterRCNN-ResNet50-FPN",
Expand Down
18 changes: 9 additions & 9 deletions paddlex/inference/utils/pp_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...utils.device import parse_device, set_env_for_device, get_default_device
from ...utils.device import (
parse_device,
set_env_for_device,
get_default_device,
check_device,
)
from ...utils import logging
from .new_ir_blacklist import NEWIR_BLOCKLIST
from .new_ir_blacklist import NEWIR_BLACKLIST


class PaddlePredictorOption(object):
Expand All @@ -28,7 +33,6 @@ class PaddlePredictorOption(object):
"mkldnn",
"mkldnn_bf16",
)
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu")

def __init__(self, model_name=None, **kwargs):
super().__init__()
Expand Down Expand Up @@ -60,7 +64,7 @@ def _get_default_config(self):
"cpu_threads": 1,
"trt_use_static": False,
"delete_pass": [],
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
"enable_new_ir": True if self.model_name not in NEWIR_BLACKLIST else False,
}

@property
Expand Down Expand Up @@ -95,11 +99,7 @@ def device(self, device: str):
if not device:
return
device_type, device_ids = parse_device(device)
if device_type not in self.SUPPORT_DEVICE:
support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
raise ValueError(
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
)
check_device(self.model_name, device_type)
self._cfg["device"] = device_type
device_id = device_ids[0] if device_ids is not None else 0
self._cfg["device_id"] = device_id
Expand Down
11 changes: 9 additions & 2 deletions paddlex/modules/base/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from abc import ABC, abstractmethod

from .build_model import build_model
from ...utils.device import update_device_num, set_env_for_device
from ...utils.device import (
update_device_num,
set_env_for_device,
parse_device,
check_device,
)
from ...utils.misc import AutoRegisterABCMetaClass
from ...utils.config import AttrDict
from ...utils.logging import *
Expand Down Expand Up @@ -138,8 +143,10 @@ def get_device(self, using_device_number: int = None) -> str:
Returns:
str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
"""
device_type, device_ids = parse_device(self.global_config.device)
check_device(self.global_config.model, device_type)
if using_device_number:
return update_device_num(self.global_config.device, using_device_number)
return update_device_num(device_type, device_ids, using_device_number)
set_env_for_device(self.global_config.device)
return self.global_config.device

Expand Down
11 changes: 9 additions & 2 deletions paddlex/modules/base/exportor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from abc import ABC, abstractmethod

from .build_model import build_model
from ...utils.device import update_device_num, set_env_for_device
from ...utils.device import (
update_device_num,
set_env_for_device,
parse_device,
check_device,
)
from ...utils.misc import AutoRegisterABCMetaClass
from ...utils.config import AttrDict
from ...utils import logging
Expand Down Expand Up @@ -103,8 +108,10 @@ def get_device(self, using_device_number: int = None) -> str:
Returns:
str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
"""
device_type, device_ids = parse_device(self.global_config.device)
check_device(self.global_config.model, device_type)
if using_device_number:
return update_device_num(self.global_config.device, using_device_number)
return update_device_num(device_type, device_ids, using_device_number)
set_env_for_device(self.global_config.device)
return self.global_config.device

Expand Down
11 changes: 9 additions & 2 deletions paddlex/modules/base/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from abc import ABC, abstractmethod
from pathlib import Path
from .build_model import build_model
from ...utils.device import update_device_num, set_env_for_device
from ...utils.device import (
update_device_num,
set_env_for_device,
parse_device,
check_device,
)
from ...utils.misc import AutoRegisterABCMetaClass
from ...utils.config import AttrDict

Expand Down Expand Up @@ -95,8 +100,10 @@ def get_device(self, using_device_number: int = None) -> str:
Returns:
str: device setting, such as: `gpu:0,1`, `npu:0,1` `cpu`.
"""
device_type, device_ids = parse_device(self.global_config.device)
check_device(self.global_config.model, device_type)
if using_device_number:
return update_device_num(self.global_config.device, using_device_number)
return update_device_num(device_type, device_ids, using_device_number)
set_env_for_device(self.global_config.device)
return self.global_config.device

Expand Down
15 changes: 15 additions & 0 deletions paddlex/paddlex_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
import argparse
import subprocess
import sys
import shutil
import tempfile
from pathlib import Path

from . import create_pipeline
from .inference.pipelines import create_pipeline_from_config, load_pipeline_config
from .repo_manager import setup, get_all_supported_repo_names
from .utils.cache import CACHE_DIR
from .utils import logging
from .utils.interactive_get_pipeline import interactive_get_pipeline

Expand Down Expand Up @@ -65,6 +68,7 @@ def parse_str(s):

################# install pdx #################
parser.add_argument("--install", action="store_true", default=False, help="")
parser.add_argument("--clear_cache", action="store_true", default=False, help="")
parser.add_argument("plugins", nargs="*", default=[])
parser.add_argument("--no_deps", action="store_true")
parser.add_argument("--platform", type=str, default="github.com")
Expand Down Expand Up @@ -159,6 +163,15 @@ def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, po
run_server(app, host=host, port=port, debug=False)


def clear_cache():
cache_dir = Path(CACHE_DIR) / "official_models"
if cache_dir.exists() and cache_dir.is_dir():
shutil.rmtree(cache_dir)
logging.info(f"Successfully cleared the cache models at {cache_dir}")
else:
logging.info(f"No cache models found at {cache_dir}")


# for CLI
def main():
"""API for commad line"""
Expand All @@ -180,6 +193,8 @@ def main():
host=args.host,
port=args.port,
)
elif args.clear_cache:
clear_cache()
else:
if args.get_pipeline_config is not None:
interactive_get_pipeline(args.get_pipeline_config, args.save_path)
Expand Down
24 changes: 17 additions & 7 deletions paddlex/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

from . import logging
from .errors import raise_unsupported_device_error

SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
from .other_devices_model_list import OTHER_DEVICES_MODEL_LIST


def _constr_device(device_type, device_ids):
Expand All @@ -38,6 +37,21 @@ def get_default_device():
return _constr_device("gpu", [avail_gpus[0]])


def check_device(model_name, device_type):
supported_device_type = ["cpu", "gpu", "xpu", "npu", "mlu", "dcu"]
device_type = device_type.lower()
if device_type not in supported_device_type:
support_run_mode_str = ", ".join(supported_device_type)
raise ValueError(
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
)
if device_type in OTHER_DEVICES_MODEL_LIST:
if model_name not in OTHER_DEVICES_MODEL_LIST[device_type]:
raise ValueError(
f"The model '{model_name}' is not supported on {device_type}."
)


def parse_device(device):
"""parse_device"""
# According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
Expand All @@ -55,14 +69,10 @@ def parse_device(device):
f"Device ID must be an integer. Invalid device ID: {device_id}"
)
device_ids = list(map(int, device_ids))
device_type = device_type.lower()
# raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
assert device_type.lower() in SUPPORTED_DEVICE_TYPE
return device_type, device_ids


def update_device_num(device, num):
device_type, device_ids = parse_device(device)
def update_device_num(device_type, device_ids, num):
if device_ids:
assert len(device_ids) >= num
return _constr_device(device_type, device_ids[:num])
Expand Down
Loading

0 comments on commit 4430b50

Please sign in to comment.