-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcue_classifier_builder.py
executable file
·74 lines (62 loc) · 1.92 KB
/
cue_classifier_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright 2021 BBC
# Authors: Chris Newell <[email protected]>
#
# License: Apache-2.0
"""
This application trains and/or evaluates a Citron Cue Classifier.
"""
import logging
import argparse
import os
from citron.cue import CueClassifier
from citron.logger import logger
from citron import utils
def main():
parser = argparse.ArgumentParser(
description="Build and/or evaluate a Cue Classifier model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-v",
action = "store_true",
default = False,
help = "Verbose mode"
)
parser.add_argument("--train-path",
metavar = "train_path",
type = str,
help = "Optional: Path to file or directory containing Citron format training data (default: no training)"
)
parser.add_argument("--test-path",
metavar = "test_path",
type = str,
help = "Optional: Path to file or directory containing Citron format test data (default: no testing)"
)
parser.add_argument("--model-path",
metavar = "model_path",
type = str,
required=True,
help = "Path to the Citron model directory"
)
parser.add_argument('--verbnet-path',
metavar = 'verbnet_path',
type = str,
help = 'Path to Verbnet directory (required if training)'
)
args = parser.parse_args()
if args.v:
logger.setLevel(logging.DEBUG)
nlp = utils.get_parser()
if args.train_path:
CueClassifier.build_model(
nlp,
args.train_path,
args.model_path,
args.verbnet_path
)
if args.test_path:
cue_classifier = CueClassifier(args.model_path)
cue_classifier.evaluate(nlp, args.test_path)
if not (args.train_path or args.test_path):
logger.error("Must specify train_path and/or test_path")
if __name__ == "__main__":
main()