Skip to content

Commit

Permalink
add a test for caching
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 20, 2024
1 parent 36d8dba commit 443ebc7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
1 change: 1 addition & 0 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,7 @@ def __init__(
self.conditioner = None

cross_attn_dim_context = None
dim_text = None

if condition_on_text:
self.conditioner = TextEmbeddingReturner(
Expand Down
2 changes: 1 addition & 1 deletion meshgpt_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.1'
__version__ = '1.5.2'
22 changes: 20 additions & 2 deletions tests/test_meshgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_readme(adaptive_rmsnorm):
# mock inputs

vertices = torch.randn((2, 121, 3)) # (batch, num vertices, coor (3))
faces = torch.randint(0, 121, (2, 64, 3)) # (batch, num faces, vertices (3))
faces = torch.randint(0, 121, (2, 2, 3)) # (batch, num faces, vertices (3))

# forward in the faces

Expand All @@ -33,7 +33,7 @@ def test_readme(adaptive_rmsnorm):
transformer = MeshTransformer(
autoencoder,
dim = 512,
max_seq_len = 768,
max_seq_len = 60,
num_sos_tokens = 1,
fine_cross_attend_text = True,
text_cond_with_film = False,
Expand All @@ -51,3 +51,21 @@ def test_readme(adaptive_rmsnorm):
loss.backward()

faces_coordinates, face_mask = transformer.generate(texts = ['a small chair'], cond_scale = 3.)

def test_cache():
# test that the output for generation with and without kv (and optional gateloop) cache is equivalent

autoencoder = MeshAutoencoder(
num_discrete_coors = 128
)

transformer = MeshTransformer(
autoencoder,
dim = 512,
max_seq_len = 12
)

uncached_faces_coors, _ = transformer.generate(cache_kv = False, temperature = 0)
cached_faces_coors, _ = transformer.generate(cache_kv = True, temperature = 0)

assert torch.allclose(uncached_faces_coors, cached_faces_coors)

0 comments on commit 443ebc7

Please sign in to comment.