Skip to content

Commit

Permalink
end to end generation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 5, 2023
1 parent d4965e5 commit cc7becc
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ transformer = MeshTransformer(
loss = transformer(face_vertex_codes)
loss.backward()

# to decode back to continuous coordinates for each face (9 vertices)
# after much training of transformer, you can now sample from the attention net

# (batch, number of faces, vertex (3), coord (3))
faces_coordinates = transformer.generate()

face_seq_coords = autoencoder.decode_from_codes_to_faces(face_vertex_codes)
# (batch, num faces, vertices (3), coordinates (3))
# now post process for the generated 3d asset

```

Expand Down
94 changes: 73 additions & 21 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
from torchtyping import TensorType

from beartype import beartype
from beartype.typing import Tuple
from beartype.typing import Tuple, Callable, Optional

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

from x_transformers import Decoder

from x_transformers.attend import Attend
from x_transformers.autoregressive_wrapper import (
eval_decorator,
top_k,
top_p,
)

from vector_quantize_pytorch import (
ResidualVQ,
Expand Down Expand Up @@ -445,6 +449,7 @@ def __init__(
):
super().__init__()

self.autoencoder = autoencoder
self.codebook_size = autoencoder.codebook_size
self.num_quantizers = autoencoder.num_quantizers

Expand All @@ -453,10 +458,14 @@ def __init__(

# they use axial positional embeddings

assert divisible_by(max_seq_len, 3 * self.num_quantizers) # 3 vertices per face, with D codes per vertex

self.token_embed = nn.Embedding(self.codebook_size + 1, dim)
self.quantize_level_embed = nn.Parameter(torch.randn(self.num_quantizers, dim))
self.abs_pos_emb = nn.Embedding(max_seq_len, dim)

self.max_seq_len = max_seq_len

# main autoregressive attention network

self.decoder = Decoder(
Expand All @@ -469,39 +478,88 @@ def __init__(

self.to_logits = nn.Linear(dim, self.codebook_size + 1)

def generate(self):
return self
@property
def device(self):
return next(self.parameters()).device

@eval_decorator
@torch.no_grad()
@beartype
def generate(
self,
prompt: Optional[Tensor] = None,
batch_size: Optional[int] = None,
filter_logits_fn: Callable = top_k,
filter_kwargs: dict = dict(),
temperature = 1.,
return_codes = False
):
if exists(prompt):
assert not exists(batch_size)

prompt = rearrange(prompt, 'b ... -> b (...)')
batch_size = prompt.shape[0]

batch_size = default(batch_size, 1)

codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device))

curr_length = codes.shape[-1]

for i in range(curr_length, self.max_seq_len):
can_eos = (i + 1) % (self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residusl VQ codes

logits = self.forward(codes, return_loss = False)
logits = logits[:, -1]

if not can_eos:
logits[:, -1] = -torch.finfo(logits.dtype).max

filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim = -1)

sample = torch.multinomial(probs, 1)
codes, _ = pack([codes, sample], 'b *')

if return_codes:
codes = codes[:, 1:] # remove sos
codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
return codes

return self.autoencoder.decode_from_codes_to_faces(codes)

def forward(
self,
codes,
codes = None,
return_loss = True
):
codes = rearrange(codes, 'b ... -> b (...)')
if codes.ndim > 2:
codes = rearrange(codes, 'b ... -> b (...)')

batch, seq_len, device = *codes.shape, codes.device
assert divisible_by(seq_len, self.num_quantizers)

assert seq_len <= self.max_seq_len

if return_loss:
codes, labels = codes[:, :-1], codes

codes = self.token_embed(codes)

# auto append sos token

sos = repeat(self.sos_token, 'd -> b d', b = batch)
codes, _ = pack([sos, codes], 'b * d')

# codebook embed + absolute positions

seq_arange = torch.arange(seq_len, device = device)
seq_arange = torch.arange(codes.shape[-2], device = device)

codes = codes + self.abs_pos_emb(seq_arange)

# embedding for quantizer level

level_embed = repeat(self.quantize_level_embed, 'n d -> (r n) d', r = seq_len // self.num_quantizers)
codes = codes + level_embed
level_embed = repeat(self.quantize_level_embed, 'n d -> (r n) d', r = ceil(seq_len / self.num_quantizers))
codes = codes + level_embed[:codes.shape[1]]

# auto append sos token

sos = repeat(self.sos_token, 'd -> b d', b = batch)
codes, _ = pack([sos, codes], 'b * d')

# attention

Expand All @@ -511,12 +569,6 @@ def forward(

logits = self.to_logits(attended)

if self.num_quantizers > 1:
eos_mask = ((torch.arange(seq_len, device = device) - 1) % self.num_quantizers == 0)
eos_mask = rearrange(eos_mask, 'n -> 1 n 1')
eos_mask = F.pad(eos_mask, (logits.shape[-1] - 1, 0), value = False)
logits = logits.masked_fill(eos_mask, -torch.finfo(logits.dtype).max)

if not return_loss:
return logits

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'meshgpt-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'MeshGPT Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit cc7becc

Please sign in to comment.