Skip to content

Commit

Permalink
fix the bug when resume rng_state with diffuerent gpu numbers.
Browse files Browse the repository at this point in the history
  • Loading branch information
lawrence-cj committed Jan 6, 2025
1 parent ac55992 commit 1ca56ea
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion train_scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,10 +947,13 @@ def main(cfg: SanaConfig) -> None:
if rng_state:
logger.info("resuming randomise")
torch.set_rng_state(rng_state["torch"])
torch.cuda.set_rng_state_all(rng_state["torch_cuda"])
np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["python"])
generator.set_state(rng_state["generator"]) # resume generator status
try:
torch.cuda.set_rng_state_all(rng_state["torch_cuda"])
except:
logger.warning("Failed to resume torch_cuda rng state")

# Prepare everything
# There is no specific order to remember, you just need to unpack the
Expand Down

0 comments on commit 1ca56ea

Please sign in to comment.