Skip to content

Commit

Permalink
Update DCNv3 model (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
salmon1802 authored Aug 12, 2024
1 parent 940967a commit 789bb88
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 231 deletions.
15 changes: 7 additions & 8 deletions model_zoo/DCNv3/README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# DCNv3 & SDCNv3
# DCNv3

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dcnv3-towards-next-generation-deep-cross/click-through-rate-prediction-on-criteo)](https://paperswithcode.com/sota/click-through-rate-prediction-on-criteo?p=dcnv3-towards-next-generation-deep-cross)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dcnv3-towards-next-generation-deep-cross/click-through-rate-prediction-on-kdd12)](https://paperswithcode.com/sota/click-through-rate-prediction-on-kdd12?p=dcnv3-towards-next-generation-deep-cross)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dcnv3-towards-next-generation-deep-cross/click-through-rate-prediction-on-kkbox)](https://paperswithcode.com/sota/click-through-rate-prediction-on-kkbox?p=dcnv3-towards-next-generation-deep-cross)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dcnv3-towards-next-generation-deep-cross/click-through-rate-prediction-on-ipinyou)](https://paperswithcode.com/sota/click-through-rate-prediction-on-ipinyou?p=dcnv3-towards-next-generation-deep-cross)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dcnv3-towards-next-generation-deep-cross/click-through-rate-prediction-on-avazu)](https://paperswithcode.com/sota/click-through-rate-prediction-on-avazu?p=dcnv3-towards-next-generation-deep-cross)

We introduces the next generation deep cross networks, DCNv3 and SDCNv3. The former explicitly captures feature interaction through an exponentially growing modeling method, and further filters noise signals via the Self-Mask operation, reducing the parameter count by half. The latter builds on DCNv3 by incorporating the shallow cross network, SCNv3, to capture both high-order and low-order feature interactions without relying on the less interpretable DNN. Tri-BCE helps the two sub-networks in SDCNv3 obtain more suitable supervision signals for themselves.
> Li, Honghao and Zhang, Yiwen and Zhang, Yi and Li, Hanwei and Sang, Lei. [DCNv3: Towards Next Generation Deep Cross Network for Click-Through Rate Prediction](https://arxiv.org/abs/2407.13349).
We introduces the next generation deep cross networks, called DCNv3, which uses sub-networks LCN and ECN to capture both low-order and high-order feature interactions without relying on the less interpretable DNN. LCN uses a linearly growing interaction method for low-order interactions, while ECN employs an exponentially increasing method for high-order interactions. The Self-Mask filters interaction noise and further improves DCNv3’s computational efficiency. Tri-BCE helped the two sub-networks in DCNv3 obtain more suitable supervision signals for themselves. Comprehensive experiments on six datasets demonstrated the effectiveness, efficiency, and interpretability of DCNv3.
> Li, Honghao and Zhang, Yiwen and Zhang, Yi and Li, Hanwei and Sang, Lei and Zhu, Jieming. [DCNv3: Towards Next Generation Deep Cross Network for Click-Through Rate Prediction](https://arxiv.org/abs/2407.13349).
## Model Overview

<div align="center">
<img src="https://github.com/user-attachments/assets/6b0df396-d4ee-4475-ac02-21538ae0ef27" alt="SDCNv3" />
<img src="https://github.com/user-attachments/assets/6f0479ce-edb2-4ad1-92a0-7ec9aeeb5f2d" alt="DCNv3" />
</div>


## Requirements

We have tested FinalMLP with the following requirements.
Expand All @@ -25,7 +26,7 @@ pytorch: 1.10
fuxictr: 2.0.1
```

## SDCNv3 Configuration Guide
## DCNv3 Configuration Guide


The `dataset_config.yaml` file contains all the dataset settings as follows.
Expand Down Expand Up @@ -81,7 +82,7 @@ The `model_config.yaml` file contains all the model hyper-parameters as follows.
| save_best_only | bool | True | whether to save the best model checkpoint only |
| eval_steps | int\|None | None | evaluate the model on validation data every ```eval_steps```. By default, ```None``` means evaluation every epoch. |

## DCNv3 Configuration Guide
## ECN Configuration Guide


The `dataset_config.yaml` file contains all the dataset settings as follows.
Expand Down Expand Up @@ -140,6 +141,4 @@ The `model_config.yaml` file contains all the model hyper-parameters as follows.

## Results

AUC's evaluation results can be found [here](https://github.com/salmon1802/DCNv3).

For reproducing the results, please refer to https://github.com/salmon1802/DCNv3/tree/master/checkpoints
8 changes: 4 additions & 4 deletions model_zoo/DCNv3/config/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Base:
feature_specs: null
feature_config: null

DCNv3_test:
model: DCNv3
ECN_test:
model: ECN
dataset_id: tiny_npz
loss: 'binary_crossentropy'
metrics: ['logloss', 'AUC']
Expand All @@ -35,8 +35,8 @@ DCNv3_test:
monitor: {'AUC': 1, 'logloss': 0}
monitor_mode: 'max'

SDCNv3_test:
model: SDCNv3
DCNv3_test:
model: DCNv3
dataset_id: tiny_npz
loss: 'binary_crossentropy'
metrics: ['logloss', 'AUC']
Expand Down
116 changes: 96 additions & 20 deletions model_zoo/DCNv3/src/DCNv3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =========================================================================
# Copyright (C) 2024 salmon1802@github
# Copyright (C) 2024 salmon@github
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -17,6 +17,7 @@
from torch import nn
from fuxictr.pytorch.models import BaseModel
from fuxictr.pytorch.layers import FeatureEmbedding
from fuxictr.pytorch.torch_utils import get_regularizer


class DCNv3(BaseModel):
Expand All @@ -25,9 +26,11 @@ def __init__(self,
model_id="DCNv3",
gpu=-1,
learning_rate=1e-3,
embedding_dim=16,
num_cross_layers=3,
net_dropout=0,
embedding_dim=10,
num_deep_cross_layers=4,
num_shallow_cross_layers=4,
deep_net_dropout=0.1,
shallow_net_dropout=0.3,
layer_norm=True,
batch_norm=False,
num_heads=1,
Expand All @@ -42,24 +45,50 @@ def __init__(self,
**kwargs)
self.embedding_layer = MultiHeadFeatureEmbedding(feature_map, embedding_dim * num_heads, num_heads)
input_dim = feature_map.sum_emb_out_dim()
self.dcnv3 = DeepCrossNetv3(input_dim=input_dim,
num_cross_layers=num_cross_layers,
net_dropout=net_dropout,
layer_norm=layer_norm,
batch_norm=batch_norm,
num_heads=num_heads)
self.ECN = ExponentialCrossNetwork(input_dim=input_dim,
num_cross_layers=num_deep_cross_layers,
net_dropout=deep_net_dropout,
layer_norm=layer_norm,
batch_norm=batch_norm,
num_heads=num_heads)
self.LCN = LinearCrossNetwork(input_dim=input_dim,
num_cross_layers=num_shallow_cross_layers,
net_dropout=shallow_net_dropout,
layer_norm=layer_norm,
batch_norm=batch_norm,
num_heads=num_heads)
self.compile(kwargs["optimizer"], kwargs["loss"], learning_rate)
self.reset_parameters()
self.model_to_device()

def forward(self, inputs):
X = self.get_inputs(inputs)
feature_emb = self.embedding_layer(X) # B × H × FD/H
y_pred = self.dcnv3(feature_emb).mean(dim=1)
y_pred = self.output_activation(y_pred)
return_dict = {"y_pred": y_pred}
feature_emb = self.embedding_layer(X)
dlogit = self.ECN(feature_emb).mean(dim=1)
slogit = self.LCN(feature_emb).mean(dim=1)
logit = (dlogit + slogit) * 0.5
y_pred = self.output_activation(logit)
return_dict = {"y_pred": y_pred,
"y_d": self.output_activation(dlogit),
"y_s": self.output_activation(slogit)}
return return_dict

def add_loss(self, inputs):
return_dict = self.forward(inputs)
y_true = self.get_labels(inputs)
y_pred = return_dict["y_pred"]
y_d = return_dict["y_d"]
y_s = return_dict["y_s"]
loss = self.loss_fn(y_pred, y_true, reduction='mean')
loss_d = self.loss_fn(y_d, y_true, reduction='mean')
loss_s = self.loss_fn(y_s, y_true, reduction='mean')
weight_d = loss_d - loss
weight_s = loss_s - loss
weight_d = torch.where(weight_d > 0, weight_d, torch.zeros(1).to(weight_d.device))
weight_s = torch.where(weight_s > 0, weight_s, torch.zeros(1).to(weight_s.device))
loss = loss + loss_d * weight_d + loss_s * weight_s
return loss


class MultiHeadFeatureEmbedding(nn.Module):
def __init__(self, feature_map, embedding_dim, num_heads=2):
Expand All @@ -80,15 +109,15 @@ def forward(self, X): # H = num_heads
return multihead_feature_emb # B × H × FD/H


class DeepCrossNetv3(nn.Module):
class ExponentialCrossNetwork(nn.Module):
def __init__(self,
input_dim,
num_cross_layers=3,
layer_norm=True,
batch_norm=True,
batch_norm=False,
net_dropout=0.1,
num_heads=1):
super(DeepCrossNetv3, self).__init__()
super(ExponentialCrossNetwork, self).__init__()
self.num_cross_layers = num_cross_layers
self.layer_norm = nn.ModuleList()
self.batch_norm = nn.ModuleList()
Expand All @@ -97,7 +126,7 @@ def __init__(self,
self.b = nn.ParameterList()
for i in range(num_cross_layers):
self.w.append(nn.Linear(input_dim, input_dim // 2, bias=False))
self.b.append(nn.Parameter(torch.empty((input_dim,))))
self.b.append(nn.Parameter(torch.zeros((input_dim,))))
if layer_norm:
self.layer_norm.append(nn.LayerNorm(input_dim // 2))
if batch_norm:
Expand All @@ -106,7 +135,7 @@ def __init__(self,
self.dropout.append(nn.Dropout(net_dropout))
nn.init.uniform_(self.b[i].data)
self.masker = nn.ReLU()
self.fc = nn.Linear(input_dim, 1)
self.dfc = nn.Linear(input_dim, 1)

def forward(self, x):
for i in range(self.num_cross_layers):
Expand All @@ -122,5 +151,52 @@ def forward(self, x):
x = x * (H + self.b[i]) + x
if len(self.dropout) > i:
x = self.dropout[i](x)
logit = self.fc(x)
logit = self.dfc(x)
return logit


class LinearCrossNetwork(nn.Module):
def __init__(self,
input_dim,
num_cross_layers=3,
layer_norm=True,
batch_norm=True,
net_dropout=0.1,
num_heads=1):
super(LinearCrossNetwork, self).__init__()
self.num_cross_layers = num_cross_layers
self.layer_norm = nn.ModuleList()
self.batch_norm = nn.ModuleList()
self.dropout = nn.ModuleList()
self.w = nn.ModuleList()
self.b = nn.ParameterList()
for i in range(num_cross_layers):
self.w.append(nn.Linear(input_dim, input_dim // 2, bias=False))
self.b.append(nn.Parameter(torch.zeros((input_dim,))))
if layer_norm:
self.layer_norm.append(nn.LayerNorm(input_dim // 2))
if batch_norm:
self.batch_norm.append(nn.BatchNorm1d(num_heads))
if net_dropout > 0:
self.dropout.append(nn.Dropout(net_dropout))
nn.init.uniform_(self.b[i].data)
self.masker = nn.ReLU()
self.sfc = nn.Linear(input_dim, 1)

def forward(self, x):
x0 = x
for i in range(self.num_cross_layers):
H = self.w[i](x)
if len(self.batch_norm) > i:
H = self.batch_norm[i](H)
if len(self.layer_norm) > i:
norm_H = self.layer_norm[i](H)
mask = self.masker(norm_H)
else:
mask = self.masker(H)
H = torch.cat([H, H * mask], dim=-1)
x = x0 * (H + self.b[i]) + x
if len(self.dropout) > i:
x = self.dropout[i](x)
logit = self.sfc(x)
return logit
127 changes: 127 additions & 0 deletions model_zoo/DCNv3/src/ECN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# =========================================================================
# Copyright (C) 2024. The FuxiCTR Library. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================

import torch
from torch import nn
from fuxictr.pytorch.models import BaseModel
from fuxictr.pytorch.layers import FeatureEmbedding


class ECN(BaseModel):
def __init__(self,
feature_map,
model_id="ECN",
gpu=-1,
learning_rate=1e-3,
embedding_dim=16,
num_cross_layers=3,
net_dropout=0,
layer_norm=True,
batch_norm=False,
num_heads=1,
embedding_regularizer=None,
net_regularizer=None,
**kwargs):
super(ECN, self).__init__(feature_map,
model_id=model_id,
gpu=gpu,
embedding_regularizer=embedding_regularizer,
net_regularizer=net_regularizer,
**kwargs)
self.embedding_layer = MultiHeadFeatureEmbedding(feature_map, embedding_dim * num_heads, num_heads)
input_dim = feature_map.sum_emb_out_dim()
self.ECN = ExponentialCrossNetwork(input_dim=input_dim,
num_cross_layers=num_cross_layers,
net_dropout=net_dropout,
layer_norm=layer_norm,
batch_norm=batch_norm,
num_heads=num_heads)
self.compile(kwargs["optimizer"], kwargs["loss"], learning_rate)
self.reset_parameters()
self.model_to_device()

def forward(self, inputs):
X = self.get_inputs(inputs)
feature_emb = self.embedding_layer(X) # B × H × FD/H
y_pred = self.ECN(feature_emb).mean(dim=1)
y_pred = self.output_activation(y_pred)
return_dict = {"y_pred": y_pred}
return return_dict


class MultiHeadFeatureEmbedding(nn.Module):
def __init__(self, feature_map, embedding_dim, num_heads=2):
super(MultiHeadFeatureEmbedding, self).__init__()
self.num_heads = num_heads
self.embedding_layer = FeatureEmbedding(feature_map, embedding_dim)

def forward(self, X): # H = num_heads
feature_emb = self.embedding_layer(X) # B × F × D
multihead_feature_emb = torch.tensor_split(feature_emb, self.num_heads, dim=-1)
multihead_feature_emb = torch.stack(multihead_feature_emb, dim=1) # B × H × F × D/H
multihead_feature_emb1, multihead_feature_emb2 = torch.tensor_split(multihead_feature_emb, 2,
dim=-1) # B × H × F × D/2H
multihead_feature_emb1, multihead_feature_emb2 = multihead_feature_emb1.flatten(start_dim=2), \
multihead_feature_emb2.flatten(
start_dim=2) # B × H × FD/2H; B × H × FD/2H
multihead_feature_emb = torch.cat([multihead_feature_emb1, multihead_feature_emb2], dim=-1)
return multihead_feature_emb # B × H × FD/H


class ExponentialCrossNetwork(nn.Module):
def __init__(self,
input_dim,
num_cross_layers=3,
layer_norm=True,
batch_norm=True,
net_dropout=0.1,
num_heads=1):
super(ExponentialCrossNetwork, self).__init__()
self.num_cross_layers = num_cross_layers
self.layer_norm = nn.ModuleList()
self.batch_norm = nn.ModuleList()
self.dropout = nn.ModuleList()
self.w = nn.ModuleList()
self.b = nn.ParameterList()
for i in range(num_cross_layers):
self.w.append(nn.Linear(input_dim, input_dim // 2, bias=False))
self.b.append(nn.Parameter(torch.empty((input_dim,))))
if layer_norm:
self.layer_norm.append(nn.LayerNorm(input_dim // 2))
if batch_norm:
self.batch_norm.append(nn.BatchNorm1d(num_heads))
if net_dropout > 0:
self.dropout.append(nn.Dropout(net_dropout))
nn.init.uniform_(self.b[i].data)
self.masker = nn.ReLU()
self.fc = nn.Linear(input_dim, 1)

def forward(self, x):
for i in range(self.num_cross_layers):
H = self.w[i](x)
if len(self.batch_norm) > i:
H = self.batch_norm[i](H)
if len(self.layer_norm) > i:
norm_H = self.layer_norm[i](H)
mask = self.masker(norm_H)
else:
mask = self.masker(H)
H = torch.cat([H, H * mask], dim=-1)
x = x * (H + self.b[i]) + x
if len(self.dropout) > i:
x = self.dropout[i](x)
logit = self.fc(x)
return logit
Loading

0 comments on commit 789bb88

Please sign in to comment.