From 2e0cd36359fd780cafe80079e9e11b76b251ca5b Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Fri, 23 Feb 2024 15:38:23 +0800 Subject: [PATCH 1/9] add support for cpu kv in sok --- .../compat/feature_column/feature_column.py | 49 +++++++++++----- easy_rec/python/compat/sok_optimizer.py | 57 ++++++++++--------- easy_rec/python/input/input.py | 2 +- easy_rec/python/protos/easy_rec_model.proto | 4 +- 4 files changed, 69 insertions(+), 43 deletions(-) diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index 5428690ed..f53b90cb3 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -435,28 +435,49 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring if shared_name in shared_weights: embedding_weights = shared_weights[shared_name] else: - with ops.device('/gpu:0'): - if column.ev_params is not None: - assert dynamic_variable is not None, 'sok is not installed' + if column.ev_params is not None: + assert dynamic_variable is not None, 'sok is not installed' + with ops.device('/cpu:0'): embedding_weights = dynamic_variable.DynamicVariable( name='embedding_weights', dimension=column.dimension, initializer='random {"stddev":0.0025}', # column.initializer, - var_type=None - if not column.ev_params.use_cache else 'hybrid', + var_type='dram', trainable=column.trainable and trainable, dtype=dtypes.float32, init_capacity=column.ev_params.init_capacity, max_capacity=column.ev_params.max_capacity) - else: - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=column.initializer, - trainable=column.trainable and trainable, - partitioner=None, - collections=weight_collections) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=column.initializer, + trainable=column.trainable and trainable, + partitioner=None, + collections=weight_collections) + # with ops.device('/gpu:0'): + # if column.ev_params is not None: + # assert dynamic_variable is not None, 'sok is not installed' + # embedding_weights = dynamic_variable.DynamicVariable( + # name='embedding_weights', + # dimension=column.dimension, + # initializer='random {"stddev":0.0025}', # column.initializer, + # var_type=None + # if not column.ev_params.use_cache else 'hybrid', + # trainable=column.trainable and trainable, + # dtype=dtypes.float32, + # init_capacity=column.ev_params.init_capacity, + # max_capacity=column.ev_params.max_capacity) + # else: + # embedding_weights = variable_scope.get_variable( + # name='embedding_weights', + # shape=embedding_shape, + # dtype=dtypes.float32, + # initializer=column.initializer, + # trainable=column.trainable and trainable, + # partitioner=None, + # collections=weight_collections) shared_weights[shared_name] = embedding_weights else: with ops.device('/gpu:0'): diff --git a/easy_rec/python/compat/sok_optimizer.py b/easy_rec/python/compat/sok_optimizer.py index 429c25963..4344e4f75 100644 --- a/easy_rec/python/compat/sok_optimizer.py +++ b/easy_rec/python/compat/sok_optimizer.py @@ -135,22 +135,23 @@ def _create_slots_dynamic(self, var): for slot_name in self._initial_vals: if key not in self._optimizer._slots[slot_name]: if var.backend_type == 'hbm': - slot = DynamicVariable( - dimension=var.dimension, - initializer=self._initial_vals[slot_name], - name='DynamicSlot', - trainable=False, - ) + with ops.colocate_with(var): + slot = DynamicVariable( + dimension=var.dimension, + initializer=self._initial_vals[slot_name], + name='DynamicSlot', + trainable=False) else: tmp_config = var.config_dict # tmp_initializer = var.initializer_str - slot = DynamicVariable( - dimension=var.dimension, - initializer=self._initial_vals[slot_name], - var_type=var.backend_type, - name='DynamicSlot', - trainable=False, - **tmp_config) + with ops.colocate_with(var): + slot = DynamicVariable( + dimension=var.dimension, + initializer=self._initial_vals[slot_name], + var_type=var.backend_type, + name='DynamicSlot', + trainable=False, + **tmp_config) self._optimizer._slots[slot_name][key] = slot @@ -227,23 +228,27 @@ def apply_sparse_gradients(self, grads_and_vars, global_step=None, name=None): for slot_name in self._initial_vals: if key not in self._optimizer._slots[slot_name]: tmp_slot_var_name = v._dummy_handle.op.name + '/' + self._optimizer._name + # import pdb + # pdb.set_trace() if v.backend_type == 'hbm': - slot = DynamicVariable( - dimension=v.dimension, - initializer=self._initial_vals[slot_name], - name=tmp_slot_var_name, - trainable=False, - ) + with ops.colocate_with(v): + slot = DynamicVariable( + dimension=v.dimension, + initializer=self._initial_vals[slot_name], + name=tmp_slot_var_name, + trainable=False, + ) else: tmp_config = v.config_dict # tmp_initializer = v.initializer_str - slot = DynamicVariable( - dimension=v.dimension, - initializer=self._initial_vals[slot_name], - var_type=v.backend_type, - name=tmp_slot_var_name, - trainable=False, - **tmp_config) + with ops.colocate_with(v): + slot = DynamicVariable( + dimension=v.dimension, + initializer=self._initial_vals[slot_name], + var_type=v.backend_type, + name=tmp_slot_var_name, + trainable=False, + **tmp_config) self._optimizer._slots[slot_name][key] = slot else: diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index 5e3de03a2..a932448ad 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -461,7 +461,7 @@ def _parse_tag_feature(self, fc, parsed_dict, field_dict): indices, tmp_ks, parsed_dict[feature_name].dense_shape) parsed_dict[feature_name + '_w'] = tf.sparse.SparseTensor( indices, tmp_vs, parsed_dict[feature_name].dense_shape) - if not fc.HasField('hash_bucket_size'): + if not fc.HasField('hash_bucket_size') and fc.num_buckets > 0: check_list = [ tf.py_func( check_string_to_number, diff --git a/easy_rec/python/protos/easy_rec_model.proto b/easy_rec/python/protos/easy_rec_model.proto index 1cbd55440..8b74bb492 100644 --- a/easy_rec/python/protos/easy_rec_model.proto +++ b/easy_rec/python/protos/easy_rec_model.proto @@ -27,7 +27,7 @@ import "easy_rec/python/protos/variational_dropout.proto"; import "easy_rec/python/protos/multi_tower_recall.proto"; import "easy_rec/python/protos/tower.proto"; import "easy_rec/python/protos/pdn.proto"; -// import "easy_rec/python/protos/ppnet.proto"; +import "easy_rec/python/protos/ppnet.proto"; // for input performance test message DummyModel { @@ -92,7 +92,7 @@ message EasyRecModel { SimpleMultiTask simple_multi_task = 304; PLE ple = 305; - // PPNetV3 ppnet = 306; + PPNetV3 ppnet = 306; RocketLaunching rocket_launching = 401; } From 89989c2f8473fce0ac7dd50ed228bdc9b8d8025e Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Sun, 25 Feb 2024 19:02:09 +0800 Subject: [PATCH 2/9] add embedding_on_cpu param --- .../compat/feature_column/feature_column.py | 64 ++++++++----------- easy_rec/python/input/input.py | 2 +- easy_rec/python/model/easy_rec_estimator.py | 1 + easy_rec/python/protos/feature_config.proto | 1 + easy_rec/python/utils/constant.py | 3 + easy_rec/python/utils/embedding_utils.py | 6 ++ 6 files changed, 39 insertions(+), 38 deletions(-) diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index f53b90cb3..dbd42842d 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -399,6 +399,18 @@ def _get_logits(): # pylint: disable=missing-docstring def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring assert hvd is not None, 'horovod is not installed' builder = _LazyBuilder(features) + + if embedding_utils.embedding_on_cpu(): + embedding_device = '/cpu:0' + else: + embedding_device = '/gpu:0' + + def _get_var_type(column): + if column.ev_params.use_cache: + return 'hybrid' + else: + return None + output_tensors = [] ordered_columns = [] @@ -435,59 +447,37 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring if shared_name in shared_weights: embedding_weights = shared_weights[shared_name] else: - if column.ev_params is not None: - assert dynamic_variable is not None, 'sok is not installed' - with ops.device('/cpu:0'): + with ops.device(embedding_device): + if column.ev_params is not None: + assert dynamic_variable is not None, 'sok is not installed' embedding_weights = dynamic_variable.DynamicVariable( name='embedding_weights', dimension=column.dimension, initializer='random {"stddev":0.0025}', # column.initializer, - var_type='dram', + var_type=_get_var_type(column), trainable=column.trainable and trainable, dtype=dtypes.float32, init_capacity=column.ev_params.init_capacity, max_capacity=column.ev_params.max_capacity) - else: - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=column.initializer, - trainable=column.trainable and trainable, - partitioner=None, - collections=weight_collections) - # with ops.device('/gpu:0'): - # if column.ev_params is not None: - # assert dynamic_variable is not None, 'sok is not installed' - # embedding_weights = dynamic_variable.DynamicVariable( - # name='embedding_weights', - # dimension=column.dimension, - # initializer='random {"stddev":0.0025}', # column.initializer, - # var_type=None - # if not column.ev_params.use_cache else 'hybrid', - # trainable=column.trainable and trainable, - # dtype=dtypes.float32, - # init_capacity=column.ev_params.init_capacity, - # max_capacity=column.ev_params.max_capacity) - # else: - # embedding_weights = variable_scope.get_variable( - # name='embedding_weights', - # shape=embedding_shape, - # dtype=dtypes.float32, - # initializer=column.initializer, - # trainable=column.trainable and trainable, - # partitioner=None, - # collections=weight_collections) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=column.initializer, + trainable=column.trainable and trainable, + partitioner=None, + collections=weight_collections) shared_weights[shared_name] = embedding_weights else: - with ops.device('/gpu:0'): + with ops.device(embedding_device): if column.ev_params is not None: assert dynamic_variable is not None, 'sok is not installed' embedding_weights = dynamic_variable.DynamicVariable( name='embedding_weights', dimension=column.dimension, initializer='random {"stddev":0.0025}', # column.initializer, - var_type=None if not column.ev_params.use_cache else 'hybrid', + var_type=_get_var_type(column), trainable=column.trainable and trainable, dtype=dtypes.float32, init_capacity=column.ev_params.init_capacity, diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index a932448ad..d94b1de13 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -1039,7 +1039,7 @@ def _input_fn(mode=None, params=None, config=None): dataset = self._build(mode, params) return dataset elif mode is None: # serving_input_receiver_fn for export SavedModel - place_on_cpu = os.getenv('place_embedding_on_cpu') + place_on_cpu = os.getenv(constant.EmbeddingOnCPU) place_on_cpu = eval(place_on_cpu) if place_on_cpu else False if export_config.multi_placeholder: with conditional(place_on_cpu, ops.device('/CPU:0')): diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py index be22ec928..2669ea771 100644 --- a/easy_rec/python/model/easy_rec_estimator.py +++ b/easy_rec/python/model/easy_rec_estimator.py @@ -657,6 +657,7 @@ def _export_model_fn(self, features, labels, run_config, params): def _model_fn(self, features, labels, mode, config, params): os.environ['tf.estimator.mode'] = mode os.environ['tf.estimator.ModeKeys.TRAIN'] = tf.estimator.ModeKeys.TRAIN + os.environ['place_embedding_on_cpu'] = 'True' if self._pipeline_config.fg_json_path: EasyRecEstimator._write_rtp_fg_config_to_col( fg_config_path=self._pipeline_config.fg_json_path) diff --git a/easy_rec/python/protos/feature_config.proto b/easy_rec/python/protos/feature_config.proto index fb5375822..6b4cbed28 100644 --- a/easy_rec/python/protos/feature_config.proto +++ b/easy_rec/python/protos/feature_config.proto @@ -145,6 +145,7 @@ message FeatureConfig { message FeatureConfigV2 { repeated FeatureConfig features = 1 ; + optional bool embedding_on_cpu = 2 [default=false]; } message FeatureGroupConfig { diff --git a/easy_rec/python/utils/constant.py b/easy_rec/python/utils/constant.py index 7b7818681..50d071b83 100644 --- a/easy_rec/python/utils/constant.py +++ b/easy_rec/python/utils/constant.py @@ -23,6 +23,9 @@ # shard embedding var_name collection EmbeddingParallel = 'EmbeddingParallel' +# environ variable to force embedding placement on cpu +EmbeddingOnCPU = 'place_embedding_on_cpu' + def enable_avx_str_split(): os.environ[ENABLE_AVX_STR_SPLIT] = '1' diff --git a/easy_rec/python/utils/embedding_utils.py b/easy_rec/python/utils/embedding_utils.py index 5d171b4e4..960513801 100644 --- a/easy_rec/python/utils/embedding_utils.py +++ b/easy_rec/python/utils/embedding_utils.py @@ -65,3 +65,9 @@ def is_embedding_parallel(): def sort_col_by_name(): return constant.SORT_COL_BY_NAME in os.environ + + +def embedding_on_cpu(): + place_on_cpu = os.getenv(constant.EmbeddingOnCPU) + place_on_cpu = eval(place_on_cpu) if place_on_cpu else False + return place_on_cpu From 3cf4058d3be70e82bacda4dccf5e79db81848301 Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Mon, 26 Feb 2024 09:43:44 +0800 Subject: [PATCH 3/9] fix bug --- easy_rec/python/model/easy_rec_estimator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py index 2669ea771..c40260b18 100644 --- a/easy_rec/python/model/easy_rec_estimator.py +++ b/easy_rec/python/model/easy_rec_estimator.py @@ -657,7 +657,8 @@ def _export_model_fn(self, features, labels, run_config, params): def _model_fn(self, features, labels, mode, config, params): os.environ['tf.estimator.mode'] = mode os.environ['tf.estimator.ModeKeys.TRAIN'] = tf.estimator.ModeKeys.TRAIN - os.environ['place_embedding_on_cpu'] = 'True' + if self._pipeline_config.feature_config.embedding_on_cpu: + os.environ['place_embedding_on_cpu'] = 'True' if self._pipeline_config.fg_json_path: EasyRecEstimator._write_rtp_fg_config_to_col( fg_config_path=self._pipeline_config.fg_json_path) From 73d3484865a8cf15c47bf2ae85a6fb61a41c2d3b Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Mon, 26 Feb 2024 12:07:32 +0800 Subject: [PATCH 4/9] fix embedding parallel lookup --- .../compat/feature_column/feature_column.py | 41 ++++++++++++++++--- .../multi_tower_on_taobao_sok.config | 7 ++-- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index dbd42842d..af61d5d64 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -248,8 +248,13 @@ def embedding_parallel_lookup(embedding, lookup_indices, output_ids, is_training, - output_tensors=None): + output_tensors=None, + batch_size=None): N = len(output_ids) + if batch_size is None: + num_segments = None + else: + num_segments = N * batch_size # first concat all the ids and unique if isinstance(lookup_indices, dict) and 'sparse_fea' in lookup_indices.keys(): # all_uniq_ids, uniq_idx, segment_lens = features['sparse_fea'] @@ -318,7 +323,11 @@ def embedding_parallel_lookup(embedding, recv_embeddings = data_flow_ops.parallel_dynamic_stitch( original_part_ids, recv_embeddings, name='parallel_dynamic_stitch') embeddings = math_ops.sparse_segment_sum( - recv_embeddings, uniq_idx, segment_ids, name='sparse_segment_sum') + recv_embeddings, + uniq_idx, + segment_ids, + num_segments=num_segments, + name='sparse_segment_sum') else: if isinstance(embedding, dynamic_variable.DynamicVariable): recv_embeddings = embedding.sparse_read( @@ -326,7 +335,11 @@ def embedding_parallel_lookup(embedding, else: recv_embeddings = array_ops.gather(embedding, all_uniq_ids) embeddings = math_ops.sparse_segment_sum( - recv_embeddings, uniq_idx, segment_ids, name='sparse_segment_sum') + recv_embeddings, + uniq_idx, + segment_ids, + num_segments=num_segments, + name='sparse_segment_sum') embed_dim = embedding.get_shape()[-1] output_tensor = array_ops.reshape(embeddings, [N, -1, embed_dim]) @@ -337,7 +350,8 @@ def embedding_parallel_lookup(embedding, output_tensors[output_id] = array_ops.squeeze(output, axis=0) return array_ops.reshape( - array_ops.transpose(output_tensor, perm=[1, 0, 2]), [-1, N * embed_dim]) + array_ops.transpose(output_tensor, perm=[1, 0, 2]), + [batch_size, N * embed_dim]) def _internal_input_layer(features, @@ -426,6 +440,8 @@ def _get_var_type(column): shared_weights = {} dense_cnt = 0 + + batch_sizes = [] for column in feature_columns: ordered_columns.append(column) with variable_scope.variable_scope( @@ -536,13 +552,19 @@ def _get_var_type(column): output_tensors[dense_output_id] = features[ 'dense_fea'][:, fea_dim_s:fea_dim_e] fea_dim_s = fea_dim_e + batch_sizes.append(array_ops.shape(features['dense_fea'])[0]) else: for dense_output_id, dense_col in zip(dense_output_ids, dense_cols): output_tensors[dense_output_id] = features[dense_col.raw_name] + batch_sizes.append(array_ops.shape(output_tensors[dense_output_id])[0]) for tmp_embed_var in set(lookup_embeddings): ops.add_to_collection(constant.EmbeddingParallel, tmp_embed_var.name) + if len(batch_sizes) == 0: + batch_size = None + else: + batch_size = batch_sizes[0] # do embedding parallel lookup if len(lookup_output_ids) > 0: packed_input = ('sparse_fea' in features or 'ragged_ids' in features) @@ -551,8 +573,15 @@ def _get_var_type(column): assert uniq_embed_cnt == 1, 'only one uniq embed is support for packed inputs' outputs = embedding_parallel_lookup(lookup_embeddings[0], lookup_indices, lookup_output_ids, - is_training, output_tensors) + is_training, output_tensors, + batch_size) else: + if batch_size is None: + all_indices = [] + for lookup_indice in lookup_indices: + all_indices.append(lookup_indice.indices[-1:, 0]) + all_indices = array_ops.concat(all_indices, axis=0) + batch_size = math_ops.reduce_max(all_indices) + 1 # group lookup_embeddings grouped_inputs = {} for embedding, lookup_indice, output_id in zip(lookup_embeddings, @@ -572,7 +601,7 @@ def _get_var_type(column): output_ids = grouped_inputs[embedding]['output_id'] outputs = embedding_parallel_lookup(embedding, lookup_indices, output_ids, is_training, - output_tensors) + output_tensors, batch_size) for output_tensor, col in zip(output_tensors, feature_columns): if feature_name_to_output_tensors is not None: diff --git a/samples/model_config/multi_tower_on_taobao_sok.config b/samples/model_config/multi_tower_on_taobao_sok.config index ccf840b44..7dfe7c987 100644 --- a/samples/model_config/multi_tower_on_taobao_sok.config +++ b/samples/model_config/multi_tower_on_taobao_sok.config @@ -1,5 +1,5 @@ -train_input_path: "taobao_train_data_8192" -# train_input_path: "data/test/tb_data/taobao_train_data" +# train_input_path: "taobao_train_data_8192" +train_input_path: "data/test/tb_data/taobao_train_data" eval_input_path: "data/test/tb_data/taobao_test_data" model_dir: "experiments/multi_tower_taobao_ckpt" @@ -136,8 +136,7 @@ data_config { batch_size: 8192 num_epochs: 1000000 prefetch_size: 64 - # input_type: CSVInput - input_type: DummyInput + input_type: CSVInput } feature_config: { From 18109581b7b365a669a5117653fce5a4e9aa2fbe Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Mon, 26 Feb 2024 14:36:45 +0800 Subject: [PATCH 5/9] fix batch_size none bug --- easy_rec/python/compat/feature_column/feature_column.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index af61d5d64..3af58f949 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -349,6 +349,8 @@ def embedding_parallel_lookup(embedding, for output, output_id in zip(outputs, output_ids): output_tensors[output_id] = array_ops.squeeze(output, axis=0) + if batch_size is None: + batch_size = -1 return array_ops.reshape( array_ops.transpose(output_tensor, perm=[1, 0, 2]), [batch_size, N * embed_dim]) From 93a25a64b822b4dc7c4c684626d40da29022301d Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Mon, 26 Feb 2024 14:38:09 +0800 Subject: [PATCH 6/9] remove test code --- easy_rec/python/protos/easy_rec_model.proto | 3 --- 1 file changed, 3 deletions(-) diff --git a/easy_rec/python/protos/easy_rec_model.proto b/easy_rec/python/protos/easy_rec_model.proto index 8b74bb492..b0f79fe0f 100644 --- a/easy_rec/python/protos/easy_rec_model.proto +++ b/easy_rec/python/protos/easy_rec_model.proto @@ -27,7 +27,6 @@ import "easy_rec/python/protos/variational_dropout.proto"; import "easy_rec/python/protos/multi_tower_recall.proto"; import "easy_rec/python/protos/tower.proto"; import "easy_rec/python/protos/pdn.proto"; -import "easy_rec/python/protos/ppnet.proto"; // for input performance test message DummyModel { @@ -92,8 +91,6 @@ message EasyRecModel { SimpleMultiTask simple_multi_task = 304; PLE ple = 305; - PPNetV3 ppnet = 306; - RocketLaunching rocket_launching = 401; } repeated SeqAttGroupConfig seq_att_groups = 7; From 295531593698885902b37b210668233cacdd15e8 Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Mon, 26 Feb 2024 15:01:31 +0800 Subject: [PATCH 7/9] add conditional placement of feature column on cpu --- easy_rec/python/compat/feature_column/feature_column.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index 3af58f949..c7ce6199e 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -171,6 +171,7 @@ from tensorflow.python.util import nest from easy_rec.python.compat.feature_column import utils as fc_utils +from easy_rec.python.utils import conditional from easy_rec.python.utils import constant from easy_rec.python.utils import embedding_utils @@ -634,7 +635,7 @@ def _get_var_type(column): if embedding_utils.is_embedding_parallel(): return _get_logits_embedding_parallel() else: - with ops.device('/CPU:0'): + with conditional(embedding_utils.embedding_on_cpu(), '/cpu:0'): return _get_logits() From bb5e69eeb8ed9fb1f7b8ba1c6b9d143b5aae0fa0 Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Tue, 27 Feb 2024 08:55:00 +0800 Subject: [PATCH 8/9] fix type error --- easy_rec/python/compat/feature_column/feature_column.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index c7ce6199e..73a568d9c 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -635,7 +635,8 @@ def _get_var_type(column): if embedding_utils.is_embedding_parallel(): return _get_logits_embedding_parallel() else: - with conditional(embedding_utils.embedding_on_cpu(), '/cpu:0'): + with conditional(embedding_utils.embedding_on_cpu(), + ops.device('/cpu:0')): return _get_logits() From d504cf39d75fbbd765417317fea9bad114d93e21 Mon Sep 17 00:00:00 2001 From: chengmengli06 Date: Tue, 27 Feb 2024 11:35:46 +0800 Subject: [PATCH 9/9] add comments and update docker --- docs/source/train.md | 4 ++-- easy_rec/python/compat/sok_optimizer.py | 2 -- easy_rec/python/protos/feature_config.proto | 2 ++ 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/train.md b/docs/source/train.md index dcbae3407..843955e81 100644 --- a/docs/source/train.md +++ b/docs/source/train.md @@ -194,9 +194,9 @@ pai -name easy_rec_ext -project algo_public ### 依赖 - 混合并行使用Horovod做底层的通信, 因此需要安装Horovod, 可以直接使用下面的镜像 -- mybigpai-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:sok-tf212-gpus-v4 +- mybigpai-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:sok-tf212-gpus-v5 ``` - sudo docker run --gpus=all --privileged -v /home/easyrec/:/home/easyrec/ -ti mybigpai-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:sok-tf212-gpus-v4 bash + sudo docker run --gpus=all --privileged -v /home/easyrec/:/home/easyrec/ -ti mybigpai-registry.cn-beijing.cr.aliyuncs.com/easyrec/easyrec:sok-tf212-gpus-v5 bash ``` ### 配置 diff --git a/easy_rec/python/compat/sok_optimizer.py b/easy_rec/python/compat/sok_optimizer.py index 4344e4f75..7f368a9a1 100644 --- a/easy_rec/python/compat/sok_optimizer.py +++ b/easy_rec/python/compat/sok_optimizer.py @@ -228,8 +228,6 @@ def apply_sparse_gradients(self, grads_and_vars, global_step=None, name=None): for slot_name in self._initial_vals: if key not in self._optimizer._slots[slot_name]: tmp_slot_var_name = v._dummy_handle.op.name + '/' + self._optimizer._name - # import pdb - # pdb.set_trace() if v.backend_type == 'hbm': with ops.colocate_with(v): slot = DynamicVariable( diff --git a/easy_rec/python/protos/feature_config.proto b/easy_rec/python/protos/feature_config.proto index 6b4cbed28..3dba84072 100644 --- a/easy_rec/python/protos/feature_config.proto +++ b/easy_rec/python/protos/feature_config.proto @@ -145,6 +145,8 @@ message FeatureConfig { message FeatureConfigV2 { repeated FeatureConfig features = 1 ; + // force place embedding lookup ops on cpu to improve + // training and inference efficiency. optional bool embedding_on_cpu = 2 [default=false]; }