Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LAAT #323

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions example_config/EUR-Lex/laat_tune.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# data
training_file: data/EUR-Lex/train.txt
test_file: data/EUR-Lex/test.txt
data_name: EUR-Lex
min_vocab_freq: 1
max_seq_length: 500
include_test_labels: true
remove_no_label_data: true
add_special_tokens: false

# train
seed: 0
epochs: 50
batch_size: 16
optimizer: adamw
learning_rate: ['grid_search', [0.001, 0.0003]]
eps: 0.00000001 # 1e-08 (expand str for ray)
momentum: 0.9 # not used
weight_decay: 0
patience: 10
early_stopping_metric: RP@5
shuffle: true
lr_scheduler: ReduceLROnPlateau
scheduler_config:
factor: 0.9
patience: 9
min_lr: 0.0001

# eval
eval_batch_size: 16
monitor_metrics: ['Another-Macro-F1', 'Macro-F1', 'Micro-F1', 'P@1', 'P@5', 'RP@5', 'nDCG@5']
val_metric: RP@5

# model
model_name: LAAT
loss_function: binary_cross_entropy_with_logits
init_weight: kaiming_uniform
network_config:
embed_dropout: ['grid_search', [0, 0.2, 0.4, 0.6, 0.8]]
encoder_dropout: ['grid_search', [0, 0.2, 0.4]]
rnn_dim: ['grid_search', [512, 768, 1024]] # 256, 512, 1024
num_layers: 1
d_a: ['grid_search', [128, 256, 384, 512]] # add d_a = 128 in EUR-Lex
freeze_embed: false

# pretrained vocab / embeddings
vocab_file: null
embed_file: glove.6B.200d
normalize_embed: false

# hyperparamter search
search_alg: basic_variant
embed_cache_dir: .vector_cache
num_samples: 1
scheduler: null
no_merge_train_val: true # do not retrain

# Uncomment the following lines to enable the ASHAScheduler.
# See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
#scheduler:
#time_attr: training_iteration
#max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
#grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
#reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
#brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)

# other parameters specified in main.py::get_args
checkpoint_path: null
cpu: false
data_workers: 4
eval: false
label_file: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
metric_threshold: 0.5
result_dir: runs
save_k_predictions: 0
silent: true
val_size: 0.2
81 changes: 81 additions & 0 deletions example_config/MIMIC/laat_tune.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# data
training_file: data/MIMIC/train.txt
val_file: data/MIMIC/valid.txt
test_file: data/MIMIC/test.txt
data_name: MIMIC
min_vocab_freq: 1
max_seq_length: 4000
include_test_labels: true
remove_no_label_data: true
add_special_tokens: false

# train
seed: 0
epochs: 50
batch_size: 8
optimizer: adamw
learning_rate: ['grid_search', [0.001, 0.0003]]
eps: 0.00000001 # 1e-08 (expand str for ray)
momentum: 0.9 # not used
weight_decay: 0
patience: 6
early_stopping_metric: Micro-F1
shuffle: true
lr_scheduler: ReduceLROnPlateau
scheduler_config:
factor: 0.9
patience: 5
min_lr: 0.0001

# eval
eval_batch_size: 8
monitor_metrics: ['Another-Macro-F1', 'Macro-F1', 'Micro-F1', 'P@5', 'P@8', 'P@15']
val_metric: Micro-F1

# model
model_name: LAAT
loss_function: binary_cross_entropy_with_logits
init_weight: kaiming_uniform
network_config:
embed_dropout: ['grid_search', [0, 0.2, 0.4, 0.6, 0.8]]
encoder_dropout: ['grid_search', [0, 0.2, 0.4]] # LAAT: 0
rnn_dim: ['grid_search', [512, 768, 1024]] # LAAT: 512
num_layers: 1
d_a: ['grid_search', [256, 384, 512]] # LAAT: 256
freeze_embed: false

# pretrained vocab / embeddings
vocab_file: null # generate min_vocab_freq=1 by LibMultiLabel
embed_file: data/MIMIC-50/word2vec_sg0_100.embed
normalize_embed: false

