Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The implementation of WuKong #106

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions model_zoo/WuKong/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# WuKong

> Buyun Zhang, Liang Luo, Yuxin Chen, Jade Nie, Xi Liu, Daifeng Guo, Yanli Zhao, Shen Li, Yuchen Hao, Yantao Yao, Guna Lakshminarayanan, Ellie Dingqiao Wen, Jongsoo Park, Maxim Naumov, Wenlin Chen. [Wukong: Towards a Scaling Law for Large-Scale Recommendation](https://arxiv.org/abs/2403.02545), in Arxiv 2024.
7 changes: 7 additions & 0 deletions model_zoo/WuKong/config/dataset_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
### Tiny data for tests only
tiny_npz:
data_root: ../../data/
data_format: npz
train_data: ../../data/tiny_npz/train.npz
valid_data: ../../data/tiny_npz/valid.npz
test_data: ../../data/tiny_npz/test.npz
65 changes: 65 additions & 0 deletions model_zoo/WuKong/config/model_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
Base:
model_root: './checkpoints/'
num_workers: 12
verbose: 1
early_stop_patience: 2
pickle_feature_encoder: True
save_best_only: True
eval_steps: null
debug_mode: False
group_id: null
use_features: null
feature_specs: null
feature_config: null

WuKong_default: # This is a config template
model: WuKong
dataset_id: TBD
loss: 'binary_crossentropy'
metrics: ['logloss', 'AUC']
task: binary_classification
optimizer: adam
learning_rate: 1.0e-3
embedding_regularizer: 0
net_regularizer: 0
batch_size: 10000
embedding_dim: 64
num_layers: 8
compression_dim: 40
fmb_units: [200,200]
fmb_dim: 40
project_dim: 8
dropout_rate: 0.2
hidden_activations: relu
mlp_hidden_units: [32,32]
epochs: 100
shuffle: True
seed: 2024
monitor: {'AUC': 1, 'logloss': -1}
monitor_mode: 'max'

WuKong_test:
model: WuKong
dataset_id: tiny_npz
loss: 'binary_crossentropy'
metrics: ['logloss', 'AUC']
task: binary_classification
optimizer: adam
learning_rate: 1.0e-3
embedding_regularizer: 0
net_regularizer: 0
batch_size: 2048
embedding_dim: 64
num_layers: 4
compression_dim: 32
fmb_units: [128,128,128]
fmb_dim: 32
project_dim: 24
dropout_rate: 0.2
hidden_activations: relu
mlp_hidden_units: [64]
epochs: 5
shuffle: True
seed: 2024
monitor: 'AUC'
monitor_mode: 'max'
3 changes: 3 additions & 0 deletions model_zoo/WuKong/fuxictr_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pip install -U fuxictr
import fuxictr
assert fuxictr.__version__ >= "2.3.2"
87 changes: 87 additions & 0 deletions model_zoo/WuKong/run_expid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# =========================================================================
# Copyright (C) 2024. The FuxiCTR Library. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================


import os
os.chdir(os.path.dirname(os.path.realpath(__file__)))
import sys
import logging
import fuxictr_version
from fuxictr import datasets
from datetime import datetime
from fuxictr.utils import load_config, set_logger, print_to_json, print_to_list
from fuxictr.features import FeatureMap
from fuxictr.pytorch.dataloaders import RankDataLoader
from fuxictr.pytorch.torch_utils import seed_everything
from fuxictr.preprocess import FeatureProcessor, build_dataset
import src
import gc
import argparse
import os
from pathlib import Path


if __name__ == '__main__':
''' Usage: python run_expid.py --config {config_dir} --expid {experiment_id} --gpu {gpu_device_id}
'''
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./config/', help='The config directory.')
parser.add_argument('--expid', type=str, default='DeepFM_test', help='The experiment id to run.')
parser.add_argument('--gpu', type=int, default=-1, help='The gpu index, -1 for cpu')
args = vars(parser.parse_args())

experiment_id = args['expid']
params = load_config(args['config'], experiment_id)
params['gpu'] = args['gpu']
set_logger(params)
logging.info("Params: " + print_to_json(params))
seed_everything(seed=params['seed'])

data_dir = os.path.join(params['data_root'], params['dataset_id'])
feature_map_json = os.path.join(data_dir, "feature_map.json")
if params["data_format"] == "csv":
# Build feature_map and transform data
feature_encoder = FeatureProcessor(**params)
params["train_data"], params["valid_data"], params["test_data"] = \
build_dataset(feature_encoder, **params)
feature_map = FeatureMap(params['dataset_id'], data_dir)
feature_map.load(feature_map_json, params)
logging.info("Feature specs: " + print_to_json(feature_map.features))

model_class = getattr(src, params['model'])
model = model_class(feature_map, **params)
model.count_parameters() # print number of parameters used in model

train_gen, valid_gen = RankDataLoader(feature_map, stage='train', **params).make_iterator()
model.fit(train_gen, validation_data=valid_gen, **params)

logging.info('****** Validation evaluation ******')
valid_result = model.evaluate(valid_gen)
del train_gen, valid_gen
gc.collect()

test_result = {}
if params["test_data"]:
logging.info('******** Test evaluation ********')
test_gen = RankDataLoader(feature_map, stage='test', **params).make_iterator()
test_result = model.evaluate(test_gen)

result_filename = Path(args['config']).name.replace(".yaml", "") + '.csv'
with open(result_filename, 'a+') as fw:
fw.write(' {},[command] python {},[exp_id] {},[dataset_id] {},[train] {},[val] {},[test] {}\n' \
.format(datetime.now().strftime('%Y%m%d-%H%M%S'),
' '.join(sys.argv), experiment_id, params['dataset_id'],
"N.A.", print_to_list(valid_result), print_to_list(test_result)))
155 changes: 155 additions & 0 deletions model_zoo/WuKong/src/WuKong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# =========================================================================
# Copyright (C) 2024. XiaoLongtao. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
""" This model implements the paper: Zhang et al., Wukong: Towards a Scaling Law for
Large-Scale Recommendation, Arxiv 2024.
[PDF] https://arxiv.org/abs/2403.02545
"""

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from fuxictr.pytorch.models import BaseModel
from fuxictr.pytorch.layers import FeatureEmbedding, MLP_Block


class WuKong(BaseModel):
"""
The WuKong model class that implements factorization machines-based model.

Args:
feature_map: A FeatureMap instance used to store feature specs (e.g., vocab_size).
model_id: Equivalent to model class name by default, which is used in config to determine
which model to call.
gpu: gpu device used to load model. -1 means cpu (default=-1).
learning_rate: learning rate for training (default=1e-3).
embedding_dim: embedding dimension of features (default=64).
num_layers: number of WuKong layers (default=3).
compression_dim: dimension of compressed features in LCB (default=40).
mlp_hidden_units: hidden units of MLP on top of WuKong (default=[32,32]).
fmb_units: hidden units of FMB (default=[32,32]).
fmb_dim: dimension of FMB output (default=40).
project_dim: dimension of projection matrix in FMB (default=8).
dropout_rate: dropout rate used in LCB (default=0.2).
embedding_regularizer: regularization term used for embedding parameters (default=0).
net_regularizer: regularization term used for network parameters (default=0).
"""
def __init__(self,
feature_map,
model_id="WuKong",
gpu=-1,
learning_rate=1e-3,
embedding_dim=64,
num_layers=3,
compression_dim=40,
mlp_hidden_units=[32,32],
fmb_units=[32,32],
fmb_dim=40,
project_dim=8,
dropout_rate=0.2,
embedding_regularizer=None,
net_regularizer=None,
**kwargs):
super(WuKong, self).__init__(feature_map,
model_id=model_id,
gpu=gpu,
embedding_regularizer=embedding_regularizer,
net_regularizer=net_regularizer,
**kwargs)
self.feature_map = feature_map
self.embedding_dim = embedding_dim
self.embedding_layer = FeatureEmbedding(feature_map, embedding_dim)
self.interaction_layers = nn.ModuleList([
WuKongLayer(feature_map.num_fields, embedding_dim, project_dim, fmb_units, fmb_dim, compression_dim,dropout_rate) for _ in range(num_layers)
])
self.final_mlp = MLP_Block(input_dim=feature_map.num_fields*embedding_dim,
output_dim=1,
hidden_units=mlp_hidden_units,
hidden_activations='relu',
output_activation=None)
self.compile(kwargs["optimizer"], kwargs["loss"], learning_rate)
self.reset_parameters()
self.model_to_device()

def forward(self, inputs):
X = self.get_inputs(inputs)
feature_emb = self.embedding_layer(X)
for layer in self.interaction_layers:
feature_emb = layer(feature_emb)
y_pred = self.final_mlp(feature_emb)
y_pred = self.output_activation(y_pred)
return_dict = {"y_pred": y_pred}
return return_dict

class FactorizationMachineBlock(nn.Module):
def __init__(self, num_features=14, embedding_dim=16, project_dim=8):
super(FactorizationMachineBlock, self).__init__()
self.embedding_dim = embedding_dim
self.project_dim = project_dim
self.num_features = num_features
self.projection_matrix = nn.Parameter(torch.randn(self.num_features, self.project_dim))

def forward(self, x):
batch_size = x.size(0)
x_fm = x.view(batch_size, self.num_features, self.embedding_dim)
projected = torch.matmul(x_fm.transpose(1, 2), self.projection_matrix)
fm_matrix = torch.matmul(x_fm, projected)
return fm_matrix.view(batch_size, -1)

class FMB(nn.Module):
def __init__(self, num_features=14, embedding_dim=16, fmb_units=[32,32], fmb_dim=40, project_dim=8):
super(FMB, self).__init__()
self.fm_block = FactorizationMachineBlock(num_features, embedding_dim, project_dim)
self.layer_norm = nn.LayerNorm(num_features * project_dim)
model_layers = [nn.Linear(num_features * project_dim, fmb_units[0]), nn.ReLU()]
for i in range(1, len(fmb_units)):
model_layers.append(nn.Linear(fmb_units[i-1], fmb_units[i]))
model_layers.append(nn.ReLU())
model_layers.append(nn.Linear(fmb_units[-1], fmb_dim))
self.mlp = nn.Sequential(*model_layers)

def forward(self, x):
y = self.fm_block(x)
y = self.layer_norm(y)
y = self.mlp(y)
y = F.relu(y)
return y

# Linear Compression Block (LCB)
class LinearCompressionBlock(nn.Module):
def __init__(self, num_features=14, embedding_dim=16, compressed_dim=8,dropout_rate=0.2):
super(LinearCompressionBlock, self).__init__()
self.linear = nn.Linear(num_features * embedding_dim, compressed_dim)
self.dropout = nn.Dropout(p=dropout_rate)
def forward(self, x):
return self.dropout(self.linear(x.view(x.size(0), -1)))

# WuKong Layer
class WuKongLayer(nn.Module):
def __init__(self, num_features=14, embedding_dim=16, project_dim=4, fmb_units=[40,40,40], fmb_dim=40, compressed_dim=40, dropout_rate=0.2):
super(WuKongLayer, self).__init__()
self.fmb = FMB(num_features, embedding_dim, fmb_units, fmb_dim, project_dim)
self.lcb = LinearCompressionBlock(num_features, embedding_dim, compressed_dim, dropout_rate)
self.layer_norm = nn.LayerNorm(num_features * embedding_dim)
self.transform = nn.Linear(fmb_dim + compressed_dim, num_features*embedding_dim)

def forward(self, x):
fmb_out = self.fmb(x)
lcb_out = self.lcb(x)
concat_out = torch.cat([fmb_out, lcb_out], dim=1)
concat_out = self.transform(concat_out)
add_norm_out = self.layer_norm(concat_out+x.view(x.size(0), -1))
return add_norm_out
1 change: 1 addition & 0 deletions model_zoo/WuKong/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .WuKong import *
Loading