diff --git a/easy_rec/python/layers/keras/fibinet.py b/easy_rec/python/layers/keras/fibinet.py index 18c75e164..59ad01bd6 100644 --- a/easy_rec/python/layers/keras/fibinet.py +++ b/easy_rec/python/layers/keras/fibinet.py @@ -58,7 +58,7 @@ def build(self, input_shape): kernel_initializer='he_normal', name='W1') self.excite_layer = Dense( - units=emb_size, kernel_regularizer='glorot_normal', name='W2') + units=emb_size, kernel_initializer='glorot_normal', name='W2') def call(self, inputs, **kwargs): g = self.config.num_squeeze_group diff --git a/easy_rec/python/layers/keras/mask_net.py b/easy_rec/python/layers/keras/mask_net.py index ac37a0a36..f76c9ab41 100644 --- a/easy_rec/python/layers/keras/mask_net.py +++ b/easy_rec/python/layers/keras/mask_net.py @@ -31,8 +31,9 @@ def __init__(self, params, name='mask_block', reuse=None, **kwargs): self.reuse = reuse def build(self, input_shape): - input_dim = int(input_shape[0].shape[-1]) - mask_input_dim = int(input_shape[1].shape[-1]) + assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs' + input_dim = int(input_shape[0][-1]) + mask_input_dim = int(input_shape[1][-1]) if self.config.HasField('reduction_factor'): aggregation_size = int(mask_input_dim * self.config.reduction_factor) elif self.config.HasField('aggregation_size') is not None: diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 859e2442c..dd2bd8f38 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -880,6 +880,15 @@ def test_train_with_multi_worker_mirror(self): self._test_dir) self.assertTrue(self._success) + @unittest.skipIf( + LooseVersion(tf.__version__) != LooseVersion('2.3.0'), + 'MultiWorkerMirroredStrategy need tf version == 2.3') + def test_train_with_multi_worker_mirror(self): + self._success = test_utils.test_distributed_train_eval( + 'samples/model_config/mmoe_mirrored_strategy_on_taobao.config', + self._test_dir) + self.assertTrue(self._success) + def test_fg_dtype(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/taobao_fg_test_dtype.config', self._test_dir) diff --git a/samples/model_config/mmoe_mirrored_strategy_on_taobao.config b/samples/model_config/mmoe_mirrored_strategy_on_taobao.config new file mode 100644 index 000000000..761f3e739 --- /dev/null +++ b/samples/model_config/mmoe_mirrored_strategy_on_taobao.config @@ -0,0 +1,318 @@ +train_input_path: "data/test/tb_data/taobao_train_data" +eval_input_path: "data/test/tb_data/taobao_test_data" +model_dir: "experiments/mmoe_mirrored_strategy_taobao_ckpt" + +train_config { + optimizer_config { + adam_optimizer { + learning_rate { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 1e-07 + } + } + } + use_moving_average: false + } + train_distribute: MultiWorkerMirroredStrategy + num_gpus_per_worker: 2 + num_steps: 200 + sync_replicas: true + save_checkpoints_steps: 100 + log_step_count_steps: 100 +} +data_config { + batch_size: 4096 + label_fields: "clk" + label_fields: "buy" + prefetch_size: 32 + input_type: CSVInput + input_fields { + input_name: "clk" + input_type: INT32 + } + input_fields { + input_name: "buy" + input_type: INT32 + } + input_fields { + input_name: "pid" + input_type: STRING + } + input_fields { + input_name: "adgroup_id" + input_type: STRING + } + input_fields { + input_name: "cate_id" + input_type: STRING + } + input_fields { + input_name: "campaign_id" + input_type: STRING + } + input_fields { + input_name: "customer" + input_type: STRING + } + input_fields { + input_name: "brand" + input_type: STRING + } + input_fields { + input_name: "user_id" + input_type: STRING + } + input_fields { + input_name: "cms_segid" + input_type: STRING + } + input_fields { + input_name: "cms_group_id" + input_type: STRING + } + input_fields { + input_name: "final_gender_code" + input_type: STRING + } + input_fields { + input_name: "age_level" + input_type: STRING + } + input_fields { + input_name: "pvalue_level" + input_type: STRING + } + input_fields { + input_name: "shopping_level" + input_type: STRING + } + input_fields { + input_name: "occupation" + input_type: STRING + } + input_fields { + input_name: "new_user_class_level" + input_type: STRING + } + input_fields { + input_name: "tag_category_list" + input_type: STRING + } + input_fields { + input_name: "tag_brand_list" + input_type: STRING + } + input_fields { + input_name: "price" + input_type: INT32 + } +} +feature_config: { + features { + input_names: "pid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features { + input_names: "adgroup_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features { + input_names: "cate_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features { + input_names: "campaign_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features { + input_names: "customer" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features { + input_names: "brand" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features { + input_names: "user_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features { + input_names: "cms_segid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features { + input_names: "cms_group_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features { + input_names: "final_gender_code" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features { + input_names: "age_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features { + input_names: "pvalue_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features { + input_names: "shopping_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features { + input_names: "occupation" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features { + input_names: "new_user_class_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features { + input_names: "tag_category_list" + feature_type: TagFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" + } + features { + input_names: "tag_brand_list" + feature_type: TagFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" + } + features { + input_names: "price" + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 + } +} +model_config { + model_name: "MMoE" + model_class: "MultiTaskModel" + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + feature_names: "pid" + feature_names: "tag_category_list" + feature_names: "tag_brand_list" + wide_deep: DEEP + } + backbone { + blocks { + name: 'all' + inputs { + feature_group_name: 'all' + } + input_layer { + only_output_feature_list: true + } + } + blocks { + name: "senet" + inputs { + block_name: "all" + } + keras_layer { + class_name: 'SENet' + senet { + reduction_ratio: 4 + } + } + } + blocks { + name: "mmoe" + inputs { + block_name: "senet" + } + keras_layer { + class_name: 'MMoE' + mmoe { + num_task: 2 + num_expert: 3 + expert_mlp { + hidden_units: [256, 128] + } + } + } + } + } + model_params { + task_towers { + tower_name: "ctr" + label_name: "clk" + dnn { + hidden_units: [128, 64] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + task_towers { + tower_name: "cvr" + label_name: "buy" + dnn { + hidden_units: [128, 64] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + l2_regularization: 1e-06 + } + embedding_regularization: 5e-05 +}