Skip to content

Commit

Permalink
[feat]: format backbone code, add recurrent and sequential layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jun 20, 2023
1 parent e795f00 commit c4f5ea9
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 27 deletions.
25 changes: 20 additions & 5 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,22 @@ def __call__(self, is_training, **kwargs):
input_fn = EnhancedInputLayer(conf, self._input_layer, self._features)
output = input_fn(block, is_training)
block_outputs[block] = output
elif layer == 'sequential':
print(config)
else:
inputs = block_input(config, block_outputs)
output = self.call_layer(inputs, config, block, is_training)
block_outputs[block] = output

temp = []
outputs = []
for output in self._config.concat_blocks:
if output in block_outputs:
temp.append(block_outputs[output])
temp = block_outputs[output]
if type(temp) in (tuple, list):
outputs.extend(temp)
else:
outputs.append(temp)
else:
raise ValueError('No output `%s` of backbone to be concat' % output)
output = concat_inputs(temp, msg='backbone')
output = concat_inputs(outputs, msg='backbone')

if self._config.HasField('top_mlp'):
params = Parameter.make_from_pb(self._config.top_mlp)
Expand Down Expand Up @@ -193,6 +195,19 @@ def call_layer(self, inputs, config, name, training):
conf = getattr(config, 'lambda')
fn = eval(conf.expression)
return fn(inputs)
if layer_name == 'repeat':
conf = config.repeat
n_loop = conf.num_repeat
outputs = []
for i in range(n_loop):
name_i = '%s_%d' % (name, i)
output = self.call_keras_layer(conf.keras_layer, inputs, name_i, training)
outputs.append(output)
if len(outputs) == 1:
return outputs[0]
if conf.HasField('output_concat_axis'):
return tf.concat(outputs, conf.output_concat_axis)
return outputs
if layer_name == 'recurrent':
conf = config.recurrent
fixed_input_index = -1
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def call(self, group, is_training):
do_feature_dropout = is_training and 0.0 < self._config.feature_dropout_rate < 1.0
if do_feature_dropout:
keep_prob = 1.0 - self._config.feature_dropout_rate
bern = tf.distributions.Bernoulli(probs=keep_prob)
bern = tf.distributions.Bernoulli(probs=keep_prob, dtype=tf.float32)
mask = bern.sample(num_features)
elif do_bn:
features = tf.layers.batch_normalization(features, training=is_training)
Expand Down
10 changes: 9 additions & 1 deletion easy_rec/python/layers/keras/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,16 @@ def add_rich_layer(self,

def call(self, x, training=None, **kwargs):
"""Performs the forward computation of the block."""
from inspect import isfunction
for layer in self._sub_layers:
x = layer(x, training=training)
if isfunction(layer):
x = layer(x, training=training)
else:
cls = layer.__class__.__name__
if cls in ('Dropout', 'BatchNormalization'):
x = layer(x, training=training)
else:
x = layer(x)
return x


Expand Down
9 changes: 9 additions & 0 deletions easy_rec/python/protos/backbone.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,19 @@ message RecurrentLayer {
required KerasLayer keras_layer = 3;
}

message RepeatLayer {
required uint32 num_repeat = 1 [default = 1];
// default output the list of multiple outputs
optional int32 output_concat_axis = 2;
required KerasLayer keras_layer = 3;
}

message Layer {
oneof layer {
Lambda lambda = 1;
KerasLayer keras_layer = 2;
RecurrentLayer recurrent = 3;
RepeatLayer repeat = 4;
}
}

Expand All @@ -54,6 +62,7 @@ message Block {
Lambda lambda = 102;
KerasLayer keras_layer = 103;
RecurrentLayer recurrent = 104;
RepeatLayer repeat = 105;
}
}

Expand Down
93 changes: 93 additions & 0 deletions easy_rec/python/utils/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Such as Hyper parameter tuning or automatic feature expanding.
"""

import argparse
import datetime
import json
import logging
Expand Down Expand Up @@ -605,3 +606,95 @@ def process_multi_file_input_path(sampler_config_input_path):
input_path = sampler_config_input_path

return input_path


def change_configured_embedding_dim(pipeline_config_path, groups, emb_dim):
"""Reads config from a file containing pipeline_pb2.EasyRecConfig.
Args:
pipeline_config_path: Path to pipeline_pb2.EasyRecConfig text
proto.
groups: the names of feature group to be changed
emb_dim: target embedding dimension
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`. Value are the
corresponding config objects.
"""
if isinstance(pipeline_config_path, pipeline_pb2.EasyRecConfig):
return pipeline_config_path

assert tf.gfile.Exists(
pipeline_config_path
), 'pipeline_config_path [%s] not exists' % pipeline_config_path

pipeline_config = pipeline_pb2.EasyRecConfig()
with tf.gfile.GFile(pipeline_config_path, 'r') as f:
config_str = f.read()
if pipeline_config_path.endswith('.config'):
text_format.Merge(config_str, pipeline_config)
elif pipeline_config_path.endswith('.json'):
json_format.Parse(config_str, pipeline_config)
else:
assert False, 'invalid file format(%s), currently support formats: .config(prototxt) .json' % pipeline_config_path

target_groups = set(groups.split(','))
features = set()
conf = pipeline_config.model_config
for group in conf.feature_groups:
if group.group_name not in target_groups:
continue
for feature in group.feature_names:
features.add(feature)

feature_configs = get_compatible_feature_configs(pipeline_config)
for fea_conf in feature_configs:
fea_name = fea_conf.input_names[0]
if fea_conf.HasField('feature_name'):
fea_name = fea_conf.feature_name
if fea_name in features:
fea_conf.embedding_dim = emb_dim

return pipeline_config


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--pipeline_config_path',
type=str,
default=None,
help='Path to pipeline config file.')
parser.add_argument(
'--feature_groups',
type=str,
default=None,
help='The name of feature group to be changed.')
parser.add_argument(
'--embedding_dim',
type=int,
default=None,
help='The embedding dim to be changed to.')
parser.add_argument(
'--save_config_path',
type=str,
default=None,
help='Path to save changed config.')

args, extra_args = parser.parse_known_args()
if args.pipeline_config_path is None:
raise ValueError('--pipeline_config_path must be set')
if args.save_config_path is None:
raise ValueError('--save_config_path must be set')
if args.feature_groups is None:
raise ValueError('--feature_groups must be set')
if args.embedding_dim is None:
raise ValueError('--embedding_dim must be set')

# 传入一个不存在的feature group,可以起到format配置文件的效果
config = change_configured_embedding_dim(
args.pipeline_config_path,
args.feature_groups,
args.embedding_dim)
save_message(config, args.save_config_path)
31 changes: 20 additions & 11 deletions examples/configs/deepfm_backbone_on_criteo_with_autodis.config
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ model_config: {
inputs {
name: 'wide_features'
}
Lambda {
lambda {
expression: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
}
}
Expand All @@ -683,11 +683,14 @@ model_config: {
inputs {
name: 'numerical_features'
}
auto_dis_embedding {
embedding_dim: 16
num_bins: 20
temperature: 0.815
output_tensor_list: true
keras_layer {
class_name: 'AutoDisEmbedding'
auto_dis_embedding {
embedding_dim: 16
num_bins: 20
temperature: 0.815
output_tensor_list: true
}
}
}
blocks {
Expand All @@ -706,8 +709,11 @@ model_config: {
name: 'num_emb'
input_fn: 'lambda x: x[1]'
}
fm {
use_variant: true
keras_layer {
class_name: 'FM'
fm {
use_variant: true
}
}
}
blocks {
Expand All @@ -720,11 +726,14 @@ model_config: {
name: 'num_emb'
input_fn: 'lambda x: x[0]'
}
mlp {
hidden_units: [256, 128, 64]
keras_layer {
class_name: 'MLP'
mlp {
hidden_units: [256, 128, 64]
}
}
}
// no wide_logit may have better performance
# no wide_logit may have better performance
concat_blocks: ['wide_logit', 'fm', 'deep']
top_mlp {
hidden_units: [256, 128, 64]
Expand Down
27 changes: 18 additions & 9 deletions examples/configs/deepfm_backbone_on_criteo_with_periodic.config
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ model_config: {
inputs {
name: 'wide_features'
}
Lambda {
lambda {
expression: 'lambda x: tf.reduce_sum(x, axis=1, keepdims=True)'
}
}
Expand All @@ -683,10 +683,13 @@ model_config: {
inputs {
name: 'numerical_features'
}
periodic_embedding {
embedding_dim: 16
sigma: 0.005
output_tensor_list: true
keras_layer {
class_name: 'PeriodicEmbedding'
periodic_embedding {
embedding_dim: 16
sigma: 0.005
output_tensor_list: true
}
}
}
blocks {
Expand All @@ -705,8 +708,11 @@ model_config: {
name: 'num_emb'
input_fn: 'lambda x: x[1]'
}
fm {
use_variant: true
keras_layer {
class_name: 'FM'
fm {
use_variant: true
}
}
}
blocks {
Expand All @@ -719,8 +725,11 @@ model_config: {
name: 'num_emb'
input_fn: 'lambda x: x[0]'
}
mlp {
hidden_units: [256, 128, 64]
keras_layer {
class_name: 'MLP'
mlp {
hidden_units: [256, 128, 64]
}
}
}
concat_blocks: ['wide_logit', 'fm', 'deep']
Expand Down

0 comments on commit c4f5ea9

Please sign in to comment.