From de96ccf1497bce15bd78e6836b2acf50bb69e7d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=AB=E8=8B=8F?= Date: Thu, 5 Sep 2024 15:46:57 +0800 Subject: [PATCH] add support for rank distillation --- easy_rec/python/builders/loss_builder.py | 43 ++++++-- easy_rec/python/input/input.py | 12 ++- easy_rec/python/loss/listwise_loss.py | 104 +++++++++++++++++++- easy_rec/python/model/rank_model.py | 19 ++-- easy_rec/python/protos/dataset.proto | 8 ++ easy_rec/python/protos/easy_rec_model.proto | 13 +++ easy_rec/python/protos/loss.proto | 11 +++ pai_jobs/deploy_ext.sh | 44 ++++----- 8 files changed, 214 insertions(+), 40 deletions(-) diff --git a/easy_rec/python/builders/loss_builder.py b/easy_rec/python/builders/loss_builder.py index a80288cf4..36cdd95b4 100644 --- a/easy_rec/python/builders/loss_builder.py +++ b/easy_rec/python/builders/loss_builder.py @@ -6,6 +6,7 @@ from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits from easy_rec.python.loss.jrc_loss import jrc_loss +from easy_rec.python.loss.listwise_loss import listwise_distill_loss from easy_rec.python.loss.listwise_loss import listwise_rank_loss from easy_rec.python.loss.pairwise_loss import pairwise_focal_loss from easy_rec.python.loss.pairwise_loss import pairwise_hinge_loss @@ -132,10 +133,11 @@ def build(loss_type, name=loss_name) elif loss_type == LossType.LISTWISE_RANK_LOSS: session = kwargs.get('session_ids', None) - trans_fn, temp, label_is_logits = None, 1.0, False + trans_fn, temp, label_is_logits, scale = None, 1.0, False, False if loss_param is not None: temp = loss_param.temperature label_is_logits = loss_param.label_is_logits + scale = loss_param.scale_logits if loss_param.HasField('transform_fn'): trans_fn = loss_param.transform_fn return listwise_rank_loss( @@ -145,6 +147,25 @@ def build(loss_type, temperature=temp, label_is_logits=label_is_logits, transform_fn=trans_fn, + scale_logits=scale, + weights=loss_weight) + elif loss_type == LossType.LISTWISE_DISTILL_LOSS: + session = kwargs.get('session_ids', None) + trans_fn, temp, label_clip_max_value, scale = None, 1.0, 512.0, False + if loss_param is not None: + temp = loss_param.temperature + label_clip_max_value = loss_param.label_clip_max_value + scale = loss_param.scale_logits + if loss_param.HasField('transform_fn'): + trans_fn = loss_param.transform_fn + return listwise_distill_loss( + label, + pred, + session, + temperature=temp, + label_clip_max_value=label_clip_max_value, + transform_fn=trans_fn, + scale_logits=scale, weights=loss_weight) elif loss_type == LossType.F1_REWEIGHTED_LOSS: f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square @@ -199,6 +220,16 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict): loss_name = 'kd_loss_' + kd.pred_name.replace('/', '_') loss_name += '_' + kd.soft_label_name.replace('/', '_') + loss_weight = kd.loss_weight + if kd.HasField('task_space_indicator_name') and kd.HasField( + 'task_space_indicator_value'): + in_task_space = tf.to_float( + tf.equal(feature_dict[kd.task_space_indicator_name], + kd.task_space_indicator_value)) + loss_weight = loss_weight * ( + kd.in_task_space_weight * in_task_space + kd.out_task_space_weight * + (1 - in_task_space)) + label = label_dict[kd.soft_label_name] pred = prediction_dict[kd.pred_name] epsilon = tf.keras.backend.epsilon() @@ -265,18 +296,18 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict): preds = pred losses = tf.keras.losses.KLD(labels, preds) loss_dict[loss_name] = tf.reduce_mean( - losses, name=loss_name) * kd.loss_weight + losses, name=loss_name) * loss_weight elif kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS: losses = tf.keras.backend.binary_crossentropy( label, pred, from_logits=True) loss_dict[loss_name] = tf.reduce_mean( - losses, name=loss_name) * kd.loss_weight + losses, name=loss_name) * loss_weight elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS: loss_dict[loss_name] = tf.losses.log_loss( - label, pred, weights=kd.loss_weight) + label, pred, weights=loss_weight) elif kd.loss_type == LossType.L2_LOSS: loss_dict[loss_name] = tf.losses.mean_squared_error( - labels=label, predictions=pred, weights=kd.loss_weight) + labels=label, predictions=pred, weights=loss_weight) else: loss_param = kd.WhichOneof('loss_param') kwargs = {} @@ -288,7 +319,7 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict): kd.loss_type, label, pred, - loss_weight=kd.loss_weight, + loss_weight=loss_weight, loss_param=loss_param, **kwargs) return loss_dict diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index d94b1de13..f53c4ee45 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -921,8 +921,16 @@ def _preprocess(self, field_dict): ], 'invalid label dtype: %s' % str(field_dict[input_name].dtype) label_dict[input_name] = field_dict[input_name] - if self._data_config.HasField('sample_weight'): - if self._mode != tf.estimator.ModeKeys.PREDICT: + if self._mode != tf.estimator.ModeKeys.PREDICT: + for func_config in self._data_config.extra_label_func: + lbl_name = func_config.label_name + func_name = func_config.label_func + logging.info('generating new label `%s` by transform: %s' % + (lbl_name, func_name)) + lbl_fn = load_by_path(func_name) + label_dict[lbl_name] = lbl_fn(label_dict) + + if self._data_config.HasField('sample_weight'): parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[ self._data_config.sample_weight] diff --git a/easy_rec/python/loss/listwise_loss.py b/easy_rec/python/loss/listwise_loss.py index f778f38f8..24bd5864f 100644 --- a/easy_rec/python/loss/listwise_loss.py +++ b/easy_rec/python/loss/listwise_loss.py @@ -4,8 +4,10 @@ import tensorflow as tf +from easy_rec.python.utils.load_class import load_by_path -def list_wise_loss(x, labels, logits, session_ids, label_is_logits): + +def _list_wise_loss(x, labels, logits, session_ids, label_is_logits): mask = tf.equal(x, session_ids) logits = tf.boolean_mask(logits, mask) labels = tf.boolean_mask(labels, mask) @@ -14,12 +16,22 @@ def list_wise_loss(x, labels, logits, session_ids, label_is_logits): return -tf.reduce_sum(y * y_hat) +def _list_prob_loss(x, labels, logits, session_ids): + mask = tf.equal(x, session_ids) + logits = tf.boolean_mask(logits, mask) + labels = tf.boolean_mask(labels, mask) + y = labels / tf.reduce_sum(labels) + y_hat = tf.nn.log_softmax(logits) + return -tf.reduce_sum(y * y_hat) + + def listwise_rank_loss(labels, logits, session_ids, transform_fn=None, temperature=1.0, label_is_logits=False, + scale_logits=False, weights=1.0, name='listwise_loss'): r"""Computes listwise softmax cross entropy loss between `labels` and `logits`. @@ -39,23 +51,107 @@ def listwise_rank_loss(labels, temperature: (Optional) The temperature to use for scaling the logits. label_is_logits: Whether `labels` is expected to be a logits tensor. By default, we consider that `labels` encodes a probability distribution. + scale_logits: Whether to scale the logits. weights: sample weights name: the name of loss """ loss_name = name if name else 'listwise_rank_loss' - logging.info('[{}] temperature: {}'.format(loss_name, temperature)) + logging.info('[{}] temperature: {}, scale logits: {}'.format( + loss_name, temperature, scale_logits)) labels = tf.to_float(labels) + if scale_logits: + with tf.variable_scope(loss_name): + w = tf.get_variable( + 'scale_w', + dtype=tf.float32, + shape=(1,), + initializer=tf.ones_initializer()) + b = tf.get_variable( + 'scale_b', + dtype=tf.float32, + shape=(1,), + initializer=tf.zeros_initializer()) + logits = logits * tf.abs(w) + b if temperature != 1.0: logits /= temperature if label_is_logits: labels /= temperature if transform_fn is not None: - labels = transform_fn(labels) + trans_fn = load_by_path(transform_fn) + labels = trans_fn(labels) + + sessions, _ = tf.unique(tf.squeeze(session_ids)) + tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions)) + losses = tf.map_fn( + lambda x: _list_wise_loss(x, labels, logits, session_ids, label_is_logits + ), + sessions, + dtype=tf.float32) + if tf.is_numeric_tensor(weights): + logging.error('[%s] use unsupported sample weight' % loss_name) + return tf.reduce_mean(losses) + else: + return tf.reduce_mean(losses) * weights + + +def listwise_distill_loss(labels, + logits, + session_ids, + transform_fn=None, + temperature=1.0, + label_clip_max_value=512, + scale_logits=False, + weights=1.0, + name='listwise_distill_loss'): + r"""Computes listwise softmax cross entropy loss between `labels` and `logits`. + + Definition: + $$ + \mathcal{L}(\{y\}, \{s\}) = + \sum_i y_j \log( \frac{\exp(s_i)}{\sum_j exp(s_j)} ) + $$ + + Args: + labels: A `Tensor` of the same shape as `logits` representing the rank position of a base model. + logits: A `Tensor` with shape [batch_size]. + session_ids: a `Tensor` with shape [batch_size]. Session ids of each sample, used to max GAUC metric. e.g. user_id + transform_fn: an transformation function of labels. + temperature: (Optional) The temperature to use for scaling the logits. + label_clip_max_value: clip the labels to this value. + scale_logits: Whether to scale the logits. + weights: sample weights + name: the name of loss + """ + loss_name = name if name else 'listwise_rank_loss' + logging.info('[{}] temperature: {}'.format(loss_name, temperature)) + labels = tf.to_float(labels) # supposed to be positions of a teacher model + labels = tf.clip_by_value(labels, 1, label_clip_max_value) + if transform_fn is not None: + trans_fn = load_by_path(transform_fn) + labels = trans_fn(labels) + else: + labels = tf.log1p(label_clip_max_value) - tf.log(labels) + + if scale_logits: + with tf.variable_scope(loss_name): + w = tf.get_variable( + 'scale_w', + dtype=tf.float32, + shape=(1,), + initializer=tf.ones_initializer()) + b = tf.get_variable( + 'scale_b', + dtype=tf.float32, + shape=(1,), + initializer=tf.zeros_initializer()) + logits = logits * tf.abs(w) + b + if temperature != 1.0: + logits /= temperature sessions, _ = tf.unique(tf.squeeze(session_ids)) tf.summary.scalar('loss/%s_num_of_group' % loss_name, tf.size(sessions)) losses = tf.map_fn( - lambda x: list_wise_loss(x, labels, logits, session_ids, label_is_logits), + lambda x: _list_prob_loss(x, labels, logits, session_ids), sessions, dtype=tf.float32) if tf.is_numeric_tensor(weights): diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py index 8c5b21afd..fc8e5214c 100644 --- a/easy_rec/python/model/rank_model.py +++ b/easy_rec/python/model/rank_model.py @@ -27,7 +27,10 @@ def __init__(self, self._num_class = self._model_config.num_class self._losses = self._model_config.losses if self._labels is not None: - self._label_name = list(self._labels.keys())[0] + if model_config.HasField('label_name'): + self._label_name = model_config.label_name + else: + self._label_name = list(self._labels.keys())[0] self._outputs = [] def build_predict_graph(self): @@ -58,7 +61,8 @@ def _output_to_prediction_impl(self, LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, - LossType.PAIRWISE_LOGISTIC_LOSS, LossType.BINARY_CROSS_ENTROPY_LOSS + LossType.PAIRWISE_LOGISTIC_LOSS, LossType.BINARY_CROSS_ENTROPY_LOSS, + LossType.LISTWISE_DISTILL_LOSS } if loss_type in binary_loss_type: assert num_class == 1, 'num_class must be 1 when loss type is %s' % loss_type.name @@ -133,7 +137,8 @@ def build_rtp_output_dict(self): LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS, - LossType.JRC_LOSS + LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS, + LossType.LISTWISE_RANK_LOSS } if loss_types & binary_loss_set: if 'probs' in self._prediction_dict: @@ -175,7 +180,8 @@ def _build_loss_impl(self, LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, - LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS + LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS, + LossType.LISTWISE_DISTILL_LOSS } if loss_type in { LossType.CLASSIFICATION, LossType.BINARY_CROSS_ENTROPY_LOSS @@ -280,7 +286,8 @@ def _build_metric_impl(self, LossType.CLASSIFICATION, LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS, - LossType.JRC_LOSS + LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS, + LossType.LISTWISE_RANK_LOSS } metric_dict = {} if metric.WhichOneof('metric') == 'auc': @@ -421,7 +428,7 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''): LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, - LossType.PAIRWISE_LOGISTIC_LOSS + LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS } if loss_type in binary_loss_set: return ['probs' + suffix, 'logits' + suffix] diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto index 5ffefd064..ca19dcd04 100644 --- a/easy_rec/python/protos/dataset.proto +++ b/easy_rec/python/protos/dataset.proto @@ -176,6 +176,14 @@ message DatasetConfig { // are labels have dimension > 1 repeated uint32 label_dim = 42; + message LabelFunction { + required string label_name = 1; + required string label_func = 2; + } + + // extra transformation functions that generate new labels + repeated LabelFunction extra_label_func = 43; + // whether to shuffle data optional bool shuffle = 5 [default = true]; diff --git a/easy_rec/python/protos/easy_rec_model.proto b/easy_rec/python/protos/easy_rec_model.proto index 1fc40b85c..56f5b713e 100644 --- a/easy_rec/python/protos/easy_rec_model.proto +++ b/easy_rec/python/protos/easy_rec_model.proto @@ -53,6 +53,15 @@ message KD { optional float loss_weight = 4 [default=1.0]; // only for loss_type == CROSS_ENTROPY_LOSS or BINARY_CROSS_ENTROPY_LOSS or KL_DIVERGENCE_LOSS optional float temperature = 5 [default=1.0]; + // field name for indicating the sample space for this task + optional string task_space_indicator_name = 6; + // field value for indicating the sample space for this task + optional string task_space_indicator_value = 7; + // the loss weight for sample in the task space + optional float in_task_space_weight = 8 [default = 1.0]; + // the loss weight for sample out the task space + optional float out_task_space_weight = 9 [default = 1.0]; + oneof loss_param { F1ReweighedLoss f1_reweighted_loss = 101; SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102; @@ -65,6 +74,7 @@ message KD { JRCLoss jrc_loss = 109; PairwiseHingeLoss pairwise_hinge_loss = 110; ListwiseRankLoss listwise_rank_loss = 111; + ListwiseDistillLoss listwise_distill_loss = 112; } } @@ -135,4 +145,7 @@ message EasyRecModel { required LossWeightStrategy loss_weight_strategy = 16 [default = Fixed]; optional BackboneTower backbone = 17; + + // label name for rank_model to select one label between multiple labels + optional string label_name = 18; } diff --git a/easy_rec/python/protos/loss.proto b/easy_rec/python/protos/loss.proto index 4b49126c2..b377cd75c 100644 --- a/easy_rec/python/protos/loss.proto +++ b/easy_rec/python/protos/loss.proto @@ -22,6 +22,7 @@ enum LossType { BINARY_CROSS_ENTROPY_LOSS = 15; KL_DIVERGENCE_LOSS = 16; LISTWISE_RANK_LOSS = 18; + LISTWISE_DISTILL_LOSS = 19; } message Loss { @@ -41,6 +42,7 @@ message Loss { JRCLoss jrc_loss = 109; PairwiseHingeLoss pairwise_hinge_loss = 110; ListwiseRankLoss listwise_rank_loss = 111; + ListwiseDistillLoss listwise_distill_loss = 112; } }; @@ -120,4 +122,13 @@ message ListwiseRankLoss { optional string session_name = 2; optional string transform_fn = 3; optional bool label_is_logits = 4 [default = false]; + optional bool scale_logits = 5 [default = false]; +} + +message ListwiseDistillLoss { + required float temperature = 1 [default = 1.0]; + optional string session_name = 2; + optional string transform_fn = 3; + optional float label_clip_max_value = 4 [default = 512.0]; + optional bool scale_logits = 5 [default = false]; } diff --git a/pai_jobs/deploy_ext.sh b/pai_jobs/deploy_ext.sh index 26a1dd091..796cde6bf 100755 --- a/pai_jobs/deploy_ext.sh +++ b/pai_jobs/deploy_ext.sh @@ -59,27 +59,27 @@ then exit 1 fi -ODPSCMD=`which $ODPSCMD` -if [ $? -ne 0 ] && [ $mode -ne 2 ] -then - echo "$ODPSCMD is not in PATH" - exit 1 -fi - -if [ ! -e "$odps_config" ] && [ $mode -ne 2 ] -then - if [ -z "$odps_config" ] - then - echo "odps_config is not set" - else - echo "odps_config[$odps_config] does not exist" - fi - exit 1 -fi -if [ -e "$odps_config" ] -then - odps_config=`readlink -f $odps_config` -fi +#ODPSCMD=`which $ODPSCMD` +#if [ $? -ne 0 ] && [ $mode -ne 2 ] +#then +# echo "$ODPSCMD is not in PATH" +# exit 1 +#fi +# +#if [ ! -e "$odps_config" ] && [ $mode -ne 2 ] +#then +# if [ -z "$odps_config" ] +# then +# echo "odps_config is not set" +# else +# echo "odps_config[$odps_config] does not exist" +# fi +# exit 1 +#fi +#if [ -e "$odps_config" ] +#then +# odps_config=`readlink -f $odps_config` +#fi cd $root_dir sh scripts/gen_proto.sh @@ -144,7 +144,7 @@ then fi tar -cvzhf $RES_PATH easy_rec datahub lz4 cprotobuf kafka faiss run.py - +exit # 2 means generate only if [ $mode -ne 2 ] then