forked from tech-srl/lstar_extraction
-
Notifications
You must be signed in to change notification settings - Fork 1
/
WhiteboxRNNCounterexampleGenerator.py
163 lines (140 loc) · 8.45 KB
/
WhiteboxRNNCounterexampleGenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from copy import deepcopy
from time import perf_counter
class WhiteboxRNNCounterexampleGenerator:
def __init__(self,network,partitioning,starting_examples):
self.time_limit = None
self.whiteboxrnn = network
self.partitioning = partitioning
self.starting_dict = {cex:network.classify_word(cex) for cex in starting_examples}
return
def set_time_limit(self,time_limit,start_time):
self.time_limit = time_limit
self.start_time = start_time
def _get_counterexample_from(self,words):
words = sorted(words,key=lambda x:len(x)) #prefer shortest possible counterexample
for w in words:
if not self.whiteboxrnn.classify_word(w) == self.proposed_dfa.classify_word(w):
return w
return None
def _counterexample_from_classification_conflict(self,state_info):
res = self._get_counterexample_from(state_info.paths)
if None is res:
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("classification conflict didn't cause counterexample:")
print("check your partitioning is consistent and transition function ")
print("(from one continuous network state (R-State) to another) is correct ")
raise NoCounterexampleFromClassificationConflict()
return res
def _counterexample_from_cluster_conflict(self,old_info,new_info):
q1 =old_info.dfa_state
q2 = new_info.dfa_state
prefixes = old_info.paths + new_info.paths
suffix = self.proposed_dfa.minimal_diverging_suffix(q1,q2)
return self._get_counterexample_from([p+suffix for p in prefixes])
def _process_new_state_except_children(self,new_cluster,new_info):
counterexample = None
split = SplitInfo()
old_info = self.cluster_information[new_cluster] if new_cluster in self.cluster_information else None
full_info = old_info + new_info if not None is old_info else new_info
if not new_info.accepting == (new_info.dfa_state in self.proposed_dfa.F):
counterexample = self._counterexample_from_classification_conflict(new_info)
elif not new_info.dfa_state == full_info.dfa_state:
counterexample = self._counterexample_from_cluster_conflict(old_info,new_info)
if None is counterexample:
split = SplitInfo(agreeing_RStates=old_info.RStates,
conflicted_RState=new_info.RStates[0]) # the one seen now, in new info
else: #no conflicts, store state and continue processing it later
self.cluster_information[new_cluster] = full_info
return counterexample, split
def _add_children_states(self,cluster):
state_info = self.cluster_information[cluster]
if not state_info.explored:
# we explore a state only the first time we successfully visit and process it, and we store a state's
# information in self.cluster_information only if we have successfully processed it.
RState = state_info.RStates[0]
state_info.explored = True
for char in self.proposed_dfa.alphabet:
next_RState, pos = self.whiteboxrnn.get_next_RState(RState, char)
path = state_info.paths[0] + char
# we only ever explore a state the first
# time we find it, so, with the first path in its list of reaching paths
next_dfa_state = self.proposed_dfa.delta[state_info.dfa_state][char]
self.new_RStates.append(UnrollingInfo(next_dfa_state,path,next_RState,pos))
def _process_top_pair(self):
new_info = self.new_RStates.pop(0)
self.new_RStates_backup = new_info #might want to unpop if we refine the partitioning and want to restart from here
new_cluster = self.partitioning.get_partition(new_info.RStates[0])
counterexample, split = self._process_new_state_except_children(new_cluster, new_info)
if (None is counterexample) and (not split.has_info): #i.e. no conflicts
self._add_children_states(new_cluster)
return counterexample, split
def _initialise_unrolling(self):
self.cluster_information = {}
initial_RState, pos = self.whiteboxrnn.get_first_RState()
self.new_RStates = [UnrollingInfo(self.proposed_dfa.q0,"",initial_RState,pos)]
def _cex_from_starting_dict(self,dfa):
for cex in self.starting_dict:
if not dfa.classify_word(cex) == self.starting_dict[cex]:
print("storing provided counterexample of length " + str(len(cex)))
return cex
return None
def _out_of_time(self):
if not None is self.time_limit:
if (perf_counter() - self.start_time) > self.time_limit:
return True
return False
def _split_was_clean(self,old_cluster,split):
new_states_given_to_agreeing = list(set([self.partitioning.get_partition(vec) for vec in split.agreeing_RStates]))
return self.partitioning.refinement_doesnt_hurt_other_clusters \
and new_states_given_to_agreeing == [old_cluster] \
and not self.partitioning.get_partition(split.conflicted_RState) == old_cluster
def counterexample(self,dfa):
print("guided starting equivalence query for DFA of size " + str(len(dfa.Q)))
dfa.draw_nicely(maximum=30)
counterexample = self._cex_from_starting_dict(dfa)
if not None is counterexample:
return counterexample,counterexample_message(counterexample,self.whiteboxrnn)
self.proposed_dfa = dfa
while True: #main loop: restarts every time the partitioning is refined
self._initialise_unrolling() # start BFS exploration of network abstraction with current partitioning
while True: #inner loop: extracts according to the partitioning, comparing to the proposed dfa as it goes
if self._out_of_time(): # note: putting this after all the next bits sometimes leaves the time limit unchecked for a very long time...
return None, "lstar extraction not successful - ran out of time"
if len(self.new_RStates) == 0: # seen everything there is to see here
return None, "lstar successful: unrolling seems equivalent to proposed automaton"
counterexample, split = self._process_top_pair() # always returns a cex, or a split, or neither - but never both
if not None is counterexample:
return counterexample,counterexample_message(counterexample,self.whiteboxrnn)
elif split.has_info:
cluster_being_split = self.partitioning.get_partition(split.agreeing_RStates[0])
self.partitioning.refine(split.agreeing_RStates,split.conflicted_RState)
if self._split_was_clean(cluster_being_split,split):
# the latest R-state got a new cluster of its own and absolutely nothing else changed,
# so we can just reprocess this visitor and continue as if nothing happened
self.new_RStates = [self.new_RStates_backup] + self.new_RStates
else:
print("split wasn't perfect: gotta start over")
break # clustering has changed, have to restart unrolling from the top
def counterexample_message(counterexample,rnn):
return ("returning counterexample of length " + str(len(counterexample)) + ":\t\t" + counterexample +
", this counterexample is " + ("accepted" if rnn.classify_word(counterexample)==True else "rejected") +
" by the given RNN.")
class SplitInfo: #todo: move this to quantisations and just give the whole thing over to the relevant function there, instead of unpacking it to 3 parameters here
def __init__(self,agreeing_RStates=None,conflicted_RState=None):
self.agreeing_RStates = agreeing_RStates
self.conflicted_RState = conflicted_RState
self.has_info = not (None is conflicted_RState)
class UnrollingInfo:
def __init__(self,dfa_state,path,RState,accepting):
self.explored = False
self.dfa_state = dfa_state
self.paths = [path]
self.RStates = [RState]
self.accepting = accepting
def __add__(self, other):
res = deepcopy(self)
res.paths += other.paths
res.RStates += other.RStates
return res
class NoCounterexampleFromClassificationConflict(Exception):
pass