diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index 6d2e693a0..3e2785018 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -1,17 +1,20 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging +import os from abc import abstractmethod from collections import OrderedDict import six import tensorflow as tf +from tensorflow.python.framework import ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile from easy_rec.python.core import sampler as sampler_lib from easy_rec.python.protos.dataset_pb2 import DatasetConfig +from easy_rec.python.utils import conditional from easy_rec.python.utils import config_util from easy_rec.python.utils import constant from easy_rec.python.utils.check_utils import check_split @@ -1015,11 +1018,15 @@ def _input_fn(mode=None, params=None, config=None): dataset = self._build(mode, params) return dataset elif mode is None: # serving_input_receiver_fn for export SavedModel + place_on_cpu = os.getenv('place_embedding_on_cpu') + place_on_cpu = eval(place_on_cpu) if place_on_cpu else False if export_config.multi_placeholder: - inputs, features = self.create_multi_placeholders(export_config) + with conditional(place_on_cpu, ops.device('/CPU:0')): + inputs, features = self.create_multi_placeholders(export_config) return tf.estimator.export.ServingInputReceiver(features, inputs) else: - inputs, features = self.create_placeholders(export_config) + with conditional(place_on_cpu, ops.device('/CPU:0')): + inputs, features = self.create_placeholders(export_config) print('built feature placeholders. features: {}'.format( features.keys())) return tf.estimator.export.ServingInputReceiver(features, inputs) diff --git a/easy_rec/python/layers/input_layer.py b/easy_rec/python/layers/input_layer.py index 731f47c82..92ed236db 100644 --- a/easy_rec/python/layers/input_layer.py +++ b/easy_rec/python/layers/input_layer.py @@ -1,8 +1,10 @@ # -*- encoding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. +import os from collections import OrderedDict import tensorflow as tf +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope @@ -14,6 +16,7 @@ from easy_rec.python.layers import variational_dropout_layer from easy_rec.python.layers.common_layers import text_cnn from easy_rec.python.protos.feature_config_pb2 import WideOrDeep +from easy_rec.python.utils import conditional from easy_rec.python.utils import shape_utils from easy_rec.python.compat.feature_column.feature_column_v2 import EmbeddingColumn # NOQA @@ -36,13 +39,14 @@ def __init__(self, ev_params=None, embedding_regularizer=None, kernel_regularizer=None, - is_training=False): + is_training=False, + is_predicting=False): self._feature_groups = { x.group_name: FeatureGroup(x) for x in feature_groups_config } self.sequence_feature_layer = sequence_feature_layer.SequenceFeatureLayer( feature_configs, feature_groups_config, ev_params, - embedding_regularizer, kernel_regularizer, is_training) + embedding_regularizer, kernel_regularizer, is_training, is_predicting) self._seq_feature_groups_config = [] for x in feature_groups_config: for y in x.sequence_features: @@ -62,6 +66,7 @@ def __init__(self, self._embedding_regularizer = embedding_regularizer self._kernel_regularizer = kernel_regularizer self._is_training = is_training + self._is_predicting = is_predicting self._variational_dropout_config = variational_dropout_config def has_group(self, group_name): @@ -92,7 +97,11 @@ def __call__(self, features, group_name, is_combine=True, is_dict=False): feature_name_to_output_tensors = {} negative_sampler = self._feature_groups[group_name]._config.negative_sampler if is_combine: - concat_features, group_features = self.single_call_input_layer( + place_on_cpu = os.getenv('place_embedding_on_cpu') + place_on_cpu = eval(place_on_cpu) if place_on_cpu else False + with conditional(self._is_predicting and place_on_cpu, + ops.device('/CPU:0')): + concat_features, group_features = self.single_call_input_layer( features, group_name, feature_name_to_output_tensors) if group_name in self._group_name_to_seq_features: # for target attention diff --git a/easy_rec/python/layers/sequence_feature_layer.py b/easy_rec/python/layers/sequence_feature_layer.py index a75eedcc9..bb948f785 100644 --- a/easy_rec/python/layers/sequence_feature_layer.py +++ b/easy_rec/python/layers/sequence_feature_layer.py @@ -1,10 +1,12 @@ import logging +import os import tensorflow as tf - +from tensorflow.python.framework import ops from easy_rec.python.compat import regularizers from easy_rec.python.layers import dnn from easy_rec.python.layers import seq_input_layer +from easy_rec.python.utils import conditional if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -18,7 +20,8 @@ def __init__(self, ev_params=None, embedding_regularizer=None, kernel_regularizer=None, - is_training=False): + is_training=False, + is_predicting=False): self._seq_feature_groups_config = [] for x in feature_groups_config: for y in x.sequence_features: @@ -33,6 +36,7 @@ def __init__(self, self._embedding_regularizer = embedding_regularizer self._kernel_regularizer = kernel_regularizer self._is_training = is_training + self._is_predicting = is_predicting def negative_sampler_target_attention(self, dnn_config, @@ -199,9 +203,14 @@ def __call__(self, need_key_feature = seq_att_map_config.need_key_feature allow_key_transform = seq_att_map_config.allow_key_transform transform_dnn = seq_att_map_config.transform_dnn - seq_features = self._seq_input_layer(features, group_name, - feature_name_to_output_tensors, - allow_key_search, scope_name) + + place_on_cpu = os.getenv('place_embedding_on_cpu') + place_on_cpu = eval(place_on_cpu) if place_on_cpu else False + with conditional(self._is_predicting and place_on_cpu, + ops.device('/CPU:0')): + seq_features = self._seq_input_layer(features, group_name, + feature_name_to_output_tensors, + allow_key_search, scope_name) # apply regularization for sequence feature key in seq_input_layer. diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py index 7416c5cc4..325cdc257 100644 --- a/easy_rec/python/model/easy_rec_model.py +++ b/easy_rec/python/model/easy_rec_model.py @@ -36,6 +36,7 @@ def __init__(self, self._base_model_config = model_config self._model_config = model_config self._is_training = is_training + self._is_predicting = labels is None self._feature_dict = features # embedding variable parameters @@ -97,7 +98,8 @@ def build_input_layer(self, model_config, feature_configs): kernel_regularizer=self._l2_reg, variational_dropout_config=model_config.variational_dropout if model_config.HasField('variational_dropout') else None, - is_training=self._is_training) + is_training=self._is_training, + is_predicting=self._is_predicting) @abstractmethod def build_predict_graph(self): diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 698715cdd..777dfd21a 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -315,6 +315,12 @@ def test_fm(self): 'samples/model_config/fm_on_taobao.config', self._test_dir) self.assertTrue(self._success) + def test_place_embed_on_cpu(self): + os.environ['place_embedding_on_cpu'] = 'True' + self._success = test_utils.test_single_train_eval( + 'samples/model_config/fm_on_taobao.config', self._test_dir) + self.assertTrue(self._success) + def test_din(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/din_on_taobao.config', self._test_dir) diff --git a/easy_rec/python/utils/__init__.py b/easy_rec/python/utils/__init__.py index e69de29bb..09dc89476 100644 --- a/easy_rec/python/utils/__init__.py +++ b/easy_rec/python/utils/__init__.py @@ -0,0 +1,15 @@ +class conditional(object): + """Wrap another context manager and enter it only if condition is true.""" + + def __init__(self, condition, contextmanager): + self.condition = condition + self.contextmanager = contextmanager + + def __enter__(self): + """Conditionally enter a context manager.""" + if self.condition: + return self.contextmanager.__enter__() + + def __exit__(self, *args): + if self.condition: + return self.contextmanager.__exit__(*args) diff --git a/easy_rec/python/utils/meta_graph_editor.py b/easy_rec/python/utils/meta_graph_editor.py index 34362fcd1..5906213bd 100644 --- a/easy_rec/python/utils/meta_graph_editor.py +++ b/easy_rec/python/utils/meta_graph_editor.py @@ -11,6 +11,7 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model.loader_impl import SavedModelLoader +from easy_rec.python.utils import conditional from easy_rec.python.utils import constant from easy_rec.python.utils import embedding_utils from easy_rec.python.utils import proto_util @@ -400,9 +401,12 @@ def add_lookup_op(self, lookup_input_indices, lookup_input_values, def add_oss_lookup_op(self, lookup_input_indices, lookup_input_values, lookup_input_shapes, lookup_input_weights): logging.info('add custom lookup operation to lookup embeddings from oss') - for i in range(len(lookup_input_values)): - if lookup_input_values[i].dtype == tf.int32: - lookup_input_values[i] = tf.to_int64(lookup_input_values[i]) + place_on_cpu = os.getenv('place_embedding_on_cpu') + place_on_cpu = eval(place_on_cpu) if place_on_cpu else False + with conditional(place_on_cpu, ops.device('/CPU:0')): + for i in range(len(lookup_input_values)): + if lookup_input_values[i].dtype == tf.int32: + lookup_input_values[i] = tf.to_int64(lookup_input_values[i]) # N = len(lookup_input_indices) # self._lookup_outs = [ None for _ in range(N) ] # for i in range(N): diff --git a/pai_jobs/run.py b/pai_jobs/run.py index 41c61ad31..986731d36 100644 --- a/pai_jobs/run.py +++ b/pai_jobs/run.py @@ -166,6 +166,8 @@ tf.app.flags.DEFINE_string('oss_embedding_version', '', 'oss embedding version') tf.app.flags.DEFINE_bool('verbose', False, 'print more debug information') +tf.app.flags.DEFINE_bool('place_embedding_on_cpu', False, + 'whether to place embedding variables on cpu') # for automl hyper parameter tuning tf.app.flags.DEFINE_string('model_dir', None, 'model directory') @@ -434,7 +436,10 @@ def main(argv): elif FLAGS.cmd == 'export': check_param('export_dir') check_param('config') - + if FLAGS.place_embedding_on_cpu: + os.environ['place_embedding_on_cpu'] = 'True' + else: + os.environ['place_embedding_on_cpu'] = 'False' redis_params = {} if FLAGS.redis_url: redis_params['redis_url'] = FLAGS.redis_url