Skip to content

Commit

Permalink
[feature] add support for horovod (#389)
Browse files Browse the repository at this point in the history
* add support for horovod
* remove dependency on HOROVOD_RANK env
  • Loading branch information
chengmengli06 authored Jul 10, 2023
1 parent 2145e64 commit 7648671
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 5 deletions.
15 changes: 15 additions & 0 deletions easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
from easy_rec.python.utils.export_big_model import export_big_model
from easy_rec.python.utils.export_big_model import export_big_model_to_oss

try:
import horovod.tensorflow as hvd
except Exception:
hvd = None

if tf.__version__ >= '2.0':
gfile = tf.compat.v1.gfile
from tensorflow.core.protobuf import config_pb2
Expand Down Expand Up @@ -97,6 +102,16 @@ def _create_estimator(pipeline_config, distribution=None, params={}):
model_config = pipeline_config.model_config
train_config = pipeline_config.train_config
gpu_options = GPUOptions(allow_growth=False)

if hvd is not None:
gpus = estimator_utils.get_available_gpus()
if len(gpus) > 0:
local_rnk = hvd.local_rank()
num_gpus_per_worker = pipeline_config.train_config.num_gpus_per_worker
sid = local_rnk * num_gpus_per_worker
eid = sid + num_gpus_per_worker
gpu_options.visible_device_list = ','.join(gpus[sid:eid])

session_config = ConfigProto(
gpu_options=gpu_options,
allow_soft_placement=True,
Expand Down
18 changes: 14 additions & 4 deletions easy_rec/python/model/easy_rec_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
from easy_rec.python.utils import pai_util
from easy_rec.python.utils.multi_optimizer import MultiOptimizer

try:
import horovod.tensorflow as hvd
except Exception:
hvd = None

if tf.__version__ >= '2.0':
tf = tf.compat.v1

Expand Down Expand Up @@ -201,6 +206,14 @@ def _train_model_fn(self, features, labels, run_config):
optimizer = MultiOptimizer(all_opts, grouped_vars)

hooks = []
if estimator_utils.has_hvd():
assert not self.train_config.sync_replicas, \
'sync_replicas should not be set when using horovod'
optimizer = hvd.DistributedOptimizer(
optimizer, backward_passes_per_step=1)
bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
hooks.append(bcast_hook)

# for distributed and synced training
if self.train_config.sync_replicas and run_config.num_worker_replicas > 1:
logging.info('sync_replicas: num_worker_replias = %d' %
Expand Down Expand Up @@ -359,7 +372,6 @@ def _train_model_fn(self, features, labels, run_config):
# for multi worker strategy, we could not replace the
# inner CheckpointSaverHook, so just use it.
scaffold = tf.train.Scaffold()
chief_hooks = []
else:
var_list = (
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
Expand Down Expand Up @@ -407,9 +419,8 @@ def _train_model_fn(self, features, labels, run_config):
write_graph=self.train_config.write_graph,
data_offset_var=data_offset_var,
increment_save_config=self.incr_save_config)
chief_hooks = []
hooks.append(saver_hook)
if estimator_utils.is_chief():
hooks.append(saver_hook)
hooks.append(
basic_session_run_hooks.StepCounterHook(
every_n_steps=log_step_count_steps, output_dir=self.model_dir))
Expand All @@ -426,7 +437,6 @@ def _train_model_fn(self, features, labels, run_config):
predictions=predict_dict,
train_op=train_op,
scaffold=scaffold,
training_chief_hooks=chief_hooks,
training_hooks=hooks)

def _eval_model_fn(self, features, labels, run_config):
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ enum DistributionStrategy {
// multi worker multi gpu mode
// see tf.distribute.experimental.MultiWorkerMirroredStrategy
MultiWorkerMirroredStrategy = 5;
// use horovod strategy
HorovodStrategy = 6;
}

message IncrementSaveConfig {
Expand Down
13 changes: 13 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
except Exception:
gl = None

try:
import horovod as hvd
except Exception:
hvd = None

if tf.__version__ >= '2.0':
tf = tf.compat.v1

Expand Down Expand Up @@ -1039,6 +1044,14 @@ def test_multi_tower_recall_neg_sampler_only_sequence_feature(self):
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(hvd is None, 'horovod is not installed')
def test_horovod(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/deepfm_combo_on_avazu_ctr.config',
self._test_dir,
use_hvd=True)
self.assertTrue(self._success)


if __name__ == '__main__':
tf.test.main()
4 changes: 4 additions & 0 deletions easy_rec/python/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tensorflow.python.platform import gfile

from easy_rec.python.main import _train_and_evaluate_impl
from easy_rec.python.protos.train_pb2 import DistributionStrategy
from easy_rec.python.utils import config_util
from easy_rec.python.utils import ds_util
from easy_rec.python.utils import estimator_utils
Expand Down Expand Up @@ -147,6 +148,9 @@
if pipeline_config.train_config.fine_tune_checkpoint:
ds_util.cache_ckpt(pipeline_config)

if pipeline_config.train_config.train_distribute == DistributionStrategy.HorovodStrategy:
estimator_utils.init_hvd()

if args.hpo_param_path:
with gfile.GFile(args.hpo_param_path, 'r') as fin:
hpo_config = json.load(fin)
Expand Down
37 changes: 37 additions & 0 deletions easy_rec/python/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import logging
import os
import re
import sys
import time

import numpy as np
import six
import tensorflow as tf
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.client import device_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
Expand All @@ -31,6 +33,11 @@

from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer # NOQA

try:
import horovod.tensorflow as hvd
except Exception:
hvd = None

try:
from kafka import KafkaProducer, KafkaAdminClient
from kafka.admin import NewTopic
Expand Down Expand Up @@ -442,6 +449,9 @@ def __init__(self,
self._sparse_timer = None

def after_create_session(self, session, coord):
if not is_chief():
return

global_step = session.run(self._global_step_tensor)
if self._write_graph:
# We do write graph and saver_def at the first call of before_run.
Expand Down Expand Up @@ -581,6 +591,9 @@ def _send_sparse(self, global_step, session):
% (global_step, msg_num, len(bytes_buf)))

def after_run(self, run_context, run_values):
if not is_chief():
return

super(CheckpointSaverHook, self).after_run(run_context, run_values)
stale_global_step = run_values.results
global_step = -1
Expand Down Expand Up @@ -637,6 +650,8 @@ def _save(self, session, step):
return should_stop

def end(self, session):
if not is_chief():
return
super(CheckpointSaverHook, self).end(session)
global_step = session.run(self._global_step_tensor)
if self._dense_timer is not None and \
Expand Down Expand Up @@ -950,6 +965,8 @@ def is_chief():
tf_config = json.loads(os.environ['TF_CONFIG'])
if 'task' in tf_config:
return tf_config['task']['type'] in ['chief', 'master']
elif has_hvd():
return hvd.rank() == 0
return True


Expand All @@ -967,3 +984,23 @@ def is_evaluator():
if 'task' in tf_config:
return tf_config['task']['type'] == 'evaluator'
return False


def has_hvd():
return hvd is not None and 'HOROVOD_RANK' in os.environ


def init_hvd():
if hvd is None:
logging.error(
'horovod is not installed: HOROVOD_WITH_TENSORFLOW=1 pip install horovod'
)
sys.exit(1)

hvd.init()
os.environ['HOROVOD_RANK'] = str(hvd.rank())


def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
26 changes: 25 additions & 1 deletion easy_rec/python/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,24 +633,48 @@ def _multi_worker_mirror_train(pipeline_config_path, test_dir, num_worker):
return procs


def _multi_worker_hvd_train(pipeline_config_path, test_dir, num_worker):
gpus = get_available_gpus()
# not enough gpus, run on cpu only
if len(gpus) < num_worker:
gpus = ''
else:
gpus = ','.join(gpus)
set_gpu_id(gpus)
ports = _get_ports(num_worker)
hosts = ','.join(['localhost:%d' % ports[i] for i in range(num_worker)])
train_cmd = 'horovodrun -np %d --hosts %s python -m easy_rec.python.train_eval --pipeline_config_path %s' % (
num_worker, hosts, pipeline_config_path)
proc = run_cmd(train_cmd, '%s/log_hvd.txt' % test_dir)
proc_wait(proc, timeout=1200)
return proc.returncode == 0


def test_distributed_train_eval(pipeline_config_path,
test_dir,
total_steps=50,
num_evaluator=0,
edit_config_json=None):
edit_config_json=None,
use_hvd=False):
logging.info('testing pipeline config %s' % pipeline_config_path)
pipeline_config = _load_config_for_test(pipeline_config_path, test_dir,
total_steps)
if edit_config_json is not None:
config_util.edit_config(pipeline_config, edit_config_json)

if use_hvd:
pipeline_config.train_config.sync_replicas = False
pipeline_config.train_config.train_distribute = DistributionStrategy.HorovodStrategy

train_config = pipeline_config.train_config
config_util.save_pipeline_config(pipeline_config, test_dir)
test_pipeline_config_path = os.path.join(test_dir, 'pipeline.config')

task_failed = None
procs = None
try:
if use_hvd:
return _multi_worker_hvd_train(test_pipeline_config_path, test_dir, 2)
if train_config.train_distribute == DistributionStrategy.NoStrategy:
num_worker = 2
procs = _ps_worker_train(test_pipeline_config_path, test_dir, num_worker,
Expand Down

0 comments on commit 7648671

Please sign in to comment.