33
44import torch
55from torch .distributed .tensor import DeviceMesh , Shard , distribute_tensor
6+ from torch .distributed .tensor .placement_types import Replicate
67
78import torch_xla
89import torch_xla .runtime as xr
10+ from torch_xla .distributed .spmd import XLAShardedTensor
11+ from torch_xla .distributed .spmd .xla_sharding import wrap_as_sharded_tensor
912
1013import unittest
1114import test_xla_sharding_base
@@ -31,7 +34,6 @@ def test_xla_to_dtensor_spec_conversion(self):
3134 mesh = DeviceMesh ("xla" , list (range (device_count )))
3235
3336 # Test different sharding patterns
34- from torch .distributed .tensor .placement_types import Replicate
3537 test_cases = [
3638 (torch .randn (100 , 50 ), [Shard (0 )]),
3739 (torch .randn (100 , 50 ), [Shard (1 )]),
@@ -64,30 +66,27 @@ def test_mesh_conversion(self):
6466 assert converted_spec .mesh .shape == original_mesh .shape
6567
6668 def test_spec_caching (self ):
67- """Test that _spec property caches results for better performance"""
68- import time
69+ """Test that _spec property caches results
70+
71+ Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to
72+ annoying flakes in my experience. I think it's sufficient to just test that
73+ self._cached_spec has a permanent value after the first call."
74+ """
6975 device_count = xr .global_runtime_device_count ()
7076 mesh = DeviceMesh ("xla" , list (range (device_count )))
71- tensor = torch .randn (1000 ,
72- 1000 ) # Large tensor to make spec creation noticeable
77+ tensor = torch .randn (100 , 100 )
7378 xla_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
7479
75- # first access should create and cache the spec
76- start_time = time .time ()
80+ # First access should create and cache the spec
7781 spec1 = xla_tensor ._spec
78- first_access_time = time .time () - start_time
7982
80- # should be much faster due to caching
81- start_time = time .time ()
82- spec2 = xla_tensor ._spec
83- second_access_time = time .time () - start_time
83+ # Verify the spec is cached
84+ assert xla_tensor ._cached_spec is not None
85+ assert xla_tensor ._cached_spec is spec1
8486
87+ # Second access should return the cached spec
88+ spec2 = xla_tensor ._spec
8589 assert spec1 is spec2
86- print (
87- f"First access: { first_access_time :.6f} s, Second access: { second_access_time :.6f} s"
88- )
89- assert second_access_time * 10 < first_access_time , \
90- f"Cached access should be much faster: { first_access_time :.6f} s vs { second_access_time :.6f} s"
9190
9291 def _create_test_tensor_and_mesh (self , tensor_shape , mesh_shape , placements ):
9392 """Helper to create tensor and mesh for testing"""
@@ -114,22 +113,8 @@ def test_multi_dim_sharding_spec(self):
114113 assert len (spec .placements ) == 2
115114 assert spec .mesh .ndim == 2
116115
117- def test_tensor_operations_preserve_spec (self ):
118- """Test that tensor operations preserve sharding metadata"""
119- xla_tensor , mesh = self ._create_test_tensor_and_mesh ((100 , 50 ), (- 1 ,),
120- [Shard (0 )])
121-
122- result_add = xla_tensor + 1
123- result_mul = xla_tensor * 2
124- result_relu = torch .relu (xla_tensor )
125-
126- for result in [result_add , result_mul , result_relu ]:
127- assert hasattr (result , '_spec' )
128- assert result ._spec .mesh .device_type == "xla"
129-
130116 def test_mixed_placement_spec (self ):
131117 """Test _spec for tensors with mixed shard/replicate placements"""
132- from torch .distributed .tensor .placement_types import Replicate
133118 device_count = xr .global_runtime_device_count ()
134119 if device_count < 4 :
135120 self .skipTest ("Need at least 4 devices for 2D mesh" )
@@ -143,6 +128,97 @@ def test_mixed_placement_spec(self):
143128 assert isinstance (spec .placements [0 ], Shard )
144129 assert isinstance (spec .placements [1 ], Replicate )
145130
131+ def test_sharding_info_acquisition (self ):
132+ """Test that non-XLAShardedTensor can acquire sharding information
133+
134+ Tests case of 'elem is not an XLAShardedTensor but there exists
135+ sharding information we want to acquire'
136+ """
137+
138+ device_count = xr .global_runtime_device_count ()
139+ mesh_shape = (device_count ,)
140+ partition_spec = (0 , None )
141+
142+ regular_tensor = torch .randn (100 , 50 ).to ('xla' )
143+
144+ sharded_tensor = wrap_as_sharded_tensor (
145+ regular_tensor , mesh_shape = mesh_shape , partition_spec = partition_spec )
146+
147+ # Verify the tensor acquired the sharding information
148+ assert isinstance (sharded_tensor , XLAShardedTensor )
149+ assert sharded_tensor .mesh_shape == mesh_shape
150+ assert sharded_tensor .partition_spec == partition_spec
151+
152+ def test_resharding_logic (self ):
153+ """
154+ Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
155+ """
156+
157+ device_count = xr .global_runtime_device_count ()
158+ if device_count < 4 :
159+ self .skipTest ("Need at least 4 devices for resharding test" )
160+
161+ # Initial sharding
162+ initial_mesh_shape = (device_count ,)
163+ initial_partition_spec = (0 , None )
164+ new_mesh_shape = (2 , device_count // 2 )
165+ new_partition_spec = (0 , 1 )
166+
167+ # Create tensor and verify resharding
168+ tensor = torch .randn (100 , 50 ).to ('xla' )
169+ sharded_tensor = wrap_as_sharded_tensor (
170+ tensor ,
171+ mesh_shape = initial_mesh_shape ,
172+ partition_spec = initial_partition_spec )
173+ initial_spec = sharded_tensor ._spec
174+
175+ resharded_tensor = wrap_as_sharded_tensor (
176+ sharded_tensor ,
177+ mesh_shape = new_mesh_shape ,
178+ partition_spec = new_partition_spec )
179+
180+ # Verify resharding worked and cache was invalidated
181+ assert resharded_tensor .mesh_shape == new_mesh_shape
182+ assert resharded_tensor .partition_spec == new_partition_spec
183+ assert resharded_tensor ._spec is not initial_spec
184+
185+ def test_spec_invalidation_on_resharding (self ):
186+ """Tests cases where the cached spec may become outdated.
187+ """
188+
189+ device_count = xr .global_runtime_device_count ()
190+ if device_count < 4 :
191+ self .skipTest ("Need at least 4 devices for resharding test" )
192+
193+ tensor = torch .randn (100 , 50 ).to ('xla' )
194+ initial_mesh_shape = (device_count ,)
195+ initial_partition_spec = (0 , None )
196+ new_mesh_shape = (2 , device_count // 2 )
197+ new_partition_spec = (0 , 1 )
198+
199+ sharded_tensor = wrap_as_sharded_tensor (
200+ tensor ,
201+ mesh_shape = initial_mesh_shape ,
202+ partition_spec = initial_partition_spec )
203+ initial_spec = sharded_tensor ._spec
204+ assert sharded_tensor ._cached_spec is not None
205+
206+ # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
207+ resharded_tensor = wrap_as_sharded_tensor (
208+ sharded_tensor ,
209+ mesh_shape = new_mesh_shape ,
210+ partition_spec = initial_partition_spec )
211+ assert resharded_tensor ._spec is not initial_spec
212+ assert resharded_tensor ._spec .mesh .shape == new_mesh_shape
213+
214+ initial_spec = resharded_tensor ._spec
215+ resharded_tensor = wrap_as_sharded_tensor (
216+ resharded_tensor ,
217+ mesh_shape = new_mesh_shape ,
218+ partition_spec = new_partition_spec )
219+ assert resharded_tensor ._spec is not initial_spec
220+ assert resharded_tensor ._spec .placements [1 ].dim == 1
221+
146222
147223if __name__ == '__main__' :
148224 test = unittest .main ()
0 commit comments