Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with distributed training #80

Merged
merged 8 commits into from
Apr 20, 2024

Conversation

eihli
Copy link
Contributor

@eihli eihli commented Feb 13, 2024

Addresses Issue #33

Log to wandb through the accelerator so that you don't get multiple logs sent, one from each process. I know we had logging tucked inside a check for accelerator.is_main_process but that wasn't working for some reason I can't explain at the time I'm writing this message.

Also, fixes an issue that when you wrap a customized model in DistributedDataParallel you don't have access to the custom attributes/methods. You can get access to them by unwrapping to the original module with model.module.custom_attribute. I haven't deeply investigated the consequences of this. What does the DDP wrapper do? What do you skip by reaching through? If you're just accessing a scalar argument value, like context_length, then I imagine it's safe. But what if you're accessing some custom data loading functionality?

You can see a wandb training run here

image

Log to wandb through the accelerator so that you don't get
multiple logs sent, one from each process. I know we
had logging tucked inside a check for accelerator.is_main_process
but that wasn't working for some reason I can't explain at the
time I'm writing this message.

Also, fixes an issue that when you wrap a customized model in
DistributedDataParallel you don't have access to the custom attributes/methods.
You can get access to them by unwrapping to the original module with
model.module.custom_attribute. I haven't deeply investigated the consequences
of this. What does the DDP wrapper do? What do you skip by reaching through?
If you're just accessing a scalar argument value, like context_length, then
I imagine it's safe. But what if you're accessing some custom data loading
functionality?
@@ -145,7 +153,7 @@ def evaluate(self, model: GatoPolicy, n_iterations, deterministic=True, promptle
# trim to context length
input_dict[self.obs_str] = input_dict[self.obs_str][-context_timesteps:,]
input_dict[self.action_str] = input_dict[self.action_str][-context_timesteps:,]
action = model.predict_control(input_dict, task=self, deterministic=deterministic)
action = model.module.predict_control(input_dict, task=self, deterministic=deterministic)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a hard time thinking of a good way to handle this. When the training is launched with accelerate launch train.py, then you need to access model.module.predict... But when it's launched with python train.py, then you need just model.predict.... I'd hate to have if conditionals all over the place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! Clearly! https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

Attributes of the wrapped module

After wrapping a Module with DataParallel, the attributes of the module (e.g. custom methods) became inaccessible. This is because DataParallel defines a few new members, and allowing other attributes might lead to clashes in their names. For those who still want to access the attributes, a workaround is to use a subclass of DataParallel as below.

class MyDataParallel([nn.DataParallel](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel)):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scratch that. Now I remember. We're getting DataParallel by way of Huggingface's Accelerate library. We'd need to make the change there. Not quite so clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Here we go:

class GatoPolicy(nn.Module):
    # ...
    @property
    def module(self):
        return self

Add that and just use model.module everywhere. That ought to work for both Accelerated runs and non-distributed runs.

@bhavul
Copy link
Contributor

bhavul commented Apr 20, 2024

Good job, thanks @eihli for fixing this.

@bhavul bhavul merged commit 71d1a9d into ManifoldRG:master Apr 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants