From a310cc775feefa27ddc3c86e34dd06b81cc46fea Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Fri, 22 Nov 2024 15:39:27 -0800 Subject: [PATCH] Cleanup the test --- sharktank/tests/layers/mmdit_test.py | 44 ++-------------------------- 1 file changed, 3 insertions(+), 41 deletions(-) diff --git a/sharktank/tests/layers/mmdit_test.py b/sharktank/tests/layers/mmdit_test.py index fa40abea0..923d640d0 100644 --- a/sharktank/tests/layers/mmdit_test.py +++ b/sharktank/tests/layers/mmdit_test.py @@ -15,13 +15,9 @@ from iree.turbine import aot from sharktank.layers import ( MMDITDoubleBlock, - PagedLlamaAttentionBlock, - PagedKVCache, - RotaryEmbeddingLayer, ) import sharktank.ops as ops from sharktank.layers.testing import ( - make_llama_attention_block_theta, make_mmdit_block_theta, ) from sharktank.types.tensors import DefaultPrimitiveTensor @@ -32,24 +28,10 @@ def setUp(self): torch.manual_seed(12345) self.hidden_size = 3072 self.num_heads = 24 - - self.transformer_block_count = 13 - self.block_index = 1 - self.shard_count = 3 - self.head_count_kv = 2 * self.shard_count - self.attention_head_count = 5 * self.head_count_kv - self.attention_head_dim = 24 - self.rms_epsilon = 0.01 - self.cache_partition_count = 2 - self.page_count = 23 - self.embedding_length = self.attention_head_count * self.attention_head_dim - self.rope_dimension_count = self.attention_head_dim self.block_seqlen = 7 self.block_seq_stride = 17 self.max_seqlen = self.block_seq_stride * self.block_seqlen - self.rope_freq_base = None self.batch_size = 3 - self.start_index = 0 def testExport(self): dtype = torch.float32 @@ -66,36 +48,16 @@ def testExport(self): num_heads=self.num_heads, ) - seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( - self.batch_size, -1 - ) - - embedding_module = RotaryEmbeddingLayer( - rope_dimension_count=self.rope_dimension_count, - max_seqlen=self.max_seqlen, - rope_freq_base=self.rope_freq_base, - ) - - class MyModule(torch.nn.Module): - def forward(self, img, txt, vec, pe): - return mmdit.forward( - img, - txt, - vec, - pe, - ) - - mod = MyModule() img = torch.rand([self.batch_size, 1024, self.hidden_size]) txt = torch.rand([self.batch_size, 512, self.hidden_size]) vec = torch.rand([self.batch_size, self.hidden_size]) rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2]) - mod.forward(img, txt, vec, rot) - fxb = aot.FxProgramsBuilder(mod) + mmdit.forward(img, txt, vec, rot) + fxb = aot.FxProgramsBuilder(mmdit) @fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False) def _(model, img, txt, vec, rot) -> torch.Tensor: - return model(img, txt, vec, rot) + return model.forward(img, txt, vec, rot) output = aot.export(fxb) output.verify()