From 0aaaaafae9040fa0abe254f2520521c3e01dfa58 Mon Sep 17 00:00:00 2001 From: Fotiligner <3181945719@qq.com> Date: Fri, 8 Dec 2023 22:51:01 +0800 Subject: [PATCH] update CL4SRec --- recbole/data/transform.py | 84 +++++++ .../model/sequential_recommender/__init__.py | 1 + .../model/sequential_recommender/cl4srec.py | 229 ++++++++++++++++++ recbole/properties/model/CL4SRec.yaml | 13 + 4 files changed, 327 insertions(+) create mode 100644 recbole/model/sequential_recommender/cl4srec.py create mode 100644 recbole/properties/model/CL4SRec.yaml diff --git a/recbole/data/transform.py b/recbole/data/transform.py index 225ea747b..ae9d4d7c7 100644 --- a/recbole/data/transform.py +++ b/recbole/data/transform.py @@ -24,6 +24,7 @@ def construct_transform(config): "crop_itemseq": CropItemSequence, "reorder_itemseq": ReorderItemSequence, "user_defined": UserDefinedTransform, + "random_itemseq": RandomAugmentationSequence } if config["transform"] not in str2transform: raise NotImplementedError( @@ -221,6 +222,89 @@ def __call__(self, dataset, interaction): interaction.update(Interaction(new_dict)) return interaction +class RandomAugmentationSequence: + def __init__(self, config): + self.ITEM_SEQ = config["ITEM_ID_FIELD"] + config["LIST_SUFFIX"] + self.RANDOM_ITEM_SEQ = "Random_" + self.ITEM_SEQ + self.ITEM_SEQ_LEN = config["ITEM_LIST_LENGTH_FIELD"] + self.ITEM_ID = config["ITEM_ID_FIELD"] + config["RANDOM_ITEM_SEQ"] = self.RANDOM_ITEM_SEQ + + + def __call__(self, dataset, interaction): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + device = item_seq.device + n_items = dataset.num(self.ITEM_ID) + + aug_seq1 = [] + aug_len1 = [] + aug_seq2 = [] + aug_len2 = [] + + for seq, length in zip(item_seq, item_seq_len): + if length > 1: + switch = random.sample(range(3), k=2) + else: + switch = [3, 3] + aug_seq = seq + aug_len = length + if switch[0] == 0: + aug_seq, aug_len = self.item_crop(seq, length) + elif switch[0] == 1: + aug_seq, aug_len = self.item_mask(seq, n_items, length) + elif switch[0] == 2: + aug_seq, aug_len = self.item_reorder(seq, length) + + aug_seq1.append(aug_seq) + aug_len1.append(aug_len) + + if switch[1] == 0: + aug_seq, aug_len = self.item_crop(seq, length) + elif switch[1] == 1: + aug_seq, aug_len = self.item_mask(seq, n_items, length) + elif switch[1] == 2: + aug_seq, aug_len = self.item_reorder(seq, length) + + aug_seq2.append(aug_seq) + aug_len2.append(aug_len) + + new_dict = { + "aug1" : torch.stack(aug_seq1), + "aug1_len" : torch.stack(aug_len1), + "aug2" : torch.stack(aug_seq2), + "aug2_len" : torch.stack(aug_len2) + } + interaction.update(Interaction(new_dict)) + return interaction + + def item_crop(self, item_seq, item_seq_len, eta=0.6): + num_left = math.floor(item_seq_len * eta) + crop_begin = random.randint(0, item_seq_len - num_left) + croped_item_seq = np.zeros(item_seq.shape[0]) + if crop_begin + num_left < item_seq.shape[0]: + croped_item_seq[:num_left] = item_seq.cpu().detach().numpy()[crop_begin:crop_begin + num_left] + else: + croped_item_seq[:num_left] = item_seq.cpu().detach().numpy()[crop_begin:] + return torch.tensor(croped_item_seq, dtype=torch.long, device=item_seq.device),\ + torch.tensor(num_left, dtype=torch.long, device=item_seq.device) + + def item_mask(self, item_seq, n_items, item_seq_len, gamma=0.3): + num_mask = math.floor(item_seq_len * gamma) + mask_index = random.sample(range(item_seq_len), k=num_mask) + masked_item_seq = item_seq.cpu().detach().numpy().copy() + masked_item_seq[mask_index] = n_items - 1 # token 0 has been used for semantic masking + return torch.tensor(masked_item_seq, dtype=torch.long, device=item_seq.device), item_seq_len + + def item_reorder(self, item_seq, item_seq_len, beta=0.6): + num_reorder = math.floor(item_seq_len * beta) + reorder_begin = random.randint(0, item_seq_len - num_reorder) + reordered_item_seq = item_seq.cpu().detach().numpy().copy() + shuffle_index = list(range(reorder_begin, reorder_begin + num_reorder)) + random.shuffle(shuffle_index) + reordered_item_seq[reorder_begin:reorder_begin + num_reorder] = reordered_item_seq[shuffle_index] + return torch.tensor(reordered_item_seq, dtype=torch.long, device=item_seq.device), item_seq_len + class CropItemSequence: """ diff --git a/recbole/model/sequential_recommender/__init__.py b/recbole/model/sequential_recommender/__init__.py index e6787cb60..07949a165 100644 --- a/recbole/model/sequential_recommender/__init__.py +++ b/recbole/model/sequential_recommender/__init__.py @@ -1,6 +1,7 @@ from recbole.model.sequential_recommender.bert4rec import BERT4Rec from recbole.model.sequential_recommender.caser import Caser from recbole.model.sequential_recommender.core import CORE +from recbole.model.sequential_recommender.cl4srec import CL4SRec from recbole.model.sequential_recommender.dien import DIEN from recbole.model.sequential_recommender.din import DIN from recbole.model.sequential_recommender.fdsa import FDSA diff --git a/recbole/model/sequential_recommender/cl4srec.py b/recbole/model/sequential_recommender/cl4srec.py new file mode 100644 index 000000000..b2cbc9d4a --- /dev/null +++ b/recbole/model/sequential_recommender/cl4srec.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +# @Time : 2023/11/30 +# @Author : Bingqian Li +# @Email : lizibing666@gmail.com + +import math +import random +import numpy as np +import torch +from torch import nn +from recbole.model.abstract_recommender import SequentialRecommender +from recbole.model.layers import TransformerEncoder +from recbole.model.loss import BPRLoss + + +class CL4SRec(SequentialRecommender): + def __init__(self, config, dataset): + super(CL4SRec, self).__init__(config, dataset) + + # load parameters info + self.n_layers = config['n_layers'] + self.n_heads = config['n_heads'] + self.hidden_size = config['hidden_size'] + self.inner_size = config['inner_size'] + self.hidden_dropout_prob = config['hidden_dropout_prob'] + self.attn_dropout_prob = config['attn_dropout_prob'] + self.hidden_act = config['hidden_act'] + self.layer_norm_eps = config['layer_norm_eps'] + + self.batch_size = config['train_batch_size'] + self.lmd = config['lmd'] + self.tau = config['tau'] + self.sim = config['sim'] + + self.initializer_range = config['initializer_range'] + self.loss_type = config['loss_type'] + + # define layers and loss + self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0) + self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size) + self.trm_encoder = TransformerEncoder( + n_layers=self.n_layers, + n_heads=self.n_heads, + hidden_size=self.hidden_size, + inner_size=self.inner_size, + hidden_dropout_prob=self.hidden_dropout_prob, + attn_dropout_prob=self.attn_dropout_prob, + hidden_act=self.hidden_act, + layer_norm_eps=self.layer_norm_eps + ) + + self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.dropout = nn.Dropout(self.hidden_dropout_prob) + + if self.loss_type == 'BPR': + self.loss_fct = BPRLoss() + elif self.loss_type == 'CE': + self.loss_fct = nn.CrossEntropyLoss() + else: + raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") + + self.mask_default = self.mask_correlated_samples(batch_size=self.batch_size) + self.nce_fct = nn.CrossEntropyLoss() + + # parameters initialization + self.apply(self._init_weights) + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def get_attention_mask(self, item_seq): + """Generate left-to-right uni-directional attention mask for multi-head attention.""" + attention_mask = (item_seq > 0).long() + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 + # mask for left-to-right unidirectional + max_len = attention_mask.size(-1) + attn_shape = (1, max_len, max_len) + subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8 + subsequent_mask = (subsequent_mask == 0).unsqueeze(1) + subsequent_mask = subsequent_mask.long().to(item_seq.device) + + extended_attention_mask = extended_attention_mask * subsequent_mask + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward(self, item_seq, item_seq_len): + position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device) + position_ids = position_ids.unsqueeze(0).expand_as(item_seq) + position_embedding = self.position_embedding(position_ids) + + item_emb = self.item_embedding(item_seq) + input_emb = item_emb + position_embedding + input_emb = self.LayerNorm(input_emb) + input_emb = self.dropout(input_emb) + + extended_attention_mask = self.get_attention_mask(item_seq) + + trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True) + output = trm_output[-1] + output = self.gather_indexes(output, item_seq_len - 1) + return output # [B H] + + def calculate_loss(self, interaction): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + seq_output = self.forward(item_seq, item_seq_len) + pos_items = interaction[self.POS_ITEM_ID] + if self.loss_type == 'BPR': + neg_items = interaction[self.NEG_ITEM_ID] + pos_items_emb = self.item_embedding(pos_items) + neg_items_emb = self.item_embedding(neg_items) + pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B] + neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B] + loss = self.loss_fct(pos_score, neg_score) + else: # self.loss_type = 'CE' + test_item_emb = self.item_embedding.weight[:self.n_items] # unpad the augmentation mask + logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) + loss = self.loss_fct(logits, pos_items) + + # # NCE + # aug_item_seq1, aug_len1, aug_item_seq2, aug_len2 = self.augment(item_seq, item_seq_len) + # # aug_item_seq1, aug_len1, aug_item_seq2, aug_len2 = \ + # # interaction['aug1'], interaction['aug_len1'], interaction['aug2'], interaction['aug_len2'] + # seq_output1 = self.forward(aug_item_seq1, aug_len1) + # seq_output2 = self.forward(aug_item_seq2, aug_len2) + + seq_output1 = self.forward(interaction["aug1"], interaction["aug1_len"]) + seq_output2 = self.forward(interaction["aug2"], interaction["aug2_len"]) + + nce_logits, nce_labels = self.info_nce(seq_output1, seq_output2, temp=self.tau, batch_size=item_seq_len.shape[0], sim='dot') + + nce_loss = self.nce_fct(nce_logits, nce_labels) + + with torch.no_grad(): + alignment, uniformity = self.decompose(seq_output1, seq_output2, seq_output, + batch_size=item_seq_len.shape[0]) + + return loss + self.lmd * nce_loss, alignment, uniformity + + def decompose(self, z_i, z_j, origin_z, batch_size): + """ + We do not sample negative examples explicitly. + Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. + """ + N = 2 * batch_size + + z = torch.cat((z_i, z_j), dim=0) + + # pairwise l2 distace + sim = torch.cdist(z, z, p=2) + + sim_i_j = torch.diag(sim, batch_size) + sim_j_i = torch.diag(sim, -batch_size) + + positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) + alignment = positive_samples.mean() + + # pairwise l2 distace + sim = torch.cdist(origin_z, origin_z, p=2) + mask = torch.ones((batch_size, batch_size), dtype=bool) + mask = mask.fill_diagonal_(0) + negative_samples = sim[mask].reshape(batch_size, -1) + uniformity = torch.log(torch.exp(-2 * negative_samples).mean()) + + return alignment, uniformity + + def mask_correlated_samples(self, batch_size): + N = 2 * batch_size + mask = torch.ones((N, N), dtype=bool) + mask = mask.fill_diagonal_(0) + for i in range(batch_size): + mask[i, batch_size + i] = 0 + mask[batch_size + i, i] = 0 + return mask + + def info_nce(self, z_i, z_j, temp, batch_size, sim='dot'): + """ + We do not sample negative examples explicitly. + Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. + """ + N = 2 * batch_size + + z = torch.cat((z_i, z_j), dim=0) + + if sim == 'cos': + sim = nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temp + elif sim == 'dot': + sim = torch.mm(z, z.T) / temp + + sim_i_j = torch.diag(sim, batch_size) + sim_j_i = torch.diag(sim, -batch_size) + + positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) + if batch_size != self.batch_size: + mask = self.mask_correlated_samples(batch_size) + else: + mask = self.mask_default + negative_samples = sim[mask].reshape(N, -1) + + labels = torch.zeros(N).to(positive_samples.device).long() + logits = torch.cat((positive_samples, negative_samples), dim=1) + return logits, labels + + def predict(self, interaction): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + test_item = interaction[self.ITEM_ID] + seq_output = self.forward(item_seq, item_seq_len) + test_item_emb = self.item_embedding(test_item) + scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] + return scores + + def full_sort_predict(self, interaction): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + seq_output = self.forward(item_seq, item_seq_len) + test_items_emb = self.item_embedding.weight[:self.n_items] # unpad the augmentation mask + scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items] + return scores diff --git a/recbole/properties/model/CL4SRec.yaml b/recbole/properties/model/CL4SRec.yaml new file mode 100644 index 000000000..b0b7e5066 --- /dev/null +++ b/recbole/properties/model/CL4SRec.yaml @@ -0,0 +1,13 @@ +n_layers: 2 # (int) The number of transformer layers in transformer encoder. +n_heads: 2 # (int) The number of attention heads for multi-head attention layer. +hidden_size: 64 # (int) The number of features in the hidden state. +inner_size: 256 # (int) The inner hidden size in feed-forward layer. +hidden_dropout_prob: 0.5 # (float) The probability of an element to be zeroed. +attn_dropout_prob: 0.5 # (float) The probability of an attention score to be zeroed. +hidden_act: 'gelu' # (str) The activation function in feed-forward layer. +layer_norm_eps: 1e-12 # (float) A value added to the denominator for numerical stability. +initializer_range: 0.02 # (float) The standard deviation for normal initialization. +loss_type: 'BPR' # (str) The type of loss function. Range in ['BPR', 'CE']. +transform: 'random_itemseq' # (str) The type of item trasformation. +lmd: 0.01 # (float) proportion of contrastive loss +tau: 5 # (float) hyper parameter of contrastive loss