2323 Float8ColwiseParallel ,
2424 Float8RowwiseParallel ,
2525)
26+ from torch .distributed .tensor .parallel import (
27+ ColwiseParallel ,
28+ parallelize_module ,
29+ PrepareModuleInput ,
30+ RowwiseParallel ,
31+ SequenceParallel ,
32+ )
2633from float8_experimental .float8_utils import tensor_to_scale
2734from torch .distributed ._tensor import distribute_tensor , DTensor , Replicate , Shard
2835from torch .distributed .device_mesh import DeviceMesh , init_device_mesh
@@ -221,6 +228,73 @@ def test_fp8_mlp_tensor_parallelism_base(
221228 tp_model .out_proj .weight .grad , sp_model .out_proj .weight .grad
222229 )
223230
231+ def get_cuda_mem_allocated_gb ():
232+ return torch .cuda .max_memory_allocated () / 1e9
233+
234+ class EmbLNLinear (nn .Module ):
235+ def __init__ (self , dim0 , dim1 , dim2 ):
236+ super ().__init__ ()
237+ self .emb = nn .Embedding (dim0 , dim1 )
238+ self .ln = nn .LayerNorm (dim1 )
239+ self .fc = nn .Linear (dim1 , dim2 )
240+
241+ def forward (self , x ):
242+ x = self .emb (x )
243+ x = self .ln (x )
244+ x = self .fc (x )
245+ return x
246+
247+ def test_fp8_compile_tp_sp_oom (
248+ mesh : DeviceMesh , size = 16 , compile : bool = False
249+ ):
250+ """
251+ A standalone repro of the OOM we observed on LLaMa 3 8B in torchtitan
252+ with float8, compile, TP and SP on. When you run this test you should
253+ see a memory leak, as evidenced by printouts of cuda memory used as well
254+ as tensors not beeing freed in dumped the memory snapshot.
255+
256+ TODO: root cause the issue and write a better test once we fix it.
257+ """
258+
259+ vocab_size = 128256
260+ model_dim = 4096
261+ device = mesh .device_type
262+ bsz = 1
263+ world_size = mesh .size ()
264+
265+ m = EmbLNLinear (vocab_size , model_dim , model_dim ).cuda ()
266+ m = swap_linear_with_float8_linear (
267+ m , Float8DynamicLinear , emulate = True
268+ )
269+
270+ tokens = torch .ones (bsz , model_dim * world_size , device = device , dtype = torch .int64 )
271+
272+ m = parallelize_module (
273+ m ,
274+ mesh ,
275+ {
276+ "emb" : RowwiseParallel (
277+ input_layouts = Replicate (),
278+ output_layouts = Shard (1 ),
279+ ),
280+ "ln" : SequenceParallel (),
281+ "fc" : Float8ColwiseParallel (
282+ input_layouts = Shard (1 ),
283+ output_layouts = Replicate (),
284+ use_local_output = True ,
285+ ),
286+ },
287+ )
288+
289+ m = torch .compile (m , dynamic = False )
290+ torch .cuda .memory ._record_memory_history ()
291+ for i in range (100 ):
292+ print (i , get_cuda_mem_allocated_gb ())
293+ y = m (tokens )
294+ y .sum ().backward ()
295+ torch .cuda .memory ._dump_snapshot ("dtensor_test_memory.pickle" )
296+ print ('done' )
297+
224298
225299def test_fp8_mlp_tensor_parallelism_compile (mesh : DeviceMesh , size = 16 ):
226300 test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = True )
@@ -231,6 +305,11 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
231305 # other test files to not use TestCase but instead just add the test
232306 # cases in the main func.
233307 device_mesh = setup_distributed ()
308+ test_fp8_compile_tp_sp_oom (device_mesh )
309+ # TODO(before land): remove early return, this is for debugging only
310+ import sys ; sys .exit (0 )
311+
312+
234313 tests = [
235314 test_scaled_mm ,
236315 test_fp8_redistribute ,
0 commit comments