Skip to content

Commit

Permalink
refactor sok dynamic variable
Browse files Browse the repository at this point in the history
  • Loading branch information
chengmengli06 committed Jan 4, 2024
1 parent 8530ca8 commit 315cc02
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 13 deletions.
19 changes: 13 additions & 6 deletions docs/source/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@
- 仅在train_distribute为NoStrategy时可以设置成true,其它情况应该设置为false
- PS异步训练也设置为false

- train_distribute: 默认不开启Strategy(NoStrategy), strategy确定分布式执行的方式

- NoStrategy 不使用Strategy
- PSStrategy 异步ParameterServer模式
- MirroredStrategy 单机多卡模式,仅在PAI上可以使用,本地和EMR上不能使用
- MultiWorkerMirroredStrategy 多机多卡模式,在TF版本>=1.15时可以使用
- train_distribute: 默认不开启Strategy(NoStrategy), strategy确定分布式执行的方式, 可以分成两种模式: PS-Worker模式 和 All-Reduce模式

- PS-Worker模式:
- NoStrategy: 不使用Strategy
- PSStrategy: 异步ParameterServer模式
- All-Reduce模式:
- 数据并行:
- MirroredStrategy: 单机多卡模式,仅在PAI上可以使用,本地和EMR上不能使用
- MultiWorkerMirroredStrategy: 多机多卡模式,在TF版本>=1.15时可以使用
- HorovodStragtegy: horovod多机多卡并行, 需要安装horovod
- 混合并行(数据并行 + Embedding分片):
- EmbeddingParallelStrategy: 在horovod多机多卡并行的基础上, 增加了Embedding分片的功能
- SokStrategy: 在horovod多机多卡并行的基础上, 增加了KV Embedding和Embedding分片的功能

- num_gpus_per_worker: 仅在MirrorredStrategy, MultiWorkerMirroredStrategy, PSStrategy的时候有用

Expand Down
6 changes: 4 additions & 2 deletions easy_rec/python/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def main(argv):

if pipeline_config.train_config.train_distribute in [
DistributionStrategy.HorovodStrategy,
DistributionStrategy.EmbeddingParallelStrategy
]:
estimator_utils.init_hvd()
elif pipeline_config.train_config.train_distribute == DistributionStrategy.SokStrategy:
elif pipeline_config.train_config.train_distribute in [
DistributionStrategy.EmbeddingParallelStrategy,
DistributionStrategy.SokStrategy
]:
estimator_utils.init_hvd()
estimator_utils.init_sok()

Expand Down
6 changes: 4 additions & 2 deletions easy_rec/python/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ def main(argv):
pipeline_config_path)
if pipeline_config.train_config.train_distribute in [
DistributionStrategy.HorovodStrategy,
DistributionStrategy.EmbeddingParallelStrategy
]:
estimator_utils.init_hvd()
elif pipeline_config.train_config.train_distribute == DistributionStrategy.SokStrategy:
elif pipeline_config.train_config.train_distribute in [
DistributionStrategy.EmbeddingParallelStrategy,
DistributionStrategy.SokStrategy
]:
estimator_utils.init_hvd()
estimator_utils.init_sok()

Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ enum DistributionStrategy {
MultiWorkerMirroredStrategy = 5;
// use horovod strategy
HorovodStrategy = 6;
// support kv embedding, support kv embedding shard
SokStrategy = 7;
// support embedding shard, requires horovod
EmbeddingParallelStrategy = 8;
}

Expand Down
6 changes: 4 additions & 2 deletions easy_rec/python/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,12 @@

if pipeline_config.train_config.train_distribute in [
DistributionStrategy.HorovodStrategy,
DistributionStrategy.EmbeddingParallelStrategy
]:
estimator_utils.init_hvd()
elif pipeline_config.train_config.train_distribute == DistributionStrategy.SokStrategy:
elif pipeline_config.train_config.train_distribute in [
DistributionStrategy.EmbeddingParallelStrategy,
DistributionStrategy.SokStrategy
]:
estimator_utils.init_hvd()
estimator_utils.init_sok()

Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def init_sok():
os.environ['ENABLE_SOK'] = '1'
return True
except Exception:
logging.error('sok is not installed')
logging.warning('sok is not installed')
return False


Expand Down

0 comments on commit 315cc02

Please sign in to comment.