From 315cc028fdfe2e141e8ea4614bc9e216f386f0a4 Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Thu, 4 Jan 2024 12:31:24 +0800 Subject: [PATCH] refactor sok dynamic variable --- docs/source/train.md | 19 +++++++++++++------ easy_rec/python/eval.py | 6 ++++-- easy_rec/python/export.py | 6 ++++-- easy_rec/python/protos/train.proto | 2 ++ easy_rec/python/train_eval.py | 6 ++++-- easy_rec/python/utils/estimator_utils.py | 2 +- 6 files changed, 28 insertions(+), 13 deletions(-) diff --git a/docs/source/train.md b/docs/source/train.md index c1cdc755d..355b42a7d 100644 --- a/docs/source/train.md +++ b/docs/source/train.md @@ -61,12 +61,19 @@ - 仅在train_distribute为NoStrategy时可以设置成true,其它情况应该设置为false - PS异步训练也设置为false -- train_distribute: 默认不开启Strategy(NoStrategy), strategy确定分布式执行的方式 - - - NoStrategy 不使用Strategy - - PSStrategy 异步ParameterServer模式 - - MirroredStrategy 单机多卡模式,仅在PAI上可以使用,本地和EMR上不能使用 - - MultiWorkerMirroredStrategy 多机多卡模式,在TF版本>=1.15时可以使用 +- train_distribute: 默认不开启Strategy(NoStrategy), strategy确定分布式执行的方式, 可以分成两种模式: PS-Worker模式 和 All-Reduce模式 + + - PS-Worker模式: + - NoStrategy: 不使用Strategy + - PSStrategy: 异步ParameterServer模式 + - All-Reduce模式: + - 数据并行: + - MirroredStrategy: 单机多卡模式,仅在PAI上可以使用,本地和EMR上不能使用 + - MultiWorkerMirroredStrategy: 多机多卡模式,在TF版本>=1.15时可以使用 + - HorovodStragtegy: horovod多机多卡并行, 需要安装horovod + - 混合并行(数据并行 + Embedding分片): + - EmbeddingParallelStrategy: 在horovod多机多卡并行的基础上, 增加了Embedding分片的功能 + - SokStrategy: 在horovod多机多卡并行的基础上, 增加了KV Embedding和Embedding分片的功能 - num_gpus_per_worker: 仅在MirrorredStrategy, MultiWorkerMirroredStrategy, PSStrategy的时候有用 diff --git a/easy_rec/python/eval.py b/easy_rec/python/eval.py index aa54e949e..d41c3d7ae 100644 --- a/easy_rec/python/eval.py +++ b/easy_rec/python/eval.py @@ -69,10 +69,12 @@ def main(argv): if pipeline_config.train_config.train_distribute in [ DistributionStrategy.HorovodStrategy, - DistributionStrategy.EmbeddingParallelStrategy ]: estimator_utils.init_hvd() - elif pipeline_config.train_config.train_distribute == DistributionStrategy.SokStrategy: + elif pipeline_config.train_config.train_distribute in [ + DistributionStrategy.EmbeddingParallelStrategy, + DistributionStrategy.SokStrategy + ]: estimator_utils.init_hvd() estimator_utils.init_sok() diff --git a/easy_rec/python/export.py b/easy_rec/python/export.py index cbe747e78..d39db74ea 100644 --- a/easy_rec/python/export.py +++ b/easy_rec/python/export.py @@ -112,10 +112,12 @@ def main(argv): pipeline_config_path) if pipeline_config.train_config.train_distribute in [ DistributionStrategy.HorovodStrategy, - DistributionStrategy.EmbeddingParallelStrategy ]: estimator_utils.init_hvd() - elif pipeline_config.train_config.train_distribute == DistributionStrategy.SokStrategy: + elif pipeline_config.train_config.train_distribute in [ + DistributionStrategy.EmbeddingParallelStrategy, + DistributionStrategy.SokStrategy + ]: estimator_utils.init_hvd() estimator_utils.init_sok() diff --git a/easy_rec/python/protos/train.proto b/easy_rec/python/protos/train.proto index 0df241fe4..ab3ca4ddc 100644 --- a/easy_rec/python/protos/train.proto +++ b/easy_rec/python/protos/train.proto @@ -21,7 +21,9 @@ enum DistributionStrategy { MultiWorkerMirroredStrategy = 5; // use horovod strategy HorovodStrategy = 6; + // support kv embedding, support kv embedding shard SokStrategy = 7; + // support embedding shard, requires horovod EmbeddingParallelStrategy = 8; } diff --git a/easy_rec/python/train_eval.py b/easy_rec/python/train_eval.py index 164798a8a..085884d1c 100644 --- a/easy_rec/python/train_eval.py +++ b/easy_rec/python/train_eval.py @@ -164,10 +164,12 @@ if pipeline_config.train_config.train_distribute in [ DistributionStrategy.HorovodStrategy, - DistributionStrategy.EmbeddingParallelStrategy ]: estimator_utils.init_hvd() - elif pipeline_config.train_config.train_distribute == DistributionStrategy.SokStrategy: + elif pipeline_config.train_config.train_distribute in [ + DistributionStrategy.EmbeddingParallelStrategy, + DistributionStrategy.SokStrategy + ]: estimator_utils.init_hvd() estimator_utils.init_sok() diff --git a/easy_rec/python/utils/estimator_utils.py b/easy_rec/python/utils/estimator_utils.py index ed29302db..a90f0b0f0 100644 --- a/easy_rec/python/utils/estimator_utils.py +++ b/easy_rec/python/utils/estimator_utils.py @@ -1024,7 +1024,7 @@ def init_sok(): os.environ['ENABLE_SOK'] = '1' return True except Exception: - logging.error('sok is not installed') + logging.warning('sok is not installed') return False