Skip to content

Commit

Permalink
Cleanup the test
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Nov 22, 2024
1 parent 8727221 commit a310cc7
Showing 1 changed file with 3 additions and 41 deletions.
44 changes: 3 additions & 41 deletions sharktank/tests/layers/mmdit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
from iree.turbine import aot
from sharktank.layers import (
MMDITDoubleBlock,
PagedLlamaAttentionBlock,
PagedKVCache,
RotaryEmbeddingLayer,
)
import sharktank.ops as ops
from sharktank.layers.testing import (
make_llama_attention_block_theta,
make_mmdit_block_theta,
)
from sharktank.types.tensors import DefaultPrimitiveTensor
Expand All @@ -32,24 +28,10 @@ def setUp(self):
torch.manual_seed(12345)
self.hidden_size = 3072
self.num_heads = 24

self.transformer_block_count = 13
self.block_index = 1
self.shard_count = 3
self.head_count_kv = 2 * self.shard_count
self.attention_head_count = 5 * self.head_count_kv
self.attention_head_dim = 24
self.rms_epsilon = 0.01
self.cache_partition_count = 2
self.page_count = 23
self.embedding_length = self.attention_head_count * self.attention_head_dim
self.rope_dimension_count = self.attention_head_dim
self.block_seqlen = 7
self.block_seq_stride = 17
self.max_seqlen = self.block_seq_stride * self.block_seqlen
self.rope_freq_base = None
self.batch_size = 3
self.start_index = 0

def testExport(self):
dtype = torch.float32
Expand All @@ -66,36 +48,16 @@ def testExport(self):
num_heads=self.num_heads,
)

seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)

embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

class MyModule(torch.nn.Module):
def forward(self, img, txt, vec, pe):
return mmdit.forward(
img,
txt,
vec,
pe,
)

mod = MyModule()
img = torch.rand([self.batch_size, 1024, self.hidden_size])
txt = torch.rand([self.batch_size, 512, self.hidden_size])
vec = torch.rand([self.batch_size, self.hidden_size])
rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2])
mod.forward(img, txt, vec, rot)
fxb = aot.FxProgramsBuilder(mod)
mmdit.forward(img, txt, vec, rot)
fxb = aot.FxProgramsBuilder(mmdit)

@fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False)
def _(model, img, txt, vec, rot) -> torch.Tensor:
return model(img, txt, vec, rot)
return model.forward(img, txt, vec, rot)

output = aot.export(fxb)
output.verify()
Expand Down

0 comments on commit a310cc7

Please sign in to comment.