diff --git a/easy_rec/python/layers/backbone.py b/easy_rec/python/layers/backbone.py index 7eee14a4d..22645bee0 100644 --- a/easy_rec/python/layers/backbone.py +++ b/easy_rec/python/layers/backbone.py @@ -10,7 +10,6 @@ from easy_rec.python.layers.keras import MLP from easy_rec.python.layers.utils import Parameter from easy_rec.python.protos import backbone_pb2 -from easy_rec.python.protos import keras_layer_pb2 from easy_rec.python.utils.dag import DAG from easy_rec.python.utils.load_class import load_keras_layer @@ -204,13 +203,17 @@ def call_layer(self, inputs, config, name, training): output = inputs for i in range(conf.num_steps): name_i = '%s_%d' % (name, i) - output_i = self.call_keras_layer(conf.keras_layer, output, name_i, training) + layer = conf.keras_layer + output_i = self.call_keras_layer(layer, output, name_i, training) if fixed_input_index >= 0: j = 0 for idx in range(len(output)): if idx == fixed_input_index: continue - output[idx] = output_i[j] if type(output_i) in (tuple, list) else output_i + if type(output_i) in (tuple, list): + output[idx] = output_i[j] + else: + output[idx] = output_i j += 1 else: output = output_i diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py index 64cacf3c9..24f62ffb3 100644 --- a/easy_rec/python/layers/keras/__init__.py +++ b/easy_rec/python/layers/keras/__init__.py @@ -1,13 +1,13 @@ from .blocks import MLP from .blocks import Highway from .bst import BST -from .dcn import Cross from .din import DIN -from .dot_interaction import DotInteraction from .fibinet import BiLinear from .fibinet import FiBiNet from .fibinet import SENet -from .fm import FM +from .interaction import FM +from .interaction import Cross +from .interaction import DotInteraction from .mask_net import MaskBlock from .mask_net import MaskNet from .numerical_embedding import AutoDisEmbedding diff --git a/easy_rec/python/layers/keras/dot_interaction.py b/easy_rec/python/layers/keras/dot_interaction.py deleted file mode 100644 index 7ec47c5ad..000000000 --- a/easy_rec/python/layers/keras/dot_interaction.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- encoding:utf-8 -*- -# Copyright (c) Alibaba, Inc. and its affiliates. -"""Implements `Dot Interaction` Layer of DLRM model.""" - -import tensorflow as tf - - -class DotInteraction(tf.keras.layers.Layer): - """Dot interaction layer. - - See theory in the DLRM paper: https://arxiv.org/pdf/1906.00091.pdf, - section 2.1.3. Sparse activations and dense activations are combined. - Dot interaction is applied to a batch of input Tensors [e1,...,e_k] of the - same dimension and the output is a batch of Tensors with all distinct pairwise - dot products of the form dot(e_i, e_j) for i <= j if self self_interaction is - True, otherwise dot(e_i, e_j) i < j. - - Attributes: - self_interaction: Boolean indicating if features should self-interact. - If it is True, then the diagonal entries of the interaction metric are - also taken. - skip_gather: An optimization flag. If it's set then the upper triangle part - of the dot interaction matrix dot(e_i, e_j) is set to 0. The resulting - activations will be of dimension [num_features * num_features] from which - half will be zeros. Otherwise activations will be only lower triangle part - of the interaction matrix. The later saves space but is much slower. - name: String name of the layer. - """ - - def __init__(self, params, name=None, **kwargs): - self._self_interaction = params.get_or_default('self_interaction', False) - self._skip_gather = params.get_or_default('skip_gather', False) - super(DotInteraction, self).__init__(name=name, **kwargs) - - def call(self, inputs, **kwargs): - """Performs the interaction operation on the tensors in the list. - - The tensors represent as transformed dense features and embedded categorical - features. - Pre-condition: The tensors should all have the same shape. - - Args: - inputs: List of features with shapes [batch_size, feature_dim]. - - Returns: - activations: Tensor representing interacted features. It has a dimension - `num_features * num_features` if skip_gather is True, otherside - `num_features * (num_features + 1) / 2` if self_interaction is True and - `num_features * (num_features - 1) / 2` if self_interaction is False. - """ - if isinstance(inputs, (list, tuple)): - # concat_features shape: batch_size, num_features, feature_dim - try: - concat_features = tf.stack(inputs, axis=1) - except (ValueError, tf.errors.InvalidArgumentError) as e: - raise ValueError('Input tensors` dimensions must be equal, original' - 'error message: {}'.format(e)) - else: - assert inputs.shape.ndims == 3, 'input of dot func must be a 3D tensor or a list of 2D tensors' - concat_features = inputs - - batch_size = tf.shape(concat_features)[0] - - # Interact features, select lower-triangular portion, and re-shape. - xactions = tf.matmul(concat_features, concat_features, transpose_b=True) - num_features = xactions.shape[-1] - ones = tf.ones_like(xactions) - if self._self_interaction: - # Selecting lower-triangular portion including the diagonal. - lower_tri_mask = tf.linalg.band_part(ones, -1, 0) - upper_tri_mask = ones - lower_tri_mask - out_dim = num_features * (num_features + 1) // 2 - else: - # Selecting lower-triangular portion not included the diagonal. - upper_tri_mask = tf.linalg.band_part(ones, 0, -1) - lower_tri_mask = ones - upper_tri_mask - out_dim = num_features * (num_features - 1) // 2 - - if self._skip_gather: - # Setting upper triangle part of the interaction matrix to zeros. - activations = tf.where( - condition=tf.cast(upper_tri_mask, tf.bool), - x=tf.zeros_like(xactions), - y=xactions) - out_dim = num_features * num_features - else: - activations = tf.boolean_mask(xactions, lower_tri_mask) - activations = tf.reshape(activations, (batch_size, out_dim)) - return activations diff --git a/easy_rec/python/layers/keras/fibinet.py b/easy_rec/python/layers/keras/fibinet.py index dc1f7d003..98cdb3179 100644 --- a/easy_rec/python/layers/keras/fibinet.py +++ b/easy_rec/python/layers/keras/fibinet.py @@ -5,7 +5,6 @@ import tensorflow as tf -from easy_rec.python.layers import dnn from easy_rec.python.layers.common_layers import layer_norm from easy_rec.python.layers.keras.blocks import MLP from easy_rec.python.layers.utils import Parameter @@ -15,9 +14,20 @@ class SENet(tf.keras.layers.Layer): - """SENet+ Layer used in FiBiNET,支持不同field的embedding dimension不等. + """SENET Layer used in FiBiNET. - arxiv: 2209.05016 + Input shape + - A list of 2D tensor with shape: ``(batch_size,embedding_size)``. + The ``embedding_size`` of each field can have different value. + + Output shape + - A 2D tensor with shape: ``(batch_size,sum_of_embedding_size)``. + + References: + 1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf) + Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction + 2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf) + Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction """ def __init__(self, params, name='SENet', **kwargs): @@ -25,8 +35,6 @@ def __init__(self, params, name='SENet', **kwargs): self.config = params.get_pb_config() def call(self, inputs, **kwargs): - """embedding_list: - A list of 2D tensor with shape: ``(batch_size,embedding_size)``.""" - print('SENET layer with %d inputs' % len(inputs)) g = self.config.num_squeeze_group for emb in inputs: assert emb.shape.ndims == 2, 'field embeddings must be rank 2 tensors' @@ -88,14 +96,26 @@ def _full_interaction(v_i, v_j): class BiLinear(tf.keras.layers.Layer): - """双线性特征交互层,支持不同field embeddings的size不等. + """BilinearInteraction Layer used in FiBiNET. + + Input shape + - A list of 2D tensor with shape: ``(batch_size,embedding_size)``. + Its length is ``filed_size``. + The ``embedding_size`` of each field can have different value. - arxiv: 2209.05016 + Output shape + - 2D tensor with shape: ``(batch_size,output_size)``. Attributes: - num_output_units: 输出的size - type: ['all', 'each', 'interaction'],支持其中一种 - use_plus: 是否使用bi-linear+ + num_output_units: the number of output units + type: ['all', 'each', 'interaction'], types of bilinear functions used in this layer + use_plus: whether to use bi-linear+ + + References: + 1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf) + Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction + 2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf) + Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction """ def __init__(self, params, name='bilinear', **kwargs): @@ -186,36 +206,32 @@ def call(self, inputs, **kwargs): class FiBiNet(tf.keras.layers.Layer): """FiBiNet++:Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction. - This is almost an exact implementation of the original FiBiNet++ model. - See the original paper: - https://arxiv.org/pdf/2209.05016.pdf + References: + - [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf) + Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction """ - def __init__(self, params, name='fibinet', l2_reg=None, **kwargs): + def __init__(self, params, name='fibinet', **kwargs): super(FiBiNet, self).__init__(name, **kwargs) self._config = params.get_pb_config() if self._config.HasField('mlp'): - # self.final_dnn = dnn.DNN( - # self._config.mlp, - # kwargs['l2_reg'] if 'l2_reg' in kwargs else None, - # name='%s_fibinet_mlp' % self.name, - # is_training=False) p = Parameter.make_from_pb(self._config.mlp) - self.final_dnn = MLP(p, name=name, l2_reg=l2_reg) + p.l2_regularizer = params.l2_regularizer + self.final_mlp = MLP(p, name=name) else: - self.final_dnn = None + self.final_mlp = None def call(self, inputs, training=None, **kwargs): feature_list = [] params = Parameter.make_from_pb(self._config.senet) - senet = SENet(params, name='%s_senet' % self.name) + senet = SENet(params, name='%s/senet' % self.name) senet_output = senet(inputs) feature_list.append(senet_output) if self._config.HasField('bilinear'): params = Parameter.make_from_pb(self._config.bilinear) - bilinear = BiLinear(params, name='%s_bilinear' % self.name) + bilinear = BiLinear(params, name='%s/bilinear' % self.name) bilinear_output = bilinear(inputs) feature_list.append(bilinear_output) @@ -224,6 +240,6 @@ def call(self, inputs, training=None, **kwargs): else: feature = feature_list[0] - if self.final_dnn is not None: - feature = self.final_dnn(feature, training=training) + if self.final_mlp is not None: + feature = self.final_mlp(feature, training=training) return feature diff --git a/easy_rec/python/layers/keras/fm.py b/easy_rec/python/layers/keras/fm.py deleted file mode 100644 index 56910541f..000000000 --- a/easy_rec/python/layers/keras/fm.py +++ /dev/null @@ -1,46 +0,0 @@ -# -*- encoding:utf-8 -*- -# Copyright (c) Alibaba, Inc. and its affiliates. -import tensorflow as tf - -if tf.__version__ >= '2.0': - tf = tf.compat.v1 - - -class FM(tf.keras.layers.Layer): - """Factorization Machine models pairwise (order-2) feature interactions without linear term and bias. - - References - - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) - Input shape. - - List of 2D tensor with shape: ``(batch_size,embedding_size)``. - - Or a 3D tensor with shape: ``(batch_size,field_size,embedding_size)`` - Output shape - - 2D tensor with shape: ``(batch_size, 1)``. - """ - - def __init__(self, params, name='fm', **kwargs): - super(FM, self).__init__(name, **kwargs) - self.use_variant = params.get_or_default('use_variant', False) - - def call(self, inputs, **kwargs): - if type(inputs) == list: - emb_dims = set(map(lambda x: int(x.shape[-1]), inputs)) - if len(emb_dims) != 1: - dims = ','.join([str(d) for d in emb_dims]) - raise ValueError('all embedding dim must be equal in FM layer:' + dims) - - with tf.name_scope(self.name): - fea = tf.stack(inputs, axis=1) - else: - assert inputs.shape.ndims == 3, 'input of FM layer must be a 3D tensor or a list of 2D tensors' - fea = inputs - - with tf.name_scope(self.name): - square_of_sum = tf.square(tf.reduce_sum(fea, axis=1)) - sum_of_square = tf.reduce_sum(tf.square(fea), axis=1) - cross_term = tf.subtract(square_of_sum, sum_of_square) - if self.use_variant: - cross_term = 0.5 * cross_term - else: - cross_term = 0.5 * tf.reduce_sum(cross_term, axis=-1, keepdims=True) - return cross_term diff --git a/easy_rec/python/layers/keras/dcn.py b/easy_rec/python/layers/keras/interaction.py similarity index 59% rename from easy_rec/python/layers/keras/dcn.py rename to easy_rec/python/layers/keras/interaction.py index 9585893e5..55f56f7a1 100644 --- a/easy_rec/python/layers/keras/dcn.py +++ b/easy_rec/python/layers/keras/interaction.py @@ -1,12 +1,133 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -"""Implements `Cross` Layer, the cross layer in Deep & Cross Network (DCN).""" - import tensorflow as tf from easy_rec.python.utils.activation import get_activation +class FM(tf.keras.layers.Layer): + """Factorization Machine models pairwise (order-2) feature interactions without linear term and bias. + + References + - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) + Input shape. + - List of 2D tensor with shape: ``(batch_size,embedding_size)``. + - Or a 3D tensor with shape: ``(batch_size,field_size,embedding_size)`` + Output shape + - 2D tensor with shape: ``(batch_size, 1)``. + """ + + def __init__(self, params, name='fm', **kwargs): + super(FM, self).__init__(name, **kwargs) + self.use_variant = params.get_or_default('use_variant', False) + + def call(self, inputs, **kwargs): + if type(inputs) == list: + emb_dims = set(map(lambda x: int(x.shape[-1]), inputs)) + if len(emb_dims) != 1: + dims = ','.join([str(d) for d in emb_dims]) + raise ValueError('all embedding dim must be equal in FM layer:' + dims) + with tf.name_scope(self.name): + fea = tf.stack(inputs, axis=1) + else: + assert inputs.shape.ndims == 3, 'input of FM layer must be a 3D tensor or a list of 2D tensors' + fea = inputs + + with tf.name_scope(self.name): + square_of_sum = tf.square(tf.reduce_sum(fea, axis=1)) + sum_of_square = tf.reduce_sum(tf.square(fea), axis=1) + cross_term = tf.subtract(square_of_sum, sum_of_square) + if self.use_variant: + cross_term = 0.5 * cross_term + else: + cross_term = 0.5 * tf.reduce_sum(cross_term, axis=-1, keepdims=True) + return cross_term + + +class DotInteraction(tf.keras.layers.Layer): + """Dot interaction layer of DLRM model.. + + See theory in the DLRM paper: https://arxiv.org/pdf/1906.00091.pdf, + section 2.1.3. Sparse activations and dense activations are combined. + Dot interaction is applied to a batch of input Tensors [e1,...,e_k] of the + same dimension and the output is a batch of Tensors with all distinct pairwise + dot products of the form dot(e_i, e_j) for i <= j if self self_interaction is + True, otherwise dot(e_i, e_j) i < j. + + Attributes: + self_interaction: Boolean indicating if features should self-interact. + If it is True, then the diagonal entries of the interaction metric are + also taken. + skip_gather: An optimization flag. If it's set then the upper triangle part + of the dot interaction matrix dot(e_i, e_j) is set to 0. The resulting + activations will be of dimension [num_features * num_features] from which + half will be zeros. Otherwise activations will be only lower triangle part + of the interaction matrix. The later saves space but is much slower. + name: String name of the layer. + """ + + def __init__(self, params, name=None, **kwargs): + self._self_interaction = params.get_or_default('self_interaction', False) + self._skip_gather = params.get_or_default('skip_gather', False) + super(DotInteraction, self).__init__(name=name, **kwargs) + + def call(self, inputs, **kwargs): + """Performs the interaction operation on the tensors in the list. + + The tensors represent as transformed dense features and embedded categorical + features. + Pre-condition: The tensors should all have the same shape. + + Args: + inputs: List of features with shapes [batch_size, feature_dim]. + + Returns: + activations: Tensor representing interacted features. It has a dimension + `num_features * num_features` if skip_gather is True, otherside + `num_features * (num_features + 1) / 2` if self_interaction is True and + `num_features * (num_features - 1) / 2` if self_interaction is False. + """ + if isinstance(inputs, (list, tuple)): + # concat_features shape: batch_size, num_features, feature_dim + try: + concat_features = tf.stack(inputs, axis=1) + except (ValueError, tf.errors.InvalidArgumentError) as e: + raise ValueError('Input tensors` dimensions must be equal, original' + 'error message: {}'.format(e)) + else: + assert inputs.shape.ndims == 3, 'input of dot func must be a 3D tensor or a list of 2D tensors' + concat_features = inputs + + batch_size = tf.shape(concat_features)[0] + + # Interact features, select lower-triangular portion, and re-shape. + xactions = tf.matmul(concat_features, concat_features, transpose_b=True) + num_features = xactions.shape[-1] + ones = tf.ones_like(xactions) + if self._self_interaction: + # Selecting lower-triangular portion including the diagonal. + lower_tri_mask = tf.linalg.band_part(ones, -1, 0) + upper_tri_mask = ones - lower_tri_mask + out_dim = num_features * (num_features + 1) // 2 + else: + # Selecting lower-triangular portion not included the diagonal. + upper_tri_mask = tf.linalg.band_part(ones, 0, -1) + lower_tri_mask = ones - upper_tri_mask + out_dim = num_features * (num_features - 1) // 2 + + if self._skip_gather: + # Setting upper triangle part of the interaction matrix to zeros. + activations = tf.where( + condition=tf.cast(upper_tri_mask, tf.bool), + x=tf.zeros_like(xactions), + y=xactions) + out_dim = num_features * num_features + else: + activations = tf.boolean_mask(xactions, lower_tri_mask) + activations = tf.reshape(activations, (batch_size, out_dim)) + return activations + + class Cross(tf.keras.layers.Layer): """Cross Layer in Deep & Cross Network to learn explicit feature interactions. diff --git a/easy_rec/python/layers/keras/mask_net.py b/easy_rec/python/layers/keras/mask_net.py index 8749a1ee8..2e66beb22 100644 --- a/easy_rec/python/layers/keras/mask_net.py +++ b/easy_rec/python/layers/keras/mask_net.py @@ -6,9 +6,6 @@ from easy_rec.python.layers.keras.blocks import MLP from easy_rec.python.layers.utils import Parameter -if tf.__version__ >= '2.0': - tf = tf.compat.v1 - class MaskBlock(tf.keras.layers.Layer): diff --git a/easy_rec/python/model/cmbf.py b/easy_rec/python/model/cmbf.py index 0f0a8f3aa..a11a30582 100644 --- a/easy_rec/python/model/cmbf.py +++ b/easy_rec/python/model/cmbf.py @@ -38,7 +38,7 @@ def __init__(self, def build_predict_graph(self): hidden = self._cmbf_layer(self._is_training, l2_reg=self._l2_reg) - final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_dnn_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn_layer(hidden) diff --git a/easy_rec/python/model/collaborative_metric_learning.py b/easy_rec/python/model/collaborative_metric_learning.py index d785e7141..b19537239 100644 --- a/easy_rec/python/model/collaborative_metric_learning.py +++ b/easy_rec/python/model/collaborative_metric_learning.py @@ -48,21 +48,22 @@ def __init__( raise ValueError('unsupported loss type: %s' % LossType.Name(self._loss_type)) - self._highway_features = {} - self._highway_num = len(self._model_config.highway) - for _id in range(self._highway_num): - highway_cfg = self._model_config.highway[_id] - highway_feature, _ = self._input_layer(self._feature_dict, - highway_cfg.input) - self._highway_features[highway_cfg.input] = highway_feature - - self.input_features = [] - if self._model_config.HasField('input'): - input_feature, _ = self._input_layer(self._feature_dict, - self._model_config.input) - self.input_features.append(input_feature) - - self.dnn = copy_obj(self._model_config.dnn) + if not self.has_backbone: + self._highway_features = {} + self._highway_num = len(self._model_config.highway) + for _id in range(self._highway_num): + highway_cfg = self._model_config.highway[_id] + highway_feature, _ = self._input_layer(self._feature_dict, + highway_cfg.input) + self._highway_features[highway_cfg.input] = highway_feature + + self.input_features = [] + if self._model_config.HasField('input'): + input_feature, _ = self._input_layer(self._feature_dict, + self._model_config.input) + self.input_features.append(input_feature) + + self.dnn = copy_obj(self._model_config.dnn) if self._labels is not None: if self._model_config.HasField('session_id'): @@ -79,32 +80,35 @@ def __init__( self.sample_id = None def build_predict_graph(self): - for _id in range(self._highway_num): - highway_cfg = self._model_config.highway[_id] - highway_fea = tf.layers.batch_normalization( - self._highway_features[highway_cfg.input], - training=self._is_training, - trainable=True, - name='highway_%s_bn' % highway_cfg.input) - highway_fea = highway( - highway_fea, - highway_cfg.emb_size, - activation=gelu, - scope='highway_%s' % _id) - print('highway_fea: ', highway_fea) - self.input_features.append(highway_fea) - - feature = tf.concat(self.input_features, axis=1) - - num_dnn_layer = len(self.dnn.hidden_units) - last_hidden = self.dnn.hidden_units.pop() - dnn_net = dnn.DNN(self.dnn, self._l2_reg, 'dnn', self._is_training) - net_output = dnn_net(feature) - tower_emb = tf.layers.dense( - inputs=net_output, - units=last_hidden, - kernel_regularizer=self._l2_reg, - name='dnn/dnn_%d' % (num_dnn_layer - 1)) + if self.has_backbone: + tower_emb = self.backbone + else: + for _id in range(self._highway_num): + highway_cfg = self._model_config.highway[_id] + highway_fea = tf.layers.batch_normalization( + self._highway_features[highway_cfg.input], + training=self._is_training, + trainable=True, + name='highway_%s_bn' % highway_cfg.input) + highway_fea = highway( + highway_fea, + highway_cfg.emb_size, + activation=gelu, + scope='highway_%s' % _id) + print('highway_fea: ', highway_fea) + self.input_features.append(highway_fea) + + feature = tf.concat(self.input_features, axis=1) + + num_dnn_layer = len(self.dnn.hidden_units) + last_hidden = self.dnn.hidden_units.pop() + dnn_net = dnn.DNN(self.dnn, self._l2_reg, 'dnn', self._is_training) + net_output = dnn_net(feature) + tower_emb = tf.layers.dense( + inputs=net_output, + units=last_hidden, + kernel_regularizer=self._l2_reg, + name='dnn/dnn_%d' % (num_dnn_layer - 1)) if self._model_config.output_l2_normalized_emb: norm_emb = tf.nn.l2_normalize(tower_emb, axis=-1) diff --git a/easy_rec/python/model/dcn.py b/easy_rec/python/model/dcn.py index fcfa7e780..2a460163a 100644 --- a/easy_rec/python/model/dcn.py +++ b/easy_rec/python/model/dcn.py @@ -60,7 +60,7 @@ def build_predict_graph(self): tower_fea_arr.append(cross_tensor) # final tower all_fea = tf.concat(tower_fea_arr, axis=1) - final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_dnn_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn_layer(all_fea) output = tf.layers.dense(all_fea, self._num_class, name='output') diff --git a/easy_rec/python/model/deepfm.py b/easy_rec/python/model/deepfm.py index d1414c050..0ead36e26 100644 --- a/easy_rec/python/model/deepfm.py +++ b/easy_rec/python/model/deepfm.py @@ -39,7 +39,7 @@ def __init__(self, def build_input_layer(self, model_config, feature_configs): # overwrite create input_layer to support wide_output_dim - has_final = len(model_config.deepfm.final_dnn.hidden_units) > 0 + has_final = len(model_config.deepfm.final_mlp.hidden_units) > 0 if not has_final: assert model_config.deepfm.wide_output_dim == model_config.num_class self._wide_output_dim = model_config.deepfm.wide_output_dim @@ -60,9 +60,9 @@ def build_predict_graph(self): deep_fea = deep_layer(self._deep_features) # Final - if len(self._model_config.final_dnn.hidden_units) > 0: + if len(self._model_config.final_mlp.hidden_units) > 0: all_fea = tf.concat([wide_fea, fm_fea, deep_fea], axis=1) - final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_dnn_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn_layer(all_fea) output = tf.layers.dense( diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py index fe9a20ef8..cb6c8a802 100644 --- a/easy_rec/python/model/easy_rec_model.py +++ b/easy_rec/python/model/easy_rec_model.py @@ -13,7 +13,6 @@ from easy_rec.python.compat import regularizers from easy_rec.python.layers import input_layer from easy_rec.python.layers.backbone import Backbone -from easy_rec.python.layers.sequence_encoder import SequenceEncoder from easy_rec.python.utils import constant from easy_rec.python.utils import estimator_utils from easy_rec.python.utils import restore_filter diff --git a/easy_rec/python/model/multi_tower.py b/easy_rec/python/model/multi_tower.py index 5cdd89ba5..cb0aa6233 100644 --- a/easy_rec/python/model/multi_tower.py +++ b/easy_rec/python/model/multi_tower.py @@ -52,7 +52,7 @@ def build_predict_graph(self): tower_fea_arr.append(tower_fea) all_fea = tf.concat(tower_fea_arr, axis=1) - final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_dnn_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn_layer(all_fea) output = tf.layers.dense(all_fea, self._num_class, name='output') diff --git a/easy_rec/python/model/multi_tower_bst.py b/easy_rec/python/model/multi_tower_bst.py index 4cbc9fd29..478d26a6c 100644 --- a/easy_rec/python/model/multi_tower_bst.py +++ b/easy_rec/python/model/multi_tower_bst.py @@ -180,7 +180,7 @@ def build_predict_graph(self): tower_fea_arr.append(tower_fea) all_fea = tf.concat(tower_fea_arr, axis=1) - final_dnn = dnn.DNN(self._model_config.final_dnn, self._l2_reg, 'final_dnn', + final_dnn = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn(all_fea) output = tf.layers.dense(all_fea, self._num_class, name='output') diff --git a/easy_rec/python/model/multi_tower_din.py b/easy_rec/python/model/multi_tower_din.py index e586da1cf..7a1356caa 100644 --- a/easy_rec/python/model/multi_tower_din.py +++ b/easy_rec/python/model/multi_tower_din.py @@ -120,7 +120,7 @@ def build_predict_graph(self): tower_fea_arr.append(tower_fea) all_fea = tf.concat(tower_fea_arr, axis=1) - final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_dnn_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn_layer(all_fea) output = tf.layers.dense(all_fea, self._num_class, name='output') diff --git a/easy_rec/python/model/multi_tower_recall.py b/easy_rec/python/model/multi_tower_recall.py index 8f576944e..101ad36cf 100644 --- a/easy_rec/python/model/multi_tower_recall.py +++ b/easy_rec/python/model/multi_tower_recall.py @@ -57,7 +57,7 @@ def build_predict_graph(self): tower_fea_arr.append(item_tower_emb) all_fea = tf.concat(tower_fea_arr, axis=-1) - final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_dnn_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn_layer(all_fea) output = tf.layers.dense(all_fea, 1, name='output') diff --git a/easy_rec/python/model/uniter.py b/easy_rec/python/model/uniter.py index 40dfc8cb1..9479ce639 100644 --- a/easy_rec/python/model/uniter.py +++ b/easy_rec/python/model/uniter.py @@ -37,7 +37,7 @@ def __init__(self, def build_predict_graph(self): hidden = self._uniter_layer(self._is_training, l2_reg=self._l2_reg) - final_dnn_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_dnn_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_dnn_layer(hidden) diff --git a/easy_rec/python/model/wide_and_deep.py b/easy_rec/python/model/wide_and_deep.py index f841ed049..e0850abe4 100755 --- a/easy_rec/python/model/wide_and_deep.py +++ b/easy_rec/python/model/wide_and_deep.py @@ -34,7 +34,7 @@ def __init__(self, def build_input_layer(self, model_config, feature_configs): # overwrite create input_layer to support wide_output_dim - has_final = len(model_config.wide_and_deep.final_dnn.hidden_units) > 0 + has_final = len(model_config.wide_and_deep.final_mlp.hidden_units) > 0 self._wide_output_dim = model_config.wide_and_deep.wide_output_dim if not has_final: model_config.wide_and_deep.wide_output_dim = model_config.num_class @@ -55,11 +55,11 @@ def build_predict_graph(self): logging.info('output deep features dimension: %d' % deep_fea.get_shape()[-1]) - has_final = len(self._model_config.final_dnn.hidden_units) > 0 + has_final = len(self._model_config.final_mlp.hidden_units) > 0 print('wide_deep has_final_dnn layers = %d' % has_final) if has_final: all_fea = tf.concat([wide_fea, deep_fea], axis=1) - final_layer = dnn.DNN(self._model_config.final_dnn, self._l2_reg, + final_layer = dnn.DNN(self._model_config.final_mlp, self._l2_reg, 'final_dnn', self._is_training) all_fea = final_layer(all_fea) output = tf.layers.dense( @@ -87,7 +87,7 @@ def get_grouped_vars(self): Return: list of list of variables. """ - assert len(self._model_config.final_dnn.hidden_units) == 0, \ + assert len(self._model_config.final_mlp.hidden_units) == 0, \ 'if use different optimizers for wide group and deep group, '\ + ' final_dnn should not be set.' wide_vars = []