-
Notifications
You must be signed in to change notification settings - Fork 1
/
crf.py
171 lines (127 loc) · 5.49 KB
/
crf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import datetime
import os
import random
from seqeval.metrics import accuracy_score, classification_report
from typing import Optional, Type, Union
from pyseqlab.attributes_extraction import GenericAttributeExtractor
from pyseqlab.features_extraction import FeatureExtractor, FOFeatureExtractor
from pyseqlab.linear_chain_crf import LCRF, LCRFModelRepresentation
from pyseqlab.utilities import TemplateGenerator
from pyseqlab.workflow import GenericTrainingWorkflow
from utils import LENER_DATASET_DIR, PySeqLabSequenceBuilder, adapt_lener_to_pyseqlab
adapt_lener_to_pyseqlab()
class CRFModel:
def __init__(
self,
model_type: Type[LCRF],
model_representation_type: Type[LCRFModelRepresentation],
feature_extraction_type: [Type[FeatureExtractor], Type[FOFeatureExtractor]],
working_directory: Union[bytes, str],
):
self._pyseqlab_dataset_dir = os.path.join(LENER_DATASET_DIR, "pyseqlab")
self._model_type = model_type
self._model_representation_type = model_representation_type
self._feature_extraction_type = feature_extraction_type
self._working_directory = working_directory
self._data_parser_options = dict(
header="main", y_ref=True, column_sep=" ", seg_other_symbol="O"
)
self._attribute_description = dict(
w=dict(description="word observation track", encoding="categorical")
)
self._generic_attribute_extractor = GenericAttributeExtractor(
self._attribute_description
)
self._data_split_option = dict(method="cross_validation", k_fold=5)
self._training_workflow: Optional[GenericTrainingWorkflow] = None
self._model_object: Optional[LCRF] = None
self._trained_model_dir = None
self._training_data_split = None
def train(self, **kwargs):
epochs = kwargs.get("epochs")
optimization_option = dict(
method="SGA-ADADELTA",
regularization_type="l2",
regularization_value=0,
num_epochs=epochs,
)
self._build_model()
# TODO: train using data splits
if self._training_data_split is not None:
begin = datetime.datetime.now()
for fold in self._training_data_split:
train_sequences_id = self._training_data_split[fold]["train"]
self._trained_model_dir = self._training_workflow.train_model(
train_sequences_id, self._model_object, optimization_option
)
end = datetime.datetime.now()
elapsed_time = end - begin
print("-" * 100)
print("Model trained successfully.")
print(f"Elapsed time: {divmod(elapsed_time.total_seconds(), 60)}")
def predict(self, sequences, output_file):
decoding_method = "viterbi"
return self._model_object.decode_seqs(
decoding_method=decoding_method,
out_dir=self._trained_model_dir,
seqs=sequences,
file_name=output_file,
sep="\t",
)
def evaluate(self, sequence_type):
sequences = PySeqLabSequenceBuilder(
self._pyseqlab_dataset_dir
).generate_sequences(sequence_type)
y_true = [sequence.flat_y for sequence in sequences]
prediction = self.predict(
sequences,
os.path.join(self._pyseqlab_dataset_dir, "output_evaluation.txt"),
)
y_pred = [value["Y_pred"] for key, value in prediction.items()]
print(y_true)
print(y_pred)
print(accuracy_score(y_true, y_pred))
print(classification_report(y_true, y_pred))
def _build_model(self):
assert self._model_object is None
if self._training_workflow is None:
self._build_training_workflow()
training_sequence = PySeqLabSequenceBuilder(
self._pyseqlab_dataset_dir
).generate_sequences("train")
self._training_data_split = self._training_workflow.seq_parsing_workflow(
self._data_split_option, seqs=random.sample(training_sequence, 1000), full_parsing=True
)
self._model_object = self._training_workflow.build_crf_model(
self._training_data_split[0]["train"], "f_0"
)
self._model_object.weights.fill(0)
def get_model_features(self):
return dict(
number=len(self._model_object.model.modelfeatures_codebook),
features=self._model_object.model.modelfeatures,
)
def _build_training_workflow(self):
assert self._training_workflow is None
template_xy = {}
template_gen = TemplateGenerator()
template_gen.generate_template_XY(
"w", ("1-gram:2-grams", range(-1, 2)), "1-state", template_xy
)
template_y = template_gen.generate_template_Y("1-state:2-states")
if self._feature_extraction_type == FOFeatureExtractor:
feature_extractor = self._feature_extraction_type(
template_xy, template_y, self._attribute_description, start_state=False
)
else:
feature_extractor = self._feature_extraction_type(
template_xy, template_y, self._attribute_description
)
self._training_workflow = GenericTrainingWorkflow(
self._generic_attribute_extractor,
feature_extractor,
None,
self._model_representation_type,
self._model_type,
self._working_directory,
)