forked from THU-KEG/MAVEN-dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
27 lines (24 loc) · 763 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import tensorflow as tf
import utils
from models import Trigger_Model
import os
from constant import *
flags = tf.flags
flags.DEFINE_string("gpu", "1", "The GPU to run on")
flags.DEFINE_string("mode", "MOGANED", "DMCNN or MOGANED")
flags.DEFINE_bool('eval', False, "Eval or Train")
def main(_):
config = flags.FLAGS
os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
extractor = utils.Extractor()
extractor.extract()
loader = utils.Loader(cut_len)
t_data = loader.load_trigger()
trigger = Trigger_Model(t_data,loader.maxlen,loader.wordemb,config.mode)
if not config.eval:
trigger.train_trigger()
else:
trigger.eval_trigger()
if __name__=="__main__":
tf.app.run()