Skip to content

Commit

Permalink
[feature] add support for multival separation using combo feature (#396)
Browse files Browse the repository at this point in the history
* add support for multival separation using combo feature
* add support for git lfs auto pull
* make larger the duration used in the save checkpoints seconds test, which is easy to fail under tf 1.12
  • Loading branch information
chengmengli06 authored Jul 10, 2023
1 parent 7648671 commit 303eef3
Show file tree
Hide file tree
Showing 14 changed files with 903 additions and 35 deletions.
1 change: 1 addition & 0 deletions .git_bin_path
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{"leaf_name": "data/test", "leaf_file": ["data/test/batch_criteo_sample.tfrecord", "data/test/criteo_sample.tfrecord", "data/test/dwd_avazu_ctr_deepmodel_10w.csv", "data/test/embed_data.csv", "data/test/lookup_data.csv", "data/test/tag_kv_data.csv", "data/test/test.csv", "data/test/test_sample_weight.txt", "data/test/test_with_quote.csv"]}
{"leaf_name": "data/test/client", "leaf_file": ["data/test/client/item_lst", "data/test/client/user_table_data", "data/test/client/user_table_schema"]}
{"leaf_name": "data/test/criteo_data", "leaf_file": ["data/test/criteo_data/category.bin", "data/test/criteo_data/dense.bin", "data/test/criteo_data/label.bin", "data/test/criteo_data/readme"]}
{"leaf_name": "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls", "leaf_file": ["data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/ESTIMATOR_TRAIN_DONE", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/atexit_sync_1661483067", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/checkpoint", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/eval_result.txt", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.index", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.meta", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/pipeline.config", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/version"]}
{"leaf_name": "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt", "leaf_file": ["data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/checkpoint", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/eval_result.txt", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.index", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.meta", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/pipeline.config"]}
Expand Down
2 changes: 2 additions & 0 deletions docs/source/eval.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ eval_config {
}
}
```

当转化率很低(万分之3左右)的时候,可以在auc中再设置一个参数num_thresholds:

```sql
auc {
num_thresholds: 10000
Expand Down
2 changes: 2 additions & 0 deletions docs/source/feature/feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ ComboFeature:组合特征
来自data\_config.input\_fields.input\_name
- embedding\_dim: embedding的维度,同IdFeature
- hash\_bucket\_size: hash bucket的大小
- combo_join_sep: 连接多个特征的分隔符, 如age是20, sex是'F', combo_join_sep是'X', 那么产生的特征是'20_X_F'
- combo_input_seps: 分隔符数组, 对应每个输入(input_names)的分隔符, 如果不需要分割, 填空字符串''; 如果所有的输入都不需要分割, 可以不设置


ExprFeature:表达式特征
Expand Down
29 changes: 18 additions & 11 deletions easy_rec/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,17 +425,24 @@ def parse_combo_feature(self, config):
feature_name = config.feature_name if config.HasField('feature_name') \
else None
assert len(config.input_names) >= 2
input_names = []
for input_id in range(len(config.input_names)):
if input_id == 0:
input_names.append(feature_name)
else:
input_names.append(feature_name + '_' + str(input_id))
fc = feature_column.crossed_column(
input_names,
self._get_hash_bucket_size(config),
hash_key=None,
feature_name=feature_name)

if len(config.combo_join_sep) == 0:
input_names = []
for input_id in range(len(config.input_names)):
if input_id == 0:
input_names.append(feature_name)
else:
input_names.append(feature_name + '_' + str(input_id))
fc = feature_column.crossed_column(
input_names,
self._get_hash_bucket_size(config),
hash_key=None,
feature_name=feature_name)
else:
fc = feature_column.categorical_column_with_hash_bucket(
feature_name,
hash_bucket_size=self._get_hash_bucket_size(config),
feature_name=feature_name)

if self.is_wide(config):
self._add_wide_embedding_column(fc, config)
Expand Down
99 changes: 81 additions & 18 deletions easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import six
import tensorflow as tf
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import gfile

from easy_rec.python.core import sampler as sampler_lib
Expand Down Expand Up @@ -330,6 +332,82 @@ def _get_labels(self, fields):
labels[x].get_shape()[1] == 1 else labels[x]) for x in labels
])

def _as_string(self, field, fc):
if field.dtype == tf.string:
return field
if field.dtype in [tf.float32, tf.double]:
feature_name = fc.feature_name if fc.HasField(
'feature_name') else fc.input_names[0]
assert fc.precision > 0, 'fc.precision not set for feature[%s], it is dangerous to convert ' \
'float or double to string due to precision problem, it is suggested ' \
' to convert them into string format before using EasyRec; ' \
'if you really need to do so, please set precision (the number of ' \
'decimal digits) carefully.' % feature_name
precision = None
if field.dtype in [tf.float32, tf.double]:
if fc.precision > 0:
precision = fc.precision

# convert to string
if 'as_string' in dir(tf.strings):
return tf.strings.as_string(field, precision=precision)
else:
return tf.as_string(field, precision=precision)

def _parse_combo_feature(self, fc, parsed_dict, field_dict):
# for compatibility with existing implementations
feature_name = fc.feature_name if fc.HasField(
'feature_name') else fc.input_names[0]

if len(fc.combo_input_seps) > 0:
assert len(fc.combo_input_seps) == len(fc.input_names), \
'len(combo_separator)[%d] != len(fc.input_names)[%d]' % (
len(fc.combo_input_seps), len(fc.input_names))

def _get_input_sep(input_id):
if input_id < len(fc.combo_input_seps):
return fc.combo_input_seps[input_id]
else:
return ''

if len(fc.combo_join_sep) == 0:
for input_id, input_name in enumerate(fc.input_names):
if input_id > 0:
key = feature_name + '_' + str(input_id)
else:
key = feature_name
input_sep = _get_input_sep(input_id)
if input_sep != '':
assert field_dict[
input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
input_name, field_dict[input_name].dtype)
parsed_dict[key] = tf.string_split(field_dict[input_name], input_sep)
else:
parsed_dict[key] = self._as_string(field_dict[input_name], fc)
else:
if len(fc.combo_input_seps) > 0:
split_inputs = []
for input_id, input_name in enumerate(fc.input_names):
input_sep = fc.combo_input_seps[input_id]
if len(input_sep) > 0:
assert field_dict[
input_name].dtype == tf.string, 'could not apply string_split to input-name[%s] dtype=%s' % (
input_name, field_dict[input_name].dtype)
split_inputs.append(
tf.string_split(field_dict[input_name],
fc.combo_input_seps[input_id]))
else:
split_inputs.append(tf.reshape(field_dict[input_name], [-1, 1]))
parsed_dict[feature_name] = sparse_ops.sparse_cross(
split_inputs, fc.combo_join_sep)
else:
inputs = [
self._as_string(field_dict[input_name], fc)
for input_name in fc.input_names
]
parsed_dict[feature_name] = string_ops.string_join(
inputs, fc.combo_join_sep)

def _parse_tag_feature(self, fc, parsed_dict, field_dict):
input_0 = fc.input_names[0]
feature_name = fc.feature_name if fc.HasField('feature_name') else input_0
Expand Down Expand Up @@ -441,24 +519,7 @@ def _parse_id_feature(self, fc, parsed_dict, field_dict):
parsed_dict[feature_name] = field_dict[input_0]
if fc.HasField('hash_bucket_size'):
if field_dict[input_0].dtype != tf.string:
if field_dict[input_0].dtype in [tf.float32, tf.double]:
assert fc.precision > 0, 'it is dangerous to convert float or double to string due to ' \
'precision problem, it is suggested to convert them into string ' \
'format during feature generalization before using EasyRec; ' \
'if you really need to do so, please set precision (the number of ' \
'decimal digits) carefully.'
precision = None
if field_dict[input_0].dtype in [tf.float32, tf.double]:
if fc.precision > 0:
precision = fc.precision
# convert to string

if 'as_string' in dir(tf.strings):
parsed_dict[feature_name] = tf.strings.as_string(
field_dict[input_0], precision=precision)
else:
parsed_dict[feature_name] = tf.as_string(
field_dict[input_0], precision=precision)
parsed_dict[feature_name] = self._as_string(field_dict[input_0], fc)
elif fc.num_buckets > 0:
if parsed_dict[feature_name].dtype == tf.string:
check_list = [
Expand Down Expand Up @@ -779,6 +840,8 @@ def _preprocess(self, field_dict):
self._parse_id_feature(fc, parsed_dict, field_dict)
elif feature_type == fc.ExprFeature:
self._parse_expr_feature(fc, parsed_dict, field_dict)
elif feature_type == fc.ComboFeature:
self._parse_combo_feature(fc, parsed_dict, field_dict)
else:
feature_name = fc.feature_name if fc.HasField(
'feature_name') else fc.input_names[0]
Expand Down
9 changes: 9 additions & 0 deletions easy_rec/python/protos/feature_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ message FeatureConfig {

// embedding variable params
optional EVParams ev_params = 31;

// for combo feature:
// if not set, use cross_column
// otherwise, the input features are first joined
// and then passed to categorical_column
optional string combo_join_sep = 401 [default = ''];
// separator for each inputs
// if not set, combo inputs will not be split
repeated string combo_input_seps = 402;
}

message FeatureConfigV2 {
Expand Down
16 changes: 15 additions & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def test_deepfm_with_combo_feature(self):
'samples/model_config/deepfm_combo_on_avazu_ctr.config', self._test_dir)
self.assertTrue(self._success)

def test_deepfm_with_combo_v2_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_combo_v2_on_avazu_ctr.config',
self._test_dir)
self.assertTrue(self._success)

def test_deepfm_with_combo_v3_feature(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_combo_v3_on_avazu_ctr.config',
self._test_dir)
self.assertTrue(self._success)

def test_deepfm_freeze_gradient(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/deepfm_freeze_gradient.config', self._test_dir)
Expand Down Expand Up @@ -144,9 +156,11 @@ def test_multi_tower_save_checkpoint_secs(self):
# remove last ckpt time
ckpts_times = np.array(sorted(ckpts_times)[:-1])
# ensure interval is 20s
diffs = list(ckpts_times[1:] - ckpts_times[:-1])
logging.info('nearby ckpts_times diff = %s' % diffs)
self.assertAllClose(
ckpts_times[1:] - ckpts_times[:-1], [20] * (len(ckpts_times) - 1),
atol=16)
atol=20)
self.assertTrue(self._success)

def test_keep_ckpt_max(self):
Expand Down
16 changes: 13 additions & 3 deletions easy_rec/python/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import string
import subprocess
import time
import six
from multiprocessing import Process
from subprocess import getstatusoutput
from tensorflow.python.platform import gfile
Expand All @@ -27,7 +28,8 @@

TEST_DIR = './tmp/easy_rec_test'

TEST_TIME_OUT = int(os.environ.get('TEST_TIME_OUT', 1200))
# parallel run of tests could take more time
TEST_TIME_OUT = int(os.environ.get('TEST_TIME_OUT', 1800))


def get_hdfs_tmp_dir(test_dir):
Expand All @@ -45,6 +47,8 @@ def proc_wait(proc, timeout=1200):
while proc.poll() is None and time.time() - t0 < timeout:
time.sleep(1)
if proc.poll() is None:
logging.warning('proc[pid=%d] timeout[%d], will kill the proc' %
(proc.pid, timeout))
proc.terminate()
while proc.poll() is None:
time.sleep(1)
Expand Down Expand Up @@ -95,8 +99,12 @@ def run_cmd(cmd_str, log_file, env=None):
cmd_str = cmd_str.replace('\r', ' ').replace('\n', ' ')
logging.info('RUNCMD: %s > %s 2>&1 ' % (cmd_str, log_file))
with open(log_file, 'w') as lfile:
return subprocess.Popen(
proc = subprocess.Popen(
cmd_str, stdout=lfile, stderr=subprocess.STDOUT, shell=True, env=env)
if six.PY2:
# for debug purpose
proc.args = cmd_str
return proc


def RunAsSubprocess(f):
Expand Down Expand Up @@ -224,7 +232,9 @@ def test_datahub_train_eval(pipeline_config_path,
proc = run_cmd(train_cmd, '%s/log_%s.txt' % (test_dir, 'master'))
proc_wait(proc, timeout=TEST_TIME_OUT)
if proc.returncode != 0:
logging.error('train %s failed' % test_pipeline_config_path)
logging.warning(
'train %s failed[pid=%d][code=%d][args=%s]' %
(test_pipeline_config_path, proc.pid, proc.returncode, proc.args))
return False
if post_check_func:
return post_check_func(pipeline_config)
Expand Down
2 changes: 1 addition & 1 deletion samples/model_config/deepfm_combo_on_avazu_ctr.config
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ feature_config: {
input_names: ["site_id", "app_id"]
feature_name: "site_id_app_id"
feature_type: ComboFeature
hash_bucket_size: 1000,
hash_bucket_size: 1000
embedding_dim: 16
}

Expand Down
Loading

0 comments on commit 303eef3

Please sign in to comment.