Skip to content

Commit

Permalink
add support for rank distillation
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Sep 5, 2024
1 parent 262e2ca commit de96ccf
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 40 deletions.
43 changes: 37 additions & 6 deletions easy_rec/python/builders/loss_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits
from easy_rec.python.loss.jrc_loss import jrc_loss
from easy_rec.python.loss.listwise_loss import listwise_distill_loss
from easy_rec.python.loss.listwise_loss import listwise_rank_loss
from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss
from easy_rec.python.loss.pairwise_loss import pairwise_hinge_loss
Expand Down Expand Up @@ -132,10 +133,11 @@ def build(loss_type,
name=loss_name)
elif loss_type == LossType.LISTWISE_RANK_LOSS:
session = kwargs.get('session_ids', None)
trans_fn, temp, label_is_logits = None, 1.0, False
trans_fn, temp, label_is_logits, scale = None, 1.0, False, False
if loss_param is not None:
temp = loss_param.temperature
label_is_logits = loss_param.label_is_logits
scale = loss_param.scale_logits
if loss_param.HasField('transform_fn'):
trans_fn = loss_param.transform_fn
return listwise_rank_loss(
Expand All @@ -145,6 +147,25 @@ def build(loss_type,
temperature=temp,
label_is_logits=label_is_logits,
transform_fn=trans_fn,
scale_logits=scale,
weights=loss_weight)
elif loss_type == LossType.LISTWISE_DISTILL_LOSS:
session = kwargs.get('session_ids', None)
trans_fn, temp, label_clip_max_value, scale = None, 1.0, 512.0, False
if loss_param is not None:
temp = loss_param.temperature
label_clip_max_value = loss_param.label_clip_max_value
scale = loss_param.scale_logits
if loss_param.HasField('transform_fn'):
trans_fn = loss_param.transform_fn
return listwise_distill_loss(
label,
pred,
session,
temperature=temp,
label_clip_max_value=label_clip_max_value,
transform_fn=trans_fn,
scale_logits=scale,
weights=loss_weight)
elif loss_type == LossType.F1_REWEIGHTED_LOSS:
f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square
Expand Down Expand Up @@ -199,6 +220,16 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
loss_name = 'kd_loss_' + kd.pred_name.replace('/', '_')
loss_name += '_' + kd.soft_label_name.replace('/', '_')

loss_weight = kd.loss_weight
if kd.HasField('task_space_indicator_name') and kd.HasField(
'task_space_indicator_value'):
in_task_space = tf.to_float(
tf.equal(feature_dict[kd.task_space_indicator_name],
kd.task_space_indicator_value))
loss_weight = loss_weight * (
kd.in_task_space_weight * in_task_space + kd.out_task_space_weight *
(1 - in_task_space))

label = label_dict[kd.soft_label_name]
pred = prediction_dict[kd.pred_name]
epsilon = tf.keras.backend.epsilon()
Expand Down Expand Up @@ -265,18 +296,18 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
preds = pred
losses = tf.keras.losses.KLD(labels, preds)
loss_dict[loss_name] = tf.reduce_mean(
losses, name=loss_name) * kd.loss_weight
losses, name=loss_name) * loss_weight
elif kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
losses = tf.keras.backend.binary_crossentropy(
label, pred, from_logits=True)
loss_dict[loss_name] = tf.reduce_mean(
losses, name=loss_name) * kd.loss_weight
losses, name=loss_name) * loss_weight
elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
loss_dict[loss_name] = tf.losses.log_loss(
label, pred, weights=kd.loss_weight)
label, pred, weights=loss_weight)
elif kd.loss_type == LossType.L2_LOSS:
loss_dict[loss_name] = tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=kd.loss_weight)
labels=label, predictions=pred, weights=loss_weight)
else:
loss_param = kd.WhichOneof('loss_param')
kwargs = {}
Expand All @@ -288,7 +319,7 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
kd.loss_type,
label,
pred,
loss_weight=kd.loss_weight,
loss_weight=loss_weight,
loss_param=loss_param,
**kwargs)
return loss_dict
12 changes: 10 additions & 2 deletions easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,8 +921,16 @@ def _preprocess(self, field_dict):
], 'invalid label dtype: %s' % str(field_dict[input_name].dtype)
label_dict[input_name] = field_dict[input_name]

if self._data_config.HasField('sample_weight'):
if self._mode != tf.estimator.ModeKeys.PREDICT:
if self._mode != tf.estimator.ModeKeys.PREDICT:
for func_config in self._data_config.extra_label_func:
lbl_name = func_config.label_name
func_name = func_config.label_func
logging.info('generating new label `%s` by transform: %s' %
(lbl_name, func_name))
lbl_fn = load_by_path(func_name)
label_dict[lbl_name] = lbl_fn(label_dict)

if self._data_config.HasField('sample_weight'):
parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
self._data_config.sample_weight]

Expand Down
104 changes: 100 additions & 4 deletions easy_rec/python/loss/listwise_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import tensorflow as tf

