-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
34 lines (25 loc) · 1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
torch.set_float32_matmul_precision('high')
def base_training(trainer, dm, lit_mod, ckpt=None):
if trainer.logger is not None:
print()
print("Logdir:", trainer.logger.log_dir)
print()
trainer.fit(lit_mod, datamodule=dm, ckpt_path=ckpt)
trainer.test(lit_mod, datamodule=dm, ckpt_path='best')
def multi_dm_training(trainer, dm, lit_mod, test_dm=None, test_fn=None, ckpt=None):
if trainer.logger is not None:
print()
print("Logdir:", trainer.logger.log_dir)
print()
trainer.fit(lit_mod, datamodule=dm, ckpt_path=ckpt)
if test_fn is not None:
if test_dm is None:
test_dm = dm
lit_mod._norm_stats = test_dm.norm_stats()
best_ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.callbacks = []
trainer.test(lit_mod, datamodule=test_dm, ckpt_path=best_ckpt_path)
print("\nBest ckpt score:")
print(test_fn(lit_mod).to_markdown())
print("\n###############")