Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature]: add cpu kv support for sok #455

Merged
merged 9 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions easy_rec/python/compat/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
from tensorflow.python.util import nest

from easy_rec.python.compat.feature_column import utils as fc_utils
from easy_rec.python.utils import conditional
from easy_rec.python.utils import constant
from easy_rec.python.utils import embedding_utils

Expand Down Expand Up @@ -248,8 +249,13 @@ def embedding_parallel_lookup(embedding,
lookup_indices,
output_ids,
is_training,
output_tensors=None):
output_tensors=None,
batch_size=None):
N = len(output_ids)
if batch_size is None:
num_segments = None
else:
num_segments = N * batch_size
# first concat all the ids and unique
if isinstance(lookup_indices, dict) and 'sparse_fea' in lookup_indices.keys():
# all_uniq_ids, uniq_idx, segment_lens = features['sparse_fea']
Expand Down Expand Up @@ -318,15 +324,23 @@ def embedding_parallel_lookup(embedding,
recv_embeddings = data_flow_ops.parallel_dynamic_stitch(
original_part_ids, recv_embeddings, name='parallel_dynamic_stitch')
embeddings = math_ops.sparse_segment_sum(
recv_embeddings, uniq_idx, segment_ids, name='sparse_segment_sum')
recv_embeddings,
uniq_idx,
segment_ids,
num_segments=num_segments,
name='sparse_segment_sum')
else:
if isinstance(embedding, dynamic_variable.DynamicVariable):
recv_embeddings = embedding.sparse_read(
all_uniq_ids, lookup_only=(not is_training))
else:
recv_embeddings = array_ops.gather(embedding, all_uniq_ids)
embeddings = math_ops.sparse_segment_sum(
recv_embeddings, uniq_idx, segment_ids, name='sparse_segment_sum')
recv_embeddings,
uniq_idx,
segment_ids,
num_segments=num_segments,
name='sparse_segment_sum')

embed_dim = embedding.get_shape()[-1]
output_tensor = array_ops.reshape(embeddings, [N, -1, embed_dim])
Expand All @@ -336,8 +350,11 @@ def embedding_parallel_lookup(embedding,
for output, output_id in zip(outputs, output_ids):
output_tensors[output_id] = array_ops.squeeze(output, axis=0)

if batch_size is None:
batch_size = -1
return array_ops.reshape(
array_ops.transpose(output_tensor, perm=[1, 0, 2]), [-1, N * embed_dim])
array_ops.transpose(output_tensor, perm=[1, 0, 2]),
[batch_size, N * embed_dim])


