-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9683b20
commit bb18db2
Showing
5 changed files
with
698 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# pdm | ||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
#pdm.lock | ||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
# in version control. | ||
# https://pdm.fming.dev/#use-with-ide | ||
.pdm.toml | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
.idea/ | ||
|
||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Starbucks Masked Autoencoder | ||
|
||
```bash | ||
python pretrain_bert.py \ | ||
--output_dir checkpoints/bert-base-uncased-fineweb100bt-smae \ | ||
--save_steps 10000 \ | ||
--bf16 \ | ||
--per_device_train_batch_size 32 \ | ||
--gradient_accumulation_steps 2 \ | ||
--learning_rate 1e-4 \ | ||
--lr_scheduler_type cosine \ | ||
--weight_decay 0.001 \ | ||
--warmup_ratio 0.05 \ | ||
--num_train_epochs 1 \ | ||
--logging_steps 100 \ | ||
--mlm_probability 0.2 \ | ||
--decoder_mlm_probability 0.4 \ | ||
--report_to wandb \ | ||
--matryoshka_pretraining True \ | ||
--mae_pretraining True \ | ||
--run_name pretrain-bert-mae-matryoshka-fineweb100bt-starbucks \ | ||
--dataloader_num_workers 16 \ | ||
--num_processes 32 \ | ||
--save_safetensors False \ | ||
--log_level info \ | ||
--logging_nan_inf_filter False | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from transformers import BatchEncoding, PreTrainedTokenizer, DataCollatorForWholeWordMask, BertTokenizer, BertTokenizerFast | ||
from transformers.data.data_collator import _torch_collate_batch | ||
import warnings | ||
from torch.utils.data import Dataset | ||
from datasets import load_dataset | ||
import random | ||
import torch | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union | ||
|
||
|
||
class MLMDataset(Dataset): | ||
def __init__(self, tokenizer: PreTrainedTokenizer, num_processes=16): | ||
self.tokenizer = tokenizer | ||
self.max_length = self.tokenizer.model_max_length - self.tokenizer.num_special_tokens_to_add(pair=False) | ||
self.corpus = load_dataset( | ||
'HuggingFaceFW/fineweb', # hard code for now | ||
'sample-100BT', | ||
split='train', | ||
num_proc=num_processes, | ||
) | ||
|
||
def __len__(self): | ||
return len(self.corpus) | ||
|
||
def __getitem__(self, item) -> BatchEncoding: | ||
text = self.corpus[item]['text'] | ||
|
||
# if the text is too long, truncate it randomly | ||
tokens = self.tokenizer.tokenize(text) | ||
|
||
if len(tokens) > self.max_length: | ||
trunc = len(tokens) - self.max_length | ||
trunc_left = random.randint(0, trunc) | ||
trunc_right = trunc - trunc_left | ||
|
||
truncated = tokens[trunc_left:] | ||
if trunc_right > 0: | ||
truncated = truncated[:-trunc_right] | ||
text = self.tokenizer.convert_tokens_to_string(truncated) | ||
|
||
tokenized_text = self.tokenizer(text, | ||
return_special_tokens_mask=False, | ||
return_token_type_ids=False, | ||
truncation=True) | ||
return tokenized_text | ||
|
||
@dataclass | ||
class DataCollatorForWholeWordMaskWithAttentionMask(DataCollatorForWholeWordMask): | ||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: | ||
output = super().torch_call(examples) | ||
batch = self.tokenizer.pad( | ||
examples, | ||
padding=True, | ||
return_attention_mask=True, | ||
pad_to_multiple_of=self.pad_to_multiple_of, | ||
return_tensors='pt', | ||
) | ||
output['attention_mask'] = batch['attention_mask'] | ||
return output | ||
|
||
|
||
|
||
@dataclass | ||
class MaeDataCollatorForWholeWordMask(DataCollatorForWholeWordMask): | ||
encoder_mlm_probability: float = 0.3 | ||
decoder_mlm_probability: float = 0.5 | ||
|
||
def __post_init__(self): | ||
super(MaeDataCollatorForWholeWordMask, self).__post_init__() | ||
|
||
from transformers import BertTokenizer, BertTokenizerFast | ||
from transformers import RobertaTokenizer, RobertaTokenizerFast | ||
if isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)): | ||
self.whole_word_cand_indexes = self._whole_word_cand_indexes_bert | ||
elif isinstance(self.tokenizer, (RobertaTokenizer, RobertaTokenizerFast)): | ||
self.whole_word_cand_indexes = self._whole_word_cand_indexes_roberta | ||
else: | ||
raise NotImplementedError(f'{type(self.tokenizer)} collator not supported yet') | ||
|
||
self.specials = self.tokenizer.all_special_tokens | ||
|
||
def _whole_word_cand_indexes_bert(self, input_tokens: List[str]): | ||
cand_indexes = [] | ||
for (i, token) in enumerate(input_tokens): | ||
if token in self.specials: | ||
continue | ||
|
||
if len(cand_indexes) >= 1 and token.startswith("##"): | ||
cand_indexes[-1].append(i) | ||
else: | ||
cand_indexes.append([i]) | ||
return cand_indexes | ||
|
||
def _whole_word_cand_indexes_roberta(self, input_tokens: List[str]): | ||
cand_indexes = [] | ||
for (i, token) in enumerate(input_tokens): | ||
if token in self.specials: | ||
raise ValueError('We expect only raw input for roberta for current implementation') | ||
|
||
if i == 0: | ||
cand_indexes.append([0]) | ||
elif not token.startswith('\u0120'): | ||
cand_indexes[-1].append(i) | ||
else: | ||
cand_indexes.append([i]) | ||
return cand_indexes | ||
|
||
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): | ||
""" | ||
Get 0/1 labels for masked tokens with whole word mask proxy | ||
""" | ||
|
||
cand_indexes = self._whole_word_cand_indexes_bert(input_tokens) | ||
|
||
random.shuffle(cand_indexes) | ||
encoder_num_to_predict = min(max_predictions, | ||
max(1, int(round(len(input_tokens) * self.encoder_mlm_probability)))) | ||
decoder_num_to_predict = min(max_predictions, | ||
max(1, int(round(len(input_tokens) * self.decoder_mlm_probability)))) | ||
|
||
masked_lms = [] | ||
encoder_masked_lms = [] | ||
decoder_masked_lms = [] | ||
covered_indexes = set() | ||
for index_set in cand_indexes: | ||
if len(masked_lms) >= max(encoder_num_to_predict, decoder_num_to_predict): | ||
break | ||
# If adding a whole-word mask would exceed the maximum number of | ||
# predictions, then just skip this candidate. | ||
if len(masked_lms) + len(index_set) > max(encoder_num_to_predict, decoder_num_to_predict): | ||
continue | ||
is_any_index_covered = False | ||
for index in index_set: | ||
if index in covered_indexes: | ||
is_any_index_covered = True | ||
break | ||
if is_any_index_covered: | ||
continue | ||
|
||
encoder_add = True if len(encoder_masked_lms) + len(index_set) <= encoder_num_to_predict else False | ||
decoder_add = True if len(decoder_masked_lms) + len(index_set) <= decoder_num_to_predict else False | ||
|
||
for index in index_set: | ||
covered_indexes.add(index) | ||
if encoder_add and decoder_add: | ||
encoder_masked_lms.append(index) | ||
if decoder_add: | ||
decoder_masked_lms.append(index) | ||
masked_lms.append(index) | ||
|
||
assert len(covered_indexes) == len(masked_lms) | ||
encoder_mask_labels = [1 if i in encoder_masked_lms else 0 for i in range(len(input_tokens))] | ||
decoder_mask_labels = [1 if i in decoder_masked_lms else 0 for i in range(len(input_tokens))] | ||
return encoder_mask_labels, decoder_mask_labels | ||
|
||
def __call__(self, examples, return_tensors=None): | ||
batch = self.tokenizer.pad( | ||
examples, | ||
padding=True, | ||
return_attention_mask=True, | ||
pad_to_multiple_of=self.pad_to_multiple_of, | ||
return_tensors='pt', | ||
) | ||
|
||
input_ids = [e["input_ids"] for e in examples] | ||
|
||
encoder_mlm_masks = [] | ||
decoder_mlm_masks = [] | ||
|
||
for e in input_ids: | ||
tokens = [] | ||
for tid in e: | ||
tokens.append(self.tokenizer._convert_id_to_token(tid)) | ||
|
||
encoder_mlm_mask, decoder_mlm_mask = self._whole_word_mask(tokens, self.tokenizer.model_max_length) | ||
encoder_mlm_masks.append(encoder_mlm_mask) | ||
decoder_mlm_masks.append(decoder_mlm_mask) | ||
|
||
encoder_mlm_masks = _torch_collate_batch(encoder_mlm_masks, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) | ||
decoder_mlm_masks = _torch_collate_batch(decoder_mlm_masks, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) | ||
|
||
|
||
encoder_inputs, encoder_labels = self.torch_mask_tokens( | ||
batch['input_ids'].clone(), | ||
encoder_mlm_masks.clone() | ||
) | ||
|
||
decoder_inputs, decoder_labels = self.torch_mask_tokens( | ||
batch['input_ids'].clone(), | ||
decoder_mlm_masks.clone() | ||
) | ||
|
||
output = { | ||
"input_ids": encoder_inputs, | ||
"labels": encoder_labels, | ||
"decoder_input_ids": decoder_inputs, | ||
"decoder_labels": decoder_labels, | ||
"attention_mask": batch['attention_mask'], | ||
} | ||
|
||
return output |
Oops, something went wrong.