@@ -197,9 +197,8 @@ def __init__(
197
197
self ._outer_optimizer = outer_optimizer
198
198
199
199
# Stores pending all reduce
200
- self ._allreduce_futures : list [
201
- torch .futures .Future [None ] | torch .futures .Future [torch .Tensor ]
202
- ] = []
200
+ self ._allreduce_futures : list [torch .futures .Future [torch .Tensor ]] = []
201
+ self ._stream : torch .cuda .Stream = torch .cuda .Stream ()
203
202
204
203
if bucket_cap_mb is not None :
205
204
self .bucket_cap_mb = int (bucket_cap_mb * 1024 * 1024 )
@@ -222,13 +221,15 @@ def __init__(
222
221
t = t .pin_memory ()
223
222
self .original_parameters [name ] = t
224
223
224
+ @torch .profiler .record_function ("torchft::local_sgd::save_parameters" )
225
225
def save_parameters (self ) -> None :
226
226
with torch .no_grad ():
227
227
# TODO: consider running copy on a separate stream
228
228
for name , p in self ._model_fragment .named_parameters ():
229
229
param_to_local = extract_local_tensor (p .data )
230
230
self .original_parameters [name ].copy_ (param_to_local , non_blocking = True )
231
231
232
+ @torch .profiler .record_function ("torchft::local_sgd::restore_parameters" )
232
233
def restore_parameters (self ) -> None :
233
234
with torch .no_grad ():
234
235
# TODO: consider running copy on a separate stream
@@ -248,6 +249,7 @@ def restore_parameters(self) -> None:
248
249
else :
249
250
p .data .copy_ (self .original_parameters [name ], non_blocking = False )
250
251
252
+ @torch .profiler .record_function ("torchft::local_sgd::wait" )
251
253
def wait (self ) -> None :
252
254
"""
253
255
Waits for the previously scheduled allreduce to finish
@@ -272,22 +274,27 @@ def should_sync_fragment(self, step: int) -> bool:
272
274
step_to_sync = step - self ._fragment_sync_offset - self ._fragment_sync_delay
273
275
return step_to_sync % self ._sync_every == 0
274
276
277
+ @torch .profiler .record_function ("torchft::local_sgd::prepare_sync" )
275
278
def prepare_sync (self ) -> None :
276
279
"""
277
280
Calculate the pseugradient, average them across the manager group and starts
278
281
allreduce on the pseudo-gradients but doesn't wait for it to finish.
279
282
"""
280
- # Set the .grad field of each parameter to its pseudogradient
281
- for name , p in self ._model_fragment .named_parameters ():
282
- local_param = extract_local_tensor (p .data )
283
- pseudogradient = local_param - self .original_parameters [name ].to (p .device )
284
- if isinstance (p , DTensor ):
285
- p .grad ._local_tensor = pseudogradient
286
- else :
287
- p .grad = pseudogradient
283
+ with torch .cuda .stream (self ._stream ):
284
+ # Set the .grad field of each parameter to its pseudogradient
285
+ for name , p in self ._model_fragment .named_parameters ():
286
+ local_param = extract_local_tensor (p .data )
287
+ pseudogradient = local_param - self .original_parameters [name ].to (
288
+ p .device
289
+ )
290
+ if isinstance (p , DTensor ):
291
+ p .grad ._local_tensor = pseudogradient
292
+ else :
293
+ p .grad = pseudogradient
288
294
289
- self ._average_grads ()
295
+ self ._average_grads ()
290
296
297
+ @torch .profiler .record_function ("torchft::local_sgd::perform_sync" )
291
298
def perform_sync (self ) -> bool :
292
299
"""
293
300
Overrides the sync method to wait for the scheduled allreduce to finish and
@@ -297,6 +304,7 @@ def perform_sync(self) -> bool:
297
304
return True
298
305
299
306
self .wait ()
307
+ self ._stream .synchronize ()
300
308
301
309
# Restore the parameters back to the previous state
302
310
self .restore_parameters ()
@@ -467,16 +475,6 @@ def __init__(
467
475
if fragment_update_alpha < 0 or fragment_update_alpha > 1 :
468
476
raise ValueError ("fragment_update_alpha must be between 0 and 1" )
469
477
470
- # TODO: Support multiple fragments
471
- # This requires changing the manager to support `should_commit` for each
472
- # fragment separately.
473
- if len (model_fragments ) != 1 :
474
- raise ValueError ("Multiple fragments are not supported yet" )
475
-
476
- # TODO: Support `fragment_sync_delay`
477
- if fragment_sync_delay != 0 :
478
- raise ValueError ("Fragment synchronization delay is not supported yet" )
479
-
480
478
# TODO: Support `fragment_update_alpha`
481
479
if fragment_update_alpha != 0.0 :
482
480
raise ValueError (
@@ -522,6 +520,8 @@ def __init__(
522
520
use_bucketization ,
523
521
bucket_cap_mb ,
524
522
should_quantize ,
523
+ fragment_sync_delay ,
524
+ fragment_update_alpha ,
525
525
)
526
526
for i , model_fragment in enumerate (model_fragments )
527
527
]
@@ -606,16 +606,20 @@ def _step_post_hook(
606
606
step = self ._local_step
607
607
608
608
# Start sending fragments
609
- for fragment in self ._fragments :
609
+ for i , fragment in enumerate ( self ._fragments ) :
610
610
if not fragment .should_prepare_fragment (step ):
611
611
continue
612
612
613
+ logger .info (f"preparing fragment { i } at step { step } " )
614
+
613
615
fragment .prepare_sync ()
614
616
615
- for fragment in self ._fragments :
617
+ for i , fragment in enumerate ( self ._fragments ) :
616
618
if not fragment .should_sync_fragment (step ):
617
619
continue
618
620
621
+ logger .info (f"syncing fragment { i } at step { step } " )
622
+
619
623
if not fragment .perform_sync ():
620
624
# Cancel all the previously scheduled allreduce by simply
621
625
# waiting for them. They should have failed but lets be
0 commit comments