You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
In #3740, we added support for FullyShardedDataParallel, but limited implementation to that of Zero2, not Zero3. Zero3 results in substantial decreases of memory usage compared with Zero2 while bringing speed back in line with vanilla DDP.
We have already added support for this (via manual calls to wrap) within the Transformer modules, but we still cannot support Zero3. The main issue is that Zero3 assumes that every worker calls forward the exact same number of times, and performs a parameter-transfer during this forward (moving the sharded parameters to each worker just in time). ParlAI cannot provide this guarantee though because:
During validation, each worker sees a variable number of examples. This is okay in itself, but it is problematic (hang) if it results in any worker having extra batches.
During generation, workers will have a variable number of forwards due to the variable sequence length. While everything stays happy for a while, if one worker ends the run with needing more generations than the others, we will get hangs.
It seems far too difficult (and ugly) to try to force this equality in worlds.py or in our dataset sharding. So our best future bet is to implement something like .join() in vanilla DDP. It would work roughly as follows:
Every worker in forward tries to synchronize a True boolean saying "Am I doing a true forward?"
Upon __exit__ of the context, workers enter an infinite loop where they sync a False boolean. As long as any worker is providing a True value, they participate in a dummy batch forward.
When all workers agree on the False boolean, we can end the infinite loop.
This feature makes the most sense to implement upstream in Fairscale, and then integrate into ParlAI.
The text was updated successfully, but these errors were encountered:
During validation, each worker sees a variable number of examples. This is okay in itself, but it is problematic (hang) if it results in any worker having extra batches.
Pytorch distributed has a wrapper for that, I've tried to look it up to no avail (maybe not public yet). Not sure how applicable that would be, just a heads up
During validation, each worker sees a variable number of examples. This is okay in itself, but it is problematic (hang) if it results in any worker having extra batches.
Pytorch distributed has a wrapper for that, I've tried to look it up to no avail (maybe not public yet). Not sure how applicable that would be, just a heads up
In #3740, we added support for FullyShardedDataParallel, but limited implementation to that of Zero2, not Zero3. Zero3 results in substantial decreases of memory usage compared with Zero2 while bringing speed back in line with vanilla DDP.
We have already added support for this (via manual calls to
wrap
) within the Transformer modules, but we still cannot support Zero3. The main issue is that Zero3 assumes that every worker callsforward
the exact same number of times, and performs a parameter-transfer during this forward (moving the sharded parameters to each worker just in time). ParlAI cannot provide this guarantee though because:It seems far too difficult (and ugly) to try to force this equality in worlds.py or in our dataset sharding. So our best future bet is to implement something like
.join()
in vanilla DDP. It would work roughly as follows:True
boolean saying "Am I doing a true forward?"__exit__
of the context, workers enter an infinite loop where they sync aFalse
boolean. As long as any worker is providing aTrue
value, they participate in a dummy batch forward.False
boolean, we can end the infinite loop.This feature makes the most sense to implement upstream in Fairscale, and then integrate into ParlAI.
The text was updated successfully, but these errors were encountered: