-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.sh
executable file
·63 lines (50 loc) · 1.68 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env bash
set -Eeu
SCRIPTDIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
MACHAMP_DIR="$SCRIPTDIR/../../submodules/machamp"
function script_usage() {
cat << EOF
Usage: train.sh <TASK> <MODEL> <[SEED]>
Positional Args:
TASK Task to train on; should be the name of a config file
(without the extension) in the configs/ folder (e.g.
'ner_naija' or 'ner_wikiann_ht')
MODEL Model type to use; one of: mbert, mt5, xlmr
SEED (optional) Random seed to pass to MaChAmp.
EOF
}
function msg() {
echo >&2 -e "${1-}"
}
if [[ $# -lt 2 || $# -gt 3 ]]; then
msg "Error: Found $# positional arguments; expected 2 or 3\n"
script_usage
exit 1
fi
# TEST IF MACHAMP_DIR exists, if not then give message to pull submodules
if [ ! -f "$MACHAMP_DIR/train.py" ]; then
msg "Error: MaChAmp not found, did you pull submodules? <git submodule update --init --recursive>"
exit 1
fi
task="$1"
task_config="$SCRIPTDIR/configs/$task.json"
if [[ ! -f "$task_config" ]]; then
msg "Error: Found no configuration file for task '$task'"
msg " expected: $task_config"
exit 2
fi
model="$2"
model_config="$SCRIPTDIR/configs/params_$model.json"
if [[ ! -f "$model_config" ]]; then
msg "Error: Found no configuration file for model '$model'"
msg " expected: $model_config"
exit 2
fi
seed="${3-$RANDOM}"
# Important for relative path names in the config files
cd "$SCRIPTDIR"
python3 "$MACHAMP_DIR"/train.py \
--parameters_config "$model_config" \
--dataset_config "$task_config" \
--name "${task}_${model}_baseline" \
--seed $seed