from easy_rec.python.utils.load_class import load_by_path

def list_wise_loss(x, labels, logits, session_ids, label_is_logits):

def _list_wise_loss(x, labels, logits, session_ids, label_is_logits):
mask = tf.equal(x, session_ids)
logits = tf.boolean_mask(logits, mask)
labels = tf.boolean_mask(labels, mask)
Expand All @@ -14,12 +16,22 @@ def list_wise_loss(x, labels, logits, session_ids, label_is_logits):
return -tf.reduce_sum(y * y_hat)


def _list_prob_loss(x, labels, logits, session_ids):
mask = tf.equal(x, session_ids)
logits = tf.boolean_mask(logits, mask)
labels = tf.boolean_mask(labels, mask)
y = labels / tf.reduce_sum(labels)
y_hat = tf.nn.log_softmax(logits)
return -tf.reduce_sum(y * y_hat)


def listwise_rank_loss(labels,
logits,
session_ids,
transform_fn=None,
temperature=1.0,
label_is_logits=False,
scale_logits=False,
weights=1.0,
name='listwise_loss'):
r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
Expand All @@ -39,23 +51,107 @@ def listwise_rank_loss(labels,
temperature: (Optional) The temperature to use for scaling the logits.
label_is_logits: Whether `labels` is expected to be a logits tensor.
By default, we consider that `labels` encodes a probability distribution.
scale_logits: Whether to scale the logits.
weights: sample weights
name: the name of loss
"""
loss_name = name if name else 'listwise_rank_loss'
logging.info('[{}] temperature: {}'.format(loss_name, temperature))
logging.info('[{}] temperature: {}, scale logits: {}'.format(
loss_name, temperature, scale_logits))
labels = tf.to_float(labels)
if scale_logits:
with tf.variable_scope(loss_name):
w = tf.get_variable(
'scale_w',
dtype=tf.float32,
shape=(1,),
initializer=tf.ones_initializer())
b = tf.get_variable(
'scale_b',
dtype=tf.float32,
shape=(1,),
initializer=tf.zeros_initializer())
logits = logits * tf.abs(w) + b
if temperature != 1.0:
logits /= temperature
if label_is_logits:
labels /= temperature
if transform_fn is not None:
labels = transform_fn(labels)
trans_fn = load_by_path(transform_fn)
labels = trans_fn(labels)

sessions, _ = tf.unique(tf.squeeze(session_ids))
tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
losses = tf.map_fn(
lambda x: _list_wise_loss(x, labels, logits, session_ids, label_is_logits
),
sessions,
dtype=tf.float32)
if tf.is_numeric_tensor(weights):
logging.error('[%s] use unsupported sample weight' % loss_name)
return tf.reduce_mean(losses)
else:
return tf.reduce_mean(losses) * weights


def listwise_distill_loss(labels,
logits,
session_ids,
transform_fn=None,
temperature=1.0,
label_clip_max_value=512,
scale_logits=False,
weights=1.0,
name='listwise_distill_loss'):
r"""Computes listwise softmax cross entropy loss between `labels` and `logits`.
Definition:
$$
\mathcal{L}(\{y\}, \{s\}) =
\sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} )
$$
Args:
labels: A `Tensor` of the same shape as `logits` representing the rank position of a base model.
logits: A `Tensor` with shape [batch_size].
session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id
transform_fn: an transformation function of labels.
temperature: (Optional) The temperature to use for scaling the logits.
label_clip_max_value: clip the labels to this value.
scale_logits: Whether to scale the logits.
weights: sample weights
name: the name of loss
"""
loss_name = name if name else 'listwise_rank_loss'
logging.info('[{}] temperature: {}'.format(loss_name, temperature))
labels = tf.to_float(labels) # supposed to be positions of a teacher model
labels = tf.clip_by_value(labels, 1, label_clip_max_value)
if transform_fn is not None:
trans_fn = load_by_path(transform_fn)
labels = trans_fn(labels)
else:
labels = tf.log1p(label_clip_max_value) - tf.log(labels)

if scale_logits:
with tf.variable_scope(loss_name):
w = tf.get_variable(
'scale_w',
dtype=tf.float32,
shape=(1,),
initializer=tf.ones_initializer())
b = tf.get_variable(
'scale_b',
dtype=tf.float32,
shape=(1,),
initializer=tf.zeros_initializer())
logits = logits * tf.abs(w) + b
if temperature != 1.0:
logits /= temperature

sessions, _ = tf.unique(tf.squeeze(session_ids))
tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions))
losses = tf.map_fn(
lambda x: list_wise_loss(x, labels, logits, session_ids, label_is_logits),
lambda x: _list_prob_loss(x, labels, logits, session_ids),
sessions,
dtype=tf.float32)
if tf.is_numeric_tensor(weights):
Expand Down
19 changes: 13 additions & 6 deletions easy_rec/python/model/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def __init__(self,
self._num_class = self._model_config.num_class
self._losses = self._model_config.losses
if self._labels is not None:
self._label_name = list(self._labels.keys())[0]
if model_config.HasField('label_name'):
self._label_name = model_config.label_name
else:
self._label_name = list(self._labels.keys())[0]
self._outputs = []

def build_predict_graph(self):
Expand Down Expand Up @@ -58,7 +61,8 @@ def _output_to_prediction_impl(self,
LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
LossType.PAIRWISE_LOGISTIC_LOSS, LossType.BINARY_CROSS_ENTROPY_LOSS
LossType.PAIRWISE_LOGISTIC_LOSS, LossType.BINARY_CROSS_ENTROPY_LOSS,
LossType.LISTWISE_DISTILL_LOSS
}
if loss_type in binary_loss_type:
assert num_class == 1, 'num_class must be 1 when loss type is %s' % loss_type.name
Expand Down Expand Up @@ -133,7 +137,8 @@ def build_rtp_output_dict(self):
LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
LossType.JRC_LOSS
LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
LossType.LISTWISE_RANK_LOSS
}
if loss_types & binary_loss_set:
if 'probs' in self._prediction_dict:
Expand Down Expand Up @@ -175,7 +180,8 @@ def _build_loss_impl(self,
LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS
LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS,
LossType.LISTWISE_DISTILL_LOSS
}
if loss_type in {
LossType.CLASSIFICATION, LossType.BINARY_CROSS_ENTROPY_LOSS
Expand Down Expand Up @@ -280,7 +286,8 @@ def _build_metric_impl(self,
LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS,
LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
LossType.JRC_LOSS
LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
LossType.LISTWISE_RANK_LOSS
}
metric_dict = {}
if metric.WhichOneof('metric') == 'auc':
Expand Down Expand Up @@ -421,7 +428,7 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
LossType.PAIRWISE_LOGISTIC_LOSS
LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS
}
if loss_type in binary_loss_set:
return ['probs' + suffix, 'logits' + suffix]
Expand Down
8 changes: 8 additions & 0 deletions easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ message DatasetConfig {
// are labels have dimension > 1
repeated uint32 label_dim = 42;

message LabelFunction {
required string label_name = 1;
required string label_func = 2;
}

// extra transformation functions that generate new labels
repeated LabelFunction extra_label_func = 43;

// whether to shuffle data
optional bool shuffle = 5 [default = true];

Expand Down
13 changes: 13 additions & 0 deletions easy_rec/python/protos/easy_rec_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ message KD {
optional float loss_weight = 4 [default=1.0];
// only for loss_type == CROSS_ENTROPY_LOSS or BINARY_CROSS_ENTROPY_LOSS or KL_DIVERGENCE_LOSS
optional float temperature = 5 [default=1.0];
// field name for indicating the sample space for this task
optional string task_space_indicator_name = 6;
// field value for indicating the sample space for this task
optional string task_space_indicator_value = 7;
// the loss weight for sample in the task space
optional float in_task_space_weight = 8 [default = 1.0];
// the loss weight for sample out the task space
optional float out_task_space_weight = 9 [default = 1.0];

oneof loss_param {
F1ReweighedLoss f1_reweighted_loss = 101;
SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102;
Expand All @@ -65,6 +74,7 @@ message KD {
JRCLoss jrc_loss = 109;
PairwiseHingeLoss pairwise_hinge_loss = 110;
ListwiseRankLoss listwise_rank_loss = 111;
ListwiseDistillLoss listwise_distill_loss = 112;
}
}

Expand Down Expand Up @@ -135,4 +145,7 @@ message EasyRecModel {
required LossWeightStrategy loss_weight_strategy = 16 [default = Fixed];

optional BackboneTower backbone = 17;

// label name for rank_model to select one label between multiple labels
optional string label_name = 18;
}
11 changes: 11 additions & 0 deletions easy_rec/python/protos/loss.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum LossType {
BINARY_CROSS_ENTROPY_LOSS = 15;
KL_DIVERGENCE_LOSS = 16;
LISTWISE_RANK_LOSS = 18;
LISTWISE_DISTILL_LOSS = 19;
}

message Loss {
Expand All @@ -41,6 +42,7 @@ message Loss {
JRCLoss jrc_loss = 109;
PairwiseHingeLoss pairwise_hinge_loss = 110;
ListwiseRankLoss listwise_rank_loss = 111;
ListwiseDistillLoss listwise_distill_loss = 112;
}
};

Expand Down Expand Up @@ -120,4 +122,13 @@ message ListwiseRankLoss {
optional string session_name = 2;
optional string transform_fn = 3;
optional bool label_is_logits = 4 [default = false];
optional bool scale_logits = 5 [default = false];
}

message ListwiseDistillLoss {
required float temperature = 1 [default = 1.0];
optional string session_name = 2;
optional string transform_fn = 3;
optional float label_clip_max_value = 4 [default = 512.0];
optional bool scale_logits = 5 [default = false];
}
Loading

0 comments on commit de96ccf

Please sign in to comment.