11
11
import logging
12
12
import math
13
13
import threading
14
+ from contextlib import nullcontext
14
15
from types import TracebackType
15
16
from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , Type
16
17
@@ -197,9 +198,10 @@ def __init__(
197
198
self ._outer_optimizer = outer_optimizer
198
199
199
200
# Stores pending all reduce
200
- self ._allreduce_futures : list [
201
- torch .futures .Future [None ] | torch .futures .Future [torch .Tensor ]
202
- ] = []
201
+ self ._allreduce_futures : list [torch .futures .Future [torch .Tensor ]] = []
202
+ self ._stream : Optional [torch .cuda .Stream ] = (
203
+ torch .cuda .Stream () if torch .cuda .is_available () else None
204
+ )
203
205
204
206
if bucket_cap_mb is not None :
205
207
self .bucket_cap_mb = int (bucket_cap_mb * 1024 * 1024 )
@@ -222,13 +224,15 @@ def __init__(
222
224
t = t .pin_memory ()
223
225
self .original_parameters [name ] = t
224
226
227
+ @torch .profiler .record_function ("torchft::local_sgd::save_parameters" )
225
228
def save_parameters (self ) -> None :
226
229
with torch .no_grad ():
227
230
# TODO: consider running copy on a separate stream
228
231
for name , p in self ._model_fragment .named_parameters ():
229
232
param_to_local = extract_local_tensor (p .data )
230
233
self .original_parameters [name ].copy_ (param_to_local , non_blocking = True )
231
234
235
+ @torch .profiler .record_function ("torchft::local_sgd::restore_parameters" )
232
236
def restore_parameters (self ) -> None :
233
237
with torch .no_grad ():
234
238
# TODO: consider running copy on a separate stream
@@ -248,6 +252,7 @@ def restore_parameters(self) -> None:
248
252
else :
249
253
p .data .copy_ (self .original_parameters [name ], non_blocking = False )
250
254
255
+ @torch .profiler .record_function ("torchft::local_sgd::wait" )
251
256
def wait (self ) -> None :
252
257
"""
253
258
Waits for the previously scheduled allreduce to finish
@@ -272,22 +277,31 @@ def should_sync_fragment(self, step: int) -> bool:
272
277
step_to_sync = step - self ._fragment_sync_offset - self ._fragment_sync_delay
273
278
return step_to_sync % self ._sync_every == 0
274
279
280
+ @torch .profiler .record_function ("torchft::local_sgd::prepare_sync" )
275
281
def prepare_sync (self ) -> None :
276
282
"""
277
283
Calculate the pseugradient, average them across the manager group and starts
278
284
allreduce on the pseudo-gradients but doesn't wait for it to finish.
279
285
"""
280
- # Set the .grad field of each parameter to its pseudogradient
281
- for name , p in self ._model_fragment .named_parameters ():
282
- local_param = extract_local_tensor (p .data )
283
- pseudogradient = local_param - self .original_parameters [name ].to (p .device )
284
- if isinstance (p , DTensor ):
285
- p .grad ._local_tensor = pseudogradient
286
- else :
287
- p .grad = pseudogradient
286
+ with (
287
+ torch .cuda .stream (self ._stream )
288
+ if self ._stream is not None
289
+ else nullcontext ()
290
+ ):
291
+ # Set the .grad field of each parameter to its pseudogradient
292
+ for name , p in self ._model_fragment .named_parameters ():
293
+ local_param = extract_local_tensor (p .data )
294
+ pseudogradient = local_param - self .original_parameters [name ].to (
295
+ p .device
296
+ )
297
+ if isinstance (p , DTensor ):
298
+ p .grad ._local_tensor = pseudogradient
299
+ else :
300
+ p .grad = pseudogradient
288
301
289
- self ._average_grads ()
302
+ self ._average_grads ()
290
303
304
+ @torch .profiler .record_function ("torchft::local_sgd::perform_sync" )
291
305
def perform_sync (self ) -> bool :
292
306
"""
293
307
Overrides the sync method to wait for the scheduled allreduce to finish and
@@ -298,6 +312,9 @@ def perform_sync(self) -> bool:
298
312
299
313
self .wait ()
300
314
315
+ if self ._stream is not None :
316
+ self ._stream .synchronize ()
317
+
301
318
# Restore the parameters back to the previous state
302
319
self .restore_parameters ()
303
320
@@ -467,16 +484,6 @@ def __init__(
467
484
if fragment_update_alpha < 0 or fragment_update_alpha > 1 :
468
485
raise ValueError ("fragment_update_alpha must be between 0 and 1" )
469
486
470
- # TODO: Support multiple fragments
471
- # This requires changing the manager to support `should_commit` for each
472
- # fragment separately.
473
- if len (model_fragments ) != 1 :
474
- raise ValueError ("Multiple fragments are not supported yet" )
475
-
476
- # TODO: Support `fragment_sync_delay`
477
- if fragment_sync_delay != 0 :
478
- raise ValueError ("Fragment synchronization delay is not supported yet" )
479
-
480
487
# TODO: Support `fragment_update_alpha`
481
488
if fragment_update_alpha != 0.0 :
482
489
raise ValueError (
@@ -522,6 +529,8 @@ def __init__(
522
529
use_bucketization ,
523
530
bucket_cap_mb ,
524
531
should_quantize ,
532
+ fragment_sync_delay ,
533
+ fragment_update_alpha ,
525
534
)
526
535
for i , model_fragment in enumerate (model_fragments )
527
536
]
@@ -606,16 +615,20 @@ def _step_post_hook(
606
615
step = self ._local_step
607
616
608
617
# Start sending fragments
609
- for fragment in self ._fragments :
618
+ for i , fragment in enumerate ( self ._fragments ) :
610
619
if not fragment .should_prepare_fragment (step ):
611
620
continue
612
621
622
+ logger .info (f"preparing fragment { i } at step { step } " )
623
+
613
624
fragment .prepare_sync ()
614
625
615
- for fragment in self ._fragments :
626
+ for i , fragment in enumerate ( self ._fragments ) :
616
627
if not fragment .should_sync_fragment (step ):
617
628
continue
618
629
630
+ logger .info (f"syncing fragment { i } at step { step } " )
631
+
619
632
if not fragment .perform_sync ():
620
633
# Cancel all the previously scheduled allreduce by simply
621
634
# waiting for them. They should have failed but lets be
0 commit comments