diff --git a/.gitignore b/.gitignore index 3904fd6..a831437 100644 --- a/.gitignore +++ b/.gitignore @@ -149,6 +149,7 @@ configs/local/default.yaml /data/ /logs/ .env +/data # Aim logging .aim @@ -156,5 +157,3 @@ configs/local/default.yaml # Scripts *.sh -# Ignore data -data diff --git a/README.md b/README.md index 091c7c3..81fdbf4 100644 --- a/README.md +++ b/README.md @@ -2,20 +2,26 @@ [[Paper](https://openreview.net/pdf?id=QZfdDpTX1uM)][[Poster](https://iclr.cc/media/PosterPDFs/ICLR%202023/11395.png?t=1682361273.0520558)][[OpenReview](https://openreview.net/forum?id=QZfdDpTX1uM)] ## Datasets -We provide the compressed datasets: Stack Overflow, Mooc, Reddit, Wiki, Sin, Uber, NYC Taxi, in this link. +We provide the compressed datasets: Stack Overflow, Mooc, Reddit, Wiki, Sin, Uber, NYC Taxi, in this [link](https://drive.google.com/file/d/1pL1wDG1elgtUa0CPv4GP21xGII-Ymk0x/view?usp=drive_link). Unzip the compressed file and locate it in the `$ROOT` directory. + ## Setup Setup the pipeline by installing dependencies using the following command. pretrained models and utils. ```bash pip install -r requirements.txt ``` +For nfe pacakge, install the package in [neural flows repo](https://github.com/mbilos/neural-flows-experiments) using +```bash +pip install -e . +``` ## Pre-trained models We also provide the checkpoints for Intensity free, THP+ and Attentive TPP on all the datasets. -Please download the compress file in this link, unzip it and locate it in the `$ROOT` directory. +Please download the compress file in this [link](https://drive.google.com/file/d/1frnaUoToJIMh9BnQaqz4zy3HNtaoKe35/view?usp=drive_link), unzip it and locate it in the `$ROOT` directory. + ## Train @@ -26,6 +32,7 @@ python src/train.py data/datasets=$DATASET model=$MODEL `$DATASET` can be chosen from `{so_fold1, mooc, reddit, wiki, sin, uber_drop, taxi_times_jan_feb}` and `$MODEL` can be chosen from `{intensity_free,thp_mix,attn_lnp}`. Other configurations can be also easily modified using hydra syntax. Please refer to [hydra](https://hydra.cc/docs/intro/) for further details. + ## Eval A model can be evaluated on test datasets using the following command. ```bash @@ -34,6 +41,7 @@ python src/eval.py data/datasets=$DATASET model=$MODEL Here, the default checkpoint paths are set to the ones in `checkpoints` directory we provided above. To use different checkpoints, please chagne `ckpt_path` argument in `configs/eval.yaml`. + ## Modifications We made some modifications during code refactorization after ICLR 2023. For the NLL metric, we took out L2 norm of model params, which we had to include for implementation purpose using the internal pipeline. diff --git a/requirements.txt b/requirements.txt index c3bb075..3567dd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -65,7 +65,6 @@ mpmath==1.3.0 msgpack==1.0.5 multidict==6.0.4 networkx==3.1 --e git+https://github.com/mbilos/neural-flows-experiments.git@bd19f7c92461e83521e268c1a235ef845a3dd963#egg=nfe nodeenv==1.8.0 numpy==1.24.3 nvidia-cublas-cu11==11.10.3.66 @@ -80,7 +79,7 @@ nvidia-cusparse-cu11==11.7.4.91 nvidia-nccl-cu11==2.14.3 nvidia-nvtx-cu11==11.7.91 omegaconf==2.3.0 -optuna==2.10.1 +optuna==3.6.1 ordered-set==4.1.0 packaging==23.1 pandas==2.0.2 diff --git a/src/data/tpp_dataset.py b/src/data/tpp_dataset.py new file mode 100644 index 0000000..4fa13dd --- /dev/null +++ b/src/data/tpp_dataset.py @@ -0,0 +1,165 @@ +import os +import numpy as np +import torch +import logging +from torch.utils.data import Dataset +from torch.utils.data import DataLoader + +from lightning import LightningModule +from src import constants +#from plato.bear.dataset.base_dataset import BaseDataset +#from plato.bear.utils.shared_data import SharedData +#from code import constants + +logger = logging.getLogger(__name__) + +class TPPDataModule(LightningModule): + def __init__(self, datasets, data_dir, batch_size, num_workers, + pin_memory=False, **kwargs): + super().__init__() + self.dataset = datasets['dataset'] + self.num_classes = datasets['num_classes'] + self.data_dir = data_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.kwargs = kwargs + + def prepare_data(self): + pass + + def setup(self, stage): + self.train_dataset = TPPDataset( + self.data_dir, self.dataset, self.num_classes, mode='train', **self.kwargs) + self.val_dataset = TPPDataset( + self.data_dir, self.dataset, self.num_classes, mode='val', **self.kwargs) + self.test_dataset = TPPDataset( + self.data_dir, self.dataset, self.num_classes, mode='test', **self.kwargs) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, + shuffle=True, pin_memory=self.pin_memory) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, batch_size=int(self.batch_size/4), num_workers=self.num_workers, + shuffle=False, pin_memory=self.pin_memory) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, batch_size=int(self.batch_size/4), num_workers=self.num_workers, + shuffle=False, pin_memory=self.pin_memory) + + + +class TPPDataset(Dataset): + synthetic_data = ['sin'] + real_data = ['so_fold1', 'mooc', 'reddit', 'wiki', + 'uber_drop', 'taxi_times_jan_feb'] + + data_fixed_indices = { + 'retweet': [20000, 2000], 'mimic_fold1': [527, 58], 'so_fold1': [4777, 530] + } + + def __init__(self, data_dir, dataset, num_classes, mode, **kwargs): + ''' + data_dir: the root directory where all .npz files are. Default is /shared-data/TPP + dataset: the name of a dataset + mode: dataset type - [train, val, test] + ''' + super(TPPDataset).__init__() + self.mode = mode + + if dataset in self.synthetic_data: + data_path = os.path.join(data_dir, 'synthetic', dataset + '.npz') + elif dataset in self.real_data: + data_path = os.path.join(data_dir, 'real', dataset + '.npz') + else: + logger.error(f'{dataset} is not valid for dataset argument'); exit() + + use_marks = kwargs.get('use_mark', True) + data_dict = dict(np.load(data_path, allow_pickle=True)) + times = data_dict[constants.TIMES] + marks = data_dict.get(constants.MARKS, np.ones_like(times)) + masks = data_dict.get(constants.MASKS, np.ones_like(times)) + if not use_marks: + marks = np.ones_like(times) + self._num_classes = num_classes + + if dataset not in self.data_fixed_indices: + (train_size, val_size) = ( + kwargs.get('train_size', 0.6), kwargs.get('val_size', 0.2)) + else: + train_size, val_size = self.data_fixed_indices[dataset] + + train_rate = kwargs.get('train_rate', 1.0) + eval_rate = kwargs.get('eval_rate', 1.0) + num_data = len(times) + (start_idx, end_idx) = self._get_split_indices( + num_data, mode=mode, train_size=train_size, val_size=val_size, + train_rate=train_rate, eval_rate=eval_rate) + + self._times = torch.tensor( + times[start_idx:end_idx], dtype=torch.float32).unsqueeze(-1) + self._marks = torch.tensor( + marks[start_idx:end_idx], dtype=torch.long).unsqueeze(-1) + self._masks = torch.tensor( + masks[start_idx:end_idx], dtype=torch.float32).unsqueeze(-1) + + def _sanity_check(self, time, mask): + valid_time = time[mask.bool()] + prev_time = valid_time[0] + for i in range(1, valid_time.shape[0]): + curr_time = valid_time[i] + if curr_time < prev_time: + logger.error(f'sanity check failed - prev time: {prev_time}, curr time: {curr_time}'); exit() + logger.info('sanity check passed') + + def _get_split_indices(self, num_data, mode, train_size=0.6, val_size=0.2, + train_rate=1.0, eval_rate=1.0): + if mode == 'train': + start_idx = 0 + if train_size > 1.0: + end_idx = int(train_size * train_rate) + else: + end_idx = int(num_data * train_size * train_rate) + elif mode == 'val': + if val_size > 1.0: + start_idx = train_size + end_idx = train_size + val_size + else: + start_idx = int(num_data * train_size) + end_idx = start_idx + int(num_data * val_size * eval_rate) + elif mode == 'test': + if train_size > 1.0 and val_size > 1.0: + start_idx = train_size + val_size + else: + start_idx = int(num_data * train_size) + int(num_data * val_size) + end_idx = start_idx + int((num_data - start_idx) * eval_rate) + else: + logger.error(f'Wrong mode {mode} for dataset'); exit() + return (start_idx, end_idx) + + def __getitem__(self, idx): + time, mark, mask = self._times[idx], self._marks[idx], self._masks[idx] + + missing_mask = [] + input_dict = { + constants.TIMES: time, + constants.MARKS: mark, + constants.MASKS: mask, + constants.MISSING_MASKS: missing_mask + } + return input_dict + + def __len__(self): + return self._times.shape[0] + + @property + def num_classes(self): + return self._num_classes + + @property + def num_seq(self): + return self._times.shape[1] diff --git a/src/models/tpp/intensity_free.py b/src/models/tpp/intensity_free.py deleted file mode 100644 index b3b32de..0000000 --- a/src/models/tpp/intensity_free.py +++ /dev/null @@ -1,129 +0,0 @@ -""" This module defines some network classes for selective capacity models. """ -import os -import logging -import copy -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from torch.distributions import Categorical - -from src import constants -from src.models.tpp import util -from src.models.tpp.prob_dists import NormalMixture, LogNormalMixture -from src.models.tpp.flow import ContinuousGRULayer, ContinuousLSTMLayer -from src.models.tpp.thp.models import ( - TransformerEncoder, TransformerAttnEncoder, NPVIEncoder, NPMLEncoder, - TransformerRNN, TransformerDecoder) -from src.models.tpp.thp import util as thp_util - -logger = logging.getLogger(__name__) - -class IntensityFreePredictor(nn.Module): - def __init__(self, hidden_dim, num_components, num_classes, flow=None, activation=None, - weights_path=None, perm_invar=False, compute_acc=True): - ''' - hidden_dim: the size of intermediate features - num_components: the number of mixtures - encoder: dictionary that specifices arguments for the encoder - activation: dictionary that specifices arguments for the activation function - weights_path: path to a checkpoint point - ''' - super().__init__() - self.num_classes = num_classes - self.compute_acc = compute_acc - - self.perm_invar = perm_invar - self.hidden_dim = hidden_dim - self.embedding = nn.Embedding(self.num_classes+1, hidden_dim, padding_idx=constants.PAD) - - # if flow is specified, it correponds to neural flow else intensity-free - self.flow = flow - if self.flow == 'gru': - self.encoder = ContinuousGRULayer( - 1 + hidden_dim, hidden_dim=hidden_dim, - model='flow', flow_model='resnet', flow_layers=1, - hidden_layers=2, time_net='TimeTanh', time_hidden_dime=8) - elif self.flow == 'lstm': - self.encoder = ContinuousLSTMLayer( - 1 + hidden_dim, hidden_dim=hidden_dim+1, - model='flow', flow_model='resnet', flow_layers=1, - hidden_layers=2, time_net='TimeTanh', time_hidden_dime=8) - else: - self.encoder = nn.GRU( - 1 + hidden_dim, hidden_dim, batch_first=True) - self.activation = util.build_activation(activation) - - if self.perm_invar: - decoder_hidden_dim = self.hidden_dim * 2 - else: - decoder_hidden_dim = self.hidden_dim - - self.prob_dist = LogNormalMixture( - decoder_hidden_dim, num_components, activation=self.activation) - - if self.num_classes > 1: - self.mark_linear = nn.Linear(decoder_hidden_dim, self.num_classes) - - trainable_params = sum( - p.numel() for p in self.parameters() if p.requires_grad) - print(f'The number of trainable model parameters: {trainable_params}', flush=True) - - #if weights_path: - # load_weights_module(self, weights_path, key=('model', 'network')) - # shared_data[constants.CHECKPOINT_METRIC] = weights_path.split('/')[-2] - - - def forward(self, times, marks, masks, missing_masks=[]): - if isinstance(missing_masks, torch.Tensor): - masks = torch.logical_and(masks.bool(), missing_masks.bool()).float() - - # obtain the features from the encoder - if self.flow != 'gru' and self.flow != 'lstm': - hidden = torch.zeros( - 1, 1, self.hidden_dim).repeat(1, times.shape[0], 1).to(times) # (1, B, D) - marks_emb = self.embedding(marks.squeeze(-1)) # (B, Seq, D) - inputs = torch.cat([times, marks_emb], -1) # (B, Seq, D+1) - - histories, _ = self.encoder(inputs, hidden) # (B, Seq, D) - else: - marks_emb = self.embedding(marks.squeeze(-1)) - histories = self.encoder(torch.cat([times, marks_emb], -1), times) - - histories = histories[:,:-1] # (B, Seq-1, D) - - prob_output_dict = self.prob_dist( - histories, times[:,1:], masks[:,1:]) # (B, Seq-1, 1): ignore the first event since that's only input not output - event_ll = prob_output_dict['event_ll'] - surv_ll = prob_output_dict['surv_ll'] - time_predictions = prob_output_dict['preds'] - - # compute log-likelihood and class predictions if marks are available - class_predictions = None - if self.num_classes > 1 and self.compute_acc: - batch_size = times.shape[0] - last_event_idx = masks.squeeze(-1).sum(-1, keepdim=True).long().squeeze(-1) - 1 # (batch_size,) - masks_without_last = masks.clone() - masks_without_last[torch.arange(batch_size), last_event_idx] = 0 - - mark_logits = torch.log_softmax(self.mark_linear(histories), dim=-1) # (B, Seq-1, num_marks) - mark_dist = Categorical(logits=mark_logits) - adjusted_marks = torch.where(marks-1 >= 0, marks-1, torch.zeros_like(marks)).squeeze(-1) # original dataset uses 1-index - mark_log_probs = mark_dist.log_prob(adjusted_marks[:,1:]) # (B, Seq-1) - mark_log_probs = torch.stack( - [torch.sum(mark_log_prob[mask.bool()]) for - mark_log_prob, mask in zip(mark_log_probs, masks_without_last.squeeze(-1)[:,:-1])]) - event_ll = event_ll + mark_log_probs - class_predictions = torch.argmax(mark_logits, dim=-1) - - output_dict = { - constants.HISTORIES: histories, - constants.EVENT_LL: event_ll, - constants.SURV_LL: surv_ll, - constants.KL: None, - constants.TIME_PREDS: time_predictions, - constants.CLS_PREDS: class_predictions, - constants.ATTENTIONS: None, - } - return output_dict -