Skip to content

Commit

Permalink
automatically take care of eos
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 5, 2023
1 parent fa17e84 commit a57f118
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ faces_coordinates = transformer.generate()
- [x] properly mask out eos logit during generation
- [x] make sure it trains
- [x] take care of sos token automatically
- [ ] take care of eos token automatically if sequence length or mask is passed in
- [x] take care of eos token automatically if sequence length or mask is passed in
- [ ] generation + cache kv
- [ ] speculative decoding option
- [ ] hierarchical transformers (using the RQ transformer)
Expand Down
25 changes: 21 additions & 4 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def generate(
for i in range(curr_length, self.max_seq_len):
can_eos = divisible_by(i + 1, self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residusl VQ codes

logits = self.forward(codes, return_loss = False)
logits = self.forward(codes, return_loss = False, append_eos = False)
logits = logits[:, -1]

if not can_eos:
Expand All @@ -539,7 +539,9 @@ def generate(
def forward(
self,
codes = None,
return_loss = True
return_loss = True,
append_eos = False,
code_lens: Optional[Tensor] = None # needed for inserting eos automatically for variable lengthed meshes
):
if codes.ndim > 2:
codes = rearrange(codes, 'b ... -> b (...)')
Expand All @@ -548,10 +550,25 @@ def forward(

assert seq_len <= self.max_seq_len

# auto append eos token

if append_eos:
code_lens = default(code_lens, torch.full((batch, 1), seq_len, device = device))
codes = F.pad(codes, (0, 1), value = 0)

batch_arange = torch.arange(batch, device = device)
batch_arange = rearrange(batch_arange, '... -> ... 1')

codes[batch_arange, code_lens] = self.eos_token_id

# if returning loss, save the labels for cross entropy

if return_loss:
assert seq_len > 0
codes, labels = codes[:, :-1], codes

# token embed (each residual VQ id)

codes = self.token_embed(codes)

# codebook embed + absolute positions
Expand All @@ -564,12 +581,12 @@ def forward(

code_len = codes.shape[1]

level_embed = repeat(self.quantize_level_embed, 'q d -> (r q) d', r = ceil(seq_len / self.num_quantizers))
level_embed = repeat(self.quantize_level_embed, 'q d -> (r q) d', r = ceil(code_len / self.num_quantizers))
codes = codes + level_embed[:code_len]

# embedding for each vertex

vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(seq_len / (3 * self.num_quantizers)), q = self.num_quantizers)
vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (3 * self.num_quantizers)), q = self.num_quantizers)
codes = codes + vertex_embed[:code_len]

# auto prepend sos token
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'meshgpt-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'MeshGPT Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a57f118

Please sign in to comment.