Skip to content

Commit

Permalink
add user embedding and item embedding outputs for match model
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Feb 2, 2024
1 parent a08bf22 commit 606aa9b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
12 changes: 9 additions & 3 deletions easy_rec/python/model/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,21 @@ def build_predict_graph(self):

def get_outputs(self):
if self._loss_type == LossType.CLASSIFICATION:
return ['logits', 'probs', 'user_emb', 'item_emb']
return [
'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
'item_tower_emb'
]
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
self._prediction_dict['logits'] = tf.squeeze(
self._prediction_dict['logits'], axis=-1)
self._prediction_dict['probs'] = tf.nn.sigmoid(
self._prediction_dict['logits'])
return ['logits', 'probs', 'user_emb', 'item_emb']
return [
'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
'item_tower_emb'
]
elif self._loss_type == LossType.L2_LOSS:
return ['y', 'user_emb', 'item_emb']
return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb']
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))

Expand Down
15 changes: 12 additions & 3 deletions easy_rec/python/model/mind.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,23 @@ def build_metric_graph(self, eval_config):

def get_outputs(self):
if self._loss_type == LossType.CLASSIFICATION:
return ['logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num']
return [
'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num',
'user_interests', 'item_tower_emb'
]
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
self._prediction_dict['logits'] = tf.squeeze(
self._prediction_dict['logits'], axis=-1)
self._prediction_dict['probs'] = tf.nn.sigmoid(
self._prediction_dict['logits'])
return ['logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num']
return [
'logits', 'probs', 'user_emb', 'item_emb', 'user_emb_num',
'user_interests', 'item_tower_emb'
]
elif self._loss_type == LossType.L2_LOSS:
return ['y', 'user_emb', 'item_emb', 'user_emb_num']
return [
'y', 'user_emb', 'item_emb', 'user_emb_num', 'user_interests',
'item_tower_emb'
]
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))

0 comments on commit 606aa9b

Please sign in to comment.