# hyperparamter search
search_alg: basic_variant
embed_cache_dir: null
num_samples: 1
scheduler: null
no_merge_train_val: true # do not retrain

# Uncomment the following lines to enable the ASHAScheduler.
# See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
# scheduler:
# time_attr: training_iteration
# max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
# grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
# reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
# brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)

# other parameters specified in main.py::get_args
checkpoint_path: null
cpu: false
data_workers: 8
eval: false
label_file: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
metric_threshold: 0.5
result_dir: runs
save_k_predictions: 0
silent: true
val_size: 0.2
40 changes: 40 additions & 0 deletions example_config/MIMICIV-ICD10-50/laat.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# data
training_file: data/MIMICIV-ICD10-50/train.txt
val_file: data/MIMICIV-ICD10-50/valid.txt
test_file: data/MIMICIV-ICD10-50/test.txt
data_name: MIMICIV-ICD10-50
min_vocab_freq: 1
max_seq_length: 4000
include_test_labels: true
remove_no_label_data: true

# train
seed: 0
epochs: 50
batch_size: 8
optimizer: adamw
learning_rate: 0.001
weight_decay: 0
eps: 1e-06
patience: 6
shuffle: true

# eval
eval_batch_size: 8
monitor_metrics: ['Loss', 'Another-Macro-F1', 'Macro-F1', 'Micro-F1', 'P@5', 'P@8', 'P@15']
val_metric: Micro-F1

# model
model_name: LAAT
init_weight: null
network_config:
embed_dropout: 0.3
encoder_dropout: 0
rnn_dim: 512 # 512//2 = 256
num_layers: 1
d_a: 256

# pretrained vocab / embeddings
vocab_file: data/MIMICIV-ICD10-50/vocab.csv # 179,131, min_vocab_freq=1
embed_file: data/MIMICIV-ICD10-50/processed_full.embed
normalize_embed: false
72 changes: 72 additions & 0 deletions example_config/MIMICIV-ICD10-50/laat_tune.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# data
training_file: data/MIMICIV-ICD10-50/train.txt
val_file: data/MIMICIV-ICD10-50/valid.txt
test_file: data/MIMICIV-ICD10-50/test.txt
data_name: MIMICIV-ICD10-50
min_vocab_freq: 1
max_seq_length: 4000
include_test_labels: true
remove_no_label_data: true
add_special_tokens: false

# train
seed: 0
epochs: 50
batch_size: 8
optimizer: adamw
learning_rate: 0.001
eps: 0.00000001 # 1e-08 (expand str for ray)
momentum: 0.9
weight_decay: 0
patience: 6
early_stopping_metric: Micro-F1
shuffle: true
lr_scheduler: ReduceLROnPlateau
scheduler_config:
factor: 0.9
patience: 5
min_lr: 0.0001

# eval
eval_batch_size: 8
monitor_metrics: ['Another-Macro-F1', 'Macro-F1', 'Micro-F1', 'P@5', 'P@8', 'P@15']
val_metric: Micro-F1

# model
model_name: LAAT
loss_function: binary_cross_entropy_with_logits
init_weight: kaiming_uniform # null
network_config:
embed_dropout: 0.3
encoder_dropout: 0
rnn_dim: ['grid_search', [512, 1024, 768, 256]] # 2u
num_layers: 1
d_a: ['grid_search', [256, 512, 384, 128]]
freeze_embed: false

# pretrained vocab / embeddings
vocab_file: data/MIMICIV-ICD10-50/vocab.csv # 179,131, min_vocab_freq=1
embed_file: data/MIMICIV-ICD10-50/processed_full.embed
normalize_embed: false

# hyperparamter search
search_alg: basic_variant
embed_cache_dir: .vector_cache
num_samples: 1
scheduler: null
no_merge_train_val: true # do not retrain

