Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with distributed training #80

Merged
merged 8 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ python ./gato/data/download_custom_datasets.py
```bash
docker build -t gato-control -f ./docker/Dockerfile .
docker run -it --mount "type=bind,source=$(pwd),target=/app/gato-control" --entrypoint /bin/bash --gpus=all gato-control

```


Expand Down
14 changes: 11 additions & 3 deletions gato/tasks/control_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,21 @@ def evaluate(self, model: GatoPolicy, n_iterations, deterministic=True, promptle
ep_lens = []
metrics = {}

context_timesteps = model.context_len // self.tokens_per_timestep # amount of timesteps that fit into context
context_timesteps = model.module.context_len // self.tokens_per_timestep # amount of timesteps that fit into context

for i in range(n_iterations):
observation, info = self.env.reset()

# sample prompt
input_dict = self.sample_batch_configurable(batch_size=1, device=model.device, prompt_proportions=[1.], prompt_types = ['end'], max_tokens = model.context_len, share_prompt_episodes=True,ep_ids=self.top_ids)[0]
input_dict = self.sample_batch_configurable(
batch_size=1,
device=model.device,
prompt_proportions=[1.],
prompt_types=['end'],
max_tokens=model.module.context_len,
share_prompt_episodes=True,
ep_ids=self.top_ids
)[0]

# infer dtypes
action_type = input_dict[self.action_str].dtype
Expand Down Expand Up @@ -145,7 +153,7 @@ def evaluate(self, model: GatoPolicy, n_iterations, deterministic=True, promptle
# trim to context length
input_dict[self.obs_str] = input_dict[self.obs_str][-context_timesteps:,]
input_dict[self.action_str] = input_dict[self.action_str][-context_timesteps:,]
action = model.predict_control(input_dict, task=self, deterministic=deterministic)
action = model.module.predict_control(input_dict, task=self, deterministic=deterministic)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a hard time thinking of a good way to handle this. When the training is launched with accelerate launch train.py, then you need to access model.module.predict... But when it's launched with python train.py, then you need just model.predict.... I'd hate to have if conditionals all over the place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! Clearly! https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

Attributes of the wrapped module

After wrapping a Module with DataParallel, the attributes of the module (e.g. custom methods) became inaccessible. This is because DataParallel defines a few new members, and allowing other attributes might lead to clashes in their names. For those who still want to access the attributes, a workaround is to use a subclass of DataParallel as below.

class MyDataParallel([nn.DataParallel](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel)):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scratch that. Now I remember. We're getting DataParallel by way of Huggingface's Accelerate library. We'd need to make the change there. Not quite so clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Here we go:

class GatoPolicy(nn.Module):
    # ...
    @property
    def module(self):
        return self

Add that and just use model.module everywhere. That ought to work for both Accelerated runs and non-distributed runs.

input_dict[self.action_str][-1,] = action
np_action = action.cpu().numpy()
observation, reward, terminated, truncated, info = self.env.step(np_action)
Expand Down
4 changes: 2 additions & 2 deletions gato/tasks/text_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def sample_batch(self, batch_size, is_test=False)->List[Dict]:
return batch_dicts

def evaluate(self, model: GatoPolicy, num_examples_to_test=50, deterministic=True, log_examples_to_output=False):
tokenizer = model.text_tokenizer
tokenizer = model.module.text_tokenizer
loss_fn = nn.CrossEntropyLoss()
total_loss = 0
total_tokens = 0
Expand Down Expand Up @@ -89,7 +89,7 @@ def evaluate(self, model: GatoPolicy, num_examples_to_test=50, deterministic=Tru
new_batch_dict['text'] = input_tokens

# Generate prediction
pred_logits, pred_tokens = model.predict_text(new_batch_dict, max_length=len(target_tokens), deterministic=deterministic)
pred_logits, pred_tokens = model.module.predict_text(new_batch_dict, max_length=len(target_tokens), deterministic=deterministic)
# todo: pull 50 into a CLI argument in train.py
if log_examples_to_output and idx%50==0:
print(f'Text Example : {tokenizer.decode(batch_dict["text"])} \n Input passed to model : {tokenizer.decode(new_batch_dict["text"])} \n Predicted output : {tokenizer.decode(pred_tokens)}')
Expand Down
5 changes: 3 additions & 2 deletions gato/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def train(self):
iters = self.args.training_steps // self.args.log_eval_freq
for i in range(iters):
logs = self.train_iteration(self.args.log_eval_freq, i)
if self.args.use_wandb and self.accelerator.is_main_process:
wandb.log(logs)
self.accelerator.log(logs)

## Save model at end of training only if not saving checkpoints
if self.args.save_model and self.args.save_mode == 'last':
Expand All @@ -55,6 +54,8 @@ def train(self):
unwrapped_model = self.accelerator.unwrap_model(self.model)
save_model(unwrapped_model, self.exp_dir, f'checkpoint_{self.steps}', self.args)

self.accelerator.end_training()


def train_iteration(self, num_steps, iter):
logs = {}
Expand Down
21 changes: 15 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@

def main(args):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, split_batches=True, gradient_accumulation_steps=args.gradient_accumulation_steps, kwargs_handlers=[ddp_kwargs])
if args.use_wandb:
log_with = 'wandb'
else:
log_with = None
accelerator = Accelerator(
cpu=args.cpu,
mixed_precision=args.mixed_precision,
split_batches=True,
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[ddp_kwargs],
log_with=log_with,
)
args.device = accelerator.device.type

exp_date = datetime.now().strftime('%y-%m-%d_%H-%M-%S')
Expand Down Expand Up @@ -126,11 +137,9 @@ def main(args):
optimizer, scheduler = accelerator.prepare(optimizer, scheduler)

if args.use_wandb:
wandb.init(
name = exp_name,
project=args.wandb_project,
config=args,
)
accelerator.init_trackers(args.wandb_project, init_kwargs={'wandb': {'name': exp_name, 'config': args}})
else:
accelerator.init_trackers('')

# Create save dir if does not exist
if args.save_model and not os.path.exists(args.save_dir):
Expand Down