Skip to content

Commit

Permalink
torchrun for DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Sep 20, 2024
1 parent 86be9c7 commit 3a925bf
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 12 deletions.
4 changes: 2 additions & 2 deletions xuance/common/common_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def get_runner(method,
args = get_arguments(method, env, env_id, config_path, parser_args, is_test)

if type(args) == list:
device = f"GPU-{args[0].rank}" if args[0].distributed_training else args[0].device
device = f"GPU-{os.environ['RANK']}" if args[0].distributed_training else args[0].device
else:
device = f"GPU-{args.rank}" if args.distributed_training else args.device
device = f"GPU-{os.environ['RANK']}" if args.distributed_training else args.device
dl_toolbox = args[0].dl_toolbox if type(args) == list else args.dl_toolbox
print("Calculating device:", device)

Expand Down
3 changes: 1 addition & 2 deletions xuance/torch/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self,
self.distributed_training = config.distributed_training
if self.distributed_training:
master_port = config.master_port if hasattr(config, "master_port") else None
init_distributed_mode(config.rank, config.world_size, master_port=master_port)
init_distributed_mode(int(os.environ['LOCAL_RANK']), config.world_size, master_port=master_port)

self.gamma = config.gamma
self.start_training = config.start_training if hasattr(config, "start_training") else 1
Expand Down Expand Up @@ -120,7 +120,6 @@ def save_model(self, model_name):
'var': self.obs_rms.var}
np.save(obs_norm_path, observation_stat)


def load_model(self, path, model=None):
# load neural networks
path_loaded = self.learner.load_model(path, model)
Expand Down
27 changes: 26 additions & 1 deletion xuance/torch/learners/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from xuance.common import Optional, List, Union
from argparse import Namespace
from operator import itemgetter
from xuance.torch import Tensor
from xuance.torch import Tensor, DistributedDataParallel

MAX_GPUs = 100

Expand All @@ -25,6 +25,16 @@ def __init__(self,
self.scheduler: Union[dict, list, Optional[torch.optim.lr_scheduler.LinearLR]] = None

self.distributed_training = config.distributed_training
if self.distributed_training:
self.snapshot_path = os.path.join(config.model_dir, "DDP_Snapshot")
if os.path.exists(self.snapshot_path):
print("Loading Snapshot...")
self.load_snapshot(self.snapshot_path)
else:
os.makedirs(self.snapshot_path)
self.device = int(os.environ['LOCAL_RANK'])
self.policy = DistributedDataParallel(self.policy, find_unused_parameters=True,
device_ids=[int(os.environ['LOCAL_RANK'])])
self.use_grad_clip = config.use_grad_clip
self.grad_clip_norm = config.grad_clip_norm
self.device = config.device
Expand All @@ -35,6 +45,7 @@ def __init__(self,
def save_model(self, model_path):
if self.distributed_training:
torch.save(self.policy.module.state_dict(), model_path)
self.save_snapshot()
else:
torch.save(self.policy.state_dict(), model_path)

Expand Down Expand Up @@ -63,6 +74,20 @@ def load_model(self, path, model=None):
print(f"Successfully load model from '{path}'.")
return path

def load_snapshot(self, snapshot_path):
loc = f"cuda: {self.device}"
snapshot = torch.load(snapshot_path, map_location=loc)
self.policy.load_state_dict(snapshot["MODEL_STATE"])
print("Resuming training from snapshot.")

def save_snapshot(self):
snapshot = {
"MODEL_STATE": self.policy.module.state_dict(),
}
snapshot_pt = os.path.join(self.snapshot_path, "snapshot.pt")
torch.save(snapshot, snapshot_pt)
print(f"Training snapshot saved at {self.snapshot_path}")

@abstractmethod
def update(self, *args):
raise NotImplementedError
Expand Down
8 changes: 1 addition & 7 deletions xuance/torch/learners/qlearning_family/dqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch import nn
from argparse import Namespace
from xuance.torch.learners import Learner
from torch.nn.parallel import DistributedDataParallel


class DQN_Learner(Learner):
Expand All @@ -22,12 +21,7 @@ def __init__(self,
self.sync_frequency = config.sync_frequency
self.mse_loss = nn.MSELoss()
self.one_hot = nn.functional.one_hot
self.n_actions = self.policy.action_dim
# parallel settings
if self.distributed_training:
self.device = config.rank
self.policy = DistributedDataParallel(self.policy, find_unused_parameters=True,
device_ids=[self.config.rank])
self.n_actions = self.policy.module.action_dim if self.distributed_training else self.policy.action_dim

def update(self, **samples):
self.iterations += 1
Expand Down

0 comments on commit 3a925bf

Please sign in to comment.