def _internal_input_layer(features,
Expand Down Expand Up @@ -399,6 +416,18 @@ def _get_logits(): # pylint: disable=missing-docstring
def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
assert hvd is not None, 'horovod is not installed'
builder = _LazyBuilder(features)

if embedding_utils.embedding_on_cpu():
embedding_device = '/cpu:0'
else:
embedding_device = '/gpu:0'

def _get_var_type(column):
if column.ev_params.use_cache:
return 'hybrid'
else:
return None

output_tensors = []
ordered_columns = []

Expand All @@ -414,6 +443,8 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring

shared_weights = {}
dense_cnt = 0

batch_sizes = []
for column in feature_columns:
ordered_columns.append(column)
with variable_scope.variable_scope(
Expand All @@ -435,15 +466,14 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
if shared_name in shared_weights:
embedding_weights = shared_weights[shared_name]
else:
with ops.device('/gpu:0'):
with ops.device(embedding_device):
if column.ev_params is not None:
assert dynamic_variable is not None, 'sok is not installed'
embedding_weights = dynamic_variable.DynamicVariable(
name='embedding_weights',
dimension=column.dimension,
initializer='random {"stddev":0.0025}', # column.initializer,
var_type=None
if not column.ev_params.use_cache else 'hybrid',
var_type=_get_var_type(column),
trainable=column.trainable and trainable,
dtype=dtypes.float32,
init_capacity=column.ev_params.init_capacity,
Expand All @@ -459,14 +489,14 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
collections=weight_collections)
shared_weights[shared_name] = embedding_weights
else:
with ops.device('/gpu:0'):
with ops.device(embedding_device):
if column.ev_params is not None:
assert dynamic_variable is not None, 'sok is not installed'
embedding_weights = dynamic_variable.DynamicVariable(
name='embedding_weights',
dimension=column.dimension,
initializer='random {"stddev":0.0025}', # column.initializer,
var_type=None if not column.ev_params.use_cache else 'hybrid',
var_type=_get_var_type(column),
trainable=column.trainable and trainable,
dtype=dtypes.float32,
init_capacity=column.ev_params.init_capacity,
Expand Down Expand Up @@ -525,13 +555,19 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
output_tensors[dense_output_id] = features[
'dense_fea'][:, fea_dim_s:fea_dim_e]
fea_dim_s = fea_dim_e
batch_sizes.append(array_ops.shape(features['dense_fea'])[0])
else:
for dense_output_id, dense_col in zip(dense_output_ids, dense_cols):
output_tensors[dense_output_id] = features[dense_col.raw_name]
batch_sizes.append(array_ops.shape(output_tensors[dense_output_id])[0])

for tmp_embed_var in set(lookup_embeddings):
ops.add_to_collection(constant.EmbeddingParallel, tmp_embed_var.name)

if len(batch_sizes) == 0:
batch_size = None
else:
batch_size = batch_sizes[0]
# do embedding parallel lookup
if len(lookup_output_ids) > 0:
packed_input = ('sparse_fea' in features or 'ragged_ids' in features)
Expand All @@ -540,8 +576,15 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
assert uniq_embed_cnt == 1, 'only one uniq embed is support for packed inputs'
outputs = embedding_parallel_lookup(lookup_embeddings[0],
lookup_indices, lookup_output_ids,
is_training, output_tensors)
is_training, output_tensors,
batch_size)
else:
if batch_size is None:
all_indices = []
for lookup_indice in lookup_indices:
all_indices.append(lookup_indice.indices[-1:, 0])
all_indices = array_ops.concat(all_indices, axis=0)
batch_size = math_ops.reduce_max(all_indices) + 1
# group lookup_embeddings
grouped_inputs = {}
for embedding, lookup_indice, output_id in zip(lookup_embeddings,
Expand All @@ -561,7 +604,7 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
output_ids = grouped_inputs[embedding]['output_id']
outputs = embedding_parallel_lookup(embedding, lookup_indices,
output_ids, is_training,
output_tensors)
output_tensors, batch_size)

for output_tensor, col in zip(output_tensors, feature_columns):
if feature_name_to_output_tensors is not None:
Expand Down Expand Up @@ -592,7 +635,8 @@ def _get_logits_embedding_parallel(): # pylint: disable=missing-docstring
if embedding_utils.is_embedding_parallel():
return _get_logits_embedding_parallel()
else:
with ops.device('/CPU:0'):
with conditional(embedding_utils.embedding_on_cpu(),
ops.device('/cpu:0')):
return _get_logits()


Expand Down
57 changes: 31 additions & 26 deletions easy_rec/python/compat/sok_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,23 @@ def _create_slots_dynamic(self, var):
for slot_name in self._initial_vals:
if key not in self._optimizer._slots[slot_name]:
if var.backend_type == 'hbm':
slot = DynamicVariable(
dimension=var.dimension,
initializer=self._initial_vals[slot_name],
name='DynamicSlot',
trainable=False,
)
with ops.colocate_with(var):
slot = DynamicVariable(
dimension=var.dimension,
initializer=self._initial_vals[slot_name],
name='DynamicSlot',
trainable=False)
else:
tmp_config = var.config_dict
# tmp_initializer = var.initializer_str
slot = DynamicVariable(
dimension=var.dimension,
initializer=self._initial_vals[slot_name],
var_type=var.backend_type,
name='DynamicSlot',
trainable=False,
**tmp_config)
with ops.colocate_with(var):
slot = DynamicVariable(
dimension=var.dimension,
initializer=self._initial_vals[slot_name],
var_type=var.backend_type,
name='DynamicSlot',
trainable=False,
**tmp_config)

self._optimizer._slots[slot_name][key] = slot

Expand Down Expand Up @@ -227,23 +228,27 @@ def apply_sparse_gradients(self, grads_and_vars, global_step=None, name=None):
for slot_name in self._initial_vals:
if key not in self._optimizer._slots[slot_name]:
tmp_slot_var_name = v._dummy_handle.op.name + '/' + self._optimizer._name
# import pdb
# pdb.set_trace()
if v.backend_type == 'hbm':
slot = DynamicVariable(
dimension=v.dimension,
initializer=self._initial_vals[slot_name],
name=tmp_slot_var_name,
trainable=False,
)
with ops.colocate_with(v):
slot = DynamicVariable(
dimension=v.dimension,
initializer=self._initial_vals[slot_name],
name=tmp_slot_var_name,
trainable=False,
)
else:
tmp_config = v.config_dict
# tmp_initializer = v.initializer_str
slot = DynamicVariable(
dimension=v.dimension,
initializer=self._initial_vals[slot_name],
var_type=v.backend_type,
name=tmp_slot_var_name,
trainable=False,
**tmp_config)
with ops.colocate_with(v):
slot = DynamicVariable(
dimension=v.dimension,
initializer=self._initial_vals[slot_name],
var_type=v.backend_type,
name=tmp_slot_var_name,
trainable=False,
**tmp_config)

self._optimizer._slots[slot_name][key] = slot
else:
Expand Down
4 changes: 2 additions & 2 deletions easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _parse_tag_feature(self, fc, parsed_dict, field_dict):
indices, tmp_ks, parsed_dict[feature_name].dense_shape)
parsed_dict[feature_name + '_w'] = tf.sparse.SparseTensor(
indices, tmp_vs, parsed_dict[feature_name].dense_shape)
if not fc.HasField('hash_bucket_size'):
if not fc.HasField('hash_bucket_size') and fc.num_buckets > 0:
check_list = [
tf.py_func(
check_string_to_number,
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def _input_fn(mode=None, params=None, config=None):
dataset = self._build(mode, params)
return dataset
elif mode is None: # serving_input_receiver_fn for export SavedModel
place_on_cpu = os.getenv('place_embedding_on_cpu')
place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
if export_config.multi_placeholder:
with conditional(place_on_cpu, ops.device('/CPU:0')):
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/model/easy_rec_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,8 @@ def _export_model_fn(self, features, labels, run_config, params):
def _model_fn(self, features, labels, mode, config, params):
os.environ['tf.estimator.mode'] = mode
os.environ['tf.estimator.ModeKeys.TRAIN'] = tf.estimator.ModeKeys.TRAIN
if self._pipeline_config.feature_config.embedding_on_cpu:
os.environ['place_embedding_on_cpu'] = 'True'
if self._pipeline_config.fg_json_path:
EasyRecEstimator._write_rtp_fg_config_to_col(
fg_config_path=self._pipeline_config.fg_json_path)
Expand Down
3 changes: 0 additions & 3 deletions easy_rec/python/protos/easy_rec_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import "easy_rec/python/protos/variational_dropout.proto";
import "easy_rec/python/protos/multi_tower_recall.proto";
import "easy_rec/python/protos/tower.proto";
import "easy_rec/python/protos/pdn.proto";
// import "easy_rec/python/protos/ppnet.proto";

// for input performance test
message DummyModel {
Expand Down Expand Up @@ -92,8 +91,6 @@ message EasyRecModel {
SimpleMultiTask simple_multi_task = 304;
PLE ple = 305;

// PPNetV3 ppnet = 306;

RocketLaunching rocket_launching = 401;
}
repeated SeqAttGroupConfig seq_att_groups = 7;
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ message FeatureConfig {

message FeatureConfigV2 {
repeated FeatureConfig features = 1 ;
optional bool embedding_on_cpu = 2 [default=false];
}

message FeatureGroupConfig {
Expand Down
3 changes: 3 additions & 0 deletions easy_rec/python/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
# shard embedding var_name collection
EmbeddingParallel = 'EmbeddingParallel'

# environ variable to force embedding placement on cpu
EmbeddingOnCPU = 'place_embedding_on_cpu'


def enable_avx_str_split():
os.environ[ENABLE_AVX_STR_SPLIT] = '1'
Expand Down
6 changes: 6 additions & 0 deletions easy_rec/python/utils/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ def is_embedding_parallel():

def sort_col_by_name():
return constant.SORT_COL_BY_NAME in os.environ


def embedding_on_cpu():
place_on_cpu = os.getenv(constant.EmbeddingOnCPU)
place_on_cpu = eval(place_on_cpu) if place_on_cpu else False
return place_on_cpu
7 changes: 3 additions & 4 deletions samples/model_config/multi_tower_on_taobao_sok.config
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
train_input_path: "taobao_train_data_8192"
# train_input_path: "data/test/tb_data/taobao_train_data"
# train_input_path: "taobao_train_data_8192"
train_input_path: "data/test/tb_data/taobao_train_data"
eval_input_path: "data/test/tb_data/taobao_test_data"
model_dir: "experiments/multi_tower_taobao_ckpt"

Expand Down Expand Up @@ -136,8 +136,7 @@ data_config {
batch_size: 8192
num_epochs: 1000000
prefetch_size: 64
# input_type: CSVInput
input_type: DummyInput
input_type: CSVInput
}

feature_config: {
Expand Down
Loading