Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run_pretrain_bart.sh returns IndexError #82

Open
shivanraptor opened this issue Aug 30, 2024 · 13 comments
Open

run_pretrain_bart.sh returns IndexError #82

shivanraptor opened this issue Aug 30, 2024 · 13 comments

Comments

@shivanraptor
Copy link

Here is the stacktrace of run_pretrain_bart.sh error:

[rank0]: IndexError: Caught IndexError in DataLoader worker process 0.
[rank0]: Original Traceback (most recent call last):
[rank0]:   File "/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
[rank0]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
[rank0]:   File "/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
[rank0]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank0]:   File "/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
[rank0]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/blendable_dataset.py", line 83, in __getitem__
[rank0]:     return self.datasets[dataset_idx][sample_idx]
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/bart_dataset.py", line 106, in __getitem__
[rank0]:     return self.build_training_sample(sample, self.max_seq_length, np_rng)
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/bart_dataset.py", line 148, in build_training_sample
[rank0]:     source = self.add_whole_word_mask(source, mask_ratio, replace_length)
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/bart_dataset.py", line 360, in add_whole_word_mask
[rank0]:     source[indices[mask_random]] = torch.randint(
[rank0]: IndexError: The shape of the mask [2] at index 0 does not match the shape of the indexed tensor [1] at index 0

How can I debug it? It seems the dataset is problematic, but the dataset was generated by the MEGATRON Pre-process scripts in the tools/ folder and it runs without an error. The .bin and .idx are generated properly (and I put it in the dataset/ folder).

@choosewhatulike
Copy link
Member

Hi,
It seems something's wrong when loading the training data, could you provide your training script and the structure of the dataset folder?

@shivanraptor
Copy link
Author

shivanraptor commented Aug 30, 2024

Thank you for prompt reply:

The training command of run_pretrain_bart.sh is just ./run_pretrain_bart.sh. I edited the .sh file before to remove the distributed node mechanism, which the file becomes:

#!/bin/bash

DATA_PATH="dataset/"
CHECKPOINT_PATH=checkpoints/bart-base
VOCAB_FILE=vocab/

python \
       pretrain_bart.py \
       --num-layers 12 \
       --hidden-size 768 \
       --num-attention-heads 12 \
       --micro-batch-size 16 \
       --global-batch-size 256 \
       --seq-length 512 \
       --max-position-embeddings 512 \
       --mask-prob 0.15 \
       --train-iters 100000 \
       --lr-decay-iters 100000 \
       --save $CHECKPOINT_PATH \
       --load $CHECKPOINT_PATH \
       --data-path $DATA_PATH \
       --vocab-file $VOCAB_FILE \
       --data-impl mmap \
       --split 949,30,1 \
       --distributed-backend nccl \
       --lr 1e-4 \
       --lr-decay-style cosine \
       --min-lr 1e-6 \
       --initial-loss-scale 65536 \
       --weight-decay 1e-2 \
       --clip-grad 1.0 \
       --lr-warmup-fraction .01 \
       --log-interval 1 \
       --save-interval 1600 \
       --eval-interval 500 \
       --eval-iters 10 \
       --fp16 \
       --optimizer adam \
       --num-workers 2 \
       # --checkpoint-activations

For the MEGATRON pre-process data command, it is:

python tools/preprocess_data.py --input ../data/megatron_3m.json --output-prefix cantonese-bert --vocab ../bert-tokenizer-cantonese/vocab-bart-base-cantonese.txt --dataset-impl mmap --tokenizer-type BertWordPieceLowerCase --split-sentences

which the sample contents of vocab-bart-base-cantonese.txt are as follows:

巴塞
巴士
巴士佬
巴士司機
巴士站
巴士線
巴巴
巴打
巴拿馬
巴林
巴爾
巴絲
巴膠
巴西
巴閉
巴黎
市中心
市值
市價
市區
市場
市政
市民
市況
市長
市集
市面

Please note that I intended to use more than 1 character for the vocabulary, which might lead to the problem.

For the data structure of dataset/ folder, here it is:

cantonese-bert_text_sentence_0.bin
cantonese-bert_text_sentence_0.idx
cantonese-bert_text_sentence_0_test_indexmap_2573mns_510msl_0.00ssp_1234s.npy
cantonese-bert_text_sentence_0_train_indexmap_25728000mns_510msl_0.00ssp_1234s.npy
cantonese-bert_text_sentence_0_valid_indexmap_517133mns_510msl_0.00ssp_1234s.npy

and I checked the content of each .npy file using np.load(filename), it shows:

[[34753249 34753250      510]
 [34773514 34773515      510]
 [34749867 34749868      510]
 ...
 [34774749 34774752      510]
 [34757999 34758000      510]
 [34748040 34748041      510]]

For the vocab folder data structure:

added_tokens.json
config.json
generation_config.json
pytorch_model.bin
special_tokens_map.json
tokenizer_config.json
vocab.txt

For the roberta_zh folder data structure:

config.json
pytorch_model.bin

Hope the above information helps.

@choosewhatulike
Copy link
Member

choosewhatulike commented Aug 30, 2024

I noticed that you have a custom tokenizer. Have you changed the vocab_size in the model config and the tokenizer config accordingly? You can try to run the pre-training on the original Chinese Bart Model to see if it still works.

@shivanraptor
Copy link
Author

Oh wow, how do you notice I use a custom tokenizer?

I use the following codes to add tokens:

from transformers import BertTokenizer, BartForConditionalGeneration

tokenizer = BertTokenizer.from_pretrained("fnlp/bart-base-chinese")
model = BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese")

with open('data/wordsegs_uniq.txt', 'r') as f:
    tokenizer.add_tokens(f.readlines())

model.resize_token_embeddings(len(tokenizer))

# Save to local and HF
model.save_pretrained("wordseg_tokenizer")
tokenizer.save_pretrained("wordseg_tokenizer")
model.push_to_hub("raptorkwok/bart-base-cantonese")
tokenizer.push_to_hub("raptorkwok/bart-base-cantonese")

I think the .resize_token_embedding() function already changed the vocab_size and I proved by checking the config.json at https://huggingface.co/raptorkwok/bart-base-cantonese/blob/main/config.json. It looks correct to me.

You mentioned about the tokenizer config, but the tokenizer_config.json does not specify the number of vocabs. You might want to browse the content here

@shivanraptor
Copy link
Author

The full output is here:

$ ./run_pretrain_bart.sh
[2024-08-30 20:11:19,870] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
2024-08-30 20:11:21.052166: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-30 20:11:21.072135: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-30 20:11:21.078269: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-30 20:11:21.833711: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
using world size: 1, data-parallel-size: 1, tensor-model-parallel size: 1, pipeline-model-parallel size: 1 
using torch.float16 for parameters ...
------------------------ arguments ------------------------
  accumulate_allreduce_grads_in_fp32 .............. False
  adam_beta1 ...................................... 0.9
  adam_beta2 ...................................... 0.98
  adam_eps ........................................ 1e-06
  adlr_autoresume ................................. False
  adlr_autoresume_interval ........................ 1000
  apply_query_key_layer_scaling ................... True
  apply_residual_connection_post_layernorm ........ False
  attention_dropout ............................... 0.1
  attention_softmax_in_fp32 ....................... False
  bert_binary_head ................................ True
  bert_load ....................................... None
  bf16 ............................................ False
  bias_dropout_fusion ............................. True
  bias_gelu_fusion ................................ True
  biencoder_projection_dim ........................ 0
  biencoder_shared_query_context_model ............ False
  block_data_path ................................. None
  checkpoint_activations .......................... False
  checkpoint_num_layers ........................... 1
  clip_grad ....................................... 1.0
  consumed_train_samples .......................... 0
  consumed_valid_samples .......................... 0
  data_impl ....................................... mmap
  data_parallel_size .............................. 1
  data_path ....................................... ['1', 'dataset/cantonese-bert_text_sentence_0']
  dataloader_type ................................. single
  DDP_impl ........................................ local
  decoder_seq_length .............................. None
  deepscale ....................................... False
  deepscale_config ................................ None
  deepspeed ....................................... False
  deepspeed_config ................................ None
  dim_target_kl ................................... 1.0
  distribute_checkpointed_activations ............. False
  distributed_backend ............................. nccl
  embedding_path .................................. None
  encoder_seq_length .............................. 512
  eod_mask_loss ................................... False
  eval_interval ................................... 500
  eval_iters ...................................... 10
  evidence_data_path .............................. None
  exit_duration_in_mins ........................... None
  exit_interval ................................... None
  ffn_hidden_size ................................. 3072
  finetune ........................................ False
  fp16 ............................................ True
  fp16_lm_cross_entropy ........................... False
  fp32_residual_connection ........................ False
  global_batch_size ............................... 256
  hidden_dropout .................................. 0.1
  hidden_size ..................................... 768
  hysteresis ...................................... 2
  ict_head_size ................................... None
  ict_load ........................................ None
  img_dim ......................................... 224
  indexer_batch_size .............................. 128
  indexer_log_interval ............................ 1000
  init_method_std ................................. 0.02
  init_method_xavier_uniform ...................... False
  initial_loss_scale .............................. 65536.0
  kv_channels ..................................... 64
  layernorm_epsilon ............................... 1e-05
  lazy_mpu_init ................................... None
  load ............................................ checkpoints/bart-base
  local_rank ...................................... None
  log_batch_size_to_tensorboard ................... False
  log_interval .................................... 1
  log_learning_rate_to_tensorboard ................ True
  log_loss_scale_to_tensorboard ................... True
  log_num_zeros_in_grad ........................... False
  log_params_norm ................................. False
  log_timers_to_tensorboard ....................... False
  log_validation_ppl_to_tensorboard ............... False
  loss_scale ...................................... None
  loss_scale_window ............................... 200
  lr .............................................. 0.0001
  lr_decay_iters .................................. 100000
  lr_decay_samples ................................ None
  lr_decay_style .................................. cosine
  lr_encoder ...................................... None
  lr_warmup_fraction .............................. 0.01
  lr_warmup_iters ................................. 0
  lr_warmup_samples ............................... 0
  make_vocab_size_divisible_by .................... 128
  mask_prob ....................................... 0.15
  masked_softmax_fusion ........................... True
  max_position_embeddings ......................... 512
  merge_file ...................................... None
  micro_batch_size ................................ 16
  min_loss_scale .................................. 1.0
  min_lr .......................................... 1e-06
  mmap_warmup ..................................... False
  no_load_optim ................................... None
  no_load_rng ..................................... None
  no_save_optim ................................... None
  no_save_rng ..................................... None
  num_attention_heads ............................. 12
  num_channels .................................... 3
  num_classes ..................................... 1000
  num_decoder_layers .............................. None
  num_layers ...................................... 12
  num_layers_per_virtual_pipeline_stage ........... None
  num_workers ..................................... 2
  onnx_safe ....................................... None
  openai_gelu ..................................... False
  optimizer ....................................... adam
  override_lr_scheduler ........................... False
  params_dtype .................................... torch.float16
  patch_dim ....................................... 16
  pipeline_model_parallel_size .................... 1
  prompt_length ................................... 1
  query_in_block_prob ............................. 0.1
  rampup_batch_size ............................... None
  rank ............................................ 0
  raw_data_path ................................... ['dataset/']
  reset_attention_mask ............................ False
  reset_position_ids .............................. False
  retriever_report_topk_accuracies ................ []
  retriever_score_scaling ......................... False
  retriever_seq_length ............................ 256
  sample_rate ..................................... 1.0
  save ............................................ checkpoints/bart-base
  save_interval ................................... 1600
  scatter_gather_tensors_in_pipeline .............. True
  seed ............................................ 1234
  seq_length ...................................... 512
  sgd_momentum .................................... 0.9
  short_seq_prob .................................. 0.0
  split ........................................... 949,30,1
  tensor_model_parallel_size ...................... 1
  tensorboard_dir ................................. None
  tensorboard_log_interval ........................ 1
  tensorboard_queue_size .......................... 1000
  titles_data_path ................................ None
  tokenizer_type .................................. Huggingface
  train_iters ..................................... 100000
  train_samples ................................... None
  use_checkpoint_lr_scheduler ..................... False
  use_contiguous_buffers_in_ddp ................... False
  use_cpu_initialization .......................... None
  use_one_sent_docs ............................... False
  virtual_pipeline_model_parallel_size ............ None
  vocab_extra_ids ................................. 0
  vocab_file ...................................... vocab/
  weight_decay .................................... 0.01
  world_size ...................................... 1
  z_dim ........................................... 32
-------------------- end of arguments ---------------------
setting number of micro-batches to constant 16
> building Huggingface tokenizer ...
 > padded vocab (size: 51271) with 57 dummy tokens (new size: 51328)
> initializing torch distributed ...
> initializing tensor model parallel with size 1
> initializing pipeline model parallel with size 1
> setting random seeds to 1234 ...
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
> compiling dataset index builder ...
make: Entering directory '/home/jupyter-raptor/pretrain_tokenizer/megatron/data'
make: Nothing to be done for 'default'.
make: Leaving directory '/home/jupyter-raptor/pretrain_tokenizer/megatron/data'
>>> done with dataset index builder. Compilation time: 0.042 seconds
> compiling and loading fused kernels ...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/jupyter-raptor/pretrain_tokenizer/megatron/fused_kernels/build/build.ninja...
/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module scaled_upper_triang_masked_softmax_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module scaled_upper_triang_masked_softmax_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/jupyter-raptor/pretrain_tokenizer/megatron/fused_kernels/build/build.ninja...
/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module scaled_masked_softmax_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module scaled_masked_softmax_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/jupyter-raptor/pretrain_tokenizer/megatron/fused_kernels/build/build.ninja...
/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module fused_mix_prec_layer_norm_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_mix_prec_layer_norm_cuda...
>>> done with compiling and loading fused kernels. Compilation time: 0.634 seconds
/home/jupyter-raptor/pretrain_tokenizer/megatron/training.py:101: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:78.)
  start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
time to initialize megatron (seconds): 185.079
[after megatron is initialized] datetime: 2024-08-30 20:14:49 
building BART model ...
BartModel(
  (language_model): BartForConditionalGeneration(
    (model): BartModel(
      (shared): Embedding(51271, 768, padding_idx=0)
      (encoder): BartEncoder(
        (embed_tokens): Embedding(51271, 768, padding_idx=0)
        (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
        (layers): ModuleList(
          (0-5): 6 x BartEncoderLayer(
            (self_attn): BartAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
        (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
      (decoder): BartDecoder(
        (embed_tokens): Embedding(51271, 768, padding_idx=0)
        (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
        (layers): ModuleList(
          (0-5): 6 x BartDecoderLayer(
            (self_attn): BartAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (activation_fn): GELUActivation()
            (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (encoder_attn): BartAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
        (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (lm_head): Linear(in_features=768, out_features=51271, bias=False)
  )
)
 > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 140193024
<megatron.optimizer.optimizer.Float16OptimizerWithFloat16Params object at 0x7efc7394ff10>
> learning rate decay style: cosine
WARNING: could not find the metadata file checkpoints/bart-base/latest_checkpointed_iteration.txt 
    will not load any checkpoints and will start from random
time (ms) | load-checkpoint: 0.12
[after model, optimizer, and learning rate scheduler are built] datetime: 2024-08-30 20:14:52 
> building train, validation, and test datasets ...
 > datasets target sizes (minimum size):
    train:      25600000
    validation: 514560
    test:       2560
> building train, validation, and test datasets for BART ...
 > building dataset index ...
    reading sizes...
    reading pointers...
    reading document index...
    creating numpy buffer of mmap...
    creating memory view of numpy buffer...
 > finished creating indexed dataset in 0.000316 seconds
 > indexed dataset stats:
    number of documents: 30148984
    number of sentences: 34778811
 > dataset split:
    train:
     document indices in [0, 29195292) total of 29195292 documents
     sentence indices in [0, 33691437) total of 33691437 sentences
    validation:
     document indices in [29195292, 30118220) total of 922928 documents
     sentence indices in [33691437, 34744209) total of 1052772 sentences
    test:
     document indices in [30118220, 30148984) total of 30764 documents
     sentence indices in [34744209, 34778811) total of 34602 sentences
 > loading indexed mapping from dataset/cantonese-bert_text_sentence_0_train_indexmap_25728000mns_510msl_0.00ssp_1234s.npy
    loaded indexed file in 0.000 seconds
    total number of samples: 29195677
 > loading indexed mapping from dataset/cantonese-bert_text_sentence_0_valid_indexmap_517133mns_510msl_0.00ssp_1234s.npy
    loaded indexed file in 0.000 seconds
    total number of samples: 922942
 > loading indexed mapping from dataset/cantonese-bert_text_sentence_0_test_indexmap_2573mns_510msl_0.00ssp_1234s.npy
    loaded indexed file in 0.000 seconds
    total number of samples: 30764
> elapsed time for building blendable dataset indices: 0.00 (sec)
> elapsed time for building blendable dataset indices: 0.00 (sec)
> elapsed time for building blendable dataset indices: 0.00 (sec)
> finished creating BART datasets ...
collate_fn: None
collate_fn: None
collate_fn: None
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
[after dataloaders are built] datetime: 2024-08-30 20:14:54 
done with setup ...
time (ms) | model-and-optimizer-setup: 3226.58 | train/valid/test-data-iterators-setup: 1813.18
training ...
Building prefix dict from the default dictionary ...
[before the start of training step] datetime: 2024-08-30 20:14:54 
Loading model from cache /tmp/jieba.cache
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.672 seconds.
Prefix dict has been built succesfully.
bart_dataset.py line 358: torch.Size([13]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([17]) torch.Size([3, 1]) torch.Size([2]) torch.Size([2])
Loading model cost 0.684 seconds.
Prefix dict has been built succesfully.
bart_dataset.py line 358: torch.Size([15]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([29]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([45]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([16]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([13]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([18]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
Loading model cost 0.688 seconds.
Prefix dict has been built succesfully.
bart_dataset.py line 358: torch.Size([15]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
Loading model cost 0.688 seconds.
Prefix dict has been built succesfully.
bart_dataset.py line 358: torch.Size([57]) torch.Size([4, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([16]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([20]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([15]) torch.Size([4, 1]) torch.Size([3]) torch.Size([3])
bart_dataset.py line 358: torch.Size([20]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([55]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([18]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
Loading model cost 0.673 seconds.
Prefix dict has been built succesfully.
bart_dataset.py line 358: torch.Size([23]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
Loading model cost 0.677 seconds.
Prefix dict has been built succesfully.
bart_dataset.py line 358: torch.Size([15]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([13]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([13]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([17]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([20]) torch.Size([3, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([16]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
bart_dataset.py line 358: torch.Size([18]) torch.Size([2, 1]) torch.Size([2]) torch.Size([2])
bart_dataset.py line 358: torch.Size([18]) torch.Size([1, 1]) torch.Size([1]) torch.Size([2])
[rank0]: ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/pretrain_bart.py:133 in <module>                         │
[rank0]: │                                                                                                  │
[rank0]: │   130                                                                                            │
[rank0]: │   131 if __name__ == "__main__":                                                                 │
[rank0]: │   132 │                                                                                          │
[rank0]: │ ❱ 133 │   pretrain(train_valid_test_datasets_provider, model_provider, forward_step,             │
[rank0]: │   134 │   │   │    args_defaults={'tokenizer_type': 'Huggingface'})                              │
[rank0]: │   135                                                                                            │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/megatron/training.py:143 in pretrain                     │
[rank0]: │                                                                                                  │
[rank0]: │   140 │                                                                                          │
[rank0]: │   141 │   iteration = 0                                                                          │
[rank0]: │   142 │   if args.do_train and args.train_iters > 0:                                             │
[rank0]: │ ❱ 143 │   │   iteration = train(forward_step_func,                                               │
[rank0]: │   144 │   │   │   │   │   │     model, optimizer, lr_scheduler,                                  │
[rank0]: │   145 │   │   │   │   │   │     train_data_iterator, valid_data_iterator)                        │
[rank0]: │   146 │   print_datetime('after training is done')                                               │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/megatron/training.py:647 in train                        │
[rank0]: │                                                                                                  │
[rank0]: │   644 │   while iteration < args.train_iters:                                                    │
[rank0]: │   645 │   │   update_num_microbatches(args.consumed_train_samples)                               │
[rank0]: │   646 │   │   loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \                          │
[rank0]: │ ❱ 647 │   │   │   train_step(forward_step_func,                                                  │
[rank0]: │   648 │   │   │   │   │      train_data_iterator,                                                │
[rank0]: │   649 │   │   │   │   │      model,                                                              │
[rank0]: │   650 │   │   │   │   │      optimizer,                                                          │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/megatron/training.py:391 in train_step                   │
[rank0]: │                                                                                                  │
[rank0]: │   388 │   │   │   forward_backward_func = forward_backward_pipelining_without_interleaving       │
[rank0]: │   389 │   else:                                                                                  │
[rank0]: │   390 │   │   forward_backward_func = forward_backward_no_pipelining                             │
[rank0]: │ ❱ 391 │   losses_reduced = forward_backward_func(                                                │
[rank0]: │   392 │   │   forward_step_func, data_iterator, model,                                           │
[rank0]: │   393 │   │   optimizer, timers, forward_only=False)                                             │
[rank0]: │   394                                                                                            │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/megatron/schedules.py:129 in                             │
[rank0]: │ forward_backward_no_pipelining                                                                   │
[rank0]: │                                                                                                  │
[rank0]: │   126 │   input_tensor, output_tensor_grad = None, None                                          │
[rank0]: │   127 │   with context_handler():                                                                │
[rank0]: │   128 │   │   for i in range(get_num_microbatches() - 1):                                        │
[rank0]: │ ❱ 129 │   │   │   output_tensor = forward_step(forward_step_func, data_iterator, model,          │
[rank0]: │   130 │   │   │   │   │   │   │   │   │   │    input_tensor, losses_reduced)                     │
[rank0]: │   131 │   │   │   if not forward_only:                                                           │
[rank0]: │   132 │   │   │   │   backward_step(optimizer, model, input_tensor, output_tensor,               │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/megatron/schedules.py:56 in forward_step                 │
[rank0]: │                                                                                                  │
[rank0]: │    53 │   unwrapped_model = unwrap_model(                                                        │
[rank0]: │    54 │   │   model, (torchDDP, LocalDDP, Float16Module, DeepSpeedEngine))                       │
[rank0]: │    55 │   unwrapped_model.set_input_tensor(input_tensor)                                         │
[rank0]: │ ❱  56 │   output_tensor, loss_func = forward_step_func(data_iterator, model)                     │
[rank0]: │    57 │   if mpu.is_pipeline_last_stage():                                                       │
[rank0]: │    58 │   │   output_tensor = loss_func(output_tensor)                                           │
[rank0]: │    59 │   │   loss, loss_reduced = output_tensor                                                 │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/pretrain_bart.py:99 in forward_step                      │
[rank0]: │                                                                                                  │
[rank0]: │    96 │                                                                                          │
[rank0]: │    97 │   # Get the batch.                                                                       │
[rank0]: │    98 │   timers('batch-generator').start()                                                      │
[rank0]: │ ❱  99 │   source, target, prev_output_tokens, attn_mask, loss_mask, use_decoder = get_batch(da   │
[rank0]: │   100 │   timers('batch-generator').stop()                                                       │
[rank0]: │   101 │                                                                                          │
[rank0]: │   102 │   # Forward model lm_labels                                                              │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/pretrain_tokenizer/pretrain_bart.py:53 in get_batch                         │
[rank0]: │                                                                                                  │
[rank0]: │    50 │                                                                                          │
[rank0]: │    51 │   # Broadcast data.                                                                      │
[rank0]: │    52 │   if data_iterator is not None:                                                          │
[rank0]: │ ❱  53 │   │   data = next(data_iterator)                                                         │
[rank0]: │    54 │   else:                                                                                  │
[rank0]: │    55 │   │   data = None                                                                        │
[rank0]: │    56 │   data_b = mpu.broadcast_data(keys, data, datatype)                                      │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:631 in   │
[rank0]: │ __next__                                                                                         │
[rank0]: │                                                                                                  │
[rank0]: │    628 │   │   │   if self._sampler_iter is None:                                                │
[rank0]: │    629 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
[rank0]: │    630 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
[rank0]: │ ❱  631 │   │   │   data = self._next_data()                                                      │
[rank0]: │    632 │   │   │   self._num_yielded += 1                                                        │
[rank0]: │    633 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          │
[rank0]: │    634 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1346 in  │
[rank0]: │ _next_data                                                                                       │
[rank0]: │                                                                                                  │
[rank0]: │   1343 │   │   │   │   self._task_info[idx] += (data,)                                           │
[rank0]: │   1344 │   │   │   else:                                                                         │
[rank0]: │   1345 │   │   │   │   del self._task_info[idx]                                                  │
[rank0]: │ ❱ 1346 │   │   │   │   return self._process_data(data)                                           │
[rank0]: │   1347 │                                                                                         │
[rank0]: │   1348 │   def _try_put_index(self):                                                             │
[rank0]: │   1349 │   │   assert self._tasks_outstanding < self._prefetch_factor * self._num_workers        │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1372 in  │
[rank0]: │ _process_data                                                                                    │
[rank0]: │                                                                                                  │
[rank0]: │   1369 │   │   self._rcvd_idx += 1                                                               │
[rank0]: │   1370 │   │   self._try_put_index()                                                             │
[rank0]: │   1371 │   │   if isinstance(data, ExceptionWrapper):                                            │
[rank0]: │ ❱ 1372 │   │   │   data.reraise()                                                                │
[rank0]: │   1373 │   │   return data                                                                       │
[rank0]: │   1374 │                                                                                         │
[rank0]: │   1375 │   def _mark_worker_as_unavailable(self, worker_id, shutdown=False):                     │
[rank0]: │                                                                                                  │
[rank0]: │ /home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/_utils.py:705 in reraise          │
[rank0]: │                                                                                                  │
[rank0]: │   702 │   │   │   # If the exception takes multiple arguments, don't try to                      │
[rank0]: │   703 │   │   │   # instantiate since we don't know how to                                       │
[rank0]: │   704 │   │   │   raise RuntimeError(msg) from None                                              │
[rank0]: │ ❱ 705 │   │   raise exception                                                                    │
[rank0]: │   706                                                                                            │
[rank0]: │   707                                                                                            │
[rank0]: │   708 def _get_available_device_type():                                                          │
[rank0]: ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
[rank0]: IndexError: Caught IndexError in DataLoader worker process 0.
[rank0]: Original Traceback (most recent call last):
[rank0]:   File "/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
[rank0]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
[rank0]:   File "/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
[rank0]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank0]:   File "/home/jupyter-raptor/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
[rank0]:     data = [self.dataset[idx] for idx in possibly_batched_index]
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/blendable_dataset.py", line 83, in __getitem__
[rank0]:     return self.datasets[dataset_idx][sample_idx]
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/bart_dataset.py", line 106, in __getitem__
[rank0]:     return self.build_training_sample(sample, self.max_seq_length, np_rng)
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/bart_dataset.py", line 148, in build_training_sample
[rank0]:     source = self.add_whole_word_mask(source, mask_ratio, replace_length)
[rank0]:   File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/bart_dataset.py", line 360, in add_whole_word_mask
[rank0]:     source[indices[mask_random]] = torch.randint(
[rank0]: IndexError: The shape of the mask [2] at index 0 does not match the shape of the indexed tensor [1] at index 0

@shivanraptor
Copy link
Author

After replacing the dataset files in the dataset/ folder with the original BART Chinese's ones (regenerated with the MEGATRON preprocess script), the same error occurs. The pre-process command is:

python tools/preprocess_data.py        --input ../data/megatron_3m.json        --output-prefix my-bert        --vocab vocab-bart-base-chinese.txt        --dataset-impl mmap        --tokenizer-type BertWordPieceLowerCase        --split-sentences

Which the the bart-base-chinese.txt is generated by:

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('fnlp/bart-base-chinese')
tokenizer.save_vocabulary('vocab-bart-base-chinese.txt')

Is this vocab generation approach correct?

@choosewhatulike
Copy link
Member

It seems that everything is good with your configuration. It can be the mismatch of some package versions like torch. Since the code was 3 years ago and we do not have the env to run it in our platform. Could you please add some print at this line:

[rank0]: File "/home/jupyter-raptor/pretrain_tokenizer/megatron/data/bart_dataset.py", line 360, in add_whole_word_mask
[rank0]: source[indices[mask_random]] = torch.randint(

to show the size and dtype of mask_random, indices and source.

@shivanraptor
Copy link
Author

shivanraptor commented Sep 3, 2024

Here is the output just before the IndexError occurs:

line 360: mask_random - torch.Size([4]) torch.bool , indices - torch.Size([4]) torch.int64 , source - torch.Size([20]) torch.int64
line 360: mask_random - torch.Size([4]) torch.bool , indices - torch.Size([4]) torch.int64 , source - torch.Size([17]) torch.int64
line 360: mask_random - torch.Size([4]) torch.bool , indices - torch.Size([4]) torch.int64 , source - torch.Size([23]) torch.int64
line 360: mask_random - torch.Size([4]) torch.bool , indices - torch.Size([4]) torch.int64 , source - torch.Size([35]) torch.int64
line 360: mask_random - torch.Size([4]) torch.bool , indices - torch.Size([4]) torch.int64 , source - torch.Size([14]) torch.int64
line 360: mask_random - torch.Size([4]) torch.bool , indices - torch.Size([4]) torch.int64 , source - torch.Size([16]) torch.int64
line 360: mask_random - torch.Size([3]) torch.bool , indices - torch.Size([3]) torch.int64 , source - torch.Size([13]) torch.int64
line 360: mask_random - torch.Size([8]) torch.bool , indices - torch.Size([8]) torch.int64 , source - torch.Size([39]) torch.int64
line 360: mask_random - torch.Size([24]) torch.bool , indices - torch.Size([24]) torch.int64 , source - torch.Size([143]) torch.int64
line 360: mask_random - torch.Size([3]) torch.bool , indices - torch.Size([3]) torch.int64 , source - torch.Size([17]) torch.int64
line 360: mask_random - torch.Size([4]) torch.bool , indices - torch.Size([4]) torch.int64 , source - torch.Size([24]) torch.int64

@choosewhatulike
Copy link
Member

It looks like the dataset can work fine and generate samples. You may check the dataset and the vocab to see if it is some special samples.

@shivanraptor
Copy link
Author

If I use try-except logic to filter out the IndexError, the pre-train process can be executed without problem (sort of).

 try:
      source[indices[mask_random]] = torch.randint(
          1, self.vocab_size, size=(mask_random.sum(),)
     )
except IndexError:
    pass

Each iteration took 2500ms, for 100,000 iterations, it will take 69 hours to complete.

@shivanraptor
Copy link
Author

After the long running time, the pre-processing is finally complete. It generated a checkpoints/bart-base/ folder with a lot of sub-folders. One of the folder last contained a file called model_optim_rng.pt. How can I make use of the generated files for further actions?

For example, how can I use the files with BartForConditionalGeneration.from_pretained() function?

@choosewhatulike
Copy link
Member

run this script “pretrain/tools/convert_ckpt.py” with passing the folder path. It will convert a ckpt that can be used by huggingface module.

@shivanraptor
Copy link
Author

shivanraptor commented Sep 9, 2024

Thanks. I successfully converted the .pt file to pytorch_model.bin. However, even though I added vocabulary to the tokenizer before training, it still has 51271 vocabularies as reported by the convert_ckpt.py script.

How can I use this generated pytorch_model.bin (without any other files such as config.json) to fine-tune to become a BartForConditionalGeneration model? I believe this fine-tune script is related. Is it?

In the vocab/ folder, the added tokens (apart from the 51,271 original tokens) are in the added_tokens.json. Is it correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants