Skip to content

Commit

Permalink
add film conditioning of coarse and fine transformer with pooled text…
Browse files Browse the repository at this point in the history
… embedding before being sent to transformers
  • Loading branch information
lucidrains committed Jun 1, 2024
1 parent 8d7032d commit da5ef5e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
40 changes: 39 additions & 1 deletion meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,24 @@ def scatter_mean(

# resnet block

class FiLM(Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
linear = nn.Linear(dim, dim_out * 2)

self.to_gamma_beta = nn.Sequential(
linear,
Rearrange('b (gb d) -> gb b 1 d', gb = 2)
)

nn.init.zeros_(linear.weight)
nn.init.constant_(linear.bias, 1.)

def forward(self, x, cond):
gamma, beta = self.to_gamma_beta(cond)
return x * gamma + beta

class PixelNorm(Module):
def __init__(self, dim, eps = 1e-4):
super().__init__()
Expand Down Expand Up @@ -1100,7 +1118,8 @@ def __init__(
dim_text = self.conditioner.dim_latent
cross_attn_dim_context = dim_text

self.to_sos_text_cond = nn.Linear(dim_text, dim_fine)
self.text_coarse_film_cond = FiLM(dim_text, dim)
self.text_fine_film_cond = FiLM(dim_text, dim_fine)

# for summarizing the vertices of each face

Expand Down Expand Up @@ -1352,6 +1371,8 @@ def forward_on_codes(

text_embed, text_mask = maybe_dropped_text_embeds

pooled_text_embed = masked_mean(text_embed, text_mask, dim = 1)

attn_context_kwargs = dict(
context = text_embed,
context_mask = text_mask
Expand Down Expand Up @@ -1465,6 +1486,11 @@ def forward_on_codes(

should_cache_fine = not divisible_by(curr_vertex_pos + 1, num_tokens_per_face)

# condition face codes with text if needed

if self.condition_on_text:
face_codes = self.text_coarse_film_cond(face_codes, pooled_text_embed)

# attention on face codes (coarse)

if need_call_first_transformer:
Expand Down Expand Up @@ -1543,6 +1569,8 @@ def forward_on_codes(

fine_attn_context_kwargs = dict()

# optional text cross attention conditioning for fine transformer

if self.fine_cross_attend_text:
repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0]

Expand All @@ -1554,6 +1582,16 @@ def forward_on_codes(
context_mask = text_mask
)

# also film condition the fine vertex codes

if self.condition_on_text:
repeat_batch = fine_vertex_codes.shape[0] // pooled_text_embed.shape[0]

pooled_text_embed = repeat(pooled_text_embed, 'b ... -> (b r) ...', r = repeat_batch)
fine_vertex_codes = self.text_fine_film_cond(fine_vertex_codes, pooled_text_embed)

# fine transformer

attended_vertex_codes, fine_cache = self.fine_decoder(
fine_vertex_codes,
cache = fine_cache,
Expand Down
2 changes: 1 addition & 1 deletion meshgpt_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.10'
__version__ = '1.2.11'

0 comments on commit da5ef5e

Please sign in to comment.