You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What are the minimum requirements for a gpu to run your solution on test data (deepspeed_xl_runormas_v14_130k_finetune.sh) ?
I try to use this sh file with some changes in paths:
NUM_GPUS_PER_WORKER=0
gpt_options=" \
--train-data-path /home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/data/public_test_v13/files.list \
--max-files-per-process 10 \
--logging-dir=/mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/demo/ \
--load /mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/mp_rank_00_model_states.pt \
--save /mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/demo/ \
--tokenizer-path sberbank-ai/rugpt3xl \
--no-load-optim \
--finetune \
--cache-prefix p16 \
--save-interval 5000 \
--log-interval 100 \
--model-parallel-size 1 \
--num-layers 24 \
--hidden-size 2048 \
--num-attention-heads 16 \
--batch-size 1 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--train-iters 20000 \
--distributed-backend nccl \
--lr 0.000015 \
--warmup 0.0 \
--lr-decay-style constant \
--weight-decay 1e-2 \
--fp16 \
--local_rank=0 \
--sparse-mode alternating \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--deepspeed \
--deepspeed_config ../src/deepspeed_config/gpt3_xl_sparse_2048.json \
"
run_cmd="USE_DEEPSPEED=1 python3.7 ../pretrain_gpt3.py $@ ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
set +x```
but I have error:
```$ ./deepspeed_xl_runormas_v14_130k_finetune.sh
USE_DEEPSPEED=1 python3.7 ../pretrain_gpt3.py --train-data-path /home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/data/public_test_v13/files.list --max-files-per-process 10 --logging-dir=/mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/demo/ --load /mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/mp_rank_00_model_states.pt --save /mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/demo/ --tokenizer-path sberbank-ai/rugpt3xl --no-load-optim --finetune --cache-prefix p16 --save-interval 5000 --log-interval 100 --model-parallel-size 1 --num-layers 24 --hidden-size 2048 --num-attention-heads 16 --batch-size 1 --seq-length 2048 --max-position-embeddings 2048 --train-iters 20000 --distributed-backend nccl --lr 0.000015 --warmup 0.0 --lr-decay-style constant --weight-decay 1e-2 --fp16 --local_rank=0 --sparse-mode alternating --checkpoint-activations --deepspeed-activation-checkpointing --deepspeed --deepspeed_config ../src/deepspeed_config/gpt3_xl_sparse_2048.json
using world size: 1 and model-parallel size: 1
> using dynamic loss scaling
> initializing model parallel with size 1
[2021-05-29 00:53:48,161] [INFO] [checkpointing.py:629:_configure_using_config_file] {'partition_activations': False, 'contiguous_memory_optimization': False, 'cpu_checkpointing': False, 'number_checkpoints': None, 'synchronize_checkpoint_boundary': False, 'profile': False}
Pretrain GPT3 model
arguments:
attention_dropout ............ 0.1
num_attention_heads .......... 16
hidden_size .................. 2048
intermediate_size ............ None
num_layers ................... 24
layernorm_epsilon ............ 1e-05
hidden_dropout ............... 0.1
max_position_embeddings ...... 2048
vocab_size ................... 30522
deep_init .................... False
make_vocab_size_divisible_by . 8
cpu_optimizer ................ False
cpu_torch_adam ............... False
sparse_mode .................. alternating
fp16 ......................... True
fp32_embedding ............... False
fp32_layernorm ............... False
fp32_tokentypes .............. False
fp32_allreduce ............... False
hysteresis ................... 2
loss_scale ................... None
loss_scale_window ............ 1000
min_scale .................... 1
batch_size ................... 1
weight_decay ................. 0.01
checkpoint_activations ....... True
checkpoint_num_layers ........ 1
deepspeed_activation_checkpointing True
clip_grad .................... 1.0
train_iters .................. 20000
log_interval ................. 100
logging_dir .................. /mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/demo/
exit_interval ................ None
seed ......................... 1234
reset_position_ids ........... False
reset_attention_mask ......... False
lr_decay_iters ............... None
lr_decay_style ............... constant
lr ........................... 1.5e-05
min_lr ....................... 1e-06
warmup ....................... 0.0
save ......................... /mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/demo/
save_interval ................ 5000
no_save_optim ................ False
no_save_rng .................. False
load ......................... /mnt/de420497-d399-4936-af66-6bddbd9bfd33/Data/aspect_miner/mp_rank_00_model_states.pt
no_load_optim ................ True
log_memory ................... False
no_load_rng .................. False
load_huggingface ............. None
export_huggingface ........... None
huggingface_double_pos_embeddings False
load_tag .....................
cache_prefix ................. p16
finetune ..................... True
resume_dataloader ............ False
distributed_backend .......... nccl
local_rank ................... 0
master_port .................. 6000
eval_batch_size .............. None
eval_iters ................... 100
eval_interval ................ 1000
eval_seq_length .............. None
eval_max_preds_per_seq ....... None
overlapping_eval ............. 32
cloze_eval ................... False
eval_hf ...................... False
load_openai .................. False
temperature .................. 1.0
top_p ........................ 0.0
top_k ........................ 0
out_seq_length ............... 256
tg_token_name ................ token.txt
model_parallel_size .......... 1
shuffle ...................... False
train_data ................... None
use_npy_data_loader .......... False
train_data_path .............. /home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/data/public_test_v13/files.list
val_data_path ................
test_data_path ...............
input_data_sizes_file ........ sizes.txt
delim ........................ ,
text_key ..................... sentence
eval_text_key ................ None
valid_data ................... None
split ........................ 1000,1,1
test_data .................... None
overwrite_cache .............. False
lazy_loader .................. False
loose_json ................... False
presplit_sentences ........... False
num_workers .................. 2
tokenizer_path ............... sberbank-ai/rugpt3xl
cache_dir .................... None
use_tfrecords ................ False
seq_length ................... 2048
max_files_per_process ........ 10
max_preds_per_seq ............ None
path ......................... None
tokenizer_name ............... sberbank-ai/rugpt3xl
output_dir ................... None
part ......................... train
train_part_name .............. train
data_parts ................... generic,named
answer_sep ................... <answer>
window_size .................. 0
start_sep .................... <start>
end_sep ...................... <end>
save_preds_path .............. ../test_pred/
do_sample .................... None
num_beams .................... None
loss_only_norm ............... False
line_by_line ................. False
deepspeed .................... True
deepspeed_config ............. ../src/deepspeed_config/gpt3_xl_sparse_2048.json
deepscale .................... False
deepscale_config ............. None
deepspeed_mpi ................ False
cuda ......................... True
rank ......................... 0
world_size ................... 1
dynamic_loss_scale ........... True
[2021-05-29 00:53:48,161] [INFO] [checkpointing.py:256:model_parallel_cuda_manual_seed] > initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
Load tokenizer from sberbank-ai/rugpt3xl
Add answer_sep: <answer>
Add start_sep <start>
Add start_sep <end>
Load RuGPT3 Dataset from /home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/data/public_test_v13/files.list, 10 files per process
R0/1: Loading dataset /home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/data/public_test_v13/files.list
R0/1: Check filelist /home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/data/public_test_v13/files.list with root dir /home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/data/public_test_v13
Shard size 545 > max_file_load 10, only first 10 files of dataset would be loaded!
R0/1: Shard [0, 10]
R0/1: Loaded 0/10 files
R0/1: Loaded 78 examples, 159744 tokens
> padded vocab (size: 50260) with 4 dummy tokens (new size: 50264)
> end-of-document token: 0
building GPT3 model ...
Use sparse attention with mode alternating
Use alternating sparse & dense attention layers
> number of parameters on model parallel rank 0: 1315737600
Optimizer = FusedAdam
learning rate decaying constant
DeepSpeed is enabled.
[2021-05-29 00:54:03,370] [INFO] [logging.py:60:log_dist] [Rank 0] DeepSpeed info: version=0.3.7, git-hash=unknown, git-branch=unknown
[2021-05-29 00:54:03,380] [INFO] [engine.py:588:_configure_optimizer] Using client Optimizer as basic optimizer
[2021-05-29 00:54:03,380] [INFO] [engine.py:597:_configure_optimizer] DeepSpeed Basic Optimizer = FusedAdam (
Parameter Group 0
betas: (0.9, 0.999)
bias_correction: True
eps: 1e-08
lr: 1.5e-05
weight_decay: 0.01
Parameter Group 1
betas: (0.9, 0.999)
bias_correction: True
eps: 1e-08
lr: 1.5e-05
weight_decay: 0.0
)
Checking ZeRO support for optimizer=FusedAdam type=<class 'apex.optimizers.fused_adam.FusedAdam'>
[2021-05-29 00:54:03,380] [INFO] [engine.py:715:_configure_zero_optimizer] Creating fp16 ZeRO stage 2 optimizer
Using /home/mdomrachev/.cache/torch_extensions as PyTorch extensions root...
Emitting ninja build file /home/mdomrachev/.cache/torch_extensions/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module utils...
Time to load utils op: 0.23281097412109375 seconds
[2021-05-29 00:54:03,924] [INFO] [stage2.py:130:__init__] Reduce bucket size 50000000
[2021-05-29 00:54:03,924] [INFO] [stage2.py:131:__init__] Allgather bucket size 500000000
[2021-05-29 00:54:03,924] [INFO] [stage2.py:132:__init__] CPU Offload: False
Traceback (most recent call last):
File "../pretrain_gpt3.py", line 834, in <module>
main()
File "../pretrain_gpt3.py", line 790, in main
model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
File "../pretrain_gpt3.py", line 190, in setup_model_and_optimizer
dist_init_required=False
File "/home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/gpt/lib/python3.7/site-packages/deepspeed/__init__.py", line 118, in initialize
config_params=config_params)
File "/home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/gpt/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 181, in __init__
self._configure_optimizer(optimizer, model_parameters)
File "/home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/gpt/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 609, in _configure_optimizer
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
File "/home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/gpt/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 749, in _configure_zero_optimizer
gradient_accumulation_steps=self.gradient_accumulation_steps())
File "/home/mdomrachev/aspect-extractor/entities_normalizer/RuNormAS-solution/gpt/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 264, in __init__
self.device).clone().float().detach())
RuntimeError: CUDA out of memory. Tried to allocate 4.90 GiB (GPU 0; 10.91 GiB total capacity; 4.90 GiB already allocated; 4.85 GiB free; 4.90 GiB reserved in total by PyTorch)
Hello.
Thanks for your work.
What are the minimum requirements for a gpu to run your solution on test data (deepspeed_xl_runormas_v14_130k_finetune.sh) ?
I try to use this sh file with some changes in paths:
gpu = 1080Ti
triton==0.2.2
deepspeed==0.3.7
transformers==3.5.0
torch=1.7.1
But, if I run this code, it;s ok:
Any parameters in the script that need to be changed?
The text was updated successfully, but these errors were encountered: