forked from charlesashby/CharLSTM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
36 lines (28 loc) · 1.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Load a model')
parser.add_argument('model', action="store", type=str)
parser.add_argument('--train', action='store_true')
parser.add_argument('--sentences', nargs='*', default=None)
args = parser.parse_args()
print('Using model: %s' % args.model)
print('Training: %s' % args.train)
if args.sentences is not None:
for value in args.sentences:
print('processing sentence: %s' % value)
if args.model == 'lstm':
from lib_model.char_lstm import *
network = LSTM()
network.build()
if args.train == True:
network.train()
if args.sentences is not None:
network.predict_sentences(args.sentences)
elif args.model == 'bidirectional_lstm':
from lib_model.bidirectional_lstm import *
network = LSTM()
network.build()
if args.train == True:
network.train()
if args.sentences is not None:
network.predict_sentences(args.sentences)