Skip to content

Latest commit

 

History

History
211 lines (165 loc) · 5.33 KB

user_define.md

File metadata and controls

211 lines (165 loc) · 5.33 KB

自定义模型

编写模型proto文件

TorchEasyRec使用 Protocol Buffer 定义配置文件格式。

tzrec/protos/models/rank_model.proto 中增加一个 CustomRankModel Message来定义模型配置

message CustomRankModel {
  required MLP mlp = 1;
  ...
};

tzrec/protos/model.proto的在oneof model里面增加 CustomRankModel

message ModelConfig {
   ...

   oneof model {
      ...
      CustomRankModel custom_rank_model = 1001;
      ...
   }
   ...
}

生成proto python *_pb2.py 文件

bash scripts/gen_proto.sh

编写模型文件

继承

继承 tzrec.models.model.BaseModel 来实现自定义模型,需重载以下函数

初始化: __init__

  • 根据模型配置model_config和特征配置features构建子模块

前向: predict

  • 根据输入的batch数据,进行前向推理,得到predictions
    • batchtzrec.datasets.utils.Batch的数据结构,包含dense_features(稠密特征)、sparse_features(稀疏特征)、sequence_dense_features (序列稠密特征)
    • 一般可以将dense_featuressparse_featuressequence_dense_features 传给EmbeddingGroup模块tzrec.modules.embedding.EmbeddingGroup得到分组的Embedding结果后,再进行进一步前向推理

损失: init_loss & loss

  • init_loss函数用于根据模型损失函数配置初始化loss模块,写入到self._loss_modules
  • loss函数用于根据输入的predictionsbatch中的label,实际计算每个step的loss,返回一个loss_dict

评估: init_metric & update_metric

  • init_metric函数用于根据模型初始化metric模块,写入到self._metric_modules
  • update_metric函数用于根据输入的predictionsbatch中的label,更新metric模块的状态

常用继承

在排序、多目标排序、召回的场景下,可以直接继承以下子模型,可以只用重置前向推理函数

  • 排序模型可直接继承 tzrec.models.rank_model.RankModel
  • 多目标模型可直接继承 tzrec.models.multi_task_rank.MultiTaskRank
  • 召回模型可直接继承 tzrec.models.match_model.MatchModel

以排序模型为例

# tzrec/model/custom_rank_model.py
from typing import Dict, List

import torch
from torch import nn

from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.rank_model import RankModel
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.mlp import MLP
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.utils.config_util import config_to_kwargs


class CustomRankModel(RankModel):
    """CustomRankModel.

    Args:
        model_config (ModelConfig): an instance of ModelConfig.
        features (list): list of features.
        labels (list): list of label names.
    """

    def __init__(
        self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
    ) -> None:
        super().__init__(model_config, features, labels)
        # 构建EmbeddingGroup
        self.embedding_group = EmbeddingGroup(
            features, list(model_config.feature_groups)
        )
        # 构建MLP
        total_in_dim = sum(self.embedding_group.group_total_dim(n) for n in self.embedding_group.group_names())
        self.mlp = MLP(
            in_features=total_in_dim,
            **config_to_kwargs(self._model_config.mlp),
        )
        final_dim = self.mlp.output_dim()
        self.output_mlp = nn.Linear(final_dim, self._num_class)
        # 初始化其他模块
        ...

    def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
        """Forward the model.

        Args:
            batch (Batch): input batch data.

        Return:
            predictions (dict): a dict of predicted result.
        """
        grouped_features = self.embedding_group(
            batch.sparse_features, batch.dense_features
        )
        features = torch.cat(grouped_features, dim=-1)
        y = self.output_mlp(tower_output)
        # 其他前向推理
        ...
        return self._output_to_prediction(y)

测试

编写 custom_rank_model.config


# 数据相关参数配置
data_config {
  ...
}

# 特征相关参数配置
feature_configs : {
  ...
}
feature_configs : {
  ...
}

# 训练相关的参数配置
train_config {
  ...
}

# 评估相关参数配置
eval_config {
  ...
}

# 模型相关参数配置
model_config: {
    feature_groups: {
        group_name: 'group1'
        feature_names: 'f1'
        feature_names: 'f2'
        ...
        wide_deep: DEEP
    }
    feature_groups: {
        group_name: 'group2'
        feature_names: 'f3'
        feature_names: 'f4'
        ...
        wide_deep: DEEP
    }
    ...
    custom_rank_model {
        mlp {
            hidden_units: [64]
        }
        ...
    }
    metrics {
        auc {}
    }
    losses {
        binary_cross_entropy {}
    }
}

运行

PYTHONPATH=. torchrun --master_addr=localhost --master_port=32555 \
    --nnodes=1 --nproc-per-node=2 --node_rank=0 \
    tzrec/train_eval.py \
    --pipeline_config_path custom_rank_model.config \
    --train_input_path ${TRAIN_INPUT_PATH} \
    --eval_input_path ${EVAL_INPUT_PATH} \
    --model_dir ${MODEL_DIR}

打包发布

参考开发指南