Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix of SENet when run with MirroredStrategy #443

Merged
merged 32 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d39af13
fix bug of `package` syntax of backbone
yangxudong Sep 16, 2023
8ee5af2
fix bug of `package` syntax of backbone
yangxudong Sep 16, 2023
8593b79
add document and test cases of BST & DIN backbone model
yangxudong Sep 17, 2023
c8e1257
modify input layer to support reuse
yangxudong Sep 18, 2023
bab2cea
add support for contrastive learning
yangxudong Sep 18, 2023
bd6dbe8
modify document
yangxudong Sep 18, 2023
670458e
add powerful support for contrastive learning
yangxudong Sep 18, 2023
a357cb9
add powerful support for contrastive learning
yangxudong Sep 19, 2023
61e238f
modify document
yangxudong Sep 19, 2023
4141461
modify document
yangxudong Sep 19, 2023
02d4dce
fix bug of bst when output all token embedding
yangxudong Sep 19, 2023
8672e47
add documents
yangxudong Sep 19, 2023
5bbead1
decrease num_steps of test cases
yangxudong Sep 19, 2023
5cec266
Merge branch 'master' of https://github.com/alibaba/EasyRec into bug_fix
yangxudong Sep 20, 2023
798d72f
fix bug of NaN output of tf.norm
yangxudong Sep 27, 2023
a7cc673
Merge branch 'master' into bug_fix
yangxudong Dec 13, 2023
5157f50
fix doc build problem
yangxudong Dec 13, 2023
2ee7015
fix doc build problem
yangxudong Dec 13, 2023
5cafa73
fix doc build problem
yangxudong Dec 13, 2023
ac9fb01
fix doc build problem
yangxudong Dec 14, 2023
7e8fa34
Merge branch 'master' of https://github.com/alibaba/EasyRec into bug_fix
yangxudong Dec 14, 2023
40ed612
Merge branch 'master' of https://github.com/alibaba/EasyRec into bug_fix
yangxudong Dec 21, 2023
27c95df
fix doc build problem
yangxudong Dec 22, 2023
db73211
fix doc build problem
yangxudong Dec 25, 2023
6ec43f7
fix bug of SENet when run with tf.distribute.MirroredStrategy
yangxudong Dec 26, 2023
6b3eba8
fix bug of SENet when run with tf.distribute.MirroredStrategy
yangxudong Dec 26, 2023
71d91eb
add test case
yangxudong Dec 27, 2023
39d3c05
add test case
yangxudong Dec 27, 2023
0d83b0b
add test case
yangxudong Dec 27, 2023
e524243
add test case
yangxudong Dec 27, 2023
0557c9b
add test case
yangxudong Dec 28, 2023
8572072
add LayerNormalization Layer
yangxudong Dec 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
13 changes: 8 additions & 5 deletions easy_rec/python/layers/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, config, features, input_layer, l2_reg=None):
self._block_outputs = {}
self._package_input = None
reuse = None if config.name == 'backbone' else tf.AUTO_REUSE
input_feature_groups = set()
input_feature_groups = {}

for block in config.blocks:
if len(block.inputs) == 0:
Expand All @@ -71,9 +71,9 @@ def __init__(self, config, features, input_layer, l2_reg=None):
if group in input_feature_groups:
logging.warning('input `%s` already exists in other block' % group)
else:
input_feature_groups.add(group)
input_fn = EnhancedInputLayer(self._input_layer, self._features,
group, reuse)
input_feature_groups[group] = input_fn
self._name_to_layer[block.name] = input_fn
else:
self.define_layers(layer, block, block.name, reuse)
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self, config, features, input_layer, l2_reg=None):
if iname in self._name_to_blocks:
assert iname != name, 'input name can not equal to block name:' + iname
self._dag.add_edge(iname, name)
elif iname not in input_feature_groups:
else:
is_fea_group = input_type == 'feature_group_name'
if is_fea_group and input_layer.has_group(iname):
logging.info('adding an input_layer block: ' + iname)
Expand All @@ -129,8 +129,11 @@ def __init__(self, config, features, input_layer, l2_reg=None):
self._name_to_blocks[iname] = new_block
self._dag.add_node(iname)
self._dag.add_edge(iname, name)
input_feature_groups.add(iname)
fn = EnhancedInputLayer(self._input_layer, self._features, iname)
if iname in input_feature_groups:
fn = input_feature_groups[iname]
else:
fn = EnhancedInputLayer(self._input_layer, self._features, iname)
input_feature_groups[iname] = fn
self._name_to_layer[iname] = fn
elif Package.has_backbone_block(iname):
backbone = Package.__packages['backbone']
Expand Down
45 changes: 23 additions & 22 deletions easy_rec/python/layers/keras/fibinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,31 @@ def __init__(self, params, name='SENet', reuse=None, **kwargs):
self.config = params.get_pb_config()
self.reuse = reuse

def call(self, inputs, **kwargs):
def build(self, input_shape):
g = self.config.num_squeeze_group
for emb in inputs:
assert emb.shape.ndims == 2, 'field embeddings must be rank 2 tensors'
dim = int(emb.shape[-1])
emb_size = 0
for shape in input_shape:
assert shape.ndims == 2, 'field embeddings must be rank 2 tensors'
dim = int(shape[-1])
assert dim >= g and dim % g == 0, 'field embedding dimension %d must be divisible by %d' % (
dim, g)
emb_size += dim

field_size = len(inputs)
feature_size_list = [emb.shape.as_list()[-1] for emb in inputs]
r = self.config.reduction_ratio
field_size = len(input_shape)
reduction_size = max(1, field_size * g * 2 // r)
initializer = tf.keras.initializers.VarianceScaling()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a test case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there already has one: mmoe_backbone_on_taobao.config

self.reduce_layer = tf.keras.layers.Dense(
units=reduction_size,
activation='relu',
kernel_initializer=initializer,
name='W1')
init = tf.keras.initializers.glorot_normal()
self.excite_layer = tf.keras.layers.Dense(
units=emb_size, kernel_regularizer=init, name='W2')

def call(self, inputs, **kwargs):
g = self.config.num_squeeze_group

# Squeeze
# embedding dimension 必须能被 g 整除
Expand All @@ -59,22 +74,8 @@ def call(self, inputs, **kwargs):
z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]

# Excitation
r = self.config.reduction_ratio
reduction_size = max(1, field_size * g * 2 // r)

a1 = tf.layers.dense(
z,
reduction_size,
kernel_initializer=tf.initializers.variance_scaling(),
activation=tf.nn.relu,
reuse=self.reuse,
name='%s/W1' % self.name)
weights = tf.layers.dense(
a1,
sum(feature_size_list),
kernel_initializer=tf.glorot_normal_initializer(),
reuse=self.reuse,
name='%s/W2' % self.name)
a1 = self.reduce_layer(z)
weights = self.excite_layer(a1)

# Re-weight
inputs = tf.concat(inputs, axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,distutils,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
known_third_party = absl,common_io,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down
Loading