forked from opensbt/opensbt-core
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmnist_simulation.py
144 lines (120 loc) · 5.66 KB
/
mnist_simulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import List
from mnist import features, predictor, vectorization_tools
from mnist.digit_input import Digit
from mnist.config import EXPECTED_LABEL
from opensbt.simulation.simulator import Simulator, SimulationOutput
import json
from scipy.stats import entropy
import numpy as np
import logging as log
from mnist.mnist_loader import mnist_loader
from mnist.mutations import *
class MnistSimulator(Simulator):
do_visualize = True
sim_time = 2
time_step = 0.01
@staticmethod
def simulate(list_individuals,
variable_names,
scenario_path: str,
sim_time: float,
time_step = 1,
do_visualize = False,
**kwargs) -> List[SimulationOutput]:
try:
results = []
problem = kwargs["problem"]
seed_digits = problem.seed_digits
for ind in list_individuals:
out = {}
# apply mutations only to the same seed digit
new_digit = seed_digits[0].clone()
if len(variable_names) == 3:
# create digit
extent_1 = ind[0]
extent_2 = ind[1]
vertex = round(ind[2])
new_digit = apply_mutation_index(problem, new_digit, extent_1, extent_2, vertex)
elif len(variable_names) == 6:
# create digit
extent_1 = ind[0]
extent_2 = ind[1]
vertex_1 = round(ind[4])
# create digit
extent_3 = ind[2]
extent_4 = ind[3]
vertex_2 = round(ind[5])
new_digit = apply_mutation_index_bi(problem,
new_digit,
extent_1,
extent_2,
extent_3,
extent_4,
vertex_1,
vertex_2)
assert(new_digit.seed == problem.seed_digits[0].seed)
##### Evalute fitness value of the classification ( = simulation) ########
predicted_label, confidence = \
predictor.Predictor.predict(new_digit.purified)
predictions = predictor.Predictor.predict_extended(new_digit.purified)
##### store info in digit ##########
new_digit.predicted_label = predicted_label
new_digit.confidence = confidence
brightness = new_digit.brightness(min_saturation=problem.min_saturation)
coverage = new_digit.coverage(min_saturation=problem.min_saturation)
coverage_rel = new_digit.coverage(
min_saturation=problem.min_saturation,
relative = True
)
# calculate static and dynamic properties
data = {}
data["predicted_label"] = predicted_label
data["confidence"] = confidence
data["predictions"] = predictions
data["expected_label"] = problem.expected_label
# data["archive"] = archive # all digits found so far # TODO improve how we pass the archive
data["digit"] = new_digit
# data["distance_archive"] = distance
data["coverage"] = coverage
data["brightness"] = brightness
data["move_distance"] = features.move_distance(new_digit)
data["angle"] = features.angle_calc(new_digit)
data["orientation"] = features.orientation_calc(new_digit, problem.min_saturation)
data["entropy_signed"] = - entropy(pk=predictions) if np.argmax(predictions) != problem.expected_label else entropy(pk=predictions)
# data["distance_test_input"] = distance_test_input
data["coverage_rel"] = coverage_rel
log.info("Individual evaluated and mutated digit created.")
dict_simout = {
"simTime": 0.0,
"times": [],
"location": {},
"velocity": {},
"speed": {},
"acceleration": {},
"yaw": {},
"collisions": [],
"actors": {},
"otherParams": {}
}
# create artifical simout
simout = SimulationOutput.from_json(json.dumps(dict_simout))
# fill
simout.otherParams["data"] = data
simout.otherParams["DIG"] = new_digit
results.append(simout)
except Exception as e:
raise e
return results
def generate_digit(seed):
seed_image = mnist_loader.get_x_test()[int(seed)]
xml_desc = vectorization_tools.vectorize(seed_image)
return Digit(xml_desc, EXPECTED_LABEL, seed)
# get predicitons and metrics for digits for input validation
def generate_and_evaluate_digit(seed):
seed_image = mnist_loader.get_x_test()[int(seed)]
xml_desc = vectorization_tools.vectorize(seed_image)
digit = Digit(xml_desc, EXPECTED_LABEL, seed)
predicted_label, confidence = predictor.Predictor.predict(digit.purified)
digit.confidence = confidence
digit.predicted_label = predicted_label
return digit