Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation for PaLI? #238

Open
BurgerAndreas opened this issue Feb 5, 2024 · 0 comments
Open

Generation for PaLI? #238

BurgerAndreas opened this issue Feb 5, 2024 · 0 comments

Comments

@BurgerAndreas
Copy link

BurgerAndreas commented Feb 5, 2024

How would one generate an action (output text) using PaLI?

PaLI from readme.md

import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder

# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)

vit = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

pali = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024
)

# training data

img = torch.randn(1, 3, 256, 256)               # images
prompt = torch.randint(0, 256, (1, 1024))       # prompt
prompt_mask = torch.ones(1, 1024).bool()        # prompt text mask
output_text = torch.randint(0, 256, (1, 1024))  # target output text

# train

img_embeds = vit(
    img,
    return_embeddings = True
)

loss = pali(
    prompt,
    output_text,
    mask = prompt_mask,
    src_prepend_embeds = img_embeds             # will preprend image embeddings to encoder text embeddings before attention
)

loss.backward()

Desired behaviour

with torch.no_grad()
    vit.eval()
    pali.eval()

    img_embeds = vit(
        img,
        return_embeddings = True
    )
    
    # how to do this?
    # XTransformer.generate() does not take src_prepend_embeds that can be fed to encoder
    output_text = pali.generate(
        img_embeds,
        prompt,
        mask = prompt_mask,
    )

Idea?

img_embeds = self.vit(img=img, return_embeddings = True)

# from XTransformer.forward()
enc = pali.encoder(prompt, mask=prompt_mask, preprend_embeds=img_embeds, return_embeddings=True)
# from XTransformer.generate()
output_text = pali.decoder.generate(seq_out_start, seq_len, context=enc, context_mask=prompt_mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant