Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any training example about tabtransformer? #6

Open
pancodex opened this issue Apr 13, 2021 · 1 comment
Open

Is there any training example about tabtransformer? #6

pancodex opened this issue Apr 13, 2021 · 1 comment

Comments

@pancodex
Copy link

Hi,
I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

@Alexx776
Copy link

Alexx776 commented Sep 5, 2024

Hello, here is a simple example of training the model, I hope it can help

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tab_transformer_pytorch import TabTransformer
import numpy as np
from sklearn.metrics import accuracy_score, classification_report

# Generate fake data
def generate_data(num_samples):
    x_categ = torch.randint(0, 10, (num_samples, 5))
    
    x_cont = torch.randn(num_samples, 10)
    
    y = torch.zeros(num_samples)
    for i in range(num_samples):
        if x_categ[i, 0] > 5 and x_cont[i, 0] > 0:
            y[i] = 2  
        elif x_categ[i, 1] < 3 or x_cont[i, 1] < -1:
            y[i] = 1 
        else:
            y[i] = 0  
    
    return x_categ, x_cont, y.long()

num_samples = 10000
x_categ, x_cont, y = generate_data(num_samples)

cont_mean = x_cont.mean(dim=0)
cont_std = x_cont.std(dim=0)
x_cont = (x_cont - cont_mean) / cont_std 
cont_mean_std = torch.stack([cont_mean, cont_std], dim=1)

# Model
model = TabTransformer(
    categories = (10, 10, 10, 10, 10), 
    num_continuous = 10,                
    dim = 64,                            
    dim_out = 3,                        
    depth = 6,                           
    heads = 8,                          
    attn_dropout = 0.1,                  
    ff_dropout = 0.1,                    
    mlp_hidden_mults = (4, 2),           
    mlp_act = nn.ReLU(),                 
    continuous_mean_std = cont_mean_std  
)

dataset = TensorDataset(x_categ, x_cont, y)
train_size = int(0.8 * len(dataset))  
test_size = len(dataset) - train_size 
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  

criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 

num_epochs = 20

# Train
for epoch in range(num_epochs):
    model.train()  
    total_loss = 0
    all_preds = []  
    all_labels = []  

    for batch_categ, batch_cont, batch_y in train_loader:
        
        optimizer.zero_grad()  
        
        outputs = model(batch_categ, batch_cont)
        loss = criterion(outputs, batch_y)  
        
        loss.backward() 
        optimizer.step()  
        
        total_loss += loss.item()  
        
        _, predicted = torch.max(outputs.data, 1) 
        all_preds.extend(predicted.numpy())  
        all_labels.extend(batch_y.numpy())  
    
    avg_loss = total_loss / len(train_loader)
    
    accuracy = accuracy_score(all_labels, all_preds)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants