Skip to content

Commit

Permalink
[feat]: format backbone code, add recurrent and sequential layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jun 19, 2023
1 parent 7d0e350 commit 136cf37
Show file tree
Hide file tree
Showing 18 changed files with 232 additions and 227 deletions.
9 changes: 6 additions & 3 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from easy_rec.python.layers.keras import MLP
from easy_rec.python.layers.utils import Parameter
from easy_rec.python.protos import backbone_pb2
from easy_rec.python.protos import keras_layer_pb2
from easy_rec.python.utils.dag import DAG
from easy_rec.python.utils.load_class import load_keras_layer

Expand Down Expand Up @@ -204,13 +203,17 @@ def call_layer(self, inputs, config, name, training):
output = inputs
for i in range(conf.num_steps):
name_i = '%s_%d' % (name, i)
output_i = self.call_keras_layer(conf.keras_layer, output, name_i, training)
layer = conf.keras_layer
output_i = self.call_keras_layer(layer, output, name_i, training)
if fixed_input_index >= 0:
j = 0
for idx in range(len(output)):
if idx == fixed_input_index:
continue
output[idx] = output_i[j] if type(output_i) in (tuple, list) else output_i
if type(output_i) in (tuple, list):
output[idx] = output_i[j]
else:
output[idx] = output_i
j += 1
else:
output = output_i
Expand Down
6 changes: 3 additions & 3 deletions easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .blocks import MLP
from .blocks import Highway
from .bst import BST
from .dcn import Cross
from .din import DIN
from .dot_interaction import DotInteraction
from .fibinet import BiLinear
from .fibinet import FiBiNet
from .fibinet import SENet
from .fm import FM
from .interaction import FM
from .interaction import Cross
from .interaction import DotInteraction
from .mask_net import MaskBlock
from .mask_net import MaskNet
from .numerical_embedding import AutoDisEmbedding
Expand Down
89 changes: 0 additions & 89 deletions easy_rec/python/layers/keras/dot_interaction.py

This file was deleted.

66 changes: 41 additions & 25 deletions easy_rec/python/layers/keras/fibinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import tensorflow as tf

from easy_rec.python.layers import dnn
from easy_rec.python.layers.common_layers import layer_norm
from easy_rec.python.layers.keras.blocks import MLP
from easy_rec.python.layers.utils import Parameter
Expand All @@ -15,18 +14,27 @@


class SENet(tf.keras.layers.Layer):
"""SENet+ Layer used in FiBiNET,支持不同field的embedding dimension不等.
"""SENET Layer used in FiBiNET.
arxiv: 2209.05016
Input shape
- A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
The ``embedding_size`` of each field can have different value.
Output shape
- A 2D tensor with shape: ``(batch_size,sum_of_embedding_size)``.
References:
1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
"""

def __init__(self, params, name='SENet', **kwargs):
super(SENet, self).__init__(name, **kwargs)
self.config = params.get_pb_config()

def call(self, inputs, **kwargs):
"""embedding_list: - A list of 2D tensor with shape: ``(batch_size,embedding_size)``."""
print('SENET layer with %d inputs' % len(inputs))
g = self.config.num_squeeze_group
for emb in inputs:
assert emb.shape.ndims == 2, 'field embeddings must be rank 2 tensors'
Expand Down Expand Up @@ -88,14 +96,26 @@ def _full_interaction(v_i, v_j):


class BiLinear(tf.keras.layers.Layer):
"""双线性特征交互层,支持不同field embeddings的size不等.
"""BilinearInteraction Layer used in FiBiNET.
Input shape
- A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
Its length is ``filed_size``.
The ``embedding_size`` of each field can have different value.
arxiv: 2209.05016
Output shape
- 2D tensor with shape: ``(batch_size,output_size)``.
Attributes:
num_output_units: 输出的size
type: ['all', 'each', 'interaction'],支持其中一种
use_plus: 是否使用bi-linear+
num_output_units: the number of output units
type: ['all', 'each', 'interaction'], types of bilinear functions used in this layer
use_plus: whether to use bi-linear+
References:
1. [FiBiNET](https://arxiv.org/pdf/1905.09433.pdf)
Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction
2. [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
"""

def __init__(self, params, name='bilinear', **kwargs):
Expand Down Expand Up @@ -186,36 +206,32 @@ def call(self, inputs, **kwargs):
class FiBiNet(tf.keras.layers.Layer):
"""FiBiNet++:Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction.
This is almost an exact implementation of the original FiBiNet++ model.
See the original paper:
https://arxiv.org/pdf/2209.05016.pdf
References:
- [FiBiNet++](https://arxiv.org/pdf/2209.05016.pdf)
Improving FiBiNet by Greatly Reducing Model Size for CTR Prediction
"""

def __init__(self, params, name='fibinet', l2_reg=None, **kwargs):
def __init__(self, params, name='fibinet', **kwargs):
super(FiBiNet, self).__init__(name, **kwargs)
self._config = params.get_pb_config()
if self._config.HasField('mlp'):
# self.final_dnn = dnn.DNN(
# self._config.mlp,
# kwargs['l2_reg'] if 'l2_reg' in kwargs else None,
# name='%s_fibinet_mlp' % self.name,
# is_training=False)
p = Parameter.make_from_pb(self._config.mlp)
self.final_dnn = MLP(p, name=name, l2_reg=l2_reg)
p.l2_regularizer = params.l2_regularizer
self.final_mlp = MLP(p, name=name)
else:
self.final_dnn = None
self.final_mlp = None

def call(self, inputs, training=None, **kwargs):
feature_list = []

params = Parameter.make_from_pb(self._config.senet)
senet = SENet(params, name='%s_senet' % self.name)
senet = SENet(params, name='%s/senet' % self.name)
senet_output = senet(inputs)
feature_list.append(senet_output)

if self._config.HasField('bilinear'):
params = Parameter.make_from_pb(self._config.bilinear)
bilinear = BiLinear(params, name='%s_bilinear' % self.name)
bilinear = BiLinear(params, name='%s/bilinear' % self.name)
bilinear_output = bilinear(inputs)
feature_list.append(bilinear_output)

Expand All @@ -224,6 +240,6 @@ def call(self, inputs, training=None, **kwargs):
else:
feature = feature_list[0]

if self.final_dnn is not None:
feature = self.final_dnn(feature, training=training)
if self.final_mlp is not None:
feature = self.final_mlp(feature, training=training)
return feature
46 changes: 0 additions & 46 deletions easy_rec/python/layers/keras/fm.py

This file was deleted.

Loading

0 comments on commit 136cf37

Please sign in to comment.