-
Notifications
You must be signed in to change notification settings - Fork 0
/
LstmTwoClassifier.py
58 lines (42 loc) · 2.14 KB
/
LstmTwoClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Iterator, List, Dict
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.training.metrics import CategoricalAccuracy, F1Measure
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.data.vocabulary import Vocabulary
import torch
import torch.optim as optim
class LstmTwoClassifier(Model):
def __init__(self,
word_embeddings: TextFieldEmbedder,
encoder: Seq2VecEncoder,
vocab: Vocabulary) -> None:
super().__init__(vocab)
# We need the embeddings to convert word IDs to their vector representations
self.word_embeddings = word_embeddings
self.encoder = encoder
self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
out_features=2)
self.accuracy = CategoricalAccuracy()
#self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
self.out_act = torch.nn.Sigmoid()
#self.loss_function = torch.nn.BCELoss() #BCEWithLogitsLoss() #CrossEntropyLoss()
self.loss = torch.nn.CrossEntropyLoss() #torch.nn.BCEWithLogitsLoss()
def forward(self,
tokens: Dict[str, torch.tensor],
label: torch.tensor = None) -> torch.tensor:
mask = get_text_field_mask(tokens)
# Forward pass
embeddings = self.word_embeddings(tokens)
encoder_out = self.encoder(embeddings, mask)
linear_out = self.hidden2tag(encoder_out)
logits = self.out_act(linear_out)
output = {"logits": logits}
if label is not None:
#y = torch.tensor(label.reshape(-1, 1), dtype=torch.float).cuda()
self.accuracy(logits, label)
#output["loss"] = self.loss_function(logits, label)
output["loss"] = self.loss(logits, label)
return output