diff --git a/easy_rec/python/model/dssm.py b/easy_rec/python/model/dssm.py index bf3dfdc88..334850d94 100644 --- a/easy_rec/python/model/dssm.py +++ b/easy_rec/python/model/dssm.py @@ -105,15 +105,21 @@ def build_predict_graph(self): def get_outputs(self): if self._loss_type == LossType.CLASSIFICATION: - return ['logits', 'probs', 'user_emb', 'item_emb'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb', + 'item_tower_emb' + ] elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY: self._prediction_dict['logits'] = tf.squeeze( self._prediction_dict['logits'], axis=-1) self._prediction_dict['probs'] = tf.nn.sigmoid( self._prediction_dict['logits']) - return ['logits', 'probs', 'user_emb', 'item_emb'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb', + 'item_tower_emb' + ] elif self._loss_type == LossType.L2_LOSS: - return ['y', 'user_emb', 'item_emb'] + return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb'] else: raise ValueError('invalid loss type: %s' % str(self._loss_type)) diff --git a/easy_rec/python/model/mind.py b/easy_rec/python/model/mind.py index 270060297..0e47f79d8 100644 --- a/easy_rec/python/model/mind.py +++ b/easy_rec/python/model/mind.py @@ -423,14 +423,23 @@ def build_metric_graph(self, eval_config): def get_outputs(self): if self._loss_type == LossType.CLASSIFICATION: - return ['logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num', + 'user_interests', 'item_tower_emb' + ] elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY: self._prediction_dict['logits'] = tf.squeeze( self._prediction_dict['logits'], axis=-1) self._prediction_dict['probs'] = tf.nn.sigmoid( self._prediction_dict['logits']) - return ['logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num'] + return [ + 'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num', + 'user_interests', 'item_tower_emb' + ] elif self._loss_type == LossType.L2_LOSS: - return ['y', 'user_emb', 'item_emb', 'user_emb_num'] + return [ + 'y', 'user_emb', 'item_emb', 'user_emb_num', 'user_interests', + 'item_tower_emb' + ] else: raise ValueError('invalid loss type: %s' % str(self._loss_type))