diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index d48d4a173..170374e37 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -32,6 +32,10 @@ from streaming import Stream, StreamingDataset from omegaconf import OmegaConf as om +import pickle as pkl + +import os + def make_data_loader(dataset, neox_args): """Build dataloader given an input dataset.""" if dataset is None: @@ -324,8 +328,9 @@ def build_streaming_train_valid_test_data_iterators(neox_args): def prepare_config(dataset_config): dataset_config['num_workers'] = neox_args.num_workers dataset_config['dataset']['max_seq_length'] = neox_args.seq_length - dataset_config['dataset']['eos_token_id'] = neox_args.tokenizer.eod_id dataset_config['dataset']['remote'] = None # TODO Allow remote datasets + dataset_config['dataset']['position_pad_id'] = neox_args.position_pad_id + dataset_config['dataset']['vision_pad_id'] = neox_args.vision_pad_id prepare_config(neox_args.train_streaming_data_config) prepare_config(neox_args.valid_streaming_data_config) @@ -340,9 +345,7 @@ def prepare_config(dataset_config): tokenizer = neox_args.tokenizer train_dataloader = build_interleaved_dataloader(train_dataset_cfg, tokenizer, device_batch_size) - train_dataset_cfg['dataset']['split'] = "validation" valid_dataloader = build_interleaved_dataloader(validation_dataset_cfg, tokenizer, device_batch_size) - validation_dataset_cfg['dataset']['split'] = "test" test_dataloader = build_interleaved_dataloader(test_dataset_cfg, tokenizer, device_batch_size) # Flags to know if we need to do training/validation/testing. @@ -368,25 +371,46 @@ def prepare_config(dataset_config): neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() - - - # Build iterators. + + # Shift the start iterations. if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) - else: - train_data_iterator = None + train_state_dict_path = neox_args.train_streaming_data_config['state_dict_path'] + if os.path.exists(train_state_dict_path): + file_name = os.path.join(train_state_dict_path, f'{neox_args.iteration}_checkpoint.pkl') + + if os.path.isfile(file_name): # If the file exists + train_state_dict = pkl.load(open(file_name, 'rb')) # Load the file + print(train_state_dict) + train_dataloader.load_state_dict(train_state_dict) + else: + print("No matching state dict found.") + else: + print_rank_0( + "setting training data start iteration to {}".format( + 0 + ) + ) + if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) - else: - valid_data_iterator = None - - if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) - else: - test_data_iterator = None + valid_state_dict_path = neox_args.valid_streaming_data_config['state_dict_path'] + if os.path.exists(valid_state_dict_path): + file_name = os.path.join(valid_state_dict_path, f'{neox_args.iteration}_checkpoint.pkl') + + if os.path.isfile(file_name): # If the file exists + valid_state_dict = pkl.load(open(file_name, 'rb')) # Load the file + print(valid_state_dict) + valid_dataloader.load_state_dict(valid_state_dict) + else: + print("No matching state dict found.") + else: + print_rank_0( + "setting validation data start iteration to {}".format( + 0 + ) + ) - return train_data_iterator, valid_data_iterator, test_data_iterator + return train_dataloader, valid_dataloader, test_dataloader def build_train_valid_test_data_iterators(neox_args): """XXX""" diff --git a/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py b/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py index 3fb19701f..6b7b280c3 100644 --- a/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py +++ b/megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py @@ -11,7 +11,7 @@ from streaming import MDSWriter from torch.utils.data import DataLoader, IterableDataset from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase +# from transformers import AutoTokenizer, PreTrainedTokenizerBase # Import Image type class from PIL import Image @@ -25,23 +25,79 @@ import numpy as np from torch.utils.data import IterableDataset -from transformers import PreTrainedTokenizerBase +# from transformers import PreTrainedTokenizerBase import torch from torch.nn.utils.rnn import pad_sequence # from hamcrest.core.core.isnone import none +import datasets as hf_datasets +import lm_dataformat as lmd +from threading import Semaphore # Import webdataset import webdataset as wds from streaming.base.format.mds.encodings import Encoding, _encodings +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import torchvision + +from megatron.tokenizer.tokenizer import build_tokenizer + +from multiprocessing import Pool +import multiprocessing + +from functools import partial + + +class ListPIL(Encoding): + """Store PIL image raw. + + Format: [width: 4] [height: 4] [mode size: 4] [mode] [raw image]. + """ + + def encode(self, images: List[Image.Image]) -> bytes: + # self._validate(images, List[Image.Image]) + final_bytes = b'' + for obj in images: + mode = obj.mode.encode('utf-8') + width, height = obj.size + raw = obj.tobytes() + ints = np.array([width, height, len(mode), len(raw)], np.uint32) + final_bytes += ints.tobytes() + mode + raw + return final_bytes + + def decode(self, data: bytes) -> List[Image.Image]: + images = [] + idx = 4 * 4 + start = 0 + # print("Data length", len(data)) + while True: + if start == len(data): + break + width, height, mode_size, raw_size = np.frombuffer(data[start:start+idx], np.uint32) + # print("width, height, mode_size, raw_size", width, height, mode_size, raw_size) + start = start + idx + idx2 = start + mode_size + # print("start", start, " idx2", idx2) + mode = data[start:idx2].decode('utf-8') + start = idx2 + size = width, height + idx3 = start + raw_size + raw = data[start:idx3] + start = idx3 + images.append(Image.frombytes(mode, size, raw)) # pyright: ignore + return images + +_encodings['listpil'] = ListPIL + class ImageEncoding(Encoding): def encode(self, images: List[Image.Image]) -> bytes: bytes_arr = [] for image in images: byte_io = io.BytesIO() - image.save(byte_io, format='JPEG') + image.save(byte_io, format='png') bytes_arr.append(byte_io.getvalue()) return b''.join(bytes_arr) @@ -114,6 +170,8 @@ def decode(self, data: bytes) -> np.ndarray: class simple_encoding(Encoding): def encode(self, data: List[Image.Image]) -> bytes: + if data == []: + return np.array([]).tobytes() # Read all images into numpy array data = map(lambda x: np.array(x), data) data = np.stack(list(data)) @@ -173,7 +231,7 @@ class ConcatTokensDataset(IterableDataset): def __init__( self, dataset: IterableDataset, - tokenizer: PreTrainedTokenizerBase, + tokenizer, max_length: int, image_seq_length: int, bos_text: str, @@ -181,7 +239,8 @@ def __init__( image_start_text: str, image_end_text: str, no_wrap: bool, - after_image_extra_tokens: int = 10 + after_image_extra_tokens: int = 10, + position_pad_id: int = -1 ): self.dataset = dataset self.tokenizer = tokenizer @@ -193,44 +252,31 @@ def __init__( self.after_image_extra_tokens = after_image_extra_tokens self.bos_text = bos_text self.eos_text = eos_text - self.pad_token_id = self.tokenizer("<|padding|>", - truncation=False, - padding=False, - add_special_tokens=False)['input_ids'][0] + self.pad_token_id = self.tokenizer.pad_id + self.position_pad_id = position_pad_id self.should_wrap = not no_wrap - self.bos_tokens = self.tokenizer(self.bos_text, - truncation=False, - padding=False, - add_special_tokens=False)['input_ids'] + self.bos_tokens = self.tokenizer.tokenize(self.bos_text) if len(self.bos_tokens) > 1: warnings.warn( f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token\ , instead we got {self.bos_tokens}. Quit if this was in error.') - self.eos_tokens = self.tokenizer(self.eos_text, - truncation=False, - padding=False, - add_special_tokens=False)['input_ids'] + self.eos_tokens = self.tokenizer.tokenize(self.eos_text) print("eos token", self.eos_tokens) if len(self.eos_tokens) > 1: warnings.warn( f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\ , instead we got {self.eos_tokens}. Quit if this was in error.') - self.image_start_token = self.tokenizer(self.image_start_text, - truncation=False, - padding=False, - add_special_tokens=False)['input_ids'][0] - self.image_end_token = self.tokenizer(self.image_end_text, - truncation=False, - padding=False, - add_special_tokens=False)['input_ids'][0] + self.image_start_token = self.tokenizer.tokenize(self.image_start_text)[0] + + self.image_end_token = self.tokenizer.tokenize(self.image_end_text)[0] eos_text_provided = self.eos_text != '' bos_text_provided = self.bos_text != '' - test_text = self.tokenizer('') - if len(test_text['input_ids']) > 0 and (eos_text_provided or + test_text = self.tokenizer.tokenize('') + if len(test_text) > 0 and (eos_text_provided or bos_text_provided): message = 'both eos and bos' if eos_text_provided and bos_text_provided else ( 'eos_text' if eos_text_provided else 'bos_text') @@ -250,22 +296,22 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: curr_image = [] for sample in self.dataset: - text = sample["text"] - images = sample["images"] + sample_text = sample["text"] + sample_images = sample["images"] self.text_buffer.append(self.bos_tokens) self.image_buffer.insert(0, None) # To batch bos - for section in text: + for section in sample_text: if section != None: # Need to check this for max length however - self.text_buffer.append(self.tokenizer(section, truncation=False, padding=False)["input_ids"]) + self.text_buffer.append(self.tokenizer.tokenize(section)) else: self.text_buffer.append(None) self.text_buffer.append(self.eos_tokens) - self.image_buffer.extend(images) + self.image_buffer.extend(sample_images) self.image_buffer.append(None) #We want to add text and image to our upcoming output (setup), and remove them from the buffer. @@ -326,7 +372,7 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: text_tokens = text_ids text_positions = torch.from_numpy(np.where(np_text != None)[0]) - images = list(filter(lambda a: a != None, curr_image)) + images = list(filter(lambda a: a != None, curr_image)) # FIX THIS image_positions = torch.from_numpy(np.where(np_text == None)[0]) labels = np.roll(np_text, -1, axis = 0) labels[-1] = self.pad_token_id @@ -338,9 +384,9 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: text_labels = np.where(text_labels == None, self.pad_token_id, text_labels).astype(np.int64) image_labels = np.where(image_labels == None, self.pad_token_id, image_labels).astype(np.int64) - multimodal_position_ids = torch.nn.utils.rnn.pad_sequence([text_positions, image_positions], batch_first = True, padding_value = -1) # TODO: Make this position pad id + multimodal_position_ids = torch.nn.utils.rnn.pad_sequence([text_positions, image_positions], batch_first = True, padding_value = self.position_pad_id) - labels = torch.nn.utils.rnn.pad_sequence([torch.from_numpy(text_labels), torch.from_numpy(image_labels)], batch_first = True, padding_value = -1) + labels = torch.nn.utils.rnn.pad_sequence([torch.from_numpy(text_labels), torch.from_numpy(image_labels)], batch_first = True, padding_value = self.pad_token_id) # convert tensor to numpy array labels = labels.numpy().tobytes() @@ -348,19 +394,8 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: text_tokens = text_tokens.tobytes() multimodal_position_ids = multimodal_position_ids.numpy().tobytes() - images = map(lambda x: np.array(x), images) - images = np.stack(list(images)) - images = np.expand_dims(images, axis=1) - - # print("text_id", text_tokens) - # print("text_positions", text_positions) - # print("image_positions", image_positions) - # print("multimodal_position_ids", multimodal_position_ids) - # print("labels", labels) - # print("images", images) - yield { - 'images': images.tobytes(), + 'images': images, 'tokens': text_tokens, 'multimodal_position_ids' : multimodal_position_ids, 'labels': labels @@ -377,106 +412,72 @@ class ConcatMode(Enum): NO_CONCAT = 'NO_CONCAT' CONCAT_TOKENS = 'CONCAT_TOKENS' +class TextConcatDataset(IterableDataset): + def __init__(self, path, group): + # List all jsonl files in the folder mentioned by path: + all_json_ls = glob(path + "/*.jsonl") + # Sort the list of jsonl files: + all_json_ls.sort() + # Get the start and end indices of the group: + start, end = group + # Get the jsonl files in the group: + self.paths = all_json_ls[start:end] -''' -python create_dataset.py \ - --path /p/fastdata/mmlaion/hummingbird/streaming/arxiv.jsonl \ - --out_root /p/fastdata/mmlaion/hummingbird/streaming/text/train --split train \ - --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' \ - --compression zstd -''' -def parse_args() -> Namespace: - """Parse commandline arguments.""" - parser = ArgumentParser( - description= - 'Convert dataset into MDS format, optionally concatenating and tokenizing' - ) - parser.add_argument('--path', type=str, required=True) - parser.add_argument('--out_root', type=str, required=True) - parser.add_argument('--compression', type=str, default=None) - - group = parser.add_mutually_exclusive_group(required=False) - group.add_argument( - '--concat_tokens', - type=int, - help='Convert text to tokens and concatenate up to this many tokens') - parser.add_argument('--split', type=str, default='train') - - parser.add_argument('--tokenizer', type=str, required=False, default=None) - parser.add_argument('--bos_text', type=str, required=False, default=None) - parser.add_argument('--eos_text', type=str, required=False, default=None) - parser.add_argument('--no_wrap', default=False, action='store_true') - - parsed = parser.parse_args() - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set( - parsed.split))) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.' - ) - - # Make sure we have needed concat options - if (parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer') - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' - return parsed - + def __iter__(self): + for fname in self.paths: + for doc in filter(lambda x: x, lmd.Reader(fname).stream_data()): + sample = { + "images": [None], + "text": [doc] + } + yield sample class ImageCaptionDataset(IterableDataset): - def __init__(self, path): - fpath = path + "/{00000..41455}.tar" - self.dataset = wds.WebDataset(fpath).decode("pilrgb").rename(image="jpg;png;jpeg;webp", text="txt").to_tuple("image", "text") - - # def __iter__(self): - # for image, text in self.dataset: - # sample = { - # "images": [None, image], - # "text": [text, None] - # } - # yield sample + def __init__(self, path, group): + start, end = group + fpath = f"{path}/{{{str(start).zfill(5)}..{str(end).zfill(5)}}}.tar" + self.dataset = iter(wds.WebDataset(fpath).decode("pilrgb").rename(image="jpg;png;jpeg;webp", text="txt").to_tuple("image", "text")) def __iter__(self): - data_iter = iter(self.dataset) while True: try: - image, text = next(data_iter) + image, text = next(self.dataset) + image = torchvision.transforms.functional.resize(image, [224, 224], interpolation=InterpolationMode.BICUBIC) + if text is None: + print("key 'text' not found in the sample, skipping this datapoint") + continue + yield { + "images": [None, image], + "text": [text, None] + } except StopIteration: - # If StopIteration is raised, break from loop break - except Exception as e: + except ValueError as e: print(f"Error encountered: {e}. Skipping this datapoint.") continue - - sample = { - "images": [None, image], - "text": [text, None] - } - yield sample + except Exception as e: + print(f"Unexpected Error encountered: {e}. Skipping this datapoint.") + continue -def build_image_caption_dataset( +def build_interleaved_multimodal_dataset( path: str, - split: str, + group: tuple, mode: ConcatMode, max_length: Optional[int] = None, bos_text: str = '', eos_text: str = '', + image_start_text: str = '<|image_start|>', + image_end_text: str = '<|image_end|>', no_wrap: bool = False, - tokenizer: PreTrainedTokenizerBase = None, + tokenizer = None, vision_seq_length: int = 64, -) -> IterableDataset: + after_image_extra_tokens: int = 10, + position_pad_id: int = -1 +): """Build an IterableDataset over the HF C4 or pile source data. Args: dataset_name (str): Dataset name - split (str): Split name. mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS max_length (int): The length of concatenated tokens bos_text (str): text to insert at the beginning of each sequence @@ -490,20 +491,18 @@ def build_image_caption_dataset( An IterableDataset. """ - dataset = ImageCaptionDataset(path) + dataset = ImageCaptionDataset(path, group) + # dataset = TextConcatDataset(path, group) if mode == ConcatMode.NO_CONCAT: dataset = NoConcatDataset(dataset) else: - if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - f'{tokenizer} must be of type PreTrainedTokenizerBase') if max_length is None: raise ValueError(f'max_length must be set.') if bos_text + eos_text == '': test_tokens = tokenizer('test') - if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + if test_tokens[ + 0] != tokenizer.bos_token_id and test_tokens[ -1] != tokenizer.eos_token_id: tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' @@ -519,36 +518,88 @@ def build_image_caption_dataset( image_seq_length=vision_seq_length, bos_text=bos_text, eos_text=eos_text, - image_start_text='hello', - image_end_text='world', - no_wrap=no_wrap + image_start_text=image_start_text, + image_end_text=image_end_text, + no_wrap=no_wrap, + after_image_extra_tokens=after_image_extra_tokens, + position_pad_id=position_pad_id ) - return dataset + for sample in tqdm(dataset): + yield sample -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None -) -> Iterable[Dict[str, bytes]]: - """Generator over samples of a dataloader. +def data_generator(task_queue, data_queue, args, worker_id): + + if args.concat_tokens is not None: + mode = ConcatMode.CONCAT_TOKENS + args.rank = 0 + args.model_parallel_size = 1 + args.make_vocab_size_divisible_by = 128 + tokenizer = build_tokenizer(args) + # we will enforce length, so suppress warnings about sequences too long for the model + tokenizer.model_max_length = int(1e30) + else: + mode = ConcatMode.NO_CONCAT + tokenizer = None + + partial_builder = partial(build_interleaved_multimodal_dataset, + path=args.path, + mode=mode, + max_length=args.concat_tokens, + bos_text=tokenizer.bos_text, + eos_text=tokenizer.eos_text, + image_start_text=tokenizer.image_start_text, + image_end_text=tokenizer.image_end_text, + no_wrap=args.no_wrap, + tokenizer=tokenizer, + vision_seq_length=args.vision_seq_length, + after_image_extra_tokens=args.after_image_extra_tokens, + position_pad_id=args.position_pad_id) + + while not task_queue.empty(): + group = task_queue.get() + start, end = group + print(f'Worker {worker_id} started processing data: {start}-{end}') + for data in partial_builder(group=group): + data_queue.put(data) + print(f'Worker {worker_id} finished processed data: {start}-{end}') + +def data_writer(data_queue, args, index): + if args.concat_tokens is not None: + columns = {'tokens': 'bytes', 'images': 'listpil', 'multimodal_position_ids': 'bytes', 'labels': 'bytes'} + else: + columns = {'text': 'str', 'images': 'ndarray'} + + with MDSWriter(columns=columns, + out=os.path.join(f"{args.out_root}/{index}"), + compression=args.compression, size_limit=1e+9) as out: + + total_samples = 0 + total_images = 0 + while True: + print("The queue size is", data_queue.qsize()) + try: + sample = data_queue.get(timeout=100) + total_samples += 1 + total_images += len(sample["images"]) + out.write(sample) + print(f'\rWriter {index} Writing sample {total_samples} with {total_images} images.........', flush=True, end='') + except multiprocessing.queues.Empty: + print(f'\rNo more data to write. Exiting. {index}') + break + +def get_dataset_groups(start_ind:int, end_ind:int, groups: int): + """Get the sub-directory path and the sample range. Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. + out_root (str): base output mds directory + groups (int): Number of sub-directories to create Yields: - Sample dicts. + Iterator[Tuple[str, int, int]]: Each argument tuple """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - print(keys) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield {k: v[idx] for k, v in batch.items()} - + group_size = (end_ind - start_ind) // groups + for group_start in range(start_ind, end_ind, group_size): + yield (group_start, group_start + group_size) def main(args: Namespace) -> None: """Main: create C4/pile streaming dataset. @@ -556,50 +607,98 @@ def main(args: Namespace) -> None: Args: args (Namespace): Commandline arguments. """ - if args.concat_tokens is not None: - mode = ConcatMode.CONCAT_TOKENS - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) - # we will enforce length, so suppress warnings about sequences too long for the model - tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'bytes', 'images': 'bytes', 'multimodal_position_ids': 'bytes', 'labels': 'bytes'} - else: - mode = ConcatMode.NO_CONCAT - tokenizer = None - columns = {'text': 'str', 'images': 'ndarray'} - # Write samples print(f'Converting to MDS format...') print( f'Note that the progress bar is based on the dataset length before tokenization.' ) print(f'It will finish at a value below 100% if tokenizing') - with MDSWriter(columns=columns, - out=os.path.join(args.out_root), - compression=args.compression, size_limit=1e+10) as out: - # Get samples - dataset = build_image_caption_dataset(path='/p/fastdata/mmlaion/laion-400m/LAION-400m-webdataset/data', - split=args.split, - mode=mode, - max_length=args.concat_tokens, - bos_text=args.bos_text, - eos_text=args.eos_text, - no_wrap=args.no_wrap, - tokenizer=tokenizer) - total_samples = 0 - total_images = 0 - for sample in tqdm(dataset): - total_samples += 1 - total_images += len(sample["images"]) - print(total_samples, total_images) - # simple_encoder = simple_encoding() - out.write(sample) - if total_samples >= 145: - break + dataset_group_iterator = get_dataset_groups(args.start_ind, args.end_ind, args.num_groups) + + task_queue = multiprocessing.Queue() + for index_range in dataset_group_iterator: + task_queue.put(index_range) + + data_queue = multiprocessing.Queue(maxsize=args.queue_size) + + workers = [] + for i in range(args.workers): + worker_process = multiprocessing.Process(target=data_generator, args=(task_queue, data_queue, args, i)) + worker_process.start() + workers.append(worker_process) + + # writers + writers = [] + + for i in range(args.num_writers): + writer_process = multiprocessing.Process(target=data_writer, args=(data_queue, args, i)) + writer_process.start() + writers.append(writer_process) + + # Wait for all the workers to finish + for worker in workers: + worker.join() + + # Now the master can terminate + for writer in writers: + writer.join() + +def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description= + 'Convert dataset into MDS format, optionally concatenating and tokenizing' + ) + parser.add_argument('--path', type=str, required=True) + parser.add_argument('--out_root', type=str, required=True) + parser.add_argument('--compression', type=str, default=None) + + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument( + '--concat_tokens', + type=int, + help='Convert text to tokens and concatenate up to this many tokens') + parser.add_argument('--queue_size', type=int, default=5000) + parser.add_argument('--split', type=str, default='train') + parser.add_argument('--num_groups', type=int, default=100) + parser.add_argument('--workers', type=int, default=24) + parser.add_argument('--num_writers', type=int, default=10) + parser.add_argument('--start_ind', type=int, default=0) + parser.add_argument('--end_ind', type=int, default=41455) + parser.add_argument('--tokenizer_type', type=str, required=False, default=None) + parser.add_argument('--vocab_file', type=str, required=False, default=None) + parser.add_argument('--merge_file', type=str, required=False, default=None) + parser.add_argument('--no_wrap', default=False, action='store_true') + parser.add_argument('--vision_seq_length', type=int, default=64) + parser.add_argument('--after_image_extra_tokens', type=int, default=10) + parser.add_argument('--position_pad_id', type=int, default=-1) + + parsed = parser.parse_args() + + if os.path.isdir(parsed.out_root) and len( + set(os.listdir(parsed.out_root)).intersection(set( + parsed.split))) > 0: + raise ValueError( + f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.' + ) + else: + os.makedirs(parsed.out_root) + # Make sure we have needed concat options + if (parsed.concat_tokens is not None and + isinstance(parsed.concat_tokens, int) and parsed.tokenizer_type is None): + parser.error( + 'When setting --concat_tokens, you must specify a tokenizer') + + return parsed if __name__ == '__main__': main(parse_args()) ''' -python create_dataset.py --path /p/fastdata/mmlaion/hummingbird/streaming/arxiv.jsonl --out_root /p/fastdata/mmlaion/hummingbird/streaming/interleaved/train --split train --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' --compression zstd''' \ No newline at end of file +python create_interleaved_dataset.py --path /p/fastdata/mmlaion/hummingbird/red_pajama_raw/arxiv/arxiv_0af50072-df4c-4084-a833-cebbd046e70e.jsonl --compression zstd --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' --out_root /p/fastdata/mmlaion/hummingbird/test_laion400M/test7 + +python megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py --path /p/fastdata/mmlaion/hummingbird/red_pajama_raw/arxiv --compression zstd --concat_tokens 2048 --tokenizer_type HFTokenizer --vocab_file /p/project/ccstdl/gupta6/multimodal/20B_tokenizer.json --out_root /p/fastdata/mmlaion/hummingbird/hummingbird_dataset/text_train_final +python megatron/data/streaming_dataset/interleaved_text_image/create_interleaved_dataset.py --path /p/fastdata/mmlaion/laion2B-en --compression zstd --concat_tokens 2048 --tokenizer_type HFTokenizer --vocab_file /p/project/ccstdl/gupta6/multimodal/20B_tokenizer.json --out_root /p/fastdata/mmlaion/hummingbird/hummingbird_dataset/laion_5b_test +''' \ No newline at end of file diff --git a/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py b/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py index 423e56992..cea17699d 100644 --- a/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py +++ b/megatron/data/streaming_dataset/interleaved_text_image/dataloader.py @@ -22,35 +22,11 @@ from transformers import AutoImageProcessor, AutoModel from streaming.base.format.mds.encodings import Encoding, _encodings +from einops import rearrange -class PickleEncoding(Encoding): - def encode(self, data: List[Image.Image]) -> bytes: - return pickle.dumps(data) - - def decode(self, data: bytes) -> np.ndarray: - data = pickle.loads(data) - # Convert PIL Images to numpy arrays - data = map(lambda x: np.array(x), data) - return np.stack(list(data)) - +from megatron.data.streaming_dataset.interleaved_text_image.create_interleaved_dataset import simple_encoding, ListPIL, PickleEncoding _encodings['pickleencoding'] = PickleEncoding - -class simple_encoding(Encoding): - def encode(self, data: List[Image.Image]) -> bytes: - # Read all images into numpy array - data = map(lambda x: np.array(x), data) - data = np.stack(list(data)) - assert data.shape == (len(data), 256, 256, 3), f'Expected shape (N, 256, 256, 3), got {data.shape}' - return data.tobytes() - - def decode(self, data: bytes) -> np.ndarray: - # convert bytes to numpy array - data = np.frombuffer(data, dtype=np.uint8) - # print(data.shape, data.reshape(-1, 256, 256, 3).shape) - # reshape to original shape - data = data.reshape(-1, 256, 256, 3) - return data - +_encodings['listpil'] = ListPIL _encodings['simple_encoding'] = simple_encoding def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase: @@ -126,7 +102,7 @@ class StreamingInterleavedDataset(StreamingDataset): """ def __init__(self, - tokenizer: PreTrainedTokenizerBase, + tokenizer, max_seq_length: int, streams: Optional[Sequence[Stream]] = None, remote: Optional[str] = None, @@ -146,6 +122,8 @@ def __init__(self, shuffle_algo: str = 'py1b', shuffle_seed: int = 9176, shuffle_block_size: int = 1 << 18, + batching_method: str = 'random', + vision_pad_id: int = -100, **kwargs: Any): group_method = kwargs.pop('group_method', None) @@ -188,10 +166,11 @@ def __init__(self, shuffle_algo=shuffle_algo, shuffle_seed=shuffle_seed, shuffle_block_size=shuffle_block_size, + batching_method=batching_method, ) self.tokenizer = tokenizer self.max_seq_length = max_seq_length - self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') + self.vision_pad_id = vision_pad_id # How to tokenize a text sample to a token sample def _tokenize(self, text_sample: Mapping): @@ -226,9 +205,22 @@ def __getitem__(self, idx: int): raise RuntimeError( 'StreamingTextDataset needs samples to have a `text` or `tokens` column' ) - vision_input = np.frombuffer(sample.get('images', None), dtype=np.uint8).copy().reshape(-1, 256, 256, 3) - vision_input = self.processor(vision_input, return_tensors="pt") - vision_input = vision_input["pixel_values"].to(torch.int64).unsqueeze(1) # TODO: Fix for num_frames > 1 + images = list(map(lambda x: np.array(x), sample.get('images', None))) + if images != []: + images = np.stack(images) + else: + images = np.array([]) + vision_input = images.reshape(-1, 224, 224, 3) + is_vision_empty = vision_input.shape[0] == 0 + if is_vision_empty: + vision_input = np.zeros((1, 224, 224, 3), dtype=np.uint8) + + vision_input = torch.from_numpy(vision_input).to(torch.int64) + vision_input = vision_input.unsqueeze(1) # TODO: Fix for num_frames > 1 + vision_input = rearrange(vision_input, "t f h w c -> t f c h w") + + if is_vision_empty: + vision_input = torch.ones_like(vision_input) * self.vision_pad_id multimodal_position_ids = torch.from_numpy(np.frombuffer(sample.get('multimodal_position_ids', None), dtype=np.int64).copy()).reshape(2, -1) labels = torch.from_numpy(np.frombuffer(sample.get('labels', None), dtype=np.int64).copy()).reshape(2, -1) return (token_sample, vision_input, multimodal_position_ids, labels) @@ -296,8 +288,9 @@ def build_interleaved_dataloader( # get kwargs streams_dict = cfg.dataset.pop('streams', None) mlm_probability = cfg.dataset.pop('mlm_probability', None) - eos_token_id = cfg.dataset.pop('eos_token_id', None) - bos_token_id = cfg.dataset.pop('bos_token_id', None) + position_pad_id = cfg.dataset.pop('position_pad_id', None) + pad_token_id = tokenizer.pad_token_id + vision_pad_id = cfg.dataset.pop('vision_pad_id', None) # build streams streams = None @@ -312,21 +305,22 @@ def build_interleaved_dataloader( dataset = StreamingInterleavedDataset( tokenizer=tokenizer, streams=streams, + vision_pad_id=vision_pad_id, **cfg.dataset, ) text_collate_fn = transformers.DataCollatorForLanguageModeling( - tokenizer=dataset.tokenizer, + tokenizer=tokenizer, mlm=mlm_probability is not None, mlm_probability=mlm_probability) text_collate_fn = TextNeoXCollateWrapper(text_collate_fn) - vision_collate_fn = PaddedCollateWrapper(pad_token_id=-1) # Each sample: (timesteps, num_vision, H, W, C) + vision_collate_fn = PaddedCollateWrapper(pad_token_id=vision_pad_id) # Each sample: (timesteps, num_vision, H, W, C) - multimodal_position_ids_collate_fn = PaddedCollateWrapper(pad_token_id=-1, take_transpose=True) # Each sample: (num_modalities, max_seq_length) + multimodal_position_ids_collate_fn = PaddedCollateWrapper(pad_token_id=position_pad_id, take_transpose=True) # Each sample: (num_modalities, max_seq_length) - label_collate_fn = PaddedCollateWrapper(pad_token_id=-1, take_transpose=True) # Each sample: (num_modalities, max_seq_length) + label_collate_fn = PaddedCollateWrapper(pad_token_id=pad_token_id, take_transpose=True) # Each sample: (num_modalities, max_seq_length) collate_fn = MultimodalCollateWrapper(text_collator=text_collate_fn, vision_collator=vision_collate_fn, @@ -368,11 +362,11 @@ def build_interleaved_dataloader( help='the path to the remote copy to stream from (optional)') parser.add_argument('--split', type=str, - default='val', + default='validation', help='which split of the dataset to use') parser.add_argument('--max_seq_length', type=int, - default=32, + default=2048, help='max sequence length to test') args = parser.parse_args() @@ -388,14 +382,15 @@ def build_interleaved_dataloader( 'name': 'text', 'dataset': { 'local': args.local_path, - 'remote': args.remote_path, + 'remote': None, 'split': args.split, 'shuffle': False, 'max_seq_length': args.max_seq_length, 'keep_zip': True, # in case we need compressed files after testing + 'eos_token_id': 50256, }, - 'drop_last': False, - 'num_workers': 4, + 'drop_last': True, + 'num_workers': 5, } cfg = om.create(cfg) device_batch_size = 2 @@ -405,9 +400,11 @@ def build_interleaved_dataloader( tokenizer_cfg = om.create(tokenizer_cfg) tokenizer = build_tokenizer(tokenizer_cfg) - loader = build_interleaved_dataloader(cfg, tokenizer, device_batch_size) - tokenizer = loader.dataset.tokenizer # type: ignore - for batch_ix, batch in enumerate(islice(loader, 5)): + loader = iter(build_interleaved_dataloader(cfg, tokenizer, device_batch_size)) + # tokenizer = loader.dataset.tokenizer # type: ignore + print("I am ready") + print(next(loader)) + for batch_ix, batch in enumerate(loader): print('\n') print('#' * 20, f'Batch {batch_ix}', '#' * 20) for k, v in batch.items(): diff --git a/megatron/data/streaming_dataset/interleaved_text_image/merge_files.py b/megatron/data/streaming_dataset/interleaved_text_image/merge_files.py new file mode 100644 index 000000000..77ed7c38d --- /dev/null +++ b/megatron/data/streaming_dataset/interleaved_text_image/merge_files.py @@ -0,0 +1,11 @@ +from streaming.base.util import merge_index +import argparse + +# main function +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--out_root', type=str, default=None) + args = parser.parse_args() + out_root = args.out_root + merge_index(out_root, keep_local=True) + diff --git a/megatron/model/embeddings.py b/megatron/model/embeddings.py index bdd85c49a..1b303ec8c 100644 --- a/megatron/model/embeddings.py +++ b/megatron/model/embeddings.py @@ -238,9 +238,10 @@ class EmbeddingPipe(Embedding): def __init__(self, neox_args, *args, **kwargs): super().__init__(neox_args, *args, **kwargs) # self.image_encoder = ImageEncoder(neox_args) + self.neox_args = neox_args self.seq_length = neox_args.seq_length - image_encoder_args = Args() - self.image_encoder = MultiModalEncoder(image_encoder_args, neox_args.hidden_size) + self.image_encoder_args = Args() + self.image_encoder = MultiModalEncoder(self.image_encoder_args, neox_args.hidden_size) @property def word_embeddings_weight(self): @@ -261,7 +262,12 @@ def forward(self, args): word_embeddings = super().forward(input_ids, position_ids) # [B, T, E] # Vision Input is [B, T, F, C, H, W] - image_embeddings = self.image_encoder(vision_input) # [B, T, N, E] where N=1 for now + # print("Inside embedding", vision_input) + if torch.all(vision_input == self.neox_args.vision_pad_id): + image_embeddings = torch.zeros((vision_input.shape[0], vision_input.shape[1], self.image_encoder_args.perceiver_seq_length, self.neox_args.hidden_size), dtype=word_embeddings.dtype, device=word_embeddings.device) + else: + image_embeddings = self.image_encoder(vision_input) # [B, T, N, E] where N=1 for now + image_embeddings = rearrange(image_embeddings, "b t n e -> b (t n) e") # [B, T*N, E] # Concatenate the embeddings diff --git a/megatron/model/encoders/vision/transforms.py b/megatron/model/encoders/vision/transforms.py new file mode 100644 index 000000000..e4232e201 --- /dev/null +++ b/megatron/model/encoders/vision/transforms.py @@ -0,0 +1,97 @@ +# Adapted from https://github.com/facebookresearch/dinov2/blob/main/dinov2/data/transforms.py + +from typing import Sequence + +import torch +from torchvision import transforms + + +class GaussianBlur(transforms.RandomApply): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): + # NOTE: torchvision is applying 1 - probability to return the original image + keep_p = 1 - p + transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) + super().__init__(transforms=[transform], p=keep_p) + + +class MaybeToTensor(transforms.ToTensor): + """ + Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + if isinstance(pic, torch.Tensor): + return pic + return super().__call__(pic) + +class ReScale(torch.nn.Module): + def __init__(self, scale) -> None: + super().__init__() + self.scale = scale + + def __call__(self, image): + return image * self.scale + + +# Use timm's names +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +DINO_SCALE = 0.00392156862745098 + +def make_normalize_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Normalize: + return transforms.Normalize(mean=mean, std=std) + + +# This roughly matches torchvision's preset for classification training: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 +def make_classification_train_transform( + *, + crop_size: int = 224, + interpolation=transforms.InterpolationMode.BICUBIC, + hflip_prob: float = 0.5, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +): + transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0.0: + transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms_list.extend( + [ + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + ) + return transforms.Compose(transforms_list) + + +# This matches (roughly) torchvision's preset for classification evaluation: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 +def make_classification_eval_transform( + *, + resize_size: int = 256, + interpolation=transforms.InterpolationMode.BICUBIC, + crop_size: int = 224, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Compose: + transforms_list = [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + ReScale(DINO_SCALE), + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + return transforms.Compose(transforms_list) \ No newline at end of file diff --git a/megatron/model/encoders/vision/vision_encoder.py b/megatron/model/encoders/vision/vision_encoder.py index 33da02b93..0dabc6b3b 100644 --- a/megatron/model/encoders/vision/vision_encoder.py +++ b/megatron/model/encoders/vision/vision_encoder.py @@ -10,6 +10,7 @@ from .dinov2 import layers from transformers import AutoImageProcessor, AutoModel +from .transforms import make_classification_eval_transform class DinoWrapper(nn.Module): def __init__(self, encoder, config): @@ -17,6 +18,7 @@ def __init__(self, encoder, config): self.encoder = encoder self.config = config self.prepare_encoder() + self.transform = make_classification_eval_transform() def freeze_model(self): num_layers_to_unfreeze = self.config.num_layers_to_unfreeze @@ -44,6 +46,9 @@ def forward(self, x): c=number of channels, h=height, w=width ''' b, t, c, h, w = x.shape + combined_batch = rearrange(x, "b t c h w -> (b t) c h w") + preprocessed_vision = self.transform(combined_batch).half().contiguous() + x = rearrange(preprocessed_vision, "(b t) c h w -> b t c h w", b=b, t=t) if "vision" in self.config.arch: embeddings = self.encoder(x) # B, N_E, E else: diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 57c9973c7..abdb3aec8 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -685,6 +685,9 @@ def enable_logging(self): file_prefix = os.path.join(self.log_dir, hostname) Tee(file_prefix + "_stdout.txt", err=False) Tee(file_prefix + "_stderr.txt", err=True) + + if self.wandb_dir: + os.makedirs(self.wandb_dir, exist_ok=True) def print(self): """Print arguments.""" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 1d55916b6..5aa694aa3 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -499,6 +499,11 @@ class NeoXArgsLogging(NeoXArgsTemplate): git_hash: str = get_git_commit_hash() """current git hash of repository""" + wandb_dir: str = None + """ + Directory to save logs to. + """ + log_dir: str = None """ Directory to save logs to. @@ -715,8 +720,10 @@ class NeoXArgsTokenizer(NeoXArgsTemplate): """ tokenizer object loaded into memory and accessible by other functions """ - - + position_pad_id: int = -1 + + vision_pad_id: int = 0 + @dataclass class NeoXArgsTraining(NeoXArgsTemplate): """ diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index d4089f2b5..68a1b2b99 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -228,12 +228,21 @@ def __init__(self, vocab_file): name = "HFTokenizer" super().__init__(name) self.tokenizer = Tokenizer.from_file(vocab_file) + self.eos_text = "<|endoftext|>" self.eod_id = self.tokenizer.token_to_id("<|endoftext|>") + self.bos_text = "" + self.bos_id = None + self.pad_text = "<|padding|>" self.pad_id = self.tokenizer.token_to_id("<|padding|>") - self.pad_token_id = self.pad_id - self._pad_token = self.pad_id + self.pad_token_id = self.pad_id # TODO Fix + self._pad_token = self.pad_id # TODO Fix self.padding_side = "right" + self.image_start_text = "<|image_start|>" + self.image_start_id = self.tokenizer.token_to_id("<|image_start|>") + self.image_end_text = "<|image_end|>" + self.image_end_id = self.tokenizer.token_to_id("<|image_end|>") + @property def vocab_size(self): return self.tokenizer.get_vocab_size() diff --git a/megatron/training.py b/megatron/training.py index f136860e9..67dc429ce 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -57,6 +57,28 @@ ) from megatron.model.gpt2_model import cross_entropy from eval_tasks import run_eval_harness +import pickle as pkl + +import os + +def save_dataloader_checkpoint(checkpoint_path, iteration, dataloader_state_dict): + save_checkpoint = False + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + save_checkpoint = True + else: + save_checkpoint = True + + if save_checkpoint: + if not os.path.isdir(checkpoint_path): # If the folder does not exist + os.makedirs(checkpoint_path) # Create the folder + + file_name = os.path.join(checkpoint_path, f'{iteration}_checkpoint.pkl') + + with open(file_name, 'wb') as file: + pkl.dump(dataloader_state_dict, file) + + def mup_weights_reinit(neox_args, model): @@ -198,14 +220,14 @@ def pretrain(neox_args): # Data stuff. timers("train/valid/test data iterators").start() ( - train_data_iterator, - valid_data_iterator, - test_data_iterator, + train_dataloader, + valid_dataloader, + test_dataloader, ) = build_streaming_train_valid_test_data_iterators(neox_args=neox_args) timers("train/valid/test data iterators").stop() if neox_args.use_mup and neox_args.coord_check: - mup_coord_check(neox_args, timers, lr_scheduler, train_data_iterator) + mup_coord_check(neox_args, timers, lr_scheduler, iter(train_dataloader) if train_dataloader is not None else None) # Print setup timing. print_rank_0("done with setups ...") @@ -223,6 +245,9 @@ def pretrain(neox_args): optimizer=optimizer, lr_scheduler=lr_scheduler, ) + save_dataloader_checkpoint(neox_args.train_streaming_data_config['state_dict_path'], iteration, train_dataloader.state_dict()) + save_dataloader_checkpoint(neox_args.valid_streaming_data_config['state_dict_path'], iteration, valid_dataloader.state_dict()) + iteration = train( neox_args=neox_args, @@ -230,8 +255,8 @@ def pretrain(neox_args): model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, - train_data_iterator=train_data_iterator, - valid_data_iterator=valid_data_iterator, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, ) if neox_args.do_valid: @@ -240,7 +265,7 @@ def pretrain(neox_args): neox_args=neox_args, prefix=prefix, forward_step_func=forward_step, - data_iterator=valid_data_iterator, + data_iterator=iter(valid_dataloader) if valid_dataloader is not None else None, model=model, iteration=iteration, verbose=False, @@ -255,6 +280,8 @@ def pretrain(neox_args): optimizer=optimizer, lr_scheduler=lr_scheduler, ) + save_dataloader_checkpoint(neox_args.train_streaming_data_config['state_dict_path'], iteration, train_dataloader.state_dict()) + save_dataloader_checkpoint(neox_args.valid_streaming_data_config['state_dict_path'], iteration, valid_dataloader.state_dict()) if neox_args.do_test: # Run on test data. @@ -263,7 +290,7 @@ def pretrain(neox_args): neox_args=neox_args, prefix=prefix, forward_step_func=forward_step, - data_iterator=test_data_iterator, + data_iterator=iter(test_dataloader) if test_dataloader is not None else None, model=model, iteration=iteration, verbose=True, @@ -281,7 +308,7 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): if "labels" in data_b: # This our custom approach labels = data_b["labels"].long().contiguous() text_input = text_input_.contiguous() - else: + else: # This is not supported labels = text_input_[:, 1:].contiguous() text_input = text_input_[:, :-1].contiguous() @@ -292,8 +319,8 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): max_text_length = text_input.shape[1] text_positions = multimodal_position_ids[:, MODALITY_DICT['text'], :max_text_length] text_labels = labels[:, MODALITY_DICT['text'], :max_text_length] - assert torch.all(multimodal_position_ids[:, MODALITY_DICT['text'], max_text_length:] == -1) - assert torch.all(labels[:, MODALITY_DICT['text'], max_text_length:] == -1) + assert torch.all(multimodal_position_ids[:, MODALITY_DICT['text'], max_text_length:] == neox_args.position_pad_id) + assert torch.all(labels[:, MODALITY_DICT['text'], max_text_length:] == tokenizer.pad_id) text_input_info = { "input": text_input, "labels": text_labels, @@ -302,12 +329,12 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): } # Unpack vision_input and get padded vision length - vision_input = data_b["vision_input"].half().contiguous() + vision_input = data_b["vision_input"] max_vision_length = vision_input.shape[1] vision_positions = multimodal_position_ids[:, MODALITY_DICT['vision'], :max_vision_length] vision_labels = labels[:, MODALITY_DICT['vision'], :max_vision_length] - assert torch.all(multimodal_position_ids[:, MODALITY_DICT['vision'], max_vision_length:] == -1) - assert torch.all(labels[:, MODALITY_DICT['vision'], max_vision_length:] == -1) + assert torch.all(multimodal_position_ids[:, MODALITY_DICT['vision'], max_vision_length:] == neox_args.position_pad_id) + assert torch.all(labels[:, MODALITY_DICT['vision'], max_vision_length:] == tokenizer.pad_id) vision_input_info = { "input": vision_input, "labels": vision_labels, @@ -334,9 +361,11 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids, labels = get_multimodal_ltor_masks_and_position_ids( input_info=input_info, input_seq_length=neox_args.seq_length, - eod_token=neox_args.tokenizer.eod_id, - bos_token=neox_args.tokenizer.bos_id if hasattr(neox_args.tokenizer, "bos_id") else None, - pad_token=neox_args.tokenizer.pad_id, + eod_token=tokenizer.eod_id, + bos_token=tokenizer.bos_id if hasattr(tokenizer, "bos_id") else None, + pad_token=tokenizer.pad_id, + position_pad_token_id=neox_args.position_pad_id, + vision_start_token = tokenizer.image_start_id, concat_data=neox_args.concat_data, attn_uses_sequence_id=neox_args.attn_uses_sequence_id ) @@ -817,11 +846,12 @@ def train( model, optimizer, lr_scheduler, - train_data_iterator, - valid_data_iterator, + train_dataloader, + valid_dataloader, ): """Train the model function.""" - + train_data_iterator = iter(train_dataloader) if train_dataloader else None + valid_data_iterator = iter(valid_dataloader) if valid_dataloader else None # Turn on training mode which enables dropout. model.train() @@ -887,6 +917,8 @@ def train( optimizer=optimizer, lr_scheduler=lr_scheduler, ) + save_dataloader_checkpoint(neox_args.train_streaming_data_config['state_dict_path'], iteration, train_dataloader.state_dict()) + save_dataloader_checkpoint(neox_args.valid_streaming_data_config['state_dict_path'], iteration, valid_dataloader.state_dict()) # Evaluation if ( diff --git a/megatron/utils.py b/megatron/utils.py index 6271428ba..65ebe26ea 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -182,7 +182,7 @@ def get_shifted_multimodal_position_ids(input_info, position_pad_id=-1): shited_audio_positions = None return shifted_text_positions, shifted_vision_positions, shited_audio_positions -def get_proxy_tokens(position_ids, seq_length, text_pad_id, position_pad_id=-1, start_ind=100): +def get_proxy_tokens(position_ids, seq_length, pad_id, position_pad_id=-1, start_ind=100): multimodal_mask = position_ids != position_pad_id # All vision tokens are given a negative index @@ -190,7 +190,7 @@ def get_proxy_tokens(position_ids, seq_length, text_pad_id, position_pad_id=-1, proxy_tokens = torch.repeat_interleave(proxy_tokens, seq_length, dim=1) multimodal_mask = torch.repeat_interleave(multimodal_mask, seq_length, dim=1) proxy_masked_tokens = proxy_tokens*multimodal_mask - proxy_masked_tokens[proxy_masked_tokens == 0] = text_pad_id # This is any random text token. This cannot be equal to eos, or eod: TODO, can this be padid? + proxy_masked_tokens[proxy_masked_tokens == 0] = pad_id # This is any random text token. This cannot be equal to eos, or eod: TODO, can this be padid? return proxy_masked_tokens def get_multimodal_mask(interleaved_tokens, text_pad_id): @@ -226,8 +226,8 @@ def get_multimodal_attn_mask( # lower triangular attention mask across all tokens mask = torch.tril(torch.ones((1, input_seq_length, input_seq_length), device=device)).expand((batch_size, -1, -1)).clone() - # Form vision proxy tokens using shifted multimodal position ids - proxy_vision_tokens = get_proxy_tokens(vision_positions, vision_seq_length, position_pad_id=position_pad_token_id, text_pad_id=text_pad_token_id, start_ind=100) + # Form vision proxy tokens using shifted multimodal position ids. Use text_pad_token_id as pad_id + proxy_vision_tokens = get_proxy_tokens(vision_positions, vision_seq_length, position_pad_id=position_pad_token_id, pad_id=text_pad_token_id, start_ind=100) # Do the same process for Audio #TODO # Concatenate vision proxy tokens with text tokens @@ -269,6 +269,8 @@ def get_multimodal_ltor_masks_and_position_ids( eod_token, bos_token, pad_token, + position_pad_token_id, + vision_start_token, concat_data=True, attn_uses_sequence_id=False, ): @@ -294,7 +296,7 @@ def get_multimodal_ltor_masks_and_position_ids( shifted_multimodal_position_ids=shifted_multimodal_position_ids, eos_token_id=eod_token, bos_token_id=bos_token, - position_pad_token_id=-1, # TODO, get whatever is used in streaming + position_pad_token_id=position_pad_token_id, text_pad_token_id=pad_token, concat_data=concat_data, attn_uses_sequence_id=attn_uses_sequence_id, @@ -319,6 +321,7 @@ def get_multimodal_ltor_masks_and_position_ids( loss_mask[labels == pad_token] = 0.0 loss_mask[labels == bos_token] = 0.0 loss_mask[labels == eod_token] = 0.0 + loss_mask[labels == vision_start_token] = 0.0 position_ids = torch.arange(input_seq_length, dtype=torch.long, device=labels.device) # FIX THIS #TODO position_ids = position_ids.unsqueeze(0).expand(batch_size, input_seq_length) @@ -381,6 +384,7 @@ def init_wandb(neox_args): if neox_args.use_wandb: group_name = neox_args.wandb_group name = f"{socket.gethostname()}-{local_rank()}" if group_name else None + print("Logging wandb to:", neox_args.wandb_dir, flush=True) try: wandb.init( project=neox_args.wandb_project, @@ -389,6 +393,7 @@ def init_wandb(neox_args): save_code=False, force=False, entity=neox_args.wandb_team, + dir=neox_args.wandb_dir, ) except wandb.UsageError as e: neox_args.update_value("use_wandb", False) diff --git a/mytests/multimodal_utils_test.py b/mytests/multimodal_utils_test.py index 4cbd794c3..d7673a60e 100644 --- a/mytests/multimodal_utils_test.py +++ b/mytests/multimodal_utils_test.py @@ -187,32 +187,49 @@ def test_get_shifted_multimodal_position_ids(): def test_get_proxy_tokens(): eps = 1e-7 + # Case 0: 0 images # TODO + vision_positions = torch.tensor([[-1]]) + vision_seq_len = 1 + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) + assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[0]])) < eps) + # Case 1: 1 image vision_positions = torch.tensor([[1]]) vision_seq_len = 1 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100]])) < eps) # Case 2: multiple images vision_positions = torch.tensor([[1, 2]]) vision_seq_len = 1 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -101]])) < eps) # Case 3: multiple images, seq len > 1 vision_positions = torch.tensor([[3, 5]]) vision_seq_len = 3 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -100, -100, -101, -101, -101]])) < eps) # Case 4: multiple samples with padding vision_positions = torch.tensor([[1, 2], [3, -1]]) vision_seq_len = 2 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -100, -101, -101], [-100, -100, 0, 0]])) < eps) def test_get_multimodal_mask(): + # Case 0: 0 images + tokens = torch.tensor([[2, 2]]) + multimodal_mask = utils.get_multimodal_mask(tokens, text_pad_id=0) + correct_mask = torch.tensor([ + [ + [False, False], + [False, False] + ] + ]) + assert torch.all(multimodal_mask == correct_mask) + # Case 1: 1 image tokens = torch.tensor([[-100]]) multimodal_mask = utils.get_multimodal_mask(tokens, text_pad_id=0)