diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 4727c4c3..cdc3240c 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -178,8 +178,6 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, self.patience_counter = 0 # Saving model self.best_network = copy.deepcopy(self.network) - # Updating feature_importances_ - self.feature_importances_ = fit_metrics['train']['feature_importances_'] else: self.patience_counter += 1 @@ -209,6 +207,9 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, # load best models post training self.load_best_model() + # compute feature importance once the best model is defined + self._compute_feature_importances(train_dataloader) + def fit_epoch(self, train_dataloader, valid_dataloader): """ Evaluates and updates network for one epoch. @@ -333,7 +334,7 @@ def explain(self, X): for batch_nb, data in enumerate(dataloader): data = data.to(self.device).float() - output, M_loss, M_explain, masks = self.network(data) + M_explain, masks = self.network.forward_masks(data) for key, value in masks.items(): masks[key] = csc_matrix.dot(value.cpu().detach().numpy(), self.reducing_matrix) @@ -350,6 +351,18 @@ def explain(self, X): res_masks[key] = np.vstack([res_masks[key], value]) return res_explain, res_masks + def _compute_feature_importances(self, loader): + self.network.eval() + feature_importances_ = np.zeros((self.network.post_embed_dim)) + for data, targets in loader: + data = data.to(self.device).float() + M_explain, masks = self.network.forward_masks(data) + feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy() + + feature_importances_ = csc_matrix.dot(feature_importances_, + self.reducing_matrix) + self.feature_importances_ = feature_importances_ / np.sum(feature_importances_) + class TabNetClassifier(TabModel): @@ -471,7 +484,6 @@ def train_epoch(self, train_loader): y_preds = [] ys = [] total_loss = 0 - feature_importances_ = np.zeros((self.network.post_embed_dim)) for data, targets in train_loader: batch_outs = self.train_batch(data, targets) @@ -483,13 +495,6 @@ def train_epoch(self, train_loader): y_preds.append(indices.cpu().detach().numpy()) ys.append(batch_outs["y"].cpu().detach().numpy()) total_loss += batch_outs["loss"] - feature_importances_ += batch_outs['batch_importance'] - - # Reduce to initial input_dim - feature_importances_ = csc_matrix.dot(feature_importances_, - self.reducing_matrix) - # Normalize feature_importances_ - feature_importances_ = feature_importances_ / np.sum(feature_importances_) y_preds = np.hstack(y_preds) ys = np.hstack(ys) @@ -501,7 +506,6 @@ def train_epoch(self, train_loader): total_loss = total_loss / len(train_loader) epoch_metrics = {'loss_avg': total_loss, 'stopping_loss': stopping_loss, - 'feature_importances_': feature_importances_ } if self.scheduler is not None: @@ -525,7 +529,7 @@ def train_batch(self, data, targets): targets = targets.to(self.device).long() self.optimizer.zero_grad() - output, M_loss, M_explain, _ = self.network(data) + output, M_loss = self.network(data) loss = self.loss_fn(output, targets) loss -= self.lambda_sparse*M_loss @@ -538,8 +542,7 @@ def train_batch(self, data, targets): loss_value = loss.item() batch_outs = {'loss': loss_value, 'y_preds': output, - 'y': targets, - 'batch_importance': M_explain.sum(dim=0).cpu().detach().numpy()} + 'y': targets} return batch_outs def predict_epoch(self, loader): @@ -599,7 +602,7 @@ def predict_batch(self, data, targets): self.network.eval() data = data.to(self.device).float() targets = targets.to(self.device).long() - output, M_loss, M_explain, _ = self.network(data) + output, M_loss = self.network(data) loss = self.loss_fn(output, targets) loss -= self.lambda_sparse*M_loss @@ -632,7 +635,7 @@ def predict(self, X): for batch_nb, data in enumerate(dataloader): data = data.to(self.device).float() - output, M_loss, M_explain, masks = self.network(data) + output, M_loss = self.network(data) predictions = torch.argmax(torch.nn.Softmax(dim=1)(output), dim=1) predictions = predictions.cpu().detach().numpy().reshape(-1) @@ -667,7 +670,7 @@ def predict_proba(self, X): for batch_nb, data in enumerate(dataloader): data = data.to(self.device).float() - output, M_loss, M_explain, masks = self.network(data) + output, M_loss = self.network(data) predictions = torch.nn.Softmax(dim=1)(output).cpu().detach().numpy() results.append(predictions) res = np.vstack(results) @@ -741,20 +744,12 @@ def train_epoch(self, train_loader): y_preds = [] ys = [] total_loss = 0 - feature_importances_ = np.zeros((self.network.post_embed_dim)) for data, targets in train_loader: batch_outs = self.train_batch(data, targets) y_preds.append(batch_outs["y_preds"].cpu().detach().numpy()) ys.append(batch_outs["y"].cpu().detach().numpy()) total_loss += batch_outs["loss"] - feature_importances_ += batch_outs['batch_importance'] - - # Reduce to initial input_dim - feature_importances_ = csc_matrix.dot(feature_importances_, - self.reducing_matrix) - # Normalize feature_importances_ - feature_importances_ = feature_importances_ / np.sum(feature_importances_) y_preds = np.vstack(y_preds) ys = np.vstack(ys) @@ -763,7 +758,6 @@ def train_epoch(self, train_loader): total_loss = total_loss / len(train_loader) epoch_metrics = {'loss_avg': total_loss, 'stopping_loss': stopping_loss, - 'feature_importances_': feature_importances_ } if self.scheduler is not None: @@ -788,7 +782,7 @@ def train_batch(self, data, targets): targets = targets.to(self.device).float() self.optimizer.zero_grad() - output, M_loss, M_explain, _ = self.network(data) + output, M_loss = self.network(data) loss = self.loss_fn(output, targets) loss -= self.lambda_sparse*M_loss @@ -801,8 +795,7 @@ def train_batch(self, data, targets): loss_value = loss.item() batch_outs = {'loss': loss_value, 'y_preds': output, - 'y': targets, - 'batch_importance': M_explain.sum(dim=0).cpu().detach().numpy()} + 'y': targets} return batch_outs def predict_epoch(self, loader): @@ -855,7 +848,7 @@ def predict_batch(self, data, targets): data = data.to(self.device).float() targets = targets.to(self.device).float() - output, M_loss, M_explain, _ = self.network(data) + output, M_loss = self.network(data) loss = self.loss_fn(output, targets) loss -= self.lambda_sparse*M_loss @@ -890,7 +883,7 @@ def predict(self, X): for batch_nb, data in enumerate(dataloader): data = data.to(self.device).float() - output, M_loss, M_explain, masks = self.network(data) + output, M_loss = self.network(data) predictions = output.cpu().detach().numpy() results.append(predictions) res = np.vstack(results) diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 8e8c4531..e60c8ff8 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -83,6 +83,8 @@ def __init__(self, input_dim, output_dim, self.n_shared = n_shared self.virtual_batch_size = virtual_batch_size + self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01) + if self.n_shared > 0: shared_feat_transform = torch.nn.ModuleList() for i in range(self.n_shared): @@ -120,6 +122,32 @@ def __init__(self, input_dim, output_dim, def forward(self, x): res = 0 + x = self.initial_bn(x) + + prior = torch.ones(x.shape).to(x.device) + M_loss = 0 + att = self.initial_splitter(x)[:, self.n_d:] + + for step in range(self.n_steps): + M = self.att_transformers[step](prior, att) + M_loss += torch.mean(torch.sum(torch.mul(M, torch.log(M+self.epsilon)), + dim=1)) + # update prior + prior = torch.mul(self.gamma - M, prior) + # output + masked_x = torch.mul(M, x) + out = self.feat_transformers[step](masked_x) + d = ReLU()(out[:, :self.n_d]) + res = torch.add(res, d) + # update attention + att = out[:, self.n_d:] + + M_loss /= self.n_steps + res = self.final_mapping(res) + return res, M_loss + + def forward_masks(self, x): + x = self.initial_bn(x) prior = torch.ones(x.shape).to(x.device) M_explain = torch.zeros(x.shape).to(x.device) @@ -138,15 +166,13 @@ def forward(self, x): masked_x = torch.mul(M, x) out = self.feat_transformers[step](masked_x) d = ReLU()(out[:, :self.n_d]) - res = torch.add(res, d) # explain step_importance = torch.sum(d, dim=1) M_explain += torch.mul(M, step_importance.unsqueeze(dim=1)) # update attention att = out[:, self.n_d:] - res = self.final_mapping(res) - return res, M_loss, M_explain, masks + return M_explain, masks class TabNet(torch.nn.Module): @@ -215,7 +241,6 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8, self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps, gamma, n_independent, n_shared, epsilon, virtual_batch_size, momentum) - self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01) # Defining device if device_name == 'auto': @@ -228,9 +253,12 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8, def forward(self, x): x = self.embedder(x) - x = self.initial_bn(x) return self.tabnet(x) + def forward_masks(self, x): + x = self.embedder(x) + return self.tabnet.forward_masks(x) + class AttentiveTransformer(torch.nn.Module): def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02):