-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issues with distributed training #80
Conversation
Log to wandb through the accelerator so that you don't get multiple logs sent, one from each process. I know we had logging tucked inside a check for accelerator.is_main_process but that wasn't working for some reason I can't explain at the time I'm writing this message. Also, fixes an issue that when you wrap a customized model in DistributedDataParallel you don't have access to the custom attributes/methods. You can get access to them by unwrapping to the original module with model.module.custom_attribute. I haven't deeply investigated the consequences of this. What does the DDP wrapper do? What do you skip by reaching through? If you're just accessing a scalar argument value, like context_length, then I imagine it's safe. But what if you're accessing some custom data loading functionality?
@@ -145,7 +153,7 @@ def evaluate(self, model: GatoPolicy, n_iterations, deterministic=True, promptle | |||
# trim to context length | |||
input_dict[self.obs_str] = input_dict[self.obs_str][-context_timesteps:,] | |||
input_dict[self.action_str] = input_dict[self.action_str][-context_timesteps:,] | |||
action = model.predict_control(input_dict, task=self, deterministic=deterministic) | |||
action = model.module.predict_control(input_dict, task=self, deterministic=deterministic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a hard time thinking of a good way to handle this. When the training is launched with accelerate launch train.py
, then you need to access model.module.predict...
But when it's launched with python train.py
, then you need just model.predict...
. I'd hate to have if
conditionals all over the place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! Clearly! https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
Attributes of the wrapped module
After wrapping a Module with DataParallel, the attributes of the module (e.g. custom methods) became inaccessible. This is because DataParallel defines a few new members, and allowing other attributes might lead to clashes in their names. For those who still want to access the attributes, a workaround is to use a subclass of DataParallel as below.
class MyDataParallel([nn.DataParallel](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel)):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scratch that. Now I remember. We're getting DataParallel by way of Huggingface's Accelerate library. We'd need to make the change there. Not quite so clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh. Here we go:
class GatoPolicy(nn.Module):
# ...
@property
def module(self):
return self
Add that and just use model.module
everywhere. That ought to work for both Accelerated runs and non-distributed runs.
Good job, thanks @eihli for fixing this. |
Addresses Issue #33
Log to wandb through the accelerator so that you don't get multiple logs sent, one from each process. I know we had logging tucked inside a check for accelerator.is_main_process but that wasn't working for some reason I can't explain at the time I'm writing this message.
Also, fixes an issue that when you wrap a customized model in DistributedDataParallel you don't have access to the custom attributes/methods. You can get access to them by unwrapping to the original module with model.module.custom_attribute. I haven't deeply investigated the consequences of this. What does the DDP wrapper do? What do you skip by reaching through? If you're just accessing a scalar argument value, like context_length, then I imagine it's safe. But what if you're accessing some custom data loading functionality?
You can see a wandb training run here