forked from THU-KEG/MAVEN-dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_MAVEN_infer.sh
executable file
·21 lines (21 loc) · 954 Bytes
/
run_MAVEN_infer.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
python3 run_ee.py \
--data_dir ../maven/ \ #path to the test data, remember to delete the cached files at first (otherwise the test data may be random shuffled before)
--model_type bert \
--model_name_or_path ./MAVEN/checkpoint-2500 \ #path to the trained checkpoint
--task_name maven_infer \
--output_dir ./MAVEN \ #output path
--max_seq_length 128 \
--do_lower_case \
--per_gpu_train_batch_size 42 \
--per_gpu_eval_batch_size 42 \
--gradient_accumulation_steps 2 \
--learning_rate 5e-5 \
--num_train_epochs 5 \
--save_steps 500 \
--logging_steps 500 \
--seed 42 \
--do_infer #add this flag to do inference only
python3 get_submission.py \ #convert the predictions to the submission format
--test_data ../maven/test.jsonl \ #path to the test data file
--preds MAVEN/checkpoint-2500/checkpoint-2500_preds.npy \ #path to the prediction file
--output ./results.jsonl #output file