Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditioning on Category #54

Open
thauptmann opened this issue Aug 29, 2023 · 1 comment
Open

Conditioning on Category #54

thauptmann opened this issue Aug 29, 2023 · 1 comment

Comments

@thauptmann
Copy link

thauptmann commented Aug 29, 2023

Dear @ZENGXH ,

Thank you for providing the source code to your interesting work. We want to use a conditioned LION with one of our data sets to create point clouds for specific classes.

I tested it with Shapenet by setting cfg.data.cond_on_cat to true (1), but in the VAE I get the error: AttributeError: 'tuple' object has no attribute 'transpose', because the points and class_label are only combined into a tuple.

Our classes are embeddings based on a character sequence with variable length. How would one incorporate the embedding vector? Concatenating the embedding vector to every point or using it in a later layer? Or would it be better to use an architecture similar to the CLIP embeddings?

Greetings Tony

@thauptmann thauptmann changed the title Conditioning on Categoriwa Conditioning on Category Aug 29, 2023
@ZENGXH
Copy link
Collaborator

ZENGXH commented Sep 1, 2023

Hi!

Currently the cond_on_cat is not well supported. You can need some additional digging into the code to add this support.
for example:

  • turning
    style = z_global # torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
    into style = torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
  • set style_mlp as a layer that map the dim_z_global + dim_cls_emb to dim_z_global
  • change the global prior to take conditional input
  • the data loader may also require changes
  • there may be other change needed

in terms of the variable length issue: I think you could take a transformer like encoder and do average pooling at the end to get a 1D latent. And feed them into prior model and vae's decoder like how we use the CLIP embeddings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants