Skip to content

Commit

Permalink
[feat]: place embedding on cpu if necessary (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jul 14, 2023
1 parent f21683b commit 1b4e8ec
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 15 deletions.
11 changes: 9 additions & 2 deletions easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
from abc import abstractmethod
from collections import OrderedDict

import six
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import gfile

from easy_rec.python.core import sampler as sampler_lib
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.utils import conditional
from easy_rec.python.utils import config_util
from easy_rec.python.utils import constant
from easy_rec.python.utils.check_utils import check_split
Expand Down Expand Up @@ -1015,11 +1018,15 @@ def _input_fn(mode=None, params=None, config=None):
dataset = self._build(mode, params)
return dataset
elif mode is None: # serving_input_receiver_fn for export SavedModel
place_on_cpu = os.getenv('place_embedding_on_cpu')
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
if export_config.multi_placeholder:
inputs, features = self.create_multi_placeholders(export_config)
with conditional(place_on_cpu, ops.device('/CPU:0')):
inputs, features = self.create_multi_placeholders(export_config)
return tf.estimator.export.ServingInputReceiver(features, inputs)
else:
inputs, features = self.create_placeholders(export_config)
with conditional(place_on_cpu, ops.device('/CPU:0')):
inputs, features = self.create_placeholders(export_config)
print('built feature placeholders. features: {}'.format(
features.keys()))
return tf.estimator.export.ServingInputReceiver(features, inputs)
Expand Down
15 changes: 12 additions & 3 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from collections import OrderedDict

import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope

Expand All @@ -14,6 +16,7 @@
from easy_rec.python.layers import variational_dropout_layer
from easy_rec.python.layers.common_layers import text_cnn
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep
from easy_rec.python.utils import conditional
from easy_rec.python.utils import shape_utils

from easy_rec.python.compat.feature_column.feature_column_v2 import EmbeddingColumn # NOQA
Expand All @@ -36,13 +39,14 @@ def __init__(self,
ev_params=None,
embedding_regularizer=None,
kernel_regularizer=None,
is_training=False):
is_training=False,
is_predicting=False):
self._feature_groups = {
x.group_name: FeatureGroup(x) for x in feature_groups_config
}
self.sequence_feature_layer = sequence_feature_layer.SequenceFeatureLayer(
feature_configs, feature_groups_config, ev_params,
embedding_regularizer, kernel_regularizer, is_training)
embedding_regularizer, kernel_regularizer, is_training, is_predicting)
self._seq_feature_groups_config = []
for x in feature_groups_config:
for y in x.sequence_features:
Expand All @@ -62,6 +66,7 @@ def __init__(self,
self._embedding_regularizer = embedding_regularizer
self._kernel_regularizer = kernel_regularizer
self._is_training = is_training
self._is_predicting = is_predicting
self._variational_dropout_config = variational_dropout_config

def has_group(self, group_name):
Expand Down Expand Up @@ -92,7 +97,11 @@ def __call__(self, features, group_name, is_combine=True, is_dict=False):
feature_name_to_output_tensors = {}
negative_sampler = self._feature_groups[group_name]._config.negative_sampler
if is_combine:
concat_features, group_features = self.single_call_input_layer(
place_on_cpu = os.getenv('place_embedding_on_cpu')
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
with conditional(self._is_predicting and place_on_cpu,
ops.device('/CPU:0')):
concat_features, group_features = self.single_call_input_layer(
features, group_name, feature_name_to_output_tensors)
if group_name in self._group_name_to_seq_features:
# for target attention
Expand Down
19 changes: 14 additions & 5 deletions easy_rec/python/layers/sequence_feature_layer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import os

import tensorflow as tf

from tensorflow.python.framework import ops
from easy_rec.python.compat import regularizers
from easy_rec.python.layers import dnn
from easy_rec.python.layers import seq_input_layer
from easy_rec.python.utils import conditional

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand All @@ -18,7 +20,8 @@ def __init__(self,
ev_params=None,
embedding_regularizer=None,
kernel_regularizer=None,
is_training=False):
is_training=False,
is_predicting=False):
self._seq_feature_groups_config = []
for x in feature_groups_config:
for y in x.sequence_features:
Expand All @@ -33,6 +36,7 @@ def __init__(self,
self._embedding_regularizer = embedding_regularizer
self._kernel_regularizer = kernel_regularizer
self._is_training = is_training
self._is_predicting = is_predicting

def negative_sampler_target_attention(self,
dnn_config,
Expand Down Expand Up @@ -199,9 +203,14 @@ def __call__(self,
need_key_feature = seq_att_map_config.need_key_feature
allow_key_transform = seq_att_map_config.allow_key_transform
transform_dnn = seq_att_map_config.transform_dnn
seq_features = self._seq_input_layer(features, group_name,
feature_name_to_output_tensors,
allow_key_search, scope_name)

place_on_cpu = os.getenv('place_embedding_on_cpu')
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
with conditional(self._is_predicting and place_on_cpu,
ops.device('/CPU:0')):
seq_features = self._seq_input_layer(features, group_name,
feature_name_to_output_tensors,
allow_key_search, scope_name)

# apply regularization for sequence feature key in seq_input_layer.

Expand Down
4 changes: 3 additions & 1 deletion easy_rec/python/model/easy_rec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self,
self._base_model_config = model_config
self._model_config = model_config
self._is_training = is_training
self._is_predicting = labels is None
self._feature_dict = features

# embedding variable parameters
Expand Down Expand Up @@ -97,7 +98,8 @@ def build_input_layer(self, model_config, feature_configs):
kernel_regularizer=self._l2_reg,
variational_dropout_config=model_config.variational_dropout
if model_config.HasField('variational_dropout') else None,
is_training=self._is_training)
is_training=self._is_training,
is_predicting=self._is_predicting)

@abstractmethod
def build_predict_graph(self):
Expand Down
6 changes: 6 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ def test_fm(self):
'samples/model_config/fm_on_taobao.config', self._test_dir)
self.assertTrue(self._success)

def test_place_embed_on_cpu(self):
os.environ['place_embedding_on_cpu'] = 'True'
self._success = test_utils.test_single_train_eval(
'samples/model_config/fm_on_taobao.config', self._test_dir)
self.assertTrue(self._success)

def test_din(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/din_on_taobao.config', self._test_dir)
Expand Down
15 changes: 15 additions & 0 deletions easy_rec/python/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class conditional(object):
"""Wrap another context manager and enter it only if condition is true."""

def __init__(self, condition, contextmanager):
self.condition = condition
self.contextmanager = contextmanager

def __enter__(self):
"""Conditionally enter a context manager."""
if self.condition:
return self.contextmanager.__enter__()

def __exit__(self, *args):
if self.condition:
return self.contextmanager.__exit__(*args)
10 changes: 7 additions & 3 deletions easy_rec/python/utils/meta_graph_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model.loader_impl import SavedModelLoader

from easy_rec.python.utils import conditional
from easy_rec.python.utils import constant
from easy_rec.python.utils import embedding_utils
from easy_rec.python.utils import proto_util
Expand Down Expand Up @@ -400,9 +401,12 @@ def add_lookup_op(self, lookup_input_indices, lookup_input_values,
def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values,
lookup_input_shapes, lookup_input_weights):
logging.info('add custom lookup operation to lookup embeddings from oss')
for i in range(len(lookup_input_values)):
if lookup_input_values[i].dtype == tf.int32:
lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
place_on_cpu = os.getenv('place_embedding_on_cpu')
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
with conditional(place_on_cpu, ops.device('/CPU:0')):
for i in range(len(lookup_input_values)):
if lookup_input_values[i].dtype == tf.int32:
lookup_input_values[i] = tf.to_int64(lookup_input_values[i])
# N = len(lookup_input_indices)
# self._lookup_outs = [ None for _ in range(N) ]
# for i in range(N):
Expand Down
7 changes: 6 additions & 1 deletion pai_jobs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@
tf.app.flags.DEFINE_string('oss_embedding_version', '', 'oss embedding version')

tf.app.flags.DEFINE_bool('verbose', False, 'print more debug information')
tf.app.flags.DEFINE_bool('place_embedding_on_cpu', False,
'whether to place embedding variables on cpu')

# for automl hyper parameter tuning
tf.app.flags.DEFINE_string('model_dir', None, 'model directory')
Expand Down Expand Up @@ -434,7 +436,10 @@ def main(argv):
elif FLAGS.cmd == 'export':
check_param('export_dir')
check_param('config')

if FLAGS.place_embedding_on_cpu:
os.environ['place_embedding_on_cpu'] = 'True'
else:
os.environ['place_embedding_on_cpu'] = 'False'
redis_params = {}
if FLAGS.redis_url:
redis_params['redis_url'] = FLAGS.redis_url
Expand Down

0 comments on commit 1b4e8ec

Please sign in to comment.