From da5ef5e9991cfd0e868f975624518cb183466f69 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 1 Jun 2024 08:17:41 -0700 Subject: [PATCH] add film conditioning of coarse and fine transformer with pooled text embedding before being sent to transformers --- meshgpt_pytorch/meshgpt_pytorch.py | 40 +++++++++++++++++++++++++++++- meshgpt_pytorch/version.py | 2 +- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index d8e5ca8c..6acdc302 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -247,6 +247,24 @@ def scatter_mean( # resnet block +class FiLM(Module): + def __init__(self, dim, dim_out = None): + super().__init__() + dim_out = default(dim_out, dim) + linear = nn.Linear(dim, dim_out * 2) + + self.to_gamma_beta = nn.Sequential( + linear, + Rearrange('b (gb d) -> gb b 1 d', gb = 2) + ) + + nn.init.zeros_(linear.weight) + nn.init.constant_(linear.bias, 1.) + + def forward(self, x, cond): + gamma, beta = self.to_gamma_beta(cond) + return x * gamma + beta + class PixelNorm(Module): def __init__(self, dim, eps = 1e-4): super().__init__() @@ -1100,7 +1118,8 @@ def __init__( dim_text = self.conditioner.dim_latent cross_attn_dim_context = dim_text - self.to_sos_text_cond = nn.Linear(dim_text, dim_fine) + self.text_coarse_film_cond = FiLM(dim_text, dim) + self.text_fine_film_cond = FiLM(dim_text, dim_fine) # for summarizing the vertices of each face @@ -1352,6 +1371,8 @@ def forward_on_codes( text_embed, text_mask = maybe_dropped_text_embeds + pooled_text_embed = masked_mean(text_embed, text_mask, dim = 1) + attn_context_kwargs = dict( context = text_embed, context_mask = text_mask @@ -1465,6 +1486,11 @@ def forward_on_codes( should_cache_fine = not divisible_by(curr_vertex_pos + 1, num_tokens_per_face) + # condition face codes with text if needed + + if self.condition_on_text: + face_codes = self.text_coarse_film_cond(face_codes, pooled_text_embed) + # attention on face codes (coarse) if need_call_first_transformer: @@ -1543,6 +1569,8 @@ def forward_on_codes( fine_attn_context_kwargs = dict() + # optional text cross attention conditioning for fine transformer + if self.fine_cross_attend_text: repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0] @@ -1554,6 +1582,16 @@ def forward_on_codes( context_mask = text_mask ) + # also film condition the fine vertex codes + + if self.condition_on_text: + repeat_batch = fine_vertex_codes.shape[0] // pooled_text_embed.shape[0] + + pooled_text_embed = repeat(pooled_text_embed, 'b ... -> (b r) ...', r = repeat_batch) + fine_vertex_codes = self.text_fine_film_cond(fine_vertex_codes, pooled_text_embed) + + # fine transformer + attended_vertex_codes, fine_cache = self.fine_decoder( fine_vertex_codes, cache = fine_cache, diff --git a/meshgpt_pytorch/version.py b/meshgpt_pytorch/version.py index c7ac77f9..a8102c1e 100644 --- a/meshgpt_pytorch/version.py +++ b/meshgpt_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.2.10' +__version__ = '1.2.11'