Skip to content

Commit

Permalink
[feat]: add more backbone blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jun 19, 2023
1 parent 9234140 commit 7d0e350
Show file tree
Hide file tree
Showing 11 changed files with 571 additions and 100 deletions.
67 changes: 52 additions & 15 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import six
import tensorflow as tf
from google.protobuf import struct_pb2

from easy_rec.python.layers.common_layers import EnhancedInputLayer
from easy_rec.python.layers.keras import MLP
from easy_rec.python.layers.utils import Parameter
from easy_rec.python.protos import backbone_pb2
from easy_rec.python.protos import keras_layer_pb2
from easy_rec.python.utils.dag import DAG
from easy_rec.python.utils.load_class import load_keras_layer

Expand Down Expand Up @@ -112,6 +114,14 @@ def __call__(self, is_training, **kwargs):
print('backbone topological order: ' + ','.join(blocks))
for block in blocks:
config = self._name_to_blocks[block]
if config.layers: # sequential layers
logging.info('call sequential %d layers' % len(config.layers))
output = block_input(config, block_outputs)
for layer in config.layers:
output = self.call_layer(output, layer, block, is_training)
block_outputs[block] = output
continue
# just one of layer
layer = config.WhichOneof('layer')
if layer is None: # identity layer
block_outputs[block] = block_input(config, block_outputs)
Expand All @@ -121,14 +131,11 @@ def __call__(self, is_training, **kwargs):
output = input_fn(block, is_training)
block_outputs[block] = output
elif layer == 'sequential':
inputs = block_input(config, block_outputs)
layers = config.sequential.layers
output = self.call_sequential_layers(inputs, layers, block, is_training)
block_outputs[block] = output
print(config)
else:
inputs = block_input(config, block_outputs)
block_outputs[block] = self.call_layer(inputs, config, block,
is_training)
output = self.call_layer(inputs, config, block, is_training)
block_outputs[block] = output

temp = []
for output in self._config.concat_blocks:
Expand Down Expand Up @@ -166,16 +173,19 @@ def call_keras_layer(self, layer_conf, inputs, name, training):
layer = layer_cls(name=name)
else:
assert param_type == 'st_params', 'internal keras layer only support st_params'
kwargs = convert_to_dict(layer_conf.st_params)
layer = layer_cls(name=name, **kwargs)
try:
kwargs = convert_to_dict(layer_conf.st_params)
logging.info('call %s layer with params %r' %
(layer_conf.class_name, kwargs))
layer = layer_cls(name=name, **kwargs)
except TypeError as e:
logging.warning(e)
args = map(format_value, layer_conf.st_params.values())
logging.info('try to call %s layer with params %r' %
(layer_conf.class_name, args))
layer = layer_cls(*args, name=name)
return layer(inputs, training=training)

def call_sequential_layers(self, inputs, layers, name, training):
output = inputs
for layer in layers:
output = self.call_layer(output, layer, name, training)
return output

def call_layer(self, inputs, config, name, training):
layer_name = config.WhichOneof('layer')
if layer_name == 'keras_layer':
Expand All @@ -184,6 +194,33 @@ def call_layer(self, inputs, config, name, training):
conf = getattr(config, 'lambda')
fn = eval(conf.expression)
return fn(inputs)
if layer_name == 'recurrent':
conf = config.recurrent
fixed_input_index = -1
if conf.HasField('fixed_input_index'):
fixed_input_index = conf.fixed_input_index
if fixed_input_index >= 0:
assert type(inputs) in (tuple, list), '%s inputs must be a list'
output = inputs
for i in range(conf.num_steps):
name_i = '%s_%d' % (name, i)
output_i = self.call_keras_layer(conf.keras_layer, output, name_i, training)
if fixed_input_index >= 0:
j = 0
for idx in range(len(output)):
if idx == fixed_input_index:
continue
output[idx] = output_i[j] if type(output_i) in (tuple, list) else output_i
j += 1
else:
output = output_i
if fixed_input_index >= 0:
del output[fixed_input_index]
if len(output) == 1:
return output[0]
return output
return output

raise NotImplementedError('Unsupported backbone layer:' + layer_name)


Expand All @@ -205,7 +242,7 @@ def concat_inputs(inputs, axis=-1, msg=''):

def format_value(value):
value_type = type(value)
if value_type in (unicode, str):
if value_type == six.text_type:
return str(value)
if value_type == float:
int_v = int(value)
Expand Down
7 changes: 5 additions & 2 deletions easy_rec/python/model/easy_rec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self,
# model_config.feature_groups,
# self._l2_reg)
# self._sequence_encoding_by_group_name = {}
self._backbone_output = None
if model_config.HasField('backbone'):
self._backbone = Backbone(
model_config.backbone,
Expand All @@ -83,11 +84,13 @@ def has_backbone(self):

@property
def backbone(self):
if self._backbone_output:
return self._backbone_output
if self._backbone:
output = self._backbone(self._is_training)
self._backbone_output = self._backbone(self._is_training)
loss_dict = self._backbone.loss_dict
self._loss_dict.update(loss_dict)
return output
return self._backbone_output
return None

@property
Expand Down
8 changes: 6 additions & 2 deletions easy_rec/python/model/esmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self,

self._group_num = len(self._model_config.groups)
self._group_features = []
if self._group_num > 0:
if self.has_backbone:
logging.info('use bottom backbone network')
elif self._group_num > 0:
logging.info('group_num: {0}'.format(self._group_num))
for group_id in range(self._group_num):
group = self._model_config.groups[group_id]
Expand Down Expand Up @@ -173,7 +175,9 @@ def build_predict_graph(self):
Returns:
self._prediction_dict: Prediction result of two tasks.
"""
if self._group_num > 0:
if self.has_backbone:
all_fea = self.backbone
elif self._group_num > 0:
group_fea_arr = []
# Both towers share the underlying network.
for group_id in range(self._group_num):
Expand Down
5 changes: 4 additions & 1 deletion easy_rec/python/model/mmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def __init__(self,
self._model_config = self._model_config.mmoe
assert isinstance(self._model_config, MMoEConfig)

self._features, _ = self._input_layer(self._feature_dict, 'all')
if self.has_backbone:
self._features = self.backbone
else:
self._features, _ = self._input_layer(self._feature_dict, 'all')
self._init_towers(self._model_config.task_towers)

def build_predict_graph(self):
Expand Down
5 changes: 4 additions & 1 deletion easy_rec/python/model/ple.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def __init__(self,

self._layer_nums = len(self._model_config.extraction_networks)
self._task_nums = len(self._model_config.task_towers)
self._features, _ = self._input_layer(self._feature_dict, 'all')
if self.has_backbone:
self._features = self.backbone
else:
self._features, _ = self._input_layer(self._feature_dict, 'all')
self._init_towers(self._model_config.task_towers)

def gate(self, selector_fea, vec_feas, name):
Expand Down
5 changes: 4 additions & 1 deletion easy_rec/python/model/simple_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def __init__(self,
self._model_config = self._model_config.simple_multi_task
assert isinstance(self._model_config, SimpleMultiTaskConfig)

self._features, _ = self._input_layer(self._feature_dict, 'all')
if self.has_backbone:
self._features = self.backbone
else:
self._features, _ = self._input_layer(self._feature_dict, 'all')
self._init_towers(self._model_config.task_towers)

def build_predict_graph(self):
Expand Down
31 changes: 19 additions & 12 deletions easy_rec/python/protos/backbone.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,36 @@ message Input {
optional string input_fn = 2;
}

message RecurrentLayer {
required uint32 num_steps = 1 [default = 1];
optional uint32 fixed_input_index = 2;
required KerasLayer keras_layer = 3;
}

message Layer {
oneof layer {
Lambda lambda = 1;
KerasLayer keras_layer = 2;
RecurrentLayer recurrent = 3;
}
}

message Block {
required string name = 1;
// the input names of feature groups or other blocks
repeated Input inputs = 2;
optional int32 input_concat_axis = 3 [default = -1];
optional bool merge_inputs_into_list = 4;
optional string extra_input_fn = 5;

// sequential layers
repeated Layer layers = 6;
// only take effect when there are no layers
oneof layer {
InputLayer input_layer = 101;
Lambda lambda = 102;
KerasLayer keras_layer = 103;
Sequential sequential = 104;
RecurrentLayer recurrent = 104;
}
}

Expand All @@ -44,14 +62,3 @@ message BackboneTower {
repeated string concat_blocks = 2;
optional MLP top_mlp = 3;
}

message Layer {
oneof layer {
Lambda lambda = 101;
KerasLayer keras_layer = 102;
}
}

message Sequential {
repeated Layer layers = 1;
}
64 changes: 8 additions & 56 deletions examples/configs/dcn_backbone_on_movielens.config
Original file line number Diff line number Diff line change
Expand Up @@ -174,68 +174,20 @@ model_config: {
}
}
blocks {
name: "cross1"
name: "dcn"
inputs {
name: 'all'
input_fn: 'lambda x: [x, x]'
}
keras_layer {
class_name: 'Cross'
}
}
blocks {
name: "cross2"
inputs {
name: 'all'
}
inputs {
name: 'cross1'
}
merge_inputs_into_list: true
keras_layer {
class_name: 'Cross'
}
}
blocks {
name: "cross3"
inputs {
name: 'all'
}
inputs {
name: 'cross2'
}
merge_inputs_into_list: true
keras_layer {
class_name: 'Cross'
}
}
blocks {
name: "cross4"
inputs {
name: 'all'
}
inputs {
name: 'cross3'
}
merge_inputs_into_list: true
keras_layer {
class_name: 'Cross'
}
}
blocks {
name: "cross5"
inputs {
name: 'all'
}
inputs {
name: 'cross4'
}
merge_inputs_into_list: true
keras_layer {
class_name: 'Cross'
recurrent {
num_steps: 3
fixed_input_index: 0
keras_layer {
class_name: 'Cross'
}
}
}
concat_blocks: ['deep', 'cross5']
concat_blocks: ['deep', 'dcn']
top_mlp {
hidden_units: [64, 32, 16]
}
Expand Down
Loading

0 comments on commit 7d0e350

Please sign in to comment.