Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification with x-transformers #264

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

RyanKim17920
Copy link

Added cls token/pooling option for NLP based full text classification

x_transformers/x_transformers.py Outdated Show resolved Hide resolved
x = x[:, 0]

if self.use_pooling:
x = self.pooling(x).squeeze()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the pooling, we need to account for masking (masked averaging)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can take care of this if you'd like, it is all around a bit tricky

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please do so. Thank you 👍

@lucidrains
Copy link
Owner

@RyanKim17920 do you want to try the latest changes and see if that's enough?

@lucidrains
Copy link
Owner

@RyanKim17920 hey Ryan, sorry for hijacking your efforts, just that the project is at a size where things need to be a bit more particular

your example should run now as

import torch
from torch import nn

from x_transformers import (
    TransformerWrapper,
    Encoder
)

# CLS token test
transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=2, # num_classes 
    use_cls_token=True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])

print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)

print(loss)

# BCE cls token

transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=1, # num_classes 
    use_cls_token=True,
    squeeze_out_last_dim = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()

print(x.shape)
logits = transformer(x).squeeze()
loss = nn.BCEWithLogitsLoss()(logits, y)

print(loss)

# pooling test
transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=2, # num_classes 
    average_pool_embed = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])

print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)

print(loss)

# pooling BCE test

# pooling test
transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=1, # num_classes 
    average_pool_embed = True,
    squeeze_out_last_dim = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()

print(x.shape)
logits = transformer(x).squeeze()
print(logits.shape)
loss = nn.BCEWithLogitsLoss()(logits, y)

print(loss)

# normal test 

transformer = TransformerWrapper(
    num_tokens=6,
    max_seq_len=10,
    logits_dim=2, # num_classes 
    average_pool_embed = True,
    attn_layers = Encoder(
        dim = 6,
        depth = 1,
        heads = 2,
    )
)

x = torch.randint(0, 5, (1, 10))
y = torch.tensor([0])

print(x.shape)
logits = transformer(x)
print(logits.shape)

@RyanKim17920
Copy link
Author

RyanKim17920 commented Aug 20, 2024

Thank you for the improvements you've already made to my original additions. I noticed that the test/x_transformers are outdated, so those changes aren't needed anymore. However, I believe the example I provided could still be valuable. It demonstrates the usage of the NLP classification with a well-known dataset, which might be useful for users to understand how to implement it while getting a high 90% validation accuracy.

Would it be possible to add the example to the repository?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants