-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_pretrain.sh
70 lines (66 loc) · 1.87 KB
/
train_pretrain.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# ------------------- Args setting -------------------
MODEL=$1
BATCH_SIZE=$2
DATASET=$3
DATASET_ROOT=$4
WORLD_SIZE=$5
RESUME=$6
# ------------------- Training setting -------------------
MASK_RATIO=0.75
# Optimizer config
OPTIMIZER="adamw"
LRSCHEDULER="cosine"
BASE_LR=0.00015
MIN_LR=0
WEIGHT_DECAY=0.05
# Epoch
MAX_EPOCH=800
WP_EPOCH=40
EVAL_EPOCH=20
# ------------------- Dataset setting -------------------
if [[ $DATASET == "cifar10" ]]; then
IMG_SIZE=32
PATCH_SIZE=2
NUM_CLASSES=10
elif [[ $DATASET == "cifar100" ]]; then
IMG_SIZE=32
PATCH_SIZE=2
NUM_CLASSES=100
elif [[ $DATASET == "imagenet_1k" || $DATASET == "imagenet_22k" ]]; then
IMG_SIZE=224
PATCH_SIZE=16
NUM_CLASSES=1000
elif [[ $DATASET == "custom" ]]; then
IMG_SIZE=224
PATCH_SIZE=16
NUM_CLASSES=2
else
echo "Unknown dataset!!"
exit 1
fi
# ------------------- Training pipeline -------------------
if (( $WORLD_SIZE >= 1 && $WORLD_SIZE <= 8 )); then
python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port 1700 main_pretrain.py \
--cuda \
-dist \
--root ${DATASET_ROOT} \
--dataset ${DATASET} \
--model ${MODEL} \
--resume ${RESUME} \
--batch_size ${BATCH_SIZE} \
--img_size ${IMG_SIZE} \
--patch_size ${PATCH_SIZE} \
--max_epoch ${MAX_EPOCH} \
--wp_epoch ${WP_EPOCH} \
--eval_epoch ${EVAL_EPOCH} \
--optimizer ${OPTIMIZER} \
--lr_scheduler ${LRSCHEDULER} \
--base_lr ${BASE_LR} \
--min_lr ${MIN_LR} \
--weight_decay ${WEIGHT_DECAY} \
--mask_ratio ${MASK_RATIO}
else
echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
multi-card training mode, which is currently unsupported."
exit 1
fi