Skip to content

Commit

Permalink
fix sync_replicas_optimizer que runners
Browse files Browse the repository at this point in the history
  • Loading branch information
chengmengli06 committed Jul 19, 2023
1 parent 0b6bf36 commit eb535f3
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 15 deletions.
24 changes: 17 additions & 7 deletions easy_rec/python/compat/sync_replicas_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
```
"""

sync_que_id = -1

def __init__(self,
opt,
replicas_to_aggregate,
Expand Down Expand Up @@ -299,15 +301,24 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
global_step)

def _get_token_qname():
SyncReplicasOptimizer.sync_que_id += 1
if SyncReplicasOptimizer.sync_que_id == 0:
return 'sync_token_q'
else:
return 'sync_token_q_' + str(SyncReplicasOptimizer.sync_que_id)

# Create token queue.
token_qname = _get_token_qname()
logging.info('create sync_token_queue[%s]' % token_qname)
with ops.device(global_step.device), ops.name_scope(''):
sync_token_queue = (
data_flow_ops.FIFOQueue(
-1,
global_step.dtype.base_dtype,
shapes=(),
name='sync_token_q',
shared_name='sync_token_q'))
name=token_qname,
shared_name=token_qname))
self._sync_token_queue = sync_token_queue
self._is_sync_que_closed = sync_token_queue.is_closed()
self._close_sync_que = sync_token_queue.close(
Expand Down Expand Up @@ -342,6 +353,8 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):

self._chief_queue_runner = queue_runner.QueueRunner(
dummy_queue, [sync_op])
ops.add_to_collection(ops.GraphKeys.QUEUE_RUNNERS,
self._chief_queue_runner)
for accum, dev in self._accumulator_list:
with ops.device(dev):
chief_init_ops.append(
Expand Down Expand Up @@ -479,14 +492,12 @@ def begin(self):
self._local_init_op = self._sync_optimizer.chief_init_op
self._ready_for_local_init_op = (
self._sync_optimizer.ready_for_local_init_op)
self._q_runner = self._sync_optimizer.get_chief_queue_runner()
self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
self._num_tokens)
else:
self._local_init_op = self._sync_optimizer.local_step_init_op
self._ready_for_local_init_op = (
self._sync_optimizer.ready_for_local_init_op)
self._q_runner = None
self._init_tokens_op = None

def after_create_session(self, session, coord):
Expand All @@ -500,11 +511,10 @@ def after_create_session(self, session, coord):
'local_init. Init op: %s, error: %s' %
(self._local_init_op.name, msg))
session.run(self._local_init_op)
is_closed = session.run(self._sync_optimizer._is_sync_que_closed)
assert not is_closed, 'sync_que is closed'
if self._init_tokens_op is not None:
session.run(self._init_tokens_op)
if self._q_runner is not None:
self._q_runner.create_threads(
session, coord=coord, daemon=True, start=True)

def end(self, session):
try:
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _train_and_evaluate_impl(pipeline_config,
f.write(easy_rec.__version__ + '\n')

train_steps = None
if train_config.HasField('num_steps'):
if train_config.HasField('num_steps') and train_config.num_steps > 0:
train_steps = train_config.num_steps
assert train_steps is not None or data_config.num_epochs > 0, (
'either num_steps and num_epochs must be set to an integer > 0.')
Expand Down Expand Up @@ -348,6 +348,7 @@ def _train_and_evaluate_impl(pipeline_config,
estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
logging.info('Train and evaluate finish')
if fit_on_eval and (not estimator_utils.is_evaluator()):
tf.reset_default_graph()
logging.info('Start continue training on eval data')
eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
**input_fn_kwargs)
Expand Down
9 changes: 9 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,15 @@ def test_fit_on_eval(self):
fit_on_eval=True)
self.assertTrue(self._success)

def test_unbalance_data(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao_unblanace.config',
self._test_dir,
total_steps=0,
num_epoch=1,
num_evaluator=1)
self.assertTrue(self._success)

def test_train_with_ps_worker_with_evaluator(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao.config',
Expand Down
6 changes: 3 additions & 3 deletions easy_rec/python/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@
default=None,
help='eval data input path')
parser.add_argument(
'fit_on_eval',
type=bool,
'--fit_on_eval',
action='store_true',
default=False,
help='Fit evaluation data after fitting and evaluating train data')
parser.add_argument(
'fit_on_eval_steps',
'--fit_on_eval_steps',
type=int,
default=None,
help='Fit evaluation data steps')
Expand Down
12 changes: 8 additions & 4 deletions easy_rec/python/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def _replace_data_for_test(data_path):
return data_path


def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
def _load_config_for_test(pipeline_config_path,
test_dir,
total_steps=50,
num_epochs=0):
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path)
train_config = pipeline_config.train_config
Expand All @@ -171,7 +174,7 @@ def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
pipeline_config.model_dir = os.path.join(test_dir, 'train')
logging.info('test_model_dir %s' % pipeline_config.model_dir)
eval_config.num_examples = max(10, data_config.batch_size)
data_config.num_epochs = 0
data_config.num_epochs = num_epochs
return pipeline_config


Expand Down Expand Up @@ -672,10 +675,11 @@ def test_distributed_train_eval(pipeline_config_path,
num_evaluator=0,
edit_config_json=None,
use_hvd=False,
fit_on_eval=False):
fit_on_eval=False,
num_epoch=0):
logging.info('testing pipeline config %s' % pipeline_config_path)
pipeline_config = _load_config_for_test(pipeline_config_path, test_dir,
total_steps)
total_steps, num_epoch)
if edit_config_json is not None:
config_util.edit_config(pipeline_config, edit_config_json)

Expand Down
Loading

0 comments on commit eb535f3

Please sign in to comment.