-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathconfig.yaml
147 lines (121 loc) · 4.53 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
name: ehr_transformer
root: "."
target_key: "Target" # key that points to labels in sample_dict
max_len_seq: 350 # maximal number of tokens in the trajectory
aux_gender_classification: True # set to True to enable gender auxiliary head
aux_next_vis_classification: True # set to True to enable next visit auxiliary head
data:
dataset_cfg: # See PhysioNetCinC.dataset for details about the arguments
raw_data_pkl: null # in case we want to use pickle for managing raw data use ${oc.env:CINC_DATA_PKL} to define the
#path to the pickle file
raw_data_path: ${oc.env:CINC_DATA_PATH}
split_filename: None
num_folds: 5
train_folds: [ 0, 1, 2 ]
validation_folds: [ 3 ]
test_folds: [ 4 ]
seed: 2580
reset_split: True
num_percentiles : 4 #number of bins (percentiles) for converting floating lab/vital measurements to categorical values
categorical_max_num_of_values : 5 #max number of uniq values for categorical variable for not to be digitized
min_hours_in_hospital: 46
min_number_of_visits: 10
max_len_seq: ${max_len_seq}
static_variables_to_embed: ['Age','ICUType','Height','Weight','BMI',]
embed_static_in_all_visits: 0
batch_size: 128
target_key: ${target_key}
data_loader_train: # Dataloader constructor parameters
num_workers: 8
data_loader_valid: # Dataloader constructor parameters
num_workers: 8
batch_size: ${data.batch_size}
model:
encoder_type: "bert" # supported values: "bert" "transformer"
transformer_encoder: # TransformerEncoder constructor arguments - used when encoder type is "transformer"
num_tokens: ${max_len_seq}
token_dim: ${model.embed.emb_dim}
depth: 4
heads: 10
mlp_dim: 50
dropout: 0.0
emb_dropout: 0.0
num_cls_tokens: 1
bert_config_kwargs: # BertConfig constructor arguments - used when encoder type is "bert"
hidden_size: ${model.z_dim} # word embedding and seg embedding hidden size (needs to be a multiple of attention heads)
num_hidden_layers: 6 # number of multi-head attention layers required
num_attention_heads: 24 # number of attention heads
intermediate_size: 512 # the size of the "intermediate" layer in the transformer encoder
hidden_act: gelu # activation function ("gelu", 'relu', 'swish')
hidden_dropout_prob: 0.2 # dropout rate
attention_probs_dropout_prob: 0.22 # multi-head attention dropout rate
initializer_range: 0.02 # parameter weight initializer range
embed: # Embed constructor arguments
emb_dim: ${model.z_dim}
classifier_head: # HeadD1 constructor arguments - used for main classifer head
num_outputs: 2
layers_description: [256]
classifier_gender_head: # HeadD1 constructor arguments - used for gender classifer head
num_outputs: 2
layers_description: [256]
classifier_next_vis_head: # HeadD1 constructor arguments - used for next visit classifer head
layers_description: [256]
aux_gender_classification: ${aux_gender_classification}
aux_next_vis_classification: ${aux_next_vis_classification}
z_dim: 48
# train
train: # arguments for train() in classifiers_main_train.py
model_dir: ${root}/${name}
target_key: ${target_key}
target_loss_weight: 0.8
aux_gender_classification: ${aux_gender_classification}
gender_loss_weight: 0.1
aux_next_vis_classification: ${aux_next_vis_classification}
next_vis_loss_weight: 0.1
# uncomment to track in clearml
# track_clearml:
# project_name: "ehr_transformer"
# task_name: ${name}
# tags: "fuse_example"
# reuse_last_task_id: True
# continue_last_task: False
# uncomment for SGD
# opt:
# _partial_: true
# _target_: torch.optim.SGD
# momentum: 0.99
# nesterov: True
# lr: 0.001
# # weight_decay: 1e-5
# AdamW
opt:
_partial_: true
_target_: torch.optim.AdamW
lr: 1.0e-3
# linear_schedule_with_warmup
lr_scheduler:
_partial_: True
_target_: transformers.get_linear_schedule_with_warmup
num_warmup_steps: 500
num_training_steps: 50000
# uncomment for lr sch ReduceLROnPlateau
# lr_scheduler:
# _partial_: true
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
# uncomment for lr sch CosineAnnealingLR
# lr_scheduler:
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR
# T_max: ${train.trainer_kwargs.max_epochs}
# eta_min: 1e-7
# last_epoch: -1
trainer_kwargs: # arguments for pl.Trainer
default_root_dir: ${train.model_dir}
max_epochs: 100
accelerator: "gpu"
devices: 1
num_sanity_val_steps: 0
hydra:
run:
dir: ${root}/${name}
job:
chdir: False