diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index 378968db..8a452119 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -67,6 +67,9 @@ def default(v, d): def first(it): return it[0] +def identity(t, *args, **kwargs): + return t + def divisible_by(num, den): return (num % den) == 0 @@ -264,8 +267,8 @@ def forward(self, x, cond): # for initializing to identity - gamma = (1 + self.gamma_mult * gamma) - beta = beta * self.beta_mult + gamma = (1 + self.gamma_mult * gamma.tanh()) + beta = beta.tanh() * self.beta_mult # classic film @@ -1067,6 +1070,7 @@ def __init__( pad_id = -1, num_sos_tokens = None, condition_on_text = False, + text_cond_with_film = False, text_condition_model_types = ('t5',), text_condition_cond_drop_prob = 0.25, quads = False, @@ -1124,8 +1128,8 @@ def __init__( dim_text = self.conditioner.dim_latent cross_attn_dim_context = dim_text - self.text_coarse_film_cond = FiLM(dim_text, dim) - self.text_fine_film_cond = FiLM(dim_text, dim_fine) + self.text_coarse_film_cond = FiLM(dim_text, dim) if text_cond_with_film else identity + self.text_fine_film_cond = FiLM(dim_text, dim_fine) if text_cond_with_film else identity # for summarizing the vertices of each face diff --git a/meshgpt_pytorch/version.py b/meshgpt_pytorch/version.py index 3f099910..7311807d 100644 --- a/meshgpt_pytorch/version.py +++ b/meshgpt_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.2.14' +__version__ = '1.2.15'