Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Dec 27, 2023
1 parent 6b3eba8 commit 71d91eb
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 3 deletions.
2 changes: 1 addition & 1 deletion easy_rec/python/layers/keras/fibinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def build(self, input_shape):
kernel_initializer='he_normal',
name='W1')
self.excite_layer = Dense(
units=emb_size, kernel_regularizer='glorot_normal', name='W2')
units=emb_size, kernel_initializer='glorot_normal', name='W2')

def call(self, inputs, **kwargs):
g = self.config.num_squeeze_group
Expand Down
5 changes: 3 additions & 2 deletions easy_rec/python/layers/keras/mask_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def __init__(self, params, name='mask_block', reuse=None, **kwargs):
self.reuse = reuse

def build(self, input_shape):
input_dim = int(input_shape[0].shape[-1])
mask_input_dim = int(input_shape[1].shape[-1])
assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs'
input_dim = int(input_shape[0][-1])
mask_input_dim = int(input_shape[1][-1])
if self.config.HasField('reduction_factor'):
aggregation_size = int(mask_input_dim * self.config.reduction_factor)
elif self.config.HasField('aggregation_size') is not None:
Expand Down
9 changes: 9 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,15 @@ def test_train_with_multi_worker_mirror(self):
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(
LooseVersion(tf.__version__) != LooseVersion('2.3.0'),
'MultiWorkerMirroredStrategy need tf version == 2.3')
def test_train_with_multi_worker_mirror(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/mmoe_mirrored_strategy_on_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_fg_dtype(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/taobao_fg_test_dtype.config', self._test_dir)
Expand Down
318 changes: 318 additions & 0 deletions samples/model_config/mmoe_mirrored_strategy_on_taobao.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
train_input_path: "data/test/tb_data/taobao_train_data"
eval_input_path: "data/test/tb_data/taobao_test_data"
model_dir: "experiments/mmoe_mirrored_strategy_taobao_ckpt"

train_config {
optimizer_config {
adam_optimizer {
learning_rate {
exponential_decay_learning_rate {
initial_learning_rate: 0.001
decay_steps: 1000
decay_factor: 0.5
min_learning_rate: 1e-07
}
}
}
use_moving_average: false
}
train_distribute: MultiWorkerMirroredStrategy
num_gpus_per_worker: 2
num_steps: 200
sync_replicas: true
save_checkpoints_steps: 100
log_step_count_steps: 100
}
data_config {
batch_size: 4096
label_fields: "clk"
label_fields: "buy"
prefetch_size: 32
input_type: CSVInput
input_fields {
input_name: "clk"
input_type: INT32
}
input_fields {
input_name: "buy"
input_type: INT32
}
input_fields {
input_name: "pid"
input_type: STRING
}
input_fields {
input_name: "adgroup_id"
input_type: STRING
}
input_fields {
input_name: "cate_id"
input_type: STRING
}
input_fields {
input_name: "campaign_id"
input_type: STRING
}
input_fields {
input_name: "customer"
input_type: STRING
}
input_fields {
input_name: "brand"
input_type: STRING
}
input_fields {
input_name: "user_id"
input_type: STRING
}
input_fields {
input_name: "cms_segid"
input_type: STRING
}
input_fields {
input_name: "cms_group_id"
input_type: STRING
}
input_fields {
input_name: "final_gender_code"
input_type: STRING
}
input_fields {
input_name: "age_level"
input_type: STRING
}
input_fields {
input_name: "pvalue_level"
input_type: STRING
}
input_fields {
input_name: "shopping_level"
input_type: STRING
}
input_fields {
input_name: "occupation"
input_type: STRING
}
input_fields {
input_name: "new_user_class_level"
input_type: STRING
}
input_fields {
input_name: "tag_category_list"
input_type: STRING
}
input_fields {
input_name: "tag_brand_list"
input_type: STRING
}
input_fields {
input_name: "price"
input_type: INT32
}
}
feature_config: {
features {
input_names: "pid"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
features {
input_names: "adgroup_id"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
features {
input_names: "cate_id"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10000
}
features {
input_names: "campaign_id"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
features {
input_names: "customer"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
features {
input_names: "brand"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
features {
input_names: "user_id"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100000
}
features {
input_names: "cms_segid"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100
}
features {
input_names: "cms_group_id"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 100
}
features {
input_names: "final_gender_code"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
features {
input_names: "age_level"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
features {
input_names: "pvalue_level"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
features {
input_names: "shopping_level"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
features {
input_names: "occupation"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
features {
input_names: "new_user_class_level"
feature_type: IdFeature
embedding_dim: 16
hash_bucket_size: 10
}
features {
input_names: "tag_category_list"
feature_type: TagFeature
embedding_dim: 16
hash_bucket_size: 100000
separator: "|"
}
features {
input_names: "tag_brand_list"
feature_type: TagFeature
embedding_dim: 16
hash_bucket_size: 100000
separator: "|"
}
features {
input_names: "price"
feature_type: IdFeature
embedding_dim: 16
num_buckets: 50
}
}
model_config {
model_name: "MMoE"
model_class: "MultiTaskModel"
feature_groups {
group_name: "all"
feature_names: "user_id"
feature_names: "cms_segid"
feature_names: "cms_group_id"
feature_names: "age_level"
feature_names: "pvalue_level"
feature_names: "shopping_level"
feature_names: "occupation"
feature_names: "new_user_class_level"
feature_names: "adgroup_id"
feature_names: "cate_id"
feature_names: "campaign_id"
feature_names: "customer"
feature_names: "brand"
feature_names: "price"
feature_names: "pid"
feature_names: "tag_category_list"
feature_names: "tag_brand_list"
wide_deep: DEEP
}
backbone {
blocks {
name: 'all'
inputs {
feature_group_name: 'all'
}
input_layer {
only_output_feature_list: true
}
}
blocks {
name: "senet"
inputs {
block_name: "all"
}
keras_layer {
class_name: 'SENet'
senet {
reduction_ratio: 4
}
}
}
blocks {
name: "mmoe"
inputs {
block_name: "senet"
}
keras_layer {
class_name: 'MMoE'
mmoe {
num_task: 2
num_expert: 3
expert_mlp {
hidden_units: [256, 128]
}
}
}
}
}
model_params {
task_towers {
tower_name: "ctr"
label_name: "clk"
dnn {
hidden_units: [128, 64]
}
num_class: 1
weight: 1.0
loss_type: CLASSIFICATION
metrics_set: {
auc {}
}
}
task_towers {
tower_name: "cvr"
label_name: "buy"
dnn {
hidden_units: [128, 64]
}
num_class: 1
weight: 1.0
loss_type: CLASSIFICATION
metrics_set: {
auc {}
}
}
l2_regularization: 1e-06
}
embedding_regularization: 5e-05
}

0 comments on commit 71d91eb

Please sign in to comment.