-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_weak_learners.py
39 lines (31 loc) · 1.55 KB
/
train_weak_learners.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
BATCH_SIZE = 128
NUM_EPOCH = 300
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from helper import load_dataset
from model import TransformerModel, Data, get_dataloaders, SoftMaxLit
DEV = False
df = load_dataset('../dataset/training.json', test=True)
# https://stackoverflow.com/questions/65445651/t5tokenizer-requires-the-sentencepiece-library-but-it-was-not-found-in-your-envicheckpoints = []
checkpoints = []
for cur_model_name in list(TransformerModel.MODELS.keys()):
# cur_model_name
cur_dataset_x = torch.load(f'pretrained--dev={DEV}--model={cur_model_name}.pt')
cur_data = Data(df, x=cur_dataset_x)
cur_dataloaders = get_dataloaders(cur_data, BATCH_SIZE)
cur_model = SoftMaxLit(TransformerModel.MODELS[cur_model_name]['dim'], 2)
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
monitor='val_loss',
mode='min',
filename=f'model={cur_model_name}--dev={DEV}' + '--{epoch}-{step}--{val_loss:.2f}'
)
trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=NUM_EPOCH)
trainer.fit(model=cur_model, train_dataloaders=cur_dataloaders['train'], val_dataloaders=cur_dataloaders['val'])
checkpoints.append(checkpoint_callback.best_model_path)
best_model = cur_model.load_from_checkpoint(n_inputs=TransformerModel.MODELS[cur_model_name]['dim'], n_outputs=2, checkpoint_path=checkpoint_callback.best_model_path)
trainer.test(best_model, dataloaders=cur_dataloaders['test'])
del cur_dataset_x
del cur_data.x
torch.cuda.empty_cache()