From bb18db2d516719d39a2bf02fccddee79eec93967 Mon Sep 17 00:00:00 2001 From: ArvinZhuang Date: Thu, 17 Oct 2024 10:57:12 +1000 Subject: [PATCH] smae code --- .gitignore | 162 +++++++++++++++++++++++++++++++++ smae/README.md | 27 ++++++ smae/data.py | 202 ++++++++++++++++++++++++++++++++++++++++++ smae/modelling.py | 186 ++++++++++++++++++++++++++++++++++++++ smae/pretrain_bert.py | 121 +++++++++++++++++++++++++ 5 files changed, 698 insertions(+) create mode 100644 .gitignore create mode 100644 smae/README.md create mode 100644 smae/data.py create mode 100644 smae/modelling.py create mode 100644 smae/pretrain_bert.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..011b6ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +.DS_Store diff --git a/smae/README.md b/smae/README.md new file mode 100644 index 0000000..4a48658 --- /dev/null +++ b/smae/README.md @@ -0,0 +1,27 @@ +# Starbucks Masked Autoencoder + +```bash +python pretrain_bert.py \ +--output_dir checkpoints/bert-base-uncased-fineweb100bt-smae \ +--save_steps 10000 \ +--bf16 \ +--per_device_train_batch_size 32 \ +--gradient_accumulation_steps 2 \ +--learning_rate 1e-4 \ +--lr_scheduler_type cosine \ +--weight_decay 0.001 \ +--warmup_ratio 0.05 \ +--num_train_epochs 1 \ +--logging_steps 100 \ +--mlm_probability 0.2 \ +--decoder_mlm_probability 0.4 \ +--report_to wandb \ +--matryoshka_pretraining True \ +--mae_pretraining True \ +--run_name pretrain-bert-mae-matryoshka-fineweb100bt-starbucks \ +--dataloader_num_workers 16 \ +--num_processes 32 \ +--save_safetensors False \ +--log_level info \ +--logging_nan_inf_filter False +``` diff --git a/smae/data.py b/smae/data.py new file mode 100644 index 0000000..b6f0ee4 --- /dev/null +++ b/smae/data.py @@ -0,0 +1,202 @@ +from transformers import BatchEncoding, PreTrainedTokenizer, DataCollatorForWholeWordMask, BertTokenizer, BertTokenizerFast +from transformers.data.data_collator import _torch_collate_batch +import warnings +from torch.utils.data import Dataset +from datasets import load_dataset +import random +import torch +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union + + +class MLMDataset(Dataset): + def __init__(self, tokenizer: PreTrainedTokenizer, num_processes=16): + self.tokenizer = tokenizer + self.max_length = self.tokenizer.model_max_length - self.tokenizer.num_special_tokens_to_add(pair=False) + self.corpus = load_dataset( + 'HuggingFaceFW/fineweb', # hard code for now + 'sample-100BT', + split='train', + num_proc=num_processes, + ) + + def __len__(self): + return len(self.corpus) + + def __getitem__(self, item) -> BatchEncoding: + text = self.corpus[item]['text'] + + # if the text is too long, truncate it randomly + tokens = self.tokenizer.tokenize(text) + + if len(tokens) > self.max_length: + trunc = len(tokens) - self.max_length + trunc_left = random.randint(0, trunc) + trunc_right = trunc - trunc_left + + truncated = tokens[trunc_left:] + if trunc_right > 0: + truncated = truncated[:-trunc_right] + text = self.tokenizer.convert_tokens_to_string(truncated) + + tokenized_text = self.tokenizer(text, + return_special_tokens_mask=False, + return_token_type_ids=False, + truncation=True) + return tokenized_text + +@dataclass +class DataCollatorForWholeWordMaskWithAttentionMask(DataCollatorForWholeWordMask): + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + output = super().torch_call(examples) + batch = self.tokenizer.pad( + examples, + padding=True, + return_attention_mask=True, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors='pt', + ) + output['attention_mask'] = batch['attention_mask'] + return output + + + +@dataclass +class MaeDataCollatorForWholeWordMask(DataCollatorForWholeWordMask): + encoder_mlm_probability: float = 0.3 + decoder_mlm_probability: float = 0.5 + + def __post_init__(self): + super(MaeDataCollatorForWholeWordMask, self).__post_init__() + + from transformers import BertTokenizer, BertTokenizerFast + from transformers import RobertaTokenizer, RobertaTokenizerFast + if isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)): + self.whole_word_cand_indexes = self._whole_word_cand_indexes_bert + elif isinstance(self.tokenizer, (RobertaTokenizer, RobertaTokenizerFast)): + self.whole_word_cand_indexes = self._whole_word_cand_indexes_roberta + else: + raise NotImplementedError(f'{type(self.tokenizer)} collator not supported yet') + + self.specials = self.tokenizer.all_special_tokens + + def _whole_word_cand_indexes_bert(self, input_tokens: List[str]): + cand_indexes = [] + for (i, token) in enumerate(input_tokens): + if token in self.specials: + continue + + if len(cand_indexes) >= 1 and token.startswith("##"): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + return cand_indexes + + def _whole_word_cand_indexes_roberta(self, input_tokens: List[str]): + cand_indexes = [] + for (i, token) in enumerate(input_tokens): + if token in self.specials: + raise ValueError('We expect only raw input for roberta for current implementation') + + if i == 0: + cand_indexes.append([0]) + elif not token.startswith('\u0120'): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + return cand_indexes + + def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): + """ + Get 0/1 labels for masked tokens with whole word mask proxy + """ + + cand_indexes = self._whole_word_cand_indexes_bert(input_tokens) + + random.shuffle(cand_indexes) + encoder_num_to_predict = min(max_predictions, + max(1, int(round(len(input_tokens) * self.encoder_mlm_probability)))) + decoder_num_to_predict = min(max_predictions, + max(1, int(round(len(input_tokens) * self.decoder_mlm_probability)))) + + masked_lms = [] + encoder_masked_lms = [] + decoder_masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= max(encoder_num_to_predict, decoder_num_to_predict): + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > max(encoder_num_to_predict, decoder_num_to_predict): + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + + encoder_add = True if len(encoder_masked_lms) + len(index_set) <= encoder_num_to_predict else False + decoder_add = True if len(decoder_masked_lms) + len(index_set) <= decoder_num_to_predict else False + + for index in index_set: + covered_indexes.add(index) + if encoder_add and decoder_add: + encoder_masked_lms.append(index) + if decoder_add: + decoder_masked_lms.append(index) + masked_lms.append(index) + + assert len(covered_indexes) == len(masked_lms) + encoder_mask_labels = [1 if i in encoder_masked_lms else 0 for i in range(len(input_tokens))] + decoder_mask_labels = [1 if i in decoder_masked_lms else 0 for i in range(len(input_tokens))] + return encoder_mask_labels, decoder_mask_labels + + def __call__(self, examples, return_tensors=None): + batch = self.tokenizer.pad( + examples, + padding=True, + return_attention_mask=True, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors='pt', + ) + + input_ids = [e["input_ids"] for e in examples] + + encoder_mlm_masks = [] + decoder_mlm_masks = [] + + for e in input_ids: + tokens = [] + for tid in e: + tokens.append(self.tokenizer._convert_id_to_token(tid)) + + encoder_mlm_mask, decoder_mlm_mask = self._whole_word_mask(tokens, self.tokenizer.model_max_length) + encoder_mlm_masks.append(encoder_mlm_mask) + decoder_mlm_masks.append(decoder_mlm_mask) + + encoder_mlm_masks = _torch_collate_batch(encoder_mlm_masks, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) + decoder_mlm_masks = _torch_collate_batch(decoder_mlm_masks, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) + + + encoder_inputs, encoder_labels = self.torch_mask_tokens( + batch['input_ids'].clone(), + encoder_mlm_masks.clone() + ) + + decoder_inputs, decoder_labels = self.torch_mask_tokens( + batch['input_ids'].clone(), + decoder_mlm_masks.clone() + ) + + output = { + "input_ids": encoder_inputs, + "labels": encoder_labels, + "decoder_input_ids": decoder_inputs, + "decoder_labels": decoder_labels, + "attention_mask": batch['attention_mask'], + } + + return output \ No newline at end of file diff --git a/smae/modelling.py b/smae/modelling.py new file mode 100644 index 0000000..cbedffb --- /dev/null +++ b/smae/modelling.py @@ -0,0 +1,186 @@ +from transformers import Trainer, PreTrainedTokenizer, BertPreTrainedModel, DataCollatorForWholeWordMask, AutoTokenizer, BertForMaskedLM, AutoModelForMaskedLM +from transformers.models.bert.modeling_bert import MaskedLMOutput, BertModel, BertOnlyMLMHead + +from typing import List, Tuple, Dict, Any, Optional, Union +import torch +from torch.nn import CrossEntropyLoss, KLDivLoss +import random +from torch.nn import functional as F +import collections +import copy +from torch import nn + + +class BertPredictionHeadTransformForwardDecorator: + def __init__(self, module): + self.module = module + + def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor: + dim = hidden_states.size(-1) + hidden_states = F.linear(hidden_states, + self.module.dense.weight[:, :dim], + self.module.dense.bias) + hidden_states = self.module.transform_act_fn(hidden_states) + + hidden_states = self.module.LayerNorm(hidden_states) + + return hidden_states + + +class BertFor2DMatryoshkaMaskedLM(BertForMaskedLM): + def __init__(self, config): + super().__init__(config) + + self.cls.predictions.transform.forward = BertPredictionHeadTransformForwardDecorator( + self.cls.predictions.transform) + + self.layer_list = [2, 4, 6, 8, 10, 12] + self.dim_list = [32, 64, 128, 256, 512, 768] + + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + hidden_states = outputs.hidden_states + + total_loss = 0 + if labels is not None: + loss_fct = CrossEntropyLoss() + for selected_layer, selected_dim in zip(self.layer_list, self.dim_list): + selected_layer_selected_dim_scores = self.cls(hidden_states[selected_layer][:, :, :selected_dim]) + selected_layer_selected_dim_loss = loss_fct(selected_layer_selected_dim_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + total_loss += selected_layer_selected_dim_loss + + total_loss /= len(self.layer_list) + + + if not return_dict: + output = (None,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return MaskedLMOutput( + loss=total_loss, + logits=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertFor2DMaekMatryoshkaMaskedLM(BertFor2DMatryoshkaMaskedLM): + def __init__(self, config): + super().__init__(config) + n_head_layers = 1 # hard code for now + self.decoder = BertForMaskedLM.from_pretrained(config._name_or_path) + self.decoder.cls = self.cls + self.decoder.bert.embeddings = self.bert.embeddings + self.decoder.bert.encoder.layer = self.decoder.bert.encoder.layer[:n_head_layers] + + + def compute_mae_loss(self, hidden_states, decoder_input_ids, decoder_labels, decoder_attention_mask, loss_fct): + encoder_cls_hiddens = hidden_states[:, :1] + decoder_input_embeds = self.decoder.bert.embeddings(decoder_input_ids) + decoder_input_embeds[:, :1] = encoder_cls_hiddens + decoder_input_mlm = self.decoder.bert.encoder(decoder_input_embeds, attention_mask=decoder_attention_mask)[0] + + mae_scores = self.decoder.cls(decoder_input_mlm) + mae_loss = loss_fct(mae_scores.view(-1, self.config.vocab_size), + decoder_labels.view(-1)) + return mae_loss + + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + decoder_labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + hidden_states = outputs.hidden_states + decoder_attention_mask = self.decoder.get_extended_attention_mask( + attention_mask, + attention_mask.shape, + attention_mask.device + ) + + total_loss = 0 + if labels is not None: + loss_fct = CrossEntropyLoss() + for selected_layer, selected_dim in zip(self.layer_list, self.dim_list): + selected_layer_selected_dim_scores = self.cls(hidden_states[selected_layer][:, :, :selected_dim]) + selected_layer_selected_dim_loss = loss_fct(selected_layer_selected_dim_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + mae_loss = self.compute_mae_loss(self.cls.predictions.transform(hidden_states[selected_layer][:, :, :selected_dim]), + decoder_input_ids, + decoder_labels, + decoder_attention_mask, + loss_fct) + total_loss += (selected_layer_selected_dim_loss + mae_loss) / 2 + + total_loss /= len(self.layer_list) + + + + + + if not return_dict: + output = (None,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return MaskedLMOutput( + loss=total_loss, + logits=None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/smae/pretrain_bert.py b/smae/pretrain_bert.py new file mode 100644 index 0000000..4a3692a --- /dev/null +++ b/smae/pretrain_bert.py @@ -0,0 +1,121 @@ +import logging +import os +import sys + +from transformers import ( + HfArgumentParser, + TrainingArguments, + set_seed, +) + +from transformers import Trainer, AutoTokenizer, BertForMaskedLM +import torch +from dataclasses import dataclass, field +from modelling import BertFor2DMatryoshkaMaskedLM, BertFor2DMaekMatryoshkaMaskedLM +from data import MLMDataset, DataCollatorForWholeWordMaskWithAttentionMask, MaeDataCollatorForWholeWordMask + +logger = logging.getLogger(__name__) + + +class SkipNanTrainer(Trainer): + def training_step(self, model, inputs): + loss = super().training_step(model, inputs) + + # Immediately check for NaN gradients + nan_gradients = False + for param in model.parameters(): + if param.grad is not None and torch.isnan(param.grad).any(): + nan_gradients = True + param.grad = None # Reset gradients, so that the optimizer.step() later will do nothing. + + if nan_gradients: + # Tip: set '--logging_nan_inf_filter False' for smooth logging + print("NaN gradient detected, skipping optimizer step.") + + return loss + + +@dataclass +class MatryoshkaPretrainingArguments(TrainingArguments): + matryoshka_pretraining: bool = field(default=False, metadata={"help": "Do matryoshka pretraining"}) + mae_pretraining: bool = field(default=False, metadata={"help": "Do MAE pretraining"}) + num_processes: int = field(default=16, metadata={"help": "Number of processes to use for data loading"}) + mlm_probability: float = field(default=0.15, metadata={"help": "Probability of masking tokens"}) + decoder_mlm_probability: float = field(default=0.3, metadata={"help": "Probability of masking tokens for the decoder"}) + + + +def main(): + parser = HfArgumentParser((MatryoshkaPretrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + training_args, = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + training_args, = parser.parse_args_into_dataclasses() + training_args: MatryoshkaPretrainingArguments + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + set_seed(training_args.seed) + + tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + train_dataset = MLMDataset(tokenizer, training_args.num_processes) + + if training_args.matryoshka_pretraining: + if training_args.mae_pretraining: + collator = MaeDataCollatorForWholeWordMask(tokenizer, + encoder_mlm_probability=training_args.mlm_probability, + decoder_mlm_probability=training_args.decoder_mlm_probability, + pad_to_multiple_of=8) + model = BertFor2DMaekMatryoshkaMaskedLM.from_pretrained('bert-base-uncased') + else: + collator = DataCollatorForWholeWordMaskWithAttentionMask(tokenizer, + mlm_probability=training_args.mlm_probability, + pad_to_multiple_of=8) + model = BertFor2DMatryoshkaMaskedLM.from_pretrained('bert-base-uncased') + else: + collator = DataCollatorForWholeWordMaskWithAttentionMask(tokenizer, + mlm_probability=training_args.mlm_probability, + pad_to_multiple_of=8) + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + + + trainer = SkipNanTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=collator + ) + train_dataset.trainer = trainer + + trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + if trainer.is_world_process_zero(): + tokenizer.save_pretrained(training_args.output_dir) + + +if __name__ == "__main__": + main()