-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Research diary: reversibles #11
Comments
Research setting:
Limitations:
Results:tl;dr reformer is still the best
Validation loss closely tracks the training loss. Gradient norm supports the hypothesis for numerical instability of momentum-reversible models DetailsModel config{
"model_type": "lean_gpt",
"architectures": [ "LeanGPTModel" ],
"num_hidden_layers": 24,
"num_hidden_groups": 24,
"num_inner_groups": 1,
"hidden_size": 1024,
"embedding_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"vocab_size": 50308,
"hidden_act": "gelu_fused",
"position_embedding_type": "rotary",
"tie_word_embeddings": false,
"reversible": YOUR_OPTION_HERE,
"hidden_dropout_prob": 0,
"attention_probs_dropout_prob": 0,
"layer_norm_eps": 1e-12,
"pad_token_id": 1,
"bos_token_id": 0,
"eos_token_id": 2
} Training codeUsing a slightly modified fairseq training: https://github.com/justheuristic/junk/tree/fairseq #9f5bff306ee93780e0c9483162dd24e244403919 . The only difference = it supports training with LeanTransformer. Was validated to match the learning curve of fairseq transformer.PYTHONPATH=`pwd`:$PYTHONPATH python fairseq_cli/train.py \
$INPUT_PATH/data-bin/openwebtext --task language_modeling --arch lean_lm --hf-model-config $SOURCE_CODE_PATH/model_config.json \
--max-tokens 32768 --update-freq 4 --max-update 50000 --tokens-per-sample 2048 --sample-break-mode none \
--ddp-backend pytorch_ddp --distributed-world-size $NUM_GPUS --seed 4 \
--amp --fp16-no-flatten-grads --min-loss-scale 1e-10 --fp16-scale-window 250 \
--lr-scheduler cosine --lr 0.0003 --warmup-init-lr 0.0 --warmup-updates 5000 \
--optimizer adam --weight-decay 0.1 --clip-norm 1.0 --adam-betas "(0.9, 0.95)" --adam-eps 1e-08 \
--save-dir $SNAPSHOT_PATH --save-interval-updates 1000 --keep-best-checkpoints 1 --no-epoch-checkpoints --keep-interval-updates 2 \
--valid-subset valid,valid_1b,valid_lambada,valid_ccnews,valid_wiki,valid_wiki2,valid_ptb --validate-interval-updates 1000 \
--log-format simple --log-interval 50 --wandb-project $WANDB_PROJECT
Libraries & versions#!/usr/bin/env bash
set -euxo pipefail
############################################################################
# core libraries
############################################################################
apt-get update --allow-unauthenticated --allow-insecure-repositories
apt-get install -y --no-install-recommends \
build-essential \
g++ gdb subversion \
software-properties-common
apt-get install -y --no-install-recommends \
wget curl vim nano ssh git libssl-dev
apt-get remove -y swig || true
apt-get install -y --no-install-recommends libstdc++6
apt-get install -y --no-install-recommends swig3.0
ln -s /usr/bin/swig3.0 /usr/bin/swig
############################################################################
# install anaconda (because native python stopped working
############################################################################
wget https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh
bash Anaconda3-2021.11-Linux-x86_64.sh -b -p /anaconda3
source /anaconda3/bin/activate
############################################################################
# common python libraries (project specfic libs these are installed later)
############################################################################
conda update -y conda
conda install -y python=3.8.12 --strict-channel-priority
conda install -y numpy scipy cython pandas h5py numba
pip install --upgrade setuptools
# common + devops
pip install \
PyYAML==5.4.1 \
Pillow==8.3.0 \
docopt==0.6.2 \
typer==0.3.2 \
black==21.6b0 \
bokeh==2.4.0dev1 \
isort==5.9.1 \
icecream==2.1.1 \
flake8==3.9.2 \
uvloop==0.15.2 \
packaging==19.0 \
msgpack==0.5.6 \
sortedcontainers==2.4.0 \
configargparse==1.2.3 \
tqdm==4.48.2 \
termcolor==1.0.0
# common data science libs
pip install \
ninja==1.10.0.post1 \
tensorboardX==2.4 \
wandb==0.10.33 \
matplotlib==3.4.2 \
seaborn==0.11.1 \
holoviews==1.14.4 \
plotly==5.1.0 \
jupyterlab==3.0.16
# pytorch utils
conda install -y cudatoolkit=11.3 -c pytorch
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install https://github.com/huggingface/transformers/archive/3dc82427166239e2764196c07fa4c5dcc25b1590.zip # 4.18.dev0
pip install datasets==2.0.0
pip install \
torch_optimizer==0.1.0 \
revlib==1.7.0 \
bitsandbytes-cuda113==0.26.0 \
pytorch-lightning==1.3.8 \
triton==1.0.0 \
einops==0.3.2 \
libzero==0.0.5
# domain-specific ML libs
pip install \
opencv-python==4.4.0.42 \
albumentations==1.0.0 \
scikit-image==0.17.2 \
lmdb==1.2.1 \
librosa==0.7.0 \
sentencepiece==0.1.96 \
nltk==3.6.2 \
gensim==4.0.1 \
sacrebleu==1.5.1 \
sacremoses==0.0.45 \
subword-nmt==0.3.7 \
youtokentome==1.0.6
pip uninstall -y enum34
############################################################################
# Set locale
############################################################################
locale-gen ru_RU.UTF-8
update-locale
############################################################################
# Clean
############################################################################
apt-get autoremove
apt-get clean
apt-get autoclean
rm -rf /var/lib/apt/lists/*
rm -rf /tmp/*
rm -rf /.cache
rm -rf /var/cache/apt/*.bin
find /var/log -iname '*.gz' -delete
find /var/log -iname '*.1' -delete
###########################################################################
# project-specific libraries (aka YOUR CODE HERE)
###########################################################################
# hivemind dependencies
pip install \
prefetch_generator>=1.0.1 \
grpcio>=1.33.2 \
grpcio-tools>=1.33.2 \
multiaddr>=0.0.9 \
pymultihash>=0.8.2 \
cryptography>=3.4.6 \
pydantic>=1.8.1 \
whatsmyip
pip install razdel==0.5.0
# golang
wget https://golang.org/dl/go1.16.4.linux-amd64.tar.gz
rm -rf /usr/local/go && tar -C /usr/local -xzf go1.16.4.linux-amd64.tar.gz
export PATH=$PATH:/usr/local/go/bin
pip install omegaconf==2.0.5 antlr4-python3-runtime==4.8 hydra-core==1.0.7 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This issue may or may not contain my notes about implementing and training reversble models.
From correspondence with @TimDettmers, @yhn112 , @borzunov, @mryab
Why? Reversible models are one of the few ways to fit large transformers into low-end DL rigs (single gpu, 8-12gb vram, 16-32gb ram). The alternatives are nested checkpointing [slower], checkpoint quantization [may affect convergence, still more memory], or checkpoint offloading [ram is already used up for master params and optimizer]. Hence, reversibles are the most memory-efficient training strategy if they can match the quality of regular models.
As of c076ba7 , we support reversible=True which triggers reversible transformer as it was defined in Reformer. However, this is not the only possible way. Existing alternatives are:
Running both Attention and FFN in each branch (the default reformer has only attn in F branch and only FFN in G branch)
MomentumNet: https://arxiv.org/abs/1907.11818
Adjusted MomentumNet from homebrewnlp - see Lucas' notes from revlib - ported to LearnTransformer
Simple running average: code in LearnTransformer
source: running average coupling
Furthermore, it is unclear how best to use reversible's two inputs and two outputs with transformers. If we cannot afford to double the layer sizes, this raises the following questions:
Finally, there are implementation hacks that can affect the training throughput and memory requirements:
The text was updated successfully, but these errors were encountered: