diff --git a/README.md b/README.md index e8af1bc9..a6f4884f 100644 --- a/README.md +++ b/README.md @@ -65,11 +65,12 @@ transformer = MeshTransformer( loss = transformer(face_vertex_codes) loss.backward() -# to decode back to continuous coordinates for each face (9 vertices) +# after much training of transformer, you can now sample from the attention net -# (batch, number of faces, vertex (3), coord (3)) +faces_coordinates = transformer.generate() -face_seq_coords = autoencoder.decode_from_codes_to_faces(face_vertex_codes) +# (batch, num faces, vertices (3), coordinates (3)) +# now post process for the generated 3d asset ``` diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index ae41a87f..99aa37b0 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -9,14 +9,18 @@ from torchtyping import TensorType from beartype import beartype -from beartype.typing import Tuple +from beartype.typing import Tuple, Callable, Optional from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange from x_transformers import Decoder - from x_transformers.attend import Attend +from x_transformers.autoregressive_wrapper import ( + eval_decorator, + top_k, + top_p, +) from vector_quantize_pytorch import ( ResidualVQ, @@ -445,6 +449,7 @@ def __init__( ): super().__init__() + self.autoencoder = autoencoder self.codebook_size = autoencoder.codebook_size self.num_quantizers = autoencoder.num_quantizers @@ -453,10 +458,14 @@ def __init__( # they use axial positional embeddings + assert divisible_by(max_seq_len, 3 * self.num_quantizers) # 3 vertices per face, with D codes per vertex + self.token_embed = nn.Embedding(self.codebook_size + 1, dim) self.quantize_level_embed = nn.Parameter(torch.randn(self.num_quantizers, dim)) self.abs_pos_emb = nn.Embedding(max_seq_len, dim) + self.max_seq_len = max_seq_len + # main autoregressive attention network self.decoder = Decoder( @@ -469,39 +478,88 @@ def __init__( self.to_logits = nn.Linear(dim, self.codebook_size + 1) - def generate(self): - return self + @property + def device(self): + return next(self.parameters()).device + + @eval_decorator + @torch.no_grad() + @beartype + def generate( + self, + prompt: Optional[Tensor] = None, + batch_size: Optional[int] = None, + filter_logits_fn: Callable = top_k, + filter_kwargs: dict = dict(), + temperature = 1., + return_codes = False + ): + if exists(prompt): + assert not exists(batch_size) + + prompt = rearrange(prompt, 'b ... -> b (...)') + batch_size = prompt.shape[0] + + batch_size = default(batch_size, 1) + + codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device)) + + curr_length = codes.shape[-1] + + for i in range(curr_length, self.max_seq_len): + can_eos = (i + 1) % (self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residusl VQ codes + + logits = self.forward(codes, return_loss = False) + logits = logits[:, -1] + + if not can_eos: + logits[:, -1] = -torch.finfo(logits.dtype).max + + filtered_logits = filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / temperature, dim = -1) + + sample = torch.multinomial(probs, 1) + codes, _ = pack([codes, sample], 'b *') + + if return_codes: + codes = codes[:, 1:] # remove sos + codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers) + return codes + + return self.autoencoder.decode_from_codes_to_faces(codes) def forward( self, - codes, + codes = None, return_loss = True ): - codes = rearrange(codes, 'b ... -> b (...)') + if codes.ndim > 2: + codes = rearrange(codes, 'b ... -> b (...)') batch, seq_len, device = *codes.shape, codes.device - assert divisible_by(seq_len, self.num_quantizers) + + assert seq_len <= self.max_seq_len if return_loss: codes, labels = codes[:, :-1], codes codes = self.token_embed(codes) - # auto append sos token - - sos = repeat(self.sos_token, 'd -> b d', b = batch) - codes, _ = pack([sos, codes], 'b * d') - # codebook embed + absolute positions - seq_arange = torch.arange(seq_len, device = device) + seq_arange = torch.arange(codes.shape[-2], device = device) codes = codes + self.abs_pos_emb(seq_arange) # embedding for quantizer level - level_embed = repeat(self.quantize_level_embed, 'n d -> (r n) d', r = seq_len // self.num_quantizers) - codes = codes + level_embed + level_embed = repeat(self.quantize_level_embed, 'n d -> (r n) d', r = ceil(seq_len / self.num_quantizers)) + codes = codes + level_embed[:codes.shape[1]] + + # auto append sos token + + sos = repeat(self.sos_token, 'd -> b d', b = batch) + codes, _ = pack([sos, codes], 'b * d') # attention @@ -511,12 +569,6 @@ def forward( logits = self.to_logits(attended) - if self.num_quantizers > 1: - eos_mask = ((torch.arange(seq_len, device = device) - 1) % self.num_quantizers == 0) - eos_mask = rearrange(eos_mask, 'n -> 1 n 1') - eos_mask = F.pad(eos_mask, (logits.shape[-1] - 1, 0), value = False) - logits = logits.masked_fill(eos_mask, -torch.finfo(logits.dtype).max) - if not return_loss: return logits diff --git a/setup.py b/setup.py index b737e252..dc303c06 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'meshgpt-pytorch', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'MeshGPT Pytorch', author = 'Phil Wang',