-
Notifications
You must be signed in to change notification settings - Fork 3
/
example_use_train.py
26 lines (19 loc) · 1.41 KB
/
example_use_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from manager import ModelManager
# this is an example requiring the data be stored according to the format described in readers.py
model_ckpt_dir = 'models'
train_tfrecord = 'train_fold0.tfrecords'
valid_tfrecord = 'validation_fold0.tfrecords'
# get the manager object and define the graph of the network using default network settings : see ModelManager.init()
manager = ModelManager(name='iUNET', num_layers=3, feature_maps_root=32, norm_type='bn', n_iterations=2)
# using the default training setttings (see ModelManager.train() method)
# we set i-bce-topo to indicate the use of an iterative (i is associated with iUNET) loss
# comprising balanced cross entropy (bce) and topological (topo) loss terms
manager.train(train_tfrecord=train_tfrecord, validation_tfrecord=valid_tfrecord,
loss_type='i-bce-topo', model_dir=model_ckpt_dir)
########################################################################################################################
# for defining and training a SHN model we would do the following:
# manager = ModelManager(name='SHN', n_modules=2, verbose=False, num_layers=2)
# we set s-bce-topo to indicate the use of an iterative (s is associated with SHN) loss
# comprising balanced cross entropy (bce) and topological (topo) loss terms
# manager.train(train_tfrecord=train_tfrecord, validation_tfrecord=valid_tfrecord,
# loss_type='s-bce-topo', model_dir=model_ckpt_dir)