-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconstraint_utils.py
57 lines (43 loc) · 1.31 KB
/
constraint_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
#sys.path.insert(0, '../')
import torch
from torch import nn
from torch.autograd import Variable
from holder import *
from util import *
def get_label_idx(labels, key):
for i, l in enumerate(labels):
if l.startswith(key):
return i
raise ValueError("Label key {0} not present".format(str(key)))
def parse_constraint_str(s):
# Parse the constraint string to get left atom, right atom,
# and operands
# B-NP implies -(I-VP) : (p1 + p2)
parts = s.strip().split()
left_multiplier = 1
if '(' in parts[0]:
left_multiplier = -1
left_str = parts[0][2:-1]
else:
left_multiplier = 1
left_str = parts[0]
if '(' in parts[2]:
right_multiplier = -1
right_str = parts[2][2:-1]
else:
right_multiplier = 1
right_str = parts[2]
if 'implies' in parts[1]:
operator = 'subtract'
else:
raise NotImplementedError('Logical operator not implemented')
#opertor = 'add'
return (left_str, left_multiplier), operator, (right_str, right_multiplier)
def build_mask(labels, key, is_cuda):
mask = Variable(
torch.Tensor([float(l.startswith(key)) for l in labels]).view(1, len(labels)),
requires_grad=False)
if is_cuda:
mask = mask.cuda()
return mask