Skip to content

Commit

Permalink
can also train transformer off raw face data, if autoencoder has fini…
Browse files Browse the repository at this point in the history
…shed training
  • Loading branch information
lucidrains committed Dec 5, 2023
1 parent a57f118 commit 24a9873
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
21 changes: 20 additions & 1 deletion meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,27 @@ def generate(
codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
return codes

self.autoencoder.eval()
return self.autoencoder.decode_from_codes_to_faces(codes)

def forward_from_raw_face_data(
self,
*,
vertices: TensorType['b', 'nv', 3, int],
faces: TensorType['b', 'nf', 3, int],
face_edges: TensorType['b', 2, 'e', int],
**kwargs
):
with torch.no_grad():
self.autoencoder.eval()
codes = self.autoencoder.tokenize(
vertices = vertices,
faces = faces,
face_edges = face_edges
)

return self.forward(codes, **kwargs)

def forward(
self,
codes = None,
Expand All @@ -548,7 +567,7 @@ def forward(

batch, seq_len, device = *codes.shape, codes.device

assert seq_len <= self.max_seq_len
assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}'

# auto append eos token

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'meshgpt-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'MeshGPT Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 24a9873

Please sign in to comment.