33
44import torch
55from torch .distributed .tensor import DeviceMesh , Shard , distribute_tensor
6- from torch .distributed .tensor .placement_types import Replicate
76
87import torch_xla
98import torch_xla .runtime as xr
10- from torch_xla .distributed .spmd import XLAShardedTensor
11- from torch_xla .distributed .spmd .xla_sharding import wrap_as_sharded_tensor
129
1310import unittest
1411import test_xla_sharding_base
@@ -34,6 +31,7 @@ def test_xla_to_dtensor_spec_conversion(self):
3431 mesh = DeviceMesh ("xla" , list (range (device_count )))
3532
3633 # Test different sharding patterns
34+ from torch .distributed .tensor .placement_types import Replicate
3735 test_cases = [
3836 (torch .randn (100 , 50 ), [Shard (0 )]),
3937 (torch .randn (100 , 50 ), [Shard (1 )]),
@@ -66,20 +64,30 @@ def test_mesh_conversion(self):
6664 assert converted_spec .mesh .shape == original_mesh .shape
6765
6866 def test_spec_caching (self ):
69- """Test that _spec property caches results
70- """
67+ """Test that _spec property caches results for better performance"""
68+ import time
7169 device_count = xr .global_runtime_device_count ()
7270 mesh = DeviceMesh ("xla" , list (range (device_count )))
73- tensor = torch .randn (100 , 100 )
71+ tensor = torch .randn (1000 ,
72+ 1000 ) # Large tensor to make spec creation noticeable
7473 xla_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
7574
75+ # first access should create and cache the spec
76+ start_time = time .time ()
7677 spec1 = xla_tensor ._spec
78+ first_access_time = time .time () - start_time
7779
78- assert xla_tensor ._cached_spec is not None
79- assert xla_tensor ._cached_spec is spec1
80-
80+ # should be much faster due to caching
81+ start_time = time .time ()
8182 spec2 = xla_tensor ._spec
83+ second_access_time = time .time () - start_time
84+
8285 assert spec1 is spec2
86+ print (
87+ f"First access: { first_access_time :.6f} s, Second access: { second_access_time :.6f} s"
88+ )
89+ assert second_access_time * 10 < first_access_time , \
90+ f"Cached access should be much faster: { first_access_time :.6f} s vs { second_access_time :.6f} s"
8391
8492 def _create_test_tensor_and_mesh (self , tensor_shape , mesh_shape , placements ):
8593 """Helper to create tensor and mesh for testing"""
@@ -106,8 +114,22 @@ def test_multi_dim_sharding_spec(self):
106114 assert len (spec .placements ) == 2
107115 assert spec .mesh .ndim == 2
108116
117+ def test_tensor_operations_preserve_spec (self ):
118+ """Test that tensor operations preserve sharding metadata"""
119+ xla_tensor , mesh = self ._create_test_tensor_and_mesh ((100 , 50 ), (- 1 ,),
120+ [Shard (0 )])
121+
122+ result_add = xla_tensor + 1
123+ result_mul = xla_tensor * 2
124+ result_relu = torch .relu (xla_tensor )
125+
126+ for result in [result_add , result_mul , result_relu ]:
127+ assert hasattr (result , '_spec' )
128+ assert result ._spec .mesh .device_type == "xla"
129+
109130 def test_mixed_placement_spec (self ):
110131 """Test _spec for tensors with mixed shard/replicate placements"""
132+ from torch .distributed .tensor .placement_types import Replicate
111133 device_count = xr .global_runtime_device_count ()
112134 if device_count < 4 :
113135 self .skipTest ("Need at least 4 devices for 2D mesh" )
@@ -121,114 +143,6 @@ def test_mixed_placement_spec(self):
121143 assert isinstance (spec .placements [0 ], Shard )
122144 assert isinstance (spec .placements [1 ], Replicate )
123145
124- def test_sharding_info_acquisition (self ):
125- """Test that non-XLAShardedTensor can acquire sharding information
126-
127- Tests case of 'elem is not an XLAShardedTensor but there exists
128- sharding information we want to acquire'
129- """
130-
131- device_count = xr .global_runtime_device_count ()
132- mesh_shape = (device_count ,)
133- partition_spec = (0 , None )
134-
135- regular_tensor = torch .randn (100 , 50 ).to ('xla' )
136-
137- sharded_tensor = wrap_as_sharded_tensor (
138- regular_tensor , mesh_shape = mesh_shape , partition_spec = partition_spec )
139-
140- # Verify the tensor acquired the sharding information
141- assert isinstance (sharded_tensor , XLAShardedTensor )
142- assert sharded_tensor .mesh_shape == mesh_shape
143- assert sharded_tensor .partition_spec == partition_spec
144-
145- def test_resharding_logic (self ):
146- """
147- Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
148- """
149-
150- device_count = xr .global_runtime_device_count ()
151- if device_count < 4 :
152- self .skipTest ("Need at least 4 devices for resharding test" )
153-
154- # Initial sharding
155- initial_mesh_shape = (device_count ,)
156- initial_partition_spec = (0 , None )
157- new_mesh_shape = (2 , device_count // 2 )
158- new_partition_spec = (0 , 1 )
159-
160- # Create tensor and verify resharding
161- tensor = torch .randn (100 , 50 ).to ('xla' )
162- sharded_tensor = wrap_as_sharded_tensor (
163- tensor ,
164- mesh_shape = initial_mesh_shape ,
165- partition_spec = initial_partition_spec )
166- initial_spec = sharded_tensor ._spec
167-
168- resharded_tensor = wrap_as_sharded_tensor (
169- sharded_tensor ,
170- mesh_shape = new_mesh_shape ,
171- partition_spec = new_partition_spec )
172-
173- # Verify resharding worked and cache was invalidated
174- assert resharded_tensor .mesh_shape == new_mesh_shape
175- assert resharded_tensor .partition_spec == new_partition_spec
176- assert resharded_tensor ._spec is not initial_spec
177-
178- def test_spec_invalidation_on_resharding (self ):
179- """Tests cases where the cached spec may become outdated.
180- """
181-
182- device_count = xr .global_runtime_device_count ()
183- if device_count < 4 :
184- self .skipTest ("Need at least 4 devices for resharding test" )
185-
186- tensor = torch .randn (100 , 50 ).to ('xla' )
187- initial_mesh_shape = (device_count ,)
188- initial_partition_spec = (0 , None )
189- new_mesh_shape = (2 , device_count // 2 )
190- new_partition_spec = (0 , 1 )
191-
192- sharded_tensor = wrap_as_sharded_tensor (
193- tensor ,
194- mesh_shape = initial_mesh_shape ,
195- partition_spec = initial_partition_spec )
196- initial_spec = sharded_tensor ._spec
197- assert sharded_tensor ._cached_spec is not None
198-
199- # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
200- resharded_tensor = wrap_as_sharded_tensor (
201- sharded_tensor ,
202- mesh_shape = new_mesh_shape ,
203- partition_spec = initial_partition_spec )
204- assert resharded_tensor ._spec is not initial_spec
205- assert resharded_tensor ._spec .mesh .shape == new_mesh_shape
206-
207- initial_spec = resharded_tensor ._spec
208- resharded_tensor = wrap_as_sharded_tensor (
209- resharded_tensor ,
210- mesh_shape = new_mesh_shape ,
211- partition_spec = new_partition_spec )
212- assert resharded_tensor ._spec is not initial_spec
213- assert resharded_tensor ._spec .placements [1 ].dim == 1
214-
215- def test_auto_wrapped_tensor_spec_failure (self ):
216- """Test that auto-wrapped tensors fail when accessing _spec property.
217-
218- Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219- but don't yet have access to the sharding propagation done through open xla,
220- causing ._spec to fail.
221- """
222- device_count = xr .global_runtime_device_count ()
223- mesh = DeviceMesh ("xla" , torch .arange (device_count ))
224- tensor = torch .randn (4 , 4 )
225- sharded_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
226-
227- auto_wrapped = sharded_tensor + sharded_tensor
228-
229- with self .assertRaises (ValueError ):
230- _ = auto_wrapped ._spec
231-
232146
233147if __name__ == '__main__' :
234148 test = unittest .main ()
0 commit comments