# other parameters specified in main.py::get_args
checkpoint_path: null
cpu: false
data_workers: 8
eval: false
label_file: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
metric_threshold: 0.5
result_dir: runs
save_k_predictions: 0
silent: true
val_size: 0.2
40 changes: 40 additions & 0 deletions example_config/MIMICIV-ICD9-50/laat.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# data
training_file: data/MIMICIV-ICD9-50/train.txt
val_file: data/MIMICIV-ICD9-50/valid.txt
test_file: data/MIMICIV-ICD9-50/test.txt
data_name: MIMICIV-ICD9-50
min_vocab_freq: 1
max_seq_length: 4000
include_test_labels: true
remove_no_label_data: true

# train
seed: 0 # 1337
epochs: 50
batch_size: 8
optimizer: adamw
learning_rate: 0.001
weight_decay: 0
eps: 1e-06
patience: 6
shuffle: true

# eval
eval_batch_size: 8
monitor_metrics: ['Loss', 'Another-Macro-F1', 'Macro-F1', 'Micro-F1', 'P@5', 'P@8', 'P@15']
val_metric: Micro-F1

# model
model_name: LAAT
init_weight: null
network_config:
embed_dropout: 0.3
encoder_dropout: 0
rnn_dim: 512 # 512//2 = 256
num_layers: 1
d_a: 256

# pretrained vocab / embeddings
vocab_file: data/MIMICIV-ICD9-50/vocab.csv # 282,173, min_vocab_freq=1
embed_file: data/MIMICIV-ICD9-50/processed_full.embed
normalize_embed: false
72 changes: 72 additions & 0 deletions example_config/MIMICIV-ICD9-50/laat_tune.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# data
training_file: data/MIMICIV-ICD9-50/train.txt
val_file: data/MIMICIV-ICD9-50/valid.txt
test_file: data/MIMICIV-ICD9-50/test.txt
data_name: MIMICIV-ICD9-50
min_vocab_freq: 1
max_seq_length: 4000
include_test_labels: true
remove_no_label_data: true
add_special_tokens: false

# train
seed: 0
epochs: 50
batch_size: 8
optimizer: adamw
learning_rate: ['grid_search', [0.001, 0.0003]] # LAAT: 0.001
eps: 0.00000001
momentum: 0.9
weight_decay: 0
patience: 6
early_stopping_metric: Micro-F1
shuffle: true
lr_scheduler: ReduceLROnPlateau
scheduler_config:
factor: 0.9
patience: 5
min_lr: 0.0001

# eval
eval_batch_size: 8
monitor_metrics: ['Another-Macro-F1', 'Macro-F1', 'Micro-F1', 'P@5', 'P@8', 'P@15']
val_metric: Micro-F1

# model
model_name: LAAT
loss_function: binary_cross_entropy_with_logits
init_weight: kaiming_uniform # null
network_config:
embed_dropout: 0.3
encoder_dropout: ['grid_search', [0, 0.2, 0.4]] # LAAT: 0
rnn_dim: ['grid_search', [512, 768, 1024]] # LAAT: 512
num_layers: 1
d_a: ['grid_search', [256, 384, 512]] # LAAT: 256
freeze_embed: false # true

# pretrained vocab / embeddings
vocab_file: data/MIMICIV-ICD9-50/vocab.csv # 282,173, min_vocab_freq=1
embed_file: data/MIMICIV-ICD9-50/processed_full.embed
normalize_embed: false

# hyperparamter search
search_alg: basic_variant
embed_cache_dir: null
num_samples: 1
scheduler: null
no_merge_train_val: true # do not retrain

# other parameters specified in main.py::get_args
checkpoint_path: null
cpu: false
data_workers: 8
eval: false
label_file: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
metric_threshold: 0.5
result_dir: runs
save_k_predictions: 0
silent: true
val_size: 0.2
1 change: 1 addition & 0 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache=N
if load_embedding_from_file:
# Add UNK embedding
# AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
# LAAT: np.random.uniform(-0.25, 0.25, embedding_size)
# CAML: np.random.randn(embed_size)
unk_vector = torch.randn(embed_size)
embedding_weights[word_dict[UNK]] = unk_vector
Expand Down
Loading
Loading