From de0c4acb96bcc75d10dcf151e452694b851e1224 Mon Sep 17 00:00:00 2001 From: zhangjiajin Date: Wed, 11 Aug 2021 17:39:15 +0800 Subject: [PATCH] release 1.6.0 --- README.cn.md | 11 +- README.md | 9 +- RELEASE.md | 2 +- docs/cn/developer/developer_guide.md | 130 --------- docs/cn/user/config_reference.md | 19 +- docs/cn/user/evaluate_service.md | 4 +- docs/en/developer/developer_guide.md | 130 --------- docs/en/user/config_reference.md | 17 +- docs/en/user/evaluate_service.md | 4 +- evaluate_service/hardwares/davinci/davinci.py | 6 +- .../hardwares/davinci/model_convert.sh | 14 +- evaluate_service/main.py | 2 +- examples/classification/classification.yml | 2 +- examples/data_augmentation/pba/pba.yml | 2 +- examples/features/quota/quota.yml | 20 +- examples/features/script_runner/bohb.yml | 1 - examples/hpo/bohb/bohb.yml | 2 +- examples/hpo/boss/boss.yml | 2 +- examples/hpo/pbt/pbt.yml | 2 +- examples/nas/dnet_nas/dnet_nas.yml | 5 +- setup.py | 3 +- vega/__init__.py | 8 +- .../compression/prune_ea/prune_ea.py | 5 +- .../prune_ea/prune_trainer_callback.py | 6 +- .../quant_ea/quant_trainer_callback.py | 4 +- vega/algorithms/hpo/bohb_conf.py | 2 +- vega/algorithms/hpo/evolution_search.py | 2 +- vega/algorithms/hpo/sha_base/bohb.py | 2 +- .../hpo/sha_base/tuner/tuner_builder.py | 2 +- vega/algorithms/nas/__init__.py | 3 +- .../adelaide_ea/adelaide_trainer_callback.py | 2 +- .../nas/cars/cars_trainer_callback.py | 4 +- .../nas/fis/ctr_trainer_callback.py | 2 +- vega/algorithms/nas/mfasc/conf.py | 1 + vega/algorithms/nas/opt_nas/__init__.py | 1 + vega/algorithms/nas/opt_nas/ops_nas.py | 83 ++++++ vega/common/__init__.py | 1 + .../__pycache__/__init__.cpython-37.pyc | Bin 949 -> 0 bytes .../__pycache__/arg_parser.cpython-37.pyc | Bin 1385 -> 0 bytes .../backend_register.cpython-37.pyc | Bin 2857 -> 0 bytes vega/common/__pycache__/check.cpython-37.pyc | Bin 2648 -> 0 bytes .../__pycache__/class_factory.cpython-37.pyc | Bin 6687 -> 0 bytes vega/common/__pycache__/config.cpython-37.pyc | Bin 4942 -> 0 bytes .../config_serializable.cpython-37.pyc | Bin 5199 -> 0 bytes vega/common/__pycache__/consts.cpython-37.pyc | Bin 870 -> 0 bytes .../__pycache__/file_ops.cpython-37.pyc | Bin 6990 -> 0 bytes .../common/__pycache__/general.cpython-37.pyc | Bin 3323 -> 0 bytes .../__pycache__/json_coder.cpython-37.pyc | Bin 927 -> 0 bytes .../__pycache__/message_client.cpython-37.pyc | Bin 2026 -> 0 bytes .../__pycache__/message_server.cpython-37.pyc | Bin 2696 -> 0 bytes .../__pycache__/pareto_front.cpython-37.pyc | Bin 723 -> 0 bytes .../__pycache__/task_ops.cpython-37.pyc | Bin 7779 -> 0 bytes .../__pycache__/user_config.cpython-37.pyc | Bin 2358 -> 0 bytes vega/common/__pycache__/utils.cpython-37.pyc | Bin 5786 -> 0 bytes vega/common/backend_register.py | 24 +- vega/common/class_factory.py | 50 +++- vega/common/general.py | 16 +- vega/common/searchable.py | 138 +++++++++ vega/core/__pycache__/__init__.cpython-37.pyc | Bin 314 -> 0 bytes vega/core/__pycache__/run.cpython-37.pyc | Bin 2826 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 512 -> 0 bytes .../pipeline/__pycache__/conf.cpython-37.pyc | Bin 3569 -> 0 bytes .../__pycache__/pipe_step.cpython-37.pyc | Bin 2103 -> 0 bytes .../__pycache__/pipeline.cpython-37.pyc | Bin 4676 -> 0 bytes vega/core/pipeline/generator.py | 51 +++- vega/core/pipeline/horovod/horovod_train.py | 2 +- .../pipeline/horovod/run_horovod_train.sh | 3 +- vega/core/pipeline/train_pipe_step.py | 16 +- vega/core/quota/__init__.py | 1 - .../quota/__pycache__/__init__.cpython-37.pyc | Bin 187 -> 0 bytes .../__pycache__/quota_strategy.cpython-37.pyc | Bin 5197 -> 0 bytes vega/core/quota/quota_strategy.py | 141 ---------- vega/core/run.py | 7 +- .../__pycache__/__init__.cpython-37.pyc | Bin 214 -> 0 bytes .../__pycache__/master_ops.cpython-37.pyc | Bin 1247 -> 0 bytes vega/core/scheduler/dask_env.py | 9 + vega/core/scheduler/distribution.py | 9 +- vega/core/scheduler/worker_env.py | 1 + vega/core/search_space/ext_hyper_parameter.py | 12 +- .../__pycache__/__init__.cpython-37.pyc | Bin 1176 -> 0 bytes vega/datasets/common/dataset.py | 10 +- vega/datasets/common/imagenet.py | 13 +- .../conf/__pycache__/__init__.cpython-37.pyc | Bin 138 -> 0 bytes .../conf/__pycache__/dataset.cpython-37.pyc | Bin 788 -> 0 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 412 -> 0 bytes .../evaluator/__pycache__/conf.cpython-37.pyc | Bin 1798 -> 0 bytes vega/evaluator/conf.py | 1 + vega/evaluator/device_evaluator.py | 4 +- vega/evaluator/evaluator.py | 7 +- vega/evaluator/host_evaluator.py | 78 ++++-- vega/evaluator/tools/evaluate_davinci_bolt.py | 14 +- vega/evaluator/tools/pytorch2onnx.py | 23 +- vega/metrics/__init__.py | 2 +- vega/metrics/flops_and_params.py | 25 +- vega/metrics/forward_latency.py | 7 +- vega/metrics/pytorch/classifier_metric.py | 2 +- vega/model_zoo/compressed_model_filter.py | 63 ----- vega/model_zoo/model_zoo.py | 44 ++- vega/modules/arch/architecture.py | 20 +- vega/modules/arch/combiner.py | 15 +- vega/modules/arch/double_channels_arch.py | 19 +- vega/modules/arch/prune_arch.py | 142 ++++++++-- vega/modules/connections/connections.py | 62 ++++- vega/modules/loss/loss.py | 8 +- vega/modules/operators/conv.py | 36 ++- .../modules/operators/functions/pytorch_fn.py | 262 ++++++++++++++++-- .../operators/functions/tensorflow_fn.py | 5 + vega/modules/operators/ops.py | 9 + vega/networks/__init__.py | 1 + .../__pycache__/__init__.cpython-37.pyc | Bin 1467 -> 0 bytes .../__pycache__/model_config.cpython-37.pyc | Bin 1417 -> 0 bytes .../__pycache__/network_desc.cpython-37.pyc | Bin 1168 -> 0 bytes vega/networks/faster_rcnn.py | 2 +- vega/networks/unet.py | 95 +++++++ vega/quota/__init__.py | 13 +- vega/quota/duration_terminate.py | 36 --- vega/quota/flops_params.py | 44 +++ vega/quota/flops_params_filter.py | 53 ---- vega/quota/latency.py | 38 +++ vega/quota/latency_filter.py | 45 --- vega/quota/model_valid.py | 30 ++ vega/quota/quota.py | 121 ++++++++ vega/{core => }/quota/quota_affinity.py | 1 + vega/quota/quota_compare.py | 90 ------ ...r_terminate_base.py => quota_item_base.py} | 37 +-- vega/quota/target_terminate.py | 33 --- vega/quota/trial_terminate.py | 34 --- vega/quota/valid_filter.py | 40 --- .../__pycache__/__init__.cpython-37.pyc | Bin 360 -> 0 bytes .../__pycache__/nsga_iii.cpython-37.pyc | Bin 5623 -> 0 bytes vega/report/__pycache__/record.cpython-37.pyc | Bin 7149 -> 0 bytes .../__pycache__/report_client.cpython-37.pyc | Bin 3646 -> 0 bytes .../report_persistence.cpython-37.pyc | Bin 2554 -> 0 bytes .../__pycache__/report_server.cpython-37.pyc | Bin 11055 -> 0 bytes vega/report/record.py | 6 +- vega/report/report_persistence.py | 14 +- vega/report/report_server.py | 72 +++-- .../tools/__pycache__/__init__.cpython-37.pyc | Bin 130 -> 0 bytes .../__pycache__/query_process.cpython-37.pyc | Bin 4536 -> 0 bytes .../__pycache__/verify_cluster.cpython-37.pyc | Bin 7903 -> 0 bytes vega/tools/query_process.py | 36 ++- vega/tools/query_progress.py | 50 ++-- vega/tools/run_pipeline.py | 2 +- vega/tools/run_slave.py | 40 +++ .../__pycache__/__init__.cpython-37.pyc | Bin 452 -> 0 bytes vega/trainer/__pycache__/conf.cpython-37.pyc | Bin 4421 -> 0 bytes .../__pycache__/task_conf.cpython-37.pyc | Bin 970 -> 0 bytes .../__pycache__/trial_agent.cpython-37.pyc | Bin 2023 -> 0 bytes vega/trainer/callbacks/__init__.py | 5 + vega/trainer/callbacks/callback_list.py | 3 +- vega/trainer/callbacks/callbacks.md | 21 ++ vega/trainer/callbacks/data_parallel.py | 50 ++++ vega/trainer/callbacks/model_builder.py | 14 +- vega/trainer/callbacks/model_checkpoint.py | 3 + vega/trainer/callbacks/progress_logger.py | 29 +- vega/trainer/callbacks/report_callback.py | 5 +- .../callbacks/timm_trainer_callback.py | 6 +- vega/trainer/callbacks/visual_callback.py | 2 +- vega/trainer/conf.py | 3 +- .../__pycache__/__init__.cpython-37.pyc | Bin 183 -> 0 bytes .../config_bakcend_map.cpython-37.pyc | Bin 1866 -> 0 bytes .../conf/__pycache__/__init__.cpython-37.pyc | Bin 145 -> 0 bytes .../conf/__pycache__/loss.cpython-37.pyc | Bin 1737 -> 0 bytes .../__pycache__/lr_scheduler.cpython-37.pyc | Bin 1398 -> 0 bytes .../conf/__pycache__/optim.cpython-37.pyc | Bin 1668 -> 0 bytes .../lr_schedulers/warmup_scheduler_torch.py | 17 +- vega/trainer/modules/optimizer/optim.py | 6 + vega/trainer/script_runner.py | 17 +- vega/trainer/trainer_base.py | 22 +- vega/trainer/trainer_torch.py | 27 +- 170 files changed, 1747 insertions(+), 1282 deletions(-) create mode 100644 vega/algorithms/nas/opt_nas/__init__.py create mode 100644 vega/algorithms/nas/opt_nas/ops_nas.py delete mode 100644 vega/common/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/common/__pycache__/arg_parser.cpython-37.pyc delete mode 100644 vega/common/__pycache__/backend_register.cpython-37.pyc delete mode 100644 vega/common/__pycache__/check.cpython-37.pyc delete mode 100644 vega/common/__pycache__/class_factory.cpython-37.pyc delete mode 100644 vega/common/__pycache__/config.cpython-37.pyc delete mode 100644 vega/common/__pycache__/config_serializable.cpython-37.pyc delete mode 100644 vega/common/__pycache__/consts.cpython-37.pyc delete mode 100644 vega/common/__pycache__/file_ops.cpython-37.pyc delete mode 100644 vega/common/__pycache__/general.cpython-37.pyc delete mode 100644 vega/common/__pycache__/json_coder.cpython-37.pyc delete mode 100644 vega/common/__pycache__/message_client.cpython-37.pyc delete mode 100644 vega/common/__pycache__/message_server.cpython-37.pyc delete mode 100644 vega/common/__pycache__/pareto_front.cpython-37.pyc delete mode 100644 vega/common/__pycache__/task_ops.cpython-37.pyc delete mode 100644 vega/common/__pycache__/user_config.cpython-37.pyc delete mode 100644 vega/common/__pycache__/utils.cpython-37.pyc create mode 100644 vega/common/searchable.py delete mode 100644 vega/core/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/core/__pycache__/run.cpython-37.pyc delete mode 100644 vega/core/pipeline/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/core/pipeline/__pycache__/conf.cpython-37.pyc delete mode 100644 vega/core/pipeline/__pycache__/pipe_step.cpython-37.pyc delete mode 100644 vega/core/pipeline/__pycache__/pipeline.cpython-37.pyc delete mode 100644 vega/core/quota/__init__.py delete mode 100644 vega/core/quota/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/core/quota/__pycache__/quota_strategy.cpython-37.pyc delete mode 100644 vega/core/quota/quota_strategy.py delete mode 100644 vega/core/scheduler/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/core/scheduler/__pycache__/master_ops.cpython-37.pyc delete mode 100644 vega/datasets/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/datasets/conf/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/datasets/conf/__pycache__/dataset.cpython-37.pyc delete mode 100644 vega/evaluator/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/evaluator/__pycache__/conf.cpython-37.pyc delete mode 100644 vega/model_zoo/compressed_model_filter.py delete mode 100644 vega/networks/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/networks/__pycache__/model_config.cpython-37.pyc delete mode 100644 vega/networks/__pycache__/network_desc.cpython-37.pyc create mode 100644 vega/networks/unet.py delete mode 100644 vega/quota/duration_terminate.py create mode 100644 vega/quota/flops_params.py delete mode 100644 vega/quota/flops_params_filter.py create mode 100644 vega/quota/latency.py delete mode 100644 vega/quota/latency_filter.py create mode 100644 vega/quota/model_valid.py create mode 100644 vega/quota/quota.py rename vega/{core => }/quota/quota_affinity.py (99%) delete mode 100644 vega/quota/quota_compare.py rename vega/quota/{filter_terminate_base.py => quota_item_base.py} (51%) delete mode 100644 vega/quota/target_terminate.py delete mode 100644 vega/quota/trial_terminate.py delete mode 100644 vega/quota/valid_filter.py delete mode 100644 vega/report/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/report/__pycache__/nsga_iii.cpython-37.pyc delete mode 100644 vega/report/__pycache__/record.cpython-37.pyc delete mode 100644 vega/report/__pycache__/report_client.cpython-37.pyc delete mode 100644 vega/report/__pycache__/report_persistence.cpython-37.pyc delete mode 100644 vega/report/__pycache__/report_server.cpython-37.pyc delete mode 100644 vega/tools/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/tools/__pycache__/query_process.cpython-37.pyc delete mode 100644 vega/tools/__pycache__/verify_cluster.cpython-37.pyc create mode 100644 vega/tools/run_slave.py delete mode 100644 vega/trainer/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/trainer/__pycache__/conf.cpython-37.pyc delete mode 100644 vega/trainer/__pycache__/task_conf.cpython-37.pyc delete mode 100644 vega/trainer/__pycache__/trial_agent.cpython-37.pyc create mode 100644 vega/trainer/callbacks/callbacks.md create mode 100644 vega/trainer/callbacks/data_parallel.py delete mode 100644 vega/trainer/modules/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/trainer/modules/__pycache__/config_bakcend_map.cpython-37.pyc delete mode 100644 vega/trainer/modules/conf/__pycache__/__init__.cpython-37.pyc delete mode 100644 vega/trainer/modules/conf/__pycache__/loss.cpython-37.pyc delete mode 100644 vega/trainer/modules/conf/__pycache__/lr_scheduler.cpython-37.pyc delete mode 100644 vega/trainer/modules/conf/__pycache__/optim.cpython-37.pyc diff --git a/README.cn.md b/README.cn.md index 9e8a44ae..043cb329 100644 --- a/README.cn.md +++ b/README.cn.md @@ -9,13 +9,14 @@ --- -**Vega ver1.5.0 发布** +**Vega ver1.6.0 发布** - 特性增强 - - 解决了分布式训练的一些bug。 - - 部分网络支持PyTorch + Ascend 910)。 - - 命令Vega-process、Vega-progress、vega-verify-cluster提供Json格式信息。 + - 支持简洁的quota设置,比如:`quota: flops < 11.2 and params in [34.0, 56.0]`。 + - 支持在python虚拟环境下运行Vega。 + - 支持运行环境:Python 3.8和PyTorch 1.9。 + - 解决了并行训练和分布式搜索的一些bug。 --- @@ -90,7 +91,7 @@ Vega提供了40+示例供参考:[示例](https://github.com/huawei-noah/vega/t | 对象 | 参考 | | :--: | :-- | -| [**用户**
(用户指南)](./docs/cn/user/README.md) | [安装指导](./docs/cn/user/install.md)、[部署指导](./docs/cn/user/deployment.md)、[配置指导](./docs/cn/user/config_reference.md)、[示例参考](./docs/cn/user/examples.md)、[评估服务](./docs/cn/user/evaluate_service.md) | +| [**用户**
(用户指南)](./docs/cn/user/README.md) | [安装指导](./docs/cn/user/install.md)、[部署指导](./docs/cn/user/deployment.md)、[配置指导](./docs/cn/user/config_reference.md)、[示例参考](./docs/cn/user/examples.md)、[评估服务](./docs/cn/user/evaluate_service.md)、任务参考([分类](./docs/cn/tasks/classification.md)、[检测](./docs/cn/tasks/detection.md)、[分割](./docs/cn/tasks/segmentation.md)、[超分](./docs/cn/tasks/segmentation.md)) | | [**开发者**
(开发者指南)](./docs/cn/developer/README.md) | [开发者指导](./docs/cn/developer/developer_guide.md)、[快速入门指导](./docs/cn/developer/quick_start.md)、[数据集指导](./docs/cn/developer/datasets.md)、[算法开发指导](./docs/cn/developer/new_algorithm.md)、[细粒度搜索空间指导](./docs/cn/developer/fine_grained_space.md) | ## FAQ diff --git a/README.md b/README.md index ff65a50c..fef52ef7 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,14 @@ --- -**Vega ver1.5.0 released** +**Vega ver1.6.0 released** - Feature enhancement: - - Fixed some bugs in distributed training. - - Some networks support PyTorch + Ascend 910. - - The Vega-process, Vega-progress, and vega-verify-cluster commands provide JSON format information. + - Supports simple quota settings, for example, `quota: flops < 11.2 and params in [34.0, 56.0]`. + - Supports running Vega in a Python virtual environment. + - Supported running environments: Python 3.8 and PyTorch 1.9. + - Fixed some bugs with parallel training and distributed search. --- diff --git a/RELEASE.md b/RELEASE.md index 1ef0acc5..5e166a28 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,4 @@ -**Vega ver1.5.0 released:** +**Vega ver1.6.0 released:** **Introduction** diff --git a/docs/cn/developer/developer_guide.md b/docs/cn/developer/developer_guide.md index 7166e1f1..b59c7b1e 100644 --- a/docs/cn/developer/developer_guide.md +++ b/docs/cn/developer/developer_guide.md @@ -238,136 +238,6 @@ class Generator(object): 代码中的sample接口即是NAS中每一次采样,首先调用搜索算法search出一个网络描述,再通过网络描述生成网络模型。 此外,Generator还具有判断迭代搜索是否停止以及更新搜索算法等功能。 -### 4.2 Quota - -Quota是一个可选的插件,允许用户定义特定的规则来控制nas搜索过程并实现自定义功能。 -Quota目前提供的功能包括: - -- 搜索过程控制:如果达到用户定义的条件限制,则停止nas搜索过程 -- Sample过滤:如果sample不满足用户定义的条件限制,则丢弃不符合要求的sample - -Quota的实现如下: - -```python -class Quota(object): - - def __new__(cls, *args, **kwargs): - return super().__new__(cls) - - def __init__(self): - self.strategies = [] - self.filters = [] - # check whether quota is configured - pipe_config = UserConfig().data.get(General.step_name) - if not pipe_config.get('quota'): - return - else: - # get quota configuration - quota_config = pipe_config.get('quota') - # get all defined strategies if any - if quota_config.get('strategy'): - strategy_types = quota_config['strategy'] - for type_name in strategy_types: - t_cls = ClassFactory.get_cls(ClassType.QUOTA, type_name) - self.strategies.append(t_cls()) - # get all defined limitations if any - if quota_config.get('filter'): - filter_types = quota_config['filter'] - for type_name in filter_types: - t_cls = ClassFactory.get_cls(ClassType.QUOTA, type_name) - self.filters.append(t_cls()) - - def halt(self): - raise NotImplementedError - - def filter(self, res): - raise NotImplementedError - - def is_halted(self): - for strategy in self.strategies: - if strategy.halt(): - return True - return False - - # check whether some defined filters are satisfied. - # If reaching constraints, just return false. Otherwise, always return False. - def is_filtered(self, res=None): - for flt in self.filters: - if flt.filter(res): - logging.info("Sample was throw away by strategy = %s", flt.__class__.__name__) - return True - return False -``` - -Quota在初始化时首先尝试检查用户是否开启了Quota配置。如果设置开启Quota,Quota将读取所有用户自定义的搜索停止策略和过滤器。Quota允许用户同时定义多个规则,规则名称同时指定在一个list中。同时定义的多个搜索停止策略属于并列关系,将同时起作用。如果未设置Quota,不会对Vega的运行造成任何影响。 - -Quota的配置使用需要完成四个步骤。首先,继承Quota基类,根据自定义需求构造strategy或filter类。其次,覆写Quota基类中的“halt()”和“filter()”抽象函数,并在其中实现用户的自定义策略。第三,将自定义实现的strategy和filter在Class Factory中进行注册。最后,在用户配置文件中添加Quota配置项和配置参数。之后Vega将自动执行Quota配置的策略。 - -以下是一个Quota的配置样例: - -```yml -general: - -pipeline: [nas] - -nas: - pipe_step: - type: SearchPipeStep - - quota: - strategy: [MaxDurationStrategy, MaxTrialNumberStrategy] - filter: [FlopParamFliter] - policy: - max_duration_time: 3000 - max_trial_num: 300 - flop_range: [!!float 0, !!float 0.6] - param_range: [!!float 0, !!float 1e10] - -``` - -Quota的配置项位于每个具体的pipeline之下,如上述配置中名为“nas”的pipeline,因此,一个Quota配置只对自己所在的Pipeline负责和起作用.如果用户希望Quota对不同的Pipeline步骤生效,则需要在每个流水线步骤中都添加Quota配置。在Quota的配置中,用户可以添加任意自定义的strategy和filter,只需将自定义的具体Quota类使用“QUOTA”关键字注册到类工厂中即可被Vega索引。与Quota相关的参数可以定义到“policy”字段中,在实现具体的Quota类时,可以通过UserConfig()对定义在用户配置中的参数进行使用。 - -```python -class Generator(object): - """Convert search space and search algorithm, sample a new model.""" - - def __init__(self): - ... - self.quota = Quota() - ... - @property - def is_completed(self): - return self.search_alg.is_completed or self.quota.is_halted() - - def sample(self): - """Sample a work id and model from search algorithm.""" - res = self.search_alg.search() - if not res: - return None - if not isinstance(res, list): - res = [res] - if self.quota.is_filtered(res): - return None - if len(res) == 0: - return None - out = [] - for sample in res: - if isinstance(sample, tuple): - sample = dict(worker_id=sample[0], desc=sample[1]) - record = self.record.load_dict(sample) - logging.debug("update record=%s", str(record)) - ReportClient().update(**record.to_dict()) - desc = self._decode_hps(record.desc) - out.append((record.worker_id, desc)) - return out -``` - -下面是关于Quota在Vega中工作的流程。在每个Pipeline中,Vega首先检查nas搜索是否完成(is_completed())或是否达到用户定义的停止条件(is_halted())。如果搜索完成,或者达到用户定义的停止条件所到达停止条件,当前pipeline将被立即停止。 - -在准备sample的阶段,generator首先在从搜索算法中获取待评估的sample,这些sample被移交给Quota进行过滤(is_filtered())。过滤规则由用户在具体Quota类的“filter()”函数中定义。任何不满足用户定义过滤规则的的样本将不会被用来训练并被直接丢弃。过滤完成后,Quota将所有满足条件的Sample传递给generator,之后将完成训练。 - -Vega现在为用户提供两种pipeline停止策略和一种样本过滤器。现有的两种停止策略分别是利用最大采样次数和最长运行时间来作为终止条件。这两种策略都支持“开箱即用”。内置的样本过滤器则允许用户在搜索算法搜索到sample之后立即评估sample的flops和parameters参数来决定是否要保留该sample。由于计算flops和parameters需要知道数据集的相关信息,因此用户必须在相关的数据集类中提供一个“data_case()”接口,或者提供自定义的方法来计算flops和parameters。 - ## 5 Trainer Trainer用于训练模型,在NAS、HPO、fully train等阶段,可将trainer配置这些阶段的pipestep中,完成模型的训练。 diff --git a/docs/cn/user/config_reference.md b/docs/cn/user/config_reference.md index a3148833..f6aca3a1 100644 --- a/docs/cn/user/config_reference.md +++ b/docs/cn/user/config_reference.md @@ -61,12 +61,7 @@ my_fully_train: | logger / level | debug \| info \| warn \| error \| critical | info | 日志级别。 | | cluster / master_ip | - | ~ | 在集群场景下需要设置该参数,设置为master节点的IP地址。 | | cluster / slaves | - | [] | 在集群场景下需要设置该参数,设置为除了master节点外的其他节点的IP地址。 | -| quota / restrict / flops | - | ~ | 过滤模型。设置采样模型的浮点计算量最大值或范围,单位为M。 | -| quota / restrict / params | - | ~ | 过滤模型。设置采样模型的参数量最大值或范围,单位为K。 | -| quota / restrict / latency | - | ~ | 过滤模型。设置采样模型的时延最大值或范围,单位为ms。 | -| quota / target / type | accuracy \| IoUMetric \| PSNR | ~ | 过滤模型。设置模型的训练metric目标类型。 | -| quota / target / value | - | ~ | 过滤模型。设置模型的训练metric目标值。 | -| quota / runtime | - | ~ | 用户设定的Pipeline最大运行时间估计值,单位为h。 | +| quota | - | ~ | 过滤模型。可设置采样模型的浮点计算量最大值或范围(单位为M),模型的参数量最大值或范围(单位为K),采样模型的时延最大值或范围(单位为ms),Pipeline最大运行时间(单位为h)。支持"<"、">"、"in"、"and" 四种操作。
eg: "flops < 10 and params in [100, 1000]" | ```yaml general: @@ -81,15 +76,7 @@ general: cluster: master_ip: ~ slaves: [] - quota: - restrict: - flops: 10 - params: [100, 1000] - latency: 100 - target: - type: accuracy - value: 0.98 - runtime: 10 + quota: "flops < 10 and params in [100, 1000]" ``` ## 2.1 并行和分布式 @@ -244,7 +231,7 @@ search_algorithm: | type | 搜索算法名称,包括RandomSearch、AshaHpo、BohbHpo、BossHpo、PBTHpo | `type: RandomSearch` | | objective_keys | 优化目标 | `objective_keys: 'accuracy'` | | policy.total_epochs | 搜索epoch配额。Vega简化了配置策略,只需要配置该参数。若需了解其他参数配置,可参考HPO和NAGO算法示例。 | `total_epochs: 2430` | -| tuner | tuner类型,用于BOHB算法,包括gp(缺省)、rf、hebo | tuner: "gp" | +| tuner | tuner类型,用于BOHB算法,包括gp、rf(缺省)、hebo | tuner: "rf" | 注意:若参数tuner设置hebo,则需要安装"[HEBO](https://github.com/huawei-noah/noah-research/tree/master/HEBO)",且需要注意gpytorch的版本为1.1.1,torch的版本设置为1.5.0,torchvision的版本为0.5.0。 diff --git a/docs/cn/user/evaluate_service.md b/docs/cn/user/evaluate_service.md index 63c778b7..4d01af4b 100644 --- a/docs/cn/user/evaluate_service.md +++ b/docs/cn/user/evaluate_service.md @@ -269,9 +269,9 @@ from .my_hardware import MyHardware ## 6. FAQ -### 6.1 Pytorch模型评估 +### 6.1 Pytorch模型转换caffe模型 -在评估服务的客户端需要进行`Pytorch`模型的转换,请下载[PytorchToCaffe](https://github.com/xxradon/PytorchToCaffe)获取并放在`./third_party`目录下(third_party目录与vega处于同一目录层级)。 +如果需要将pytorch模型转换为caffe模型,请下载[PytorchToCaffe](https://github.com/xxradon/PytorchToCaffe)获取并放在`./third_party`目录下(third_party目录与vega处于同一目录层级)。 注意: 该第三方开源软件不支持pytorch1.1版本, 并且如果您使用原生torchvisoin中的模型, 当torchvision版本高于0.2.0时, 您需要做以下额外修改: 修改`pytorch_to_caffe.py`文件, 增加以下内容: diff --git a/docs/en/developer/developer_guide.md b/docs/en/developer/developer_guide.md index 5fa4c0c9..0472c7bc 100644 --- a/docs/en/developer/developer_guide.md +++ b/docs/en/developer/developer_guide.md @@ -242,136 +242,6 @@ The sample interface in the code is used for each sampling in the NAS. The sampl In addition, the generator can determine whether the iterative search stops and update the search algorithm. -### 4.2 Quota - -Quota is an optional plugin that enables users to define specific rules to control nas search process for special purpose. -Currently, Quota provides avaliable abilities as follows: - -- Search control: halt nas search process if reaching user-defined constraints -- Sample filtering: throw away sample from search algorithm if disatisfying user-defined limitations - -The implementation of Quota is show below. - -```python -class Quota(object): - - def __new__(cls, *args, **kwargs): - return super().__new__(cls) - - def __init__(self): - self.strategies = [] - self.filters = [] - # check whether quota is configured - pipe_config = UserConfig().data.get(General.step_name) - if not pipe_config.get('quota'): - return - else: - # get quota configuration - quota_config = pipe_config.get('quota') - # get all defined strategies if any - if quota_config.get('strategy'): - strategy_types = quota_config['strategy'] - for type_name in strategy_types: - t_cls = ClassFactory.get_cls(ClassType.QUOTA, type_name) - self.strategies.append(t_cls()) - # get all defined limitations if any - if quota_config.get('filter'): - filter_types = quota_config['filter'] - for type_name in filter_types: - t_cls = ClassFactory.get_cls(ClassType.QUOTA, type_name) - self.filters.append(t_cls()) - - def halt(self): - raise NotImplementedError - - def filter(self, res): - raise NotImplementedError - - def is_halted(self): - for strategy in self.strategies: - if strategy.halt(): - return True - return False - - # check whether some defined filters are satisfied. - # If reaching constraints, just return false. Otherwise, always return False. - def is_filtered(self, res=None): - for flt in self.filters: - if flt.filter(res): - logging.info("Sample was throw away by strategy = %s", flt.__class__.__name__) - return True - return False -``` - -While initializing, Quota tries to find whether users set Quota's configuration or not. If setted, Quota gets all user-defined strategies and filters. Quota allows multiple defined rules, and users can give multi rule names in one list. All of the strategies work with a union relationship, which means they make effect at the same time. If not setted, Quota will have nothing influnce on Vega's running. - -To take advantage of Quota, there are four steps users should walk with. First, construct a strategy or filter class, which is inherited from Quota base class. Second, overwrite the abstract function of "halt()" and "filter()" to put on users' self-defined approach. Third, regist finished concrete class into class factory. At the end, add quota configuration setting item into user configuration file, and then Vega will automatically hold on the rest of things. - -Here is an configuration example of Quota: - -```yml -general: - -pipeline: [nas] - -nas: - pipe_step: - type: SearchPipeStep - - quota: - strategy: [MaxDurationStrategy, MaxTrialNumberStrategy] - filter: [FlopParamFliter] - policy: - max_duration_time: 3000 - max_trial_num: 300 - flop_range: [!!float 0, !!float 0.6] - param_range: [!!float 0, !!float 1e10] - -``` - -Quota configuration item is put under the converage of each pipeline, and only has responsibility of each pipeline step, so users need to add Quota setting in each single pipeline step if they want Quota makes effect on different pipeline steps. In Quota's setting paragraph, users can selectively give their own defined strategies and filters by classname which are implemented free and registed into class factory by type "QUOTA". Relavant parameters can be difined into "policy" and refered in strategy class by UserConfig(). - -```python -class Generator(object): - """Convert search space and search algorithm, sample a new model.""" - - def __init__(self): - ... - self.quota = Quota() - ... - @property - def is_completed(self): - return self.search_alg.is_completed or self.quota.is_halted() - - def sample(self): - """Sample a work id and model from search algorithm.""" - res = self.search_alg.search() - if not res: - return None - if not isinstance(res, list): - res = [res] - if self.quota.is_filtered(res): - return None - if len(res) == 0: - return None - out = [] - for sample in res: - if isinstance(sample, tuple): - sample = dict(worker_id=sample[0], desc=sample[1]) - record = self.record.load_dict(sample) - logging.debug("update record=%s", str(record)) - ReportClient().update(**record.to_dict()) - desc = self._decode_hps(record.desc) - out.append((record.worker_id, desc)) - return out -``` - -It should be anounced how Quota works in a round of nas search process. In each nas pipeline step, Vega first check whether the search procedure has completed or arrived at the user-defined halting conditions. If getting to the stop condition, the current nas pipeline step will halt at once. - -In proposing samples, after receiving a sample res from search algorithm, the res sample is handed to quota to filter. The filtering rules is defined by users in the function"fliter()" in concrete class. Users can throw away any sample that don't reach their expectation. Afterwards, generator gets all satisfactory samples and go for further processing. - -Vega now provides two kinds of halting strategies and one kind of sample filter. The two exisitng halting strategies allow the pipe step to stop by sample trials number and pipe step running time, respectively. These two strategies all support "out of box". The fliter example enables users to remove the sample they don't want by evaluating the flops and parameters of the sample network before training them. Calculating flops and parameters needs to know the dataset's information, so users have to write a "data_case()" interface in related dataset class or just give their own method to compute flops and parameters. - ### 5 Trainer The trainer is used to train models. In the NAS, HPO, and fully train phases, the trainer can be configured in the pipe steps of these phases to complete model training. diff --git a/docs/en/user/config_reference.md b/docs/en/user/config_reference.md index 0414d42f..998c6edc 100644 --- a/docs/en/user/config_reference.md +++ b/docs/en/user/config_reference.md @@ -60,12 +60,7 @@ The following public configuration items can be configured: | logger / level | debug \| info \| warn \| error \| critical | info | Log level | | cluster / master_ip | - | ~ | In the cluster scenario, this parameter needs to be set to the IP address of the master node. | | cluster / slaves | - | [] | In the cluster scenario, this parameter needs to be set to the IP address of other nodes except the master node. | -| quota / restrict / flops | - | ~ | Models filter. Set maximum value or range of the floating-point calculation amount of the sampling model, in MB. | -| quota / restrict / params | - | ~ | Models filter. Set maximum value or range of the parameter of the sampling model, in KB. | -| quota / restrict / latency | - | ~ | Models filter. Set maximum value or range of the latency of the sampling model, in ms. | -| quota / target / type | accuracy \| IoUMetric \| PSNR | ~ | Models filter. set training metric target type. | -| quota / target / value | - | ~ | Models filter. Set target training metric of a model. | -| quota / runtime | - | ~ | Max pipeline estimated running time set by user, in h. | +| quota | - | ~ | Models filter. Set maximum value or range of the floating-point calculation amount of the sampling model (MB), the parameters of the sampling model (KB), the latency of the sampling model (ms), max pipeline estimated running time set by user (hour). The options are "<", ">", "in", and "and".
eg: "flops < 10 and params in [100, 1000]" | ```yaml general: @@ -80,15 +75,7 @@ general: cluster: master_ip: ~ slaves: [] - quota: - restrict: - flops: 10 - params: [100, 1000] - latency: 100 - target: - type: accuracy - value: 0.98 - runtime: 10 + quota: "flops < 10 and params in [100, 1000]" ``` ## 2.1 Parallel and distributed diff --git a/docs/en/user/evaluate_service.md b/docs/en/user/evaluate_service.md index 6cfa75e8..54902d69 100644 --- a/docs/en/user/evaluate_service.md +++ b/docs/en/user/evaluate_service.md @@ -273,9 +273,9 @@ from .my_hardware import MyHardware ## 6. FAQ -### 6.1 Pytorch Model Evaluation +### 6.1 Convert pytorch model to caffe model -The `Pytorch` model needs to be converted on the Appraisal client. Download [PytorchToCaffe](https://github.com/xxradon/PytorchToCaffe) and store it in the `./third_party` directory (the third_party directory and vega directory are at the same directory level). +If you need to convert the pytorch model to caffe model, download [PytorchToCaffe](https://github.com/xxradon/PytorchToCaffe) and store it in the `./third_party` directory (the third_party directory and vega directory are at the same directory level). Note: The third-party open-source software does not support pytorch1.1. If you use the model in the native torchvisoin and the torchvision version is later than 0.2.0, you need to make the following additional modifications: Add the following content to the `pytorch_to_caffe.py` file: diff --git a/evaluate_service/hardwares/davinci/davinci.py b/evaluate_service/hardwares/davinci/davinci.py index 1c4f1345..bf872d60 100644 --- a/evaluate_service/hardwares/davinci/davinci.py +++ b/evaluate_service/hardwares/davinci/davinci.py @@ -28,13 +28,13 @@ def __init__(self, optional_params): self.davinci_environment_type = optional_params.get("davinci_environment_type") def convert_model(self, backend, model, weight, **kwargs): - """Convert the tf/caffe/mindspore model to om model in Davinci. + """Convert the tf/caffe/mindspore/onnx model to om model in Davinci. - :param backend: the backend can be one of "tensorflow", "caffe" and "mindspore" + :param backend: the backend can be one of "tensorflow", "caffe", "mindspore" and "onnx" :type backend: str :param model: the model file need to convert :type model: str - :param weight: the weight file need to converta + :param weight: the weight file need to convert :type weight: str """ om_save_path = kwargs["save_dir"] diff --git a/evaluate_service/hardwares/davinci/model_convert.sh b/evaluate_service/hardwares/davinci/model_convert.sh index fb1b7a1e..818a89a5 100644 --- a/evaluate_service/hardwares/davinci/model_convert.sh +++ b/evaluate_service/hardwares/davinci/model_convert.sh @@ -8,20 +8,22 @@ INPUT_SHAPE=$7 if [ $DAVINCI_ENV_TYPE == "ATLAS200DK" ]; then if [ $BACKEND == "tensorflow" ]; then - omg --model=$MODEL --framework=3 --output=$OM_SAVE_PATH/davinci_model >$LOG_SAVE_PATH/omg.log + omg --model=$MODEL --framework=3 --output=$OM_SAVE_PATH/davinci_model >$LOG_SAVE_PATH/omg.log 2>&1 elif [ $BACKEND == "caffe" ]; then - omg --model=$MODEL --weight=$WEIGHT --framework=0 --output=$OM_SAVE_PATH/davinci_model >$LOG_SAVE_PATH/omg.log + omg --model=$MODEL --weight=$WEIGHT --framework=0 --output=$OM_SAVE_PATH/davinci_model >$LOG_SAVE_PATH/omg.log 2>&1 else echo "[ERROR] Davinci model convert: The backend must be tensorflow, caffe." fi else if [ $BACKEND == "tensorflow" ]; then - atc --model=$MODEL --framework=3 --input_format='NCHW' --disable_reuse_memory=1 --input_shape=$INPUT_SHAPE --output=$OM_SAVE_PATH/davinci_model --soc_version=Ascend310 --core_type=AiCore >$LOG_SAVE_PATH/omg.log + atc --model=$MODEL --framework=3 --input_format='NCHW' --disable_reuse_memory=1 --input_shape=$INPUT_SHAPE --output=$OM_SAVE_PATH/davinci_model --soc_version=Ascend310 --core_type=AiCore >$LOG_SAVE_PATH/omg.log 2>&1 elif [ $BACKEND == "caffe" ]; then - atc --model=$MODEL --weight=$WEIGHT --framework=0 --input_format='NCHW' --disable_reuse_memory=1 --output=$OM_SAVE_PATH/davinci_model --soc_version=Ascend310 --core_type=AiCore >$LOG_SAVE_PATH/omg.log + atc --model=$MODEL --weight=$WEIGHT --framework=0 --input_format='NCHW' --disable_reuse_memory=1 --output=$OM_SAVE_PATH/davinci_model --soc_version=Ascend310 --core_type=AiCore >$LOG_SAVE_PATH/omg.log 2>&1 elif [ $BACKEND == "mindspore" ]; then - atc --model=$MODEL --framework=1 --disable_reuse_memory=1 --output=$OM_SAVE_PATH/davinci_model --soc_version=Ascend310 --core_type=AiCore >$LOG_SAVE_PATH/omg.log + atc --model=$MODEL --framework=1 --disable_reuse_memory=1 --output=$OM_SAVE_PATH/davinci_model --soc_version=Ascend310 --core_type=AiCore >$LOG_SAVE_PATH/omg.log 2>&1 + elif [ $BACKEND == "onnx" ]; then + atc --model=$MODEL --framework=5 --output=$OM_SAVE_PATH/davinci_model --soc_version=Ascend310 --core_type=AiCore >$LOG_SAVE_PATH/omg.log 2>&1 else - echo "[ERROR] Davinci model convert: The backend must be tensorflow, caffe or mindspore." + echo "[ERROR] Davinci model convert: The backend must be tensorflow, caffe, mindspore or onnx." fi fi \ No newline at end of file diff --git a/evaluate_service/main.py b/evaluate_service/main.py index e123b5da..ab6e1115 100644 --- a/evaluate_service/main.py +++ b/evaluate_service/main.py @@ -50,7 +50,7 @@ class Evaluate(Resource): """Evaluate Service for service.""" def __init__(self): - self.result = {"latency": "-1", "out_data": [], "status": "sucess", "timestamp": ""} + self.result = {"latency": "9999", "out_data": [], "status": "sucess", "timestamp": ""} @classmethod def _add_params(cls, work_path, optional_params): diff --git a/examples/classification/classification.yml b/examples/classification/classification.yml index 4bf8aecd..e175605d 100644 --- a/examples/classification/classification.yml +++ b/examples/classification/classification.yml @@ -141,7 +141,7 @@ fully_train: models_folder: "{local_base_path}/output/nas/" trainer: ref: fine_tune.trainer - hps_folder: "{local_base_path}/output/hpo/" + hps_file: "{local_base_path}/output/hpo/" evaluator: ref: fine_tune.evaluator dataset: diff --git a/examples/data_augmentation/pba/pba.yml b/examples/data_augmentation/pba/pba.yml index d0aea5ea..3de50cb3 100644 --- a/examples/data_augmentation/pba/pba.yml +++ b/examples/data_augmentation/pba/pba.yml @@ -65,7 +65,7 @@ fully_train: ref: pba.trainer callbacks: PbaTrainerCallback epochs: 2000 # multiple of 4 - hps_folder: "{local_base_path}/output/pba/" + hps_file: "{local_base_path}/output/pba/" evaluator: type: Evaluator host_evaluator: diff --git a/examples/features/quota/quota.yml b/examples/features/quota/quota.yml index 238fb47d..ebfc2373 100644 --- a/examples/features/quota/quota.yml +++ b/examples/features/quota/quota.yml @@ -11,17 +11,11 @@ #### copy the following configuration to your yaml file #### general: - quota: - restrict: - flops: !!float 1.6 - params: !!float 1e10 - latency: 10 - model_valid: True - filter_rules: "model_valid and max_latency and flops_params" + quota: "flops < 1.6 and params < 1e10 and model_valid and host_latency < 10" ########################### end ############################# -pipeline: [nas, fullytrain] +pipeline: [nas] nas: @@ -52,13 +46,3 @@ nas: type: Cifar10 common: data_path: /cache/datasets/cifar10/ - - -fullytrain: - pipe_step: - type: TrainPipeStep - models_folder: "{local_base_path}/output/nas/" - trainer: - ref: nas.trainer - dataset: - ref: nas.dataset diff --git a/examples/features/script_runner/bohb.yml b/examples/features/script_runner/bohb.yml index ca45460f..82e9bc42 100644 --- a/examples/features/script_runner/bohb.yml +++ b/examples/features/script_runner/bohb.yml @@ -49,4 +49,3 @@ fullytrain: epochs: 2 # script: "./train.py" script: "./train_vega.py" - hps_folder: "{local_base_path}/output/hpo" diff --git a/examples/hpo/bohb/bohb.yml b/examples/hpo/bohb/bohb.yml index c75a2f16..ff8945d0 100644 --- a/examples/hpo/bohb/bohb.yml +++ b/examples/hpo/bohb/bohb.yml @@ -66,7 +66,7 @@ fully_train: trainer: ref: hpo.trainer epochs: 200 - hps_folder: "{local_base_path}/output/hpo" + hps_file: "{local_base_path}/output/hpo" evaluator: type: Evaluator host_evaluator: diff --git a/examples/hpo/boss/boss.yml b/examples/hpo/boss/boss.yml index aa45cc54..5da3c76c 100644 --- a/examples/hpo/boss/boss.yml +++ b/examples/hpo/boss/boss.yml @@ -60,7 +60,7 @@ fully_train: trainer: ref: hpo.trainer epochs: 200 - hps_folder: "{local_base_path}/output/hpo" + hps_file: "{local_base_path}/output/hpo" evaluator: type: Evaluator host_evaluator: diff --git a/examples/hpo/pbt/pbt.yml b/examples/hpo/pbt/pbt.yml index a9798422..ceb597a8 100644 --- a/examples/hpo/pbt/pbt.yml +++ b/examples/hpo/pbt/pbt.yml @@ -66,7 +66,7 @@ fully_train: ref: pbt.trainer callbacks: PbtTrainerCallback epochs: 2000 - hps_folder: "{local_base_path}/output/pbt/" + hps_file: "{local_base_path}/output/pbt/" evaluator: type: Evaluator diff --git a/examples/nas/dnet_nas/dnet_nas.yml b/examples/nas/dnet_nas/dnet_nas.yml index c5f7a657..4ca3ed7e 100644 --- a/examples/nas/dnet_nas/dnet_nas.yml +++ b/examples/nas/dnet_nas/dnet_nas.yml @@ -1,9 +1,6 @@ general: backend: pytorch # pytorch | tensorflow | mindspore - quota: - restrict: - model_valid: True - filter_rules: "model_valid" + quota: "model_valid" pipeline: [block_nas, net_nas] diff --git a/setup.py b/setup.py index 29576ded..69299d1f 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ setuptools.setup( name="noah-vega", - version="1.5.0", + version="1.6.0", packages=["vega", "evaluate_service"], include_package_data=True, python_requires=">=3.6", @@ -61,6 +61,7 @@ "torch==1.3.0", "torchvision==0.4.1", "tensorflow-gpu>=1.14.0,<2.0", + # "onnx-simplifier" ], entry_points=""" [console_scripts] diff --git a/vega/__init__.py b/vega/__init__.py index 8ccf3ab9..d925322a 100644 --- a/vega/__init__.py +++ b/vega/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.5.0" +__version__ = "1.6.0" import sys @@ -10,6 +10,7 @@ from .common.class_factory import ClassFactory, ClassType from .core import run, init_cluster_args, module_existed from .trainer.trial_agent import TrialAgent +from .quota import * def network(name, **kwargs): @@ -25,3 +26,8 @@ def dataset(name, **kwargs): def trainer(name="Trainer", **kwargs): """Return trainer.""" return ClassFactory.get_cls(ClassType.TRAINER, name)(**kwargs) + + +def quota(**kwargs): + """Return quota.""" + return ClassFactory.get_cls(ClassType.QUOTA, "Quota")(**kwargs) diff --git a/vega/algorithms/compression/prune_ea/prune_ea.py b/vega/algorithms/compression/prune_ea/prune_ea.py index 37463c65..f3b3600b 100644 --- a/vega/algorithms/compression/prune_ea/prune_ea.py +++ b/vega/algorithms/compression/prune_ea/prune_ea.py @@ -84,7 +84,10 @@ def search(self): return self.random_count, desc records = ReportServer().get_pareto_front_records(self.step_name, self.num_individual) codes = [record.desc.get('backbone').get('encoding') for record in records] - logging.info("codes=%s", codes) + if len(codes) > 0: + logging.info("codes=%s", codes) + if len(codes) == 0: + return None if len(codes) < 2: encoding1, encoding2 = codes[0], codes[0] else: diff --git a/vega/algorithms/compression/prune_ea/prune_trainer_callback.py b/vega/algorithms/compression/prune_ea/prune_trainer_callback.py index 93a672fd..ba701f43 100644 --- a/vega/algorithms/compression/prune_ea/prune_trainer_callback.py +++ b/vega/algorithms/compression/prune_ea/prune_trainer_callback.py @@ -48,11 +48,11 @@ def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config self.device = vega.is_gpu_device() if vega.is_gpu_device() is not True else 0 - self.base_net_desc = self.trainer.model.desc + self.base_net_desc = self.trainer.model_desc sess_config = None if vega.is_torch_backend(): if vega.is_npu_device(): - count_input = torch.FloatTensor(1, 3, 32, 32).npu() + count_input = torch.FloatTensor(1, 3, 32, 32).to(vega.get_devices()) elif vega.is_gpu_device(): count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device) elif vega.is_tf_backend(): @@ -129,7 +129,7 @@ def _generate_init_model(self): map_location=torch.device('{}'.format(device))) model_init.load_state_dict(checkpoint) model = PruneResnet(model_init).apply(chn_node_mask, self.base_net_desc.backbone.chn_mask) - model.npu() + model.to(vega.get_devices()) elif vega.is_tf_backend(): model = model_init with tf.compat.v1.Session(config=self.trainer._init_session_config()) as sess: diff --git a/vega/algorithms/compression/quant_ea/quant_trainer_callback.py b/vega/algorithms/compression/quant_ea/quant_trainer_callback.py index 7dc61053..7a3ef32b 100644 --- a/vega/algorithms/compression/quant_ea/quant_trainer_callback.py +++ b/vega/algorithms/compression/quant_ea/quant_trainer_callback.py @@ -55,8 +55,8 @@ def before_train(self, logs=None): model = model.cuda() count_input = torch.FloatTensor(*count_input).cuda() elif vega.is_npu_device(): - model = model.npu() - count_input = torch.FloatTensor(*count_input).npu() + model = model.to(vega.get_devices()) + count_input = torch.FloatTensor(*count_input).to(vega.get_devices()) self.trainer.optimizer = Optimizer()(model=self.trainer.model, distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer) elif vega.is_tf_backend(): diff --git a/vega/algorithms/hpo/bohb_conf.py b/vega/algorithms/hpo/bohb_conf.py index c9778b6b..b44ee8cd 100644 --- a/vega/algorithms/hpo/bohb_conf.py +++ b/vega/algorithms/hpo/bohb_conf.py @@ -45,7 +45,7 @@ class BohbConfig(ConfigSerializable): random_samples = None # 32 prob_crossover = 0.6 prob_mutatation = 0.2 - tuner = "GP" # TPE | GP | RF + tuner = "RF" # TPE | GP | RF @classmethod def rules(cls): diff --git a/vega/algorithms/hpo/evolution_search.py b/vega/algorithms/hpo/evolution_search.py index adcf654a..3214ddb5 100644 --- a/vega/algorithms/hpo/evolution_search.py +++ b/vega/algorithms/hpo/evolution_search.py @@ -76,7 +76,7 @@ def search(self): # split codes desc = {} for _name, _size in each_codes_cache.items(): - desc[_name] = encoding_new[:_size] + desc[_name] = encoding_new[:_size][0] encoding_new = encoding_new[_size:] self.sample_count += 1 sample = dict(worker_id=self.sample_count, encoded_desc=desc) diff --git a/vega/algorithms/hpo/sha_base/bohb.py b/vega/algorithms/hpo/sha_base/bohb.py index c4eedaa9..6e9c0752 100644 --- a/vega/algorithms/hpo/sha_base/bohb.py +++ b/vega/algorithms/hpo/sha_base/bohb.py @@ -65,7 +65,7 @@ class BOHB(ShaBase): def __init__(self, search_space, num_samples, max_epochs, repeat_times, min_epochs=1, eta=3, multi_obj=False, random_samples=None, - prob_crossover=0.6, prob_mutatation=0.2, tuner="GP"): + prob_crossover=0.6, prob_mutatation=0.2, tuner="RF"): """Init BOHB.""" super().__init__(search_space, num_samples, max_epochs, min_epochs, eta) # init all the configs diff --git a/vega/algorithms/hpo/sha_base/tuner/tuner_builder.py b/vega/algorithms/hpo/sha_base/tuner/tuner_builder.py index 4d84011f..7017b29d 100644 --- a/vega/algorithms/hpo/sha_base/tuner/tuner_builder.py +++ b/vega/algorithms/hpo/sha_base/tuner/tuner_builder.py @@ -146,7 +146,7 @@ def _propose(self, num): params = self.search_space.get_sample_space(gridding=True) LOG.info('Start to transform hyper-parameters') for param in params: - param = self.search_space.deocde(param) + param = self.search_space.decode(param) # Remove duplicate hyper-parameters if param not in params_list: params_list.append(param) diff --git a/vega/algorithms/nas/__init__.py b/vega/algorithms/nas/__init__.py index f1ae8f37..19515466 100644 --- a/vega/algorithms/nas/__init__.py +++ b/vega/algorithms/nas/__init__.py @@ -29,5 +29,6 @@ "sm_nas": ["SmNasCodec", "SMNasM"], "sp_nas": ["SpNasS", "SpNasP"], "sr_ea": ["SRCodec", "SRMutate", "SRRandom"], - "mfasc": ["search_algorithm:MFASC"] + "mfasc": ["search_algorithm:MFASC"], + "opt_nas": ["OperatorSearchSpace", "OperatorReplaceCallback"] }) diff --git a/vega/algorithms/nas/adelaide_ea/adelaide_trainer_callback.py b/vega/algorithms/nas/adelaide_ea/adelaide_trainer_callback.py index 0af01b44..38189fb4 100644 --- a/vega/algorithms/nas/adelaide_ea/adelaide_trainer_callback.py +++ b/vega/algorithms/nas/adelaide_ea/adelaide_trainer_callback.py @@ -39,7 +39,7 @@ def before_train(self, logs=None): count_input = torch.FloatTensor(*input_shape).cuda() elif vega.is_npu_device(): input_shape = [1, 3, 192, 192] - count_input = torch.FloatTensor(*input_shape).npu() + count_input = torch.FloatTensor(*input_shape).to(vega.get_devices()) elif vega.is_tf_backend(): tf.compat.v1.reset_default_graph() count_input = tf.random.uniform(input_shape, dtype=tf.float32) diff --git a/vega/algorithms/nas/cars/cars_trainer_callback.py b/vega/algorithms/nas/cars/cars_trainer_callback.py index d5e8b508..eb88cd7b 100644 --- a/vega/algorithms/nas/cars/cars_trainer_callback.py +++ b/vega/algorithms/nas/cars/cars_trainer_callback.py @@ -71,14 +71,14 @@ def train_step(self, batch): if vega.is_gpu_device(): alphas = torch.from_numpy(self.alphas).cuda() elif vega.is_npu_device(): - alphas = torch.from_numpy(self.alphas).npu() + alphas = torch.from_numpy(self.alphas).to(vega.get_devices()) for j in range(self.alg_policy.num_individual_per_iter): i = np.random.randint(0, self.alg_policy.num_individual, 1)[0] if self.epoch < self.alg_policy.warmup: if vega.is_gpu_device(): alpha = torch.from_numpy(self.search_alg.random_sample_path()).cuda() elif vega.is_npu_device(): - alpha = torch.from_numpy(self.search_alg.random_sample_path()).npu() + alpha = torch.from_numpy(self.search_alg.random_sample_path()).to(vega.get_devices()) # logits = self.trainer.model.forward_random(input) else: alpha = alphas[i] diff --git a/vega/algorithms/nas/fis/ctr_trainer_callback.py b/vega/algorithms/nas/fis/ctr_trainer_callback.py index 6862d000..3fedc70c 100644 --- a/vega/algorithms/nas/fis/ctr_trainer_callback.py +++ b/vega/algorithms/nas/fis/ctr_trainer_callback.py @@ -41,5 +41,5 @@ def make_batch(self, batch): if vega.is_gpu_device(): input, target = input.cuda(), target.cuda() elif vega.is_npu_device(): - input, target = input.npu(), target.npu() + input, target = input.to(vega.get_devices()), target.to(vega.get_devices()) return (input, target) diff --git a/vega/algorithms/nas/mfasc/conf.py b/vega/algorithms/nas/mfasc/conf.py index 888a3a6f..38518bc2 100644 --- a/vega/algorithms/nas/mfasc/conf.py +++ b/vega/algorithms/nas/mfasc/conf.py @@ -15,6 +15,7 @@ class MFASCConfig(ConfigSerializable): """MF-ASC Config.""" + sample_size = 5000 batch_size = 1000 prior_rho = 1.0 diff --git a/vega/algorithms/nas/opt_nas/__init__.py b/vega/algorithms/nas/opt_nas/__init__.py new file mode 100644 index 00000000..fe30849f --- /dev/null +++ b/vega/algorithms/nas/opt_nas/__init__.py @@ -0,0 +1 @@ +from .ops_nas import OperatorSearchSpace, OperatorReplaceCallback diff --git a/vega/algorithms/nas/opt_nas/ops_nas.py b/vega/algorithms/nas/opt_nas/ops_nas.py new file mode 100644 index 00000000..1f3f1224 --- /dev/null +++ b/vega/algorithms/nas/opt_nas/ops_nas.py @@ -0,0 +1,83 @@ +# -*- coding:utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. +"""This is Operator SearchSpace.""" +from vega.common import ClassFactory, ClassType, SearchableRegister, Searchable, space, change_space +from vega.core.search_space import SearchSpace +from vega.networks.network_desc import NetworkDesc +from vega.core.pipeline.conf import PipeStepConfig +from vega.trainer.callbacks import Callback + + +@ClassFactory.register(ClassType.SEARCHSPACE) +class OperatorSearchSpace(SearchSpace): + """Operator SearchSpace.""" + + @classmethod + def get_space(self, desc): + """Get model and input.""" + for hp in desc.get("hyperparameters") or []: + change_space(hp) + model = NetworkDesc(PipeStepConfig.model.model_desc).to_model() + searchable = create_searchable_decorator(model) + return {"hyperparameters": searchable.search_space()} + + +@ClassFactory.register(ClassType.CALLBACK) +class OperatorReplaceCallback(Callback): + """Operator Replace callback.""" + + def before_train(self, logs=None): + """Call before train.""" + searchable = create_searchable_decorator(self.trainer.model) + searchable.update(self.trainer.hps) + + +def create_searchable_decorator(model): + """Create searchable class from model.""" + searchable = SearchableRegister().init() + searchable.register(Conv2dSearchable) + searchable.add_search_event(change_module) + for name, m in model.named_modules(): + searchable.add_space(name, m) + return searchable + + +def change_module(model, name, entity): + """Change module.""" + if not entity: + return + tokens = name.split('.') + attr_name = tokens[-1] + parent_names = tokens[:-1] + for s in parent_names: + model = getattr(model, s) + setattr(model, attr_name, entity) + + +@space( + key='conv', + type='CATEGORY', + range=['Conv2d', 'GhostConv2d', 'SeparableConv2d']) +class Conv2dSearchable(Searchable): + """Searchable class of Conv2d.""" + + def search_on(self, module): + """Call search on function.""" + return module.__class__.__name__ == 'Conv2d' + + def __call__(self, module): + """Call searchable.""" + cls = ClassFactory.get_cls(ClassType.NETWORK, self.desc) + in_channels = module.in_channels + out_channels = module.out_channels + kernel_size = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size + stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride + padding = module.padding[0] if isinstance(module.padding, tuple) else module.padding + return cls(in_channels, out_channels, kernel_size, stride, padding) diff --git a/vega/common/__init__.py b/vega/common/__init__.py index cba02abb..2e0f2995 100644 --- a/vega/common/__init__.py +++ b/vega/common/__init__.py @@ -12,3 +12,4 @@ from .message_server import MessageServer from .message_client import MessageClient from .arg_parser import argment_parser +from .searchable import Searchable, SearchableRegister, space, change_space diff --git a/vega/common/__pycache__/__init__.cpython-37.pyc b/vega/common/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 12d82eecd6a022a55c8dbe6346093bfccdea0c8d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 949 zcmY*X%W~5&6t(kk94B_%Bz;p#*=2^od;r5R6v_-Uz#?U%jS;qL)au$ENhzd%#CNda z2iSPaieF&GmD~r|;mh`MOLx!ly$b^Ym4bIUBtb{(u*a)TzrIp+;DN~~iiuUkw zW>^I6`OQvffSms)X4V9`X~Y11xb{G9q43R-+xW8K&!eP}Pm-(i^F*?`O6W>gAd*W~ m&?Kt>sxnDaJAvurdfoU>+ZlXjhEwff*l}Zbe0$Fk-SEHhPYV|S diff --git a/vega/common/__pycache__/arg_parser.cpython-37.pyc b/vega/common/__pycache__/arg_parser.cpython-37.pyc deleted file mode 100644 index 4a92bb105dcd4b7ea04dd45d8270065284a892a9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1385 zcmZ`(ON%5$5YEhcboFC;S{-(FRir%#TO!@+MHpn+ahX*?yNBuZQY@rJWln9?)T7O; zUf81Nv~%?rh%n<#@LvcD-o512lYe1PMr8LmvT8v_Mn*=)7a8$oeX_IDBk+|!|Nix_ zJ|TbL;`(zzxer5s2ttsE2r5`XBPyJfMoe^sE1U}&IiC{Y3IBo!U((!t?nIto3laux z|4^mlWvsMRlRqH{nVDg%(o$9p-q{1Df-XpZ`2jVuWO$+@EYM0EGb zN02|}9QI}*d-SmNn`cXTFmAialR8VJJ{Z5C>$dmKT_I27Rndg)-c~yp-xzhv#&=ux z9yo{Lu5q70A5AaQS*4q}N~A(6MSYI_Jxq`u#zi30*#23BFtOs;Nd(5S`X#4idl%BjC>u+8mzZ@8rS4c|iyn0p8OJG(`uIK_S&?Q0e>CLdx zp`-S|Qz6OiQ)xn=A6oWN;G_A+J?`+$Yvh|E@T_`TCowkS18~$xYI`BpzWeXA|L^E5 z^wcYG(DXPjYq2V@Z3E7~S;fVsiRun)t2a>GBxE(Lu6`d9{;!B zzY9x(@Fx?pM}@&x@J8Q4g9J-JB9c7O5|NA)ltX!>A`R&S)soO0DkAx{XvwIAl-Gi# zpfbuqpF8>S@{c`3T;}5 zaBRO!CTYoy?EY|MBYfocF&76H#o(goM;F${dsDFrrs9zF_os+DJmCP9H@9~B5_1m^ z1VvEbZNMA-Btjs>SU3<5^JIzjNL&5Ly?+dFF2y6 zHni2mmY&MSs&Fcz!l6!nJ{8|cZ#xKgL|Avj zEBM8UA~GI}Udb;{a(ukxRs8BC&&M~t8Gm-7^6@Qi&Yz!Xe0?HdQ1MzNfG9XOam1lfz0tYJsg{9o~#D1L} zLmWQ?A7leFJ{jU=6X88)&Xf>mFYb&qnL2mnrm+jL}M>V>i~eo12~PE|tN6Rq3v^pV-@L%<`kw)26l1>9#go zZTnqZ=RmJzHLdD-6^674B&JXksk(^+OdPK!5uN!Ihg99#-R^W7k7;pzt?{JU-mtfK zx0*G{GGZJ=A=&f7V=Bhf*lxCao$jNjo#!;~4uX(iFsSvN7!N{n0tl2mU8qkNtjraYC>iEdMx>rZyq-Fh z(!?~-d~s!B%kL9AS532uns;~&^_rx_pL6V&&U9JV0BN<%El?{ z`Ep)1W$TppX5`!p)>$ZSf1js)6K-WXOmWL!l0R;DZyaN3%u zXWfPvx?8}7)ATjXC6Gy@JFY>@J#^POp%m}|uWCP62cuLgX|Vw^#q!f>4<9dI`Wrz* zV$uB+0}1ci^YsVtGO;(KS98Rk^*Q(4IpM$YVRO79-4}>MvRf?GT<8nyp>@3~80*E# z>tj0?rY0OuhUqYTef;z#9CY^U!$j@pX*R5UMC2gOJKmejQBdfBI*Z>Wh=flEVvtP0 zVZa0K^QHqh;alSekMXGeYop`q7B*}8L@!X+rtSuHbl#MWOGArTpop}tOKfSHNg$I( z->yOa3s%Q7m3bLo|Ptph-e_h_f%?Z&DcKtP~|*VWlIT6?Wp7&vGu^GY5Oz zdhXmM(rn0xqFvLJ!$3{ZMkPoo?t>poN6pe;Wb$xRg|zkvBz4@Xks{Dc9uE_xmv9Sx z3*G-@=?g-J@FbQ5Uj$FZ0=?E=>s_c>#+jef=@jmk8k%`)Y8o$+Y|PEgE8N(HYsmQ) zCa1}fj*X;eAj#5`^Akrun_$zPq3etp^Q=~U^m0dquj2klh1nK3 zs`;b93oU)=&@w0V9bBst{nX@dP25u)F$ph$0GsD+BCmak)4Q;@g5%f8#DqPXnG^s@ zj*|#@1f_ILKptJPU3TmqJ}8l(Y=OM{lXEPNz1%rrfK-$oqJY~nd1q24TQ#J0K|^^I zHWO4{nyZ1a+1Dm*WC&Y2DC6|z5-K%OT^)mQRlK&WR1NTwO1u9#zOQ}W5%75*Ua#+C zZevpMDf?31un0WH0KUNB8SNAgEcs(Rmws1}nK5=R#tz0pKg1*JH5WQVE*8am-^F_> z9I9V!O(eIFX_8<5N1$)QhNiNksSw*#htlE2w9US$HOK%(36`ngYS@k=2W=KhP0WY#Bse5_f zkMsD>IahCuj1&}n+JE`yhi^ZpDF3EP`q5Fig5=#p#uZcHjH@kXGW=>SwXK;Nqc**z zw=-r&mNTtv+c1rG&djy*W`QX-+b`Os17?oQD?H1MM+!IW(Ou0PLp{gysORkr>Sfdm zyoh>H_MhPDs!}Zl>o=^9wPCmIj$diEEYGV5d&BX3Vfl{RsaPFex#J*fx}9}rqn9*y zh1<1-@7Ufe6_oEe-1aK|rd{zLbnS|}o-F0{x?NZJHn05$cT|?DEXr})?l#@-1C;f8 zr`N7(QNi)ndz~h27#WTyj}hfPXQN~JJz+O#Zu(&;TtV_Sk=Y8UrI;#brp8rM=bD+} zx|!t})8JV%#|<;jb7q0(%_1+DC0;Z~c*z{)Bjy+%HOqX=Ji*K6I6q;Y%ED>i7*9%^lbvs^4Ezi1V^9r}uoeq9u3K9`&kpUW_ zQ3Rt&q1Lv2;WUF2Y2EGmPTL7EWFl>9iH6tQw0Wsyi0vHkn5 z*fz(y0#qQv@W!U=Zkxq;qSv*Wc4Q#mhB!v~_&B)iQuR!fYcx7m+iotCjQ^jRtp{Xg1bXW)_y}D`xgaeQjl7*32*8TwA!YV4|tCxYAgiy;h&QwTPO& zxV*X=YJe+4|v@)%wiJ?6tXt$s5y=IhvZW-&76X4Ym_>uY8nXIWaky1a73%+1a$F1|H0dtDSExn^c|d9FTd zX5PNFyfzaVOZBze%PZH-v70M6G@XC0acz0|x><}jT)jCnTaOIn*JtMIQ9j-YXa1cc zp2d1H`|7pn4e{!9u(dTU?5;a)^?bM8n!aamSkq0n-F7?EQa~H)RuhzcQ0qPr1}0A+ zc{OA?rmH%8+OM4ZWaXq4A6HH&k{*q%N-10=eoJ+F<$tSFJhURmN@%H>MF3I9mRR8} zRWqWDox};s#wk0A%q&7Cm+G@?3(HHZL+Zq(voz`($fl7z!eF|zY&mk30Xlz*A&0Nx zD5h>_?Ch4o)peC?Tz|xNnVCD(mf>012IQr2>x>&bCtC`3hM(klUXZmSYDHdpq+ost zVX4o3o8 zdZKfwa$i_o7^p!lNEEm%GG@>WCvjV&yx|Sm3=*)aln|er}~-gtmy8s{cNcC##RoaVbDWG z67Ltsh$mT*oVt=4!$U+QTBd|-G+UrUO!VVC9o$$>I7SYlnoqNjj9LoL&sv?1>kkEk z_qr{oY58`t%gYb@)oNbq8li5=A;RN#V3X{+5Uz-_WFss;%J#Z29k$e|s1y&j1??*l zLrj6}RW-`UqeKPyBP9c_Kt{4R(l9v6lO2^EL)DFpgh;c*K$1=+JI)HMtQMHcN^FAV zSeXSc9@dUpdW72LmJ7q>y;(~Zk?U4<@e-C1FH`m+WmU=uRjXRuFzhGzh;r$Imr#=! zqU6it!Ko*1+9Z9V4~ZR><;mYeA%sSe=!F_Yk3sbGP=yG_qK6g-;|2zbLk4Rg8Fii-hTMFCsskG-qycyyZMTs;G9OC* zs00P4pU6lO-~mVyGbu@C{se`cIUj`IWpm2=7YQQw;l>{@aprEpSAh!JuG-gqE!20J z_kCZFf6NSWEjf4AQ zU^`3v_AZCYK5<&|axil3@sUtYHTYuA212&s23PKH+K8M|D=KZF<5irFjHjACAt*GH zdb=UqUiS%1e}XCw4P;C(@t8S&M#E0|wmH#T7q8N{J|&479owWFGb} z!R%#Ye35YTI+l-3*lCm~7D?{jNQGFU&Uevi=A`i=Gp2zu$L=k@gCV~}^5pY?T4JN} z#=+E)+osE7fx&&Vm^A>N&rwOlbdJLpZ7FTmr&j|NuLgKQ=<8$+e0?jkt%}z}Z7bVX zI}m8J7-*Rbb=i`xVfs0q+0UnL4FtLnQ_>6ME_o54tvC?!VgIi&GfOjx?&K1I%BB4g zS_wJPI1^?-S0niCyq$*5$KDpmHI<_HbxHWpu^4zuPjVKIdPGZtTtz%ijLgMxFrpJY ze4t^hp2pvCQ}A5V>b&Gv5cwy8e=WgiV*MKr`@zNI17bCP*iXhuT}KVN9_wJLtJP8S z3`K3VBn0;mOn27Pnz)D)nkVIeVYp8t#1GLEOXxEwMLL~CoTA1GWztlBk4hwfQi}x% zHKcs#;=Djzv_-tyc`6-4VZw7l=g-jR5t1uV+7iA6xL-(Gj*Y4ltOUU>tK%$~I3m`` zG-U)@#3d9t8f{gv$7$8(PqEKOVv_5BQ~H~H7F zs4UDt{qakF_qHb1q2M2e*>xqvdlK(k_M*u&fpZ{6@UFD~k^)AwFgkwoanZrrs zsdw4O2-DDt=dMSGTKoEgGog-MK1k01R};7I!`J9s9Gu}dPk)Av`mVBbZcj<5H5pUO zKN%K61JnZee9G?r{H500sHXGK<)SK)rAwID6G{T5)9B z6gCAN>3HDpAaK(#O-DM3Bt^~vtKX;XSs!M1PeH&&{xkbby{oStl2`;4YVFmCZAa`L zw74LnQLlm5#aODzAN;CtnLAA%ZzSC}|3({7n3z?9Z$Bot>CuLv#Bk=78fE{|fD#$PW5+F?ggvQ|vPO)>SPoZ3ft)0YI#V^$+kprZTdi{rY z7$6lH-yD+k`mdm{b178!6^R=)h$0N`&iPOWikyYPgMSETM0)=R(g^`7K|k(JLH~!o z{!tAfx1;v3PcK^(RY(sxhHZnYrxvxJP`xIw%YKfsvCkOsx}#H%#z z7<`Yktl&Ws#oOW^7~#>)6piAb;LH)BKQgq+puzI5{G^As*exu|t?|~q>GD|8Cg}_p zC?khoyhhp33Q38)MQs$39XhK#wG7wD#LxtDcks37L46YtOdBIFyQO(9~G yC1JPm#^K)CqB)Wq^o$q>GD89n4jtt#$1l}yl5LZLrVi^!R&qiuXys?i+5ZB97{2)c diff --git a/vega/common/__pycache__/config.cpython-37.pyc b/vega/common/__pycache__/config.cpython-37.pyc deleted file mode 100644 index 91a68147b7cfeaae9ae50bdfecba3ab29c644e9d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4942 zcmb_g&5s*N6|d@Ux7!|%GnvgMo3MgLSw4({CoBS1vqNAv*6B>DgVMsDEwYk+wPepT%eP#D%Z#B z_kQoQ=IzzhI>Xcd+dsO0d6}_)(U19KVen17X%|hvIxOIU5%UgLz0omLZ+1-8Tm4GM z=6X-HQ^lMaJN;Ux#@Tx;u!72C7F5FevDsM(?4bIXbsB*a)X=X6#%9Df`7Z(>(-&At>E@$j5SO0K( zv7e@Hg#Uy6P()ITOCq)Fbr;X6eLa8R4tHM551fqJT44U%m{w#xV_D_c;Os68_6O{Jc8?AI ztWH*N@}F{L*2v5g4l}TSW?W`d8#AjkGp)+z!MWL5P|d0*I4{p<_A#G2YIZ%So^UvZ zGWN_lcxiUwgxz9TuVDS<*?MNv`ts|88|s_&6Al?J7W`fTzZYlV7hiimr;fZ7tQ_;7 z^23V?gI5#=Zz#GZNrm~#K|}Aj%Q8afj#@QnHLzfytCKgs1>MgLNgL*)NbbA+Q7j{} zlAbbPWn<}F^V;*;lS-R>bYiu~{TTjclS7AMdt$bS<0t=q`t+$Fl`iy3PFKNuoAm0U ziCD(&zw#qv%BM!gX9zmbot0tyE8~#($x@X`(a*r=*Oe* z_e`vS+Q(nm*b~<`CI<%_A{-_g{zxYMc;jKX=Wq0qem@zYqbLLJKCW+(sbYF~qfP8k z@Usn@Pp)d8XnP)ch3B=m;-u%t>5caC1Oz89n|!ha4iHPFyAzhe*5nd71v&a4(IJZf zga@^*sTQ*W6-7=a>wJx0M7xf+ z$sNquM$>GdZ(yd*FB)|k8<^*lb7!=A$|gw(-=hz#p^>o3U7qpZ8kzA6PBu6>y6MO9 z+!3^oAl-;ij>N$A-9h+h9!PZxDx$awyL#8=^S99sdNu`x+RcKB&H*&v;>ti;ruYmv z6Q8B#bJUQzFA*Zu!9-)vdmHm9sgzYqBI~uDn#U6LB=SwXX%h{JoN*NtNMK{~hufiq zo%%8)<5jYzG7(`hr?Ru$58csV81){+VNfh6Z%oVL>#U=)f})4qgL7Rq?yA)7ru)e# z4#;z8DH+M(h*F432?TIog0r&Pi*kml z5v*vdqM$;xhjv1BybO@S9*#&9h|a~$p%ll6rm0!O{2XVNkr>gm0SS7`tl62NVt_*U-v z_h(^KlDI09ca@?2Z{TS+R)h)<)NNsUuNuqc0aN-!>#pKpysltBFY9K;{r3gZv~@BeFlizv3O<@)c|5rnZ8X}l20 zE{U{QJ|z-WRLU~eF`!C$W}qHKy?FxHQu*tzZ|x!^y4_wf9Jhlo9MYhBZJEJ}xbkfF zDuwTU*iSaylza?|o>OtVG~azD8qJ+eZA9uEjwwl?L^zu%&-&UaR!Wh~(viy|Qn@q# zyb5pg&~FhdXMGN#Zob*`ezJhoDFv?OdTW(ZDMbTRL3OXVIARPcha*2OR;r+x^iZQv zWg@;v%?33TCB;{$p-M%3jhZ*n+L^kB(y*jjK)# z|C-LKQ+FIC|3dMkbc!BI@f{2(l1k)o@x!BUqm;%)3F+xZk)9Z?NVN8@G?Bn8n1}`2 zH*mQyaS^Fx7Sb187%ouCx=emSiR%KcMLNKQ(8@E{?xd)Qig;K=$odX8Eqv-v_z_Xg{~dBZQ=tA@gBk!Tq_SCN%#gltOS7=A zf~*fX1Aw^2GE3iAk6y)$8CZ0Ogi;z)JF~M&gnQlLHq{2x>d~z^6|7%V>s8SHFso)Z z?wwW8q`q=E&?B1PV&7+ngS+goxu`UnMWs=hY(Nc4VZ+#u2B-gALdAm=VOYYyo%8k+ zT-TmHq5Il4qOYnLD(NteWL^~^Xr*DxLE-O9Ket7g;(A!Zz2YGYr=5yNggaHz0w!{M zN8l>6e?}E~ePx)a8S2fs_&eP0}`9Y{VyM>Jl|qXuhWZ1{6m->Z+q{i(E}9T~REtr|N3*#@yM`SCtzy TrCic7h!?tN;5FACIp_ZadsYe( diff --git a/vega/common/__pycache__/config_serializable.cpython-37.pyc b/vega/common/__pycache__/config_serializable.cpython-37.pyc deleted file mode 100644 index 36ea9045400cac92563c38666540ffa9fa04549d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5199 zcmai2+m9Pp8K0ZSpRNInNiWpI*vp!?T8+*n% zGj6l)SVG#BninK^I;8BpZI;>89Ten0=CYXZ;sF9yZpY( z@tZ3vRRf>lU;h5d&n_9pKj~%haX|bOE&VGxVfcnHg*7mJlb@Dv@zeGlJncbg==!cn zbDTkW==t8T;#crq67HZntob#j%Y*uG#b4oh-k>pD^;bbxM73x6Ymqfuk2dC}e?inl z9V46X7-B_Kjxi&;cx3xqq9Im~4gZo@6YF?xiw$uB&&y&{RPnqbE{ZKYuVURLv5j@x zk$JF#mDhxI-)LQ)Y~M-ZPa~N%JA*Jyn_1FKGue-O?XMxDvDY${C!%Q7Nk)e)OL_ft z7-bI=p{o6~JB~Y9KZ(CKF%uJAY}M`eK-KTiW_QC*mdHcAS3eC0eG$m#yJH*lAv$VwsBas0c7^HvS78dkx z3kOn{A{X+zqI_%~nSNP#q&eTiXjRmX4UAPl*To9cRnS{vRjl#s8s65$2LDnAeL-w8 zy~29CBr7;}(qt+23o)9S{pNm{Mp~Wiy$lQ|rk+di(gjshycmjod!6rcP$%P6kT;ny4cqbBHEtR+<}AC*I&0G&;&1 zJ(fGUmD{(B$L$A3=4R#GI)YY!0cUc3VGOX%79bdeP0?rbhVrnz6vG1NB+7jf;|jW- zksBXxOwHV!!6oKq>g49{t;Y^FHkB0wWc@S?WtM){&mICoR{N%M!Yq@2G~^mK_~wn> zp1ip`IXKvr(J0vs$5}EQ?0y>c!re|X940XygsUJu0h#vb@SEzbK|jqfXLS1oVuyHh zjY<0sPhZGH0BdJSK+?2t4U$ecNNxf0b7)JG${8BMfw3!PzY zm>Z^JR`K03>*zi7lPmvSnICmXSo)OFJhVG#>EF=hX6og@NNyYe=Z`G8MQ@o2R04Cr zx}96;&D@yTxji>c;5UZ{$h%o7XM~$rdEa>QG`E5L-vI9}#)0Pa7`QJ#DodGpvx=~0 z)!dnzyhDv<I1=*K4LFJvw!>mAsTS;1hoa2;;qTJGl7+$z@U zQCGh`U~XFWO5+<)|D3oF(O?MJTt7Scj(t5E~?d_p9X_|yc}2$21zJR4V52;DcN7G zMo&6}v511=VC4>yUN4fWB%=LsPr2Ph4#P~9(*mpnIZ04ptSN^aMpZ~g#!{=S?9KqO zS%g{WSJOv*#B1VCxmqL!n_uHR5-{Zay8cDKY0#6)Yxuncs*T3JX4npc&jIKFLdydH z9gD!|nQQo3Ol_K5)|xrlIR}(WK9~MXk(V5hQAuu=83!gXXU&bw7Upj$|5=l3i#V~g zD#yMe3ARl6Ev#%g$~j1r7`ht|N2w&wQI);X+i@omk;E+WzcT`oK@rhp5d0iJvVR-a z4S0Lk1`y)!sv>AvaVdO0+57hR9d1bWf4{12oN*J$e2du1-Sllu}8t%o%C74@YewPTY&Kv5cF1aud*M@^UI97J-ME zIl~8q*^f@}LDdLC{lX?3yLrpIr*uR5VGedcP>_+;+@v6b1PBe~M!=)Y+0{3>wkYbG zMD~0AI2<$y#ciac3qi0L3p7q)4YF^cK2?G+h`*;9R~L(lWx19C|4R~u!m?pkDSF;G zN1SsPp3<+4^(-AZBTKhImM4lrU19s3WiaBxZNZ2ZjG*vctd`_t1&GgzP?Av`ef9|^ zq-03Lc1^AZ|3m*Y?G%IO0buo=C~KzU{RJ5Rht*XT`LT#;T{@m&Z%?kBqpfE%TjobC z^KPplFJpMQ?OQ77kdrC`7Gd_nEJtRn~TKostzGP{(K6G!xTvQ2GrVG(|Pkn(DHe7suW3?NO_%T z8@WXOfT$}0!8sj8oqo69DNL=mk}uF&@>UMvysgG2AySq4Su{lMeH0z0El>WC=22W% z)f0FrXz~NUpmna+M-Xn5jUB+zc>gSdD3U%#ODQEACq;M+zl#SgH_bY+605aRNtVHtti3Fu9reW3j(nc3I&zh_+ z>(lmPN0*bQd>wP;_o(|mb+n0EDYoxK0BmstorA=4k0x=NAuD`|#%YF3ZWfMW6H~bO?(4rFc@OuT5b0?6qHbCGUBUKIMM{&a}u( diff --git a/vega/common/__pycache__/consts.cpython-37.pyc b/vega/common/__pycache__/consts.cpython-37.pyc deleted file mode 100644 index 1698064f3cf4485b8d494ea8a87cbfd18910746f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 870 zcmaKq&2AGh5XZf{&h92nDTPu#PPya|L=(;xLM>{bm7tZH0K%8WjZN5gZKqyuP;)KM z!JX&e8Mye$DX+kZne+n=6<+zbLFEfeKluA{HsiXh3B2iik+iuS8H8o=4{at1Wyr5rZ$5bm5Cd!V-}VNT)tqADop= zs@BNp(zT}C=KQFT-8n=doyaG9+Vyp%d|c`Lw343UYJ|I2BLHIyX4wKr!UD)H6akAw z$f$@|O;8q#8bpjCV&a$H8+#({@N6Q>S@WMocaB%hLj;sh=2-K7C-rHTX|80J`9_v0 zBdP-JW|o~+ytv#+vP_r}xRc@G_TGN~(AWF#wqL*99v&R}Te}B?VRpS{D?4(Uk7R$! zi;6pAzZ1I(u@?CKxHq;>d$Z%?o|O~RxmE&P*m!?01$dW^R`dO~f){4RH^#j<}C_ zfOv>_gt*r0!n%XLy8t{<-Uj?t<5CmcovpS%Y_!!zTeRA{o9#E7?Ey4+5!6!&FUw;0 yEq32h1PZ*k*RSnxkRey{rLl^;fwNHKO=Jd%)!1DQ+`M2OPXZQHV2YR&Cd zYvKvu^UQRk!%puG{?U z)E(a*{+q7X$|bYbznyKf_D0>gZC^G^2Vy1&N%WnEv?t0Q|yAQaYT#uIBpcTb#xaLN&bh{oFkCkqQt&T5TkG4g_Wy8?JHYTZ! zMvHg5{*i;x@S|B@u9~Xo?_jbg(lFiK}>372=uNQ~yVCgM?-CJsg?RMB%A|W@zZdB{;>LD8# z^E3%IkH#pL&5D`KrK+saoWzi+A_tyK70;q`lq%#|VO}*@HBiWb`MR}d?>XGh^=+u? zBItbICIe7JJT#{S}O_s9^8sZrCKA7IY)Xqbx1Em z-lLWCYrGr7OZI_ITb4Z82Ugu-kXq>8p1HFe!&fxNzID;~_WgTS-|T~j_=LTcmy2oN z6wdc>k9y~g>x-`%BBw{&D#)GK*5k#_W)4b)S6&4tjj?A##d(ZX*s|pZpbL-T-jEL% z$)m!Hoxa_7PUF35G5P)1-~yQgvhx8VT+4=YEVAW5Z_EURi1_ye&<)PE_G5`i+ z14;7nWkQINKYv7kC8?=ex6Gwm=IblB%;k^%j*mP^cMHl6{El)6UsO(dopqQ@-0Q+B zMz-V-2#gz!GFvL&mHt|5M>(ZmAL^F4vz-h!YHh77%v&G z;TMXulbaWN?XK&&;p!&fmb}h$$^PA5)atCe-Bxoe0A`LrQ$ujdj-O>W2EzFOVjf{k z_L`HxnFj25Z73q;)x+wENNZygC~JFl1+JS(TFzpZ@;Dmhgk9(*uisXVz|yM@yRNb^ zy($hTQs&x(^wkt`H}tD7knqUyjbg=QC_OQh_!w8yA`W~=_LUTeV>B+O4u1!`--yUYBfKpmCr&L@@_P67EDpKHICj#w?9V>VdEHj zP7znCggPE+C6p?rL%5kKVT>y?d;)92j_5^M@HR5~K=*-2H!L+akGc@OcNp9IW?-=pyR{5qW%cilH zH;rV0xg6r^Tk%{!zi%EOEVC0MTI}WD$-icN)96@mW-_u|vVS4$?vev@X6CNRu#KFp zC!0Pww(z4E+K*e|5eQ927Re@{QuAxC%xr4p%OZ-G$A{v~Ys`OFl#x8X2Ka-$VbUh8 zNUvvcxGvw*gQQ@vfh)QfjZvI~ z7g+FviaBG=AS}$_OEg`_t&`^QF%KDAY33mm8Zr-|I^;+Xxhw!y$Qo*}O3oF?db30j z&tTclj(ZENH9sf_hdd&))@4|0X{#(hfUVkWtHhIIud(8C-=45l`yKm`t>%;OULeaz zgU$cPQssjXlAK4AJThr_P$CQyEmcT|tra)fRv)}PVdebJCs^u~n%;=xZbO9I9fG$p z!{zw)5G~~k)O?YeN2&P|HD9LYF*IqUlaJGluTVqbf5N6zNm~Y%@_jswXmdl$qV$4r zb_TXIXO5DL?2VN}rbHPtgY9oXAie=?N80S< zQ?b2?SSSF9NLdZ!Ac%^{Oz19yEr6EjVqZh66_*BOkq1;VEE6xPmJ0PmK@{JEnKA#g zj;>5IfXZL=0e@IoiQYXx+I$161t6BaHxMT-rV@9B7b9(tH7t@cD9+JuvdelG8Fnj;PNn|IftnM!%FhR#Z<1tDJgYWbdfDnaYzCm3SKjG5ED9Ut#$UUT2delEH5o3MOFn-`Px$K zx4X&0P<6;e;f#o8jQw#)?nhO5uGdLg;QRXIW+|%BY!oAA6{pjOWTkRJxD9YGM@UHH zh)zC64yg*>YD90#8Cbq5QUOO0uHyx{r^*R`wYRPv^a;9AKs;fdMf6fntrv%uP%n?P zsSp!TJTmh{b~HBxhej%~pJ6avEmWvhSfDC_^VJyy11dG)&{lGWOaZ4;waM2nHz_30 zhr&V@8h!w$;G-|2ZSFY){=yEB+EDdCWuNv`0Dggc@K1t$BtLoXm-;v=i#7oAd67fT zJHJ=x7Y0-C+=9pggs^q?%Up$h>!$hs4ES<`N+0d!GzFb_2Jg)RMdl80JhoTt7pVkB z0`%tYLf_duHaM=40GuZVCq+@GKAoo>15ZzF&c~4u4##}@Wp>7CLr1fj?+Zxu1q<9%yzhnPAriRM`EKPX99I=BNLlDlqyI6k~Q0pCgP zAs^$z(m?v2*v)3+{1wj#Y3{0*qSj8W9lqs9@;SVQWU$$ZgNb7(c^Qv1c3dQi<3CkW zDz_~WgzC99z{OU*aOK4d&o?f;rtDTHRz;dOz~LrNuVUHi)(g_V-a|^SrXV-Kw`}=u z`J6o{r`18aN;!loUVG^ZhX+-7?nNdznZSW&M=7;zKiKt#y%lr(Kw z^_GzX6LkhhBlGyu9V*cmEaWWN$r?FFmP=)63!NEcg*wR=&}Ia5!et2q>B|BiZgP%` zl*Ar5RJI?`A&CP>LD`OjPPjYVb6C97X_vFY7MeJ6c|F0|A|*!EyedFwWJ#mIp%u%T zEc-jL^qR331j93;kp#bm0hEo`Cxvqj5B`EH;(u=7Jc#gq#?1czP}8c)4kqY;!p6mE z+8GdO4GQS%&NIz>#U9M$$InH>`YSEGDE&cZdpn}&3TQcHALR&~mkb{s#sxn28@@5D z?6L_m;7JpR)i4aQVG%7!Z!A`QBufrawY1%luqVYtG4^k`B0iugjl`pizpA-XH7`>; zr-}`Zw2g);HyZ6w^l+@+P?bjGdd~|e2BTMKz>=CErug4G`rjTL`}x#Kd++CI9S=}L zRhxX08ah~KgO;Q?Ny{2rDhGEuGnTZ>lenV0&=e7}E0xM@XDGE@0XxqpOEMNhCWp2TVJ|=p+Mk#hZ^lp6Lkip_>rRcopU~t=U0=-u!CRx z_kaBPpJ@G?K=ZTFxQ&wk3zh7!j_k@_)MZ`s?6Dr6R%FE%v$`~9M|Rw2eYE@1iJaJF zZntxO_t&6q0*sZ1}U!YPQII%+xWR|j@^`Cd1IkNX0?v(Zyz3Uz3D(~^&*--YMciBifY%E7U++w;UkehsTr)t)V@`Nk-N@c4N#SNfg#`tjqrR#`e1Ws%0w{7kLHJV@g>P3Db!vUBZV z+CEBtfXeIIy~FOfYj6F{8~-B`($a03@;a(*k;opiRsZu2b5;o*dP`(i-(yeQL<4K0w+ez%*i8hrt&MB>(*;Dn@GKZl6$Dg zAiL{yO>esqZk@{cqY>BIub&O|@IPR7LF89oeM!Hs-(AeyHfx|9 z=m&yeLv;p~1ZkNRY)b7C`<8D;TkW}t@6S*gdn^i$rrvgYD30H^2X@$%b3a|JlwSMw zv$3PH!(^E*X8l&`{YagusFm8ASn4N4jm}_lwWgkSTGuf0Jz71unNBvVavlkdV{7f# z&xYD!=CqQsi}OrXee$1brFMknevXopzH1d|>T4(W7Dn31K?6PJh8^ai3^ruFk8up? zFeujhO)E^3E6&3<&aA(T(ky3ggor1>dF5n6i#TT^bDnY75=!RDQi}p#V-8jjQM<;8 z@4G0ukE*t=*|M6(+O2F=x(NYUX~N*Ww_JuvSe&n?P3y|6R%z=WVm&qyn|XX0MvD2D z@c&2XN0|2;lt$Osr_eR^wzF>)5baZ9ZdcyJ0*+HFFmGCy%&~#a&Iu7xRNtkt8b0Az_^tkIbv5Y4X$Ao2_mqB`}*Z zf@(mYHo1j404oRAI?h?<(v7VW2^-Z}7$_c4h*Evd4j$d*cOE?8AAPvs_kVkj-~Zqb zAKW*W6mx_6zBETXbLu3ZvQ%=$4v}F58#!<#f>VzBJvgl_xh&YIelzDu8Jp;^{!`Kx zGBZ&KzD#v23bsRHiXs(l)XokrJC4#1PF0JXvMFDR^3B|gmNH+*sT&daj~6BRFydF zKFIEo=inK*_{wQtffM7jm3E>T&y4N)Cf_8lhr_Tu-$u zn|Wcc?Sz1H465#fVF(&VDWMpeB-A-asY_h|_Duhy63x4&5R}Q)BOAnQ6CLfGSc#g& zYQc&!=dt)!u|kgv3t+{rmZdkK>I)c+)>z{+A`jO1f@tRwPd7Oiau3(I^0u*G1(126 z!!!(kQ*Ez;38~yy5abPiS1%v=c z)6>b8BNJ>vVH&Ptd}bVA;b8O=r9Xp1{Ds$y?f{(h!Q2B?mXZ*=q|1!<$`weec8 zdq<2hCngMv7J;O3|5q&bACzptm+XMg`9hRzVk{{j2*N}xOo{cxddVd1?U(F#bICb_ zgtKJea&hV|JmQ6yhSMK%4!X|G*ytjuM>D z#3S|%%RnVV;H^Prm+@?~-#_G=)x|GHNvdQ}=0TiinTktVw)a)%>QQ1ggRom0)$~d2 z+{&|3JuPdyO+#%w*#5e9R>CLAsO|fCngXWwtesu0#z|4q$^CnG@7@{I-uBa2&B`Rt zYUj?qZLM=18k(Zm=xeNS)ugi}Yq-)LNPw?62g11EE}G${&Q{T8O$FH}hp$!DcL zO9wVwyuJOPw&FA|lzyjS+Rcul(D)JZ#@#nuth6xiz!@wo(g0AbqZU^ben zD60J^n&xshrS)1AJ=q$QGC_*k@Gi>e~{{aug?{NSC diff --git a/vega/common/__pycache__/message_server.cpython-37.pyc b/vega/common/__pycache__/message_server.cpython-37.pyc deleted file mode 100644 index 146138e4f1dbf4cb8b2162878440a257a668f9bb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2696 zcmZuz&2QYs6`$evu2x@`;Uub?ZV|u8ce-^|6Ctj}z zXHgtk+An&ue%uEgOnAH|`T0PuU+{QCMjtQ{ir^Exk(+1k8IQL_Bzm8)cw6+v8ulB) z*=M6cv-w+D*XcwK_oX_KYVRwU#U6~f2_8LFG8LeDbzV*isjCtXYai8Bd9TbWA(dHQ z33sv?iCzvD#Cw?a9W0W;KSsX6FGslez6U#eSjb}^mPR-UL=P6NDT>ewv+B;QjN_CB z?H=a)nDvjbzxxLb#|7{1V zo_r6!o6X~eRHunf>!Tzu535mNqNiyo3aM%n&GIstS4x9QpLLXHpSP5p|L@NekXIlPJ`?qDI2qYfN0qi5j*Os4Z-)>>4E0w&IMaK zt#j^bKJ_m6(rw+Tk9hLW*_VjwG{6jnE4)uVinor=d&0lqoaw$^YZ2|~;Ls7lr_Rz7 zVe2L9t#`h0fppQg_K`BZdDK_uoGpFMnj4TuMUs5$|CK+c%H~V|&;AqkdsaG6@QjT0 znw^J>@_MQsG=a|McgN$e1Y8-pD#E6ESmGX~M5|oZ&G+|ps`PVN1UA=|8fshAdn3mL zct4Tac-ZPk%@2Q*=7khPT@BThZ(f(%Ruo9%HnU?BzF(DEKGP(|V>hCb521@=)$_$kwwOcK*i5NWX}h#n-# z@ggl&Z^9%IRhA@b6EbQ@3w`j45d>dFKH3ho3!;e?sT?3G+z1oAHgf|yNbut`&R23 zwEnNvSjzW zEqU3FWA6!M=?uN0|Df4j6~Hi6DnF6p?(k2uuB$KA8%Tl)4lSmetq;o+bbT>Y60N<~ z7XC`si$a@lC4xTfW<}j>Uh;P3I5d3Ecp}y5Xb|_5%ubRiEsUnI44}xfT&>gk$T0`Uz1# zpoOlwE5&H*@HL#hi&?*og#o$|;2QCp+yS5?ejED$Cl22LkZs8v!8E+UACfP+%XRw$sfToDUiREQ zMG6tu@DWZwp4V|$FR~070_4dZ6LwduFujAcmdQNTPe)OFlP<%(?lRo79M!y5K}P8}pr0G?zB}O2o6#HJ0?9C)$N&HU diff --git a/vega/common/__pycache__/pareto_front.cpython-37.pyc b/vega/common/__pycache__/pareto_front.cpython-37.pyc deleted file mode 100644 index fb074c396ca5aa73eaa96b3df64c0abdaba5fdac..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 723 zcmZva&2G~`5XX1+BXNyV6^Tm_hpHzmA)FJZilTx7r>YW2l_DWyyxWjqd)@U0C6Rqf zIr1Pp2hYG|ublb{oS1bA1c|X`c4v0?Z_jUiIv9imt@!r+-3ce;hfjVR!^t!B`Urz0 z6Cy#<941r-l1p(76DE7Ie@!M_($^%~TMy1tqpUu>G`h5pZ}1p78)wvyXsF-$a~$hO z7#h;j2F!!TE5m0hsgIX*h;>UEx`CE8q+vsX>TG@YrLu?B?i7XVuDuf8TiE`{^(;4jk@50LG`8|2S^~y@%d4ZgA3&Ikt>#f`vIEuuRe~w57{8 zmN&407v$}$mMu^JWKUXd1)|_XJnMmc6dN!j+xw@+eND24Htd+3ldI7MxjMQ)kcf}} z2cNm6C!}3fY8>^Qs6V8Ya1_+RKgdhHnz8p3|}JVZnxbgL1JBjNfFHf3aU61;}72G$6&V1ivE0caSrBX&z7MRJ&) zkqt8x5Wqg?kZUgTE%^cY1<5(LT;`aZ7P;jTU{Cp~dS*C7k+K&#h4%DR_jFfRe^p&o z-TGp6wPxTq{LA0Jeg2VQ{F5rluK?l^o}i6F7#%~H!Wx(zliyay61FG|>|vo(Fb&Yf zL2+2>l!oO_*)*heQjyiSW~cVV5G7H5V~Da`d2M&r#tY1 zpjE4ds0Z5*=;ijr!|j@b-<^B+9RDaBAB9|gI*zmb2@1 z50A!DIGGWr*#)5-J0ou>ALhO#JZF-$y1|R2HUHjnn28qH5nkvmWu@fUSK(Q#4thc8 zt5>|frEG)@y6%IQzB-i30cx#&;EyswZ8<|<$bm!s9A7yHV;zqg$3fkd<48!Z5BDW|ML zp0%LX=KyV224M!SWN`9MCZ^?O)SqRQ*j=*9Ot&0jix8)_P!7kbs$TShBX5va2mZcR zTTYJNSPpt4*}~VPebE~NdA*T?Ud%7)LUU@R=2o$g*;d?hArtw+8ytDC<9G1ca?+m6 zr#4nTKni;@*kQ?HzgOdm_8@L>v^(~~18#z7xw?|U(ICt{(f-av*)*f_mxv*VMNKQJ zQCND+ktHfWkt3YQEW23 z4*HU~%=8B6E8;5C=RjW**O@*qTH*uoAvj(@>qo-j){CG&7B`sQ1id9bVfvE1)W0k~ zO>1RwQ{3V&SHx}6d}DO3if!>DaR>aaaY(-(l?gatIB7jaXuRkhk(N;T+5Zwt?3Ji$*;OpU(LH>ZYcg?7JycX4W}>r-x-TI%N1;x=2| zpW2{G!bEamO;&$8>V-}$POIH4M-|>&-M#&&%v4WAE3O-puB(ceEUKpF7S-GtU5)Lk zCz^$*5Xixvs-U}~)JfG*{K-%ayuQD?y|3(N@y6OXf;Hdeg&5A(djJuy%8D z?ibpvyxMW%<+Ylos-g>Ct9Y@jx>%zst>u?^0z%%fCu_;1@$kiS(kQJRp7rW%);RaU z7QV;oZCa{~&QT-p(w0cGfv(T+B+_gq(&R_Kmr$2xLXlfPCDBU8qb?EW97OpxBhL9m zoXpVok|hHW`R#CUTAW7Sp)Zl>Jfx8siBf>_qu)!Y4`%?)5`fc!ojV;?mI!wNg7q`P zT}|LR&7o_fAN^?4)AnOiegtBQ6UJc$+Z;}n-(Y!+E#EhfvEBPt-#)gd_DMllA>6@h z^Q{Hi25k`yb1ZmG?2RWfOu{zVp`+x$3wtl56Z-6hIA&-Mq(c>D20LRLIQsCQ@YPu6$Eeoqt7e%x5$sEF(bZc)B4STuMliq3fFR8BQ!+eK zqn#OxcP=lsp6=;OL9Pi>zBmNji=%@@=H}JYL zvca@j&>klBj8TyT{hZ5*&YceRKBfuyFuJl12Uvn5oTS!sx_fHg z5w7WMCC$ckG@ zhJgdoNii%@hJ=!HBkaTizLlp1_TUA;VtIC)D^XU$8$3x#<~`H@V``)qq*pSp!nrJ{ zS(%;a<{bc8mfEC(T|sd!%~|dgE>h~E-DaFPkt}}$Cw<~uq4-Qs$-GVu;)BHE7mmxB zeR}+xJBsc8>HxPYxD1B>gdcUp!1u?Lyv`EWH1$iv@vI}Rh)2G@YbU^xl#%Q`9hMyF zsu<=n|HTHcZegMSpvAJGmnR#^^3s9nlBh?Teq2W-G)_2i;{al!X6w`~_H&zGa~d39 z9pi#Y`}n9p*Gq8WHdyF9ft@+)ZF$-hRnuFkf04X&q&93$oKl(79;wWvf>aH;Q}~rw zx_y0r6rUo7o}|M#rPNCRqYU`pc#`99I(eF)cNt1_p*KD37ib(`i{P&`w5uoL8ySXD zA=gJfn+|qf_7Gj^!cu0OSA6K<=FEX7Rrd~`N7BzK{m@Ppqk)Zz+-E)`CE5OiCpLl< zy`08BM2AG5!#BRzgXMrv;p&g zQ;$?$(lp&Bxw`xW%im_~g_`_L=^ivzc_0(F6i% z#bHTpPR$cT&@w1nLmLOon_ug*9sfm!@pim!*H=pLO`gwxD+^leEGHYk(QYn1)*Qo2 zTqN;n|Hbh?s?y);7M=-L$>Rmy3;lI`)x4eoVy;Dp*mh&YM&v$ETW9ko89;>_Y5Db%$E+QL#-0-5O>JzK$a|@+F?2g#ww7 zRkCU;^?$IYX* zwsP#3Gy;@iL{%XB6?#L;`&2K{P^qGlzuTAHka=fhWzk9CsPaI={P1(sIEO1BqE*bg Jxo*|${{pw1R^R{t diff --git a/vega/common/__pycache__/user_config.cpython-37.pyc b/vega/common/__pycache__/user_config.cpython-37.pyc deleted file mode 100644 index e05a78b6ce632e3f2614578cc0dd20b960e00d35..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2358 zcmZuz&2rpC5S|(RtX5ucoL~~5LJ)8u8>)7x6bDFANs1)JC2UdrC8dHAvSxNITYnzS z2Jc3DVRPaca?JC4b0&8Q{yjJm)6`m?{@-1G^I{O=dPy#VVU zyg3;T7jpc6%99d_XjfY7uWBVg^JUn4(^LdJj+xN7d-f>7BWc6Vv3}3?c z$$)B)3sEIyHGfIl2QHidegmevby`d_p~^yU%qku$5%Dxpu*rWCXDN^5EE5Tyb22P2 z?!v6cum}=T$dGXwTAYP8X7`MQj&Qld-Dd=%?g+2V1MIHo?)kinS(&>H?(+>}_c(*x zdfLY{+UOP!9+oi=X0P%JCbG=i2a_}tChc6vsaWS7{GO`afUiMmu^5gM27%&mVPb9W z!K@EqX(*@rp4?J|vlab=D%#K$ZAio#a%44Rj~!D=4j-a*WFM3JpFmCO;a(r;}Flel`*EV?BOGkB9vhBsCwzSBK3;sZ{wzKYlDPOp1v?@bmxTyQ8Hl`6 zb6|#7rIpCiLg>y^6hg+?psW3#<7_4#NLfm01A^L%A|Op1MY@;liex`(&(=<=L|$t@ z1%A~kE)t<_%td;*>w>9TB*jEX0q+K`wri11LhmGtu#2CMiZ~bANp{nWYZ`^!T8S#c zqyHlEZ8+)Q(P%0^8!h(sMp9JeD4wY@&qhzgG#({oo|nZ4`Lm8{Sk1Ma0jb`Bt@So6 z#PO&{ecGejuzHlyOLTGJ6$gjteM%tA7)fLZWSW~Ql0tJU%B|;&+s}a<3dqyYhOI21 z&JCcL^DYwS+R|#QhVFOd4P|ln$c8(z=Nud29ozt7bbuJ%V)GutxF+;a+iG5kAF&)# z9L!QFxWwQGOrrkC?#Z&)e!2|OIymtP(js`e91mQ%1ul?vS}WN>v-{vM5f$o#b|^bU|f!vU2&aDdAtxSWduNCs^D9A4C> z>#f&{iSJH}iSJR%%qOpcp`dJvu2_9Zfu5lrKp9D7aSViw7gun5kk4wzD%P`oKAM#2UC;vV$*xV>WlDAWetY8XMkUHk@Q2 zUys4BgMO`0DGxZzR>)?r0X~D(3YWXM^X>lx_q2@zwJUv&Qlh%cnMofDOMM-7?h0y7b}M;=MBj zKCfMy@_pE?{V2*y4xNHp7)1xOIBT0qcR(Y{B#I0#t=Ns}_UX*atQC0fLEe I_N*TL51F4UJyb5@7D(V@fuewPYz{pH=q;!AZ|G%v3DA=-LC?Olzc))!A8mn>*qJwPX5P-c z_vXFd?B$7xvW8#huYY^=!*Naf2Q_wn1rQ(LNq<1WHCN+|>k)IAs&!Xawc#47HeFNI zmTRfH;1*SFySA!JZV9y!l{*!;!bp}GRXY>zgrcpe)|qrCnHHE^Q^BEa=1%hhFFw`W z!`$X2Ugi~E$G2&Cl^SzfpnUYpVEt+G>qF(Ad@&}HOgfNzxl2&{0>;@e9q6g zB9J{1&pTn9Nf$%(>LHvXv9_004|JI(afHP92Dw=*yGyv)FQg@w1j ztk1Q@n{)k@l{pb~lR3X9lTI}EBxw6{&7{*w;<+AazR}&ts_j6IrOeXO$tN|#W;ND7 zGgeRIQj|3PD1E=NcYK`*+I3wQB~+q91uXy!n`p%)@6SqONvF}w-%&92BSld~+cl#k zP2EB#kRoil)kt`0b2;F>C=jldbY<9a%bkP=k;j9y>6WGNMp5EB&muegY%{y0v1LBnT?)d3_^-3JWgJOi-=mwZQ<}2xSD>GMvjWjEV zuvYHknMsqgO1^rJ8iF*Au?AL}8CZr}4Az@bj|XmbM0iS!fl|3cUXUM%DHL;fQX({) zyle)~3B9IIG1c27kD^}@&)A?#Xbf0t4{&nU7jqNBpf^9)hx!l(gJJb8M9po)&*R$A zT0c8(8(jZ2co7tK+ej5S{T z3%`ZBky)V(I%#I5-6)iqf#X*%A`j3psP2{LR+eE@;i=G-Sx@|^7r14NlQ|1Hf@ti* zEvYd?9xzXKdXJqw+P$~;73yeBE8CdNWYc(R>=>RZ6DPoaS!~QjNjsf;Swz*4>vDWL zDX7q+L`)%iOKpY06fjtNL~2_&W7~|<)XiHZ!1@!!KPfu0#Tt-h3J-^%2GogCUqsKSphy5_?^sBv$6t(P#*`FLQ|_F zGm8gHy>@1XaVyEpHDAP;6$p_C*KUeXhDZVZ(+dfrRIl4jgbet2y-MP|)1Te5(Eq(W z*SK@*C5!w9y1`vls=nHXVe@vr+ z^Si?VOO}w}0=YkaDG@s|RhL`vi0L)4AoZi*PH{kiq;nt$t3A^;_4Jg)f(aZUuL3M> z8aOR_-bXo=LSb>)PZ5jZ9C&q3!1zP}aFEojunM1Tr@%%fVtaHhGLx3HNBJeSBw~Jn zIXM=AKBrm#DD~_@o$)vVqn;Axd3e>km?nNo1(_9$pAfCM3hbCPF$HggbIN+n5vAl+c%9U2gq?0yX@xQOa!Fa`R)z))rt8@-0;QcnYc`%&PZayMmjRLlG=>S|aHX5WN}-QXYT< zsK!FQsP>K645!1Il>!98xvUVTG!?OVnXVR@sqSr-Tgaujg`;ZQ0!a)X=OI%-SR);4all&t0G zp@XY3pYNY?@au-%xAR+T6gd|!_R>H&uYWOo{rY0PBFNpyCaKKIxKOH-?s?)Uv5UAr zMt(B@U`W=G#u~&dP=bhgRv6{BAAu<-pw|u8-r0d$X!)=mca839`=E&@pQ!D6JA=Q$ zWRy| zZ-PHB5kZI_|5I@b;4O#bDc>tkA)6I(d~kEv#HWLSj$6XQLoNP`8f|0rmh!7_Wl5Gv z(k5B@$PkcrtFk*&+4?Ew{Z`GZV&3;^-nsu{-gmfER_RiSRem&D`^2XHoQZQJ1J~Lb z;-R+Q*eC1Qe_O+I_SnWMOx%?64)#Z^YW`;PkbfJQ4yeEIfUIPR0V zZ|HH?wbT%#h8SpZG$OMUAtqs$Y*Hv`UflK6)kEASu7!JgJIIVEh{Zj{ki9N2QZeSf z5dKD1kO_HS-4OUR1~&u{^~GV32V6#c2Htmg(sxj3x($b@sz130xy%t&Pvbs9?P=8} zXCfGNLWld1gcl#cPhU*Fu_v0+*Q0SNzJms$DZURNmR^>0L%|V81T{8Jl{q~jK;hMqHX2`Ly^1m8x0$1u8Xx^^dwJ(>^7UsqISc_aZS8temV%chZw zyz!gW_`gF^I)j3=s6&Uk$~1JB!#%DI{`inlavtmE5*j1jFld9(b$g(Vv;j?#HfX=l zh5>E7n`?tW3lg#kJy`wfD9r6DfDnf%e25lJDmWEtuexXv9Z*5pE-w_ms*J)1iFtvp zJY%ZeP|l>Ch-L7B%udowg2SLaq8|}`CM$X#PnrmGDv{ryo;nrhshCIMT5+$_-4G9m zu|x$eGAk)~OCK!NmoIUXdMN*7<DDSCuK$KKATiX+kN%zLX=a_vx!!N_PPh8rc6=C&VZl_VlqCd%FCpJ!MZ+ N57{Mq+CEdR{Tp&ZzD)oC diff --git a/vega/common/backend_register.py b/vega/common/backend_register.py index d2a1e1b4..fe927696 100644 --- a/vega/common/backend_register.py +++ b/vega/common/backend_register.py @@ -11,12 +11,13 @@ """Backend Register.""" import os - +import sys __all__ = [ "set_backend", "is_cpu_device", "is_gpu_device", "is_npu_device", - "is_ms_backend", "is_tf_backend", "is_torch_backend" + "is_ms_backend", "is_tf_backend", "is_torch_backend", + "get_devices", ] @@ -43,7 +44,7 @@ def set_backend(backend='pytorch', device_category='GPU'): if device_category is not None: os.environ['DEVICE_CATEGORY'] = device_category.upper() from vega.common.general import General - General.device_category == device_category + General.device_category = device_category # backend if backend.lower() in ['pytorch', "p"]: @@ -71,6 +72,14 @@ def set_backend(backend='pytorch', device_category='GPU'): register_networks(backend) register_modelzoo(backend) + # register ext modules + vega_extension_path = os.environ.get("VEGA_EXTENSION_PATH") + if vega_extension_path: + sys.path.append(vega_extension_path) + try: + import vega_extension + except ImportError: + pass # backup config from vega.common.config_serializable import backup_configs backup_configs() @@ -104,3 +113,12 @@ def is_tf_backend(): def is_ms_backend(): """Return whether is tensorflow backend or not.""" return os.environ.get('BACKEND_TYPE', None) == 'MINDSPORE' + + +def get_devices(): + """Get devices.""" + device_id = os.environ.get('DEVICE_ID', 0) + device_category = os.environ.get('DEVICE_CATEGORY', 'CPU') + if device_category == 'GPU': + device_category = 'cuda' + return "{}:{}".format(device_category.lower(), device_id) diff --git a/vega/common/class_factory.py b/vega/common/class_factory.py index d96db02d..66e6064b 100644 --- a/vega/common/class_factory.py +++ b/vega/common/class_factory.py @@ -49,6 +49,18 @@ class SearchSpaceType(Enum): CONNECTIONS = 'connections' + @classmethod + def contains(cls, item): + """Use the contains method to replace the in operation.""" + for _item in cls: + if isinstance(item, str): + if _item.value == item: + return True + else: + if _item.value == item.value: + return True + return False + class ClassFactory(object): """A Factory Class to manage all class need to register with config.""" @@ -78,7 +90,7 @@ def wrapper(t_cls): raise ValueError( "Cannot register duplicate class ({})".format(t_cls_name)) cls.__registry__[type_name].update({t_cls_name: t_cls}) - if type_name in SearchSpaceType: + if SearchSpaceType.contains(type_name): cls.register_cls(t_cls, ClassType.NETWORK, t_cls_name) return t_cls @@ -189,25 +201,39 @@ def get_instance(cls, type_name, params=None, **kwargs): return t_cls(**_params) if _params else t_cls() # remove extra params params_sig = sig(t_cls.__init__).parameters - for k, v in params_sig.items(): + instance = cls._create_instance_params(params_sig, _params, t_cls) + if not instance: + extra_param = {k: v for k, v in _params.items() if k not in params_sig} + _params = {k: v for k, v in _params.items() if k not in extra_param} try: - if '*' in str(v) and '**' not in str(v): - return t_cls(*list(_params.values())) if list(_params.values()) else t_cls() - if '**' in str(v): - return t_cls(**_params) if _params else t_cls() + instance = t_cls(**_params) if _params else t_cls() except Exception as ex: logging.error("Failed to create instance:{}".format(t_cls)) raise ex - extra_param = {k: v for k, v in _params.items() if k not in params_sig} - _params = {k: v for k, v in _params.items() if k not in extra_param} + for k, v in extra_param.items(): + setattr(instance, k, v) + return instance + + @classmethod + def _create_instance_params(cls, params_sig, _params, t_cls): try: - instance = t_cls(**_params) if _params else t_cls() + has_args = any('*' in str(v) and not str(v).startswith('**') for v in params_sig.values()) + has_kwargs = any('**' in str(v) for v in params_sig.values()) + if has_args and not has_kwargs: + return t_cls(*list(_params.values())) if list(_params.values()) else t_cls() + if not has_args and has_kwargs: + return t_cls(**_params) if _params else t_cls() + if has_args and has_kwargs: + if _params and list(_params.values()): + return t_cls(*list(_params.values()), **_params) + if _params and not list(_params.values()): + return t_cls(**_params) + if not _params and list(_params.values()): + return t_cls(*list(_params.values())) + return t_cls() except Exception as ex: logging.error("Failed to create instance:{}".format(t_cls)) raise ex - for k, v in extra_param.items(): - setattr(instance, k, v) - return instance @classmethod def lazy_register(cls, base_pkg, pkg_cls_dict): diff --git a/vega/common/general.py b/vega/common/general.py index 0597329f..7b60d212 100644 --- a/vega/common/general.py +++ b/vega/common/general.py @@ -36,13 +36,15 @@ class ClusterConfig(ConfigSerializable): master_ip = None listen_port = get_available_port() slaves = [] + standalone_boot = False + num_workers = 0 class Worker(ConfigSerializable): """Worker Config.""" # distributed = False - timeout = 1000 + timeout = 5 * 24 * 3600 # 5 days eval_count = 10 evaluate_timeout = 0.1 @@ -86,16 +88,6 @@ class Strategy(ConfigSerializable): only_search = False -class QuotaConfig(ConfigSerializable): - """Quota Config.""" - - strategy = Strategy - target = Target - restrict = Restrict - filter_rules = "(flops_params)" - affinity = Affinity - - class General(ConfigSerializable): """General Config.""" @@ -111,7 +103,7 @@ class General(ConfigSerializable): calc_params_each_epoch = False dft = False workers_num = 1 - quota = QuotaConfig + quota = None data_format = "channels_first" # parallel parallel_search = False diff --git a/vega/common/searchable.py b/vega/common/searchable.py new file mode 100644 index 00000000..0302ba2c --- /dev/null +++ b/vega/common/searchable.py @@ -0,0 +1,138 @@ +# -*- coding:utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. +"""This is Search on Network.""" +from vega.common.utils import singleton + + +@singleton +class SearchableRegister(object): + """Searchable Register class.""" + + __types__ = [] + __searchable_classes__ = {} + __hooks__ = [] + + def init(self): + """Init items.""" + self.__types__ = [] + self.__searchable_classes__ = {} + self.__hooks__ = [] + return self + + def has_searchable(self): + """Check searchable is not None.""" + return self.__searchable_classes__ + + def search_space(self): + """Get all search space.""" + res = [] + for v in self.__searchable_classes__.values(): + search_space = v.search_space + if isinstance(search_space, list): + res.extend(search_space) + else: + res.append(v.search_space()) + return res + + def add_space(self, name, module): + """Add search space.""" + for searchable in self.__types__: + entity = searchable(name) + if not entity.search_on(module): + continue + self.__searchable_classes__[name] = entity + + def register(self, searchable): + """Register search space.""" + self.__types__.append(searchable) + + def update(self, desc): + """Update.""" + res = {} + for k, v in self.__searchable_classes__.items(): + sub_desc = desc.get(k) + if not sub_desc: + continue + v.update(sub_desc) + res[k] = sub_desc + return res + + def active_searchable(self, name, module): + """Active searchable function.""" + searchable = self.__searchable_classes__.get(name) + if not searchable: + return None + return searchable(module) + + def active_search_event(self, model): + """Active searchable event.""" + if not hasattr(model, "named_modules"): + return model + for name, m in model.named_modules(): + for hook in self.__hooks__: + hook(model, name, self.active_searchable(name, m)) + return model + + def add_search_event(self, fn): + """Add event into searchable class.""" + self.__hooks__.append(fn) + + +class Searchable(object): + """Searchable base class.""" + + _key = None + _type = None + _range = None + + def __init__(self, key, type=None, range=None): + self.key = key or self._key + self.type = type or self._type + self.range = range or self._range + self.desc = None + + def search_space(self): + """Get search space.""" + return dict(key=self.key, type=self.type, range=self.range) + + def update(self, desc): + """Update desc.""" + self.desc = desc + + def search_on(self, module): + """Call search on function.""" + raise NotImplementedError + + def __call__(self, *args, **kwargs): + """Call searchable.""" + raise NotImplementedError + + +def space(key=None, type=None, range=None): + """Set class to singleton class.""" + + def decorator(cls): + """Get decorator.""" + cls._type = type + cls._range = range + cls._key = key + SearchableRegister().register(cls) + return cls + + return decorator + + +def change_space(hps): + """Change space by hps.""" + all_searchable_cls = SearchableRegister().__types__ + for cls in all_searchable_cls: + if cls._key == hps.get('key'): + cls._type = hps.get('type') + cls._range = hps.get('range') diff --git a/vega/core/__pycache__/__init__.cpython-37.pyc b/vega/core/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 141b22db0d429fa179e46b2c1c532cfb62fe4005..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 314 zcmXw!OHRZv42F|Fre)9|A#sVan9KnXLLjk(1+d8`R86Ql#7*V2!{bJrgEO$nmepK= z6>$-k{Kfud$M)rFwPexY^Xu^?zTcGmTM~nd=zbHSndX)OlMJ}zkjR8Ho#@o2kjV^k znZrUZppZq>vpZJhdmh`WG^9B7p)~zY)!6!Q-1Xh48aw&X;4JN99jmfXTzA&o4n7q3 zwST5y`QQy+x&EnZ!x1<=S)&5}N9v37RgpzO&@xhRe uo$+%jTHJ{D?N*=}oT#VKffe+FXdD`$)V!x!4{xD7`QbodG5KV9`^ek3s3p;f4;rf zv#kG8XZd(AK0=Ydpkh{RF`GGs9owdLV#l;@?4orGuWZGw(vN+c=6OX>hH+T7Wc+hW_#?Q=VRjqR{qcwcAE6Kim5-aF>w zG!nCFI6WC&;%?T#z|mn|~=xoGTxt%HB0^5}08xs0k> zMRGQs)}WaythPo2xbpl-75GoFLa1r zLR&Nzu^?_q#i#OMu%`ozj`M1)+heXC!{$PVp!Q?4w)NMKAAbF*AV`rab;<;xuiJ;B zE)$k#N_#9-sdlkgI5gzbeq4xG(GkQiy;9U!T2QtbI_eNKNm4PE+9krQ8|q{QIKxHb z5Agj|DIvPFK3rZ)w5b~`a!#)@ew=4K$tbC6aWb&QHhq*x)l|D=ZJ_!iiX^|)TC;o3 z9+>0UJ;R^S-muM_jE0v_0K;%7hCv)6ejv7FIbf4J@g}W{YhPpe7 zaW05a*E@8d`JbpgPH!_=A}?k#(h8Y*tRCA%S=UZV^&N zd3OVAgSr8H3j^|g%jwu#NYz(R1dUw5{F3;+i4hXNi8FDTbrgyXW<4>IcV?X; zVU-U+0}2882ll}rY);+!7bzY-WTi_cZiSm(%qU8};Zv&G70 zPg7&{sRLD|K>N)9(W5<1ZDl3lsco|17U^W37WtezN>xGNL-D#sOQ5MwYo~O41o2Wlpy~ru+Lr|omV7dkr67( z3Be(fpcRrK_cm(J*zh>u7R@G(-_b#musTZ;u?m+-UGWaAc%6Qa!>lgLx*E<@UPz<9 zf+DKBHwK&Y8~xt8;U1bauaVUh|9E|wAH$_#a$I|~q>-M?UjzD?3D4}vSOY78zx36@A!3thM0dBmztQ1i#$~Kl+E>>`6 z4hPqgSJjxub)gr_qi2!wFfOzROSohT=`a#p$I})<{fXvTzDW;9B9Etv>!+=BGnaKX z;nhY89C}c#G0bpT1>Mm7+JJV1QsW>TS?%dK;3qV3x)+x#?KjkkZrt zQIsn3Fg8Q7$G<7HYs!7emkXvvnn5%A1L+>9Yxi?bQhi32vl-E%EC`?1g_6YZi1`}Q iK13MH@~mboEEkP;p^9WK#iePc8iE^O*megy>-`HAz?nDz diff --git a/vega/core/pipeline/__pycache__/conf.cpython-37.pyc b/vega/core/pipeline/__pycache__/conf.cpython-37.pyc deleted file mode 100644 index 82242d044a362cc7d01d11a196b52c7bd6fa4ac6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3569 zcmbVOOLN>r5Y|ZA_rq%^&XXkKQGq37H&7|6sG_JOP9RXkq--B3T@*z#Ys=QIHZ$5l zHtP#^RrnD(Bu9P+zk$nKIq?@b(cOA@9owNmTk4%hOX}(FufM+2YSj(6hQI&u=;5Ma z{E3tNvSDxwZuJrbH(Y}=Zbocqx+cT11!K##XlzIJu;i9#T#CxWid&&^IjRn8ZjHv3 zs6K4C4H{RY=CI|qnDLFlYrMX1@Vc<}EVm8w25-W=N%IAmw|E=oZJKvrzQ8*$@6dda zFRmNirRjycVi1l5cUI!jz!}8Sc^HmG1WUcYzty>N3~$k30% zssA{_)y9LEizwZKX=_9J@CI2dEZlqQM-xAZ50+YY{lr%yDRk@MeW5zf;Z>SrIH8Z%mXTpl{$etNsk2SrvCKR|s6yKBa(D5A}21yu?eEH0QQ25Ru zj6_e{2{>ObY=M+xU;0Cl2&uXyT~ZTRk~OeH*9Pzk4_}}wTfXv>M9KzibS-TKk*<8T+jQ>_`BcVlsEd*?d`r4t9VR_}q}EHOfGS;9b?FU?&u zF=qxkbq}Fpf!pk9YXvROoYs^0;jT%4x-yJ8z^_!_)`qT%oqRJo)tIV3_^zzYD&Z~FX2Bclp^F|OZ&(n3! zOAUl^)AOE8{3zS0cpi@f&r6TdpnyPxw_;8nmgmuwRu-v1L=mVKhIW};f_;hhQXVYB zfC5rwwq=_p>r`nyfc4@kFm(%-j$&%N>{b9>RZ)8bU3FN=(bXVyH3?nIsV{Cvo3RX& zt>FrRXnOq(gcbWm*3Z|QlTh{<*mxq!dKt=k9FoO@ngH_uV2qr<+AYZ@d`&tiFuipn z0|1&i`89Brn+*yx(ND#epYC(AyAx! z_%w7InRGo7;|vW~b2Q`(`z8t;RPO0^nkMgIhc!t5R5sH3=LBTL7&%9kC3$X<723%CzH90s9T$zBwz*$J_ z+}<(eV)lIcm;B!dMuCCT3Xt@igP7y>0`vxae zPF^^iaO6DSUF}w&fxzCvejLXUl@X#=UAy;lAja6V=rXjALq8$x=g~a}Uvn4c`F7;1 z101Kkg$#!vRSpQG*TcW2RfX1~4Yvt=Hocgeokjjl);)J}@u|S;<;N%-6xUE-^CUk( zf#ps1Q6N9c>nM=$We>#-6bEiRC{Jkp+sYH>{&#RItb}QKvY|~n^h2JmOVm76tdJ0H z!L5*2=8M&);o97URI<`CgAIFDW;*B|oabEg2R*OTS7wM%w9(6E5`(GJSEi=)C=)0x)2KS2K~5#}=d!LisjhyRc4c2-tE`Y6 QZKyAEKdW=8bFTC0KX<1|z5oCK diff --git a/vega/core/pipeline/__pycache__/pipe_step.cpython-37.pyc b/vega/core/pipeline/__pycache__/pipe_step.cpython-37.pyc deleted file mode 100644 index 10142a2ff184a7b6b2e7a02fa5499cf914a13038..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2103 zcmZ8i&5qne5bpNBJu{ohW=Q}cVi5_+2xcO|jR+xj1ByVC6&r+P2^q)TGwWIZOuBm# zb~HI{q}+Ih-6PMzfhXY7S5CYFC#u>rlOJ1Mu5OpRtLpoz{LNqx5xD9<|Nix_fRKN1 zvb`KozJ?(c2u4!ED6=Y>Qh3^xRa>b=am}gh+DV<-OwI`&r`MTvLXv28`?y^26U7&?I?YwQl$^z7 z6XR-CHhlCixJD)k)d9;DS7pr;OZ$&H2wRfU&wUn^bKdf>ozs(EPigO*B*`xCX zhcEZGM;x}zw4A|8|3|)Pg*xHlDHl4}zAx~s?PYHic?1e4kQ93;G$jHc^k$nSe5XpdkUHkq)-%BpC=iT007`&`y)rT%aG? zZrW!$k}4M}L%amiOoz+HEODkiPyjUT*Idec#&y6OW=wU|Eb9!gDNbdl*Dd3fOl<8) zUQIxDETEWNYF;dAWnF(WY zu7fNqA;2vA0QQhW5X1`T9*t-~J#(+FZ-E+ZZ3~KaaA8O|uECJFgCdtyS#w)CbN9-6 zZY%GCT>5MKiURO-^~s^&Kwp!{7+nC4Ll&9>&_9n~r|u{TgbN#Lf5uf-R8n|23y&Wh z-_K4S-9HvS&K(8Z)~;MGfWUs1HT+qYC62b?3+-Tf#7(qw%^pDcFoVs2S6ODVCfP*08wFp=^;0Sib5MxQ&T7!Hy>8 z%)}P&z{es+f!=kVAc)2tZLxL)6KSYzuX;OQSo?+>=wWk~bH6MJy5*hhiP5!X@NKyd(KlDHjFZB_y}fMc$d5YT-|I$SQW zj~N;Q0jKRIDcZpaAcvMFBd^jDN^s2>e-2wn)6cC4xEj!Xy4v4mX=hd-R(RnPlWCIm zDJW~I$ef;8n3hZ1^aBSH16^T_-N0J|X|Dk04YoIaGTwHK?W%S(0fHg`I+#KM$y3fu z*^2?Jy`2)|))sY4ly(@G#oO@>b~(Ye@Fm#oVvM>L?NN#BQ!E4gDo6) zkl2a>>R8dB?*?uVB>UnnENQ=LXS1@Ib1)RFHP=Dg?EL0Yhxhz2iX%l@)1my(sr_W6P`U8X2Sp5rw)%i20wts_JJ4UN9 zdg=Xmz>~PoJ$VpHFOwqfcf8@i3zNj7z8%R2S^F{WYHU9?k+)1$ns=aF`7q<+?Xdsw3$zs<@IDt|(y~-zzG%K! zUZ<1I7B^Mx13pNF+~ML#Pp-emvn=fJydk0^&L0bfyGYqjk#PgcF-S2`ipA_N3@F86 z1-uK)WktMQR$^tmi>$(`c$acv`DNB%OQ=^^6XIIVZ%&6A{q5Z}WL~5%o_S%Pc>|Fi z#f*EOiDCaUugB#<$~>7urlFAD0J21pgjuP0c;4wlw77*xp$&pbgqm`ual&Ei^Cqr^1zFRLhwFm zoDXI_`C1VvnI3wR9N$LD$R0Yz*a!f5NW$EP@(TJoC^#8U_J41PWeob)TbmtmYjf1? zZVIStGaSmamuw#KPPiGRf^X7==%R~4dvKykK@j(283Z?JGuqWKN6Ul=8rq)sF3wiO z06LOB+i!7FGOjO-iOKAVC5=;~Yf7tYccJGmP1DdV1!)DYES?%u^9+~MrIIXnD`yyS z@^@LCR%NwY8=Geqdh=R6uRZ!Bjr)o5=^rMxY+&`JX>)9!nG;L(f_q$$mb6$UzuWL z8WYzvMt9`K^aWPaCqPNQIK6p>T|pl8iNzZ7rBmb3#rQ5(dRe~GeL=Ry?x~i{M+RGB zOB44u?#IS2jlT7QUDxJ*;8(Xm`)hFLLPbM*28mN4}8 z?IexDB)ijoV%~jNh7}KxSF*!+5JU$&Ivm}4<9~ccu<#ui>6Z8TWVA+mU$j|jRgE1U zis*p)Xr!v|iZJfat7M>k+;nvHeh7BfoP?8~D`Hob2AE3Lyxn?T)JY6BMu;+%iZixW zUe3P()kKX3H$XN4Efmm>nW&;sEK!e^dYHf;yz}6l?Oj#Gd-vW4 z?+Nm!${{XOc85z?d;>W5?>=&>&l4?DmUHDO$3`3cUASYn~M{S68y%+S} z3#D=i3S6iH<9kB@LY@q`S^~Gix$y+6kknK;n7634c`4W*Cdo-4NsfZBsoZcd;C-en zma6i-M-d;$IPI$f7a|p^jFA!F3!_6-$u|euB02zpL)C`^T#!RDC{m$Kn&b9Al( zI_YshIv~RU6Yi(2qG-~6o^?HX zE)^Vm#jMyBvte0g1HT%QYnG5#@M|KWA9W6+{=vOrk((Ho1 zcU6<|qd4N{5NH)NRs@|ut$%ZP!297|!c}vz;JoGPt^wy_g3irT>sQvq(KSSmx>lHg z0mp?g9L~OdC1=6Mw=s8AyUDh0-uh0Sf!>N4he0qzUMxQ9rL5)Z+mN?L9HMN1So{nV zRk0WM5j21q57ME?F5Of$fJ2x&tSpdmk26fkh)jtE2|lB%=hywl$c2+8LM|Zbdp<(p z0=5KjfQ00MTFR+F+crsE)$#@krG#vmIWQ^CI(J#(E~;Imj2MvAL%%irX4Y(u zuAi&wNvc`si||R5uk}R`IoH>tErb^q0-#$H6H>RwW)~I=Em;(9kuIk$>P15Gtj8h! zJzPP!F!R}vCs?CT-xqYXtVRQ>XwLLLU8q%E>5j>g`3jx>-1Yo7ZinwngngPWM%*xi zI`w$?Z0V=cLqG*7F)tM!?+xUMph)g9WjYueJ@?-h&@)Em`H(GDK2F6U#TSh149IwE z1`JglaG||wKjN(wL7bw>-8AlNqtV7Bc4=1)6@r4Y227O%4_ScoitkhBd3;dF4njl% zPB$YDp5Q|cAFMJBblo}RCs_`vA7Wa|(V+617tRonJjAQCkB!ln@%0OIOpvnc$PD<; zD`xI-DSW&Rt8|bW@UElvOPI^ee7jYp4?I9Rx~MSqzgLEg(GV6piozk3!`v5Ya&B@V3q9BTcvWzE!9dz-P?Lr5V43W zl>LA*EyR~mQtkjBmtiIXYTl!k@&M6+j{bDW?m9L*mR0pSdU|U9v`K}HN vMTbuQCXIE|y)KWWzVo+rW) diff --git a/vega/core/pipeline/generator.py b/vega/core/pipeline/generator.py index 163fda28..7459149d 100644 --- a/vega/core/pipeline/generator.py +++ b/vega/core/pipeline/generator.py @@ -13,6 +13,7 @@ import os import pickle from copy import deepcopy +import vega from vega.core.search_algs import SearchAlgorithm from vega.core.search_space.search_space import SearchSpace from vega.core.pipeline.conf import PipeStepConfig @@ -20,10 +21,8 @@ from vega.common.task_ops import TaskOps from vega.report import ReportServer, ReportClient from vega.common.config import Config -from vega.common import update_dict +from vega.common import update_dict, SearchableRegister from vega.common.utils import remove_np_value -from vega.quota.quota_compare import QuotaCompare -from vega.core.quota.quota_affinity import QuotaAffinity class Generator(object): @@ -35,29 +34,29 @@ def __init__(self): self.search_alg = SearchAlgorithm(self.search_space) if hasattr(self.search_alg.config, 'objective_keys'): self.objective_keys = self.search_alg.config.objective_keys - self.quota = QuotaCompare('restrict') - self.affinity = None if General.quota.affinity.type is None else QuotaAffinity(General.quota.affinity) @property def is_completed(self): """Define a property to determine search algorithm is completed.""" - return self.search_alg.is_completed or self.quota.is_halted() + return self.search_alg.is_completed or vega.quota().quota_reached def sample(self): """Sample a work id and model from search algorithm.""" + out = [] + num_samples = 1 for _ in range(10): res = self.search_alg.search() if not res: return None if not isinstance(res, list): res = [res] - if len(res) == 0: + num_samples = len(res) + if num_samples == 0: return None - out = [] for sample in res: if isinstance(sample, dict): id = sample["worker_id"] - desc = self._decode_hps(sample["encoded_desc"]) + desc = sample["encoded_desc"] sample.pop("worker_id") sample.pop("encoded_desc") kwargs = sample @@ -68,7 +67,12 @@ def sample(self): if hasattr(self, "objective_keys") and self.objective_keys: kwargs["objective_keys"] = self.objective_keys (id, desc, hps) = sample - + if SearchableRegister().has_searchable(): + hps = SearchableRegister().update(desc) + desc = PipeStepConfig.model.model_desc + else: + desc = self._decode_hps(desc) + hps = self._decode_hps(hps) if "modules" in desc: PipeStepConfig.model.model_desc = deepcopy(desc) elif "network" in desc: @@ -78,15 +82,32 @@ def sample(self): desc.pop('network') desc.update(model_desc) - if self.quota.is_filtered(desc): - continue - if self.affinity and not self.affinity.is_affinity(desc): + (hps, desc) = self._split_hps_desc(hps, desc) + + if not vega.quota().verify_sample(desc) or not vega.quota().verify_affinity(desc): continue + ReportClient().update(General.step_name, id, desc=desc, hps=hps, **kwargs) out.append((id, desc, hps)) - if out: + if len(out) >= num_samples: break - return out + return out[:num_samples] + + def _split_hps_desc(self, hps, desc): + if "type" not in desc or desc.get("type") != "Sequential": + del_items = [] + for item in desc: + # TODO + flag = item in ["modules", "networks", + "bit_candidates", "type", "nbit_a_list", "nbit_w_list", + "_arch_params"] + flag = flag or ("modules" in desc and item in desc["modules"]) + if not flag: + hps[item] = desc[item] + del_items.append(item) + for item in del_items: + desc.pop(item) + return hps, desc def update(self, step_name, worker_id): """Update search algorithm accord to the worker path. diff --git a/vega/core/pipeline/horovod/horovod_train.py b/vega/core/pipeline/horovod/horovod_train.py index 55f457a0..2a1c100c 100644 --- a/vega/core/pipeline/horovod/horovod_train.py +++ b/vega/core/pipeline/horovod/horovod_train.py @@ -40,7 +40,7 @@ ClassFactory.__registry__ = cf_content.get('registry') General.from_dict(cf_content.get('general_config')) PipeStepConfig.from_dict(cf_content.get('pipe_step_config')) -cls_trainer = ClassFactory.get_cls('trainer') +cls_trainer = ClassFactory.get_cls('trainer', "Trainer") # for record in records: trainer = cls_trainer(model_desc=model_desc, id=worker_id) trainer.train_process() diff --git a/vega/core/pipeline/horovod/run_horovod_train.sh b/vega/core/pipeline/horovod/run_horovod_train.sh index f48fe82d..3207d478 100644 --- a/vega/core/pipeline/horovod/run_horovod_train.sh +++ b/vega/core/pipeline/horovod/run_horovod_train.sh @@ -20,6 +20,7 @@ basepath=$(cd `dirname $0`; pwd) SCRIPT_PATH=${basepath}/horovod_train.py nps=$1 IP_ADDRESS=$3 +PYTHON_COMMAND=$4 IFS=',' read -ra IP_ARRAY <<< "$IP_ADDRESS" @@ -37,4 +38,4 @@ run_experiment() { horovodrun -np $np -H $server_list $@ } -run_experiment $nps python3 $SCRIPT_PATH --cf_file $2 +run_experiment $nps $PYTHON_COMMAND $SCRIPT_PATH --cf_file $2 diff --git a/vega/core/pipeline/train_pipe_step.py b/vega/core/pipeline/train_pipe_step.py index 58a6a0fa..a2ff0bce 100644 --- a/vega/core/pipeline/train_pipe_step.py +++ b/vega/core/pipeline/train_pipe_step.py @@ -59,6 +59,7 @@ def do(self): def _get_current_step_records(self): step_name = self.task.step_name models_folder = PipeStepConfig.pipe_step.get("models_folder") + models_folder = models_folder or PipeStepConfig.pipe_step.get("hps_folder") cur_index = PipelineConfig.steps.index(step_name) if cur_index >= 1 or models_folder: if not models_folder: @@ -74,17 +75,17 @@ def _get_current_step_records(self): record.step_name = step_name return records - def _train_single_model(self, model_desc=None, model_id=None, weights_file=None): + def _train_single_model(self, model_desc=None, hps=None, model_id=None, weights_file=None): cls_trainer = ClassFactory.get_cls(ClassType.TRAINER, PipeStepConfig.trainer.type) step_name = self.task.step_name if model_desc is not None: sample = dict(worker_id=model_id, desc=model_desc, step_name=step_name) record = ReportRecord().load_dict(sample) logging.debug("update record=%s", str(record)) - trainer = cls_trainer(model_desc=model_desc, id=model_id, pretrained_model_file=weights_file) + trainer = cls_trainer(model_desc=model_desc, hps=hps, id=model_id, pretrained_model_file=weights_file) else: - trainer = cls_trainer(None, 0) - record = ReportRecord(trainer.step_name, trainer.worker_id, desc=trainer.model_desc) + trainer = cls_trainer(None, 0, hps=hps) + record = ReportRecord(trainer.step_name, trainer.worker_id, desc=trainer.model_desc, hps=hps) ReportClient().update(**record.to_dict()) # resume training if vega.is_torch_backend() and General._resume: @@ -119,7 +120,8 @@ def _do_single_fully_train(self, trainer): def _train_multi_models(self, records): for record in records: weights_file = record.weights_file if PipeStepConfig.pipe_step.get("load_weights", True) else None - self._train_single_model(record.desc, record.worker_id, weights_file) + self._train_single_model( + model_desc=record.desc, hps=record.hps, model_id=record.worker_id, weights_file=weights_file) def _get_evaluator(self, worker_id): if not PipeStepConfig.evaluator_enable: @@ -146,8 +148,8 @@ def _do_horovod_fully_train(self, trainer): worker_ips = General.cluster.master_ip for ip in General.cluster.slaves: worker_ips = worker_ips + ',' + ip - cmd = ['bash', '{}/horovod/run_horovod_train.sh'.format(pwd_dir), - str(self.world_device_size), cf_file, worker_ips] + cmd = ['bash', f'{pwd_dir}/horovod/run_horovod_train.sh', + str(self.world_device_size), cf_file, worker_ips, General.python_command] else: # Roma cmd = ['bash', '/home/work/run_horovod_train.sh', diff --git a/vega/core/quota/__init__.py b/vega/core/quota/__init__.py deleted file mode 100644 index e3f01ef9..00000000 --- a/vega/core/quota/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .quota_strategy import QuotaStrategy diff --git a/vega/core/quota/__pycache__/__init__.cpython-37.pyc b/vega/core/quota/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 3b51622ef02e56fcdedcb86977f3ceaf9d08d374..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 187 zcmZ?b<>g`kg51Mb6SRT!V-N=hn1BoiATAaF5-AKRj5!Rsj8Tk?3@J>(44TX@8G*u@ zjJJ3LOY=(-gG-7MOH$J-{WO`P_zFQH@x@S~B9K8V8H$*J6qxvBpr2l3pkI}hrC*d< zkguOuT9TieqhFSqo~WOkUzDm3)~_EQpP83g5+AQuP zZY?u%>p+YTzVZ(=M=t&ma5dbx)PV~p6Z`Z7We2weeh;?6Awc#78HhojomT#dpV|(QIj!tu} zxH58mSMhdS9eKW|cqgun>VAFH@EfXKi8n@@{-& za2&p!jQhhu(ca6#Nzuk+mu65;k5*nj2ADTc(hosG4ZGH8$JnjG&CfOL*5Wp5n>)OM z+9_oBE8OEXv|V234b)Y>!8cKRyvfg@uJJAG>bz`xG)=N#FB3r)4i0Cp7vi)gfe1z+ z_L;T<9!=A16ppi2(Ca0F566R6mbAoloDD}|D-E*@U7el|3F)EQM0pz}y$_OUkscYj z#!SUo;B3V?Q3bqP@>TFvu7i3*4YZck7mZwt%xL47dN*@z(LBcTT!ZLFcYvkU(>^EY zn56xf_99&orkNP_GP$|rSg$yBH|Y0=<6(9vH<#RM2=Ayfkd<<0vNj41*s76?<3pB) zf#`iIJ=}BeQ%3ujHFZ`g#8DU8vqR4oEx$BG0j25ck;c z6R{7&VdLqj(92EIkT?$GrTR44Ak38ZutIs{MV1an(-;bf?(rd8ZckPQ1-f9yDB;8Y z$xx{`VvD|ip1yC9cF4vMpPoz^51$Nsp)`_IR>JX0${62 zv)3Q^4hZW_q)Ro)AZw>LLfWJhQ(UE2E)=M+lCu=%d5oo(Ks1AN&V{~F?deLkFW;ZtZHhmDO1z+gNPoHtAT_#OlyqZtWZ5%iLLQd?Sd zQzypo=&&_S37?ZOfVuEj#45Pz$c;S(UrR@G`)h(~LFDWI;+a*&sP&TQw?3N}*jKEO z&b&^cg;^qIHs*Xb@3yZC!gyH?q2_RZa1{g_(oW&QQfWfLWtFrT;1o`zl^sq(Spmet z1qxZ?LxSi>(=6oDQPQQNM#5YmqR_lbFS0S*kRTP1)xJnZ3S~uw+6GT#lk7;%6Ct^# zRBY1w$|NXmkehJ=n*`pbAxtg%RJ5trA{O60fE1~8(nBmUl1@JnqX1&BHtTx@-m@Sc zh@a5Pg`Gvoa4ZY?Xr;EN>_~KIZHM4mWKtiBUr^gs;=!n-OUFSinnTk1^%5(TXwqi= zVH_4lrp%2jO^`|b7$v3EHN(-jVNP}3fI+zm(mfa%@f&4N6TRCoxVk?3-qXhREf}-1 zHacdyL}hIc!5sMHnqgxf0{IQth6yV(3Fl#NR%AssSsUy{hb`$jaD1WSl4B#i`PeOb zB69MuD9j+sb8}-?d+b7!O=z+`YuyPmC^6!QQy~Z4bP4R8Qj17`#Osi$D|SKKRi$p? zRq)ax*DHPoR{WgEYeZusF*RfS0KtqxB08b%N}C(52txEd+bhI7>Ult&_Q~LQ;Y%x{p1SlpOW|xA342NfFcJ#vhHx*4`-v zXq(dLye!TA@KD-Mf_S=amEU-~F|ebTj3zfZtx1UGXtUSWbmDB6-Pu*Bl->lhTt%U9 z_x~4I+=B#y_9yNWSySd;p^xI_QqaL)zoO^gznootKR9UdDFA$!j9W_HVaQuS3XpF% z!mrb782689q_ zw?Nt(g@x`=$10}r8{;4?A|*0%zuqt7CS~%$E&UR~!aC-jtpclyCg)3< zsWA6yWK8Mo6JDPt*0C>U5>iJNh&@fzCt1CxOXT$mB8R`;oOU4ri@(cYM`7Z z8nh_kuiR4ot1OG~yR8>jyr!EJFQ!z~e*k*ZxQ>(c^x5?_V>%;gI{O$2`mtpyIz!$W z(ecnh+QreKmkvo+hosX%7KsU~sUy<6?Ruf6ClLw5lHwtefXD+PbU=cxOgtj8+WNO_ z^VgX48A?iS7C9NVIlHm8#j}iUwMTl7fJ~{%Lono=Q(QkmIYGGCr9TB?Yp!DiAuyl~ zDcHhAng*wxo7?n8^$Oh0xP>=7yzxnfEIVp&1f*~eaOTSyIRFvv`O4eKXjER(DzkwW z;=M;mZ#v=?h)}?mXn}=F5ekayM7|HAGLgYaCSn6ekUN`5Mx=+;Vvc}h79wY!^1!bb z`3e#V{I4L-7p;`R>mzMeQKNW~7Bz?COhs|x2WT$+I)Sk0gZRzUZ;lfl`t&U%WR#Ej zjpDoIoOP6Z4h#GNB_({KSj$ol-&M9>V8HCzwJQUTJeBf6c_;^`xRo*KF-qH}G4XW< z5DnsTq{0|adW?x{_>_2=$n!*Q5Fz*xWEX<8Mvz>mWL5T~ffVAeq2P~H8oPivPGo7eR^(&KeiaK zKoEAxi{+3AJIllsi;nw6j7ECvUpr&UDm^FydM9}lg$U$G=(Bb!-QD6WZ<1XqH@>_d LiiGJ0gf{;NmMKor diff --git a/vega/core/quota/quota_strategy.py b/vega/core/quota/quota_strategy.py deleted file mode 100644 index bbafcf9d..00000000 --- a/vega/core/quota/quota_strategy.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Sample Filter.""" -import os -import logging -import copy -import vega -from vega.common.general import General -from vega.report import ReportServer -from vega.common.task_ops import TaskOps -from vega.core.pipeline.conf import PipelineConfig, PipeStepConfig -from vega.core.pipeline.pipe_step import PipeStep - - -class QuotaStrategy(object): - """Config parameters adjustment according to runtime setting.""" - - def __init__(self): - self.restrict_config = General.quota.restrict - self.affinity_config = General.quota.affinity - self.max_runtime = General.quota.strategy.runtime - self.only_search = General.quota.strategy.only_search - self.epoch_time = 0. - self.params_dict = {} - self.temp_trials = copy.deepcopy(self.restrict_config.trials) - self._backup_quota_config() - - def adjust_pipeline_config(self, cfg): - """Adjust pipeline config according.""" - cfg_cp = copy.deepcopy(cfg) - cfg_tiny = copy.deepcopy(cfg) - workers_num = self._calc_workers_num() - General.parallel_search = False - self._get_time_params(cfg_cp) - self._simulate_tiny_pipeline(cfg_tiny) - General.parallel_search = cfg.general.parallel_search - self._modify_pipeline_config(workers_num, self.epoch_time, self.params_dict) - if vega.is_npu_device(): - os.environ['RANK_TABLE_FILE'] = os.environ['ORIGIN_RANK_TABLE_FILE'] - os.environ['RANK_SIZE'] = os.environ['ORIGIN_RANK_SIZE'] - logging.info('Adjust runtime config successfully.') - - def _simulate_tiny_pipeline(self, cfg_tiny): - """Simulate tiny pipeline by using one sample one epoch.""" - report = ReportServer() - for i, step_name in enumerate(PipelineConfig.steps): - step_cfg = cfg_tiny.get(step_name) - if step_cfg.pipe_step.type != 'SearchPipeStep': - continue - step_cfg.trainer.distributed = False - step_cfg.trainer.epochs = 1 - self.restrict_config.trials[step_name] = 1 - General.step_name = step_name - PipeStepConfig.from_dict(step_cfg) - pipestep = PipeStep() - if i == 0: - pipestep.do() - record = report.get_step_records(step_name)[-1] - self.epoch_time = record.runtime - _worker_path = TaskOps().local_base_path - if os.path.exists(_worker_path): - os.system('rm -rf {}'.format(_worker_path)) - if step_cfg.pipe_step.type == 'SearchPipeStep': - self.params_dict[step_name]['max_samples'] = pipestep.generator.search_alg.max_samples - _file = os.path.join(TaskOps().step_path, ".generator") - if os.path.exists(_file): - os.system('rm {}'.format(_file)) - - def _get_time_params(self, cfg): - """Get time parameters from config.""" - for step_name in PipelineConfig.steps: - params = dict() - step_cfg = cfg.get(step_name) - pipe_type = step_cfg.pipe_step.type - params['pipe_type'] = pipe_type - if not cfg[step_name].get('trainer', None): - continue - params['epochs'] = cfg[step_name].trainer.epochs - self.params_dict[step_name] = params - - def _modify_pipeline_config(self, workers_num, epoch_time, params_dict): - """Modify pipeline config according to simulated results.""" - self._restore_quota_config() - nas_time_dict, ft_time_dict = dict(), dict() - for step_name in params_dict: - step_time = epoch_time * params_dict[step_name]['epochs'] - if 'max_samples' in params_dict[step_name]: - step_time = step_time * params_dict[step_name]['max_samples'] / workers_num - nas_time_dict[step_name] = step_time - else: - ft_time_dict[step_name] = step_time - nas_total_time = sum([value for key, value in nas_time_dict.items()]) - if nas_total_time == 0: - return - ft_total_time = sum([value for key, value in ft_time_dict.items()]) - left_time = self.max_runtime - if not self.only_search: - if ft_total_time > 0.9 * self.max_runtime: - ft_total_time = 0.9 * self.max_runtime - left_time = self.max_runtime - ft_total_time - scale = left_time / nas_total_time - for key, value in nas_time_dict.items(): - self.restrict_config.duration[key] = float(scale * value) - self.restrict_config.trials = copy.deepcopy(self.temp_trials) - logging.info('Max duration modified as {}'.format(self.restrict_config.duration)) - - def _backup_quota_config(self): - self.temp_trials = copy.deepcopy(self.restrict_config.trials) - self.temp_flops, self.temp_params, self.temp_latency = \ - self.restrict_config.flops, self.restrict_config.params, self.restrict_config.latency - self.restrict_config.flops, self.restrict_config.params, self.restrict_config.latency = None, None, None - self.temp_affinity_type = self.affinity_config.type - self.affinity_config.type = None - - def _restore_quota_config(self): - self.restrict_config.trials = self.temp_trials - self.restrict_config.flops, self.restrict_config.params, self.restrict_config.latency = \ - self.temp_flops, self.temp_params, self.temp_latency - self.affinity_config.type = self.temp_affinity_type - - def _calc_workers_num(self): - """Calculate workers numbers.""" - if not General.parallel_search: - return 1 - if vega.is_gpu_device(): - import torch - world_size = General.env.world_size - devices_per_node = torch.cuda.device_count() - worker_num = (world_size * devices_per_node) // General.devices_per_trainer - elif vega.is_npu_device(): - world_devices = int(os.environ['RANK_SIZE']) - worker_num = world_devices // General.devices_per_trainer - return worker_num diff --git a/vega/core/run.py b/vega/core/run.py index 14d0f738..ddd9dcb1 100644 --- a/vega/core/run.py +++ b/vega/core/run.py @@ -12,6 +12,7 @@ import sys import logging import json +import vega from vega.common.utils import init_log, lazy from vega.common import Config, UserConfig from vega.common.task_ops import TaskOps @@ -19,7 +20,6 @@ from vega import set_backend from vega.common.general import General from vega.core.pipeline.conf import PipelineConfig -from vega.core.quota import QuotaStrategy logger = logging.getLogger(__name__) @@ -67,10 +67,7 @@ def _run_pipeline(): def _adjust_config(): - if General.quota.strategy.runtime is None: - return - adjust_strategy = QuotaStrategy() - adjust_strategy.adjust_pipeline_config(UserConfig().data) + vega.quota().adjuest_pipeline_by_runtime(UserConfig().data) @lazy diff --git a/vega/core/scheduler/__pycache__/__init__.cpython-37.pyc b/vega/core/scheduler/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index e7b124a2991b358848653997286224521a40cd28..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 214 zcmZ?b<>g`kg51Mb6D)xAV-N=hn1BoiATAaG5-AKRj5!Rsj8TlaOi@gX3@J>(44TX@ z8G%xoOt*NGi&7IyQsZ+Ii%U|AZV41;l$NCAm*>SN=ahm({4|-PxFC}8`31#AAftoI_ghYUZ)E7X5!VITKAQVL`KZk@RN+J>Z(l~8z&+K;FqiRn! zJM5rHK^%F8*<%iQiaY}sUpeh7aH8y7Iws^Fci61} zKklM*4^atH5J3fNXhHqX3g-8q2yhRYu#Jj{63I5POfG1Vicmx~D+Yq`OvEDmgA{ur z5fSb~!5)!uy880*Oy)oG&Pe#4S4P8G-bypmVv=PCAi3FDa}D`@UGb)Oo@AHU5;+`G zoBSY^1m4(D)voio>I_$v%ui=h z<*-z_R(a!K-O6Rt$sBf3J&6AUCYKcZM6_{_AHVjljN@>e0z(HoC>I=O(a6SLFSnB@ zcU8->NI+r-U~JqM0DN+M4d24n9THb5u<^+$hr06~{_YilA)5ZZJ+0x3>FW4+3UZ;R zd}(ysOrOb`Pb&>_+Ep_tmW_mIU!&BEZn8K7H}a2Ax_40#7PF8JX~>2&qcQ5r-MZzF zI!+#-bbq2U7-HxK*>-`?;FxpvnF%(b2;A9_O>}`l-Xruk&K9|^pNaZ5Bqlz|-~}aX z^jdv_k%@>msTrJ;mzZ%Ik(g|=cR~6n!30~A-^mm53sLNeuQ^Uu?>yR4D&JCSa_Odi zxO)FF64&T_N4d9lH?h@?gPwu;yz?%(qvd4v@q_t7HghFE?REigGe`gpU**58$3qxk z5*wO%D{a&@QZ8)V=(?6*!?`-rHr>^=gYTYKa$)9L*~n+G@sWm>8yj>6?7#r829JTL5QyuGKM-uxkmTtdvrXqNm&YAl_l(>lTBP# zb5+9!?sVNzVp8e0)oN1rKk1gGwKVs80BKn&-b&2iSmE*^d-3ETd&c-it~BQc3a&$& ce%)%vvpd)D6x&PbFvx 1: + self.cat_transform[key] = idx / (len(self.range) - 1) + else: + self.cat_transform[key] = 0 self.range = [0.0, 1.0] def multi_sample(self, param_range): @@ -396,7 +399,12 @@ def decode(self, x, forbidden=''): """ individual = [] size = self.range[0] - if random.uniform(0, 1) < 0.8: + # TODO: TEST ONLY + from vega.core.pipeline.conf import PipeStepConfig + ratio = 0.8 + if hasattr(PipeStepConfig.search_space, "prune_ratio"): + ratio = 1 - float(PipeStepConfig.search_space.prune_ratio) + if random.uniform(0, 1) < ratio: return [1] * size if len(self.range) == 1: need_convert_code_size = size // 2 diff --git a/vega/datasets/__pycache__/__init__.cpython-37.pyc b/vega/datasets/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 1f58c652b2f1c3cf3f30a81e25e7d30bd9e25f57..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1176 zcmb7D&2H2%5Vo`V+wG?RLVyt3>xwF=I3onDK=eQ_R2*^%CeCg{^Ru$kqFr?_z=O0$ zo`YxLk}Ida0w+E@$#$<;vOWH0Jo9;GeB^lpf|39F{b^?EI+q8hP+k3=#}WtE5#XBCr4B}eSC zDsxt)G0U>Ke+T#EbWH2vamJ*4%i^jO^O!0aL(x$47*n2NsmKkzCE|gm0~$+8NpT3& z1bU#!S-Gvz$3PvRM?lTY&HEZPtHC#VO|Qv!3bNK>?=@#LMX;Q79Q`{LUU51rMSR{G z;M_l(tR3FV?yymtSILys+o_B|KR%B>v-lz@c(^s2Yrtns6k)=1?J2v-%4oMr3R#Lt zR(@UCp&pH0E?0Itd0KFZ)Fvx4TbT_$vg|T>BSb02M(wnjuAP?NPa-y9_SX-l;xJfz z`4Wg^RtCBbS)hmqU2%daN{h6LB7Zg)9xm-*%55yfaOoy(r(sg9U6o1EP>gk%t{_~a zUoH%k$w0S|oa)p+2WkR60BQj>mIjLV1gH)49e22!^_D$o`_P(8bXaBG*nGe@_;5;7 zdJB7KJ!qY!spU<`?$UznqMTd*)qAzmR!8(ec}#_RXg#>E1|XGyy1z=iP@REx6}#9M z`zm>Umpnf%^SmrVH_h2JDUvF5Pg)7RV>enzgG#VMPD+u>ala-ys*V$`6}WY(;|TQw zYA<7pdDQ(Uu?I}-V;S4E)3)2atLEOyMmIpSY$;zmotvm5{}ixP?$!(K(jnc&6c2%` zI8nsQIEsXWbNxuazIm1(tKS<kiR3*P^RN6qg`kg51Mb6F~H15CH>>K!yVl7qb9~6oz01O-8?!3`HPe1o6vQKfTC6zbY$B zzbLgJUq7+5BtJJtzbrL9Q9mWIB(XTPq*y;WKQB!`K0Y%qvm`!Vub}c4hfQvNN@-52 L9mufHK+FID7{DN# diff --git a/vega/datasets/conf/__pycache__/dataset.cpython-37.pyc b/vega/datasets/conf/__pycache__/dataset.cpython-37.pyc deleted file mode 100644 index 092ca516f6e8f20a4b0aff99f9b9bc039e051b42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 788 zcmZ8fy^a$x5cc10a`_Pl6i8GNqFIXFK#kBH38Fh8lFE%SYbUp1e}cV1x)oH%WAFkz z2hTv`mWo%PV(cX7j#90_c625hGw6yiu>MM_)| zPO#T0=#*28UL(a+^cAT{lWWSEiWey9b*FRLN~_yR-qfpNZPVXS2PN3un+3;0n?jaF zC*PMEyeG3#TD!d3>fGlYiGp$lY@dM8h+~+9D9ouMoaty2^PZx-uTSD^}7xonGV^`gtP25%3JyJ_4a|iwX=9 z|0l^e_62{&(42M;-|BW}>WSHv+J+warOrPT^*VK|1&|3dzN(A!rmC8nN3D@X4KZ!o zgg9zj`dCKk*bejN8YIFc_DS t;J-*Ba~EQ6=s~CZZ5n^HFht0?QO_^U_@b`=e2?UX-+_%mFfzg;`WN{R$P54g diff --git a/vega/evaluator/__pycache__/__init__.cpython-37.pyc b/vega/evaluator/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 16d131dd5b1387d009933df1ba01689866327d7d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 412 zcmYLGy-ve05O$o)QY$XKa0?_23I4ThLIF_Em1@@4@Z$9^V zP;qTuT&6ll<&XndWDb5doqv^jQz*B$X1#tcVEKq+L diff --git a/vega/evaluator/__pycache__/conf.cpython-37.pyc b/vega/evaluator/__pycache__/conf.cpython-37.pyc deleted file mode 100644 index eaf569ad22dd9344b1c8ed925a525cdd3362a989..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1798 zcmaJ>&2rl|5C%v}r2f^UZdxBY>M5hiM9Qg$OeW(*cACy~oJ_3dfD4KMutkR=6_9jC zn)H-BMqZ%L(Pzj7*PQYSJ#`nf2I_oC@ zV-F;M4oVS56jfeA=N|JY+CJDm^C|g-s6f@;6IIjxJD&yM*Hs99==>V^k!pb7aDH7i zkI1mK+TGX3ajMl|lBLIknIP9 zHSEo&Sza8RiKG-oW~$Gw?f_g-CdFX0Fpv(AgTA$6bU3v^0jO;xq%4h)%g2Np4&29d|JX4{++= zJL3~`Z@fA^9UHyK#-c2;c``oJCt|GE>5LKeXtA_4w|f^MV5^BA(A5uHc~qH?T->m0 zBhFJX*PPoX=krVfK(O1KzbQp>z7le-GRe8Y6m9qEv!f|Lc=_nJ-ycn%y?}jK6?clM zVL<;Jmgb2lbSjrNkfjo~KG%hbrEOszrr^g$zk#HY4X`rHlxf5Ei_9ylh3_zzK1hx# zqJdYZf%~EDR<~G#IQKC8A*ipg#s}5}z?xdc7~~7Erwx8bHCPw^dljec&b~g2rQT5F zoI7`~<<6x&BO5N3z&kl(;l4PFQyE+GWaVVR~q>{s7-*nLFqUxc-gsZyu_YYqF zIytzgytx5uHhd-IYn>_<&4f{Z3Zoemc%BuS&w%mfTRh+o6-FlR<3S>j-UZPYw8vh-e(=%tu7rH z+n%b!VWE&g%$<~Y_t_etvh!rh}@;-RVFKA3v@WDHu+1hn8JB5*r zfW2jGw5ocD%pZEy0CchNT(5@{=NlSf0~N~xVj2*7nd*+^NPoQNawG)tBig2`olO>B z4(_!?X>4ZGF{HwW6J)T5Y-bxC*D2n*u79kr9Mvzdvs;(ntmALtfnOciA7S+#$o~cw zxZ)XHJsgIP+_r^R*+^#dd6qh+8|;u9Y6r?iUqwqBJ*Y0MhgiKFFXI48*7p0~^n3pS Dqgu2r diff --git a/vega/evaluator/conf.py b/vega/evaluator/conf.py index d499e576..cf14c96b 100644 --- a/vega/evaluator/conf.py +++ b/vega/evaluator/conf.py @@ -37,6 +37,7 @@ class DeviceEvaluatorConfig(ConfigSerializable): backend = "pytorch" hardware = "Davinci" remote_host = "" + intermediate_format = "onnx" # for torch model convert cuda = False evaluate_latency = True metric = {'type': 'accuracy'} diff --git a/vega/evaluator/device_evaluator.py b/vega/evaluator/device_evaluator.py index 87b16994..ef8f454c 100644 --- a/vega/evaluator/device_evaluator.py +++ b/vega/evaluator/device_evaluator.py @@ -47,6 +47,7 @@ def __init__(self, worker_info=None, model=None, saved_folder=None, saved_step_n # self.backend = self.config.backend self.hardware = self.config.hardware self.remote_host = self.config.remote_host + self.intermediate_format = self.config.intermediate_format self.calculate_metric = self.config.calculate_metric self.quantize = self.config.quantize self.model = model @@ -97,7 +98,8 @@ def valid(self): # noqa: C901 reuse_model = False if global_step == 0 else True results = evaluate(backend="pytorch", hardware=self.hardware, remote_host=self.remote_host, model=self.model, weight=None, test_data=test_data, input_shape=data.shape, - reuse_model=reuse_model, job_id=job_id, repeat_times=repeat_times) + reuse_model=reuse_model, job_id=job_id, repeat_times=repeat_times, + intermediate_format=self.intermediate_format) if results.get("status") != "sucess" and error_count <= error_threshold: error_count += 1 break diff --git a/vega/evaluator/evaluator.py b/vega/evaluator/evaluator.py index 8b86e7ef..dcd8d04c 100644 --- a/vega/evaluator/evaluator.py +++ b/vega/evaluator/evaluator.py @@ -129,12 +129,12 @@ def _use_evaluator(self): use_evaluator = False cls_evaluator_set = [] if EvaluatorConfig.host_evaluator_enable: - cls_host_evaluator = ClassFactory.get_cls(ClassType.HOST_EVALUATOR, "HostEvaluator") + cls_host_evaluator = ClassFactory.get_cls(ClassType.HOST_EVALUATOR, EvaluatorConfig.host_evaluator.type) use_evaluator = True cls_evaluator_set.append(cls_host_evaluator) if EvaluatorConfig.device_evaluator_enable: - cls_device_evaluator = ClassFactory.get_cls( - ClassType.DEVICE_EVALUATOR, "DeviceEvaluator") + cls_device_evaluator = ClassFactory.get_cls(ClassType.DEVICE_EVALUATOR, + EvaluatorConfig.device_evaluator.type) use_evaluator = True cls_evaluator_set.append(cls_device_evaluator) # TODO HAVA_D_EVALUATOR @@ -182,7 +182,6 @@ def _get_model_desc(self): pattern = FileOps.join_path(folder, "desc_*.json") desc_file = glob.glob(pattern)[0] model_desc = Config(desc_file) - elif PipeStepConfig.pipe_step.get("models_folder") is not None: folder = PipeStepConfig.pipe_step.get("models_folder").replace("{local_base_path}", self.local_base_path) diff --git a/vega/evaluator/host_evaluator.py b/vega/evaluator/host_evaluator.py index 928bbc27..5d66dea4 100644 --- a/vega/evaluator/host_evaluator.py +++ b/vega/evaluator/host_evaluator.py @@ -11,8 +11,9 @@ """HostEvaluator used to do evaluate process on gpu.""" import time -import os +import copy import logging +from statistics import mean import vega from vega.common import ClassFactory, ClassType from vega.common import init_log @@ -55,6 +56,18 @@ def __init__(self, worker_info=None, model=None, saved_folder=None, saved_step_n self.saved_folder = saved_folder self.saved_step_name = saved_step_name + def _call_model_batch(self, batch): + input, target = None, None + if isinstance(batch, dict): + logits = self.model(**batch) + elif isinstance(batch, list) and isinstance(batch[0], dict): + target = batch + logits = self.model(batch) + else: + input, target = batch + logits = self.model(input) if not isinstance(input, dict) else self.model(**input) + return logits, target + def valid(self, valid_loader): """Validate one step of mode. @@ -63,37 +76,31 @@ def valid(self, valid_loader): if vega.is_torch_backend(): import torch from vega.metrics.pytorch import Metrics + if vega.is_gpu_device(): + self.model = self.model.cuda() + elif vega.is_npu_device(): + self.model = self.model.to(vega.get_devices()) metrics = Metrics(self.config.metric) self.model.eval() - data_num = 0 - latency_sum = 0.0 + latency_batch = None + cal_lantency_counts = 10 with torch.no_grad(): for step, batch in enumerate(valid_loader): - if isinstance(batch, list) or isinstance(batch, tuple): - data = batch[0] - target = batch[1] - else: - raise ValueError("The dataset format must be tuple or list," - "but get {}.".format(type(batch))) - if vega.is_gpu_device(): - data, target = data.cuda(), target.cuda() - self.model = self.model.cuda() - elif vega.is_npu_device(): - import torch.npu - device = "npu:{}".format(os.environ.get('DEVICE_ID', 0)) - torch.npu.set_device(device) - data, target = data.npu(), target.npu() - self.model = self.model.npu() - time_start = time.time() - logits = self.model(data) - latency_sum += time.time() - time_start - metrics(logits, target) - n = data.size(0) - data_num += n - if step % self.config.report_freq == 0: - logging.info("step [{}/{}], valid metric [{}]".format( - step + 1, len(valid_loader), str(metrics.results))) - latency = latency_sum / data_num + batch = self._set_device(batch) + if not latency_batch: + latency_batch = copy.deepcopy(batch) + logits, target = self._call_model_batch(batch) + metrics_results = metrics(logits, target) + if step % self.config.report_freq == 0 and metrics_results: + logging.info( + "step [{}/{}], valid metric [{}]".format(step + 1, len(valid_loader), metrics_results)) + latency_pre_batch = [] + for i in range(cal_lantency_counts): + time_init = time.perf_counter() + self._call_model_batch(latency_batch) + latency_pre_batch.append((time.perf_counter() - time_init) * 1000) + latency = mean(latency_pre_batch) + logging.info("evaluator latency [{}]".format(latency)) elif vega.is_tf_backend(): from vega.metrics.tensorflow.metrics import Metrics metrics = Metrics(self.config.metric) @@ -129,6 +136,21 @@ def valid(self, valid_loader): logging.info("evaluate performance: {}".format(pfms)) return pfms + def _set_device(self, data): + import torch + if torch.is_tensor(data): + if vega.is_gpu_device(): + return data.cuda() + else: + return data.to(vega.get_devices()) + if isinstance(data, dict): + return {k: self._set_device(v) for k, v in data.items()} + elif isinstance(data, list): + return [self._set_device(v) for v in data] + elif isinstance(data, tuple): + return tuple([self._set_device(v) for v in data]) + return data + def _model_fn(self, features, labels, mode): """Model function of gpu evaluator.""" import tensorflow as tf diff --git a/vega/evaluator/tools/evaluate_davinci_bolt.py b/vega/evaluator/tools/evaluate_davinci_bolt.py index 6456827d..0530c022 100644 --- a/vega/evaluator/tools/evaluate_davinci_bolt.py +++ b/vega/evaluator/tools/evaluate_davinci_bolt.py @@ -18,7 +18,7 @@ def evaluate(backend, hardware, remote_host, model, weight, test_data, input_shape=None, reuse_model=False, - job_id=None, quantize=False, repeat_times=1): + job_id=None, quantize=False, repeat_times=1, **kwargs): """Evaluate interface of the EvaluateService. :param backend: the backend can be one of "tensorflow", "caffe" and "pytorch" @@ -48,7 +48,7 @@ def evaluate(backend, hardware, remote_host, model, weight, test_data, input_sha if not reuse_model: base_save_dir = os.path.dirname(test_data) model, weight, backend = preprocessing_model(backend, hardware, model, weight, input_shape, - base_save_dir, quantize, test_data) + base_save_dir, quantize, test_data, **kwargs) model_file = open(model, "rb") data_file = open(test_data, "rb") if backend == "caffe": @@ -100,7 +100,7 @@ def evaluate(backend, hardware, remote_host, model, weight, test_data, input_sha return evaluate_result -def preprocessing_model(backend, hardware, model, weight, input_shape, base_save_dir, quantize, test_data): +def preprocessing_model(backend, hardware, model, weight, input_shape, base_save_dir, quantize, test_data, **kwargs): """Preprocess the model. :param backend: the backend can be one of "tensorflow", "caffe" , "pytorch" and "mindspore". @@ -119,8 +119,8 @@ def preprocessing_model(backend, hardware, model, weight, input_shape, base_save if backend == "pytorch": if hardware == "Bolt": from .pytorch2onnx import pytorch2onnx - model = pytorch2onnx(model, input_shape) - else: + model = pytorch2onnx(model, input_shape, base_save_dir) + elif kwargs["intermediate_format"] == "caffe": model_file = os.path.join(base_save_dir, "torch_model.pkl") shape_file = os.path.join(base_save_dir, "input_shape.pkl") with open(model_file, "wb") as f: @@ -140,6 +140,10 @@ def preprocessing_model(backend, hardware, model, weight, input_shape, base_save model = os.path.join(base_save_dir, "torch2caffe.prototxt") weight = os.path.join(base_save_dir, "torch2caffe.caffemodel") backend = "caffe" + else: + from .pytorch2onnx import pytorch2onnx + model = pytorch2onnx(model, input_shape, base_save_dir) + backend = "onnx" elif backend == "tensorflow": pb_model_file = os.path.join(base_save_dir, "tf_model.pb") if os.path.exists(pb_model_file): diff --git a/vega/evaluator/tools/pytorch2onnx.py b/vega/evaluator/tools/pytorch2onnx.py index 17ec33f8..51ddd8ff 100644 --- a/vega/evaluator/tools/pytorch2onnx.py +++ b/vega/evaluator/tools/pytorch2onnx.py @@ -13,11 +13,10 @@ import torch import subprocess import logging -import os from vega.common.general import General -def pytorch2onnx(model, input_shape): +def pytorch2onnx(model, input_shape, base_save_dir): """Convert the pytorch model to onnx model. :param model: pytorch model class @@ -30,15 +29,13 @@ def pytorch2onnx(model, input_shape): # model.load_state_dict(torch.load(weight)) # Export the trained model to ONNX dump_input = Variable(torch.randn(input_shape)) - if os.path.exists("./torch_model.onnx"): - os.remove("./torch_model.onnx") - if os.path.exists("./torch_model_sim.onnx"): - os.remove("./torch_model_sim.onnx") - torch.onnx.export(model, dump_input, "./torch_model.onnx") - try: - subprocess.call(f"{General.python_command} -m onnxsim ./torch_model.onnx ./torch_model_sim.onnx", shell=True) - except Exception as e: - logging.error("{}".format(str(e))) - onnx_model = "./torch_model_sim.onnx" - + torch.onnx.export(model, dump_input, "{}/torch_model.onnx".format(base_save_dir)) + # try: + # subprocess.call( + # f"{General.python_command} -m onnxsim {base_save_dir}/torch_model.onnx " + # f"{base_save_dir}/torch_model_sim.onnx", shell=True) + # except Exception as e: + # logging.error("{}".format(str(e))) + # onnx_model = f"{base_save_dir}/torch_model_sim.onnx" + onnx_model = f"{base_save_dir}/torch_model.onnx" return onnx_model diff --git a/vega/metrics/__init__.py b/vega/metrics/__init__.py index 37f8e063..442c2750 100644 --- a/vega/metrics/__init__.py +++ b/vega/metrics/__init__.py @@ -11,7 +11,7 @@ """Import and register metrics automatically.""" from .flops_and_params import calc_model_flops_params -from .forward_latency import calc_forward_latency +from .forward_latency import calc_forward_latency, calc_forward_latency_on_host def register_metrics(backend): diff --git a/vega/metrics/flops_and_params.py b/vega/metrics/flops_and_params.py index ec6a517e..51e1ee6e 100644 --- a/vega/metrics/flops_and_params.py +++ b/vega/metrics/flops_and_params.py @@ -9,6 +9,7 @@ # MIT License for more details. """Model counter of FLOPS and parameters.""" + from copy import deepcopy import vega import numpy as np @@ -40,6 +41,16 @@ def add_new_hooks(custom_hooks): return custom_hooks +def _do_calc_flops_params(model, input, custom_hooks=None, verbose=False): + from thop import profile + if custom_hooks is None: + custom_hooks = {} + custom_hooks = add_new_hooks(custom_hooks) + inputs = (input,) + flops, params = profile(model, inputs, custom_hooks, verbose) + return flops, params + + def calc_model_flops_params(model, input, custom_hooks=None, verbose=False): """Pytorch model flops and parameters calculation. @@ -58,14 +69,11 @@ def calc_model_flops_params(model, input, custom_hooks=None, verbose=False): _model = deepcopy(model) except Exception: _model = model + if vega.is_torch_backend(): - from thop import profile - if custom_hooks is None: - custom_hooks = {} - custom_hooks = add_new_hooks(custom_hooks) - inputs = (input,) - flops, params = profile(_model, inputs, custom_hooks, verbose) - del _model + from vega.modules.arch.architecture import register_clear_module_arch_params_hooks + flops, params = _do_calc_flops_params(_model, input, custom_hooks, verbose) + register_clear_module_arch_params_hooks(model) elif vega.is_tf_backend(): import tensorflow.compat.v1 as tf with tf.Graph().as_default() as graph: @@ -77,7 +85,6 @@ def calc_model_flops_params(model, input, custom_hooks=None, verbose=False): opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter() params = tf.profiler.profile(graph, cmd='op', options=opts).total_parameters flops *= 0.5 - del _model elif vega.is_ms_backend(): total_params = 0 for param in model.trainable_params(): @@ -85,4 +92,6 @@ def calc_model_flops_params(model, input, custom_hooks=None, verbose=False): params = total_params # TODO flops = 0 + + del _model return flops, params diff --git a/vega/metrics/forward_latency.py b/vega/metrics/forward_latency.py index 0e78d501..6480a645 100644 --- a/vega/metrics/forward_latency.py +++ b/vega/metrics/forward_latency.py @@ -32,10 +32,10 @@ def calc_forward_latency(model, input, sess_config=None, num=10): """ if DeviceEvaluatorConfig.remote_host: return _calc_forward_latency_davinci(model, input, sess_config, num, DeviceEvaluatorConfig().to_dict()) - return _calc_forward_latency_gpu(model, input, sess_config, num) + return calc_forward_latency_on_host(model, input, sess_config, num) -def _calc_forward_latency_gpu(model, input, sess_config=None, num=100): +def calc_forward_latency_on_host(model, input, sess_config=None, num=100): """Model forward latency calculation. :param model: network model @@ -100,6 +100,7 @@ def _calc_forward_latency_davinci(model, input, sess_config=None, num=10, evalua # backend = evaluate_config.get("backend") hardware = evaluate_config.get("hardware") remote_host = evaluate_config.get("remote_host") + intermediate_format = evaluate_config.get("intermediate_format") worker_path = TaskOps().local_base_path save_data_file = os.path.join(worker_path, "input.bin") @@ -116,7 +117,7 @@ def _calc_forward_latency_davinci(model, input, sess_config=None, num=10, evalua for index in range(num): reuse_model = False if index == 0 else True results = evaluate("pytorch", hardware, remote_host, model, None, save_data_file, input_shape, - reuse_model, job_id) + reuse_model, job_id, intermediate_format=intermediate_format) latency += np.float(results.get("latency")) elif vega.is_tf_backend(): input_shape = input.shape.as_list() diff --git a/vega/metrics/pytorch/classifier_metric.py b/vega/metrics/pytorch/classifier_metric.py index dd979428..2f00eeff 100644 --- a/vega/metrics/pytorch/classifier_metric.py +++ b/vega/metrics/pytorch/classifier_metric.py @@ -36,7 +36,7 @@ def accuracy(output, target, top_k=(1,)): correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in top_k: - correct_k = correct[:k].view(-1).float().sum(0) + correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res diff --git a/vega/model_zoo/compressed_model_filter.py b/vega/model_zoo/compressed_model_filter.py deleted file mode 100644 index 45084820..00000000 --- a/vega/model_zoo/compressed_model_filter.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Compressed model filter.""" - -import pandas as pd - - -class CompressedModelFilter(object): - """Compressed Model Filter.""" - - def __init__(self, models_file): - """Initialize.""" - self.models_file = models_file - self.model_zoo = self._parse_file(self.models_file) - - def _parse_file(self, models_file): - """Parse model files by pandas.""" - model_zoo = pd.read_csv(models_file) - return model_zoo - - def _parse_standard(self, standard): - """Parse quota standard to target and restrict.""" - restrict = standard.restrict().to_dict() - target = standard.target().to_dict() - return target, restrict - - def _filtrate(self, restrict): - """Filtrate models by restrict condition.""" - filters = [] - condition = True - for key, value in restrict.items(): - if not value or key not in self.model_zoo: - continue - filters.append(self.model_zoo[key].map(lambda x: x < value)) - condition = condition & filters[-1] - filtered_data = self.model_zoo[condition] - return filtered_data - - def _choose_models(self, candidates, target, num): - """Choose models by target type and selected number.""" - sort_cands = candidates.sort_values(target.type, inplace=False) - # sort_cands = sort_cands[sort_cands[target.type] > target.value] - if len(sort_cands) < num: - num = len(sort_cands) - satisfied_models = [] - for idx in range(num): - satisfied_models.append(sort_cands.iloc[idx]['desc']) - return satisfied_models - - def select_satisfied_model(self, standard, num): - """Select satisfied models by standard.""" - target, restrict = self._parse_standard(standard) - candidates = self._filtrate(restrict) - satisfied_models = self._choose_models(candidates, target, num) - return satisfied_models diff --git a/vega/model_zoo/model_zoo.py b/vega/model_zoo/model_zoo.py index ca8638c9..b3ced2be 100644 --- a/vega/model_zoo/model_zoo.py +++ b/vega/model_zoo/model_zoo.py @@ -20,6 +20,7 @@ from vega.modules.graph_utils import graph2desc from vega.modules.module import Module from vega.modules.arch import transform_architecture +from vega.common.searchable import SearchableRegister class ModelZoo(object): @@ -63,8 +64,9 @@ def get_model(cls, model_desc=None, pretrained_model_file=None, exclude_weight_p if pretrained_model_file is not None: if exclude_weight_prefix: model.exclude_weight_prefix = exclude_weight_prefix - model = cls._load_pretrained_model(model, pretrained_model_file) + model = cls._load_pretrained_model(model, pretrained_model_file, exclude_weight_prefix) model = transform_architecture(model, pretrained_model_file) + model = SearchableRegister().active_search_event(model) if model is None: raise ValueError("Failed to get mode, model is None.") return model @@ -75,7 +77,9 @@ def to_module(cls, model): if vega.is_ms_backend(): from vega.networks.mindspore.backbones.ms2vega import transform_model return transform_model(model) - else: + if vega.is_torch_backend(): + return model + if vega.is_tf_backend(): try: model_desc = cls.parse_desc_from_pretrained_model(model) except Exception as ex: @@ -110,15 +114,43 @@ def parse_desc_from_pretrained_model(cls, src_model, pb_file=None): return desc @classmethod - def _load_pretrained_model(cls, model, pretrained_model_file): + def _exclude_checkpoint_by_prefix(cls, states, head_prefix): + if head_prefix: + if not isinstance(head_prefix, list): + head_prefix = [head_prefix] + for prefix in head_prefix: + states = {k: v for k, v in states.items() if not k.startswith(prefix)} + return states + + @classmethod + def _load_pretrained_model(cls, model, pretrained_model_file, exclude_weight_prefix=None): pretrained_model_file = cls._get_abs_path(pretrained_model_file) logging.info("load model weights from file, weights file={}".format(pretrained_model_file)) if vega.is_torch_backend(): + import torch if not os.path.isfile(pretrained_model_file): raise Exception(f"Pretrained model is not existed, model={pretrained_model_file}") - import torch - checkpoint = torch.load(pretrained_model_file) - model.load_state_dict(checkpoint) + if vega.is_npu_device(): + device = int(os.environ.get('DEVICE_ID', 0)) + target_model_file = "/tmp/checkpoint_{}.pth".format(device) + cmd = "/bin/cp -f {} {} && sed -i 's/npu:[0-9]/npu:{}/g' {}".format(pretrained_model_file, + target_model_file, + device, + target_model_file) + ret = os.system(cmd) + logging.info("modify weight file result: " + str(ret)) + checkpoint = torch.load(target_model_file) + else: + checkpoint = torch.load(pretrained_model_file) + if exclude_weight_prefix: + # TODO: make it more generalize + if vega.is_torch_backend(): + model.load_state_dict(checkpoint, False, exclude_weight_prefix=exclude_weight_prefix) + else: + checkpoint = cls._exclude_checkpoint_by_prefix(checkpoint, exclude_weight_prefix) + model.load_state_dict(checkpoint, False) + else: + model.load_state_dict(checkpoint) # del checkpoint if vega.is_tf_backend(): diff --git a/vega/modules/arch/architecture.py b/vega/modules/arch/architecture.py index fade3751..bed3f7e8 100644 --- a/vega/modules/arch/architecture.py +++ b/vega/modules/arch/architecture.py @@ -36,10 +36,10 @@ def transform_architecture(model, pretrained_model_file=None): assert len(changed_name_list) == len(mask_weight_list) # change model and rebuild model_desc = model.desc - root_name = [name for name in list(model_desc.keys()) if name not in ('type', '_arch_params')] + # root_name = [name for name in list(model_desc.keys()) if name not in ('type', '_arch_params')] for changed_name, mask in zip(changed_name_list, mask_weight_list): name = changed_name.split('.') - name[0] = root_name[int(name[0])] + # name[0] = root_name[int(name[0])] assert len(name) <= 6 if len(name) == 6: model_desc[name[0]][name[1]][name[2]][name[3]][name[4]][name[5]] = sum(mask) @@ -56,7 +56,7 @@ def transform_architecture(model, pretrained_model_file=None): model_desc.pop('_arch_params') if '_arch_params' in model_desc else model_desc model.desc = model_desc # change weight - if hasattr(model, "pretrained"): + if pretrained_model_file and hasattr(model, "pretrained"): pretrained_weight = model.pretrained(pretrained_model_file) load_checkpoint(pretrained_weight, net=model) os.remove(pretrained_weight) @@ -66,13 +66,23 @@ def transform_architecture(model, pretrained_model_file=None): if not ClassFactory.is_exists(model._arch_params_type, module.model_name): continue arch_cls = ClassFactory.get_cls(model._arch_params_type, module.model_name) - decode_fn(module, arch_cls) module.register_forward_pre_hook(arch_cls.fit_weights) - module.register_forward_hook(module.clear_module_arch_params) + # module.register_forward_hook(module.clear_module_arch_params) return model +def register_clear_module_arch_params_hooks(model): + """Register hooks.""" + if not hasattr(model, "_arch_params") or not model._arch_params or \ + PipeStepConfig.pipe_step.get("type") == "TrainPipeStep": + return + for name, module in model.named_modules(): + if not ClassFactory.is_exists(model._arch_params_type, module.model_name): + continue + module.register_forward_hook(module.clear_module_arch_params) + + def decode_fn(module, arch_cls): """Decode function.""" for name, value in module._arch_params.items(): diff --git a/vega/modules/arch/combiner.py b/vega/modules/arch/combiner.py index 3288993b..339b0a5f 100644 --- a/vega/modules/arch/combiner.py +++ b/vega/modules/arch/combiner.py @@ -12,8 +12,7 @@ from collections import deque from vega.modules.operators import ops from vega.modules.connections import Add -from vega.modules.connections import Sequential -import vega +from vega.modules.connections import Module class ConnectionsArchParamsCombiner(object): @@ -54,13 +53,14 @@ def _traversal(self, module): self.add_condition(module.name + '.num_features', self.pre_conv.name + '.out_channels') elif isinstance(module, ops.Linear): self.add_condition(module.name + '.in_features', self.pre_conv.name + '.out_channels') - elif isinstance(module, Sequential): + elif isinstance(module, Module): for child in module.children(): self._traversal(child) def _traversal_add_connections(self, module): last_convs = [] last_bns = [] + add_bns = [] for child in module.children(): if isinstance(child, ops.Conv2d): add_convs = [child] @@ -71,8 +71,8 @@ def _traversal_add_connections(self, module): add_bns = [bn for name, bn in child.named_modules() if isinstance(bn, ops.BatchNorm2d)] if add_convs: last_convs.append(add_convs[-1]) - if vega.is_ms_backend(): - last_bns.append(add_bns[-1]) + if add_bns: + last_bns.append(add_bns[-1]) tmp_pre_conv = self.pre_conv for child in module.children(): self.pre_conv = tmp_pre_conv @@ -89,7 +89,7 @@ def _traversal_add_connections(self, module): if len(last_convs) == 1: self.add_forbidden(last_convs[0].name + '.out_channels') self.add_condition(last_convs[0].name + '.out_channels', self.pre_conv.name + '.out_channels') - if vega.is_ms_backend(): + if len(last_bns) > 0: self.add_condition(last_bns[-1].name + '.num_features', self.pre_conv.name + '.out_channels') else: for last_conv in last_convs: @@ -99,7 +99,8 @@ def _traversal_add_connections(self, module): self.add_condition(last_convs[0].name + '.out_channels', self.pre_conv.name + '.out_channels') for k, v in [(k, v) for k, v in self.conditions if v == last_conv.name + '.out_channels']: self.add_condition(k, self.pre_conv.name + '.out_channels') - self.pre_conv = last_convs[0] + if len(last_convs) > 0: + self.pre_conv = last_convs[0] def add_condition(self, name, value): """Add condition.""" diff --git a/vega/modules/arch/double_channels_arch.py b/vega/modules/arch/double_channels_arch.py index f3747481..ee29e5cf 100644 --- a/vega/modules/arch/double_channels_arch.py +++ b/vega/modules/arch/double_channels_arch.py @@ -43,13 +43,26 @@ def fit_weights(module, x): continue padding = [0, out_channels_diff] else: - in_channels_diff = int(inputs.shape[1]) - int(weight.shape[in_channels_axis]) + groups = module.groups + # depthwise conv + if groups == module.in_channels and module.out_channels < groups: + module.out_channels = groups + in_channels_diff = int(inputs.shape[1]) - int(weight.shape[in_channels_axis] * module.groups) out_channels_diff = int(module.out_channels) - int(weight.shape[out_channels_axis]) if in_channels_diff == 0 and out_channels_diff == 0: continue padding = [0, 0, 0, 0, 0, 0, 0, 0] - if in_channels_diff != 0: - padding[5] = in_channels_diff + # fit input channel + if groups == 1: + if in_channels_diff != 0: + padding[5] = in_channels_diff + module.in_channels += in_channels_diff + else: + if in_channels_diff > 0: + in_channels_group_diff = int(in_channels_diff / groups) + padding[5] = in_channels_group_diff + elif in_channels_diff < 0: + module.groups = int(abs(in_channels_diff) / weight.shape[in_channels_axis]) module.in_channels += in_channels_diff if out_channels_diff != 0: padding[-1] = out_channels_diff diff --git a/vega/modules/arch/prune_arch.py b/vega/modules/arch/prune_arch.py index 3d8bb321..4696ff39 100644 --- a/vega/modules/arch/prune_arch.py +++ b/vega/modules/arch/prune_arch.py @@ -9,11 +9,31 @@ # MIT License for more details. """Prune ArchSpace.""" -from vega import is_torch_backend + +import vega +from vega import is_torch_backend, is_tf_backend +from vega.modules.operators import ops from vega.common.class_factory import ClassFactory from vega.modules.arch.architecture import Architecture +def _to_cpu(data): + try: + import torch + if torch.is_tensor(data): + return data.cpu() + except Exception: + pass + + if isinstance(data, dict): + return {k: _to_cpu(v) for k, v in data.items()} + elif isinstance(data, list): + return [_to_cpu(v) for v in data] + elif isinstance(data, tuple): + return tuple([_to_cpu(v) for v in data]) + return data + + @ClassFactory.register('Prune', 'Conv2d') class Conv2dPruneArchitecture(Architecture): """Prune Conv2d.""" @@ -29,29 +49,9 @@ def fit_weights(module, x): arch_params = module.module_arch_params if not arch_params: return None - weights = module.get_weights() - if arch_params.get('out_channels'): - out_channels_idx = [idx for idx, value in enumerate(arch_params.out_channels) if value == 1] - for name, weight in weights.items(): - if weight is None: - continue - if 'BatchNorm' in name: - module.set_weights(name, weight[out_channels_idx]) - else: - if is_torch_backend(): - module.set_weights(name, weight[out_channels_idx, :, :, :]) - else: - module.set_weights(name, weight[:, :, :, out_channels_idx]) - if arch_params.get('in_channels'): - in_channels_idx = [idx for idx, value in enumerate(arch_params.in_channels) if value == 1] - for name, weight in weights.items(): - if weight is None or 'BatchNorm' in name: - continue - if weight is not None: - if is_torch_backend(): - module.set_weights(name, weight[:, in_channels_idx, :, :]) - else: - module.set_weights(name, weight[:, :, in_channels_idx, :]) + freeze(module) + prune_conv2d_out_channels(arch_params, module) + prune_conv2d_in_channels(arch_params, module) return None @@ -70,10 +70,17 @@ def fit_weights(module, x): arch_params = module.module_arch_params if not arch_params: return None + adapt(module) idx = [idx for idx, value in enumerate(arch_params.num_features) if value == 1] weights = module.get_weights() + weights = _to_cpu(weights) for name, weight in weights.items(): - module.set_weights(name, weight[idx]) + if name in ["total_ops", "total_params"]: + continue + if is_tf_backend(): + module.set_weights(name, weight[idx]) + else: + module.set_weights(name, weight[idx].to(vega.get_devices())) return None @@ -92,12 +99,91 @@ def fit_weights(module, x): arch_params = module.module_arch_params if not arch_params: return None + # for name, parameter in module.named_parameters(): + # parameter.requires_grad_(False) idx_in = [idx for idx, value in enumerate(arch_params.in_features) if value == 1] weights = module.get_weights() for name, weight in weights.items(): + if name in ["total_ops", "total_params"]: + continue if 'kernel' in name or 'weight' in name: - if is_torch_backend(): - module.set_weights(name, weight[:, idx_in]) - else: + if is_tf_backend(): module.set_weights(name, weight[idx_in, :]) + elif is_torch_backend(): + module.set_weights(name, weight[:, idx_in].to(vega.get_devices())) + else: + module.set_weights(name, weight[idx_in, :].to(vega.get_devices())) return None + + +def freeze(module): + """Freeze parameter.""" + if not is_torch_backend(): + return + for name, parameter in module.named_parameters(): + parameter.requires_grad_(False) + + +def adapt(module): + """Adapt mean and var in dataset.""" + if not is_torch_backend(): + return + module.weight.requires_grad = False + module.bias.requires_grad = False + + +def prune_conv2d_out_channels(arch_params, module): + """Prune out channels of conv2d.""" + weights = module.get_weights() + weights = _to_cpu(weights) + if arch_params.get('out_channels'): + out_channels_idx = [idx for idx, value in enumerate(arch_params.out_channels) if value == 1] + for name, weight in weights.items(): + if weight is None: + continue + if name in ["total_ops", "total_params"]: + continue + if 'BatchNorm' in name: + if is_tf_backend(): + module.set_weights(name, weight[out_channels_idx]) + else: + module.set_weights(name, weight[out_channels_idx].to(vega.get_devices())) + else: + if is_tf_backend(): + module.set_weights(name, weight[:, :, :, out_channels_idx]) + elif is_torch_backend(): + module.set_weights(name, weight[out_channels_idx, :, :, :].to(vega.get_devices())) + else: + module.set_weights(name, weight[:, :, :, out_channels_idx].to(vega.get_devices())) + + +def prune_conv2d_in_channels(arch_params, module): + """Prune in channels of conv2d.""" + weights = module.get_weights() + weights = _to_cpu(weights) + in_channels = module.in_channels + out_channels = module.out_channels + if arch_params.get('in_channels'): + in_channels_idx = [idx for idx, value in enumerate(arch_params.in_channels) if value == 1] + for name, weight in weights.items(): + if name in ["total_ops", "total_params"]: + continue + if weight is None or 'BatchNorm' in name: + continue + if weight is not None: + if is_torch_backend(): + if module.groups == 1: + module.set_weights(name, weight[:, in_channels_idx, :, :].to(vega.get_devices())) + else: + module.groups = min(in_channels, out_channels) + if module.groups < in_channels: + in_channels_diff = int(in_channels) - int(weight.shape[1] * module.groups) + in_channels_group_diff = int(in_channels_diff / module.groups) + padding = [0, 0, 0, 0, 0, 0, 0, 0] + padding[5] = in_channels_group_diff + module.set_weights(name, ops.pad(weight, padding)) + else: + if is_tf_backend(): + module.set_weights(name, weight[:, :, in_channels_idx, :]) + else: + module.set_weights(name, weight[:, :, in_channels_idx, :].to(vega.get_devices())) diff --git a/vega/modules/connections/connections.py b/vega/modules/connections/connections.py index f8de650e..709a088c 100644 --- a/vega/modules/connections/connections.py +++ b/vega/modules/connections/connections.py @@ -20,18 +20,26 @@ class ConnectionsDecorator(Module): """Base class for Connections.""" - def __init__(self, *models): - super(ConnectionsDecorator, self).__init__(*models) - for idx, model in enumerate(models): - if isinstance(model, OrderedDict): - for name, value in model.items(): - if not isinstance(value, Module) and isinstance(value, dict): - value = self.from_desc(value) - self.add_module(name, value) - else: - if not isinstance(model, Module) and isinstance(model, dict): - model = self.from_desc(model) - self.add_module(str(idx), model) + def __init__(self, *models, **kwargs): + super(ConnectionsDecorator, self).__init__(*models, **kwargs) + # for key, model in kwargs.items(): + if kwargs: + for key, model in kwargs.items(): + self.__add_module(key, model) + else: + for idx, model in enumerate(models): + self.__add_module(str(idx), model) + + def __add_module(self, key, model): + if isinstance(model, OrderedDict): + for name, value in model.items(): + if not isinstance(value, Module) and isinstance(value, dict): + value = self.from_desc(value) + self.add_module(name, value) + else: + if not isinstance(model, Module) and isinstance(model, dict): + model = self.from_desc(model) + self.add_module(key, model) def to_desc(self, recursion=True): """Convert module to desc.""" @@ -77,8 +85,8 @@ def out_channels(self): class Sequential(ConnectionsDecorator): """Sequential SearchSpace.""" - def __init__(self, *models): - super(Sequential, self).__init__(*models) + def __init__(self, *models, **kwargs): + super(Sequential, self).__init__(*models, **kwargs) def append(self, module, name=None): """Append new module.""" @@ -260,6 +268,32 @@ def call(self, inputs): return output +@ClassFactory.register(SearchSpaceType.CONNECTIONS) +class Reshape(ConnectionsDecorator): + """Create Lambda for forward x.""" + + def __init__(self, *models): + super(Reshape, self).__init__(*models) + + def call(self, x): + """Forward x.""" + inputs = None + new_shape = None + for model in self.children(): + if model is not None: + if inputs is None: + inputs = model(x) + else: + new_shape = model(x) + import torch + return torch.reshape(inputs, tuple(new_shape.to("cpu").numpy())) + + @property + def out_channels(self): + """Get out channels.""" + return [k.out_channels for k in self.children() if hasattr(k, 'out_channels')] + + @ClassFactory.register(ClassType.NETWORK) class Repeat(Module): """Repeat SearchSpace.""" diff --git a/vega/modules/loss/loss.py b/vega/modules/loss/loss.py index f33a1287..53791bd8 100644 --- a/vega/modules/loss/loss.py +++ b/vega/modules/loss/loss.py @@ -44,8 +44,8 @@ def __call__(self): if vega.is_torch_backend(): if vega.is_gpu_device(): cls_obj = cls_obj.cuda() - elif vega.is_npu_device(): - cls_obj = cls_obj.npu() + elif vega.is_npu_device() and not cls_obj.__class__.__name__ == 'SumLoss': + cls_obj = cls_obj.to(vega.get_devices()) return cls_obj except Exception as ex: logging.error("Failed to call Loss name={}, params={}".format(self._cls.__name__, params)) @@ -54,15 +54,19 @@ def __call__(self): if vega.is_torch_backend(): import torch.nn as torch_nn + ClassFactory.register_from_package(torch_nn, ClassType.LOSS) try: import timm.loss as timm_loss + ClassFactory.register_from_package(timm_loss, ClassType.LOSS) except Exception: pass elif vega.is_tf_backend(): import tensorflow.compat.v1.losses as tf_loss + ClassFactory.register_from_package(tf_loss, ClassType.LOSS) elif vega.is_ms_backend(): import mindspore.nn.loss as ms_loss + ClassFactory.register_from_package(ms_loss, ClassType.LOSS) diff --git a/vega/modules/operators/conv.py b/vega/modules/operators/conv.py index 4201d35f..51eea19d 100644 --- a/vega/modules/operators/conv.py +++ b/vega/modules/operators/conv.py @@ -9,6 +9,7 @@ # MIT License for more details. """Import all torch operators.""" +import math from vega.common import ClassType, ClassFactory from vega.modules.operators import ops @@ -124,7 +125,7 @@ def __init__(self, C_in, C_out): super(GAPConv1x1, self).__init__() self.conv1x1 = conv_bn_relu(C_in, C_out, 1, stride=1, padding=0) - def call(self, x, *args, **kwargs): + def call(self, x=None, *args, **kwargs): """Call GAPConv1x1.""" size = ops.get_shape(x)[2:] out = x @@ -153,7 +154,7 @@ def __init__(self, C_in, C_out, affine=True): self.conv_2 = ops.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.bn = ops.BatchNorm2d(C_out, affine=affine) - def call(self, x): + def call(self, x=None, *args, **kwargs): """Do an inference on FactorizedReduce.""" x = self.relu(x) out = ops.concat((self.conv_1(x), self.conv_2(x[:, :, 1:, 1:]))) @@ -182,3 +183,34 @@ def __init__(self, *models): super(Seq, self).__init__() for idx, model in enumerate(models): self.add_module(str(idx), model) + + +@ClassFactory.register(ClassType.NETWORK) +class GhostConv2d(ops.Module): + """Ghost Conv2d Module.""" + + def __init__(self, C_in, C_out, kernel_size=3, stride=1, affine=True, padding=0, ratio=2): + super(GhostConv2d, self).__init__() + self.C_out = C_out + init_channels = math.ceil(C_out / ratio) + new_channels = init_channels * (ratio - 1) + + self.primary_conv = Seq( + ops.Relu(inplace=False), + ops.Conv2d(C_in, init_channels, kernel_size=1, stride=stride, padding=padding, bias=False), + ops.BatchNorm2d(init_channels, affine=affine) + ) + + self.cheap_operation = Seq( + ops.Relu(inplace=False), + ops.Conv2d(init_channels, new_channels, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, + groups=init_channels, bias=False), + ops.BatchNorm2d(new_channels, affine=affine) + ) + + def call(self, x=None, *args, **kwargs): + """Call function.""" + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = ops.concat([x1, x2], dim=1) + return out[:, :self.C_out, :, :] diff --git a/vega/modules/operators/functions/pytorch_fn.py b/vega/modules/operators/functions/pytorch_fn.py index 91c66141..0dca48c1 100644 --- a/vega/modules/operators/functions/pytorch_fn.py +++ b/vega/modules/operators/functions/pytorch_fn.py @@ -17,8 +17,6 @@ from torch.functional import F import torch.nn.init as init from torch.nn.quantized import Conv2d as QuantConv2d -from torch import Tensor -from torch.nn import LayerNorm from torch.nn import Parameter as Torch_Parameter import vega from .serializable import OperatorSerializable @@ -43,12 +41,64 @@ def build(self): """Build network or params.""" pass - def load_state_dict(self, state_dict=None, strict=None, file_path=None): + @classmethod + def remap_state_dict(self, own_state_dict, state_dict, head_prefix=None): + """Remap state dict from npu state files.""" + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + own_keys = list(own_state_dict.keys()) + input_keys = list(state_dict.keys()) + if len(own_keys) != len(input_keys): + raise Exception("own_state_dict and state_dict have unmatched key length") + + new_state_dict = {} + own_key_prefix_occurrence_map = {} + input_key_prefix_occurrence_map = {} + + def _has_prefix(key, prefixes): + if not prefixes or not key: + return False + if isinstance(prefixes, str): + prefixes = [prefixes] + for prefix in prefixes: + if key.startswith(prefix): + return True + return False + + for i in range(len(own_keys)): + own_key = own_keys[i] + input_key = input_keys[i] + if _has_prefix(input_key, head_prefix): + continue + + own_key_prefix = own_key[:own_key.rfind(".")] if own_key.rfind(".") != -1 else own_key + own_key_suffix = own_key[own_key.rfind("."):] if own_key.rfind(".") != -1 else own_key + input_key_prefix = input_key[:input_key.rfind(".")] if input_key.rfind(".") != -1 else input_key + input_key_suffix = input_key[input_key.rfind("."):] if input_key.rfind(".") != -1 else input_key + if own_key_prefix not in own_key_prefix_occurrence_map.keys(): + own_key_prefix_occurrence_map[own_key_prefix] = \ + sum(s.startswith(own_key_prefix + ".") for s in own_keys) + if input_key_prefix not in input_key_prefix_occurrence_map.keys(): + input_key_prefix_occurrence_map[input_key_prefix] = \ + sum(s.startswith(input_key_prefix + ".") for s in input_keys) + own_key_prefix_occurrence = own_key_prefix_occurrence_map[own_key_prefix] + input_key_prefix_occurrence = input_key_prefix_occurrence_map[input_key_prefix] + if own_key_prefix_occurrence == input_key_prefix_occurrence and own_key_suffix == input_key_suffix: + new_state_dict[own_key] = state_dict[input_key] + else: + raise Exception("unmatched own_key {} and input_key {}".format(own_key, input_key)) + + return new_state_dict + + def load_state_dict(self, state_dict=None, strict=None, file_path=None, + exclude_weight_prefix=None): """Load state dict from state_dict or file.""" state_dict = torch.load(file_path) if file_path is not None else state_dict self.strict = strict if strict is not None else self.strict state_dict = self._exclude_checkpoint_by_prefix(state_dict) own_states = self.state_dict() + if vega.is_npu_device(): + state_dict = self.remap_state_dict(own_states, state_dict, exclude_weight_prefix) not_swap_keys = [] for own_key, own_state in own_states.items(): state = state_dict.get(own_key) @@ -87,7 +137,7 @@ def _exclude_checkpoint_by_prefix(self, states): def set_parameters(self, name, value): """Set Parameters.""" if vega.is_npu_device(): - self.register_parameter(name, nn.Parameter(value.npu())) + self.register_parameter(name, nn.Parameter(value.to(vega.get_devices()))) elif vega.is_gpu_device(): self.register_parameter(name, nn.Parameter(value.cuda())) else: @@ -152,7 +202,7 @@ def forward(self, input): input = torch.quantize_per_tensor(input, 1.0, 0, self._quant_type[self.quant_bit]) output = super().forward(input) if vega.is_npu_device(): - output = torch.dequantize(output).npu() + output = torch.dequantize(output).to(vega.get_devices()) elif vega.is_gpu_device(): output = torch.dequantize(output).cuda() else: @@ -160,6 +210,25 @@ def forward(self, input): return output +@ClassFactory.register(ClassType.NETWORK) +class ConvTranspose2d(nn.ConvTranspose2d, OperatorSerializable): + """MaxPool2d Module inherit nn.MaxPool2d.""" + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, + dilation=1, padding_mode='zeros'): + """Construct MaxPool2d class.""" + if isinstance(padding, str): + padding = kernel_size // 2 if isinstance(kernel_size, int) else [v // 2 for v in kernel_size] + padding = padding if not isinstance(padding, str) else kernel_size // 2 + super(ConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + groups=groups, padding=padding, bias=bias, dilation=dilation) + + def forward(self, x, output_size=None): + """Do an inference on Identity.""" + return super().forward(x, output_size) + + @ClassFactory.register(ClassType.NETWORK) class Conv2d(nn.Conv2d, OperatorSerializable): """Conv2d Module inherit nn.Module.""" @@ -205,7 +274,10 @@ def get_weights(self): def set_weights(self, name, weight): """Set weights.""" - self._parameters[name] = nn.Parameter(weight) + if name == 'weight': + self.weight.data = weight + elif name == 'bias': + self.bias.data = weight @ClassFactory.register(ClassType.NETWORK) @@ -255,19 +327,6 @@ def set_weights(self, name, weight): self.bias.data = weight -@ClassFactory.register(ClassType.NETWORK) -class Pad(nn.ZeroPad2d, OperatorSerializable): - """Pad Module inherit nn.ZeroPad2d.""" - - def __init__(self, kernel_size=None, stride=1, padding=0): - """Construct MaxPool2d class.""" - super(Pad, self).__init__(padding) - - def forward(self, x): - """Do an inference on Identity.""" - return x - - @ClassFactory.register(ClassType.NETWORK) class MaxPool2d(nn.MaxPool2d, OperatorSerializable): """MaxPool2d Module inherit nn.MaxPool2d.""" @@ -386,7 +445,10 @@ def get_weights(self): def set_weights(self, name, weight): """Set weights.""" - self._parameters[name] = nn.Parameter(weight) + if name == 'weight': + self.weight.data = weight + elif name == 'bias': + self.bias.data = weight @ClassFactory.register(ClassType.NETWORK) @@ -624,6 +686,160 @@ def forward(self, x): return super(Embedding, self).forward(x) +@ClassFactory.register(ClassType.NETWORK) +class Clip(nn.Module, OperatorSerializable): + """Clip of torch.""" + + def __init__(self, min=float("-inf"), max=float("inf")): + """Construct Clip class.""" + super(Clip, self).__init__() + self.min = float(min) + self.max = float(max) + + def forward(self, x): + """Do an inference on Clip. + + :param x: input tensor + :return: output tensor + """ + return torch.clamp(x, min=0, max=self.max) + + +@ClassFactory.register(ClassType.NETWORK) +class Shape(nn.Module, OperatorSerializable): + """Shape of torch.""" + + def __init__(self, start=0, end=None): + """Construct Shape class.""" + super(Shape, self).__init__() + self.start = start + self.end = end + + def forward(self, x): + """Do an inference on Shape. + + :param x: input tensor + :return: output tensor + """ + if self.end: + output = torch.tensor(x.shape)[self.start:self.end] + else: + output = torch.tensor(x.shape)[self.start:] + return output.to(vega.get_devices()) + + +@ClassFactory.register(ClassType.NETWORK) +class Gather(nn.Module, OperatorSerializable): + """Gather block.""" + + def __init__(self, axis=0): + """Construct Gather class.""" + super(Gather, self).__init__() + self.axis = axis # compatible with dim in pytorch + + def forward(self, x): + """Do an inference on Gather. + + :param x: input tensor + :return: output tensor + """ + return torch.gather(x, self.axis, torch.tensor(0)) + + +@ClassFactory.register(ClassType.NETWORK) +class Unsqueeze(nn.Module, OperatorSerializable): + """Unsqueeze block.""" + + def __init__(self, axes): + """Construct Identity class.""" + super(Unsqueeze, self).__init__() + self.axes = axes + + def forward(self, x): + """Do an inference on Unsqueeze. + + :param x: input tensor + :return: output tensor + """ + if not isinstance(self.axes, list): + logging.error("Unsqueeze axes: {} must be list".format(self.axes)) + return None + output = x + for axis in self.axes: + output = torch.unsqueeze(output, axis) + return output + + +@ClassFactory.register(ClassType.NETWORK) +class ConcatTensor(nn.Module, OperatorSerializable): + """ConcatTensor block.""" + + def __init__(self, axis=0): + """Construct ConcatTensor class.""" + super(ConcatTensor, self).__init__() + self.axis = axis + + def forward(self, x): + """Do an inference on ConcatTensor. + + :param x: input tensor + :return: output tensor + """ + return torch.cat((x, (torch.tensor([-1]).to(vega.get_devices()))), dim=self.axis) + + +@ClassFactory.register(ClassType.NETWORK) +class Mean(nn.Module, OperatorSerializable): + """Mean block.""" + + def __init__(self, axes=None, keepdims=False): + """Construct Mean class.""" + super(Mean, self).__init__() + self.axes = axes + self.keepdims = keepdims + + def forward(self, x): + """Do an inference on Mean. + + :param x: input tensor + :return: output tensor + """ + if self.axes: + return torch.mean(x, dim=self.axes, keepdim=self.keepdims) + return torch.mean(x, keepdim=self.keepdims) + + +@ClassFactory.register(ClassType.NETWORK) +class Pad(nn.Module, OperatorSerializable): + """Pad block.""" + + def __init__(self, mode="constant", padding=None): + self.mode = mode + self.padding = padding + super().__init__() + + def forward(self, input, pads=None, value=0): + """Call forward.""" + if self.padding is not None: + pads = self.padding + elif pads is None: + raise TypeError("forward() missing 1 required positional argument: 'pads'") + return F.pad(input, list(pads), mode=self.mode, value=value) + + +@ClassFactory.register(ClassType.NETWORK) +class Reshape(nn.Module, OperatorSerializable): + """Reshape class.""" + + def __init__(self, shape=[-1, 1024]): + super().__init__() + self.shape = shape + + def forward(self, input: torch.Tensor): + """Forward function.""" + return torch.reshape(input, tuple(self.shape)) + + def concat(inputs, dim=1): """Call concat according to backends.""" return torch.cat(inputs, dim=dim) @@ -681,7 +897,9 @@ def mean_all(inputs): def pad(inputs, position): """Apply pad function.""" - return F.pad(inputs, position) + # return F.pad(inputs, position) + dtype = inputs.dtype + return F.pad(inputs.cpu().float(), position).to(vega.get_devices()).to(dtype) def interpolate(input, size, mode='bilinear', align_corners=False): @@ -1044,3 +1262,5 @@ def __new__(cls, data=None, requires_grad=True, name=None): MSELoss = nn.MSELoss +Tensor = torch.Tensor +LayerNorm = torch.nn.LayerNorm diff --git a/vega/modules/operators/functions/tensorflow_fn.py b/vega/modules/operators/functions/tensorflow_fn.py index 4039a105..9cb29483 100644 --- a/vega/modules/operators/functions/tensorflow_fn.py +++ b/vega/modules/operators/functions/tensorflow_fn.py @@ -204,6 +204,11 @@ def _apply_weights(self): values = [(var, self._weights_buffer.get(var.name.replace(':0', ''))) for var in variables if var.name.replace(':0', '') in self._weights_buffer] for v, weight in values: + if len(v.shape) == 4: + if v.shape[2] != weight.shape[2]: + import torch + num = v.shape[2] // weight.shape[2] + weight = torch.cat([weight] * num, 2) v._initializer_op = state_ops.assign(v, weight) self._weights_buffer.clear() diff --git a/vega/modules/operators/ops.py b/vega/modules/operators/ops.py index a25132ee..3e0ebeb1 100644 --- a/vega/modules/operators/ops.py +++ b/vega/modules/operators/ops.py @@ -23,6 +23,15 @@ ConvWS2d = fn.ConvWS2d GroupNorm = fn.GroupNorm SyncBatchNorm = fn.SyncBatchNorm + ConvTranspose2d = fn.ConvTranspose2d + Clip = fn.Clip + Shape = fn.Shape + Gather = fn.Gather + Unsqueeze = fn.Unsqueeze + ConcatTensor = fn.ConcatTensor + Mean = fn.Mean + Pad = fn.Pad + Reshape = fn.Reshape Module = fn.Module Conv2d = fn.Conv2d diff --git a/vega/networks/__init__.py b/vega/networks/__init__.py index 0f1cd1d4..414b69f2 100644 --- a/vega/networks/__init__.py +++ b/vega/networks/__init__.py @@ -37,6 +37,7 @@ "gcn": ["GCN"], "vit": ["VisionTransformer"], "mtm_sr": ["MtMSR"], + "unet": ["Unet"] }) diff --git a/vega/networks/__pycache__/__init__.cpython-37.pyc b/vega/networks/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 8715a77dccc3cf1162106af26040331781d5448d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1467 zcmb7E%aYqf6xGXb*)tOolZ2Of<0NHMge67AV>1I)8C9^RQdDn*tZoZsq?XgInZzz; z$7fI!Ecp(;fsMDU@(HlwbjvmqSRvK9>OQ)C@9iVqK5Vy}2CnS4-@pB{Y8Zd0$@aR9 z!2>}41VasCP?K7znOVd#4Qf*-wKIn}ropVAE3`tZu}N#RP8+dJo3s^M)CJQ9I=~*? z2XjEL!M=j|b>Ifj1zrVi0=IzMz-z$k^bPtZBE3a>Sn)QPcYt?+_kQgDqAGZwe(-}q zs`NwO=zp}j^HsLUMbQg6?FkkqvS6aeS@A3vXT9*e$g{9WqA*RDgBRG3(KpxbI1Qyd z3Zo(y%NJ%zw!Cu9gvn^tzF=`U*hGo1TT`DlkRE4 zmx@2i#UsH^MVRn}$K>XJeZOQOWr8>ptzGHLipwn*8SbJFq`|HN@8p(s>P?X^E1P5qad4f!yQB^+^~WL34C)hya%`f` z1f2y;3bHrXS_ZSN#H|e>B2?VSw3=#3Yd5^OPb!>6XOdK&DuqZ>FbS7{!mYCz$6I~M zIce$V^U}$caVP^U(R*)7tx%v6cH}}u+Evq*=L;rwuJr;EXh%DaII%C1f>g623#5qD zpQ&>r*PzTjtgJCQ_9!>XGdvxHED*RQQ>K|u2;2-dfB=SO*#<$MsSb0V0R2Zd0Q6L*sizk-Z} z`V#AQ8m^YXHlYrwi@su7%CG)emecc;eXeHmQy9l`P1mlP4$4th-PPTwoR@a{-gW*4 DXC<%0 diff --git a/vega/networks/__pycache__/model_config.cpython-37.pyc b/vega/networks/__pycache__/model_config.cpython-37.pyc deleted file mode 100644 index ae80e1aa2008034d1c823a2cb6153bde8b12ef26..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1417 zcmZ`&&5qkP5GEz+$A9Z&lU<;Hs2cR*7O;b!i=b$W^prGc_M{8T6=_>uN^(g$#k$(l z2FPpJmjZo`K0|xpwWq#9PaRU;-KIb!%xK8r48Qqi^z~@eNAOg?|MBCeDMEj_#mz&& z_zFb70L4*`IOc)GIriTm58xZhup&9Z&PTwG@~DdQ*smj*RB4`K^c~_cPhKOQ2yz|f zJs!M7lXUz3w_;w_f@j|}^*oz5Dtl3`g@mo?-*6*(KEd|b`O3vhp-LvpEjyQD64>Nv zDaH3|ZR5@j2m5DCUv#U&1>Zc;!O%~^5D35{2RsjXkQ0yyBnC-9QV;-__dy0?$m2I? z4rTG^b?+MIqc@14{0N{YcJQp>LUz|}A3+^#BTeQvv%({OyPoG~MR)+xgKq-j*4lzH z$lw-T2fLufZLkln!WQpBj`vtCy%p@S25Uuj5f4qgN_MH~t@_GGPuQOpe)?R;)~a%>~LSqr6N7uA~s zCh8E`Z~@0%b8Mo-TCzgepwX7N3oThSWv$#i+429t*u=o}(UL3IRoinXJxWU!vN>1o zPg3xEA?5zINzaCbgtVL)X8Zc0T+?DHiVGVV3XVKJ5RkyVMyMKw(v|A4D>pt}7HoSI zjx((({SE!|#cZJ-&$g@844^kNwlPg5XRpM9&1zwOZq$XIbsTg9I$d9?Atd`4L}#E- zc#I$76l92xKpx^gJ_T{_DUR{>JxWS;J8;}hi>9hzz5=)Hn(KRM qAN-joK=eI44F3fMu6&mO diff --git a/vega/networks/__pycache__/network_desc.cpython-37.pyc b/vega/networks/__pycache__/network_desc.cpython-37.pyc deleted file mode 100644 index e9c3386ec71fec308d0c10190c4cd2a7d2f52773..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1168 zcmY*YJC74F5VpO~JPr;75+I5+6cH#vR1|;^$5V(P!3m^o%(8Yu!tOfQJ|H32aQu*5 z$?xDd(3p~nzd*%|mx$n%XXoJ=&o?u^v$NACu&Qt0U;PLP`H8}2flO}0HV=Uak`X}# zD``gIAC#;LvVfv4EW;|wBF{%uu3mn%S-!@F+sCHFcEwqStvpg!9Nl)B$Bw!RszfJmM|zV6fXP+u=ik_(?Be- zbVZNgTM-VyArM=hdQ?yCe{e?qK%<|=050MJFanIw$a3sLl)e&;rT=|xIMG*!%h_zG zVdof z#*#HO@|aR$**dV{EE4RP9Rj#@ydtyYm>#kfGuLe@f_1CW9MR7JdzP#iZ`;m_iSUT& zXDhbu9urDd*xqP)d0*$X87p1&^6i@RtOA9Bo=x+zS1BRO!Scd$n`?`vJuP!9|G5lY z{6vX^QZ6su&!;7f3mht%L*DsH6QhU4y!q+Sb$LFYmkVC!l{Ebwmy~KUnbs2*3Hkb9 zq7l7ICerevG&=Iq=-&Rzr_UZ|t+7^>FRxC}(7N}rkaIg#wTlr+ZgtB$yva>&t=1<% zf)=N!j3FJ!(?vcl-_?e6(3)Dsz2Q7aeuM~X(-X8yL%K^-+J^mq?=}vjevF+)Je<2W z$1rdY@LkT|9pvR-OTxKO1?Rd8H|TRnaI&@tCZb~?IKc*)l!hz~y{_NU7twwahzs%M zHACCM85|U self.max_duration: - return True - else: - return False diff --git a/vega/quota/flops_params.py b/vega/quota/flops_params.py new file mode 100644 index 00000000..0c8f1d3b --- /dev/null +++ b/vega/quota/flops_params.py @@ -0,0 +1,44 @@ +# -*- coding:utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. + +"""Flops and Parameters Filter.""" + +import logging +from vega.metrics import calc_model_flops_params +from vega.model_zoo import ModelZoo +from .quota_item_base import QuotaItemBase + +logger = logging.getLogger(__name__) + + +class FlopsParamsVerification(QuotaItemBase): + """Flops and Parameters Filter class.""" + + def __init__(self, params_range, flops_range): + self.params_range = params_range + self.flops_range = flops_range + + def verify(self, model_desc=None): + """Verify params and flops.""" + try: + model = ModelZoo.get_model(model_desc) + count_input = self.get_input_data() + flops, params = calc_model_flops_params(model, count_input) + flops, params = flops * 1e-9, params * 1e-3 + result = flops > self.flops_range[0] and flops < self.flops_range[1] + result = result and params > self.params_range[0] and params < self.params_range[1] + if not result: + logger.info(f"params ({params}) or flops ({flops}) out of range.") + return result + except Exception as e: + import traceback + print(traceback.format_exc()) + logging.info(f"Invild model desc: {model_desc}, error: {e}") + return False diff --git a/vega/quota/flops_params_filter.py b/vega/quota/flops_params_filter.py deleted file mode 100644 index ac33efc0..00000000 --- a/vega/quota/flops_params_filter.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Flops and Parameters Filter.""" -import logging -from vega.common import ClassFactory, ClassType -from vega.metrics import calc_model_flops_params -from .filter_terminate_base import FilterTerminateBase - -logger = logging.getLogger(__name__) - - -@ClassFactory.register(ClassType.QUOTA) -class FlopsParamsFilter(FilterTerminateBase): - """Flops and Parameters Filter class.""" - - def __init__(self): - super(FlopsParamsFilter, self).__init__() - self.flops_range = self.restrict_config.flops - self.params_range = self.restrict_config.params - if self.flops_range and not isinstance(self.flops_range, list): - self.flops_range = [0., self.flops_range] - if self.params_range and not isinstance(self.params_range, list): - self.params_range = [0., self.params_range] - if self.flops_range is not None or self.params_range is not None: - dataset_cls = ClassFactory.get_cls(ClassType.DATASET) - self.dataset = dataset_cls() - from vega.datasets import Adapter - self.dataloader = Adapter(self.dataset).loader - - def is_filtered(self, desc=None): - """Filter function of Flops and Params.""" - if self.flops_range is None and self.params_range is None: - return False - model, count_input = self.get_model_input(desc) - flops, params = calc_model_flops_params(model, count_input) - flops, params = flops * 1e-9, params * 1e-3 - if self.flops_range is not None: - if flops < self.flops_range[0] or flops > self.flops_range[1]: - logger.info("The flops {} is out of range. Skip this network.".format(flops)) - return True - if self.params_range is not None: - if params < self.params_range[0] or params > self.params_range[1]: - logger.info("The parameters {} is out of range. Skip this network.".format(params)) - return True - return False diff --git a/vega/quota/latency.py b/vega/quota/latency.py new file mode 100644 index 00000000..7012b908 --- /dev/null +++ b/vega/quota/latency.py @@ -0,0 +1,38 @@ +# -*- coding:utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. + +"""Flops and Parameters Filter.""" + +import logging +import vega +from vega.metrics import calc_forward_latency_on_host +from vega.model_zoo import ModelZoo +from .quota_item_base import QuotaItemBase + + +class LatencyVerification(QuotaItemBase): + """Latency Filter class.""" + + def __init__(self, latency_range): + self.latency_range = latency_range + + def verify_on_host(self, model_desc): + """Filter function of latency.""" + model = ModelZoo.get_model(model_desc) + count_input = self.get_input_data() + trainer = vega.trainer(model_desc=model_desc) + sess_config = trainer._init_session_config() if vega.is_tf_backend() else None + latency = calc_forward_latency_on_host(model, count_input, sess_config) + logging.info(f"Sampled model's latency: {latency}ms") + if latency < self.latency_range[0] or latency > self.latency_range[1]: + logging.info(f"The latency ({latency}) is out of range. Skip this network.") + return False + else: + return True diff --git a/vega/quota/latency_filter.py b/vega/quota/latency_filter.py deleted file mode 100644 index aaa38e54..00000000 --- a/vega/quota/latency_filter.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Flops and Parameters Filter.""" -import logging -import vega -from vega.common import ClassFactory, ClassType -from vega.metrics import calc_forward_latency -from .filter_terminate_base import FilterTerminateBase - - -@ClassFactory.register(ClassType.QUOTA) -class LatencyFilter(FilterTerminateBase): - """Latency Filter class.""" - - def __init__(self): - super(LatencyFilter, self).__init__() - self.max_latency = self.restrict_config.latency - if self.max_latency is not None: - dataset_cls = ClassFactory.get_cls(ClassType.DATASET) - self.dataset = dataset_cls() - from vega.datasets import Adapter - self.dataloader = Adapter(self.dataset).loader - - def is_filtered(self, desc=None): - """Filter function of latency.""" - if self.max_latency is None: - return False - model, count_input = self.get_model_input(desc) - trainer = ClassFactory.get_cls(ClassType.TRAINER)(model_desc=desc) - sess_config = trainer._init_session_config() if vega.is_tf_backend() else None - latency = calc_forward_latency(model, count_input, sess_config) - logging.info('Sampled model\'s latency: {}ms'.format(latency)) - if latency > self.max_latency: - logging.info('The latency is out of range. Skip this network.') - return True - else: - return False diff --git a/vega/quota/model_valid.py b/vega/quota/model_valid.py new file mode 100644 index 00000000..03d4b559 --- /dev/null +++ b/vega/quota/model_valid.py @@ -0,0 +1,30 @@ +# -*- coding:utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. + +"""Model Valid Verification.""" + +import logging +from .quota_item_base import QuotaItemBase +from vega.model_zoo import ModelZoo + + +class ModelValidVerification(QuotaItemBase): + """Model valid verification.""" + + def verify(self, model_desc): + """Filter function of latency.""" + try: + model = ModelZoo.get_model(model_desc) + count_input = self.get_input_data() + model(count_input) + return True + except Exception as e: + logging.info(f"Invild model desc: {model_desc}, error: {e}") + return False diff --git a/vega/quota/quota.py b/vega/quota/quota.py new file mode 100644 index 00000000..35799188 --- /dev/null +++ b/vega/quota/quota.py @@ -0,0 +1,121 @@ +# -*- coding:utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. + +"""Quota.""" + +import logging +from vega.common.class_factory import ClassFactory, ClassType +from vega.common.general import General +from .model_valid import ModelValidVerification +from .flops_params import FlopsParamsVerification +from .quota_affinity import QuotaAffinity +from .latency import LatencyVerification +# from .runtime import RuntimeVerification + +logger = logging.getLogger(__name__) + + +@ClassFactory.register(ClassType.QUOTA) +class Quota(object): + """Determine whether to terminate duration.""" + + def __init__(self, config=None): + self.enable = False + self.params_range = [] + self.flops_range = [] + self.host_latency_range = [] + self.device_latency_range = [] + self.model_valid_enable = False + self.affinity_enable = False + self.runtime = None + self._set_config(config or General.quota) + + def _set_config(self, config): + """Set quota. + + config examples: + "accuray > 12 and flops in [23, 45] and model_valid" + "accuray in [12, 14]" + "accuray > 12" + "model_valid" + """ + if config is None or config.strip() == "": + return + items = config.split("and") + for item in items: + if "model_valid" in item: + self.model_valid_enable = True + elif "affinity" in item: + self.affinity_enable = True + elif "params" in item: + self.params_range = self._set_value(item) + elif "flops" in item: + self.flops_range = self._set_value(item) + elif "host_latency" in item: + self.host_latency_range = self._set_value(item) + elif "device_latency" in item: + self.device_latency_range = self._set_value(item) + elif "runtime" in item: + self.runtime = float(item.split("<")[1].strip()) + self.enable = True + + def _set_value(self, value): + if "in" in value: + value_range = value.split("in")[1].strip()[1:-1].split(",") + return [float(value_range[0]), float(value_range[1])] + elif ">" in value: + return [float(value.split(">")[1].strip()), float("inf")] + elif "<" in value: + return [-float('inf'), float(value.split("<")[1].strip())] + else: + raise ValueError(f"valid quota value: {value}") + + def verify_sample(self, model_desc): + """Verify model_valid, flops, params.""" + if not self.enable: + return True + if self.model_valid_enable and len(self.flops_range) == 0 and len(self.params_range) == 0: + result = ModelValidVerification().verify(model_desc) + if not result: + return False + if len(self.flops_range) == 2 or len(self.params_range) == 2: + result = FlopsParamsVerification(self.params_range, self.flops_range).verify(model_desc) + if not result: + return False + if len(self.host_latency_range) == 2: + result = LatencyVerification(self.host_latency_range).verify_on_host(model_desc) + if not result: + return False + return True + + def verify_affinity(self, model_desc): + """Verify affinity.""" + if not self.enable or not self.affinity_enable: + return True + affinity = QuotaAffinity(General.affinity_config) + return affinity.is_affinity(model_desc) + + def adjuest_pipeline_by_runtime(self, user_config): + """Adjuest pipeline by runtime.""" + if not self.enable or self.runtime is None: + return True + # RuntimeVerification(self.runtime).adjust_config(user_config) + return True + + def verify_metric(self, model_desc): + """Verify metrics.""" + return True + + @property + def quota_reached(self): + """Return True if reach the limits.""" + # runtime|duration, samples|trials, metrics + # get data from report + return False diff --git a/vega/core/quota/quota_affinity.py b/vega/quota/quota_affinity.py similarity index 99% rename from vega/core/quota/quota_affinity.py rename to vega/quota/quota_affinity.py index ba41be2c..48b046c3 100644 --- a/vega/core/quota/quota_affinity.py +++ b/vega/quota/quota_affinity.py @@ -9,6 +9,7 @@ # MIT License for more details. """Quota for Affinity.""" + import logging import pandas as pd from sklearn import ensemble diff --git a/vega/quota/quota_compare.py b/vega/quota/quota_compare.py deleted file mode 100644 index 9fcac360..00000000 --- a/vega/quota/quota_compare.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Metric Target Termination.""" -from vega.common import ClassFactory, ClassType -from vega.common.general import General -import copy - - -class QuotaCompare(object): - """Determine whether to satisfy target.""" - - def __init__(self, type): - if type not in ['restrict', 'target']: - raise ValueError('Input type must be restriction or target.') - self.filter_types, self.terminate_types = [], [] - self.filter_compares = dict() - self.terminate_compares = dict() - self.filters_to_params = dict() - self._init_compare_types(type) - for filter in self.filter_types: - t_cls = ClassFactory.get_cls(ClassType.QUOTA, filter) - self.filter_compares[filter] = t_cls() - for terminate in self.terminate_types: - t_cls = ClassFactory.get_cls(ClassType.QUOTA, terminate) - self.terminate_compares[terminate] = t_cls() - self.filter_rules = copy.deepcopy(General.quota.filter_rules) - - def is_filtered(self, res): - """Quota Compare filter function.""" - if len(self.filter_types) == 0: - return False - exact_filters = [] - for filter in self.filter_types: - if self.filters_to_params[filter] in self.filter_rules: - exact_filters.append(filter) - filter_to_bool = dict() - for filter in exact_filters: - filter_to_bool[filter] = 'self.filter_compares[\'{}\'].is_filtered(res)'.format(filter) - filter_rules_str = copy.deepcopy(self.filter_rules) - for filter in exact_filters: - filter_rules_str = filter_rules_str.replace(self.filters_to_params[filter], filter_to_bool[filter]) - return bool(eval(filter_rules_str)) - - def is_halted(self, *args, **kwargs): - """Quota Compare halt function.""" - if len(self.terminate_types) == 0: - return False - for compare in self.terminate_compares.values(): - if compare.is_halted(args, kwargs): - return True - return False - - def filter_by_name(self, res, name): - """Filter sample by filter rule name.""" - filter = self.filter_compares[name] - if filter.is_filtered(res): - return True - return False - - def _init_compare_types(self, type): - """Initialize compare types.""" - if type == 'restrict': - restrict_config = copy.deepcopy(General.quota.restrict) - if restrict_config.flops or restrict_config.params: - self.filter_types.append('FlopsParamsFilter') - if restrict_config.latency: - self.filter_types.append('LatencyFilter') - if restrict_config.model_valid: - self.filter_types.append('ValidFilter') - if restrict_config.duration: - self.terminate_types.append('DurationTerminate') - if restrict_config.trials: - self.terminate_types.append('TrialTerminate') - elif type == 'target': - target_config = copy.deepcopy(General.quota.target) - if target_config.type: - self.terminate_types.append('TargetTerminate') - self.filters_to_params = { - 'ValidFilter': 'model_valid', - 'FlopsParamsFilter': 'flops_params', - 'LatencyFilter': 'max_latency' - } diff --git a/vega/quota/filter_terminate_base.py b/vega/quota/quota_item_base.py similarity index 51% rename from vega/quota/filter_terminate_base.py rename to vega/quota/quota_item_base.py index 3f56a4aa..5ebba904 100644 --- a/vega/quota/filter_terminate_base.py +++ b/vega/quota/quota_item_base.py @@ -8,49 +8,32 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. -"""Flops and Parameters Filter.""" -import copy +"""Quota item base.""" + import vega -from vega.common import ClassFactory, ClassType -from vega.common.general import General +from vega.core.pipeline.conf import PipeStepConfig -@ClassFactory.register(ClassType.QUOTA) -class FilterTerminateBase(object): +class QuotaItemBase(object): """Restrict and Terminate Base Calss.""" - def __init__(self): - self.restrict_config = copy.deepcopy(General.quota.restrict) - self.target_config = copy.deepcopy(General.quota.target) - - def is_halted(self, *args, **kwargs): - """Decide to halt or not.""" - - def is_filtered(self, desc=None): - """Decide to filter or not.""" - - def get_model_input(self, desc): - """Get model and input.""" - from vega.networks.network_desc import NetworkDesc - model = NetworkDesc(desc).to_model() - count_input = self.get_input_data() - return model, count_input - def get_input_data(self): """Get input data.""" count_input = None + dataset_name = PipeStepConfig.dataset.type + dataloader = vega.dataset(dataset_name).loader if vega.is_torch_backend(): - data_iter = iter(self.dataloader) - input_data, _ = data_iter.next() + _iter = iter(dataloader) + input_data, _ = _iter.next() count_input = input_data[:1] elif vega.is_tf_backend(): import tensorflow as tf - datasets = self.dataloader.input_fn() + datasets = dataloader.input_fn() data_iter = tf.compat.v1.data.make_one_shot_iterator(datasets) input_data, _ = data_iter.get_next() count_input = input_data[:1] elif vega.is_ms_backend(): - data_iter = self.dataloader.create_dict_iterator() + data_iter = dataloader.create_dict_iterator() for batch in data_iter: count_input = batch['image'] break diff --git a/vega/quota/target_terminate.py b/vega/quota/target_terminate.py deleted file mode 100644 index a15cd628..00000000 --- a/vega/quota/target_terminate.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Metric Target Termination.""" -from vega.common import ClassFactory, ClassType -from .filter_terminate_base import FilterTerminateBase - - -@ClassFactory.register(ClassType.QUOTA) -class TargetTerminate(FilterTerminateBase): - """Determine whether to satisfy target.""" - - def __init__(self): - super(TargetTerminate, self).__init__() - self.target_type = self.target_config.type - self.target_value = self.target_config.value - - def is_halted(self, *args, **kwargs): - """Halt or not.""" - if self.target_type is None or self.target_value is None: - return False - valid_metric = kwargs[self.target_type] - if valid_metric > self.target_value: - return True - else: - return False diff --git a/vega/quota/trial_terminate.py b/vega/quota/trial_terminate.py deleted file mode 100644 index 050f8d09..00000000 --- a/vega/quota/trial_terminate.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Run Duration Termination.""" -from vega.common import ClassFactory, ClassType -from vega.common.general import General -from .filter_terminate_base import FilterTerminateBase - - -@ClassFactory.register(ClassType.QUOTA) -class TrialTerminate(FilterTerminateBase): - """Determine whether to terminate duration.""" - - def __init__(self): - super(TrialTerminate, self).__init__() - self.max_trial = self.restrict_config.trials.get(General.step_name, None) - self.count_trial = 0 - - def is_halted(self, *args, **kwargs): - """Halt or not.""" - if self.max_trial is None: - return False - self.count_trial += 1 - if self.count_trial > self.max_trial: - return True - else: - return False diff --git a/vega/quota/valid_filter.py b/vega/quota/valid_filter.py deleted file mode 100644 index f7840db9..00000000 --- a/vega/quota/valid_filter.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding:utf-8 -*- - -# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. -# This program is free software; you can redistribute it and/or modify -# it under the terms of the MIT License. -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# MIT License for more details. - -"""Valid Filter.""" -import logging -from vega.common import ClassFactory, ClassType -from .filter_terminate_base import FilterTerminateBase - - -@ClassFactory.register(ClassType.QUOTA) -class ValidFilter(FilterTerminateBase): - """Valid Filter class.""" - - def __init__(self): - super(ValidFilter, self).__init__() - self.dataloader = None - - def is_filtered(self, desc=None): - """Filter function of latency.""" - try: - if not self.dataloader: - dataset_cls = ClassFactory.get_cls(ClassType.DATASET) - self.dataset = dataset_cls() - from vega.datasets import Adapter - self.dataloader = Adapter(self.dataset).loader - - model, count_input = self.get_model_input(desc) - model(count_input) - return False - except Exception as e: - encoding = desc['backbone']['encoding'] - logging.info(f"Invalid encoding: {encoding}, message: {str(e)}") - return True diff --git a/vega/report/__pycache__/__init__.cpython-37.pyc b/vega/report/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 7e31f013721d0bb2d713289d8acccdbd83b1a9db..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 360 zcmYLEJ5Izf6pZuRfHoBcXNZI_B?UtK>J{2_ZcG$gkyrjiFV0GEBYJMY8EDc{aRn;; zEG_#ik7q`HZ|sMPM}l^+&yxy#vvt%h|@+CufxHycBKd zRCZX-k=$F+N-cZ;?8aGlUOBsIPFM{wfKM?2@cI8;&j^0-Wg5za_FJ~o^r0N@|<6euS<1foS3z|#XnCtY)x?e E0l=4HPyhe` diff --git a/vega/report/__pycache__/nsga_iii.cpython-37.pyc b/vega/report/__pycache__/nsga_iii.cpython-37.pyc deleted file mode 100644 index 181336b76256497ccb7ba8b0559dfbe6bbfd29b5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5623 zcmcIoO>ZPe8LsN?>F)XPtg~yf_HGgl(GnOXv6t|X;5dtSH|%Cn)-3Uc6+;rIGu0l? zc&2+?-A=}%YAKQfBodJV2QFO3M=spB1i^(%#2rDJNE|rrKR~&_^H$IFjMp0!B$!cu zRP|fc_3^yV`@DT+Zmz`e>;3W1ufM&>*gvQ+`5Cxb#}U1PlVA-dxG=iBVc=?Z&7Re; zI1`pAyu%s=VG9RWTezZ#t0PLHjH@duVh-1$sESj#mc+bx0N1iOEzaOt5ymE~FAN{v z*t~xE$s0Fr)FQv@H{*5~)VyvhlkLw7NW5Grr+dL#*ap5;kb$;I*pTJ z9X(26zHcO4arA8sUw*kSyx6a`h3|E1{jeRx%cW8cztz4cy&kIitF?H?uQm6j^n;ix z_PZVpA~TB*`hI5RN$DNv%DXg(+ZdC(>-&At?nSsad+nfZq$WB@tsx4(X3~Y-r%$Z3 zF#@}}SNbcKvpf~e(fwcG9G{y{pQjrax)_!f$ZcI-j3-u{15i4NnoKL zEh8P9IHE^z5v`81O`!P+-lkFiU_X7YtlwYLx6F3;Rc$K?+ zc=47mw?o;RFR z;9_izQFwrv`fVzao ztYeNWLZG||cVmjf3Gq}J=Yhz;Qamv=9*vC*hPK{wz1fnuFSX$~;LNn#^J3W^Y+>(r zC854<$t7(NWB2`dys6&3JGywS9f%h|b+X^}0{=!3`?Bfxan+KC|AUe3O zCDrDZbu)8y)9ZFq(+dtjM|_a-gL_w^`21u%-@`m2&3>+YoMt=0BsqQ_I(~uLamWB4 zW+4N(90v~y9GQE<^>nZxcWkj1Pwc^mle$-b12_B$N{%qkGT0e9sJI9Nmaq^o9CF5_ zn3uJ4CM7s-QI!r2IYKCKme{C3!Ggml~fUSw_)dMJ7;X?%-ZeU zCg-M}L}AyE>jIvm9Pqg<3e#45>v!-xDGz?xVo7DtQMM{CF^u@N#K}iE9`WE>HZnY2 zC32aJ>lr;RDrfKdfXfT2l8?Wv*TGco8vB6nUARMM8zItGl|xQrgZ+Dl(osg9=R&8p8-={k|WF)a?2}eJ-Po3uK}>!y@i$RP)uk8tw-%#OBx~#HJso z{8~1AigPLNr+hEvyD9(vbdpjBZ&1pKfj7`5x`Go!oTL9~Zi7x1z>o7f-eu}>h+rVhl@@FbtZ0{o0t zg1n3A5JcLq3jAE;!)GU3_|jwtwfLU~)&KAkjG-6<^-({Jd<2PA^e>;nDK#2WkD$)= z>oZV4c>h-&qCuufiCDRq8eNj()n1Pq4O`w2}ry^wIS1>?Jej# z-0SVIvl8l^{LUTxDOYcQNQwi6Qnva1q?9YS5`~Z0=oDzX$617~JK`gUGSMRvOJzW1 zCO?@xNh;%(6;&R&@SX}lQeI$-2?qb#FV>D$%8#iE=N(=Y?EW2OfPG3c(1!gCm3CepU+3@W7DZzxJ&7w zhWp0df8m!&$8yH2>-3DAa|(V*2@&NqvX?s2&vJ`PWI~QC&*D@+CBKAA!wdkdoO0#2 zQJ0!>Crr6W9nx%IiiW-s?PyB1WA!;qEdjmMusW1$+@s=_@-@319`}V`;Vlupl-U;O zGo#ro#i+uC7Wr4SJCVX$;PSh$%b92bUZU6NLx7hUNM!88ItG#nQzsTkW&v0zaa?UE z@d6SJGjS4Ea}M|?V;D+=4_Q(g1hEsll!9okdJXU^Le&#cQA>*Ake`J^>{?P~<)n0r z-dzObCP3A81a9phMg5dAQUV@&$a#!(&A4~cp->N_!CzrlR?tC71v+MK@Te+5rJTb! zTxzQ-VB1gB+z7BIg|YVu0Pd_kpVotc-s1%&ZFXqgxmW zMhR6&dqB>IfcF5#b{Z=|ZsiO|CqEdSaGcNQGh=m4lH_s#Wk8>wrWws_9Nd_VgFyJn zZ$z+|L_qCynV8A>lLQtef;t$IP|*Twp1xGQ%XUo(&O0fzSo{je4#%>WPC#O zBEL$f@6hQwP7OQo-@=kNie2<{O@=`{rIf~l8DGS{z?0sK6N^R_jkbDv!M2dwPQu`e z)T`EZXFp?^ehZ6;2F~1)k)@0jrESFNbD$Q3Kf)g}=Cy(&8meg0DWPnGb}F!X6g8yu zY*@?N=%+rHI$K*JY;J9(1t>7&zQnwZO4Mxk50woO|9ayLETbJ@;=TBR<3c diff --git a/vega/report/__pycache__/record.cpython-37.pyc b/vega/report/__pycache__/record.cpython-37.pyc deleted file mode 100644 index 1de3ba882e0b3605e80351878659cfeb9031547b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7149 zcmb_gJ9Hbz72Q7;AP9nAk+LnXC|MR{nUv+v$w_n^*^=$pv1y02BL{J^hA~5MNnjV8 zUFZWZQ-nHBm85Z(#D`4h6e&`sNaH3=y13aQO)6LEb5i8KSzw8!Ngn0{oP9epZ+7?1 zoBKZVULGGWDfo51`0~S&qbPr;&hTfTaT8bgV-Tj)6{a$+t=3g}*XkPXdRy=4bzLQ% z(Kb3}-IQ&!ZFLIuLdUM#onpPHDqPzu@v(iiUVcqs7ArhdSb>l4>h%iK)|8svFRb!z zAflzO@G)hjrY1IPL_G339PJ6gJHcJ<`d#;Kqut{*J+aoJM$`+Fu{XlNf5mSGjEiLE zb_1h$=G8!S8qr!LJipbXHx7R~8aHu;byZQgg2gDbkh;#EXF4+;s&#{z%zCKQO%5@c z{ZQFe>jhS1C5+i@jFoXOvT;_yy~HM1756dBon%vTZW--qHY3~Pe0;OQX0x8f=GYm@ zncx%P%x67|EwHnaQ{`1~9?5zZdz3vUIg@-6oO4;vVvjRNa;EqcIOnsT#TMBG$(iQU z;9Sgl7JGtSlAIYn1J09K&tjKZO>$=WEI3!Pp2ePGS0!hTJr{*v+OzC&&mybAt|RDy2_h@VEt#)yK}4E2ty|d93jV@M!W4`qZ^LrJAu!gkT*ng z!)XUWcWI?byG1_>mx=4UxWeCp#A>X#O00om2SGJSb&uMHY@3o=k`@@vw4z?_i1!e_NSuh%^CD&Wp6Il#b@`OSl>IduiRFCH1mPBW9*pGe5}Mq zZ0^Dkwm-Q8V*#YKm9GFXrM8e1!-#iXztQ1I@m?UdxNtp|m@eaCGcnwaZkUwaE*I+r zTEEHb6U`0Y-0B9NAGzH|w2_oLfOk7>m+x_}wGoBxy4U7O$?tVsSWa^ztdBQ)LU24I z`^lIae6-1%k$0De$;3g&-Qo|z#C8S0*AOgB3Zmze%_Ngvm-|dM1$hbi3LmDH6lfNiWue2vu+id)=|#MQe#pUyM9oT!khj-G5zh&7 zK8dvjIa*<&L$1V>_Wijc7Qy=GmF1SWw%p&`Tozca6`xWhLlB6?o zl8{|UNy=!DBzx*U{FIbs<~u_XoM3%mW)Ma^fvNOCXi~WDF#;JR`U{wu3ZwTYb7~=O zQ;LZxWRPT-QjkBO0YTJA-POb#RHMb}rnUv&E=P5Y{w^~1%*fcZAZ}q_iC~UD*Y?#N zEgM&}G5!Agu?_=!XUBkHX-e^_-t(JR9Wqyk%++c5%t;b0!a+!HKWAXM7&DOcF=(7nXK1=^%tVRC8i*i z)=UYN#0;VhgfxS^W8HEGJ+h|5D6(9aZ@ccd@NoDjh@zF$vRYE-)fq`Isv2-km!b8i za~qmInZp(twA|7uG$LRe+ofT^wlHF=m9nVjCA zyA}9w01?GC+D9+&8{)xIqPpTK$W${Eop~ZDNe>#{^Wf28h%^)CIExvA)=|^t#@B7c zM3;OeLIk1k&~{4|eh>4*IS@rN)CzW4!&OCV4tHDa&*XM_0I$q)$$~RyMlJ*^gC;hW zE!YLz+4I>3w;~5o$qA8}app9d>K+LUTyARnDIyo%gL#smgK*%i@&2FAcS`dct=%-I7NK0q% ziet2tdKa0L{ggg5P}Z3oe5D7Hb8m%?_9MYw2v!We)%+bsFFh&2~1=% z<9N$4S{NId`Z;}Upo;mN&5TUS^g+7FjOHikWcnC&#*_~m$b3ISU@-S!KKrVw++T?) zGr3Bc3C0Us1#vA#wt>83mB|sAkz8Xsantr4#h*v!VQdveKeG3V%#77vDZfGXjA{$T zY^#5T9N37cCpa<$Lvy_yAn~*i3>tJgYopfUHy~4gkkgI6 zTSkv>IUF7#laxkIqES zp!<7VX|f|jOPbtLnp4)Qpr`N?(}asm26+G`Xi$cG3^XeF*?%Gt2IzPshmQQb3?MW& zGZ+GU29Qdu`bZyJWHQ)xX)4pSdHFXXBI{{HHrgJ1sg(5ugLxa@_~JmyQz)g9Qwr&0 z63-v45Ah~ge?EA9KOa^y(et4=vykWTWTpTVOEw1yi<;tE1+K5rDG>NSz;&8iAE}!D z`9t755FpKLfXy_Y5fY5FFhRv5@fML4kYS2?R6awC{R^%vSXWL5v-70UGzA{5=Ug6U zqcu|UG=2Q&!w;R{zNGICke1Kl2f1~&vlJLI4gU=t=oPgIDn3P&94nCq-}5N7)`+Kv z*Ws(t+cT(GKq0+_(-JQ2{wDF^l|J0j@vI(MPZg%emYHmSA!@*O2A#o8`TQJDEh_&vPWm!8la0|${ra{dZ$ zmvTr>6zE_k#)^0ys}z($2^FL^F33HkKIk24kr5;Y)wiU*Bsvn@L~rp(;y1Ax-7eyM zqPMx9p1b5`Sm}XV7d))v1QSvjEYSfe>6Yq6bknl)A?JjuUGtz;mf+jrB@ik*890Db z@u`XywM!^5mC>5RSP93pvub}K=SD`~lnHm7rJ$Us$@Ki`=I7Kf*3#1tvf0QL2&uQc9PiBsl5KqXz zg)6)X60735`%6F@@&&AY5hYm-NY!LSb+#*H3$X?)>$|Ef&{pUqzo+k`rWX^gwEoTj-ZJ$$3;jio&-tkt_YA!L=Qxb#-1l~naDFlZV(|elBdeV3V9TvwtACBULr#4 z7q^K><0q4s)-AV5?uz&Ujn#?#kjP@5T%azM<>)XK()S2~rFq*ZoA#7#*cED(&5C`* zRmIqVXU~?$?UH;`Z9-5zLmmo|Od-KRkt8jiPg1pd2`119{H0#xwL=M?WNc(y7SUx@ yOr8yrB0SccGD{L4&}%7ZB~}{a(t_dRLnG@d<3FR9^nU|x?Lu7u diff --git a/vega/report/__pycache__/report_client.cpython-37.pyc b/vega/report/__pycache__/report_client.cpython-37.pyc deleted file mode 100644 index 89837cda8eb3fde4c701bf22e9e80add6d7dbc60..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3646 zcmaJ@TaO$^6|U-g&-BdtvUcL&gyGI?&^iJ0f*fQyPAniX!Uj2v2<>L7y`GuAjH_$x z^`r+OHm?SYCnOSSmN$L}QvLxi^@N0k6n+8XiSJa;?9Mutd(@}SRGq3mm#@w_*E*dR z!;}C0A9sgm8T%LgSw9Y#t9YwFfCx5Vf(t9-18zoZU}3Z~JGTcmr!^;Ya(Cbw-p#z+ zANYp%vLFu!A$VT|Su1Z3+Gak?qP#Qc7`_F*JLu+HgRT73;1p-6wLG1kIpl-0H<)OP z=wl|Lbo;;_oDX{e+*+&FXfnUvDB2rlNmcb!*{hO=X)npLo;12wy`&Jm*_1qSG(}?f6PfmM zST7&%4#{eIXrvgwjkkIkL@~w3*4PGh1b<@HEOrHh9@fpirLECaJLK}dCj&^!5JdaK zY*wjM>M&1ej3E`}jsIej*50sjBiBxqX7~QaUfY6SVYP+Py26yTw3Rh>Ama-AiFL>!!^Ym&7nl?76Kmn%1cy$LcZF9w z6TiA!JHopFcu#Cus6Yg@^B4B1b-;hmRd~RbEv$zZ7|wEQwrtn#GCIUR&T1FVZOyyC z*o@e&4&lgYTqJp_!$+l@kXH$p_Ask-rwMP|L>zJPlA#)vg?142M0+x=W|`7{H5(2Q zhWUke6Wj$sXgL9lAN1r!xTD|Etx+{9DwPz&6o-t4N;_k~Tf14Ah)TO+mQO3$hOj&X zLSZa~4&T`yrc?T|^l3IE>DX*P8xt<$mFkV6+E=An1$vNZn0$!{p-y{~M~U34Bt@0< zr(0++(vWjzdZjz~>O2UuTHN6Qe%t&Me~ouRyL^831Q?n<5!Z)x$m{l9F$Or&{MTW2kfJF#c%vB_?egumS zV9~{T`v?|WfW@f|SVZ&f+o@XJ;@&*~zq+4_-3`A52t!N>PkV?S(o&zx3ouK*0@AnS zmx!kr8_&wuz|CJiMni?u*T=5Ga9Yx>(sofkl2jOUu$GWtp@Y6k#y%(&7>%1HBpav>ty6FoQVA9 z-l0hpIDk?Dr2~@+NNa^yD>k-jOCbrEs3PZ$ow19wU3V3c^#qzc1|T>?!}S zSfFAqJiNX_aakh8AMyjf2DKi%Bch8B0Qh1)<4_- zr}FCHfL~`HUj3D|2p28kF8w-$9-a!8VeQl&zU`}4?F;X~`s0~JTeX)_-KyJ%7H4P%WV`Olm!BF^TosqS)1L^#o{efZ>59UbsP zhl(Xx1P!C#Mjx=Iq8P>y`QI$1pZ+*hh=+AfIOGTg1>u43k%T z82hKS4XjBAr|ENqx~v=~*}6nd6LnwGEjD>nJK3mGI!udMj;>Q7zZ(;jqztCRaXBht z+M=CunieMMz-ClwlYF(8m3w>W$lO%QQfj+WvfpWLj7hY|mDe@L4BE}rH66#Gg`}II zN2cz>HT)o2t(yd(?UC5mZo~6?+DiAIN!=#r{|Ks9*FhLEbqlq*i+b&VQqA@(3%q5W z;b*WG;djvr_(jZ-?D|LZuRtQ;^WKSD`<(XTX1_)4A!!}Qx)sNHDP}Mu)=?Zkm?hb2 z#gAiA4&zu-6CqzG@&=JNiBQI93f6HdF&woZG6JE>T6IAJ?pT2xbc3zH4SciKKQG@S zIjRBLrx=WiJyTBy+tk+X4$C|*i(TX15*?L^=7}Z3U`$H}C#N@e&D1EX8jP4Anx3n@ zA#vQ3o62npvvY-MsQQ<8-Ku(i@>WvQ!6G8t?s(io9vJ)()_mV Gmi>RYAerSV zu@&l+T!UVE>_J9P{T=-W1$x1Sc)RDQ7*}qDHe8GumodM%yjh=u)c$+Uhxdx8?S|me()0%Kb{K zLWyAGDN)^}t=etEZC-jwcu7pZuv*7BTO$o;=spl>B31J>Y)Fr1}f?I7M_ZBU>c;FC_g=b-%NM3=lcHB986$Y3v zhYrsi@wK@(Fhn1)px*ZfLF5$u+4E#gs5&C9CtGy`89brVg>elg%b<*H&=&&OM)Y(h zlz-suIcUqwhHO|ris|s+{gXyFJovuH*+USV=Ctd`tsKsIo{U(|a?&APS6;-^?~t4= zk?qx--XK3;9a%Xr14iEif2{dQKWJ--Jh`ms7+5ogF-7k#+ocrR*GW8M`fTWSD_C(AJ%Yjy)vvwiyS417h2m{fzaxTop+-(Gcc|d0Yld=z_e@*1YsWaXvO?IdYqkQ!}CYu_wMw4 zZXrz&g1IYiz}Wdqkug>51?fO`*<~XJbvrkw&sm#9bYyMcfLN9whKnORx50KGimMpS zc`|mfp4;=p#Oum_tF63FnOnaValEolwwDYKz8aPCQm2wzu*RiAzPXc$LZXtV5vg5Q$>vWn9-#^OVx7G^r+^{(C?Nc{~ z%FZoCI*>M$1>}V%5y*?lgSFYnRwX2&p-V|Q9T#Ye+T+$M%h1I>M3*wqM8hf4 z{-1GnBb4LSp zXx?(8jPD*kPQI@zz7I)s&_lcG`%eZzZ_?xXK2Jj5msLbjLxHcP{1U|yiUWAY!WT?p z{7=D`;Vh5Z%qw}eH|?Dxh#_# r))wl%N!s!ooO%LhImMp^v&H|R(%w?x3rL`SC{Uyi1&SaKL4dyYu?32GD+&}f5cH*e%0t`l|Ih4= zq->YO&YU^7Ip@E8|K-f9Q&XOXzxF@;^E+Sryr%sdJq-SA6kf(Jx}a;C(0onk!f5Hf z&UeE%_-^_p-!0$byY1V2cYKHM1;2p1*>c-Ozo@G@Jio+kWq%5Ht2N!8@n^WrZdKZ| z{w$Xr)X({I`Iuw=G4w68=GzPYLVM9)#_Dp|0MS* zW6UZ46ql!3r`u=zGhCi-RojpIkE1-ZuKQ<0WABOZ+`jIg4`+6@`jetke=0olx+Z4D z%zNk)KD}%1>V8cui#c)ZJ>9<`j*EG*fYLK!Q7pZu`4@$8L#v%g3fIG4SH>$J6De0~ zdgh8S>@~W*2Pj)_bQ@c^m)?{jl%crVY{YnUZf?l1E^0tQF9_14E=g@)&{dc*aw(Uk(vv+DV+P)5xt4m*tyl~Z2{qo}?f-U#KrP@uS!>ReG1B~b>QisFKp64NMoVn$TZRswbBa4(Bz#4$0CmZ@-BEQm#( zVMZ*8Wwcg6=@YomipRuB+~>q8aT@ny;*6-`J}({@XK`N;Pl$84FJg}K;z^!kNjxQ< zM%}XTz?+M51{+AG6}hSrQ)Z<>RHZ+<8~DA3Uvv_eSliW7Ezo0w?`Ep)S^MZFG)$(i z)@(_X$)(ae&z0MUf&Tm1!_TGcc4Zy%=90bo|g2=%|(_Q>03u>E+ zp(}%46T7A5^KPps&~Epj-m$T7=z}LCHt9*2*s+z&UEl6h)dmLC%ciz_q8`(}NE$Pb z1WT1yYnCiy;;bUWbrMV&Y}7lV6-v2;YPpO{O_#^0OuJW?1(Y&VZg+r4T0|?)q1Oi<m><0@8)b(%@_ zE!phE)jMGnS0O`dUD>Wfp{pQVRfJJvB(5|&vHJ#oEUxRA_x5~Wk1_YAj-Mg44~*{_ zcP*j-u%&pvE4RRfreOIT%w01xO8|%DY0Tu8gR$3Qkg@Vy<|v{a;L@+X%S5|Ix81w) zJG2P3qU4zgg7j!V0qtkya;w{@x1uX348{S9mdBuIO?J^aXHm^$J^;Aluorad?Jz6k zO$Szzd84;nrj;nM@$_Wr5o=j(a2NC?5|SAtCPjaVOA3>=<;rsqA+2x5(1-8nYcPI* zIyW_Uja|KOD!@xkVq4!#&FDqp+|V98FA4;NdA${zv9)KX);@%A&)L_$t9?hqOfLGk zaRGDbitUybyIY?8iLUi+T}vu)anBRQ)F#ba2Y-d$x8Jqj*1n;2jJNRwc6!Ni5@G*B5-o>F%~HFF0!!_5pX61OR;nv`Wd>s@ zb27YNhq?wNWg)Cr2qP&U$2fVGE|a(+pFrI;{05@<>?HdivBaS;8mlAP(L!zfc%e_C zzzaP+P%qd7L!KZ->o{~>pIEq=c}=Bq?R`!AHy`Cw)`)nPrm#i(THFqugpiX zOiI6X|kQAJGH~M3l<(CKC(fEU>K)Bh`l(%OraPg#pRVt#)%k3MRC% zy4Kz9h>O*5=cV&ett3y;VA@dTtTkHzcR5WDMyHqAbt&r)GAHhmo@d3dv)zXO5{Fsw zOFNCQ7dN||tkCMNufrqBYz*znnxjCM#4fWiAd*CMNmea0!W|_LhP*_TESoHr<=lti zrM`l;=majB;p&dQs88z_{tevc@Kn*udUE`c5-|bS#Cb-9D=4ITsGemY_FllosZq?zK3v~WP$2#r4+nJ8W~*J4H`leklbYoE z>DtbP)V`|SKHD#(g-uJ5a7#;_z8wR};{xKVvIih)b!vtN&0F?vG7+S87MUB%2P(mQy`7 zQd<;1F#E-{82t;^Baq%I%3sA(d(&xgANYd4u+rmJ|1Gv-)obJ|Rj=p%?8tIGN}`=l zwE8*1D3vBDVD?E8=9r#Xm|r-=ta1DuIBthUdnjuY=AWdMEm=jG{|M%PbHa`q@)>$J z>s)U32?!6!7~8OWs2|v{FaE!mIdB1vUg)R#DZI>qWq;HP?~ko;`F_0_!}`Ih%4)n3 zRuwWu)x@jD-Rc@bx@ZF$Mc7xHIe^ZH84*JIGeP&xX4r_Eh$^?j2T|tKd%dtD@}6CWWj3F z)z!;Ja{e>1b;v@(iz0rEYlMmv;gS`PjZJe8vL^I>vgY(;r^dc6zYYX21Z?+*6&Rbs z9t;eEH0u`P->+)dwYRktsgC2&-Jpn)%_WaxV<$Ov6BLcsLRoFxi$IQ6NQ@jHW#X<# z4*a9ckat!Y>Lo1>o>T$1Jjm(Q~sJ8pT~1{s4U=^4~Pm1?|iD zp9fEp`3Z{WoreS_WhP6_1SY8;5SU1d5Ne)4D@Gj7^&mK>5-5Bzl2fnN8(Y=wUbVX& z_qJoA6-S5gJ>!iP%B?V2DtFLryBFN4M3)Z?VB+;^29vG;JxDzyy zOp#y8NAth~Tc69@#)ka_Vo|1k1zrwsq zQjIxCDxrf%oPr&s<>lF~BL*H~l}R(bgA4nXMy&4{Lf=D@Z&On&8M^e!^ye-;_k%Mr z#<;go2lG#Cq%;6A4ndh>9QpBuWcf`cChR&nM6^_j=9PbXAXGRI zz_TI$I}B2W{Q0BE#qRr90_(6H|8!ZZ9gw9Q03RVf?CO$zlHZ^Uf%GWqkbgjr-^Ar; zNr8DhPd!*xmU5OK+{53ZM?~2@ayL~Jm`lvL&q9J>L_*+5nO#=!?}HU|h9S&#vv0!X zw02&m9Nz6O^x+!rY5PdzG~tNuEK@4zb_EV7Tuqn@CoNo7?xwW|Gn6{Jto(Q9Z)q`^ zBn#%ojqN=LZm72D5@Pm?c!qn(5Q3JpC@j=^59U&L(@Q;J3+DqQif(D0b4VnWkVyJ_ z9nLQRr@S{M3UDJmCC@5-wDV7?g*E*ntyCa!&@a)u_GSPzCCppKZl=@aCTn+LRuW}#OTirYIcd6U&(PU0DqU~fwWj(2AHzQ0gzfR3A z8CeRIYtv)SAbFr!q0wz^w>vnMQWT8X^UDf3lxHUSh&Sj(|!d}qYLk-k{|u>4u~B2o-$_( zHh=vPhb2b-D}LkqnM?>BoN&0Kf$d}XCEJHX4*2r<;m*>VIO3E|L4HwBR@ZPk#B9NO zUZvH3LV5znXEY`V{(}}qgwu>eM`xlr-4a97UJ2x zxzyR$Q&*HI9T6XcO_(RE)Av$ikI|U5!Bg4*)V&I8v$*8}zn{jzMF|HN^Vs*2DC0dP zWpSWA9^O!Tw?u~*cdx#UA5tykuuLVVhXE5|>fwm#H%72bq0Ru+$t`VH8FUdmQvd*v z=KDdI&h(X!{|EW0Ob_HsU|HfSs0Nxl5Bn>YXY=TrA+=Drn(2s%zA@qpnSHDmc@ zDi?soganyIN4{C583l;W8XE|I8e53$vnkZHqrs!B(jbdgAy-k9q!?L|s0EXO1HG52 zk`2ZyDAmgHE!_OFB7-6t2lZJ2_5vY7C$q^wL<+zuj8RnCK`C>hI{f~Skzan1+8Dcs z0G^%Z7Y{Iy*hB~j+o<03D-4NNaM8?49@UqD+8z+yL!R;^#rXIU##_cbuFF9BC-noz zOrQLYAzqLo4?|N0?{p{(%rGc<`~gl*1e7&ECV%Wq-^586jxB&)VVqe6VK)(oQXU~I z2NW6+Rv^ky@aKqn*v3hQ&8PbV>?`C)`v<5a>15gkRw(ep)*L-~^7?qt^60Mw%+*yK z1;EDOzFIp@XZ|E6L6CU?(RjN><#G_*-LALtmO>DSZUgAHh!@Kfv21Hcnol3(*T9HuA7hJB&BFLf)k6x9M`5E`LZDve#^)B;mS@>9R`~+UeNt zhXTT1U)0bTQLex>EW@>3yNsV#@rthFI^MJ8l3Q>+x8zoFFSu6uvS~2&m`&LuoXNDAOu| zfz!FURZ6q6iWr(h@dV#_+{LE|aTT5deZ^C4HzWFfhYd3`zly}+=s1Jq7U`MoR-6^` zAu2c`R7>oK{orB$W-_bnYYwF^0|M@c_Q`;ps2 zrD^>H>d)%@@yJjrC>=zK&Qd`V`e>;T_P7Lq@5YetS99Lr*r0EBJRI(6bX?F!Y=tPP zz}#|^((A*srIl6r5niVF)Rf@QXXTS3>-TA}VvGC%$`jj{E`6QYk2oE)GclQmvd>aTC_lj#~KQ zZcJRs8ZhShrLk|Y6rizrWC7-M!>i~@zPy?}|IUi>o+rrk@Sx}tt)3$q6R|b2W-PK~ z)us(x%`&W=q+T_AU%>W{$#N3P(u% h)xS~&N_Ll#@TJcKhW}up2McY1dG5Eg&|C1P{tpL(g)#sD diff --git a/vega/report/record.py b/vega/report/record.py index 5a8ce813..7b24c82f 100644 --- a/vega/report/record.py +++ b/vega/report/record.py @@ -255,7 +255,11 @@ def load_dict(self, src_dic): for key, value in src_dic.items(): if key in ["original_rewards", "rewards"]: continue - setattr(self, key, remove_np_value(value)) + if isinstance(value, dict) and isinstance(getattr(self, key), dict): + for value_key, value_value in value.items(): + getattr(self, key)[value_key] = value_value + else: + setattr(self, key, remove_np_value(value)) self._cal_rewards() return self diff --git a/vega/report/report_persistence.py b/vega/report/report_persistence.py index 92879f3a..abed40f7 100644 --- a/vega/report/report_persistence.py +++ b/vega/report/report_persistence.py @@ -51,8 +51,16 @@ def save_report(self, records): try: _file = FileOps.join_path(TaskOps().local_output_path, "reports.json") FileOps.make_base_dir(_file) - data = {"_steps_": []} + data = self.get_report(records) + with open(_file, "w") as f: + json.dump(data, f, indent=4, cls=JsonEncoder) + except Exception: + logging.warning(traceback.format_exc()) + def get_report(self, records): + """Save report to `reports.json`.""" + try: + data = {"_steps_": []} for step in self.step_names: if step in self.steps: data["_steps_"].append(self.steps[step]) @@ -61,14 +69,12 @@ def save_report(self, records): "step_name": step, "status": Status.unstarted }) - for record in records: if record.step_name in data: data[record.step_name].append(record.to_dict()) else: data[record.step_name] = [record.to_dict()] - with open(_file, "w") as f: - json.dump(data, f, indent=4, cls=JsonEncoder) + return data except Exception: logging.warning(traceback.format_exc()) diff --git a/vega/report/report_server.py b/vega/report/report_server.py index 01c65355..05d66220 100644 --- a/vega/report/report_server.py +++ b/vega/report/report_server.py @@ -46,9 +46,11 @@ def __init__(self): self._hist_records = OrderedDict() self.persistence = ReportPersistence() self._start_save_report_thread() + self.old_not_finished_workers = [] def run(self): """Run report server.""" + MessageServer().register_handler("query_report", query_report) MessageServer().register_handler("update_record", update_record) MessageServer().register_handler("get_record", get_record) @@ -117,7 +119,8 @@ def get_pareto_front_records(self, step_name=None, nums=None, selected_key=None, if records: not_finished = [x.worker_id for x in records if not x.rewards_compeleted] records = [x for x in records if x.rewards_compeleted] - if not_finished: + if not_finished and set(not_finished) != set(self.old_not_finished_workers): + self.old_not_finished_workers = not_finished logging.info(f"waiting for the workers {str(not_finished)} to finish") if not records: return [] @@ -240,27 +243,55 @@ def load_records_from_model_folder(cls, model_folder): logging.error("Failed to load records from model folder, folder={}".format(model_folder)) return [] records = [] - pattern = FileOps.join_path(model_folder, "desc_*.json") - files = glob.glob(pattern) - for _file in files: + pattern_model_desc = FileOps.join_path(model_folder, "desc_*.json") + pattern_hps = FileOps.join_path(model_folder, "hps_*.json") + model_desc_files = glob.glob(pattern_model_desc) + hps_files = glob.glob(pattern_hps) + for _file in model_desc_files: try: with open(_file) as f: - worker_id = _file.split(".")[-2].split("_")[-1] - weights_file = os.path.join(os.path.dirname(_file), "model_{}".format(worker_id)) - if vega.is_torch_backend(): - weights_file = '{}.pth'.format(weights_file) - elif vega.is_ms_backend(): - weights_file = '{}.ckpt'.format(weights_file) - if not os.path.exists(weights_file): - weights_file = None - - sample = dict(worker_id=worker_id, desc=json.load(f), weights_file=weights_file) - record = ReportRecord().load_dict(sample) - records.append(record) + desc = json.load(f) + worker_id = _file.split(".")[-2].split("_")[-1] + weights_file = os.path.join(os.path.dirname(_file), "model_{}".format(worker_id)) + if vega.is_torch_backend(): + weights_file = '{}.pth'.format(weights_file) + elif vega.is_ms_backend(): + weights_file = '{}.ckpt'.format(weights_file) + if not os.path.exists(weights_file): + weights_file = None + hps_file = os.path.join(os.path.dirname(_file), os.path.basename(_file).replace("desc_", "hps_")) + hps = None + if hps_file in hps_files: + hps = cls._load_hps(hps_file) + hps_files.remove(hps_file) + sample = dict(worker_id=worker_id, desc=desc, weights_file=weights_file, hps=hps) + record = ReportRecord().load_dict(sample) + records.append(record) except Exception as ex: logging.info('Can not read records from json because {}'.format(ex)) + if len(hps_files) > 0: + for _file in hps_files: + try: + hps = None + hps = cls._load_hps(hps_file) + sample = dict(worker_id=worker_id, hps=hps) + record = ReportRecord().load_dict(sample) + records.append(record) + except Exception as ex: + logging.info('Can not read records from json because {}'.format(ex)) return records + @classmethod + def _load_hps(cls, hps_file): + with open(hps_file) as f: + hps = json.load(f) + if "trainer" in hps: + if "epochs" in hps["trainer"]: + hps["trainer"].pop("epochs") + if "checkpoint_path" in hps["trainer"]: + hps["trainer"].pop("checkpoint_path") + return hps + def _start_save_report_thread(self): _thread = Thread(target=_dump_report, args=(self, self.persistence,)) _thread.daemon = True @@ -308,7 +339,7 @@ def _dump_report(report_server, persistence): with _records_lock: if not _modified: continue - all_records = deepcopy(report_server.all_records) + all_records = report_server.all_records _modified = False try: @@ -318,3 +349,10 @@ def _dump_report(report_server, persistence): report_server.backup_output_path() except Exception as e: logging.warning(f"Failed to dump reports, message={str(e)}") + + +def query_report(): + global _records_lock + with _records_lock: + all_records = ReportServer().all_records + return ReportServer().persistence.get_report(all_records) diff --git a/vega/tools/__pycache__/__init__.cpython-37.pyc b/vega/tools/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index a2e6c6eadb45190ac8fe994a6c6fc2491cb5c0b5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 130 zcmZ?b<>g`kg51Mb6F~H15CH>>K!yVl7qb9~6oz01O-8?!3`HPe1o2B-KfTC6zbY$B zzbLgJUq7+5BtJJtzbrL9QNJWVKc`qfK0Y%qvm`!Vub}c4hfQvNN@-529mtT+K+FID D3n?9z diff --git a/vega/tools/__pycache__/query_process.cpython-37.pyc b/vega/tools/__pycache__/query_process.cpython-37.pyc deleted file mode 100644 index dc30cab0b9fba84461d7c19c5c98c966073ccd31..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4536 zcma)ATaP106|SnjxP71X%uZ$!E>2hhhKy$eQ9#RTfn<}g0)`D6vWvyWaa=v#p0V4W z>hkW4)6R>7SdbzG#KTGnXod$Q^2|TrH^4(bA@RaW{DM63o$9tdJB!4&Rku@BU3G5X zIqjFr}H%t=q?WkD8YspqvkS(cR#SSv59(nDL2E7Cz*l;@<2wj^sZhqf$N zLWGm`cc$;2qxH0gD0i-XrzKrwZpjOd%kY#PPE?+!rft#qv0r!Ni~StbY43^;w))7>$9#IH`R?i? zLqF zXc#7iNDr0Y)@tleHkGf(Dy-8)eRrq^9bJz!MAU1D#O-u-fDQH?zdveN%E6IIsUu|@ zf*F%!l7bNpjY*ug#BCEd6PHcYzc96qZJ%A-?x~C0lfAud<&TEjov|Jc`r9OI(-J8U*u~zQ!0*c4g254Or^Vq2KC?uljuc$OT<`4kuu_-@%5VNc7 z_D8>FTHJ@gRPCsOdFwtq;Zt#wh3huc_I^RV%wb2I-K+hah2j>A#j%5TSg}P`{{e)0 z6}$Ic_0B=RMJM=(1mDTOl)JG4vapKhMrbWVgF}^ z%l#nI-Qi$#<#~FkK;^Q@vt-l@(`{bv54)XybOoLH>LGR^)To}O>KUrYFqSt` zpTn?;l3nGob5?1tj>_=9k?{_V?pw-^c`SDS|Czu5YatAhgE@7Z-kb{Hu~VN|;!nWc zH<#x$a`2OAtb=}H<3f=l#7%P7Q^ZR0uO4>&QR;o9J^12ik zkc=Q3Tz#JUhBNYqXHg=WX2RpPaUH=c_>l|MxIn34MD2oef*87uftd&B9KBE86qC2` z!!U&w_|v-tFMHOWO&+2h?atU<&eYyaz)RM3UUK-3kk*IR)Q;`Ff`KPKb#yUybZI8k zx^!cQ*2d0pS-L0Ow7C<$3*b6fhlhAK7`(Fr@BCfrokY3CWFG!fn3Tw02$Ji8Y)D=PSaW#2W-*+Z^f*I)|pG$axbxf(nyD z5CvhRJ7L!cVg_BESa3Q;s%p5Y{w~n!!t$*247Z!cC0F^;xUbLp59#eCl!%a#2?1@B zp4U;Apx+JXzp~Ua-IOk6wES&!G@B7tJ`eo@IicVB6j44VO|HbiLdiE7qU{QLViQ1g zumbH^(1Jawk`}V?MT$^g!}}8J`n9V&4WWuOaPM~dV;?~xoy%vd5&d#m_bfX(Y>PG$q!8h{bq{LfXxf@FDEr@df1chH+h+Ht7#C=awiw$VGg*q= zkJ%sWjAqtiuAGf4v4vT4m+TGyz-*Ovrc{i~p3p;Yg?34pQ{&vur25vWtj&)uO|18( zmr{et{Vd!MhxbB7rkYsYds2NJJrhk5HmRmrA55tS=fKYMbvy);TXR}{OV=lMeM{9R z zxM?z_N_1T}jiQa5B0B{rN%fuiB%&?6 zDx$#2U`&)ArRct9Fy4d9qzRJ;_aW4$*$OWR^(uBOyovB~VS9hX0IpBga2GlvYlB^F zA^>DMwhaY#x1czN0;^h3a)#oaqT~&waEekil+v6MUihD~j*GU$0 zE!HU|Pl0DrA${Tj{#>fRH#H_>h*_W@;$jg*ZWetPJkN<^Wcx%+Z0Zr!IC;IJ4Yk-+ zen%sNMno}Nx&zq{Lf=g0vVoZxq@Ne*L^dEu{5}p-FQY>6F_v&96(Xw{qeL}S6uK!) zr@u9EDpZrkC@nUK@ew*lZ!beN5CQ2b(j)=T&ybH2*I!N&1-Z2P0adS4MVc^9pL#gO0Eo;ix<>uNDIlhpmu{suF83!M zX^sFo#cB&5H9&Ba98qvxzCYpq88=fT(HTlascITXO9rtE2C)%N#R|q~Hrn0OHbxFXzr z$~33EAZB|)eO2S?doUf7(;0{dh&!=?1&J?^hBH}9-SQLze$?mBD zwB6jyAbqMB61Jk-(9vXr*t^t+=~057c$Wu5Iqv&cXhU=Xl?_)X#9p&e+n(p0_g1{T OSM;9oD(DyRyN`gYKB3U*iS+s=p!X!j-VtLe@>fP<# zd9k{O-0jTx5atwsh7JM(L?*E7Qvw7z2ss!&1$n{8TnreE1Tc_Xc=ath<@^8YnVnfm zwl3MlbamI`ufM9l`u<0MZ)T>T;jjHKzxm|K3!3)d^f3A}aPt;^k*jN((CV7dh0)UM zI+sS>;L@y{Tv~ODOS^7!>C_!AbM+jTZrw#`w({*ly`YmUt5s~5>LtFnTQlu)z0CJc zYqnjfS8$&b?q6#4Ibrzo_$`RMDByWf6h#T;2{9weD3`>nsGvM4=EOY8Ww9U@QJxYf z#1hKW;-pwc`HVOvPNO^{o)KqIJ}aITXHlLN&j}CZbK;yhkJ1y*iwh{v1?R;JyT+%w zw1RUFHRxm9)T$Sgg%1N6Zti={RzHdZ=>?tpwI{S`>s38n?5@WX>>O*ay?cfav zQ)savEMW`hk)x|x>=f+|YOy!tTI^5OcXpz0h4zW|=#fRE56nYts1J0k?ZkRAJH5DC zVyv_icV(sR(^_vOWp(3qx+3syY*ei@CxT7C--=T!-tPseD}#IePzEBkwu4p=<0Bd$ zwfy@*lsc;XQvX&`Q@0~Ym%U)C=B-5D)o;ICyZYL-+LhYN6ZfxRCQAfSoZ201Y$R!i zWF8mH&`C-wE1k`#vDs~jKqhn1whs-w^>=T2S(UM}lQ=6YJ5jfj7NWQ-gGMa-LGrCA zYz3WW;I*-!%U(~0o!FB>gavsI!g$-G5fCEVeq1$D$8W}Ax5E=xEoox@wB!rXfKGjS z+o>7FGA-~sr^<*1DFmbzJ^!61PhixSS68>>mDOZtXH^Eh?yBF9yY1HM{b0*qjl11e zgj+&cBSTrOx1SajlvwODkP$86qUo+)p7<4T@9Od-TADbw{6{6BiV;=8O{_(wSRZH) zb)Z_qo7dN&r`_P~v3U~!Qv1H&>IbZ@y3l;7)8O_M+(#BJnqldZ)EQ4lQ`z{vk6-jM zF0m%PSl=-Q+NO@uq@U1#ZVdFlHhu=+joJn}J+VRtDda@4==NhS11YiL%aDfm*T377 z-DVI)snzsbt<>7MvGEQDn9a6Go!tkF5ZdMmz3m2deT8&#a3QtN81i}CjTNVTq9yS@ z)o=sFPb=k=%8!f#OBjc&@&T4}sE?jmX%h?}pT)Su?#G)e*Qz!vEYG4UEj71;=57P) z>GfmjQI#WtCSZp_=kR!Zqh}hEBVI?3vEoaHyol#1#c5Hj_+{K2DK3t7+<0V-_L`{p z5$KH7w$^V|%{1R^b)%rMDI$3W5_lcZ>_yyX8?V-3Z*+S>hZpcQiUT3kg_Rj?H8ie5 z(pY2KW;(sx=|6CzNKU8HAXx5wL3H=F_VM2iq3b~ZTocw~b7<@@g8V>zA~(>tEYPBP zTkD*$wAkD&$d7c;G-Q|q$oLEWj@B_g(gvn*cTKr7sug*bR$ysoAPtlWQ&JLTuAkk? z9h3SGe?aQ;Nb0@i+oXGCe5B^q6Yr%z@|RR@#=eX_b=JYBwPg8S7zJ=w-1Yn>LEK}W z^e!bvwU%1Fu(JimYo_|ElHi)0JnG--ZiSt3_k}7G81Uxx{3v?Rm7)fAincMlo19wj zx9hnpK&+lK76dvWx}mE_dTScc8`y>1l0AM9Tz)MC;%_Jb%Xo?pEic2eYV#6OvmI@vjw09NU@tT>z6*Oy8<#MSsLj<& z_&4-LkQ~vRYtWN{`%+d@(&bh39V3gFXN)WZdvqLGbf6DRpvyYcqbsoubm3m%i%_%o zuXuKVyk8CJ>=*|*YLDH6{Gql}I5dX9A0x%d#^oN{Lvw#QE(&*`16{`F+GAtO1d6N- zn>MidHyMiT0b0h`v_t{e{QEeo9ZM;Gjg;U00V&r;Qtn-2ESAPtocShLoSVWTY-cT* zzj@q{UP|VwUKqh@$6l}(!YL57WZ|i9-lZ?~wPb$WRY4hjC9zZDw0A`)J%wOLw!IbU zjqp@rJf%h#4u%u>i0vD+H}2|n%lu1LlOK-_>)BC z^36U5;6a;s}SBtcmvM2#95U9-vzS6b*fO49BB;0C5c?#z3SKL0$?| z+_kR+Vr;+=np+^Np}zMf!3nT?b!f!!jt558$|N_?2PS}W5pW%W&je%vJp$9ci?=o9 z0dAQ?Yww>2B)YDB4#RhG3}5)Af1FB}rr_K2o4fv25Y>|W$_l6@BCn0H_rk`0ybYf& zt0XXsaC0+|@ZM?#<=(Cp61TvlwN^-V$mn$aZHU?O)1mw$)tlWDg zwW96rgH$i1TH=1FW?W118%h!I_Y9wgb{IvV(^~S}sAXD#BMtI}kqV=ukyBJlreBf) zP0KxAOQYFsxBZS#Bur-ml-p~Cakan)pk8eIVQ1ue@M$RV>ar>?@vv;8a2nB4pHTxxTl#g!2crHQdPJ|`!aQyeW6y)ZPa6fjfa6opE&Pb#UA@9> zlQm~``7TC|4^s{gV^;qi9%Ai4r@&=P8yb5*8jz>(1F-NF)Efg6Ah!k^m@#-Dae1sohYRLbL9Rc9OpU4S1gwDV@a1S*>Qa(eQ z!)^J4g2<7y&e^Qp-Ob4_(4L=e-#a@npykrgjPnP@xI}g!p4llM!bsiwB_Kb0P!R=^ z#9mo(Jt+%6+Fgm@28wkX&ck=! z?t1&(zW2aKq_Z2gTB_r?hX){e&2A?SJNHO*yPk=i|_K!fKyZYky) zG#SlHMzBFbK8Lu7NQMwk=r$PP6*d-xZspG?pT*h82!gMp9=o}3N2&rMURb# z2-tjce?CSKBFxxSv_9cuyfyL;2(Kd_BX$OQoZE398c_^sXkjlKq|K!hgMJ^g&kaiE zalmJWY4mrgA0slqo}8Q_*)iN|PeB#2HYGKZ=Z@Udh+W-joYs0thS=H30Ro!<54Vs*tB7m#l2IzkD(n<7YhiLGQGDzm_^N=m3<{SX z8AE+y)_)(*T0WW;&hRWO{ckZN*N*04Fjdx*vr}MdbrGL?imxss6k1DM7Cz10K`;S zo3QX+*d`USuCF&ax%Y+|c5rmSL5jUQuX{TN}NPUeYL(`J){mPNtiL@jBrH1Em0nf6l%P>7hwZ|;;3^}%&;eLZR;NQ?Ax`2yj zECZrv#t?-?HpoUBu*q4(rAu;z$8epB+~1pUYd_WZt`KD2);`rgzJS;X@MQYh3%i_-^%ypI94o5G@T*Qyc@!S16Z~J8 zSQ&gAnhO36{UscmC9r*kOutd#qqK;pB|K9dqQnrR#>Z*O;M3U$B#`X~z{my=`x|Va zafF8YzK0mjIt2Vs+Eo4&ync`qnLEpgFc$}BlnBery*e-qGJ8Eap>TvOc<;z#tT9ra zkReK5o;2E%oj}OSTJzX2_7HXSkyYKVRrBf{(Q=O7N9xjx65R)#)MkP_va4)-Gvk^% zYCDf})?$O^Ic8t~7c)fjxM2eZiKXR_Vi(!1Jb98K z(hb6X#6BRn?kC{70zmYIj`&K~lDEKRD%Qz?sfz>173aO2#peuo`9=fKvHj=1F1yoD?Q3MGie9=*PdCD7d&W2-6-fHy* zPtibI^xM5i#U127j=-uv^+{xQ^cu6(;M=7Zk}atR zQ(=WiCxP$y(~|qA_zH(e_st?CMQ`F_jSs@;Ku|gu|E62Ebjx%dSN$!!%dY31a_3N< MacA6uyHL>o59B0sKL7v# diff --git a/vega/tools/query_process.py b/vega/tools/query_process.py index f5f3ace7..01dceaac 100644 --- a/vega/tools/query_process.py +++ b/vega/tools/query_process.py @@ -12,13 +12,14 @@ import psutil import json +import time from psutil import _pprint_secs from vega.common import MessageServer, MessageClient, argment_parser __all__ = [ - "get_task_info", "get_pid", "is_vega_process", "get_vega_pids", - "query_process", "query_processes", "print_process", "print_processes" + "query_task_info", "get_pid", "is_vega_process", "get_vega_pids", + "query_process", "query_processes", "print_process", "print_processes", ] @@ -51,20 +52,21 @@ def get_vega_pids(): return [_pid for (_pid, _ppid) in vega_pids] -def get_task_info(pid): +def get_task_id_path_port(pid): """Get task id.""" try: p = psutil.Process(pid) for connection in p.connections(): port = connection.laddr.port + ip = connection.laddr.ip if port in range(MessageServer().min_port, MessageServer().max_port): - client = MessageClient(ip="127.0.0.1", port=port, timeout=1) + client = MessageClient(ip=ip, port=port, timeout=1) result = client.send(action="query_task_info") if isinstance(result, dict) and "task_id" in result: - return result.get("task_id"), result.get("base_path") - return None, None + return result.get("task_id"), result.get("base_path"), ip, port + return None, None, None, None except Exception: - return None, None + return None, None, None, None def get_pid(task_id): @@ -80,10 +82,10 @@ def is_vega_process(pid): """Is it vega process.""" try: p = psutil.Process(pid) + if p.name().startswith("vega-main"): + return True except Exception: return False - if p.name().startswith("vega-main"): - return True return False @@ -132,7 +134,7 @@ def query_process(pid): """Query process info.""" try: p = psutil.Process(pid) - (task_id, base_path) = get_task_info(pid) + (task_id, base_path, ip, port) = get_task_id_path_port(pid) return { "PID": pid, "cmdline": p.cmdline()[2:], @@ -141,6 +143,9 @@ def query_process(pid): "task_id": task_id if task_id is not None else "Unknown", "base_path": base_path if base_path is not None else "Unknown", "user": p.username(), + "ip": ip, + "port": port, + "running_seconds": int(time.time() - p.create_time()), } except Exception as e: return { @@ -149,6 +154,17 @@ def query_process(pid): } +def query_task_info(task_id): + """Query task info.""" + pids = get_vega_pids() + if pids: + for id, pid in enumerate(pids): + info = query_process(pid) + if isinstance(info, dict) and info.get("task_id", None) == task_id: + return info + return None + + def query_processes(): """Query all process.""" pids = get_vega_pids() diff --git a/vega/tools/query_progress.py b/vega/tools/query_progress.py index 6e88ffa9..0125ccd7 100644 --- a/vega/tools/query_progress.py +++ b/vega/tools/query_progress.py @@ -15,11 +15,11 @@ import time from datetime import datetime from vega.common import Status, JsonEncoder, DatatimeFormatString, argment_parser -from vega.tools.query_process import get_pid +from vega.tools.query_process import query_task_info +from vega.common import MessageClient __all__ = ["query_progress"] -time_limit = 180 def _parse_args(desc): @@ -118,40 +118,46 @@ def _statistic_progress(progress): return progress -def query_progress(): +def _query_report(task_info): + """Get task id.""" + try: + port = task_info["port"] + ip = task_info["ip"] + client = MessageClient(ip=ip, port=port, timeout=1) + return client.send(action="query_report") + except Exception: + return None + + +def query_progress(times=0): """Query vega progress.""" args = _parse_args("Query Vega progress.") - is_running = get_pid(args.task_id) - report_path = _get_report_path(args.root_path, args.task_id) - - if not os.path.exists(report_path): - if is_running is None: - run_time = time.time() - is_running.create_time() - if run_time > 0 and run_time < time_limit: - return json.dumps({ - "status": Status.error, - "message": "The task is being created, please query again." - }, cls=JsonEncoder, indent=4) + task_info = query_task_info(args.task_id) + + if not task_info: + report_path = _get_report_path(args.root_path, args.task_id) + if not os.path.exists(report_path): + times += 1 + if times <= 3: + time.sleep(0.5) + query_progress(times) else: return json.dumps({ "status": Status.error, "message": "The task does not exist, please check root path and task id." }, cls=JsonEncoder, indent=4) - else: - return json.dumps({ - "status": Status.initializing, - }, cls=JsonEncoder, indent=4) - - report = _load_report(report_path) + report = _load_report(report_path) + else: + report = _query_report(task_info) if not report: return json.dumps({ "status": Status.error, - "message": "Failed to read report file." + "message": "Failed to query progress." }, cls=JsonEncoder, indent=4) progress = _parse_report(report) progress = _statistic_progress(progress) - if progress["status"] == Status.running and not is_running: + if progress["status"] == Status.running and not task_info: progress["status"] = Status.stopped return json.dumps(progress, cls=JsonEncoder, indent=4) diff --git a/vega/tools/run_pipeline.py b/vega/tools/run_pipeline.py index 83baf75e..48b341a1 100644 --- a/vega/tools/run_pipeline.py +++ b/vega/tools/run_pipeline.py @@ -155,7 +155,7 @@ def run_pipeline(load_special_lib_func=None): dict_args = vars(args) dict_args = _check_parse(dict_args) config = _modify_config(dict_args, config) - _backup_config(args) + # _backup_config(args) _change_process_name() vega.run(config) diff --git a/vega/tools/run_slave.py b/vega/tools/run_slave.py new file mode 100644 index 00000000..57572813 --- /dev/null +++ b/vega/tools/run_slave.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. + +"""Run dask worker on slave.""" + +import os +import time +import subprocess + + +def run_dask_worker(master_ip, port, num_workers): + """Run dask worker on slave.""" + success = 0 + interval = 3 # sleep 3s + for _ in range(60 * 60 // interval): + try: + subprocess.Popen( + ["dask-worker", f"{master_ip}:{port}", '--nthreads=1', '--nprocs=1', '--memory-limit=0'], + env=os.environ) + success += 1 + if success == num_workers: + break + except Exception as e: + print(f"Failed to start dask-worker ({e}), try again {interval}s later.") + time.sleep(interval) + if success != num_workers: + raise Exception("Failed to start dask-worker. Gave up.") + else: + print("dask-worker running.") + + +if __name__ == "__main__": + run_dask_worker() diff --git a/vega/trainer/__pycache__/__init__.cpython-37.pyc b/vega/trainer/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 684b63ce6b40aaeae48bd49e3c7f9bf477524af9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 452 zcmYLFy-ve05O&h^KVfI<8YE&SCWH`3NDMI0PF{>pvbQk=R{(nsP9TDWCw%ECpZFe1{kNXz zi6HYtA20w60VBZpwf_z;#AJ@*X?=31N^RVX6=G)6EVnK-GiO*{z|Gc9E9P>}q^y!} z=!0Uf32!8`_Lgy{O;v|4X~vR{LBhq3y3mF{lkq-V9>ROu47(}`=iKC_dt4XLEbf!2 zQ=uRX?8zZr{>KXVU7bX>;68#kEqpNceVMy4($HV2`MYB&#dB z|L_Kx5ma0Edx|r2hU?V|o3zxp-GRhgJ=`_Ml;%b5C{4<$8C@tX)-t`Ef`UB)7x-as I;04|J0}I=Vpa1{> diff --git a/vega/trainer/__pycache__/conf.cpython-37.pyc b/vega/trainer/__pycache__/conf.cpython-37.pyc deleted file mode 100644 index 87373a263aa491bdb61b0fc1a9288ba3021c91da..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4421 zcmb7HOOq4H5$>1X(mXW4Fu?9=AI8=$t6^F8#j-pHSk~-d!#M2d2uHi8tE6evtrk@+ zU}SMIMC_m70LMM}SNJ#hQdgh2`QVdJ{<69yjaVSU-I3W@S?{dOud=T+8dU>N{EvVB zB|dH#|DwY5aUfg*Ne0FY*I*{IA~UvJ%Ou&3?AUP~Ejv*mF1kg?1y+noaoH_veJQHM zRkvyyHw;#0m8S-)@bVMet)X0HHI!?*Txa#WM!PXy+2UP)6s4^o8Fa&*?7TEFdD~Rg zTZxocwGyS(TjFl8$Jr?2Vy3L!9;RWO6Z/Y|oRaN6ho{BKYS3h2#ZZTK4vL$xLOR4a~0T-QEx8HHgeh`d=AM6Kq zz9{8ZtF|Y({{K*|I(?syxtgWs{ z;+EfHVUUK&z!&?iM4;RaBi>QPp)dScwjEWFqahcfjEPi*`gqhsR8Vf)$_^r_9OkFK zs>*&i^l&0UALgxiUJ&_Gdfv~*f7_cqae8xn@7|{1!(`JRrAZuZKHxonGvDuKfMe|p z_eBlk9s^05(6F0k*({pl^#${FX3Gf3Gw88EWRR;Mc?67^sS%psCb;#?dJL1m4EFfg z9iEQFpe5ipvem_&^^(IuucMrFf5_XGXkvhDhnBKsDzHQB85MO@sG9EJ=>fzFHPvjD z>Iz!>0Yq5@Gwh;SH^*y>%hzq%=1#k=%APmyWA1sX>UnWO&J4NkdG|+tl-HCzk0k-r zwcMn%N_W*84Pkfg`BCKgX)5G=*;pdk9zL{}Fl|ynHc^fIPy&uV-hAzCQ!Jxi&=&Mb zlK~{{s%$!zW9rwk%$1U^4Nx(AX0bBBS-dG$SivoU%7~i^E4o!s4O9m;KufITF0-<` z!Yb}6s0rN~=n&{I`WykRgN|a14SwwTaejgqHh>-*#tYMZhrc`1)U@V3-kNFZTJt{t zV5Vtk&4(E8BharvCo%4?@%;^`4f+`L3Fs7CazACuZU^hwMENx6GqgQ}^5>wlD1U+S zm-v3gR@|>a-(WB2K;J@t9`qgP0@_{#T>@POT>)JMZSiYtl{Gb5T!*}&v1E-M(pYjF zn#1gfuDtzMU`@eyL6R2BkYEOPhr2kbNwZu zzo26j6iup<5QK(fE>t}do}6N=DrT6ZoCq-7%N*d}7sLOc*Gr1OfEYbA7KN`4#_u7d z#>Ct(9$SyjO|0kU3uDXpbLX+0S`%YpKS8*qz$lU)JE=p`#CcwrSQDEL>WOU{_nUu2 zKx-6x?=DYUY~g~?L7d}AXS-b#B}}QDKHrzBwC76-4pr*$G!tYd=&)!PvXFLCU-sdS z*|;8ok+Ly>a;VQdP;`W-GKTP|M;{YOeR*4!41|YD0ZoDx1 zMR9fl#M%I8-!uF6w8X4_1+qc#PPH#g(@1SH6d=ElLB2+|{|Dg3$C#OHbrN4)5TBBy zec@e_2#DLZHjikKVu=(vP|N^NH%n2I!`(@kVZ6R*g4d%pO{X2?4*X|^)&{e>)(p2Z z^j_py?Jqgu{9;QxrO|Ayw^C;&zdA=DwevX(y6pm=t!LhAE@gAh%SKw5&vHccSI^By2WV(L&RJ;qb6GXpOV zEmdP7?!s_)lyat)rT>7_jm0B0m4T|!)h<%6EBJloNRBo|Xo5Ydmh`AzI7qpGsETrb z5bO?=?Z-n^jr{Sxw>t`>R8<08X1jjSm#WeeKEr(%4%OkIKr%#W5z8l}97U~UlS_m( z)`i?kypMedx_bqsMdCx+RxtvWrBW@2gD^eVeH9rFPZO*+8DP~kVtF(WNfgoU<{QhZ z3ITQ+sOns=%6mhZRaUi>L=2EwC2$02`WUaOZ6s4Z#QeE=l0k0c!?af~gH-83t+%L7 z0H-oBuk-va%;srRs)lalJ>X(Dkz6%0!%mCMtTvx1o41)YSe?o9ES?wsXJ{$p^JJ zzCmkd0Q*W7StTc+ODOF99Rp;!u#@Fs3on}`3j$_Zl+h5 diff --git a/vega/trainer/__pycache__/task_conf.cpython-37.pyc b/vega/trainer/__pycache__/task_conf.cpython-37.pyc deleted file mode 100644 index e8a4c2c6928ef733551000d71e1faa8e7303ef5d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 970 zcmZ8gOK%e~5Z+BTNt33b<^86VXLxiA2ao^}N(u-OTGYU0FN?L~Ze7;)%Ju^3fj@zS z1nM8)Cm?a-!U2w4d_r)F;1_U#aR~@9lE)r@GvjYRd#_JSG(0r^?$g6c&GWvx;E&aD zaSeyI=)j9S@Sy}{=0mVif(lfjmX=@)>KM7528<)tfF}BhjR0D(1twt%wr-R+@olgj zqaA;ZA9E3w7WEu{`%zsr!MsWp>!mBuNTIl0iGLb(tH+**aCWnW@Ibe&d;5@ z7`hxo43k2KOlU(yLUsI@WDcEO(^8CxHabytBw+>V;_-EjipJV~zk7tX$xJYZybU$XpRLzC!YipsRxeQ5RB+tUfG$o-?1aClv?%~Dm;mzg)J9dRj kcimUrzo}=?1(l}1)vx-E@}%Fwx#YLXO3n3&c-~a#s diff --git a/vega/trainer/__pycache__/trial_agent.cpython-37.pyc b/vega/trainer/__pycache__/trial_agent.cpython-37.pyc deleted file mode 100644 index 443e6efb62c93136a165a59c6728314eb00dba5e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2023 zcmZuy&5zqe6d#Wr+i{ZpF6{!9BGdyxkX<1T2vw-Ghe~`bqFq$v%i#4)vd%hjI(D{d zl%BR%;157Zn4fa^hd$#P5x>+pxmc`{!WTZlxQ$R(i1q>E?~GANvmN_3~gG#-Y_4d2<}ak=6aY zHEzdkhbey9Q5#Dq-h9ktAj8*8hHC4=jkhI#%6iRt=V4N+?y*j?iPHVQ;RZYEIVQ{| zStasf2t7DX%Cm2$C8>`Uwk0{#_7gQNboD6D)TB!3yQ_yoc>pRu1yc;cFbV`gaOu8g z2*Q&M$cFS~0O`w6HX#EU$rfZN+iajpZiwtrw&ZgL(&I50+nBaE!ZVn zV!CoJdF75gP-Dq4_Mru%5WM*-Qnp~D2$KKiRjms)YD@Pg7rQzaupl-h!^z%J9Eul( zK2uuEy*?$uHH}=ItK5Vbi9wdDTAL=-X>ZfCk*P$Aw3rOCq4CbuFfpxCRpKN`&(uU} z!r26?k=c~$JWG{GlS&N>{ZdoLj1SX8rFB4qW?8AJm?UFmnk!c_sa;yy$cr?|2^&Gj z>jq7Bv_(Ut77oDUl`0?n&h#y;{&RRR)Q1Q2(ddA-9VD}=80QD%=0Q~>4~blhgp%E# zzBHi_wgSYv_`Ym|F(-7kKs!#`neV=Vrhm2b8XF3>39{JaP%1VeRKQ0l$ziiQ5ugDYP?aAO^lz; zw9K^eimAeAnx$tz-6No_QN0h#`fXzG5VH_B2f7#wnWa_F(|dGuqplWbU_6;riN&L4 zZ_!uO8t#G0Eih)2yXQo!cYb?4&1)rgooXr^DvGZ_B^g9{Bh3{bxzM};UpPNINQyJR z^ZgX(tFBgQp=Gz2bW5erapijry^9ZxKeHQb9GbkbQTBKphJW@V(Zlkl`|o34A5@Yf zjL)~%9Iox9b_I*%=C!MzV1T>G&jP(eb+JHyxQ|}uK9@~~ z`aPOzfyFIqMg6oGkAVpljM-b8=~wvwh+_ s*L9Qqh|b=mllinit_trainer | 2
before_train | 3
before_epoch | 4
before_train_step | 5
make_batch | 6
train_step | 7
valid_step | 8
model_fn | 9
train_input_fn | 10
valid_input_fn | 11
after_train_step | 12
after_epoch | 13
after_train | 14
before_valid | 15
before_valid_step | 16
after_valid_step | 17
after_valid | +| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | +| ModelBuilder | 200 | √ | | | | | | | | | | | | | | | | | +| RuntimeCallback | 210 | | √ | √ | | | | | | | | √ | √ | √ | | | | | +| ModelStatistics | 220 | | √ | | | | | | | | | | √ | √ | | | | | +| MetricsEvaluator | 230 | | √ | √ | √ | | | | | | | √ | √ | √ | | √ | √ | √ | +| ModelCheckpoint | 240 | | √ | | | | | | | | | | √ | | | | | | +| PerformanceSaver | 250 | | √ | | | | | | | | | | √ | √ | | | | | +| DataParallel | 260 | | √ | | | | | | | | | | | | | | | | +| ProgressLogger | 270 | | √ | √ | √ | | | | | | | √ | | √ | | | √ | √ | +| ReportCallback | 280 | | √ | | | | | | | | | | √ | √ | | | | √ | +| VisualCallBack | 290 | | √ | | | | | | | | | √ | √ | √ | √ | | | | +| DetectionMetricsEvaluator | | | √ | √ | | | | | | | | √ | | | | | √ | | +| DetectionProgressLogger | | | | | | | | | | | | √ | | | √ | | | | +| LearningRateScheduler | | | √ | √ | | | | | | | | √ | | | | | | | +| TimmTrainerCallback | | | √ | √ | | √ | √ | | | | | | √ | | √ | | | | +| | | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| 16| 17| diff --git a/vega/trainer/callbacks/data_parallel.py b/vega/trainer/callbacks/data_parallel.py new file mode 100644 index 00000000..f8d2cc6e --- /dev/null +++ b/vega/trainer/callbacks/data_parallel.py @@ -0,0 +1,50 @@ +# -*- coding:utf-8 -*- + +# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. +# This program is free software; you can redistribute it and/or modify +# it under the terms of the MIT License. +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. + +"""Data parallel callback.""" + +import os +import logging +import vega +from .callback import Callback +from vega.common import ClassFactory, ClassType +from vega.common.general import General + +logger = logging.getLogger(__name__) + + +@ClassFactory.register(ClassType.CALLBACK) +class DataParallel(Callback): + """Callback that saves the evaluated Performance.""" + + def __init__(self): + """Initialize ModelCheckpoint callback.""" + super(DataParallel, self).__init__() + self.priority = 260 + + def before_train(self, logs=None): + """Be called before the training process.""" + if not vega.is_torch_backend() or not General._parallel or General.devices_per_trainer == 1: + return + model = self.trainer.model + import torch + if vega.is_gpu_device(): + if General._parallel and General.devices_per_trainer > 1: + model = torch.nn.DataParallel(model) + elif vega.is_npu_device(): + if General._parallel and General.devices_per_trainer > 1: + import torch.distributed as dist + dist.init_process_group( + backend='hccl', world_size=int(os.environ['WORLD_SIZE']), + rank=int(os.environ['RANK_ID'])) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[int(os.environ['DEVICE_ID'])]) + self.trainer.model = model diff --git a/vega/trainer/callbacks/model_builder.py b/vega/trainer/callbacks/model_builder.py index f5356176..9aa4f79b 100644 --- a/vega/trainer/callbacks/model_builder.py +++ b/vega/trainer/callbacks/model_builder.py @@ -17,7 +17,6 @@ from vega.common import FileOps, Config from vega.common import ClassFactory, ClassType from vega.networks.model_config import ModelConfig -from vega.common.general import General from vega.model_zoo import ModelZoo logger = logging.getLogger(__name__) @@ -30,7 +29,7 @@ class ModelBuilder(Callback): def __init__(self): """Initialize ModelCheckpoint callback.""" super(ModelBuilder, self).__init__() - self.priority = 240 + self.priority = 200 def init_trainer(self, logs=None): """Set trainer object for current callback.""" @@ -52,19 +51,10 @@ def _init_model(self): if hasattr(model, "desc"): self.trainer.model_desc = model.desc if vega.is_torch_backend(): - import torch if vega.is_gpu_device(): model = model.cuda() - if General._parallel and General.devices_per_trainer > 1: - model = torch.nn.DataParallel(model) elif vega.is_npu_device(): - model = model.npu() - if General._parallel and General.devices_per_trainer > 1: - import torch.distributed as dist - dist.init_process_group(backend='hccl', world_size=int(os.environ['WORLD_SIZE']), - rank=int(os.environ['RANK_ID'])) - model = torch.nn.parallel.DistributedDataParallel(model, - device_ids=[int(os.environ['DEVICE_ID'])]) + model = model.to(vega.get_devices()) return model def _get_model_desc(self): diff --git a/vega/trainer/callbacks/model_checkpoint.py b/vega/trainer/callbacks/model_checkpoint.py index b788eb8d..016eeabb 100644 --- a/vega/trainer/callbacks/model_checkpoint.py +++ b/vega/trainer/callbacks/model_checkpoint.py @@ -43,6 +43,9 @@ def after_epoch(self, epoch, logs=None): """Be called after each epoch.""" if not self.trainer.config.save_checkpoint: return + if not self.trainer.do_validation: + self._save_best_model() + return self._save_checkpoint(epoch) if self.trainer.multi_task: self._saved_multi_checkpoint(epoch) diff --git a/vega/trainer/callbacks/progress_logger.py b/vega/trainer/callbacks/progress_logger.py index 924866c6..5cf8d125 100644 --- a/vega/trainer/callbacks/progress_logger.py +++ b/vega/trainer/callbacks/progress_logger.py @@ -10,6 +10,7 @@ """ProgressLogger call defination.""" import logging +import time import numpy as np from collections.abc import Iterable from .callback import Callback @@ -47,10 +48,16 @@ def before_train(self, logs=None): self.train_verbose = 0 if self.valid_report_steps is None: self.valid_verbose = 0 + self.total_time_pre_reports = 0 + self.total_time = 0 logging.debug("Start the unified trainer ... ") self.is_chief = self.params['is_chief'] self.do_validation = self.params['do_validation'] + def before_train_step(self, batch_index, logs=None): + """Be called before a batch training.""" + self.step_start_time = time.perf_counter() + def before_epoch(self, epoch, logs=None): """Be called before each epoch.""" self.cur_epoch = epoch @@ -60,6 +67,7 @@ def before_epoch(self, epoch, logs=None): def after_train_step(self, batch_index, logs=None): """Be called before each batch training.""" + self.total_time_pre_reports += time.perf_counter() - self.step_start_time if self.train_verbose >= 2 and self.is_chief \ and batch_index % self.train_report_steps == 0: metrics_results = logs.get('train_step_metrics', None) @@ -71,24 +79,33 @@ def after_train_step(self, batch_index, logs=None): cur_loss = 0 loss_avg = 0 logging.warning("Cant't get the loss, maybe the loss doesn't update in the metric evaluator.") + + current_time = self.total_time_pre_reports / self.train_report_steps + mean_time = 0 + not_perf_batch = 5 + if batch_index // self.train_report_steps > not_perf_batch: + self.total_time += self.total_time_pre_reports + mean_time = self.total_time / (batch_index - not_perf_batch * self.train_report_steps) + self.total_time_pre_reports = 0 if metrics_results is not None: - log_info = "worker id [{}], epoch [{}/{}], train step {}, " \ - "loss [{:8.3f}, {:8.3f}], lr [{:12.7f}], train metrics {}" + log_info = "worker id [{}], epoch [{}/{}], train step {}, loss [{:8.3f}, {:8.3f}], " \ + "lr [{:12.7f}, time [{:4.3f}], mean time [{:4.3f}s], train metrics {}" log_info = log_info.format( self.trainer.worker_id, self.cur_epoch + 1, self.trainer.epochs, self._format_batch(batch_index, self.train_num_batches), - cur_loss, loss_avg, lr, + cur_loss, loss_avg, lr, current_time, mean_time, self._format_metrics(metrics_results)) logging.info(log_info) else: - log_info = "worker id [{}], epoch [{}/{}], train step {}, loss [{:8.3f}, {:8.3f}], lr [{:12.7f}]" + log_info = "worker id [{}], epoch [{}/{}], train step {}, loss [{:8.3f}, {:8.3f}], lr [{:12.7f}]" \ + ", time [{:4.3f}s] , mean time [{:4.3f}s]" log_info = log_info.format( self.trainer.worker_id, self.cur_epoch + 1, self.trainer.epochs, self._format_batch(batch_index, self.train_num_batches), - cur_loss, loss_avg, lr) + cur_loss, loss_avg, lr, current_time, mean_time) logging.info(log_info) def after_valid_step(self, batch_index, logs=None): @@ -97,7 +114,7 @@ def after_valid_step(self, batch_index, logs=None): and self.do_validation and batch_index % self.valid_report_steps == 0: metrics_results = logs.get('valid_step_metrics', None) if metrics_results is not None: - log_info = "worker id [{}], epoch [{}/{}], valid step {}, valid metrics {}".format( + log_info = "worker id [{}], epoch [{}/{}], valid step {}, valid metrics {}".format( self.trainer.worker_id, self.cur_epoch + 1, self.trainer.epochs, diff --git a/vega/trainer/callbacks/report_callback.py b/vega/trainer/callbacks/report_callback.py index bae52b46..4e15902d 100644 --- a/vega/trainer/callbacks/report_callback.py +++ b/vega/trainer/callbacks/report_callback.py @@ -52,8 +52,9 @@ def after_train(self, logs=None): def _update_report(self, epoch=0): if self.trainer.standalone: return - if self.trainer.distributed and os.environ["DEVICE_ID"] != "0": - return + if self.trainer.distributed: + if "DEVICE_ID" in os.environ and os.environ.get("DEVICE_ID") != "0": + return try: record = ReportClient().get_record(self.trainer.step_name, self.trainer.worker_id) except Exception as e: diff --git a/vega/trainer/callbacks/timm_trainer_callback.py b/vega/trainer/callbacks/timm_trainer_callback.py index 1c354d89..d124116a 100644 --- a/vega/trainer/callbacks/timm_trainer_callback.py +++ b/vega/trainer/callbacks/timm_trainer_callback.py @@ -131,7 +131,7 @@ def make_batch(self, batch): if vega.is_gpu_device(): input, target = input.cuda(), target.cuda() elif vega.is_npu_device(): - input, target = input.npu(), target.npu() + input, target = input.to(vega.get_devices()), target.to(vega.get_devices()) return input, target def train_step(self, batch): @@ -231,7 +231,7 @@ def _init_model(self): if vega.is_gpu_device(): model = model.cuda() elif vega.is_npu_device(): - model = model.npu() + model = model.to(vega.get_devices()) return model def _init_optimizer(self): @@ -261,7 +261,7 @@ def _init_loss(self): if vega.is_gpu_device(): loss_fn = loss_fn.cuda() elif vega.is_npu_device(): - loss_fn = loss_fn.npu() + loss_fn = loss_fn.to(vega.get_devices()) return loss_fn def _reset_sync_opt(self): diff --git a/vega/trainer/callbacks/visual_callback.py b/vega/trainer/callbacks/visual_callback.py index 2a014064..62b2ac1f 100644 --- a/vega/trainer/callbacks/visual_callback.py +++ b/vega/trainer/callbacks/visual_callback.py @@ -80,7 +80,7 @@ def before_train(self, logs=None): if vega.is_gpu_device(): input_data = input_data.cuda() elif vega.is_npu_device(): - input_data = input_data.npu() + input_data = input_data.to(vega.get_devices()) try: self.summary.add_graph(model=model, feed_data=input_data, backend="torch") diff --git a/vega/trainer/conf.py b/vega/trainer/conf.py index 5dc90595..94f22748 100644 --- a/vega/trainer/conf.py +++ b/vega/trainer/conf.py @@ -59,6 +59,7 @@ class TrainerConfig(ConfigSerializable): valid_interval = 1 syncbn = False amp = False + opt_level = 'O1' lazy_built = False callbacks = None grad_clip = None @@ -78,7 +79,6 @@ class TrainerConfig(ConfigSerializable): codec = None model_desc = None hps_file = None - hps_folder = None loss_scale = 1. save_steps = 500 report_on_valid = False @@ -152,7 +152,6 @@ def rules(cls): "codec": {"type": (str, dict, None)}, "model_desc": {"type": (str, dict, None)}, "hps_file": {"type": (str, None)}, - "hps_folder": {"type": (str, None)}, "loss_scale": {"type": (int, float)}, "save_steps": {"type": int}, "report_on_valid": {"type": bool}, diff --git a/vega/trainer/modules/__pycache__/__init__.cpython-37.pyc b/vega/trainer/modules/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 3d54ae2bfb726626d3387298a30ec2903075c65b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 183 zcmZ?b<>g`kg51Mb6BL2;V-N=hn1BoiATAaF5-AKRj5!Rsj8Tk?3@J>(44TX@8G*u@ zjJJf6^YhX&)8ms8vy)TvQsQ$H3;Z;hZZQ;r^si(nVgXWM;+Ls@dXa&CRaTaMQEEZH zeqw1!er}F_S!#Nseo0YcW?pKMer|qBX-;afetdjpUS>&ryk0@&Ee;!qs2#|{&p^xo E0J;n=0{{R3 diff --git a/vega/trainer/modules/__pycache__/config_bakcend_map.cpython-37.pyc b/vega/trainer/modules/__pycache__/config_bakcend_map.cpython-37.pyc deleted file mode 100644 index db3c8daf09bedba68159ef89234d6d8e1c63ae9e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1866 zcmZt{%Z?jGu&3u~uh-dZav%~&BNVh)ifj&WSw$h*fDnh>3lc~pP&1isJ7arhobGnS zTH^~F@e5qNM}8nbz!f2J;4)WE`~oMcx@YzwZCg`aRo(TjYM<=y_X(`x_rHJqt4GK` zNOa2s#_mB{jGp2aWIY#~ZM{ud0AvysyTRxCrhTNIGz z@TnMpAm9!`R}z;~?!G07XB|m=9)JsB^U&%zvASCaHovwG6?Zw*`O)vku#Q)$OpCZM zkuZI}62~nOBN69itYw;)LOzL)s+G=*BS4mwQb%zm0h3H+uLR~n9$Ae>?v_HOr8&x6 zhzb|(_`QHnB>*(pIkj`?)PzxkQ-ix4IiB`+!6lJnPLV1<208@#8Fvl+9?%icXWTP% zqAHXpUgwDA{frQ??ZDhYR< z#CnQR7iv7}H(pk)&YOq}u`<|w44k!k8P7A_JTTc+x0!?11(WT@dn0CPbC4@$OtAK_ z*+=9P&Apk-L&cF$r==w)DAKcZF z<4!ii4#$O~i9={w;157F(aw%y6I?i3TGP3^b!#Z_+*^e5!x~Ci)1A+qoLthar^yA` z`gmHqj|l7mK%cqDhdl&NwDW7%+zx)4*50mnNxmRIB;V6*xQ*(t4i0zY@*Ol<1h?PSWH+cIXia|}*1f0X^|X$z@$Vpp8X6Cp&fENX zny$K@yXzUS<9A~OZPJgSzw%jLrpr!vDy!lO&zHhXl94lm)`{7yV|QZFuP;fV*RnkB zG^A54pn= zAVQv>Yf-4CpR2r7IxRENcxFQGb;HoQH0sNHC=l4Vl6{*FW=U%W)h5GEu)R2J6)@{( z<1WN`Lnokp$1`nx=tOkr0CWzh2k_82q(i!S zY*T%5>&+c^?zE=PMga~?nHGYvrq5XWAhKq_*z0w=Y)?YQc$G0G55c-|=7xhfx6*it zk0stvnCr;*Xf%|_8-+54&DH@&Dmlc79%ofiR3&Wq$7D*?1l7yA$g~l=%SIDDwLU)E O13e1M24*1af%`wr(guS7 diff --git a/vega/trainer/modules/conf/__pycache__/__init__.cpython-37.pyc b/vega/trainer/modules/conf/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index d644a468c1d386b1aa277e4bb6adace633a878cd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 145 zcmZ?b<>g`kg51Mb6F~H15CH>>K!yVl7qb9~6oz01O-8?!3`HPe1o6vCKfTC6zbY$B zzbLgJUq7+5BtJJtzbrL9QNN@pF*7fUap)c>Fd``F*n-_Q%MJ@qg2)Ngj>#09zphr^j6Ih>h~`(3ZsVfeB?{(QmPjQvHE z+s_B{7~MVsQ7mMNE0OX{gn|?ArC#WL&y=r%S4;(3Tza98c|$cZZ|VT^hVqWsu(iCu zuP4zwEeB~~?I14lNiwygzp*hJa<{!#368W$qBL1XC#ilGlkWBl){NZ;qglw|j!-=G zln8z0O?en7ABP2DQ@2#-7dGLQ8|_OTwjpV$_A3S+8lQz5aB0JJ4&X|C@N)OoJ(3M- z;b_?{O_b!?jL2!{mlrdA?D~5KvnP3JirEF3$KKjG8YyNp+5?c1m3+l6#kE-RmADZv zS?OK-H_SX$e8tb3W~7Aj64=Nff2d$}7C`!R#m<83My}K& z952t##zxPIaWpTBEFIJK<8>s*S)t~swqvS*F@<$ByD%s`_BOh`2g1B<-a_m04)62j zE`@lVkk>xPoDA^^bh^(Tx_t{o@scGRl!J=j#7hJk!Sj|6pX+jN@`0h_2NOi9x_r9&WnuFR7PcK zY~{M!s4|i1^0ZJ@Oe^nPZ=Ir8!(6_(z>u3hd^XiRD2SWrmP9S?3!gW*k5TY`rz)8k ziq{`y>@m9i733}+wXb~?F3MW@S04DFYBi6C0jR8bAX0URR~uJ;7|=RcchM>-_zqgH zvbc98D($ACRS<=+$sGJu)*i(_B{|P8G(`^0Cf4bC%E8U(TA+d4OzkqI@WE_ ub1;gFEGu$D@HO-Usvo6BadM{Pva9qg$l-{?neQ9wr|ERWa1?+2{R{LM`-hCn z;p5^7iv1ClW(m_=i<}oC5u80|+SC3y)4maBUgDwM(*fFn@zL(-{t=4?&An%4u2y*+ z<{)kQ#^_aUAWX~UJX_e=1qNmjcl%GMX6Za;e}nECKB~UOaO_r9c%n@?FoaDg!{bxgnjA zbOHuwN1pT2+&W*YTDg%u$tsB$r6n(R%2=)HvdH6i zW})H@OmR^X^EM{t<891l)hSR5-a)Z+EA#gGfCs#}yM_A(_N)9-PG8~>D3tqC6nh($ z<~7SWYK~fb6z>r$#K&vCIW+YOmLX6c;T*weYsr?284%u)Xs51n;XGRdR*qPpV~l>+ zYfr)rTz6`hRyR?7hhhU%%nSI0H+R0wygeK7SJBiBq+F`PNa;pWcDUerBIVna$~T_A zl)6lDd%S^}`pWIeRfSv_sq$Q^x(4f}()^m{D{bV7IkndnxN%EQVd^)fZUJsH#qYa- zslgEPx|oQ=4(1PSO#~=eb|2M%`@-iwGUD?cA^2{;ZI%>WSA*Dif?~f#wf#QvU%--H zs{<{F2X}i5oENI9kbq~1auXgvc=ZIsjVB~m*S@otQ6Fw$eE5nKaShZoKuNki0!S8J zytaki&928k+q^D#i^6*t;0KhJLMk`#fc9QSL3@K6Q!CHXvM9PmkO*6t|sBCiAbInX(`%Dnda*5z7G}LAyQWV%*p>o84rR+8JP@ z>1kK$X%8Gh9N-AwfHR+gi?5vc3Y>V)nVnq_!jWH|-yi$EpWpMRqfwv0D1Q6>%Wp$M z{=mi6@PT;<)4UHvl7vVqSxyU8!>X|{(x96fXw*0E>3G&k0lq87GqPP2-q$0|Jq=zE-} zxiE(FpUGbz%$NGXyg50UYgLtVv98M^pMRy6V!lOSUX*g3D>KI{oI5yX)tTf=&6P+g3$MZ00Dz=vO_5mP;a$hFG`p~M6YrWI13`EUn5 zd?90W+;x?af@(JqnTIgVFF|gZy7btxdSO{%G`=~>2-|ySZo7vEM_=f&O2%7e zr;0s&4;-AAehWFQsbrKZp;r(j*Kp-FEVdEY;nyl#9@ku|RGfh`TI>5r-Ub3?>jcO@ ziM#p^DqcadhXi{?Pk`77Rv0f@AY8-SJMD|`58-vi+^GIHV1svIng|Hi3xoO}(y7O2 z%Z;6+r5dfD#DP284)95srDai+D~)a3K087A>_owvgSP*5qzyl8JLe-TJ=b+U9eU%} F{{}q!q?rH! diff --git a/vega/trainer/modules/lr_schedulers/warmup_scheduler_torch.py b/vega/trainer/modules/lr_schedulers/warmup_scheduler_torch.py index 2694c8ab..239fbc55 100644 --- a/vega/trainer/modules/lr_schedulers/warmup_scheduler_torch.py +++ b/vega/trainer/modules/lr_schedulers/warmup_scheduler_torch.py @@ -91,12 +91,13 @@ def step(self, epoch=None): self.after_scheduler.step(epoch) return - self.current_iters = epoch - warmup_lr = self.get_lr() - for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): - param_group['lr'] = lr + if epoch is not None: + self.current_iters = epoch + warmup_lr = self.get_lr() + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): + param_group['lr'] = lr - if epoch >= self.warmup_iters: - self.warmup_finished = True - self.after_scheduler = LrScheduler(self.after_scheduler_config)(self.optimizer) - self.by_epoch = self.after_scheduler.by_epoch + if epoch >= self.warmup_iters: + self.warmup_finished = True + self.after_scheduler = LrScheduler(self.after_scheduler_config)(self.optimizer) + self.by_epoch = self.after_scheduler.by_epoch diff --git a/vega/trainer/modules/optimizer/optim.py b/vega/trainer/modules/optimizer/optim.py index f6ff6615..c0119b16 100644 --- a/vega/trainer/modules/optimizer/optim.py +++ b/vega/trainer/modules/optimizer/optim.py @@ -100,6 +100,12 @@ def set_distributed(cls, optimizer, model=None): import torch.optim as torch_opt ClassFactory.register_from_package(torch_opt, ClassType.OPTIMIZER) + if vega.is_npu_device(): + try: + from apex.optimizers import NpuFusedSGD + ClassFactory.register_cls(NpuFusedSGD, ClassType.OPTIMIZER) + except Exception: + pass elif vega.is_tf_backend(): import tensorflow.compat.v1.train as tf_train diff --git a/vega/trainer/script_runner.py b/vega/trainer/script_runner.py index a1f23ccc..de35cab0 100644 --- a/vega/trainer/script_runner.py +++ b/vega/trainer/script_runner.py @@ -81,18 +81,11 @@ def _get_hps(self, hps): if hps is not None: pass elif self.config.hps_file is not None: - desc_file = self.config.hps_file.replace("{local_base_path}", self.local_base_path) - hps = Config(desc_file) - if "trainer" in hps: - if "epochs" in hps["trainer"]: - hps["trainer"].pop("epochs") - if "checkpoint_path" in hps["trainer"]: - hps["trainer"].pop("checkpoint_path") - elif self.config.hps_folder is not None: - folder = self.config.hps_folder.replace("{local_base_path}", self.local_base_path) - pattern = os.path.join(folder, "hps_*.json") - desc_file = glob.glob(pattern)[0] - hps = Config(desc_file) + hps_file = self.config.hps_file.replace("{local_base_path}", self.local_base_path) + if os.path.isdir(hps_file): + pattern = os.path.join(hps_file, "hps_*.json") + hps_file = glob.glob(pattern)[0] + hps = Config(hps_file) if "trainer" in hps: if "epochs" in hps["trainer"]: hps["trainer"].pop("epochs") diff --git a/vega/trainer/trainer_base.py b/vega/trainer/trainer_base.py index 5370def2..19187b18 100644 --- a/vega/trainer/trainer_base.py +++ b/vega/trainer/trainer_base.py @@ -10,6 +10,7 @@ """Base Trainer.""" +import os import glob import logging import vega @@ -22,6 +23,7 @@ from vega.trainer.utils import WorkerTypes from vega.datasets import Adapter from vega.common.general import General +from vega.common.utils import update_dict class TrainerBase(DistributedWorker): @@ -194,18 +196,11 @@ def _init_hps(self, hps=None): if hps is not None: pass elif self.config.hps_file is not None: - desc_file = self.config.hps_file.replace("{local_base_path}", self.local_base_path) - hps = Config(desc_file) - if "trainer" in hps: - if "epochs" in hps["trainer"]: - hps["trainer"].pop("epochs") - if "checkpoint_path" in hps["trainer"]: - hps["trainer"].pop("checkpoint_path") - elif self.config.hps_folder is not None: - folder = self.config.hps_folder.replace("{local_base_path}", self.local_base_path) - pattern = FileOps.join_path(folder, "hps_*.json") - desc_file = glob.glob(pattern)[0] - hps = Config(desc_file) + hps_file = self.config.hps_file.replace("{local_base_path}", self.local_base_path) + if os.path.isdir(hps_file): + pattern = os.path.join(hps_file, "hps_*.json") + hps_file = glob.glob(pattern)[0] + hps = Config(hps_file) if "trainer" in hps: if "epochs" in hps["trainer"]: hps["trainer"].pop("epochs") @@ -215,8 +210,7 @@ def _init_hps(self, hps=None): if not self.hps: self.hps = hps elif hps: - hps.from_dict(self.hps) - self.hps = hps + self.hps = update_dict(self.hps, hps) # set config if self.hps and self.hps.get('trainer'): self.config.from_dict(self.hps.get('trainer')) diff --git a/vega/trainer/trainer_torch.py b/vega/trainer/trainer_torch.py index 2b7096ac..94cee6da 100644 --- a/vega/trainer/trainer_torch.py +++ b/vega/trainer/trainer_torch.py @@ -9,8 +9,6 @@ # MIT License for more details. """Torch Trainer.""" - -import os import torch import numpy as np import vega @@ -52,7 +50,7 @@ def build(self): if self.use_amp: from apex import amp self.model, self.optimizer = amp.initialize( - self.model, self.optimizer, opt_level='O1') + self.model, self.optimizer, opt_level=self.config.opt_level, loss_scale=64, combine_grad=True) def _set_default_funcs(self): self.make_batch = self._default_make_batch @@ -77,8 +75,7 @@ def _init_setting(self): torch.cuda.manual_seed(self.config.seed) elif vega.is_npu_device(): import torch.npu - device = "npu:{}".format(os.environ.get('DEVICE_ID', 0)) - torch.npu.set_device(device) + torch.npu.set_device(vega.get_devices()) torch.npu.manual_seed(self.config.seed) elif vega.is_cpu_device(): self.config.device = -1 @@ -108,7 +105,7 @@ def _init_horovod_setting(self): def _train_epoch(self): self.model.train() for batch_index, batch in enumerate(self.train_loader): - if self.config.max_train_steps and batch_index > self.config.max_train_steps: + if self.config.max_train_steps and batch_index >= self.config.max_train_steps: return batch = self.make_batch(batch) batch_logs = {'train_batch': batch} @@ -122,7 +119,6 @@ def _train_epoch(self): def _valid_epoch(self): self.callbacks.before_valid() valid_logs = None - self.model.eval() with torch.no_grad(): for batch_index, batch in enumerate(self.valid_loader): @@ -145,7 +141,7 @@ def _set_device(self, data): if vega.is_gpu_device(): return data.cuda() else: - return data.npu() + return data.to(vega.get_devices()) if isinstance(data, dict): return {k: self._set_device(v) for k, v in data.items()} elif isinstance(data, list): @@ -161,6 +157,8 @@ def _default_train_step(self, batch): output = self.model(**batch) elif isinstance(batch, list) and isinstance(batch[0], dict): output = self.model(batch) + elif isinstance(batch, list) and isinstance(batch[0], list): + output = self.model(*batch) else: # classification input, target = batch @@ -177,11 +175,16 @@ def _default_train_step(self, batch): loss = self.loss(output, target) if self.use_amp: from apex import amp - with amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - self.optimizer.synchronize() - with self.optimizer.skip_synchronize(): + if vega.is_npu_device(): + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() self.optimizer.step() + else: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + self.optimizer.synchronize() + with self.optimizer.skip_synchronize(): + self.optimizer.step() else: loss.backward() if self.config.grad_clip: