Skip to content

文本之外的模态

zhezhaoa edited this page Oct 27, 2023 · 1 revision

除了文本,TencentPretrain支持图像、语音等模态预训练模型。这里展示如何通过TencentPretrain去预训练和微调不同模态的模型。

ViT

除了文本,TencentPretrain支持图像、语音等模态预训练模型。在CIFAR10数据集上使用ViT模型预训练示例:

python3 preprocess.py --corpus_path datasets/cifar10/train.tsv --tokenizer virtual \
                      --dataset_path dataset.pt --processes_num 8 --data_processor vit

python3 pretrain.py --dataset_path dataset.pt --tokenizer virtual \
                    --pretrained_model_path models/vit_base_patch16_224_model.bin \
                    --config_path models/vit/base-16-224_config.json \
                    --output_model_path models/cifar10_vit_base_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 2000 --save_checkpoint_steps 1000 --batch_size 32 \
                    --labels_num 10

在预训练模型仓库章节可以下载 vit_base_patch16_224_model.bin 。由于处理图像不需要用分词,因此这里使用 --tokenizer virtual 。 在CIFAR10数据集上微调和推理示例:

python3 finetune/run_image_classifier.py --pretrained_model_path models/vit_base_patch16_224_model.bin \
                                         --tokenizer virtual \
                                         --config_path models/vit/base-16-224_config.json \
                                         --train_path datasets/cifar10/train.tsv \
                                         --dev_path datasets/cifar10/test.tsv \
                                         --output_model_path models/image_classifier_model.bin \
                                         --epochs_num 3 --batch_size 64

python3 inference/run_image_classifier_infer.py --load_model_path models/image_classifier_model.bin \
                                                --tokenizer virtual \
                                                --config_path models/vit/base-16-224_config.json \
                                                --test_path datasets/cifar10/test.tsv \
                                                --prediction_path datasets/cifar10/prediction.tsv \
                                                --labels_num 10

CIFAR10数据集有10个分类标签(--labels_num 10)。

S2T

LibriSpeech数据集上使用S2T模型预训练示例: 需要将 tencentpretrain/utils/constants.py 中的 models/special_tokens_map.json 修改为 models/xlmroberta_special_tokens_map.json 。此外需要将数据处理成TencentPretrain可以处理的格式。在下游任务数据集章节我们提供了10小时版本的数据集。

python3 scripts/prepare_librispeech_data.py --input_path datasets/librispeech/train-10h \
                                            --output_path datasets/librispeech/train-10h.tsv

然后预处理和预训练:

python3 preprocess.py --corpus_path datasets/librispeech/train-10h.tsv \
                      --spm_model_path models/sentencepiece.bpe.model \
                      --dataset_path dataset.pt \
                      --processes_num 8 --data_processor s2t

python3 pretrain.py --dataset_path dataset.pt  \
                    --spm_model_path models/sentencepiece.bpe.model \
                    --config_path models/s2t/small_config.json \
                    --output_model_path models/output_model.bin \
                    --accumulation_steps 8 \
                    --world_size 4 --gpu_ranks 0 1 2 3 \
                    --total_steps 100000 --save_checkpoint_steps 10000 --report_steps 100 \
                    --batch_size 8 --learning_rate 2e-3

为了为微调阶段准备数据集,需要使用 --add_column 把列名放在第一行。 在LibriSpeech数据集上微调示例:

python3 scripts/prepare_librispeech_data.py --input_path datasets/librispeech/train-10h \
                                            --output_path datasets/librispeech/train-10h.tsv \
                                            --add_column

python3 scripts/prepare_librispeech_data.py --input_path datasets/librispeech/dev-clean \
                                            --output_path datasets/librispeech/dev-clean.tsv \
                                            --add_column

python3 finetune/run_speech2text.py --pretrained_model_path models/output_model.bin \
                                    --spm_model_path models/sentencepiece.bpe.model \
                                    --config_path models/s2t/small_config.json \
                                    --train_path datasets/librispeech/train-10h.tsv \
                                    --dev_path datasets/librispeech/dev-clean.tsv \
                                    --output_model_path models/finetuned_model.bin \
                                    --batch_size 8 --epochs_num 10 \
                                    --learning_rate 2e-4 --report_steps 200

推理采用beam search的方式进行,通过设置 --beam_width 可以修改beam大小。

python3 scripts/prepare_librispeech_data.py --input_path datasets/librispeech/test-clean \
                                            --output_path datasets/librispeech/test-clean.tsv \
                                            --add_column

python3 inference/run_speech2text_infer.py --load_model_path models/finetuned_model.bin \
                                           --spm_model_path models/sentencepiece.bpe.model \
                                           --config_path models/s2t/small_config.json \
                                           --test_path datasets/librispeech/test-clean.tsv \
                                           --prediction_path output.txt \
                                           --batch_size 8 --tgt_seq_length 100 \
                                           --beam_width 5

可以将Huggingface中的S2T模型转换为TencentPretrain格式并推理。推理前需要把 models/xlmroberta_special_tokens_map.json 中的 "cls_token": "" 改成 "cls_token": "</s>" 。

python3 scripts/convert_s2t_from_huggingface_to_tencentpretrain.py --input_model_path s2t_huggingface_model.bin \
                                                                   --output_model_path s2t_tencentpretrain_model.bin

python3 inference/run_speech2text_infer.py --load_model_path s2t_tencentpretrain_model.bin \
                                           --spm_model_path models/sentencepiece.bpe.model  \
                                           --config_path models/s2t/small_config.json \
                                           --test_path datasets/librispeech/test-clean.tsv \
                                           --prediction_path output.txt \
                                           --batch_size 8 --tgt_seq_length 100 \
                                           --beam_width 5
Clone this wiki locally