-
Notifications
You must be signed in to change notification settings - Fork 1
/
embedding.py
33 lines (18 loc) · 979 Bytes
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Type
from kashgari.tasks.seq_labeling import SequenceLabelingModel
from utils import adapt_lener_to_kashgari, LeNerCorpus
# Increase data percentage to train with a bigger amount of files
adapt_lener_to_kashgari()
class EmbeddingModel:
def __init__(self, model_type: Type[SequenceLabelingModel]):
self._model = model_type()
def train(self, **kwargs):
epochs = kwargs.get("epochs")
x_train, y_train = LeNerCorpus.get_sequence_tagging_data()
x_validate, y_validate = LeNerCorpus.get_sequence_tagging_data(data_type="dev")
self._model.fit(x_train, y_train, x_validate=x_validate, y_validate=y_validate, epochs=epochs)
def evaluate(self, data_type: str = "test"):
if data_type not in ("test", "dev"):
raise Exception("Wrong data type for evaluation")
x, y = LeNerCorpus.get_sequence_tagging_data(data_type=data_type)
self._model.evaluate(x, y, debug_info=True)