2323from absl .testing import absltest
2424from flax import nnx
2525from jax .sharding import Mesh
26-
26+ from flax . linen import partitioning as nn_partitioning
2727from .. import pyconfig
2828from ..max_utils import (create_device_mesh , get_flash_block_sizes )
2929from ..models .wan .transformers .transformer_wan import (
@@ -48,6 +48,18 @@ class WanTransformerTest(unittest.TestCase):
4848
4949 def setUp (self ):
5050 WanTransformerTest .dummy_data = {}
51+ pyconfig .initialize (
52+ [
53+ None ,
54+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
55+ ],
56+ unittest = True ,
57+ )
58+ config = pyconfig .config
59+ self .config = config
60+ devices_array = create_device_mesh (config )
61+ self .mesh = Mesh (devices_array , config .mesh_axes )
62+
5163
5264 def test_rotary_pos_embed (self ):
5365 batch_size = 1
@@ -65,28 +77,31 @@ def test_nnx_pixart_alpha_text_projection(self):
6577 key = jax .random .key (0 )
6678 rngs = nnx .Rngs (key )
6779 dummy_caption = jnp .ones ((1 , 512 , 4096 ))
68- layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
69- dummy_output = layer (dummy_caption )
70- dummy_output .shape == (1 , 512 , 5120 )
80+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
81+ layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
82+ dummy_output = layer (dummy_caption )
83+ dummy_output .shape == (1 , 512 , 5120 )
7184
7285 def test_nnx_timestep_embedding (self ):
7386 key = jax .random .key (0 )
7487 rngs = nnx .Rngs (key )
7588
7689 dummy_sample = jnp .ones ((1 , 256 ))
77- layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
78- dummy_output = layer (dummy_sample )
79- assert dummy_output .shape == (1 , 5120 )
90+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
91+ layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
92+ dummy_output = layer (dummy_sample )
93+ assert dummy_output .shape == (1 , 5120 )
8094
8195 def test_fp32_layer_norm (self ):
8296 key = jax .random .key (0 )
8397 rngs = nnx .Rngs (key )
8498 batch_size = 1
8599 dummy_hidden_states = jnp .ones ((batch_size , 75600 , 5120 ))
86100 # expected same output shape with same dtype
87- layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
88- dummy_output = layer (dummy_hidden_states )
89- assert dummy_output .shape == dummy_hidden_states .shape
101+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
102+ layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
103+ dummy_output = layer (dummy_hidden_states )
104+ assert dummy_output .shape == dummy_hidden_states .shape
90105
91106 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
92107 def test_wan_time_text_embedding (self ):
@@ -97,20 +112,21 @@ def test_wan_time_text_embedding(self):
97112 time_freq_dim = 256
98113 time_proj_dim = 30720
99114 text_embed_dim = 4096
100- layer = WanTimeTextImageEmbedding (
101- rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
102- )
115+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
116+ layer = WanTimeTextImageEmbedding (
117+ rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
118+ )
103119
104- dummy_timestep = jnp .ones (batch_size )
120+ dummy_timestep = jnp .ones (batch_size )
105121
106- encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
107- dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
108- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
109- dummy_timestep , dummy_encoder_hidden_states
110- )
111- assert temb .shape == (batch_size , dim )
112- assert timestep_proj .shape == (batch_size , time_proj_dim )
113- assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
122+ encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
123+ dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
124+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
125+ dummy_timestep , dummy_encoder_hidden_states
126+ )
127+ assert temb .shape == (batch_size , dim )
128+ assert timestep_proj .shape == (batch_size , time_proj_dim )
129+ assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
114130
115131 def test_wan_block (self ):
116132 key = jax .random .key (0 )
@@ -158,20 +174,19 @@ def test_wan_block(self):
158174 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , dim ))
159175
160176 dummy_temb = jnp .ones ((batch_size , 6 , dim ))
161-
162- wan_block = WanTransformerBlock (
163- rngs = rngs ,
164- dim = dim ,
165- ffn_dim = ffn_dim ,
166- num_heads = num_heads ,
167- qk_norm = qk_norm ,
168- cross_attn_norm = cross_attn_norm ,
169- eps = eps ,
170- attention = "flash" ,
171- mesh = mesh ,
172- flash_block_sizes = flash_block_sizes ,
173- )
174- with mesh :
177+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
178+ wan_block = WanTransformerBlock (
179+ rngs = rngs ,
180+ dim = dim ,
181+ ffn_dim = ffn_dim ,
182+ num_heads = num_heads ,
183+ qk_norm = qk_norm ,
184+ cross_attn_norm = cross_attn_norm ,
185+ eps = eps ,
186+ attention = "flash" ,
187+ mesh = mesh ,
188+ flash_block_sizes = flash_block_sizes ,
189+ )
175190 dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
176191 assert dummy_output .shape == dummy_hidden_states .shape
177192
@@ -204,40 +219,39 @@ def test_wan_attention(self):
204219 mesh = Mesh (devices_array , config .mesh_axes )
205220 batch_size = 1
206221 query_dim = 5120
207- attention = FlaxWanAttention (
208- rngs = rngs ,
209- query_dim = query_dim ,
210- heads = 40 ,
211- dim_head = 128 ,
212- attention_kernel = "flash" ,
213- mesh = mesh ,
214- flash_block_sizes = flash_block_sizes ,
215- )
216-
217- dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
218-
219- dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
220- dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
221- with mesh :
222- dummy_output = attention (
223- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
224- )
225- assert dummy_output .shape == dummy_hidden_states_shape
226-
227- # dot product
228- try :
222+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
229223 attention = FlaxWanAttention (
230224 rngs = rngs ,
231225 query_dim = query_dim ,
232226 heads = 40 ,
233227 dim_head = 128 ,
234- attention_kernel = "dot_product" ,
235- split_head_dim = True ,
228+ attention_kernel = "flash" ,
236229 mesh = mesh ,
237230 flash_block_sizes = flash_block_sizes ,
238231 )
239- except NotImplementedError :
240- pass
232+ dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
233+
234+ dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
235+ dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
236+ dummy_output = attention (
237+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
238+ )
239+ assert dummy_output .shape == dummy_hidden_states_shape
240+
241+ # dot product
242+ try :
243+ attention = FlaxWanAttention (
244+ rngs = rngs ,
245+ query_dim = query_dim ,
246+ heads = 40 ,
247+ dim_head = 128 ,
248+ attention_kernel = "dot_product" ,
249+ split_head_dim = True ,
250+ mesh = mesh ,
251+ flash_block_sizes = flash_block_sizes ,
252+ )
253+ except NotImplementedError :
254+ pass
241255
242256 @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
243257 def test_wan_model (self ):
@@ -267,7 +281,8 @@ def test_wan_model(self):
267281 mesh = Mesh (devices_array , config .mesh_axes )
268282 batch_size = 1
269283 num_layers = 1
270- wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
284+ with nn_partitioning .axis_rules (config .logical_axis_rules ):
285+ wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
271286
272287 dummy_timestep = jnp .ones ((batch_size ))
273288 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , 4096 ))
0 commit comments