From 606aa9bd87f7e6cebcae3b341d56d93553155741 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Fri, 2 Feb 2024 11:44:21 +0800 Subject: [PATCH] add user embedding and item embedding outputs for match model --- easy_rec/python/model/dssm.py | 12 +++++++++--- easy_rec/python/model/mind.py | 15 ++++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/easy_rec/python/model/dssm.py b/easy_rec/python/model/dssm.py index bf3dfdc88..334850d94 100644 --- a/easy_rec/python/model/dssm.py +++ b/easy_rec/python/model/dssm.py @@ -105,15 +105,21 @@ def build_predict_graph(self): def get_outputs(self): if self._loss_type == LossType.CLASSIFICATION: - return ['logits', 'probs', 'user_emb', 'item_emb'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb', + 'item_tower_emb' + ] elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY: self._prediction_dict['logits'] = tf.squeeze( self._prediction_dict['logits'], axis=-1) self._prediction_dict['probs'] = tf.nn.sigmoid( self._prediction_dict['logits']) - return ['logits', 'probs', 'user_emb', 'item_emb'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb', + 'item_tower_emb' + ] elif self._loss_type == LossType.L2_LOSS: - return ['y', 'user_emb', 'item_emb'] + return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb'] else: raise ValueError('invalid loss type: %s' % str(self._loss_type)) diff --git a/easy_rec/python/model/mind.py b/easy_rec/python/model/mind.py index 270060297..0e47f79d8 100644 --- a/easy_rec/python/model/mind.py +++ b/easy_rec/python/model/mind.py @@ -423,14 +423,23 @@ def build_metric_graph(self, eval_config): def get_outputs(self): if self._loss_type == LossType.CLASSIFICATION: - return ['logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num', + 'user_interests', 'item_tower_emb' + ] elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY: self._prediction_dict['logits'] = tf.squeeze( self._prediction_dict['logits'], axis=-1) self._prediction_dict['probs'] = tf.nn.sigmoid( self._prediction_dict['logits']) - return ['logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num', + 'user_interests', 'item_tower_emb' + ] elif self._loss_type == LossType.L2_LOSS: - return ['y', 'user_emb', 'item_emb', 'user_emb_num'] + return [ + 'y', 'user_emb', 'item_emb', 'user_emb_num', 'user_interests', + 'item_tower_emb' + ] else: raise ValueError('invalid loss type: %s' % str(self._loss_type))