Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaned padding #59

Merged
merged 11 commits into from
Oct 26, 2023
60 changes: 42 additions & 18 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@

from streaming import Stream, StreamingDataset
from omegaconf import OmegaConf as om
import pickle as pkl

import os

def make_data_loader(dataset, neox_args):
"""Build dataloader given an input dataset."""
if dataset is None:
Expand Down Expand Up @@ -324,8 +328,9 @@ def build_streaming_train_valid_test_data_iterators(neox_args):
def prepare_config(dataset_config):
dataset_config['num_workers'] = neox_args.num_workers
dataset_config['dataset']['max_seq_length'] = neox_args.seq_length
dataset_config['dataset']['eos_token_id'] = neox_args.tokenizer.eod_id
dataset_config['dataset']['remote'] = None # TODO Allow remote datasets
dataset_config['dataset']['position_pad_id'] = neox_args.position_pad_id
dataset_config['dataset']['vision_pad_id'] = neox_args.vision_pad_id

prepare_config(neox_args.train_streaming_data_config)
prepare_config(neox_args.valid_streaming_data_config)
Expand All @@ -340,9 +345,7 @@ def prepare_config(dataset_config):
tokenizer = neox_args.tokenizer

train_dataloader = build_interleaved_dataloader(train_dataset_cfg, tokenizer, device_batch_size)
train_dataset_cfg['dataset']['split'] = "validation"
valid_dataloader = build_interleaved_dataloader(validation_dataset_cfg, tokenizer, device_batch_size)
validation_dataset_cfg['dataset']['split'] = "test"
test_dataloader = build_interleaved_dataloader(test_dataset_cfg, tokenizer, device_batch_size)

# Flags to know if we need to do training/validation/testing.
Expand All @@ -368,25 +371,46 @@ def prepare_config(dataset_config):
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()


# Build iterators.

# Shift the start iterations.
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
else:
train_data_iterator = None
train_state_dict_path = neox_args.train_streaming_data_config['state_dict_path']
if os.path.exists(train_state_dict_path):
file_name = os.path.join(train_state_dict_path, f'{neox_args.iteration}_checkpoint.pkl')

if os.path.isfile(file_name): # If the file exists
train_state_dict = pkl.load(open(file_name, 'rb')) # Load the file
print(train_state_dict)
train_dataloader.load_state_dict(train_state_dict)
else:
print("No matching state dict found.")

else:
print_rank_0(
"setting training data start iteration to {}".format(
0
)
)

if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
else:
valid_data_iterator = None

if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
else:
test_data_iterator = None
valid_state_dict_path = neox_args.valid_streaming_data_config['state_dict_path']
if os.path.exists(valid_state_dict_path):
file_name = os.path.join(valid_state_dict_path, f'{neox_args.iteration}_checkpoint.pkl')

if os.path.isfile(file_name): # If the file exists
valid_state_dict = pkl.load(open(file_name, 'rb')) # Load the file
print(valid_state_dict)
valid_dataloader.load_state_dict(valid_state_dict)
else:
print("No matching state dict found.")
else:
print_rank_0(
"setting validation data start iteration to {}".format(
0
)
)

return train_data_iterator, valid_data_iterator, test_data_iterator
return train_dataloader, valid_dataloader, test_dataloader

def build_train_valid_test_data_iterators(neox_args):
"""XXX"""
Expand Down
Loading
Loading