diff --git a/docs/source/conf.py b/docs/source/conf.py index 8be9c268..eaf721c0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -55,7 +55,7 @@ def _normalize_docstring_lines(lines: list[str]) -> list[str]: if not l.strip(): # Blank line reset param is_param_field = False else: # Restore indentation - l = " " + l.lstrip() + l = f" {l.lstrip()}" new_lines.append(l) return new_lines diff --git a/notebooks/common.py b/notebooks/common.py index 01ad4a19..8fa8b4f6 100644 --- a/notebooks/common.py +++ b/notebooks/common.py @@ -19,21 +19,19 @@ def collect_replica_files(path, prefix): begin with the given prefix.""" if not os.path.exists(path): - raise Exception("No path '{}'.".format(path)) + raise Exception(f"No path '{path}'.") results = [] for dir, subdirs, files in os.walk(path): - for file in files: - if file.startswith(prefix): - results.append(os.path.join(dir, file)) - + results.extend( + os.path.join(dir, file) + for file in files + if file.startswith(prefix) + ) return results def load_results(files, load_sut_output=True): - results = [] - for file in files: - results.append(STGEMResult.restore_from_file(file)) - + results = [STGEMResult.restore_from_file(file) for file in files] # This reduces memory usage if these values are not needed. if not load_sut_output: for result in results: @@ -48,7 +46,9 @@ def loadExperiments(path, benchmarks, prefixes): for prefix in prefixes[benchmark]: files = collect_replica_files(os.path.join(path, benchmark), prefix) if len(files) == 0: - raise Exception("Empty experiment for prefix '{}' for benchmark '{}'.".format(prefix, benchmark)) + raise Exception( + f"Empty experiment for prefix '{prefix}' for benchmark '{benchmark}'." + ) experiments[benchmark][prefix] = load_results(files) return experiments @@ -57,10 +57,10 @@ def falsification_rate(experiment): if len(experiment) == 0: return None - c = 0 - for result in experiment: - c += 1 if any(step.success for step in result.step_results) else 0 - + c = sum( + 1 if any(step.success for step in result.step_results) else 0 + for result in experiment + ) return c/len(experiment) def times(replica): @@ -102,17 +102,11 @@ def mean_min_along(results, length=None): A.append(B) A = np.array(A) - C = np.mean(A, axis=0) - - return C + return np.mean(A, axis=0) def first_falsification(replica): _, _, Y = replica.test_repository.get() - for i in range(len(Y)): - if min(Y[i]) <= 0.0: - return i - - return None + return next((i for i in range(len(Y)) if min(Y[i]) <= 0.0), None) def set_boxplot_color(bp, color): plt.setp(bp["boxes"], color=color) @@ -158,10 +152,15 @@ def plotTest(replica, idx): if replica.sut_parameters["input_type"] == "vector": input_type = "vector" - elif replica.sut_parameters["input_type"] == "signal" or replica.sut_parameters["input_type"] == "piecewise constant signal": + elif replica.sut_parameters["input_type"] in [ + "signal", + "piecewise constant signal", + ]: input_type = "signal" else: - raise Exception("Unknown input type '{}'.".format(replica.sut_parameters["input_type"])) + raise Exception( + f"""Unknown input type '{replica.sut_parameters["input_type"]}'.""" + ) output_type = replica.sut_parameters["output_type"] inputs = replica.sut_parameters["inputs"] @@ -213,32 +212,31 @@ def plotTest(replica, idx): # Output. print(", ".join(outputs)) print(Y[idx].outputs) + elif output_type == "signal": + fig, ax = plt.subplots(1, len(outputs), figsize=(10*len(outputs), 10)) + + # Input. + print(", ".join(inputs)) + print(X[idx].input_denormalized) + + # Output. + for i, var in enumerate(outputs): + o = ax[i] if len(outputs) > 1 else ax + o.set_title(var) + o.set_xlim((0, simulation_time)) + o.set_ylim(output_range[i]) + x = Z[idx].output_timestamps + y = Z[idx].outputs[i] + o.plot(x, y) else: - if output_type == "signal": - fig, ax = plt.subplots(1, len(outputs), figsize=(10*len(outputs), 10)) - - # Input. - print(", ".join(inputs)) - print(X[idx].input_denormalized) - - # Output. - for i, var in enumerate(outputs): - o = ax[i] if len(outputs) > 1 else ax - o.set_title(var) - o.set_xlim((0, simulation_time)) - o.set_ylim(output_range[i]) - x = Z[idx].output_timestamps - y = Z[idx].outputs[i] - o.plot(x, y) - else: - # Input. - print(", ".join(inputs)) - print(X[idx].input_denormalized) - print() + # Input. + print(", ".join(inputs)) + print(X[idx].input_denormalized) + print() - # Output. - print(", ".join(outputs)) - print(Y[idx].outputs) + # Output. + print(", ".join(outputs)) + print(Y[idx].outputs) def animateResult(replica): inputs = replica.sut_parameters["inputs"] @@ -329,7 +327,7 @@ def visualize3DTestSuite(experiment, idx): angle = 25 result = experiment[idx] - if not result.sut_parameters["input_type"] == "vector": + if result.sut_parameters["input_type"] != "vector": raise Exception("Test suite visualization available only for vector input SUTs.") X, _, Y = result.test_repository.get() @@ -371,7 +369,7 @@ def visualize3DTestSuite(experiment, idx): # of the robustness values in both classes. interval_false = [1, -1] # Min & Max robustness values interval_persist = [1, -1] - c = list() # List for tracking input indexes that fail the test + c = [] for i in range(len(X)): if (Y[i] <= falsify_pct): diff --git a/notebooks/losses.sync.py b/notebooks/losses.sync.py index f32c9f56..d9108c7c 100644 --- a/notebooks/losses.sync.py +++ b/notebooks/losses.sync.py @@ -230,7 +230,9 @@ def generator_loss_on_batch(model, batch_size): # TODO: This needs to be updated to use the latest performance records. What does this do? A = model.perf.histories["generator_loss"][-1][0] B = model.perf.histories["generator_loss"][-1][-1] -print("training loss: {} -> {}".format(A, B)) -print("noise batch loss: {}".format(generator_loss_on_batch(model, batch_size))) -print("discriminator batch loss: {}".format(discriminator_loss_on_batch(model, batch_size))) +print(f"training loss: {A} -> {B}") +print(f"noise batch loss: {generator_loss_on_batch(model, batch_size)}") +print( + f"discriminator batch loss: {discriminator_loss_on_batch(model, batch_size)}" +) diff --git a/problems/arch-comp-2021/afc/benchmark.py b/problems/arch-comp-2021/afc/benchmark.py index b2c26ef2..8736d8df 100644 --- a/problems/arch-comp-2021/afc/benchmark.py +++ b/problems/arch-comp-2021/afc/benchmark.py @@ -72,9 +72,11 @@ def build_specification(selected_specification, afc_mode="normal"): if selected_specification == "AFC27": E = 0.1 # Used in Ernst et al. #E = 0.05 # Used in ARCH-COMP 2021. - rise = "(THROTTLE < 8.8) and (eventually[0,{}](THROTTLE > 40.0))".format(E) - fall = "(THROTTLE > 40.0) and (eventually[0,{}](THROTTLE < 8.8))".format(E) - specification = "always[11,50](({} or {}) -> always[1,5](|MU| < 0.008))".format(rise, fall) + rise = f"(THROTTLE < 8.8) and (eventually[0,{E}](THROTTLE > 40.0))" + fall = f"(THROTTLE > 40.0) and (eventually[0,{E}](THROTTLE < 8.8))" + specification = ( + f"always[11,50](({rise} or {fall}) -> always[1,5](|MU| < 0.008))" + ) specifications = [specification] strict_horizon_check = False @@ -85,7 +87,7 @@ def build_specification(selected_specification, afc_mode="normal"): specifications = [specification] strict_horizon_check = True else: - raise Exception("Unknown specification '{}'.".format(selected_specification)) + raise Exception(f"Unknown specification '{selected_specification}'.") return sut_parameters, specifications, strict_horizon_check @@ -102,15 +104,13 @@ def step_factory(): step_1 = Search(mode=mode, budget_threshold={"executions": 75}, algorithm=Random(model_factory=(lambda: Uniform())) - ) + ) step_2 = Search(mode=mode, budget_threshold={"executions": 300}, algorithm=OGAN(model_factory=(lambda: OGAN_Model(ogan_model_parameters["convolution"])), parameters=ogan_parameters) #algorithm=WOGAN(model_factory=(lambda: WOGAN_Model()) ) - #steps = [step_1] - steps = [step_1, step_2] - return steps + return [step_1, step_2] def get_step_factory(): return step_factory diff --git a/problems/arch-comp-2021/at/benchmark.py b/problems/arch-comp-2021/at/benchmark.py index eafef4a6..8346a6c4 100644 --- a/problems/arch-comp-2021/at/benchmark.py +++ b/problems/arch-comp-2021/at/benchmark.py @@ -159,16 +159,14 @@ def step_factory(): step_1 = Search(mode=mode, budget_threshold={"executions": 75}, algorithm=Random(model_factory=(lambda: Uniform())) - ) + ) step_2 = Search(mode=mode, budget_threshold={"executions": 300}, algorithm=OGAN(model_factory=(lambda: OGAN_Model(ogan_model_parameters["convolution"])), parameters=ogan_parameters), #algorithm=WOGAN(model_factory=(lambda: WOGAN_Model())), results_include_models=False ) - #steps = [step_1] - steps = [step_1, step_2] - return steps + return [step_1, step_2] def get_step_factory(): return step_factory diff --git a/problems/arch-comp-2021/cc/benchmark.py b/problems/arch-comp-2021/cc/benchmark.py index 23dbaf93..ae26a176 100644 --- a/problems/arch-comp-2021/cc/benchmark.py +++ b/problems/arch-comp-2021/cc/benchmark.py @@ -128,16 +128,14 @@ def step_factory(): step_1 = Search(mode=mode, budget_threshold={"executions": 75}, algorithm=Random(model_factory=(lambda: Uniform())) - ) + ) step_2 = Search(mode=mode, budget_threshold={"executions": 300}, algorithm=OGAN(model_factory=(lambda: OGAN_Model(ogan_model_parameters["convolution"])), parameters=ogan_parameters), #algorithm=WOGAN(model_factory=(lambda: WOGAN_Model())), results_include_models=False ) - #steps = [step_1] - steps = [step_1, step_2] - return steps + return [step_1, step_2] def get_step_factory(): return step_factory diff --git a/problems/arch-comp-2021/check_specification_correctness.py b/problems/arch-comp-2021/check_specification_correctness.py index c2770027..e4e31d65 100644 --- a/problems/arch-comp-2021/check_specification_correctness.py +++ b/problems/arch-comp-2021/check_specification_correctness.py @@ -48,15 +48,17 @@ """ - for benchmark in data: - module = importlib.import_module("{}.benchmark".format(benchmark.lower())) + for benchmark, value in data.items(): + module = importlib.import_module(f"{benchmark.lower()}.benchmark") build_specification = module.build_specification - for specification, ((idx, robustness, precision), test) in data[benchmark].items(): + for specification, ((idx, robustness, precision), test) in value.items(): sut_parameters, specifications, strict_horizon_check = build_specification(specification) - if "type" in sut_parameters and sut_parameters["type"] == "simulink": - sut = Matlab_Simulink(sut_parameters) - else: - sut = Matlab(sut_parameters) + sut = ( + Matlab_Simulink(sut_parameters) + if "type" in sut_parameters + and sut_parameters["type"] == "simulink" + else Matlab(sut_parameters) + ) sut.setup() objectives = [FalsifySTL(specification=s, scale=False, strict_horizon_check=strict_horizon_check) for s in specifications] for objective in objectives: @@ -67,7 +69,9 @@ output = [objective(sut_input, sut_output) for objective in objectives] if abs(output[idx] - robustness) >= 10**(-precision): - raise SystemExit("Incorrect output robustness {} for benchmark '{}' and specification '{}'. Expected {}.".format(output[idx], benchmark, specification, robustness)) + raise SystemExit( + f"Incorrect output robustness {output[idx]} for benchmark '{benchmark}' and specification '{specification}'. Expected {robustness}." + ) print("All correct!") diff --git a/problems/arch-comp-2021/create_validation_data.py b/problems/arch-comp-2021/create_validation_data.py index a98b7481..528af0a1 100644 --- a/problems/arch-comp-2021/create_validation_data.py +++ b/problems/arch-comp-2021/create_validation_data.py @@ -22,14 +22,21 @@ def main(selected_benchmark, selected_specification, mode, n, init_seed, identifier): N = n - if not selected_specification in benchmark_specifications[selected_benchmark]: - raise Exception("No specification '{}' for benchmark {}.".format(selected_specification, selected_benchmark)) - - output_file = "{}.npy.gz".format(identifier) + if ( + selected_specification + not in benchmark_specifications[selected_benchmark] + ): + raise Exception( + f"No specification '{selected_specification}' for benchmark {selected_benchmark}." + ) + + output_file = f"{identifier}.npy.gz" if os.path.exists(output_file): - raise Exception("Output file '{}' already exists.".format(output_file)) + raise Exception(f"Output file '{output_file}' already exists.") - benchmark_module = importlib.import_module("{}.benchmark".format(selected_benchmark.lower())) + benchmark_module = importlib.import_module( + f"{selected_benchmark.lower()}.benchmark" + ) sut_parameters, specifications, strict_horizon_check = benchmark_module.build_specification(selected_specification, mode) @@ -38,9 +45,10 @@ def main(selected_benchmark, selected_specification, mode, n, init_seed, identif else: sut = Matlab(sut_parameters) - ranges = {} - for n in range(len(sut_parameters["input_range"])): - ranges[sut_parameters["inputs"][n]] = sut_parameters["input_range"][n] + ranges = { + sut_parameters["inputs"][n]: sut_parameters["input_range"][n] + for n in range(len(sut_parameters["input_range"])) + } for n in range(len(sut_parameters["output_range"])): ranges[sut_parameters["outputs"][n]] = sut_parameters["output_range"][n] diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/Autopilot.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/Autopilot.py index d089c8c2..645ed28c 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/Autopilot.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/Autopilot.py @@ -54,12 +54,10 @@ def get_u_ref(self, t, x_f16): Nz, ps, Ny_r, throttle = self.get_u_ref(t, x_f16) - assert Nz <= self.ctrlLimits.NzMax, "autopilot commanded too low Nz ({})".format(Nz) - assert Nz >= self.ctrlLimits.NzMin, "autopilot commanded too high Nz ({})".format(Nz) + assert Nz <= self.ctrlLimits.NzMax, f"autopilot commanded too low Nz ({Nz})" + assert Nz >= self.ctrlLimits.NzMin, f"autopilot commanded too high Nz ({Nz})" - u_ref = np.array([Nz, ps, Ny_r, throttle], dtype=float) - - return u_ref + return np.array([Nz, ps, Ny_r, throttle], dtype=float) def get_num_integrators(self): 'get the number of integrators in the autopilot' @@ -100,9 +98,9 @@ def advance_discrete_state(self, t, x_f16): eps_phi = deg2rad(5) # Max roll angle magnitude before pulling g's eps_p = deg2rad(1) # Max roll rate magnitude before pulling g's path_goal = deg2rad(0) # Final desired path angle - man_start = 2 # maneuver starts after 2 seconds - if self.state == GcasAutopilot.STATE_START: + man_start = 2 # maneuver starts after 2 seconds + if t >= man_start: self.state = GcasAutopilot.STATE_ROLL rv = True diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/BuildLqrControllers.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/BuildLqrControllers.py index 950c98fa..8628843d 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/BuildLqrControllers.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/BuildLqrControllers.py @@ -60,11 +60,7 @@ def subindex_mat(mat, rows, cols): rv = [] for row in rows: - vals = [] - - for col in cols: - vals.append(mat[row, col]) - + vals = [mat[row, col] for col in cols] rv.append(vals) return np.array(rv, dtype=float) @@ -165,7 +161,7 @@ def build_a_b_tilde_longitudinal(A, B, C, D, _): Atilde[A_con.shape[0]:, :C_con.shape[1]] = C_con # stack B and D - Btilde = np.array([row for row in B_con] + [row for row in D_con], dtype=float) + Btilde = np.array(list(B_con) + list(D_con), dtype=float) return Atilde, Btilde @@ -186,12 +182,12 @@ def build_a_b_tilde_lateral(A, B, C, D, xequil): C_top = subindex_mat(C, [Y_P], si) + subindex_mat(C, [Y_R], si) * xequil[X_ALPHA] # ps ~= p + r*alpha C_bottom = subindex_mat(C, [Y_AY], si) + subindex_mat(C, [Y_R], si) # Ny + r - C_con = np.array([row for row in C_top] + [row for row in C_bottom], dtype=float) # stack C_top and C_bottom + C_con = np.array(list(C_top) + list(C_bottom), dtype=float) D_top = subindex_mat(D, [Y_P], ii) + subindex_mat(D, [Y_R], ii) * xequil[X_ALPHA] # ps ~= p + r*alpha D_bottom = subindex_mat(D, [Y_AY], ii) + subindex_mat(D, [Y_R], ii) # Ny + r - D_con = np.array([row for row in D_top] + [row for row in D_bottom], dtype=float) # stack D_top and D_bottom + D_con = np.array(list(D_top) + list(D_bottom), dtype=float) # stack A and C in a square matrix size = A_con.shape[0] + C_con.shape[0] @@ -200,7 +196,7 @@ def build_a_b_tilde_lateral(A, B, C, D, xequil): Atilde[A_con.shape[0]:, :C_con.shape[1]] = C_con # stack B and D - Btilde = np.array([row for row in B_con] + [row for row in D_con], dtype=float) + Btilde = np.array(list(B_con) + list(D_con), dtype=float) return Atilde, Btilde diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/CtrlLimits.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/CtrlLimits.py index 32537490..2e06907d 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/CtrlLimits.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/CtrlLimits.py @@ -30,13 +30,18 @@ def check(self): ctrlLimits = self - assert not (ctrlLimits.ThrottleMin < 0 or ctrlLimits.ThrottleMax > 1), 'ctrlLimits: Throttle Limits (0 to 1)' + assert ( + ctrlLimits.ThrottleMin >= 0 and ctrlLimits.ThrottleMax <= 1 + ), 'ctrlLimits: Throttle Limits (0 to 1)' - assert not (ctrlLimits.ElevatorMaxDeg > 25 or ctrlLimits.ElevatorMinDeg < -25), \ - 'ctrlLimits: Elevator Limits (-25 deg to 25 deg)' + assert ( + ctrlLimits.ElevatorMaxDeg <= 25 and ctrlLimits.ElevatorMinDeg >= -25 + ), 'ctrlLimits: Elevator Limits (-25 deg to 25 deg)' - assert not (ctrlLimits.AileronMaxDeg > 21.5 or ctrlLimits.AileronMinDeg < -21.5), \ - 'ctrlLimits: Aileron Limits (-21.5 deg to 21.5 deg)' + assert ( + ctrlLimits.AileronMaxDeg <= 21.5 and ctrlLimits.AileronMinDeg >= -21.5 + ), 'ctrlLimits: Aileron Limits (-21.5 deg to 21.5 deg)' - assert not (ctrlLimits.RudderMaxDeg > 30 or ctrlLimits.RudderMinDeg < -30), \ - 'ctrlLimits: Rudder Limits (-30 deg to 30 deg)' + assert ( + ctrlLimits.RudderMaxDeg <= 30 and ctrlLimits.RudderMinDeg >= -30 + ), 'ctrlLimits: Rudder Limits (-30 deg to 30 deg)' diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/LowLevelController.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/LowLevelController.py index 8d8b4add..2978f2a3 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/LowLevelController.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/LowLevelController.py @@ -61,7 +61,7 @@ def get_u_deg(self, u_ref, f16_state): u_deg[0] = u_ref[3] # Add in equilibrium control - u_deg[0:4] += self.uequil + u_deg[:4] += self.uequil ## Limit controls to saturation limits ctrlLimits = self.ctrlLimits diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/PassFailAutomaton.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/PassFailAutomaton.py index f55e6433..3c846e44 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/PassFailAutomaton.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/PassFailAutomaton.py @@ -37,13 +37,17 @@ def check(self): flightLimits = self - assert not (flightLimits.vMin < 300 or flightLimits.vMax > 2500), \ - 'flightLimits: Airspeed limits outside model limits (300 to 2500)' + assert ( + flightLimits.vMin >= 300 and flightLimits.vMax <= 2500 + ), 'flightLimits: Airspeed limits outside model limits (300 to 2500)' - assert not (flightLimits.alphaMinDeg < -10 or flightLimits.alphaMaxDeg > 45), \ - 'flightLimits: Alpha limits outside model limits (-10 to 45)' + assert ( + flightLimits.alphaMinDeg >= -10 and flightLimits.alphaMaxDeg <= 45 + ), 'flightLimits: Alpha limits outside model limits (-10 to 45)' - assert not (abs(flightLimits.betaMaxDeg) > 30), 'flightLimits: Beta limit outside model limits (30 deg)' + assert ( + abs(flightLimits.betaMaxDeg) <= 30 + ), 'flightLimits: Beta limit outside model limits (30 deg)' class PassFailAutomaton(Freezable): '''The parent class for a pass fail automaton... checks each state against the flight envelope limits''' @@ -95,7 +99,7 @@ def result(self): returns True iff all conditions passed ''' - return all([pfa.result() for pfa in self.pfa_list]) + return all(pfa.result() for pfa in self.pfa_list) class FlightLimitsPFA(PassFailAutomaton): '''An automaton that checks the flight limits at each step''' diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/adc.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/adc.py index a70390eb..98df2c49 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/adc.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/adc.py @@ -16,11 +16,7 @@ def adc(vt, alt): ro = 2.377e-3 tfac = 1 - .703e-5 * alt - if alt >= 35000: # in stratosphere - t = 390 - else: - t = 519 * tfac # 3 rankine per atmosphere (3 rankine per 1000 ft) - + t = 390 if alt >= 35000 else 519 * tfac # rho = freestream mass density rho = ro * tfac**4.14 diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/controlledF16.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/controlledF16.py index 022a6147..79162e0a 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/controlledF16.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/controlledF16.py @@ -17,10 +17,12 @@ def controlledF16(t, x_f16, F16_model, autopilot, llc, multipliers=None): 'returns the LQR-controlled F-16 state derivatives and more' assert isinstance(x_f16, np.ndarray) - assert isinstance(autopilot, Autopilot), "autopilot type was {}".format(type(autopilot)) + assert isinstance( + autopilot, Autopilot + ), f"autopilot type was {type(autopilot)}" assert isinstance(llc, LowLevelController) - assert F16_model == 'stevens' or F16_model == 'morelli', 'Unknown F16_model: {}'.format(F16_model) + assert F16_model in ['stevens', 'morelli'], f'Unknown F16_model: {F16_model}' # Get Reference Control Vector (commanded Nz, ps, Ny + r, throttle) u_ref = autopilot.get_u_ref(t, x_f16) # in g's & rads / sec @@ -29,7 +31,9 @@ def controlledF16(t, x_f16, F16_model, autopilot, llc, multipliers=None): # Note: Control vector (u) for subF16 is in units of degrees - xd_model, Nz, Ny, _, _ = subf16_model(x_f16[0:13], u_deg, F16_model, multipliers=multipliers) + xd_model, Nz, Ny, _, _ = subf16_model( + x_f16[:13], u_deg, F16_model, multipliers=multipliers + ) # Nonlinear (Actual): ps = p * cos(alpha) + r * sin(alpha) ps = x_ctrl[4] * cos(x_ctrl[0]) + x_ctrl[5] * sin(x_ctrl[0]) @@ -58,6 +62,6 @@ def controlledF16(t, x_f16, F16_model, autopilot, llc, multipliers=None): for i in xrange(1, 4): u_rad[i] = deg2rad(u_deg[i]) - u_rad[4:7] = u_ref[0:3] + u_rad[4:7] = u_ref[:3] return xd, u_rad, Nz, ps, Ny_r diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/cx.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/cx.py index 29025a43..5113b446 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/cx.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/cx.py @@ -45,6 +45,4 @@ def cx(alpha, el): u = a[k-1, n-1] v = t + abs(da) * (a[l-1, m-1] - t) w = u + abs(da) * (a[l-1, n-1] - u) - cxx = v + (w - v) * abs(de) - - return cxx + return v + (w - v) * abs(de) diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/pdot.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/pdot.py index ca4fd6c0..a9f49408 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/pdot.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/pdot.py @@ -16,14 +16,11 @@ def pdot(p3, p1): else: p2 = 60 t = rtau(p2 - p3) + elif p3 >= 50: + t = 5 + p2 = 40 else: - if p3 >= 50: - t = 5 - p2 = 40 - else: - p2 = p1 - t = rtau(p2 - p3) - - pd = t * (p2 - p3) + p2 = p1 + t = rtau(p2 - p3) - return pd + return t * (p2 - p3) diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/plot.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/plot.py index c9fb8094..db84401b 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/plot.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/plot.py @@ -254,7 +254,7 @@ def plot2d(filename, times, plot_data_list): and each var_data is a list of tuples: (state_index, label) ''' - num_plots = sum([len(var_data) for _, var_data in plot_data_list]) + num_plots = sum(len(var_data) for _, var_data in plot_data_list) fig = plt.figure(figsize=(7, 5)) diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/rtau.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/rtau.py index a2f00d6e..5f64c2e4 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/rtau.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/rtau.py @@ -9,10 +9,8 @@ def rtau(dp): 'rtau function' if dp <= 25: - rt = 1.0 + return 1.0 elif dp >= 50: - rt = .1 + return .1 else: - rt = 1.9 - .036 * dp - - return rt + return 1.9 - .036 * dp diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/tgear.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/tgear.py index 8c0fe473..d15085c0 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/tgear.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/tgear.py @@ -6,9 +6,4 @@ def tgear(thtl): 'tgear function' - if thtl <= .77: - tg = 64.94 * thtl - else: - tg = 217.38 * thtl - 117.38 - - return tg + return 64.94 * thtl if thtl <= .77 else 217.38 * thtl - 117.38 diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/thrust.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/thrust.py index 4a39ba37..293c07cf 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/thrust.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/thrust.py @@ -66,11 +66,9 @@ def thrust(power, alt, rmach): s = a[i, m] * cdh + a[i + 1, m] * dh t = a[i, m + 1] * cdh + a[i + 1, m + 1] * dh tidl = s + (t - s) * dm - thrst = tidl + (tmil - tidl) * power * .02 + return tidl + (tmil - tidl) * power * .02 else: s = c[i, m] * cdh + c[i + 1, m] * dh t = c[i, m + 1] * cdh + c[i + 1, m + 1] * dh tmax = s + (t - s) * dm - thrst = tmil + (tmax - tmil) * (power - 50) * .02 - - return thrst + return tmil + (tmax - tmil) * (power - 50) * .02 diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/util.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/util.py index 6d5bacbd..764b06e8 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/util.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v1/code/util.py @@ -16,7 +16,9 @@ def freeze_attrs(self): def __setattr__(self, key, value): if self._frozen and not hasattr(self, key): - raise TypeError("{} does not contain attribute '{}' (object was frozen)".format(self, key)) + raise TypeError( + f"{self} does not contain attribute '{key}' (object was frozen)" + ) object.__setattr__(self, key, value) @@ -89,22 +91,15 @@ def fix(ele): assert isinstance(ele, float) - if ele > 0: - rv = int(floor(ele)) - else: - rv = int(ceil(ele)) - - return rv + return int(floor(ele)) if ele > 0 else int(ceil(ele)) def sign(ele): 'sign of a number' if ele < 0: - rv = -1 + return -1 elif ele == 0: - rv = 0 + return 0 else: - rv = 1 - - return rv + return 1 diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/gcas/gcas_autopilot.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/gcas/gcas_autopilot.py index 5ce8151e..56a5007e 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/gcas/gcas_autopilot.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/gcas/gcas_autopilot.py @@ -71,12 +71,7 @@ def advance_discrete_mode(self, t, x_f16): if self.is_nose_high_enough(x_f16) and t >= self.pull_start_time + self.cfg_min_pull_time: self.mode = 'standby' - rv = premode != self.mode - - #if rv: - # self.log(f"GCAS transition {premode} -> {self.mode} at time {t}") - - return rv + return premode != self.mode def are_wings_level(self, x_f16): 'are the wings level?' @@ -117,16 +112,14 @@ def get_u_ref(self, _t, x_f16): '''get the reference input signals''' if self.mode == 'standby': - rv = np.zeros(4) + return np.zeros(4) elif self.mode == 'waiting': - rv = self.waiting_cmd + return self.waiting_cmd elif self.mode == 'roll': - rv = self.roll_wings_level(x_f16) + return self.roll_wings_level(x_f16) else: assert self.mode == 'pull', f"unknown mode: {self.mode}" - rv = self.pull_nose_level() - - return rv + return self.pull_nose_level() def pull_nose_level(self): 'get commands in mode PULL' diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/waypoint/waypoint_autopilot.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/waypoint/waypoint_autopilot.py index 0c9575ff..ada5b1ba 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/waypoint/waypoint_autopilot.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/examples/waypoint/waypoint_autopilot.py @@ -84,10 +84,7 @@ def get_u_ref(self, _t, x_f16): nz_cmd = max(self.cfg_min_nz_cmd, min(self.cfg_max_nz_cmd, nz_cmd)) throttle = max(min(throttle, 1), 0) - # Create reference vector - rv = [nz_cmd, ps_cmd, 0, throttle] - - return rv + return [nz_cmd, ps_cmd, 0, throttle] def track_altitude(self, x_f16): 'get nz to track altitude, taking turning into account' @@ -104,15 +101,13 @@ def track_altitude(self, x_f16): if h_error > 0: # Ascend wings level or banked - nz = nz_alt + nz_roll + return nz_alt + nz_roll elif abs(phi) < np.deg2rad(15): # Descend wings (close enough to) level - nz = nz_alt + nz_roll + return nz_alt + nz_roll else: # Descend in bank (no negative Gs) - nz = max(0, nz_alt + nz_roll) - - return nz + return max(0, nz_alt + nz_roll) def get_phi_to_track_heading(self, x_f16, psi_cmd): 'get phi from psi_cmd' @@ -144,20 +139,14 @@ def track_roll_angle(self, x_f16, phi_cmd): phi = x_f16[StateIndex.PHI] p = x_f16[StateIndex.P] - # Calculate PD control - ps = (phi_cmd-phi) * self.cfg_k_prop_phi - p * self.cfg_k_der_phi - - return ps + return (phi_cmd-phi) * self.cfg_k_prop_phi - p * self.cfg_k_der_phi def track_airspeed(self, x_f16): 'get throttle command' vt_cmd = self.cfg_airspeed - # Proportional control on airspeed using throttle - throttle = self.cfg_k_vt * (vt_cmd - x_f16[StateIndex.VT]) - - return throttle + return self.cfg_k_vt * (vt_cmd - x_f16[StateIndex.VT]) def track_altitude_wings_level(self, x_f16): 'get nz to track altitude' @@ -174,17 +163,12 @@ def track_altitude_wings_level(self, x_f16): gamma = get_path_angle(x_f16) h_dot = vt * sin(gamma) # Calculated, not differentiated - # Calculate Nz command - nz = self.cfg_k_alt*h_error - self.cfg_k_h_dot*h_dot - - return nz + return self.cfg_k_alt*h_error - self.cfg_k_h_dot*h_dot def is_finished(self, t, x_f16): 'is the maneuver done?' - rv = self.waypoint_index >= len(self.waypoints) and self.done_time + 5.0 < t - - return rv + return self.waypoint_index >= len(self.waypoints) and self.done_time + 5.0 < t def advance_discrete_mode(self, t, x_f16): ''' @@ -239,7 +223,7 @@ def get_waypoint_data(self, x_f16): heading = wrap_to_pi(pi/2 - atan2(delta[1], delta[0])) - horiz_range = np.linalg.norm(delta[0:2]) + horiz_range = np.linalg.norm(delta[:2]) vert_range = np.linalg.norm(delta[2]) return heading, inclination, horiz_range, vert_range, slant_range @@ -252,12 +236,7 @@ def get_nz_for_level_turn_ol(x_f16): # Calculate theta phi = x_f16[StateIndex.PHI] - if abs(phi): # if cos(phi) ~= 0, basically - nz = 1 / cos(phi) - 1 # Keeps plane at altitude - else: - nz = 0 - - return nz + return 1 / cos(phi) - 1 if abs(phi) else 0 def get_path_angle(x_f16): 'get the path angle gamma' @@ -267,11 +246,11 @@ def get_path_angle(x_f16): phi = x_f16[StateIndex.PHI] # Roll anle (rad) theta = x_f16[StateIndex.THETA] # Pitch angle (rad) - gamma = asin((cos(alpha)*sin(theta)- \ - sin(alpha)*cos(theta)*cos(phi))*cos(beta) - \ - (cos(theta)*sin(phi))*sin(beta)) - - return gamma + return asin( + (cos(alpha) * sin(theta) - sin(alpha) * cos(theta) * cos(phi)) + * cos(beta) + - (cos(theta) * sin(phi)) * sin(beta) + ) def wrap_to_pi(psi_rad): '''handle angle wrapping diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/highlevel/controlled_f16.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/highlevel/controlled_f16.py index c7ee1159..849a2377 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/highlevel/controlled_f16.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/highlevel/controlled_f16.py @@ -19,12 +19,12 @@ def controlled_f16(t, x_f16, u_ref, llc, f16_model='morelli', v2_integrators=Fal assert isinstance(llc, LowLevelController) assert u_ref.size == 4 - assert f16_model in ['stevens', 'morelli'], 'Unknown F16_model: {}'.format(f16_model) + assert f16_model in ['stevens', 'morelli'], f'Unknown F16_model: {f16_model}' x_ctrl, u_deg = llc.get_u_deg(u_ref, x_f16) # Note: Control vector (u) for subF16 is in units of degrees - xd_model, Nz, Ny, _, _ = subf16_model(x_f16[0:13], u_deg, f16_model) + xd_model, Nz, Ny, _, _ = subf16_model(x_f16[:13], u_deg, f16_model) if v2_integrators: # integrators from matlab v2 model @@ -55,6 +55,6 @@ def controlled_f16(t, x_f16, u_ref, llc, f16_model='morelli', v2_integrators=Fal for i in range(1, 4): u_rad[i] = deg2rad(u_deg[i]) - u_rad[4:7] = u_ref[0:3] # inner-loop commands are 4-7 + u_rad[4:7] = u_ref[:3] return xd, u_rad, Nz, ps, Ny_r diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/adc.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/adc.py index a70390eb..98df2c49 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/adc.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/adc.py @@ -16,11 +16,7 @@ def adc(vt, alt): ro = 2.377e-3 tfac = 1 - .703e-5 * alt - if alt >= 35000: # in stratosphere - t = 390 - else: - t = 519 * tfac # 3 rankine per atmosphere (3 rankine per 1000 ft) - + t = 390 if alt >= 35000 else 519 * tfac # rho = freestream mass density rho = ro * tfac**4.14 diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/cx.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/cx.py index c23a96cb..afbf362e 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/cx.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/cx.py @@ -45,6 +45,4 @@ def cx(alpha, el): u = a[k-1, n-1] v = t + abs(da) * (a[l-1, m-1] - t) w = u + abs(da) * (a[l-1, n-1] - u) - cxx = v + (w - v) * abs(de) - - return cxx + return v + (w - v) * abs(de) diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/low_level_controller.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/low_level_controller.py index d472bf30..ae318221 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/low_level_controller.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/low_level_controller.py @@ -79,7 +79,7 @@ def get_u_deg(self, u_ref, f16_state): u_deg[0] = u_ref[3] # Add in equilibrium control - u_deg[0:4] += self.uequil + u_deg[:4] += self.uequil ## Limit controls to saturation limits ctrlLimits = self.ctrlLimits diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/pdot.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/pdot.py index dbfb1c32..3392a981 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/pdot.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/pdot.py @@ -16,14 +16,11 @@ def pdot(p3, p1): else: p2 = 60 t = rtau(p2 - p3) + elif p3 >= 50: + t = 5 + p2 = 40 else: - if p3 >= 50: - t = 5 - p2 = 40 - else: - p2 = p1 - t = rtau(p2 - p3) - - pd = t * (p2 - p3) + p2 = p1 + t = rtau(p2 - p3) - return pd + return t * (p2 - p3) diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/rtau.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/rtau.py index a2f00d6e..5f64c2e4 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/rtau.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/rtau.py @@ -9,10 +9,8 @@ def rtau(dp): 'rtau function' if dp <= 25: - rt = 1.0 + return 1.0 elif dp >= 50: - rt = .1 + return .1 else: - rt = 1.9 - .036 * dp - - return rt + return 1.9 - .036 * dp diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/tgear.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/tgear.py index 8c0fe473..d15085c0 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/tgear.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/tgear.py @@ -6,9 +6,4 @@ def tgear(thtl): 'tgear function' - if thtl <= .77: - tg = 64.94 * thtl - else: - tg = 217.38 * thtl - 117.38 - - return tg + return 64.94 * thtl if thtl <= .77 else 217.38 * thtl - 117.38 diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/thrust.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/thrust.py index 3de5f8b1..9d20cb6c 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/thrust.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/lowlevel/thrust.py @@ -66,11 +66,9 @@ def thrust(power, alt, rmach): s = a[i, m] * cdh + a[i + 1, m] * dh t = a[i, m + 1] * cdh + a[i + 1, m + 1] * dh tidl = s + (t - s) * dm - thrst = tidl + (tmil - tidl) * power * .02 + return tidl + (tmil - tidl) * power * .02 else: s = c[i, m] * cdh + c[i + 1, m] * dh t = c[i, m + 1] * cdh + c[i + 1, m + 1] * dh tmax = s + (t - s) * dm - thrst = tmil + (tmax - tmil) * (power - 50) * .02 - - return thrst + return tmil + (tmax - tmil) * (power - 50) * .02 diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/run_f16_sim.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/run_f16_sim.py index 446e65a5..535cd5d8 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/run_f16_sim.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/run_f16_sim.py @@ -120,12 +120,12 @@ def run_f16_sim(initial_state, tmax, ap, step=1/30, extended_states=False, model assert 'finished' in integrator.status - res = {} - res['status'] = integrator.status - res['times'] = times - res['states'] = np.array(states, dtype=float) - res['modes'] = modes - + res = { + 'status': integrator.status, + 'times': times, + 'states': np.array(states, dtype=float), + 'modes': modes, + } if extended_states: res['xd_list'] = xd_list res['ps_list'] = ps_list diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/util.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/util.py index d2f898fd..878bdc2a 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/util.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/util.py @@ -43,7 +43,9 @@ def freeze_attrs(self): def __setattr__(self, key, value): if self._frozen and not hasattr(self, key): - raise TypeError("{} does not contain attribute '{}' (object was frozen)".format(self, key)) + raise TypeError( + f"{self} does not contain attribute '{key}' (object was frozen)" + ) object.__setattr__(self, key, value) @@ -135,10 +137,10 @@ def printmat(mat, main_label, row_label_str, col_label_str): width = 7 - width = max(width, max([len(l) for l in col_labels])) + width = max(width, max(len(l) for l in col_labels)) if row_labels is not None: - width = max(width, max([len(l) for l in row_labels])) + width = max(width, max(len(l) for l in row_labels)) width += 1 @@ -156,9 +158,9 @@ def printmat(mat, main_label, row_label_str, col_label_str): print('') if row_labels is not None: - assert len(row_labels) == mat.shape[0], \ - "row labels (len={}) expected one element for each row of the matrix ({})".format( \ - len(row_labels), mat.shape[0]) + assert ( + len(row_labels) == mat.shape[0] + ), f"row labels (len={len(row_labels)}) expected one element for each row of the matrix ({mat.shape[0]})" for r in range(mat.shape[0]): row = mat[r] @@ -183,24 +185,17 @@ def fix(ele): assert isinstance(ele, float) - if ele > 0: - rv = int(floor(ele)) - else: - rv = int(ceil(ele)) - - return rv + return int(floor(ele)) if ele > 0 else int(ceil(ele)) def sign(ele): 'sign of a number' if ele < 0: - rv = -1 + return -1 elif ele == 0: - rv = 0 + return 0 else: - rv = 1 - - return rv + return 1 def extract_single_result(res, index, llc): 'extract a res object for a sinlge aircraft from a multi-aircraft simulation' @@ -212,14 +207,15 @@ def extract_single_result(res, index, llc): assert index == 0 rv = res else: - rv = {} - rv['status'] = res['status'] - rv['times'] = res['times'] - rv['modes'] = res['modes'] - full_states = res['states'] - rv['states'] = full_states[:, num_vars*index:num_vars*(index+1)] - + rv = { + 'status': res['status'], + 'times': res['times'], + 'modes': res['modes'], + 'states': full_states[ + :, num_vars * index : num_vars * (index + 1) + ], + } if 'xd_list' in res: # extended states key_list = ['xd_list', 'ps_list', 'Nz_list', 'Ny_r_list', 'u_list'] diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/anim3d.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/anim3d.py index 17d4159e..939c15bf 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/anim3d.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/anim3d.py @@ -77,11 +77,7 @@ def make_anim(res, filename, viewsize=1000, viewsize_z=1000, f16_scale=30, trail ##### # fill in defaults - if filename == '': - full_plot = False - else: - full_plot = True - + full_plot = filename != '' for i, skip in enumerate(skip_frames): if skip is not None: continue @@ -121,7 +117,7 @@ def make_anim(res, filename, viewsize=1000, viewsize_z=1000, f16_scale=30, trail all_modes.append(m) all_ps_list.append(ps) all_Nz_list.append(Nz) - + ## fig = plt.figure(figsize=(8, 7)) @@ -229,7 +225,7 @@ def anim_func(global_frame): mode_names = [] for mode in modes: - if not mode in mode_names: + if mode not in mode_names: mode_names.append(mode) mode = modes[frame] @@ -354,7 +350,7 @@ def anim_func(global_frame): extra_args = [] if codec is not None: - extra_args += ['-vcodec', str(codec)] + extra_args += ['-vcodec', codec] anim_obj.save(filename, fps=fps, extra_args=extra_args) print("Finished saving to {} in {:.1f} sec".format(filename, time.time() - start)) diff --git a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/plot.py b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/plot.py index c6901452..0a98b916 100644 --- a/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/plot.py +++ b/problems/arch-comp-2021/f16/AeroBenchVVPython/v2/code/aerobench/visualize/plot.py @@ -136,19 +136,15 @@ def plot_outer_loop(run_sim_result, title='Outer Loop Controls'): # u is: throt, ele, ail, rud, Nz_ref, ps_ref, Ny_r_ref # u_ref is: Nz, ps, Ny + r, throttle - ys_list = [] - - ys_list.append(nz_list) - ys_list.append([u[4] for u in u_list]) - - ys_list.append(ps_list) - ys_list.append([u[5] for u in u_list]) - - ys_list.append(ny_r_list) - ys_list.append([u[6] for u in u_list]) - - # throttle reference is not included... although it's just a small offset so probably less important - ys_list.append([u[0] for u in u_list]) + ys_list = [ + nz_list, + [u[4] for u in u_list], + ps_list, + [u[5] for u in u_list], + ny_r_list, + [u[6] for u in u_list], + [u[0] for u in u_list], + ] labels = ['N_z', 'N_z,ref', 'P_s', 'P_s,ref', 'N_yr', 'N_yr,ref', 'Throttle'] colors = ['r', 'r', 'lime', 'lime', 'b', 'b', 'c'] @@ -251,7 +247,7 @@ def plot2d(filename, times, plot_data_list): and each var_data is a list of tuples: (state_index, label) ''' - num_plots = sum([len(var_data) for _, var_data in plot_data_list]) + num_plots = sum(len(var_data) for _, var_data in plot_data_list) fig = plt.figure(figsize=(7, 5)) diff --git a/problems/arch-comp-2021/f16/benchmark.py b/problems/arch-comp-2021/f16/benchmark.py index eeed4dcd..3a734c20 100644 --- a/problems/arch-comp-2021/f16/benchmark.py +++ b/problems/arch-comp-2021/f16/benchmark.py @@ -76,15 +76,11 @@ def build_specification(selected_specification, mode=None): "simulation_time": 15 } - # Notice that here the input is a vector. - if selected_specification == "F16": - specification = "always[0,15] ALTITUDE > 0" - - specifications = [specification] - strict_horizon_check = True - else: - raise Exception("Unknown specification '{}'.".format(selected_specification)) + if selected_specification != "F16": + raise Exception(f"Unknown specification '{selected_specification}'.") + specifications = ["always[0,15] ALTITUDE > 0"] + strict_horizon_check = True return sut_parameters, specifications, strict_horizon_check def objective_selector_factory(): @@ -99,7 +95,7 @@ def step_factory(): step_1 = Search(mode=mode, budget_threshold={"executions": 75}, algorithm=Random(model_factory=(lambda: Uniform())) - ) + ) step_2 = Search(mode=mode, budget_threshold={"executions": 300}, algorithm=OGAN(model_factory=(lambda: OGAN_Model(ogan_model_parameters["dense"])), parameters=ogan_parameters), @@ -107,9 +103,7 @@ def step_factory(): results_include_models=False ) - #steps = [step_1] - steps = [step_1, step_2] - return steps + return [step_1, step_2] def get_step_factory(): return step_factory diff --git a/problems/arch-comp-2021/f16/f16_python_sut.py b/problems/arch-comp-2021/f16/f16_python_sut.py index 6e4f6618..486a112f 100644 --- a/problems/arch-comp-2021/f16/f16_python_sut.py +++ b/problems/arch-comp-2021/f16/f16_python_sut.py @@ -16,7 +16,7 @@ class F16GCAS_PYTHON2(SUT): def __init__(self, parameters): SUT.__init__(self, parameters) - if not "initial_altitude" in self.parameters: + if "initial_altitude" not in self.parameters: raise Exception("Initial altitude not defined as a SUT parameter.") self.input_type = "vector" @@ -45,7 +45,7 @@ class F16GCAS_PYTHON3(SUT): def __init__(self, parameters): SUT.__init__(self, parameters) - if not "initial_altitude" in self.parameters: + if "initial_altitude" not in self.parameters: raise Exception("Initial altitude not defined as a SUT parameter.") try: diff --git a/problems/arch-comp-2021/nn/benchmark.py b/problems/arch-comp-2021/nn/benchmark.py index 9ed11ff3..ab39e1dc 100644 --- a/problems/arch-comp-2021/nn/benchmark.py +++ b/problems/arch-comp-2021/nn/benchmark.py @@ -57,7 +57,7 @@ def build_specification(selected_specification, mode=None): elif selected_specification == "NNX": ref_input_range = [1.95, 2.05] else: - raise Exception("Unknown specification '{}'.".format(selected_specification)) + raise Exception(f"Unknown specification '{selected_specification}'.") sut_parameters = {"model_file": "nn/run_neural", "init_model_file": "nn/init_neural", @@ -75,10 +75,10 @@ def build_specification(selected_specification, mode=None): if selected_specification == "NN": alpha = 0.005 beta = 0.03 - inequality1 = "|POS - REF| > {} + {}*|REF|".format(alpha, beta) - inequality2 = "{} + {}*|REF| <= |POS - REF|".format(alpha, beta) + inequality1 = f"|POS - REF| > {alpha} + {beta}*|REF|" + inequality2 = f"{alpha} + {beta}*|REF| <= |POS - REF|" - specification = "always[1,37]( {} implies (eventually[0,2]( always[0,1] not {} )) )".format(inequality1, inequality2) + specification = f"always[1,37]( {inequality1} implies (eventually[0,2]( always[0,1] not {inequality2} )) )" specifications = [specification] strict_horizon_check = True @@ -87,13 +87,13 @@ def build_specification(selected_specification, mode=None): F2 = "eventually[1,1.5]( always[0,0.5](1.75 < POS and POS < 2.25) )" F3 = "always[2,3](1.825 < POS and POS < 2.175)" - conjunctive_specification = "{} and {} and {}".format(F1, F2, F3) + conjunctive_specification = f"{F1} and {F2} and {F3}" specifications = [conjunctive_specification] #specifications = [F1, F2, F3] strict_horizon_check = True else: - raise Exception("Unknown specification '{}'.".format(selected_specification)) + raise Exception(f"Unknown specification '{selected_specification}'.") return sut_parameters, specifications, strict_horizon_check @@ -109,15 +109,13 @@ def step_factory(): step_1 = Search(mode=mode, budget_threshold={"executions": 75}, algorithm=Random(model_factory=(lambda: Uniform())) - ) + ) step_2 = Search(mode=mode, budget_threshold={"executions": 300}, algorithm=OGAN(model_factory=(lambda: OGAN_Model(ogan_model_parameters["convolution"])), parameters=ogan_parameters) #algorithm=WOGAN(model_factory=(lambda: WOGAN_Model())) ) - #steps = [step_1] - steps = [step_1, step_2] - return steps + return [step_1, step_2] def get_step_factory(): return step_factory diff --git a/problems/arch-comp-2021/run.py b/problems/arch-comp-2021/run.py index fd2c04dd..6b235cdd 100644 --- a/problems/arch-comp-2021/run.py +++ b/problems/arch-comp-2021/run.py @@ -108,13 +108,17 @@ def main(selected_benchmark, selected_specification, mode, n, init_seed, identif if N > 1 and N_workers[selected_benchmark] > 1: os.environ["CUDA_VISIBLE_DEVICES"] = "" - if not selected_specification in specifications[selected_benchmark]: + if selected_specification not in specifications[selected_benchmark]: raise Exception("No specification '{}' for benchmark {}.".format(selected_specification, selected_benchmark)) def callback(idx, result, done): path = os.path.join("..", "..", "output", selected_benchmark) time = str(result.timestamp).replace(" ", "_") - file_name = "{}{}_{}.pickle.gz".format(selected_specification, "_" + identifier if identifier is not None else "", time) + file_name = "{}{}_{}.pickle.gz".format( + selected_specification, + f"_{identifier}" if identifier is not None else "", + time, + ) os.makedirs(path, exist_ok=True) result.dump_to_file(os.path.join(path, file_name)) diff --git a/problems/arch-comp-2021/run_hp.py b/problems/arch-comp-2021/run_hp.py index f3ff03c7..e903d7cd 100644 --- a/problems/arch-comp-2021/run_hp.py +++ b/problems/arch-comp-2021/run_hp.py @@ -21,7 +21,7 @@ @click.argument("init_seed_experiments", type=int) @click.argument("seed_hp", type=int) def main(selected_benchmark, selected_specification, mode, init_seed_experiments, seed_hp): - if not selected_specification in specifications[selected_benchmark]: + if selected_specification not in specifications[selected_benchmark]: raise Exception("No specification '{}' for benchmark {}.".format(selected_specification, selected_benchmark)) # Disable CUDA if multiprocessing is used. diff --git a/problems/arch-comp-2021/sc/benchmark.py b/problems/arch-comp-2021/sc/benchmark.py index 7198fb6a..df4c88db 100644 --- a/problems/arch-comp-2021/sc/benchmark.py +++ b/problems/arch-comp-2021/sc/benchmark.py @@ -66,14 +66,11 @@ def build_specification(selected_specification, mode=None): "sampling_step": 0.5 } - if selected_specification == "SC": - specification = "always[30,35](87 <= PRESSURE and PRESSURE <= 87.5)" - - specifications = [specification] - strict_horizon_check = True - else: - raise Exception("Unknown specification '{}'.".format(selected_specification)) + if selected_specification != "SC": + raise Exception(f"Unknown specification '{selected_specification}'.") + specifications = ["always[30,35](87 <= PRESSURE and PRESSURE <= 87.5)"] + strict_horizon_check = True return sut_parameters, specifications, strict_horizon_check def objective_selector_factory(): @@ -88,15 +85,13 @@ def step_factory(): step_1 = Search(mode=mode, budget_threshold={"executions": 75}, algorithm=Random(model_factory=(lambda: Uniform())) - ) + ) step_2 = Search(mode=mode, budget_threshold={"executions": 300}, algorithm=OGAN(model_factory=(lambda: OGAN_Model(ogan_model_parameters["convolution"])), parameters=ogan_parameters) #algorithm=WOGAN(model_factory=(lambda: WOGAN_Model())) ) - #steps = [step_1] - steps = [step_1, step_2] - return steps + return [step_1, step_2] def get_step_factory(): return step_factory diff --git a/problems/odroid/run.py b/problems/odroid/run.py index 2499d59d..5fe5bf53 100644 --- a/problems/odroid/run.py +++ b/problems/odroid/run.py @@ -126,7 +126,9 @@ def main(n, init_seed, identifier): def callback(idx, result, done): path = os.path.join("..", "..", "output", "Odroid") time = str(result.timestamp).replace(" ", "_") - file_name = "{}_{}.pickle.gz".format("Odroid_" + identifier if identifier is not None else "", time) + file_name = "{}_{}.pickle.gz".format( + f"Odroid_{identifier}" if identifier is not None else "", time + ) os.makedirs(path, exist_ok=True) result.dump_to_file(os.path.join(path, file_name)) diff --git a/problems/odroid/sut.py b/problems/odroid/sut.py index f99bb239..148091b9 100644 --- a/problems/odroid/sut.py +++ b/problems/odroid/sut.py @@ -39,9 +39,11 @@ def _load_odroid_data(self): if not os.path.exists(self.data_file): if not self.data_file.endswith(".npy"): raise Exception("The Odroid data file does not have extension .npy.") - csv_file = self.data_file[:-4] + ".csv" + csv_file = f"{self.data_file[:-4]}.csv" if not os.path.exists(csv_file): - raise Exception("No Odroid csv file '{}' available for data generation.".format(csv_file)) + raise Exception( + f"No Odroid csv file '{csv_file}' available for data generation." + ) generate_odroid_data(csv_file) data = np.load(self.data_file) @@ -71,11 +73,10 @@ def _execute_test(self, sut_input): """ test = sut_input.inputs - if not (test.shape == (1, self.idim) or test.shape == (self.idim,)): + if test.shape not in [(1, self.idim), (self.idim,)]: raise ValueError("Input array expected to have shape (1, {0}) or ({0}).".format(self.ndimensions)) distances = np.sum((self.dataX - test)**2, axis=1) retdata = self.dataY[np.argmin(distances)] - output = SUTOutput(retdata, None, None, None) - return output + return SUTOutput(retdata, None, None, None) diff --git a/problems/odroid/util.py b/problems/odroid/util.py index cb035bb6..a7952086 100644 --- a/problems/odroid/util.py +++ b/problems/odroid/util.py @@ -27,12 +27,8 @@ def generate_odroid_data(data_file): encoding = {} def encode(s): - if not s in encoding: - if len(encoding) == 0: - encoding[s] = 0 - else: - encoding[s] = max(encoding.values()) + 1 - + if s not in encoding: + encoding[s] = 0 if not encoding else max(encoding.values()) + 1 return encoding[s] data = [] @@ -69,4 +65,4 @@ def encode(s): data.append(new) - np.save(data_file[:-4] + ".npy", data) + np.save(f"{data_file[:-4]}.npy", data) diff --git a/problems/sbst/code_pipeline/beamng_executor.py b/problems/sbst/code_pipeline/beamng_executor.py index bf46eb2c..3e3221b1 100644 --- a/problems/sbst/code_pipeline/beamng_executor.py +++ b/problems/sbst/code_pipeline/beamng_executor.py @@ -99,16 +99,16 @@ def _is_the_car_moving(self, last_state): self.last_observation = last_state return True - # If the car moved since the last observation, we store the last state and move one - if Point(self.last_observation.pos[0],self.last_observation.pos[1]).distance(Point(last_state.pos[0], last_state.pos[1])) > self.min_delta_position: - self.last_observation = last_state - return True - else: + if ( + Point( + self.last_observation.pos[0], self.last_observation.pos[1] + ).distance(Point(last_state.pos[0], last_state.pos[1])) + <= self.min_delta_position + ): # How much time has passed since the last observation? - if last_state.timer - self.last_observation.timer > 10.0: - return False - else: - return True + return last_state.timer - self.last_observation.timer <= 10.0 + self.last_observation = last_state + return True def _run_simulation(self, the_test) -> SimulationData: if not self.brewer: @@ -173,9 +173,13 @@ def _run_simulation(self, the_test) -> SimulationData: if points_distance(last_state.pos, waypoint_goal.position) < 8.0: break - assert self._is_the_car_moving(last_state), "Car is not moving fast enough " + str(sim_data_collector.name) + assert self._is_the_car_moving( + last_state + ), f"Car is not moving fast enough {str(sim_data_collector.name)}" - assert not last_state.is_oob, "Car drove out of the lane " + str(sim_data_collector.name) + assert ( + not last_state.is_oob + ), f"Car drove out of the lane {str(sim_data_collector.name)}" beamng.step(steps) diff --git a/problems/sbst/code_pipeline/dave2_executor.py b/problems/sbst/code_pipeline/dave2_executor.py index c61bff97..00eb012c 100644 --- a/problems/sbst/code_pipeline/dave2_executor.py +++ b/problems/sbst/code_pipeline/dave2_executor.py @@ -105,16 +105,16 @@ def _is_the_car_moving(self, last_state): self.last_observation = last_state return True - # If the car moved since the last observation, we store the last state and move one - if Point(self.last_observation.pos[0],self.last_observation.pos[1]).distance(Point(last_state.pos[0], last_state.pos[1])) > self.min_delta_position: - self.last_observation = last_state - return True - else: + if ( + Point( + self.last_observation.pos[0], self.last_observation.pos[1] + ).distance(Point(last_state.pos[0], last_state.pos[1])) + <= self.min_delta_position + ): # How much time has passed since the last observation? - if last_state.timer - self.last_observation.timer > 10.0: - return False - else: - return True + return last_state.timer - self.last_observation.timer <= 10.0 + self.last_observation = last_state + return True def _run_simulation(self, the_test) -> SimulationData: if not self.brewer: @@ -183,9 +183,13 @@ def _run_simulation(self, the_test) -> SimulationData: if points_distance(last_state.pos, waypoint_goal.position) < 8.0: break - assert self._is_the_car_moving(last_state), "Car is not moving fast enough " + str(sim_data_collector.name) + assert self._is_the_car_moving( + last_state + ), f"Car is not moving fast enough {str(sim_data_collector.name)}" - assert not last_state.is_oob, "Car drove out of the lane " + str(sim_data_collector.name) + assert ( + not last_state.is_oob + ), f"Car drove out of the lane {str(sim_data_collector.name)}" img = vehicle_state_reader.sensors['cam_center']['colour'].convert('RGB') # TODO diff --git a/problems/sbst/code_pipeline/test_analysis.py b/problems/sbst/code_pipeline/test_analysis.py index 869f3280..45f4a122 100644 --- a/problems/sbst/code_pipeline/test_analysis.py +++ b/problems/sbst/code_pipeline/test_analysis.py @@ -30,7 +30,7 @@ def _calc_angle_distance(v0, v1): def _calc_dist_angle(points): - assert len(points) >= 2, f'at least two points are needed' + assert len(points) >= 2, 'at least two points are needed' def vector(idx): return np.subtract(points[idx + 1], points[idx]) @@ -120,9 +120,7 @@ def max_curvature(the_test, w=5): # Standard Deviation of Steering Angle accounts for the variability of the steering # angle during the execution def sd_steering(execution_data): - steering = [] - for state in execution_data: - steering.append(state.steering) + steering = [state.steering for state in execution_data] sd_steering = np.std(steering) return "STD_SA", sd_steering @@ -130,25 +128,19 @@ def sd_steering(execution_data): # Mean of Lateral Position of the car accounts for the average behavior of the car, i.e., # whether it spent most of the time traveling in the center or on the side of the lane def mean_lateral_position(execution_data): - lp = [] - for state in execution_data: - lp.append(state.oob_distance) - + lp = [state.oob_distance for state in execution_data] mean_lp = np.mean(lp) return "MEAN_LP", mean_lp def max_lateral_position(execution_data): - lp = [] - for state in execution_data: - lp.append(state.oob_distance) - + lp = [state.oob_distance for state in execution_data] max_lp = np.max(lp) return "MAX_LP", max_lp def compute_all_features(the_test, execution_data): - features = dict() + features = {} # Structural Features structural_features = [max_curvature, direction_coverage] diff --git a/problems/sbst/code_pipeline/test_generation_utils.py b/problems/sbst/code_pipeline/test_generation_utils.py index 8550e055..607cfa0f 100644 --- a/problems/sbst/code_pipeline/test_generation_utils.py +++ b/problems/sbst/code_pipeline/test_generation_utils.py @@ -74,10 +74,9 @@ def fun_wrapper(): _executed_exit_funs.add(fun) def signal_wrapper(signum=None, frame=None): - if signum is not None: - if logfun is not None: - logfun("signal {} received by process with PID {}".format( - stringify_sig(signum), os.getpid())) + if signum is not None and logfun is not None: + logfun("signal {} received by process with PID {}".format( + stringify_sig(signum), os.getpid())) fun_wrapper() # Only return the original signal this process was hit with # in case fun returns with no errors, otherwise process will @@ -91,7 +90,7 @@ def signal_wrapper(signum=None, frame=None): def register_fun(fun, signals): if not callable(fun): raise TypeError("{!r} is not callable".format(fun)) - set([fun]) # raise exc if obj is not hash-able + {fun} signals = set(signals) for sig in signals: diff --git a/problems/sbst/code_pipeline/tests_evaluation.py b/problems/sbst/code_pipeline/tests_evaluation.py index 57d1e70d..bc38695c 100644 --- a/problems/sbst/code_pipeline/tests_evaluation.py +++ b/problems/sbst/code_pipeline/tests_evaluation.py @@ -49,11 +49,14 @@ def _interpolate_and_resample_splines(sample_nodes, nodes_per_meter = 1, smoothn new_x_vals, new_y_vals = splev(unew, pos_tck) # Reduce floating point rounding errors otherwise these may cause problems with calculating parallel_offset - return list(zip([round(v, rounding_precision) for v in new_x_vals], - [round(v, rounding_precision) for v in new_y_vals], - # TODO Brutally hard-coded - [-28.0 for v in new_x_vals], - [8.0 for w in new_x_vals])) + return list( + zip( + [round(v, rounding_precision) for v in new_x_vals], + [round(v, rounding_precision) for v in new_y_vals], + [-28.0 for _ in new_x_vals], + [8.0 for _ in new_x_vals], + ) + ) def _find_circle_and_return_the_center_and_the_radius(x1, y1, x2, y2, x3, y3): @@ -234,15 +237,12 @@ def _identify_segments(nodes): else: type = "turn" - current_segment = {} - - current_segment["type"] = type - current_segment["center"] = center - current_segment["radius"] = radius - current_segment["points"] = [] - current_segment["points"].append(three_points[0]) - current_segment["points"].append(three_points[1]) - current_segment["points"].append(three_points[2]) + current_segment = { + "type": type, + "center": center, + "radius": radius, + "points": [three_points[0], three_points[1], three_points[2]], + } segments.append(current_segment) @@ -269,7 +269,7 @@ def _identify_segments(nodes): # If two consecutive segments are similar we put them together for s in segments: - if len(refined_segments) == 0: + if not refined_segments: refined_segments.append(s) elif refined_segments[-1]["type"] == "straight" and s["type"] == "straight": # print("Merging ", refined_segments[-1], "and", s) @@ -286,8 +286,8 @@ def _identify_segments(nodes): segments = [] # Move forward - for index, segment in enumerate(refined_segments[:]): - if len(segments) == 0: + for segment in refined_segments[:]: + if not segments: segments.append(segment) elif len(segment["points"]) <= 5: @@ -304,8 +304,8 @@ def _identify_segments(nodes): refined_segments = segments[:] reversed(refined_segments) segments = [] - for index, segment in enumerate(refined_segments[:]): - if len(segments) == 0: + for segment in refined_segments[:]: + if not segments: segments.append(segment) elif len(segment["points"]) <= 5: @@ -364,7 +364,7 @@ def identify_interesting_road_segments(self, road_nodes, execution_data): oob_pos = Point(record.pos[0], record.pos[1]) break - if oob_pos == None: + if oob_pos is None: # No oob, no interesting segments and we cannot tell whether the OOB was left/rigth return None, None, None, None @@ -375,11 +375,7 @@ def identify_interesting_road_segments(self, road_nodes, execution_data): # if the distance between oob and the center of the road is greater than 2.0 (half of lane) then the oob is # on the right side, otherwise on the left side # - if oob_pos.distance(road_line) < 2.0: - oob_side = "LEFT" - else: - oob_side = "RIGHT" - + oob_side = "LEFT" if oob_pos.distance(road_line) < 2.0 else "RIGHT" # https://gis.stackexchange.com/questions/84512/get-the-vertices-on-a-linestring-either-side-of-a-point before = None after = None @@ -387,7 +383,7 @@ def identify_interesting_road_segments(self, road_nodes, execution_data): road_coords = list(road_line.coords) for i, p in enumerate(road_coords): if Point(p).distance(np) < 0.5: # Since we interpolate at every meter, whatever is closer than half of if - before = road_coords[0:i] + before = road_coords[:i] before.append(np.coords[0]) after = road_coords[i:] @@ -397,7 +393,7 @@ def identify_interesting_road_segments(self, road_nodes, execution_data): temp = [] for p1, p2 in _window(reversed(before), 2): - if len(temp) == 0: + if not temp: temp.append(p1) distance += LineString([p1, p2]).length @@ -413,7 +409,7 @@ def identify_interesting_road_segments(self, road_nodes, execution_data): temp = [] for p1, p2 in _window(after, 2): - if len(temp) == 0: + if not temp: temp.append(p1) distance += LineString([p1, p2]).length @@ -457,7 +453,7 @@ def _load_oobs_from(self, result_folder): # If the test is not valid or passed we skip it the analysis - if not is_valid or not test_outcome == "FAIL": + if not is_valid or test_outcome != "FAIL": self.logger.debug("\t Test is invalid") continue @@ -515,16 +511,22 @@ def _compute_sparseness(self): self.logger.debug("Distance of OOB %s from OOB %s is %.3f", oob1["test id"], oob2["test id"], distance) # Update the max values - if oob1['test id'] in max_distances_starting_from.keys(): + if oob1['test id'] in max_distances_starting_from: max_distances_starting_from[oob1['test id']] = max( max_distances_starting_from[oob1['test id']], distance) else: max_distances_starting_from[oob1['test id']] = distance - mean_distance = np.mean([list(max_distances_starting_from.values())]) if len( - max_distances_starting_from) > 0 else np.NaN - std_dev = np.std([list(max_distances_starting_from.values())]) if len( - max_distances_starting_from) > 0 else np.NaN + mean_distance = ( + np.mean([list(max_distances_starting_from.values())]) + if max_distances_starting_from + else np.NaN + ) + std_dev = ( + np.std([list(max_distances_starting_from.values())]) + if max_distances_starting_from + else np.NaN + ) self.logger.debug("Sparseness: Mean: %.3f, StdDev: %3f", mean_distance, std_dev) @@ -550,12 +552,10 @@ def _analyse(self): mean_sparseness, stdev_sparseness = self._compute_sparseness() n_oobs_on_the_left, n_oobs_on_the_right = self._compute_oob_side_stats() - report_data = {} - - report_data["sparseness"] = (mean_sparseness, stdev_sparseness) - report_data["oob_side"] = (n_oobs_on_the_left, n_oobs_on_the_right) - - return report_data + return { + "sparseness": (mean_sparseness, stdev_sparseness), + "oob_side": (n_oobs_on_the_left, n_oobs_on_the_right), + } def create_summary(self): diff --git a/problems/sbst/code_pipeline/tests_generation.py b/problems/sbst/code_pipeline/tests_generation.py index fc9b9e13..8b4998d6 100644 --- a/problems/sbst/code_pipeline/tests_generation.py +++ b/problems/sbst/code_pipeline/tests_generation.py @@ -21,9 +21,7 @@ def _interpolate(the_test): # This is an approximation based on whatever input is given test_road_lenght = LineString([(t[0], t[1]) for t in the_test]).length num_nodes = int(test_road_lenght / interpolation_distance) - if num_nodes < min_num_nodes: - num_nodes = min_num_nodes - + num_nodes = max(num_nodes, min_num_nodes) assert len(old_x_vals) >= 2, "You need at leas two road points to define a road" assert len(old_y_vals) >= 2, "You need at leas two road points to define a road" @@ -94,12 +92,14 @@ def set_validity(self, is_valid, validation_message): self.validation_message = validation_message def to_json(self): - theobj = {} - # Statically generated attributes - theobj['is_valid'] = self.is_valid - theobj['validation_message'] = self.validation_message - theobj['road_points'] = self.road_points - theobj['interpolated_points'] = [(p[0], p[1]) for p in self.interpolated_points] + theobj = { + 'is_valid': self.is_valid, + 'validation_message': self.validation_message, + 'road_points': self.road_points, + 'interpolated_points': [ + (p[0], p[1]) for p in self.interpolated_points + ], + } # Dynamically generated attributes. # https://stackoverflow.com/questions/610883/how-to-know-if-an-object-has-an-attribute-in-python # "easier to ask for forgiveness than permission" (EAFP) @@ -166,14 +166,20 @@ def __init__(self): def __str__(self): msg = "" - msg += "test generated: " + str(self.test_generated) + "\n" - msg += "test valid: " + str(self.test_valid) + "\n" - msg += "test invalid: " + str(self.test_invalid) + "\n" - msg += "test passed: " + str(self.test_passed) + "\n" - msg += "test failed: " + str(self.test_failed) + "\n" - msg += "test in_error: " + str(self.test_in_error) + "\n" - msg += "(real) time spent in execution :" + str(sum(self.test_execution_real_times)) + "\n" - msg += "(simulated) time spent in execution :" + str(sum(self.test_execution_simulation_times)) + "\n" + msg += f"test generated: {str(self.test_generated)}" + "\n" + msg += f"test valid: {str(self.test_valid)}" + "\n" + msg += f"test invalid: {str(self.test_invalid)}" + "\n" + msg += f"test passed: {str(self.test_passed)}" + "\n" + msg += f"test failed: {str(self.test_failed)}" + "\n" + msg += f"test in_error: {str(self.test_in_error)}" + "\n" + msg += ( + f"(real) time spent in execution :{str(sum(self.test_execution_real_times))}" + + "\n" + ) + msg += ( + f"(simulated) time spent in execution :{str(sum(self.test_execution_simulation_times))}" + + "\n" + ) return msg def as_csv(self): diff --git a/problems/sbst/code_pipeline/validation.py b/problems/sbst/code_pipeline/validation.py index 402aec54..6098e05b 100644 --- a/problems/sbst/code_pipeline/validation.py +++ b/problems/sbst/code_pipeline/validation.py @@ -24,8 +24,7 @@ def find_circle(p1, p2, p3): cx = (bc*(p2[1] - p3[1]) - cd*(p1[1] - p2[1])) / det cy = ((p1[0] - p2[0]) * cd - (p2[0] - p3[0]) * bc) / det - radius = np.sqrt((cx - p1[0])**2 + (cy - p1[1])**2) - return radius + return np.sqrt((cx - p1[0])**2 + (cy - p1[1])**2) def min_radius(x, w=5): @@ -65,11 +64,7 @@ def is_not_self_intersecting(self, the_test): return road_polygon.is_valid() def is_too_sharp(self, the_test, TSHD_RADIUS=47): - if TSHD_RADIUS > min_radius(the_test.interpolated_points) > 0.0: - check = True - else: - check = False - return check + return TSHD_RADIUS > min_radius(the_test.interpolated_points) > 0.0 def is_inside_map(self, the_test): """ @@ -81,27 +76,30 @@ def is_inside_map(self, the_test): min_x, max_x = min(xs), max(xs) min_y, max_y = min(ys), max(ys) - return 0 < min_x or min_x > self.map_size and \ - 0 < max_x or max_x > self.map_size and \ - 0 < min_y or min_y > self.map_size and \ - 0 < max_y or max_y > self.map_size + return ( + min_x > 0 + or min_x > self.map_size + and max_x > 0 + or max_x > self.map_size + and min_y > 0 + or min_y > self.map_size + and max_y > 0 + or max_y > self.map_size + ) def is_right_type(self, the_test): """ The type of the_test must be RoadTest """ - check = type(the_test) is RoadTestFactory.RoadTest - return check + return type(the_test) is RoadTestFactory.RoadTest def is_valid_polygon(self, the_test): road_polygon = the_test.get_road_polygon() - check = road_polygon.is_valid() - return check + return road_polygon.is_valid() def intersects_boundary(self, the_test): road_polygon = the_test.get_road_polygon() - check = self.road_bbox.intersects_boundary(road_polygon.polygon) - return check + return self.road_bbox.intersects_boundary(road_polygon.polygon) def is_minimum_length(self, the_test): # This is approximated because at this point the_test is not yet interpolated diff --git a/problems/sbst/code_pipeline/visualization.py b/problems/sbst/code_pipeline/visualization.py index 9839c231..c0b4e75a 100644 --- a/problems/sbst/code_pipeline/visualization.py +++ b/problems/sbst/code_pipeline/visualization.py @@ -42,14 +42,16 @@ def visualize_road_test(self, the_test): # Add information about the test validity title_string = "" if the_test.is_valid is not None: - title_string = title_string + "Test is " + ("valid" if the_test.is_valid else "invalid") + title_string = f"{title_string}Test is " + ( + "valid" if the_test.is_valid else "invalid" + ) if not the_test.is_valid: - title_string = title_string + ":" + the_test.validation_message + title_string = f"{title_string}:{the_test.validation_message}" plt.suptitle(title_string, fontsize=14) plt.draw() plt.pause(0.001) - + # Plot the map. Trying to re-use an artist in more than one Axes which is supported map_patch = patches.Rectangle((0, 0), self.map_size, self.map_size, linewidth=1, edgecolor='black', facecolor='none') plt.gca().add_patch(map_patch) @@ -98,7 +100,7 @@ def visualize_road_test(self, the_test): if the_test.is_valid is not None: title_string = " ".join([title_string, "Test", str(the_test.id), "is" , ("valid" if the_test.is_valid else "invalid")]) if not the_test.is_valid: - title_string = title_string + ":" + the_test.validation_message + title_string = f"{title_string}:{the_test.validation_message}" plt.suptitle(title_string, fontsize=14) plt.draw() diff --git a/problems/sbst/convert_sbst_stgem.py b/problems/sbst/convert_sbst_stgem.py index 0e019c7b..1f72ec40 100644 --- a/problems/sbst/convert_sbst_stgem.py +++ b/problems/sbst/convert_sbst_stgem.py @@ -14,11 +14,11 @@ base_path = sys.argv[1] if not os.path.exists(base_path): - raise SystemExit("Directory {} not found.".format(base_path)) + raise SystemExit(f"Directory {base_path} not found.") output_path = sys.argv[2] if not os.path.exists(output_path): - raise SystemExit("Directory {} not found.".format(output_path)) + raise SystemExit(f"Directory {output_path} not found.") identifier = sys.argv[3] @@ -40,7 +40,6 @@ csv_file = os.path.join(dir_name, "generation_stats.csv") if not os.path.exists(csv_file): continue - raise Exception("No generation_stats.csv in {}. Is the replica incomplete?".format(dir_name)) data = pd.read_csv(csv_file) tests = int(data["test_generated"][0]) generation_time = float(data["real_time_generation"][0]) / tests @@ -65,7 +64,7 @@ # --------------------------------------------------------------------- # Fields in order: timer, pos, dir, vel, steering, steering_input, brake, brake_input, throttle, throttle_input, wheelspeed, vel_kmh, is_oob, oob_counter, max_oob_percentage, oob_distance, oob_percentage # The final test whose execution is partially saved and needs to be omitted. - if not "execution_data" in data: continue + if "execution_data" not in data: continue timestamps = np.zeros(len(data["execution_data"])) signals = np.zeros(shape=(4, len(data["execution_data"]))) for i, state in enumerate(data["execution_data"]): @@ -100,13 +99,13 @@ ) result = STGEMResult( - description="SBST converted results replica {}".format(i + 1), - sut_name="BeamNG", - sut_parameters={}, - seed=None, - step_results=[step_result], - test_repository=test_repository - ) + description=f"SBST converted results replica {i + 1}", + sut_name="BeamNG", + sut_parameters={}, + seed=None, + step_results=[step_result], + test_repository=test_repository, + ) results.append(result) for i, result in enumerate(results): diff --git a/problems/sbst/run.py b/problems/sbst/run.py index 2c1390d3..6d9209fb 100644 --- a/problems/sbst/run.py +++ b/problems/sbst/run.py @@ -10,14 +10,14 @@ identifier = sys.argv[2] if len(sys.argv) > 2 else None if not os.path.exists(python_exe): - raise Exception("No Python executable {}.".format(python_exe)) + raise Exception(f"No Python executable {python_exe}.") def run_on_powershell(python_exe, seed, identifier=None): python_exe = python_exe.strip() if identifier is None: - command = "{} sbst.py 1 {}".format(python_exe, seed) + command = f"{python_exe} sbst.py 1 {seed}" else: - command = "{} sbst.py 1 {} {}".format(python_exe, seed, identifier) + command = f"{python_exe} sbst.py 1 {seed} {identifier}" p = subprocess.Popen(["powershell.exe", command], stdout=sys.stdout) p.communicate() diff --git a/problems/sbst/sbst.py b/problems/sbst/sbst.py index a610d669..c10171f1 100644 --- a/problems/sbst/sbst.py +++ b/problems/sbst/sbst.py @@ -226,7 +226,9 @@ def seed_generator(init_seed): def result_callback(idx, result, done): path = os.path.join("..", "..", "output", "sbst") time = str(result.timestamp).replace(" ", "_").replace(":", "") - file_name = "SBST{}_{}_{}.pickle.gz".format("_" + identifier if len(identifier) > 0 else "", time, idx) + file_name = "SBST{}_{}_{}.pickle.gz".format( + f"_{identifier}" if len(identifier) > 0 else "", time, idx + ) os.makedirs(path, exist_ok=True) result.dump_to_file(os.path.join(path, file_name)) diff --git a/problems/sbst/sbst_results_analysis.sync.py b/problems/sbst/sbst_results_analysis.sync.py index 1fc48497..d8c96f96 100644 --- a/problems/sbst/sbst_results_analysis.sync.py +++ b/problems/sbst/sbst_results_analysis.sync.py @@ -40,36 +40,35 @@ def road_visualization(result, start, end): # Input range for descaling tests. input_range = [-result.sut_parameters["curvature_range"], result.sut_parameters["curvature_range"]] - + fig, axes = plt.subplots(rows, columns, figsize=(64, 64), sharex = True, sharey = True) plt.xticks([]) plt.yticks([]) idx = 0 failed_cnt = 0 - for row in range(rows): - for column in range(columns): - _input, _, _objective = result.test_repository.get(start + idx) - robustness = round(_objective[0], 3) - axes[row, column].title.set_text(f"[{(start+idx)}] - Robustness: {robustness}") - - # Highlight the roads that produced a failed test - if robustness <= 0.05: - color = "r" - failed_cnt += 1 - else: - color = "b" - - # Plot interpolated points connected by lines. - x, y = _input.input_denormalized - axes[row,column].plot(x, y, color=color) - - # Plot the control points. - points = np.array(test_to_road_points(descale(_input.inputs, input_range), result.sut_parameters["step_length"], result.sut_parameters["map_size"])) - axes[row,column].plot(points[:,0], points[:,1], "{}o".format(color)) - - idx += 1 - - fig.suptitle(f'Road visualization of {idx} test runs where {failed_cnt} failed - Seed: {result.seed}', fontsize=40) + for row, column in itertools.product(range(rows), range(columns)): + _input, _, _objective = result.test_repository.get(start + idx) + robustness = round(_objective[0], 3) + axes[row, column].title.set_text(f"[{(start+idx)}] - Robustness: {robustness}") + + # Highlight the roads that produced a failed test + if robustness <= 0.05: + color = "r" + failed_cnt += 1 + else: + color = "b" + + # Plot interpolated points connected by lines. + x, y = _input.input_denormalized + axes[row,column].plot(x, y, color=color) + + # Plot the control points. + points = np.array(test_to_road_points(descale(_input.inputs, input_range), result.sut_parameters["step_length"], result.sut_parameters["map_size"])) + axes[row,column].plot(points[:,0], points[:,1], f"{color}o") + + idx += 1 + + fig.suptitle(f'Road visualization of {idx} test runs where {failed_cnt} failed - Seed: {result.seed}', fontsize=40) #plt.savefig(f'road_images/{filename}.png', pad_inches=0.1, dpi=150) #plt.close(fig) plt.show() @@ -98,10 +97,7 @@ def move_road(P, x0, y0): Q[n,0] = math.cos(angle) * x - math.sin(angle) * y + x0 Q[n,1] = math.sin(angle) * x + math.cos(angle) * y + y0 - if isinstance(P, list): - return Q.tolist() - else: - return Q + return Q.tolist() if isinstance(P, list) else Q def steering_sd(test_repository): """Compute the standard deviation of the steering angles for each test in @@ -110,9 +106,7 @@ def steering_sd(test_repository): _, Z, _ = test_repository.get() - data = [np.std(sut_output.outputs[3]) for sut_output in Z] - - return data + return [np.std(sut_output.outputs[3]) for sut_output in Z] def direction_coverage(test_repository, bins=36): """Compute the coverage of road directions of the test suite. That is, for diff --git a/problems/sbst/self_driving/bbox.py b/problems/sbst/self_driving/bbox.py index 6d1d5be4..32e3014c 100644 --- a/problems/sbst/self_driving/bbox.py +++ b/problems/sbst/self_driving/bbox.py @@ -10,16 +10,10 @@ def __init__(self, bbox_size: Tuple[float, float, float, float]): self.bbox = box(*bbox_size) def intersects_sides(self, point: Point) -> bool: - for side in self.get_sides(): - if side.intersects(point): - return True - return False + return any(side.intersects(point) for side in self.get_sides()) def intersects_vertices(self, point: Point) -> bool: - for vertex in self.get_vertices(): - if vertex.intersects(point): - return True - return False + return any(vertex.intersects(point) for vertex in self.get_vertices()) def intersects_boundary(self, other: Polygon) -> bool: return other.intersects(self.bbox.boundary) @@ -28,12 +22,9 @@ def contains(self, other: RoadPolygon) -> bool: return self.bbox.contains(other.polyline) def get_sides(self) -> List[LineString]: - sides = [] xs, ys = self.bbox.exterior.coords.xy xys = list(zip(xs, ys)) - for p1, p2 in zip(xys[:-1], xys[1:]): - sides.append(LineString([p1, p2])) - return sides + return [LineString([p1, p2]) for p1, p2 in zip(xys[:-1], xys[1:])] def get_vertices(self) -> List[Point]: xs, ys = self.bbox.exterior.coords.xy diff --git a/problems/sbst/self_driving/beamng_brewer.py b/problems/sbst/self_driving/beamng_brewer.py index 7b827b43..9248cda6 100644 --- a/problems/sbst/self_driving/beamng_brewer.py +++ b/problems/sbst/self_driving/beamng_brewer.py @@ -26,8 +26,7 @@ def get_rgb_image(self): self.camera.pos = self.pose.pos self.camera.direction = self.pose.rot cam = self.beamng.render_cameras() - img = cam[self.name]['colour'].convert('RGB') - return img + return cam[self.name]['colour'].convert('RGB') class BeamNGBrewer: diff --git a/problems/sbst/self_driving/beamng_tig_maps.py b/problems/sbst/self_driving/beamng_tig_maps.py index af39412b..b63dee23 100644 --- a/problems/sbst/self_driving/beamng_tig_maps.py +++ b/problems/sbst/self_driving/beamng_tig_maps.py @@ -34,14 +34,14 @@ def version_info(self): return json.load(f) def tig_version_json_path(self): - return self.path + '/' + return f'{self.path}/' def delete_all_map(self): print(f'Removing [{self.path}]') shutil.rmtree(self.path, ignore_errors=True) # sometimes rmtree fails to remove files - for tries in range(20): + for _ in range(20): if os.path.exists(self.path): time.sleep(0.1) shutil.rmtree(self.path, ignore_errors=True) @@ -71,7 +71,7 @@ class Maps: def __init__(self): self.beamng_levels = LevelsFolder(os.path.join(os.environ['USERPROFILE'], r'Documents/BeamNG.research/levels')) - self.source_levels = LevelsFolder(os.getcwd()+'/levels_template') + self.source_levels = LevelsFolder(f'{os.getcwd()}/levels_template') self.source_map = self.source_levels.get_map('tig') self.beamng_map = self.beamng_levels.get_map('tig') self.never_logged_path = True @@ -93,18 +93,18 @@ def install_map_if_needed(self): f'It does not contains the distinctive file [{self.beamng_map.tig_version_json_path}]') print('Stopping execution') exit(1) - else: - if not self.beamng_map.same_version(self.source_map): - print(f'Maps have different version information. ' - f'Do you want to remove all {self.beamng_map.path} folder and copy it anew?' - f'.\nType yes to accept, no to keep it as it is') - while True: - resp = input('>') - if resp in ['yes', 'no']: - break + elif not self.beamng_map.same_version(self.source_map): + print(f'Maps have different version information. ' + f'Do you want to remove all {self.beamng_map.path} folder and copy it anew?' + f'.\nType yes to accept, no to keep it as it is') + while True: + resp = input('>') + if resp in ['yes', 'no']: + break + else: print('Type yes or no') - if resp == 'yes': - self.beamng_map.delete_all_map() + if resp == 'yes': + self.beamng_map.delete_all_map() if not self.beamng_map.exists(): print(f'Copying from [{self.source_map.path}] to [{self.beamng_map.path}]') diff --git a/problems/sbst/self_driving/beamng_waypoint.py b/problems/sbst/self_driving/beamng_waypoint.py index ec5d7273..a8d726bf 100644 --- a/problems/sbst/self_driving/beamng_waypoint.py +++ b/problems/sbst/self_driving/beamng_waypoint.py @@ -9,11 +9,12 @@ def __init__(self, name, position, persistentId=None): self.persistentId = persistentId if persistentId else str(uuid.uuid4()) def to_json(self): - obj = {} - obj['name'] = self.name - obj['class'] = 'BeamNGWaypoint' - obj['persistentId'] = self.persistentId - obj['__parent'] = 'generated' - obj['position'] = self.position - obj['scale'] = [4, 4, 4] + obj = { + 'name': self.name, + 'class': 'BeamNGWaypoint', + 'persistentId': self.persistentId, + '__parent': 'generated', + 'position': self.position, + 'scale': [4, 4, 4], + } return json.dumps(obj) diff --git a/problems/sbst/self_driving/decal_road.py b/problems/sbst/self_driving/decal_road.py index f4b29976..368571af 100644 --- a/problems/sbst/self_driving/decal_road.py +++ b/problems/sbst/self_driving/decal_road.py @@ -18,10 +18,10 @@ def __init__(self, name, def add_4d_points(self, nodes: List[Tuple[float, float, float, float]]): self._safe_add_nodes(nodes) - assert len(nodes) > 0, 'nodes should be a non empty list' + assert nodes, 'nodes should be a non empty list' assert all(len(item) == 4 for item in nodes), 'nodes list should contain tuple of 4 elements' assert all(all(isinstance(val, float) for val in item) for item in nodes), \ - 'points list can contain only float' + 'points list can contain only float' self.nodes += [list(item) for item in nodes] return self @@ -41,17 +41,18 @@ def _safe_add_nodes(self, nodes: List): def to_json(self): assert len(self.nodes) > 0, 'there are no points in this road' - roadobj = {} - roadobj['name'] = self.name - roadobj['class'] = 'DecalRoad' - roadobj['breakAngle'] = 180 - roadobj['distanceFade'] = [1000, 1000] - roadobj['drivability'] = self.drivability - roadobj['material'] = self.material - roadobj['overObjects'] = True - roadobj['persistentId'] = self.persistentId - roadobj['__parent'] = 'generated' - roadobj['position'] = tuple(self.nodes[0][:3]) # keep x,y,z discard width - roadobj['textureLength'] = 2.5 - roadobj['nodes'] = self.nodes + roadobj = { + 'name': self.name, + 'class': 'DecalRoad', + 'breakAngle': 180, + 'distanceFade': [1000, 1000], + 'drivability': self.drivability, + 'material': self.material, + 'overObjects': True, + 'persistentId': self.persistentId, + '__parent': 'generated', + 'position': tuple(self.nodes[0][:3]), + 'textureLength': 2.5, + 'nodes': self.nodes, + } return json.dumps(roadobj) diff --git a/problems/sbst/self_driving/edit_distance_polyline.py b/problems/sbst/self_driving/edit_distance_polyline.py index 9f96654c..5120ecea 100644 --- a/problems/sbst/self_driving/edit_distance_polyline.py +++ b/problems/sbst/self_driving/edit_distance_polyline.py @@ -15,13 +15,7 @@ def _calc_cost_discrete(u: AngleLength, v: AngleLength): # print(str(delta_angle)) eps_angle = 0.3 eps_len = 0.2 - if delta_angle < eps_angle and delta_len < eps_len: - res = 0 - else: - res = 2 - - # res = 1 / 2 * (delta_angle / (1 + delta_angle) + delta_len / (1 + delta_len)) - return res + return 0 if delta_angle < eps_angle and delta_len < eps_len else 2 def _calc_cost_weighted(u: AngleLength, v: AngleLength): @@ -29,11 +23,13 @@ def _calc_cost_weighted(u: AngleLength, v: AngleLength): delta_angle = np.abs((delta_angle + 180) % 360 - 180) eps_angle = 0.3 eps_len = 0.2 - if delta_angle < eps_angle and delta_len < eps_len: - res = 0 - else: - res = 1 / 2 * (delta_angle / (1 + delta_angle) + delta_len / (1 + delta_len)) - return res + return ( + 0 + if delta_angle < eps_angle and delta_len < eps_len + else 1 + / 2 + * (delta_angle / (1 + delta_angle) + delta_len / (1 + delta_len)) + ) #_calc_cost = _calc_cost_discrete @@ -51,7 +47,7 @@ def _iterative_levenshtein_dist_angle(s: ListOfAngleLength, t: ListOfAngleLength """ rows = len(s) + 1 cols = len(t) + 1 - dist = [[0 for x in range(cols)] for x in range(rows)] + dist = [[0 for _ in range(cols)] for _ in range(rows)] # source prefixes can be transformed into empty strings # by deletions: for i in range(1, rows): @@ -80,7 +76,7 @@ def _calc_angle_distance(v0, v1): def _calc_dist_angle(points: ListOfPoints) -> ListOfAngleLength: - assert len(points) >= 2, f'at least two points are needed' + assert len(points) >= 2, 'at least two points are needed' def vector(idx): return np.subtract(points[idx + 1], points[idx]) diff --git a/problems/sbst/self_driving/nvidia_prediction.py b/problems/sbst/self_driving/nvidia_prediction.py index 5b9e4e30..faf1b113 100644 --- a/problems/sbst/self_driving/nvidia_prediction.py +++ b/problems/sbst/self_driving/nvidia_prediction.py @@ -21,10 +21,7 @@ def predict(self, image, car_state: SimulationDataRecord): steering_angle = float(self.model.predict(image, batch_size=1)) speed = car_state.vel_kmh - if speed > self.speed_limit: - self.speed_limit = MIN_SPEED # slow down - else: - self.speed_limit = self.max_speed + self.speed_limit = MIN_SPEED if speed > self.speed_limit else self.max_speed throttle = 1.0 - steering_angle ** 2 - (speed / self.speed_limit) ** 2 return steering_angle, throttle diff --git a/problems/sbst/self_driving/oob_monitor.py b/problems/sbst/self_driving/oob_monitor.py index 749c3b8c..7e39381e 100644 --- a/problems/sbst/self_driving/oob_monitor.py +++ b/problems/sbst/self_driving/oob_monitor.py @@ -77,8 +77,7 @@ def oob_distance(self, wrt="right") -> float: else: distance = self.road_polygon.polyline.distance(car_point) divisor = 2.0 - difference = self.road_polygon.road_width / divisor - distance - return difference + return self.road_polygon.road_width / divisor - distance def oob_distances(self, wrt="right"): """Returns the signed distances of the car to the left and right edges @@ -88,7 +87,7 @@ def oob_distances(self, wrt="right"): # right lane and self.road_polygon.left_polyline is the middle of # the left lane. This explains the slightly awkward code. - if not wrt == "right": + if wrt != "right": raise Exception("Currently only the distance with respect to the right lane is supported.") # Old point. Were not sure what this is exactly. @@ -111,10 +110,6 @@ def oob_distances(self, wrt="right"): elif dL > dR and dL > lane_width: # Out from the right. dR *= -1 - else: - # In the middle. - pass - return dL, dR def is_oob(self, wrt="right") -> bool: diff --git a/problems/sbst/self_driving/road_points.py b/problems/sbst/self_driving/road_points.py index 5bf72cfe..a5beb87d 100644 --- a/problems/sbst/self_driving/road_points.py +++ b/problems/sbst/self_driving/road_points.py @@ -25,10 +25,11 @@ def __init__(self): def add_middle_nodes(self, middle_nodes): n = len(self.middle) + len(middle_nodes) - assert n >= 2, f'At least, two nodes are needed' + assert n >= 2, 'At least, two nodes are needed' - assert all(len(point) >= 4 for point in middle_nodes), \ - f'A node is a tuple of 4 elements (x,y,z,road_width)' + assert all( + len(point) >= 4 for point in middle_nodes + ), 'A node is a tuple of 4 elements (x,y,z,road_width)' self.n = n self.middle += list(middle_nodes) @@ -48,9 +49,9 @@ def _recalculate_nodes(self): @classmethod def calc_point_edges(cls, p1, p2) -> Tuple[Tuple, Tuple]: - origin = np.array(p1[0:2]) + origin = np.array(p1[:2]) - a = np.subtract(p2[0:2], origin) + a = np.subtract(p2[:2], origin) # calculate the vector which length is half the road width v = (a / np.linalg.norm(a)) * p1[3] / 2 @@ -59,19 +60,17 @@ def calc_point_edges(cls, p1, p2) -> Tuple[Tuple, Tuple]: r = origin + np.array([v[1], -v[0]]) return tuple(l), tuple(r) - def vehicle_start_pose(self, meters_from_road_start=2.5, road_point_index=0) \ - -> BeamNGPose: + def vehicle_start_pose(self, meters_from_road_start=2.5, road_point_index=0) -> BeamNGPose: assert self.n > road_point_index, f'road length is {self.n} it does not have index {road_point_index}' p1 = self.middle[road_point_index] p1r = self.right[road_point_index] p2 = self.middle[road_point_index + 1] - p2v = np.subtract(p2[0:2], p1[0:2]) + p2v = np.subtract(p2[:2], p1[:2]) v = (p2v / np.linalg.norm(p2v)) * meters_from_road_start - origin = np.add(p1[0:2], p1r[0:2]) / 2 + origin = np.add(p1[:2], p1r[:2]) / 2 deg = np.degrees(np.arctan2([-v[0]], [-v[1]])) - res = BeamNGPose(pos=tuple(origin + v) + (p1[2],), rot=(0, 0, deg[0])) - return res + return BeamNGPose(pos=tuple(origin + v) + (p1[2],), rot=(0, 0, deg[0])) def new_imagery(self): from .beamng_road_imagery import BeamNGRoadImagery diff --git a/problems/sbst/self_driving/road_polygon.py b/problems/sbst/self_driving/road_polygon.py index 69a9e272..2638237a 100644 --- a/problems/sbst/self_driving/road_polygon.py +++ b/problems/sbst/self_driving/road_polygon.py @@ -125,9 +125,9 @@ def is_valid(self) -> bool: logging.debug("No polygon constructed.") return False - for i, polygon in enumerate(self.polygons): + for polygon in self.polygons: if not polygon.is_valid: - logging.debug("Polygon %s is invalid." % polygon) + logging.debug(f"Polygon {polygon} is invalid.") return False for i, polygon in enumerate(self.polygons): @@ -140,11 +140,14 @@ def is_valid(self) -> bool: logging.debug("No polygon should contain any other polygon.") return False if not self._are_neighbouring_polygons(i, j) and other.intersects(polygon): - logging.debug("The non-neighbouring polygons %s and %s intersect." % (polygon, other)) + logging.debug( + f"The non-neighbouring polygons {polygon} and {other} intersect." + ) return False if self._are_neighbouring_polygons(i, j) and not isinstance(other.intersection(polygon), LineString): - logging.debug("The neighbouring polygons %s and %s have an intersection of type %s." % ( - polygon, other, type(other.intersection(polygon)))) + logging.debug( + f"The neighbouring polygons {polygon} and {other} have an intersection of type {type(other.intersection(polygon))}." + ) return False #logging.debug("The road is apparently valid.") return True diff --git a/problems/sbst/self_driving/simulation_data.py b/problems/sbst/self_driving/simulation_data.py index 632378e1..72862395 100644 --- a/problems/sbst/self_driving/simulation_data.py +++ b/problems/sbst/self_driving/simulation_data.py @@ -30,7 +30,7 @@ def delete_folder_recursively(path: Union[str, Path], exception_if_fail: bool = shutil.rmtree(path, ignore_errors=True) # sometimes rmtree fails to remove files - for tries in range(20): + for _ in range(20): if os.path.exists(path): sleep(0.1) shutil.rmtree(path, ignore_errors=True) @@ -142,8 +142,6 @@ def load_from_json(path_json): info=info) return self - pass - def complete(self) -> bool: return self.path_json.exists() diff --git a/problems/sbst/sut.py b/problems/sbst/sut.py index f09b6f36..13a3396c 100644 --- a/problems/sbst/sut.py +++ b/problems/sbst/sut.py @@ -77,7 +77,7 @@ def __init__(self, parameters=None): super().__init__(parameters) - if not "curvature_points" in self.parameters: + if "curvature_points" not in self.parameters: raise Exception("Number of curvature points not defined.") if self.curvature_points <= 0: raise ValueError("The number of curvature points must be positive.") @@ -111,12 +111,14 @@ def __init__(self, parameters=None): # Check for activation key. if not os.path.exists(os.path.join(self.beamng_user, "tech.key")): - raise Exception("The activation key 'tech.key' must be in the directory {}.".format(self.beamng_user)) + raise Exception( + f"The activation key 'tech.key' must be in the directory {self.beamng_user}." + ) # Check for DAVE-2 model if requested. if "dave2_model" in self.parameters and self.dave2_model is not None: if not os.path.exists(self.dave2_model): - raise Exception("The DAVE-2 model file '{}' does not exist.".format(self.dave2_model)) + raise Exception(f"The DAVE-2 model file '{self.dave2_model}' does not exist.") from tensorflow.python.keras.models import load_model self.load_model = load_model self.dave2 = True @@ -152,16 +154,16 @@ def _is_the_car_moving(self, last_state): self.last_observation = last_state return True - # If the car moved since the last observation, we store the last state and move one - if (Point(self.last_observation.pos[0], self.last_observation.pos[1]).distance(Point(last_state.pos[0], last_state.pos[1])) > self.min_delta_position): - self.last_observation = last_state - return True - else: + if ( + Point( + self.last_observation.pos[0], self.last_observation.pos[1] + ).distance(Point(last_state.pos[0], last_state.pos[1])) + <= self.min_delta_position + ): # How much time has passed since the last observation? - if last_state.timer - self.last_observation.timer > 10.0: - return False - else: - return True + return last_state.timer - self.last_observation.timer <= 10.0 + self.last_observation = last_state + return True def end_iteration(self): try: @@ -257,9 +259,13 @@ def _execute_test_beamng(self, test): if (points_distance(last_state.pos, waypoint_goal.position) < 8.0): break - assert self._is_the_car_moving(last_state), "Car is not moving fast enough " + str(sim_data_collector.name) + assert self._is_the_car_moving( + last_state + ), f"Car is not moving fast enough {str(sim_data_collector.name)}" - assert (not last_state.is_oob), "Car drove out of the lane " + str(sim_data_collector.name) + assert ( + not last_state.is_oob + ), f"Car drove out of the lane {str(sim_data_collector.name)}" if self.dave2: img = vehicle_state_reader.sensors['cam_center']['colour'].convert('RGB') @@ -335,7 +341,7 @@ class SBSTSUT_validator(SUT): def __init__(self, parameters): super().__init__(parameters) - if not "curvature_points" in self.parameters: + if "curvature_points" not in self.parameters: raise Exception("Number of curvature points not defined.") if self.curvature_points <= 0: raise ValueError("The number of curvature points must be positive.") diff --git a/problems/sbst/util.py b/problems/sbst/util.py index a5d0130a..c6c3d7a3 100644 --- a/problems/sbst/util.py +++ b/problems/sbst/util.py @@ -71,7 +71,7 @@ def sbst_test_to_image(test, map_size): # Add information about the test validity title_string = "Test is " + ("valid" if valid else "invalid") if not valid: - title_string = title_string + ":" + msg + title_string = f"{title_string}:{msg}" plt.suptitle(title_string, fontsize=14) plt.draw() diff --git a/stgem/algorithm/algorithm.py b/stgem/algorithm/algorithm.py index a6852d4f0..64ae38fe 100644 --- a/stgem/algorithm/algorithm.py +++ b/stgem/algorithm/algorithm.py @@ -29,7 +29,7 @@ def __init__(self, model_factory=None, model=None, models=None, parameters=None) #self.parameters = self.default_parameters | parameters self.parameters = parameters for key in self.default_parameters: - if not key in self.parameters: + if key not in self.parameters: self.parameters[key] = self.default_parameters[key] def setup(self, search_space, device=None, logger=None): @@ -47,15 +47,16 @@ def setup(self, search_space, device=None, logger=None): self.log = lambda msg: (self.logger("algorithm", msg) if logger is not None else None) # Set input dimension. - if not "input_dimension" in self.parameters: + if "input_dimension" not in self.parameters: self.parameters["input_dimension"] = self.search_space.input_dimension # Create models by cloning if self.model: self.models = [] - for _ in range(self.search_space.objectives): - self.models.append(copy.deepcopy(self.model)) - + self.models.extend( + copy.deepcopy(self.model) + for _ in range(self.search_space.objectives) + ) # Create models by factory if self.model_factory: self.models = [self.model_factory() for _ in range(self.search_space.objectives)] @@ -66,9 +67,8 @@ def setup(self, search_space, device=None, logger=None): m.setup(self.search_space, self.device, self.logger) def __getattr__(self, name): - if "parameters" in self.__dict__: - if name in self.parameters: - return self.parameters.get(name) + if "parameters" in self.__dict__ and name in self.parameters: + return self.parameters.get(name) raise AttributeError(name) diff --git a/stgem/algorithm/bayesian/algorithm.py b/stgem/algorithm/bayesian/algorithm.py index 0993eef6..fc50b7cc 100644 --- a/stgem/algorithm/bayesian/algorithm.py +++ b/stgem/algorithm/bayesian/algorithm.py @@ -11,11 +11,10 @@ class BayesianOptimization(Algorithm): def setup(self, search_space, device=None, logger=None): super().setup(search_space, device, logger) - self.bounds = [] - for i in range(self.search_space.input_dimension): - self.bounds.append({"name": "x_{}".format(i), - "type": "continuous", - "domain": (-1, 1)}) + self.bounds = [ + {"name": f"x_{i}", "type": "continuous", "domain": (-1, 1)} + for i in range(self.search_space.input_dimension) + ] def do_train(self, active_outputs, test_repository, budget_remaining): pass diff --git a/stgem/algorithm/model.py b/stgem/algorithm/model.py index 66cec7c4..0a322d59 100644 --- a/stgem/algorithm/model.py +++ b/stgem/algorithm/model.py @@ -23,9 +23,8 @@ def __init__(self, parameters): self.parameters = copy.deepcopy(parameters) def __getattr__(self, name): - if "parameters" in self.__dict__: - if name in self.parameters: - return self.parameters.get(name) + if "parameters" in self.__dict__ and name in self.parameters: + return self.parameters.get(name) raise AttributeError(name) @@ -65,7 +64,7 @@ def __init__(self, parameters=None): # We would like to write the following but this is not supported in Python 3.7. #super().__init__(self.default_parameters | parameters) for key in self.default_parameters: - if not key in parameters: + if key not in parameters: parameters[key] = self.default_parameters[key] super().__init__(parameters) @@ -82,10 +81,10 @@ def setup(self, search_space, device, logger=None, use_previous_rng=False): self.log = lambda msg: (self.logger("model", msg) if logger is not None else None) @classmethod - def setup_from_skeleton(C, skeleton, search_space, device, logger=None, use_previous_rng=False): - model = C(skeleton.parameters) + def setup_from_skeleton(cls, skeleton, search_space, device, logger=None, use_previous_rng=False): + model = cls(skeleton.parameters) model.setup(search_space, device, logger, use_previous_rng) - return C + return cls def skeletonize(self): return ModelSkeleton(self.parameters) diff --git a/stgem/algorithm/ogan/algorithm.py b/stgem/algorithm/ogan/algorithm.py index 87b3549c..8e54b05e 100644 --- a/stgem/algorithm/ogan/algorithm.py +++ b/stgem/algorithm/ogan/algorithm.py @@ -44,7 +44,7 @@ def do_train(self, active_outputs, test_repository, budget_remaining): for i in active_outputs: if self.first_training or self.tests_generated - self.model_trained[i] >= self.train_delay: - self.log("Training the OGAN model {}...".format(i + 1)) + self.log(f"Training the OGAN model {i + 1}...") if not self.first_training and self.reset_each_training: # Reset the model. self.models[i].reset() @@ -52,7 +52,7 @@ def do_train(self, active_outputs, test_repository, budget_remaining): dataX = np.asarray([sut_input.inputs for sut_input in X]) dataY = np.array(Y)[:, i].reshape(-1, 1) epochs = self.models[i].train_settings["epochs"] if not self.first_training else self.models[i].train_settings_init["epochs"] - for epoch in range(epochs): + for _ in range(epochs): if self.first_training: train_settings = self.models[i].train_settings_init else: @@ -72,7 +72,9 @@ def do_generate_next_test(self, active_outputs, test_repository, budget_remainin entry_count = 0 # this is to avoid comparing tests when two tests added to the heap have the same predicted objective N_generated = 0 N_invalid = 0 - self.log("Generating using OGAN models {}.".format(",".join(str(m + 1) for m in active_outputs))) + self.log( + f'Generating using OGAN models {",".join(str(m + 1) for m in active_outputs)}.' + ) # PerformanceRecordHandler for the current test. performance = test_repository.performance(test_repository.current_test) @@ -85,7 +87,9 @@ def do_generate_next_test(self, active_outputs, test_repository, budget_remainin # invalid, we give up and hope that the next training phase # will fix things. if N_invalid >= self.invalid_threshold: - raise GenerationException("Could not generate a valid test within {} tests.".format(N_invalid)) + raise GenerationException( + f"Could not generate a valid test within {N_invalid} tests." + ) # Generate several tests and pick the one with best # predicted objective function component. We do this as @@ -129,7 +133,9 @@ def do_generate_next_test(self, active_outputs, test_repository, budget_remainin best_model = heap[0][2] best_estimated_objective = heap[0][0] - self.log("Chose test {} with predicted minimum objective {} on OGAN model {}. Generated total {} tests of which {} were invalid.".format(best_test, best_estimated_objective, best_model + 1, N_generated, N_invalid)) + self.log( + f"Chose test {best_test} with predicted minimum objective {best_estimated_objective} on OGAN model {best_model + 1}. Generated total {N_generated} tests of which {N_invalid} were invalid." + ) return best_test diff --git a/stgem/algorithm/ogan/mlm.py b/stgem/algorithm/ogan/mlm.py index eb141522..58e1fa45 100644 --- a/stgem/algorithm/ogan/mlm.py +++ b/stgem/algorithm/ogan/mlm.py @@ -21,8 +21,8 @@ def __init__(self, noise_dim, output_shape, hidden_neurons, hidden_activation): "sigmoid": torch.sigmoid, "tanh": torch.tanh} - if not hidden_activation in activations: - raise Exception("Unknown activation function '{}'.".format(hidden_activation)) + if hidden_activation not in activations: + raise Exception(f"Unknown activation function '{hidden_activation}'.") self.hidden_activation = activations[hidden_activation] self.layers = nn.ModuleList() @@ -74,8 +74,8 @@ def __init__(self, input_shape, hidden_neurons, hidden_activation, discriminator "sigmoid": torch.sigmoid, "tanh": torch.tanh} - if not hidden_activation in activations: - raise Exception("Unknown activation function '{}'.".format(hidden_activation)) + if hidden_activation not in activations: + raise Exception(f"Unknown activation function '{hidden_activation}'.") self.hidden_activation = activations[hidden_activation] self.layers = nn.ModuleList() @@ -106,7 +106,7 @@ def __init__(self, input_shape, hidden_neurons, hidden_activation, discriminator elif a == "sigmoid": self.output_activation = torch.sigmoid else: - raise Exception("Unknown output activation function '{}'.".format(a)) + raise Exception(f"Unknown output activation function '{a}'.") def forward(self, x): """:meta private:""" @@ -153,8 +153,8 @@ def __init__(self, input_shape, feature_maps, kernel_sizes, convolution_activati "tanh": torch.tanh} # Convolution activation function. - if not convolution_activation in activations: - raise Exception("Unknown activation function '{}'.".format(convolution_activation)) + if convolution_activation not in activations: + raise Exception(f"Unknown activation function '{convolution_activation}'.") self.convolution_activation = activations[convolution_activation] # Define the convolutional layers and maxpool layers. Compute diff --git a/stgem/algorithm/ogan/model.py b/stgem/algorithm/ogan/model.py index 26927f7b..87987bc1 100644 --- a/stgem/algorithm/ogan/model.py +++ b/stgem/algorithm/ogan/model.py @@ -127,9 +127,7 @@ def setup(self, search_space, device, logger=None, use_previous_rng=False): current_rng_state = torch.random.get_rng_state() torch.random.set_rng_state(self.previous_rng_state["torch"]) else: - self.previous_rng_state = {} - self.previous_rng_state["torch"] = torch.random.get_rng_state() - + self.previous_rng_state = {"torch": torch.random.get_rng_state()} self._initialize() # Restore RNG state. @@ -161,7 +159,7 @@ def get_loss(loss_s): loss = torch.nn.MSELoss() elif loss_s == "l1": loss = torch.nn.L1Loss() - elif loss_s == "mse,logit" or loss_s == "l1,logit": + elif loss_s in ["mse,logit", "l1,logit"]: # When doing regression with values in [0, 1], we can use a # logit transformation to map the values from [0, 1] to \R # to make errors near 0 and 1 more drastic. Since logit is @@ -188,8 +186,8 @@ def f(X, Y): raise @classmethod - def setup_from_skeleton(C, skeleton, search_space, device, logger=None, use_previous_rng=False): - model = C(skeleton.parameters) + def setup_from_skeleton(cls, skeleton, search_space, device, logger=None, use_previous_rng=False): + model = cls(skeleton.parameters) model.setup(search_space, device, logger, use_previous_rng) model.modelG = skeleton.modelG.to(device) model.modelD = skeleton.modelD.to(device) @@ -265,7 +263,9 @@ def train_with_batch(self, dataX, dataY, train_settings=None): m = np.mean(D_losses) if discriminator_epochs > 0: - self.log("Discriminator epochs {}, Loss: {} -> {} (mean {})".format(discriminator_epochs, D_losses[0], D_losses[-1], m)) + self.log( + f"Discriminator epochs {discriminator_epochs}, Loss: {D_losses[0]} -> {D_losses[-1]} (mean {m})" + ) self.modelD.train(False) @@ -316,7 +316,9 @@ def train_with_batch(self, dataX, dataY, train_settings=None): m = np.mean(G_losses) if self.noise_batch_size > 0: - self.log("Generator steps {}, Loss: {} -> {}, mean {}".format(self.noise_batch_size//generator_batch_size + 1, G_losses[0], G_losses[-1], m)) + self.log( + f"Generator steps {self.noise_batch_size // generator_batch_size + 1}, Loss: {G_losses[0]} -> {G_losses[-1]}, mean {m}" + ) self.modelG.train(False) diff --git a/stgem/algorithm/ogan/model_keras.py b/stgem/algorithm/ogan/model_keras.py index da0c1c38..51941cb8 100644 --- a/stgem/algorithm/ogan/model_keras.py +++ b/stgem/algorithm/ogan/model_keras.py @@ -54,8 +54,7 @@ def generate_test(self, N=1): self.init_model() noise = np.random.normal(0, 1, size=(N, self.noise_dimensions)) - tests = self.modelG.predict(noise) - return tests + return self.modelG.predict(noise) def predict_objective(self, test): """ @@ -101,8 +100,8 @@ def skeletonize(self): return skeleton @classmethod - def setup_from_skeleton(C, skeleton, search_space, device, logger=None, use_previous_rng=False): - model = C(skeleton.parameters) + def setup_from_skeleton(cls, skeleton, search_space, device, logger=None, use_previous_rng=False): + model = cls(skeleton.parameters) model.setup(search_space, device, logger, use_previous_rng) model.init_model() model.modelG.set_weights(skeleton.modelG_weights) diff --git a/stgem/algorithm/random/model.py b/stgem/algorithm/random/model.py index 1b5eab1f..fadc66f9 100644 --- a/stgem/algorithm/random/model.py +++ b/stgem/algorithm/random/model.py @@ -8,8 +8,8 @@ def skeletonize(self): return Random_ModelSkeleton(self.parameters) @classmethod - def setup_from_skeleton(C, skeleton, search_space, device, logger=None, use_previous_rng=False): - model = C(skeleton.parameters) + def setup_from_skeleton(cls, skeleton, search_space, device, logger=None, use_previous_rng=False): + model = cls(skeleton.parameters) model.setup(search_space, device, logger, use_previous_rng) return model @@ -58,11 +58,9 @@ def setup(self, search_space, device, logger=None, use_previous_rng=False): self.used_points = [] def _satisfies_min_distance(self, test): - for p in self.used_points: - if np.linalg.norm(p - test) < self.min_distance: - return False - - return True + return all( + np.linalg.norm(p - test) >= self.min_distance for p in self.used_points + ) def generate_test(self, N=1): result = np.empty(shape=(N, self.input_dimension)) @@ -112,7 +110,7 @@ def __init__(self, parameters=None): def setup(self, search_space, device, logger=None, use_previous_rng=False): super().setup(search_space, device, logger, use_previous_rng) - if not "samples" in self.parameters: + if "samples" not in self.parameters: raise Exception("The 'samples' key must be provided for the algorithm for determining random sample size.") # Save current RNG state and use previous. @@ -134,7 +132,7 @@ def setup(self, search_space, device, logger=None, use_previous_rng=False): def random_func(self): self.current += 1 - + if self.current >= len(self.random_tests): raise Exception("Random sample exhausted.") @@ -213,23 +211,30 @@ def lhs(self, n, samples=None, criterion=None, iterations=None): """ H = None - + if samples is None: samples = n - + if criterion is not None: - assert criterion.lower() in ('center', 'c', 'maximin', 'm', - 'centermaximin', 'cm', 'correlation', - 'corr'), 'Invalid value for "criterion": {}'.format(criterion) + assert criterion.lower() in ( + 'center', + 'c', + 'maximin', + 'm', + 'centermaximin', + 'cm', + 'correlation', + 'corr', + ), f'Invalid value for "criterion": {criterion}' else: H = self._lhsclassic(n, samples) if criterion is None: criterion = 'center' - + if iterations is None: iterations = 5 - + if H is None: if criterion.lower() in ('center', 'c'): H = _lhscentered(n, samples) @@ -239,7 +244,7 @@ def lhs(self, n, samples=None, criterion=None, iterations=None): H = _lhsmaximin(n, samples, iterations, 'centermaximin') elif criterion.lower() in ('correlate', 'corr'): H = _lhscorrelate(n, samples, iterations) - + return H def _lhsclassic(self, n, samples): @@ -279,39 +284,39 @@ def _lhscentered(n, samples): return H - def _lhsmaximin(n, samples, iterations, lhstype): + def _lhsmaximin(self, samples, iterations, lhstype): maxdist = 0 - + # Maximize the minimum distance between points - for i in range(iterations): - if lhstype=='maximin': - Hcandidate = _lhsclassic(n, samples) - else: - Hcandidate = _lhscentered(n, samples) - + for _ in range(iterations): + Hcandidate = ( + _lhsclassic(self, samples) + if lhstype == 'maximin' + else _lhscentered(self, samples) + ) d = _pdist(Hcandidate) if maxdist= self.invalid_threshold: - raise GenerationException("Could not generate a valid test within {} tests.".format(N_invalid)) + raise GenerationException( + f"Could not generate a valid test within {N_invalid} tests." + ) # Generate several tests and pick the one with best # predicted objective function component. We do this as @@ -258,6 +260,8 @@ def do_generate_next_test(self, active_outputs, test_repository, budget_remainin best_model = heap[0][2] best_estimated_objective = heap[0][0] - self.log("Chose test {} with predicted minimum objective {} on WGAN model {}. Generated total {} tests of which {} were invalid.".format(best_test, best_estimated_objective, best_model + 1, N_generated, N_invalid)) + self.log( + f"Chose test {best_test} with predicted minimum objective {best_estimated_objective} on WGAN model {best_model + 1}. Generated total {N_generated} tests of which {N_invalid} were invalid." + ) return best_test diff --git a/stgem/algorithm/wogan/analyzer.py b/stgem/algorithm/wogan/analyzer.py index 0e2f2d5b..37e3de7b 100644 --- a/stgem/algorithm/wogan/analyzer.py +++ b/stgem/algorithm/wogan/analyzer.py @@ -24,9 +24,8 @@ def setup(self, device, logger=None): self.modelA = None def __getattr__(self, name): - if "parameters" in self.__dict__: - if name in self.parameters: - return self.parameters.get(name) + if "parameters" in self.__dict__ and name in self.parameters: + return self.parameters.get(name) raise AttributeError(name) @@ -63,18 +62,16 @@ def get_loss(loss_s): loss = torch.nn.MSELoss() elif loss_s == "l1": loss = torch.nn.L1Loss() - elif loss_s == "mse,logit" or loss_s == "l1,logit": + elif loss_s in ["mse,logit", "l1,logit"]: # When doing regression with values in [0, 1], we can use a # logit transformation to map the values from [0, 1] to \R # to make errors near 0 and 1 more drastic. Since logit is # undefined in 0 and 1, we actually first transform the values # to the interval [0.01, 0.99]. - if loss_s == "mse,logit": - g = torch.nn.MSELoss() - else: - g = torch.nn.L1Loss() + g = torch.nn.MSELoss() if loss_s == "mse,logit" else torch.nn.L1Loss() def f(X, Y): return g(torch.logit(0.98*X + 0.01), torch.logit(0.98*Y + 0.01)) + loss = f else: raise Exception("Unknown loss function '{}'.".format(loss_s)) @@ -102,9 +99,7 @@ def analyzer_loss(self, data_X, data_Y): else: self.parameters["l2_regularization_coef"] = 0 - A_loss = model_loss + self.l2_regularization_coef*l2_regularization - - return A_loss + return model_loss + self.l2_regularization_coef*l2_regularization def _train_with_batch(self, data_X, data_Y, train_settings): # Save the training modes for later restoring. diff --git a/stgem/algorithm/wogan/mlm.py b/stgem/algorithm/wogan/mlm.py index 1814a10a..48393cd7 100644 --- a/stgem/algorithm/wogan/mlm.py +++ b/stgem/algorithm/wogan/mlm.py @@ -30,13 +30,13 @@ def __init__(self, input_shape, hidden_neurons, output_shape, output_activation, "tanh": torch.tanh} # Hidden layer activation. - if not hidden_activation in activations: - raise Exception("Unknown activation function '{}'.".format(hidden_activation)) + if hidden_activation not in activations: + raise Exception(f"Unknown activation function '{hidden_activation}'.") self.hidden_activation = activations[hidden_activation] # Output activation. - if not output_activation in activations: - raise Exception("Unknown activation function '{}'.".format(output_activation)) + if output_activation not in activations: + raise Exception(f"Unknown activation function '{output_activation}'.") self.output_activation = activations[output_activation] # We use fully connected layers with the specified number of neurons. @@ -163,8 +163,8 @@ def __init__(self, input_shape, feature_maps, kernel_sizes, convolution_activati "tanh": torch.tanh} # Convolution activation function. - if not convolution_activation in activations: - raise Exception("Unknown activation function '{}'.".format(convolution_activation)) + if convolution_activation not in activations: + raise Exception(f"Unknown activation function '{convolution_activation}'.") self.convolution_activation = activations[convolution_activation] # Define the convolutional layers and maxpool layers. Compute diff --git a/stgem/algorithm/wogan/model.py b/stgem/algorithm/wogan/model.py index c8fd83c2..d0114f58 100644 --- a/stgem/algorithm/wogan/model.py +++ b/stgem/algorithm/wogan/model.py @@ -130,9 +130,7 @@ def setup(self, search_space, device, logger=None, use_previous_rng=False): current_rng_state = torch.random.get_rng_state() torch.random.set_rng_state(self.previous_rng_state["torch"]) else: - self.previous_rng_state = {} - self.previous_rng_state["torch"] = torch.random.get_rng_state() - + self.previous_rng_state = {"torch": torch.random.get_rng_state()} # Infer input and output dimensions for ML models. self.parameters["analyzer_parameters"]["analyzer_mlm_parameters"]["input_shape"] = self.search_space.input_dimension self.parameters["generator_mlm_parameters"]["output_shape"] = self.search_space.input_dimension @@ -165,8 +163,8 @@ def setup(self, search_space, device, logger=None, use_previous_rng=False): torch.random.set_rng_state(current_rng_state) @classmethod - def setup_from_skeleton(C, skeleton, search_space, device, logger=None, use_previous_rng=False): - model = C(skeleton.parameters) + def setup_from_skeleton(cls, skeleton, search_space, device, logger=None, use_previous_rng=False): + model = cls(skeleton.parameters) model.setup(search_space, device, logger, use_previous_rng) model.modelA.device = device model.modelA.modelA = skeleton.modelA.modelA.to(device) @@ -211,7 +209,9 @@ def train_analyzer_with_batch(self, data_X, data_Y, train_settings): losses.append(loss) m = np.mean(losses) - self.log("Analyzer epochs {}, Loss: {} -> {} (mean {})".format(train_settings["analyzer_epochs"], losses[0], losses[-1], m)) + self.log( + f'Analyzer epochs {train_settings["analyzer_epochs"]}, Loss: {losses[0]} -> {losses[-1]} (mean {m})' + ) return losses @@ -306,7 +306,9 @@ def train_with_batch(self, data_X, train_settings=None): m1 = np.mean(C_losses) m2 = np.mean(gradient_penalties) - self.log("Critic steps {}, Loss: {} -> {} (mean {}), GP: {} -> {} (mean {})".format(critic_steps, C_losses[0], C_losses[-1], m1, gradient_penalties[0], gradient_penalties[-1], m2)) + self.log( + f"Critic steps {critic_steps}, Loss: {C_losses[0]} -> {C_losses[-1]} (mean {m1}), GP: {gradient_penalties[0]} -> {gradient_penalties[-1]} (mean {m2})" + ) self.modelC.train(False) @@ -329,7 +331,9 @@ def train_with_batch(self, data_X, train_settings=None): self.optimizerG.step() m = np.mean(G_losses) - self.log("Generator steps {}, Loss: {} -> {} (mean {})".format(generator_steps, G_losses[0], G_losses[-1], m)) + self.log( + f"Generator steps {generator_steps}, Loss: {G_losses[0]} -> {G_losses[-1]} (mean {m})" + ) self.modelG.train(False) @@ -348,7 +352,7 @@ def train_with_batch(self, data_X, train_settings=None): W_distance = real_loss - fake_loss - self.log("Batch W. distance: {}".format(W_distance[0])) + self.log(f"Batch W. distance: {W_distance[0]}") # Visualize the computational graph. # print(make_dot(G_loss, params=dict(self.modelG.named_parameters()))) diff --git a/stgem/budget.py b/stgem/budget.py index c004b975..450e3506 100644 --- a/stgem/budget.py +++ b/stgem/budget.py @@ -33,19 +33,19 @@ def update_threshold(self, budget_threshold): # Use specified values; infinite budget for nonspecified quantities. for name in budget_threshold: if budget_threshold[name] < self.budgets[name](self.quantities): - raise Exception("Cannot update budget threshold '{}' to '{}' since its below the already consumed budget '{}'.".format(name, quantity_threshold[name], self.budgets[name](self.quantities))) + raise Exception( + f"Cannot update budget threshold '{name}' to '{quantity_threshold[name]}' since its below the already consumed budget '{self.budgets[name](self.quantities)}'." + ) # If budget range does not exist, we set it to default. This can # happen if the user defines a new budget by adding a key to # self.budgets. - if not name in self.budget_ranges: + if name not in self.budget_ranges: self.budget_ranges[name] = [0,math.inf] if self.budget_ranges[name][1] < budget_threshold[name]: self.budget_ranges[name][0] = self.budget_ranges[name][1] - self.budget_ranges[name][1] = budget_threshold[name] - else: - self.budget_ranges[name][1] = budget_threshold[name] + self.budget_ranges[name][1] = budget_threshold[name] def remaining(self): """Return the minimum amount of budget left among all budget as a @@ -62,17 +62,13 @@ def used(self): # If budget range does not exist, we set it to default. This can # happen if the user defines a new budget by adding a key to # self.budgets. - if not name in self.budget_ranges: + if name not in self.budget_ranges: self.budget_ranges[name] = [0,math.inf] start = self.budget_ranges[name][0] end = self.budget_ranges[name][1] value = self.budgets[name](self.quantities) - if value >= start: - remaining = 1.0 - (value - start) / (end - start) - else: - remaining = 1.0 - + remaining = 1.0 - (value - start) / (end - start) if value >= start else 1.0 result[name] = remaining return result diff --git a/stgem/experiment.py b/stgem/experiment.py index 75cdbf13..dff30e5a 100644 --- a/stgem/experiment.py +++ b/stgem/experiment.py @@ -35,18 +35,18 @@ def run(self, N_workers=1, silent=False, use_gpu=True, done=None): for idx in range(self.N): generator = self.stgem_factory() seed = self.seed_factory() - if not idx in done: + if idx not in done: generator.setup(seed=seed, use_gpu=use_gpu) if silent: generator.logger.silent = True - if not self.generator_callback is None: + if self.generator_callback is not None: self.generator_callback(generator) r = generator._run() - if not self.result_callback is None: + if self.result_callback is not None: self.result_callback(idx, r, done) done.append(idx) @@ -67,13 +67,13 @@ def run(self, N_workers=1, silent=False, use_gpu=True, done=None): # but currently we just exit and instruct the user. if torch.cuda.is_available(): raise SystemExit("Subprocesses are being used and these do " \ - "not work with any CUDA device being " \ - "available due to a pickling error (even in " \ - "the case that only CPU is requested as the " \ - "Pytorch device). Please disable " \ - "multiprocessing or set 'export " \ - "CUDA_VISIBLE_DEVICES=\"\"' to use CPU and " \ - "multiprocessing.") + "not work with any CUDA device being " \ + "available due to a pickling error (even in " \ + "the case that only CPU is requested as the " \ + "Pytorch device). Please disable " \ + "multiprocessing or set 'export " \ + "CUDA_VISIBLE_DEVICES=\"\"' to use CPU and " \ + "multiprocessing.") def consumer(queue_generators, queue_results, silent, generator_callback): while True: @@ -83,11 +83,11 @@ def consumer(queue_generators, queue_results, silent, generator_callback): idx, generator, seed = msg generator.setup(seed=seed, use_gpu=use_gpu) - + if silent: generator.logger.silent = True - if not generator_callback is None: + if generator_callback is not None: generator_callback(generator) r = generator._run() @@ -97,10 +97,10 @@ def consumer(queue_generators, queue_results, silent, generator_callback): del generator if self.garbage_collect: gc.collect() - + def producer(queue_generators, N_workers, N, stgem_factory, seed_factory, done): for idx in range(N): - if not idx in done: + if idx not in done: queue_generators.put((idx, stgem_factory(), seed_factory())) else: stgem_factory() @@ -125,7 +125,7 @@ def producer(queue_generators, N_workers, N, stgem_factory, seed_factory, done): # Wait for results and process them via the callback. while len(done) < self.N: idx, r = queue_results.get() - if not self.result_callback is None: + if self.result_callback is not None: self.result_callback(idx, r, done) done.append(idx) diff --git a/stgem/generator.py b/stgem/generator.py index b000f626..91f45c89 100644 --- a/stgem/generator.py +++ b/stgem/generator.py @@ -47,7 +47,7 @@ def dump_to_file(self, file_name: str): o = gzip.open if file_name.endswith(".gz") else open # first create a temporary file - temp_file_name = "{}.tmp".format(file_name) + temp_file_name = f"{file_name}.tmp" with o(temp_file_name, "wb") as file: pickle.dump(self, file) # then we rename it to its final name @@ -77,7 +77,7 @@ def __init__(self, algorithm: Algorithm, budget_threshold, mode="exhaust_budget" self.budget = None self.budget_threshold = budget_threshold if mode not in ["exhaust_budget", "stop_at_first_objective"]: - raise Exception("Unknown test generation mode '{}'.".format(mode)) + raise Exception(f"Unknown test generation mode '{mode}'.") self.mode = mode self.results_include_models = results_include_models @@ -103,7 +103,10 @@ def run(self) -> StepResult: self.algorithm.initialize() self.success = True - if not (self.mode == "stop_at_first_objective" and self.test_repository.minimum_objective <= 0.0): + if ( + self.mode != "stop_at_first_objective" + or self.test_repository.minimum_objective > 0.0 + ): self.success = False # Below we omit including a test into the test repository if the @@ -113,19 +116,19 @@ def run(self) -> StepResult: i = 0 while self.budget.remaining() > 0: - self.log("Budget remaining {}.".format(self.budget.remaining())) + self.log(f"Budget remaining {self.budget.remaining()}.") # Create a new test repository record to be filled. performance = self.test_repository.new_record() self.algorithm.train(self.objective_selector.select(), self.test_repository, self.budget.remaining()) self.budget.consume("training_time", performance.obtain("training_time")) - if not self.budget.remaining() > 0: + if self.budget.remaining() <= 0: self.log("Ran out of budget during training. Discarding the test.") self.test_repository.discard_record() break - self.log("Starting to generate test {}.".format(self.test_repository.tests + 1)) + self.log(f"Starting to generate test {self.test_repository.tests + 1}.") could_generate = True try: next_test = self.algorithm.generate_next_test(self.objective_selector.select(), self.test_repository, self.budget.remaining()) @@ -142,12 +145,12 @@ def run(self) -> StepResult: could_generate = False self.budget.consume("generation_time", performance.obtain("generation_time")) - if not self.budget.remaining() > 0: + if self.budget.remaining() <= 0: self.log("Ran out of budget during test generation. Discarding the test.") self.test_repository.discard_record() break if could_generate: - self.log("Generated test {}.".format(next_test)) + self.log(f"Generated test {next_test}.") self.log("Executing the test...") performance.timer_start("execution") @@ -160,26 +163,28 @@ def run(self) -> StepResult: self.budget.consume("execution_time", performance.obtain("execution_time")) self.budget.consume(sut_output) - if not self.budget.remaining() > 0: + if self.budget.remaining() <= 0: self.log("Ran out of budget during test execution. Discarding the test.") self.test_repository.discard_record() break self.budget.consume("executions") - self.log("Input to the SUT: {}".format(sut_input)) + self.log(f"Input to the SUT: {sut_input}") if sut_output.error is None: - self.log("Output from the SUT: {}".format(sut_output)) + self.log(f"Output from the SUT: {sut_output}") objectives = [objective(sut_input, sut_output) for objective in self.objective_funcs] self.test_repository.record_objectives(objectives) - self.log("The actual objective: {}".format(objectives)) + self.log(f"The actual objective: {objectives}") # TODO: Argmin does not take different scales into account. self.objective_selector.update(np.argmin(objectives)) else: - self.log("An error '{}' occurred during the test execution. No output available.".format(sut_output.error)) + self.log( + f"An error '{sut_output.error}' occurred during the test execution. No output available." + ) self.test_repository.record_objectives([]) idx = self.test_repository.finalize_record() @@ -187,7 +192,7 @@ def run(self) -> StepResult: if not self.success and self.test_repository.minimum_objective <= 0.0: self.success = True - self.log("First success at test {}.".format(i + 1)) + self.log(f"First success at test {i + 1}.") else: self.log("Encountered a problem with test generation. Skipping to next training phase.") @@ -206,17 +211,18 @@ def run(self) -> StepResult: self.algorithm.finalize() # Report results. - self.log("Step minimum objective component: {}".format(self.test_repository.minimum_objective)) + self.log( + f"Step minimum objective component: {self.test_repository.minimum_objective}" + ) - result = self._generate_step_result(test_idx, model_skeletons) - - return result + return self._generate_step_result(test_idx, model_skeletons) def _generate_step_result(self, test_idx, model_skeletons): # Save certain parameters in the StepResult object. - parameters = {} - parameters["algorithm_name"] = self.algorithm.__class__.__name__ - parameters["algorithm"] = copy.deepcopy(self.algorithm.parameters) + parameters = { + "algorithm_name": self.algorithm.__class__.__name__, + "algorithm": copy.deepcopy(self.algorithm.parameters), + } parameters["model_name"] = [self.algorithm.models[i].__class__.__name__ for i in range(self.algorithm.N_models)] parameters["model"] = [copy.deepcopy(self.algorithm.models[i].parameters) for i in range(self.algorithm.N_models)] parameters["objective_name"] = [objective.__class__.__name__ for objective in self.objective_funcs] @@ -244,11 +250,11 @@ class Load(Step): def __init__(self, file_name, mode="initial", load_range=None, consume_budget=True, recompute_objective=False): self.file_name = file_name if not os.path.exists(self.file_name): - raise Exception("Pregenerated date file '{}' does not exist.".format(self.file_name)) + raise Exception(f"Pregenerated date file '{self.file_name}' does not exist.") if mode not in ["initial", "random"]: - raise ValueError("Unknown load mode '{}'.".format(mode)) + raise ValueError(f"Unknown load mode '{mode}'.") if load_range < 0: - raise ValueError("The load range {} cannot be negative.".format(load_range)) + raise ValueError(f"The load range {load_range} cannot be negative.") self.mode = mode self.load_range = load_range self.consume_budget = consume_budget @@ -274,7 +280,9 @@ def run(self, results_include_models=False, results_checkpoint_period=1) -> Step try: raw_data = STGEMResult.restore_from_file(self.file_name) except: - raise Exception("Error loading STGEMResult object from file '{}'.".format(self.file_name)) + raise Exception( + f"Error loading STGEMResult object from file '{self.file_name}'." + ) """ If load_range is defined and consume_budget is True, then we stop when @@ -285,7 +293,9 @@ def run(self, results_include_models=False, results_checkpoint_period=1) -> Step if self.load_range is None: self.load_range = range_max elif self.load_range > range_max: - raise ValueError("The load range {} is out of bounds. The maximum range for loaded data is {}.".format(self.load_range, range_max)) + raise ValueError( + f"The load range {self.load_range} is out of bounds. The maximum range for loaded data is {range_max}." + ) if self.mode == "random": # Use the search space RNG to ensure deterministic selection. @@ -295,25 +305,30 @@ def run(self, results_include_models=False, results_checkpoint_period=1) -> Step for i in idx: if self.budget.remaining() == 0: break - self.log("Budget remaining {}.".format(self.budget.remaining())) + self.log(f"Budget remaining {self.budget.remaining()}.") X, Z, Y = raw_data.test_repository.get(i) old_performance = raw_data.test_repository.performance(i) if len(X.inputs) != self.search_space.input_dimension: - raise ValueError("Loaded sample input dimension {} does not match SUT input dimension {}".format(len(X.inputs), self.search_space.input_dimension)) + raise ValueError( + f"Loaded sample input dimension {len(X.inputs)} does not match SUT input dimension {self.search_space.input_dimension}" + ) if Z.output_timestamps is None: if len(Z.outputs) != self.search_space.output_dimension: - raise ValueError("Loaded sample vector output dimension {} does not match SUT vector output dimension {}.".format(len(Z.outputs), self.search_space.output_dimension)) - else: - if Z.outputs.shape[0] != self.search_space.output_dimension: - raise ValueError("Loaded sample signal number {} does not match SUT signal number {}.".format(Z.outputs.shape[0], self.search_space.output_dimension)) + raise ValueError( + f"Loaded sample vector output dimension {len(Z.outputs)} does not match SUT vector output dimension {self.search_space.output_dimension}." + ) + elif Z.outputs.shape[0] != self.search_space.output_dimension: + raise ValueError( + f"Loaded sample signal number {Z.outputs.shape[0]} does not match SUT signal number {self.search_space.output_dimension}." + ) # Consume the budget if requested. if self.consume_budget: self.budget.consume("training_time", old_performance.obtain("training_time")) self.budget.consume("generation_time", old_performance.obtain("generation_time")) self.budget.consume("execution_time", old_performance.obtain("execution_time")) - if not self.budget.remaining() > 0: + if self.budget.remaining() <= 0: self.log("Ran out of budget during training, generation, or execution. Discarding the test.") break self.budget.consume("executions") @@ -339,30 +354,31 @@ def run(self, results_include_models=False, results_checkpoint_period=1) -> Step if Z.error is not None: self.log("and output") self.log(str(Z)) - self.log("and objective {}.".format(Y)) + self.log(f"and objective {Y}.") else: self.log("which failed to execute and produce output.") if self.mode == "initial": - self.log("Loaded initial {} tests from the result file {}.".format(self.load_range, self.file_name)) + self.log( + f"Loaded initial {self.load_range} tests from the result file {self.file_name}." + ) else: - self.log("Loaded randomly {} tests from the result file {}.".format(self.load_range, self.file_name)) + self.log( + f"Loaded randomly {self.load_range} tests from the result file {self.file_name}." + ) success = self.test_repository.minimum_objective <= 0 # Save certain parameters in the StepResult object. - parameters = {} - parameters["file_name"] = self.file_name - parameters["mode"] = self.mode - parameters["load_range"] = self.load_range - parameters["consume_budget"] = self.consume_budget - parameters["recompute_objective"] = self.recompute_objective - parameters["executed_tests"] = test_idx - - # Build StepResult object with test_repository - step_result = StepResult(self.test_repository, success, parameters) - - return step_result + parameters = { + "file_name": self.file_name, + "mode": self.mode, + "load_range": self.load_range, + "consume_budget": self.consume_budget, + "recompute_objective": self.recompute_objective, + "executed_tests": test_idx, + } + return StepResult(self.test_repository, success, parameters) class STGEM: @@ -374,7 +390,9 @@ def __init__(self, description, sut: SUT, objectives, objective_selector=None, b nonsafe_chars = "/\<>:\"|?*" for c in self.description: if c in nonsafe_chars: - raise ValueError("Character '{}' not allowed in a description (could be used as a file name).".format(c)) + raise ValueError( + f"Character '{c}' not allowed in a description (could be used as a file name)." + ) self.sut = sut self.step_results = [] @@ -440,7 +458,9 @@ def setup(self, seed=None, use_gpu=True): if use_gpu: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.device.type != "cuda": - self.log("Warning: requested torch device 'cuda' but got '{}'.".format(self.device.type)) + self.log( + f"Warning: requested torch device 'cuda' but got '{self.device.type}'." + ) else: self.device = torch.device("cpu") @@ -460,9 +480,7 @@ def _run(self) -> STGEMResult: # Setup and run steps sequentially. self.step_results = [] - for step in self.steps: - self.step_results.append(step.run()) - + self.step_results.extend(step.run() for step in self.steps) return self._generate_result(self.step_results) def run(self, seed=None) -> STGEMResult: diff --git a/stgem/logger.py b/stgem/logger.py index c6084327..8a56172f 100644 --- a/stgem/logger.py +++ b/stgem/logger.py @@ -9,5 +9,5 @@ def __init__(self): def __call__(self, name, log): if not self.silent: - print("{}: {}".format(name, log)) + print(f"{name}: {log}") diff --git a/stgem/objective/objective.py b/stgem/objective/objective.py index a45a9737..596e2b92 100644 --- a/stgem/objective/objective.py +++ b/stgem/objective/objective.py @@ -17,9 +17,8 @@ def setup(self, sut): self.sut = sut def __getattr__(self, name): - if "parameters" in self.__dict__: - if name in self.parameters: - return self.parameters.get(name) + if "parameters" in self.__dict__ and name in self.parameters: + return self.parameters.get(name) raise AttributeError(name) @@ -33,7 +32,7 @@ class Minimize(Objective): def __init__(self, selected=None, scale=False, invert=False, clip=True): super().__init__() - if not (isinstance(selected, list) or isinstance(selected, tuple) or selected is None): + if not (isinstance(selected, (list, tuple)) or selected is None): raise Exception("The parameter 'selected' must be None or a list or a tuple.") self.parameters["selected"] = selected @@ -62,10 +61,7 @@ def __call__(self, t, r): else: output = v - if self.clip: - return max(0, min(1, min(output))) - else: - return min(output) + return max(0, min(1, min(output))) if self.clip else min(output) class FalsifySTL(Objective): """Objective function to falsify an STL specification. By default the @@ -185,7 +181,7 @@ def _evaluate_vector(self, test, output): idx = self.sut.inputs.index[var] trajectories[var] = np.array([test[idx]]) except ValueError: - raise Exception("Variable '{}' not in input or output variables.".format(var)) + raise Exception(f"Variable '{var}' not in input or output variables.") # Notice that the return value is a Cython MemoryView. #robustness_signal = self.specification.eval_interval(trajectories, timestamps) @@ -223,15 +219,13 @@ def _evaluate_signal(self, test, result): args.append(var) try: idx = self.sut.outputs.index(var) - args.append(output_timestamps) - args.append(output_signals[idx]) + args.extend((output_timestamps, output_signals[idx])) except ValueError: try: idx = self.sut.inputs.index(var) - args.append(input_timestamps) - args.append(input_signals[idx]) + args.extend((input_timestamps, input_signals[idx])) except ValueError: - raise Exception("Variable '{}' not in input or output variables.".format(var)) + raise Exception(f"Variable '{var}' not in input or output variables.") trajectories = STL.Traces.from_mixed_signals(*args, sampling_period=self.sampling_period) @@ -240,7 +234,9 @@ def _evaluate_signal(self, test, result): # Allow slight inaccuracy in horizon check. if self.strict_horizon_check and self.horizon - 1e-2 > trajectories.timestamps[-1]: - raise Exception("The horizon {} of the formula is too long compared to signal length {}. The robustness cannot be computed.".format(self.horizon, trajectories.timestamps[-1])) + raise Exception( + f"The horizon {self.horizon} of the formula is too long compared to signal length {trajectories.timestamps[-1]}. The robustness cannot be computed." + ) # Adjust time bounds. self.adjust_time_bounds() diff --git a/stgem/objective_selector/objective_selector.py b/stgem/objective_selector/objective_selector.py index 171bc076..11c87f3d 100644 --- a/stgem/objective_selector/objective_selector.py +++ b/stgem/objective_selector/objective_selector.py @@ -12,9 +12,8 @@ def setup(self, objectives): self.dim = len(objectives) def __getattr__(self, name): - if "parameters" in self.__dict__: - if name in self.parameters: - return self.parameters.get(name) + if "parameters" in self.__dict__ and name in self.parameters: + return self.parameters.get(name) raise AttributeError(name) @@ -55,13 +54,12 @@ def setup(self, objectives): def select(self): if self.total_calls <= self.warm_up: return self.select_all() - else: - p = [s / self.total_calls for s in self.model_successes] - return [np.random.choice(range(0, self.dim), p=p)] + p = [s / self.total_calls for s in self.model_successes] + return [np.random.choice(range(0, self.dim), p=p)] def update(self, idx): try: self.model_successes[idx] += 1 except IndexError: - raise Exception("No model with index {}.".format(idx)) + raise Exception(f"No model with index {idx}.") self.total_calls += 1 diff --git a/stgem/sut/__init__.py b/stgem/sut/__init__.py index 71024d42..aab46c3f 100644 --- a/stgem/sut/__init__.py +++ b/stgem/sut/__init__.py @@ -47,8 +47,7 @@ def setup(self, sut, objectives, rng): def is_valid(self, test) -> bool: # This is here until valid tests are changed to preconditions. This # line ensures that model-based SUTs work and can be pickled. - if self.sut is None: return 1 - return self.sut.validity(test) + return 1 if self.sut is None else self.sut.validity(test) def sample_input_space(self): return self.rng.uniform(-1, 1, size=self.input_dimension) @@ -68,20 +67,19 @@ def __init__(self, parameters=None): #self.parameters = self.default_parameters | parameters self.parameters = parameters for key in self.default_parameters: - if not key in self.parameters: + if key not in self.parameters: self.parameters[key] = self.default_parameters[key] - if not "input_type" in self.parameters: + if "input_type" not in self.parameters: self.parameters["input_type"] = None - if not "output_type" in self.parameters: + if "output_type" not in self.parameters: self.parameters["output_type"] = None self.base_has_been_setup = False def __getattr__(self, name): - if "parameters" in self.__dict__: - if name in self.parameters: - return self.parameters.get(name) + if "parameters" in self.__dict__ and name in self.parameters: + return self.parameters.get(name) raise AttributeError(name) @@ -103,43 +101,40 @@ def setup(self): if hasattr(self, "inputs") and isinstance(self.inputs, int): if not hasattr(self, "idim"): self.idim = self.inputs - self.inputs = ["i{}".format(i) for i in range(self.inputs)] + self.inputs = [f"i{i}" for i in range(self.inputs)] # If idim is not set, it can be inferred from input names (a list of # names) or input ranges. if hasattr(self, "idim"): # idim set already, set default input names if necessary. if not hasattr(self, "inputs"): - self.inputs = ["i{}".format(i) for i in range(self.idim)] + self.inputs = [f"i{i}" for i in range(self.idim)] + elif hasattr(self, "inputs"): + self.idim = len(self.inputs) else: - # idim can be inferred from input names, if defined. - if hasattr(self, "inputs"): - self.idim = len(self.inputs) - else: - # idim can be inferred from input ranges. Otherwise we do not - # know what to do. - if not hasattr(self, "input_range"): - raise Exception("SUT input dimension not defined and cannot be inferred.") - self.idim = len(self.input_range) - self.inputs = ["i{}".format(i) for i in range(self.idim)] + # idim can be inferred from input ranges. Otherwise we do not + # know what to do. + if not hasattr(self, "input_range"): + raise Exception("SUT input dimension not defined and cannot be inferred.") + self.idim = len(self.input_range) + self.inputs = [f"i{i}" for i in range(self.idim)] # The same as above for outputs. if hasattr(self, "outputs") and isinstance(self.outputs, int): if not hasattr(self, "odim"): self.odim = self.outputs - self.outputs = ["o{}".format(i) for i in range(self.outputs)] + self.outputs = [f"o{i}" for i in range(self.outputs)] if hasattr(self, "odim"): if not hasattr(self, "outputs"): - self.outputs = ["o{}".format(i) for i in range(self.odim)] + self.outputs = [f"o{i}" for i in range(self.odim)] + elif hasattr(self, "outputs"): + self.odim = len(self.outputs) else: - if hasattr(self, "outputs"): - self.odim = len(self.outputs) - else: - if not hasattr(self, "output_range"): - raise Exception("SUT output dimension not defined and cannot be inferred.") - self.odim = len(self.output_range) - self.outputs = ["o{}".format(i) for i in range(self.odim)] + if not hasattr(self, "output_range"): + raise Exception("SUT output dimension not defined and cannot be inferred.") + self.odim = len(self.output_range) + self.outputs = [f"o{i}" for i in range(self.odim)] # Setup input and output ranges and fill unspecified input and output # ranges with Nones. @@ -169,7 +164,7 @@ def variable_range(self, var_name): if var_name == v: return self.input_range[n] - raise Exception("No variable '{}'.".format(var_name)) + raise Exception(f"No variable '{var_name}'.") def scale(self, x, intervals, target_A=-1, target_B=1): """ @@ -180,7 +175,8 @@ def scale(self, x, intervals, target_A=-1, target_B=1): if len(intervals) < x.shape[1]: raise Exception( - "Not enough intervals ({}) for scaling a vector of length {}.".format(len(intervals), x.shape[1])) + f"Not enough intervals ({len(intervals)}) for scaling a vector of length {x.shape[1]}." + ) y = np.zeros_like(x) for i in range(x.shape[1]): @@ -221,7 +217,8 @@ def descale(self, x, intervals, A=-1, B=1): if len(intervals) < x.shape[1]: raise Exception( - "Not enough intervals ({}) for descaling a vector of length {}.".format(len(intervals), x.shape[1])) + f"Not enough intervals ({len(intervals)}) for descaling a vector of length {x.shape[1]}." + ) y = np.zeros_like(x) for i in range(x.shape[1]): diff --git a/stgem/sut/hyper/sut.py b/stgem/sut/hyper/sut.py index c0ec5023..c57a4a6e 100644 --- a/stgem/sut/hyper/sut.py +++ b/stgem/sut/hyper/sut.py @@ -41,7 +41,7 @@ class HyperParameter(SUT): def __init__(self, experiment_factory, parameters=None): super().__init__(parameters) - if not "hyperparameters" in self.parameters: + if "hyperparameters" not in self.parameters: raise Exception("No hyperparameters selected for optimization.") self.experiment_factory = experiment_factory diff --git a/stgem/sut/matlab/sut.py b/stgem/sut/matlab/sut.py index e5a26b90..809e0793 100644 --- a/stgem/sut/matlab/sut.py +++ b/stgem/sut/matlab/sut.py @@ -18,13 +18,15 @@ def __init__(self, parameters): mandatory_parameters = ["simulation_time", "sampling_step", "model_file"] for p in mandatory_parameters: - if not p in self.parameters: - raise Exception("Parameter '{}' not specified.".format(p)) + if p not in self.parameters: + raise Exception(f"Parameter '{p}' not specified.") # How often input signals are sampled for execution (in time units). self.steps = int(self.simulation_time // self.sampling_step) - if not os.path.exists(self.model_file + ".mdl") and not os.path.exists(self.model_file + ".slx"): + if not os.path.exists(f"{self.model_file}.mdl") and not os.path.exists( + f"{self.model_file}.slx" + ): raise Exception("Neither '{0}.mdl' nor '{0}.slx' exists.".format(self.model_file)) def setup_matlab(self): @@ -115,8 +117,10 @@ def __init__(self, parameters): mandatory_parameters = ["time_slices", "simulation_time", "sampling_step"] for p in mandatory_parameters: - if not p in self.parameters: - raise Exception("Parameter '{}' must be defined for piecewise constant signal inputs.".format(p)) + if p not in self.parameters: + raise Exception( + f"Parameter '{p}' must be defined for piecewise constant signal inputs." + ) # How many inputs we have for each input signal. self.pieces = [math.ceil(self.simulation_time / time_slice) for time_slice in self.time_slices] @@ -128,17 +132,19 @@ def setup(self): if self.has_been_setup: return - if not len(self.time_slices) == self.idim: - raise Exception("Expected {} time slices, found {}.".format(self.idim, len(self.time_slices))) + if len(self.time_slices) != self.idim: + raise Exception( + f"Expected {self.idim} time slices, found {len(self.time_slices)}." + ) self.N_signals = self.idim self.idim = sum(self.pieces) self.descaling_intervals = [] for i in range(len(self.input_range)): - for _ in range(self.pieces[i]): - self.descaling_intervals.append(self.input_range[i]) - + self.descaling_intervals.extend( + self.input_range[i] for _ in range(self.pieces[i]) + ) self.has_been_setup = True def _execute_test(self, test): @@ -188,24 +194,32 @@ def __init__(self, parameters): mandatory_parameters = ["model_file", "input_type", "output_type"] for p in mandatory_parameters: - if not p in self.parameters: - raise Exception("Parameter '{}' not specified.".format(p)) - - if not os.path.exists(self.model_file + ".m"): - raise Exception("The file '{}.m' does not exist.".format(self.model_file)) - if "init_model_file" in self.parameters and not os.path.exists(self.init_model_file + ".m"): - raise Exception("The file '{}.m' does not exist.".format(self.init_model_file)) - - if not self.input_type.lower() in ["vector", "piecewise constant signal", "signal"]: - raise Exception("Unknown Matlab call input type '{}'.".format(self.input_type)) - if not self.output_type.lower() in ["vector", "signal"]: - raise Exception("Unknown Matlab call output type '{}'.".format(self.output_type)) + if p not in self.parameters: + raise Exception(f"Parameter '{p}' not specified.") + + if not os.path.exists(f"{self.model_file}.m"): + raise Exception(f"The file '{self.model_file}.m' does not exist.") + if "init_model_file" in self.parameters and not os.path.exists( + f"{self.init_model_file}.m" + ): + raise Exception(f"The file '{self.init_model_file}.m' does not exist.") + + if self.input_type.lower() not in [ + "vector", + "piecewise constant signal", + "signal", + ]: + raise Exception(f"Unknown Matlab call input type '{self.input_type}'.") + if self.output_type.lower() not in ["vector", "signal"]: + raise Exception(f"Unknown Matlab call output type '{self.output_type}'.") if self.input_type == "piecewise constant signal": mandatory_parameters = ["time_slices", "simulation_time", "sampling_step"] for p in mandatory_parameters: - if not p in self.parameters: - raise Exception("Parameter '{}' must be defined for piecewise constant signal inputs.".format(p)) + if p not in self.parameters: + raise Exception( + f"Parameter '{p}' must be defined for piecewise constant signal inputs." + ) # How often input signals are sampled for execution (in time units). self.steps = int(self.simulation_time / self.sampling_step) @@ -241,17 +255,19 @@ def setup(self): # Adjust the SUT parameters if the input is a piecewise constant signal. if self.input_type == "piecewise constant signal": - if not len(self.time_slices) == self.idim: - raise Exception("Expected {} time slices, found {}.".format(self.idim, len(self.time_slices))) + if len(self.time_slices) != self.idim: + raise Exception( + f"Expected {self.idim} time slices, found {len(self.time_slices)}." + ) self.N_signals = self.idim self.idim = sum(self.pieces) self.descaling_intervals = [] for i in range(len(self.input_range)): - for _ in range(self.pieces[i]): - self.descaling_intervals.append(self.input_range[i]) - + self.descaling_intervals.extend( + self.input_range[i] for _ in range(self.pieces[i]) + ) self.has_been_setup = True def __del__(self): @@ -342,13 +358,11 @@ def _execute_test(self, test): offset += self.pieces[i] test.input_timestamps = timestamps - test.input_denormalized = signals else: timestamps = test.input_timestamps signals = test.inputs - test.input_denormalized = signals - + test.input_denormalized = signals if self.output_type == "vector": return self._execute_signal_vector(timestamps, signals) else: diff --git a/stgem/test_repository.py b/stgem/test_repository.py index 7bc73482..9a4ec5aa 100644 --- a/stgem/test_repository.py +++ b/stgem/test_repository.py @@ -75,7 +75,7 @@ def get(self, *args, **kwargs): return_list = True - if len(args) == 0: + if not args: # Return all tests. args = self.indices @@ -89,18 +89,22 @@ def get(self, *args, **kwargs): # Return multiple tests. include_all = "include_all" in kwargs and kwargs["include_all"] - X = []; Z = []; Y = [] + X = [] + Z = [] + Y = [] for i in args: if i >= self.tests or (i < 0 and i < -self.tests): - raise IndexError("Index {} out of bounds.".format(i)) + raise IndexError(f"Index {i} out of bounds.") if self._outputs[i].error is not None and not include_all: continue X.append(self._tests[i]) Z.append(self._outputs[i]) Y.append(self._objectives[i]) if not return_list: - if len(X) == 0: - raise Exception("The test with index {} failed to execute, so it is not returned. Set include_all=True to obtain it.".format(args[0])) + if not X: + raise Exception( + f"The test with index {args[0]} failed to execute, so it is not returned. Set include_all=True to obtain it." + ) X = X[0] Z = Z[0] Y = Y[0] @@ -118,15 +122,15 @@ def __init__(self, record): def timer_start(self, timer_id): if timer_id in self.timers and self.timers[timer_id] is not None: - raise Exception("Restarting timer '{}' without resetting.".format(timer_id)) + raise Exception(f"Restarting timer '{timer_id}' without resetting.") self.timers[timer_id] = time.perf_counter() def timer_reset(self, timer_id): - if not timer_id in self.timers: - raise Exception("No timer '{}' to be reset.".format(timer_id)) + if timer_id not in self.timers: + raise Exception(f"No timer '{timer_id}' to be reset.") if self.timers[timer_id] is None: - raise Exception("Timer '{}' already reset.".format(timer_id)) + raise Exception(f"Timer '{timer_id}' already reset.") time_elapsed = time.perf_counter() - self.timers[timer_id] self.timers[timer_id] = None @@ -142,8 +146,8 @@ def timers_resume(self): self.timers_hold() def obtain(self, performance_id): - if not performance_id in self._record: - raise Exception("No record with identifier '{}'.".format(performance_id)) + if performance_id not in self._record: + raise Exception(f"No record with identifier '{performance_id}'.") return self._record[performance_id] def record(self, performance_id, value): diff --git a/stl/robustness.py b/stl/robustness.py index 3e84063d..bd36b31e 100644 --- a/stl/robustness.py +++ b/stl/robustness.py @@ -35,64 +35,53 @@ def update(self, start_pos, end_pos): start = max(start, 0) if start >= end: - raise Exception("Window start position {} before its end position {}.".format(start, end)) + raise Exception( + f"Window start position {start} before its end position {end}." + ) # We have three areas we need to care about: an overlap, for which we # hopefully know the answer, and two areas to the left and right of the # overlap. Each of these three areas can be empty. if start < self.prev_start_pos: + # Disjoint and to the left. + l_s = start if end <= self.prev_start_pos: - # Disjoint and to the left. - l_s = start l_e = end o_s = -1 o_e = -1 r_s = -1 r_e = -1 else: + l_e = self.prev_start_pos if end <= self.prev_end_pos: - # Intersects from left but does not extend over to the right. - l_s = start - l_e = self.prev_start_pos - o_s = self.prev_start_pos o_e = end r_s = -1 r_e = -1 else: - # Contains the previous completely and has left and right areas nonempty. - l_s = start - l_e = self.prev_start_pos - o_s = self.prev_start_pos o_e = self.prev_end_pos r_s = self.prev_end_pos r_e = end + o_s = self.prev_start_pos else: if start >= self.prev_end_pos: - # Disjoint and to the right. - l_s = -1 - l_e = -1 o_s = -1 o_e = -1 r_s = start r_e = end else: + o_s = start if end <= self.prev_end_pos: - # Is contained completely in the previous. - l_s = -1 - l_e = -1 - o_s = start o_e = end r_s = -1 r_e = -1 else: - # Intersects from the right but does not extend over to the left. - l_s = -1 - l_e = -1 - o_s = start o_e = self.prev_end_pos r_s = self.prev_end_pos r_e = end + l_e = -1 + # Disjoint and to the right. + l_s = -1 # Find the minimums from each area. If the previous best value is not # in the overlap, we need to search the whole overlap. best_idx = -1 @@ -131,7 +120,7 @@ def __init__(self, timestamps, signals): raise ValueError("All signals must have exactly as many samples as there are timestamps.") @classmethod - def from_mixed_signals(C, *args, sampling_period=None): + def from_mixed_signals(cls, *args, sampling_period=None): """Instantiate the class from signals that have different timestamps (with 0 as a first timestamp) and different lengths. This is done by finding the maximum signal length and using that as a signal length, @@ -181,7 +170,7 @@ def from_mixed_signals(C, *args, sampling_period=None): value = signal_values[pos - 1] signals[name][n] = value - return C(timestamps, signals) + return cls(timestamps, signals) def search_time_index(self, t, start=0): """Finds the index of the time t in the timestamps using binary @@ -200,10 +189,7 @@ def search_time_index(self, t, start=0): middle = (lower_idx + upper_idx)//2 - if self.timestamps[middle] == t: - return middle - else: - return -1 + return middle if self.timestamps[middle] == t else -1 class TreeIterator: @@ -428,12 +414,9 @@ def __init__(self, formula): A = self.formulas[0].range[0] B = self.formulas[0].range[1] if A <= 0: - if B > 0: - self.range = [0, B] - else: - self.range = [-B, -A] + self.range = [0, B] if B > 0 else [-B, -A] else: - self.range = [A, B] + self.range = [A, B] self.horizon = 0 @@ -460,7 +443,7 @@ def __init__(self, left_formula, right_formula): B = self.formulas[0].range[1] - self.formulas[1].range[0] if A >= 0: self.range = [-B, -A] - elif A < 0 and B >= 0: + elif B >= 0: self.range = [-max(-A, B), 0] else: self.range = [A, B] @@ -555,18 +538,21 @@ def eval(self, traces, return_effective_range=True): # inaccuracies. We now raise an exception as otherwise the # user gets unexpected behavior. if lower_bound_pos < 0: - raise Exception("No timestamp '{}' found even though it should exist.".format(lower_bound)) + raise Exception( + f"No timestamp '{lower_bound}' found even though it should exist." + ) # Upper bound. if upper_bound > traces.timestamps[-1]: upper_bound_pos = len(traces.timestamps) - 1 + elif traces.timestamps[prev_upper_bound_pos - 1] == upper_bound: + upper_bound_pos = prev_upper_bound_pos - 1 else: - if traces.timestamps[prev_upper_bound_pos - 1] == upper_bound: - upper_bound_pos = prev_upper_bound_pos - 1 - else: - upper_bound_pos = traces.search_time_index(upper_bound, start=lower_bound_pos) + upper_bound_pos = traces.search_time_index(upper_bound, start=lower_bound_pos) # See above. - if upper_bound_pos < 0: - raise Exception("No timestamp '{}' found even though it should exist.".format(upper_bound)) + if upper_bound_pos < 0: + raise Exception( + f"No timestamp '{upper_bound}' found even though it should exist." + ) # Move a window with start position current_time_pos and end # position in the interval determined by lower_bound_pos and upper_bound_pos. @@ -650,27 +636,29 @@ def eval(self, traces, return_effective_range=True): # Lower bound. if lower_bound > traces.timestamps[-1]: lower_bound_pos = len(traces.timestamps) + elif traces.timestamps[prev_lower_bound_pos - 1] == lower_bound: + lower_bound_pos = prev_lower_bound_pos - 1 else: - if traces.timestamps[prev_lower_bound_pos - 1] == lower_bound: - lower_bound_pos = prev_lower_bound_pos - 1 - else: - lower_bound_pos = traces.search_time_index(lower_bound, start=current_time_pos) + lower_bound_pos = traces.search_time_index(lower_bound, start=current_time_pos) # TODO: This should never happen except for floating point # inaccuracies. We now raise an exception as otherwise the # user gets unexpected behavior. - if lower_bound_pos < 0: - raise Exception("No timestamp '{}' found even though it should exist.".format(lower_bound)) + if lower_bound_pos < 0: + raise Exception( + f"No timestamp '{lower_bound}' found even though it should exist." + ) # Upper bound. if upper_bound > traces.timestamps[-1]: upper_bound_pos = len(traces.timestamps) - 1 + elif traces.timestamps[prev_upper_bound_pos - 1] == upper_bound: + upper_bound_pos = prev_upper_bound_pos - 1 else: - if traces.timestamps[prev_upper_bound_pos - 1] == upper_bound: - upper_bound_pos = prev_upper_bound_pos - 1 - else: - upper_bound_pos = traces.search_time_index(upper_bound, start=lower_bound_pos) + upper_bound_pos = traces.search_time_index(upper_bound, start=lower_bound_pos) # See above. - if upper_bound_pos < 0: - raise Exception("No timestamp '{}' found even though it should exist.".format(upper_bound)) + if upper_bound_pos < 0: + raise Exception( + f"No timestamp '{upper_bound}' found even though it should exist." + ) # Slide a window corresponding to the indices and find the index of # the minimum. The value -1 signifies that the window was out of @@ -759,7 +747,7 @@ class And(STL): def __init__(self, *args, nu=None): self.formulas = list(args) self.nu = nu - + if self.nu is not None and self.nu <= 0: raise ValueError("The nu parameter must be positive.") @@ -767,10 +755,7 @@ def __init__(self, *args, nu=None): A = [f.range[0] if f.range is not None else None for f in self.formulas] B = [f.range[1] if f.range is not None else None for f in self.formulas] - if None in A or None in B: - self.range = None - else: - self.range = [min(A), min(B)] + self.range = None if None in A or None in B else [min(A), min(B)] def eval(self, traces, return_effective_range=True): if self.nu is None: diff --git a/stl/stlParser.py b/stl/stlParser.py index 26c3f488..14365b50 100644 --- a/stl/stlParser.py +++ b/stl/stlParser.py @@ -412,202 +412,178 @@ def accept(self, visitor:ParseTreeVisitor): def phi(self, _p:int=0): - _parentctx = self._ctx - _parentState = self.state - localctx = stlParser.PhiContext(self, self._ctx, _parentState) - _prevctx = localctx - _startState = 2 - self.enterRecursionRule(localctx, 2, self.RULE_phi, _p) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 35 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,2,self._ctx) - if la_ == 1: - localctx = stlParser.ParenPhiExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - - self.state = 12 - self.match(stlParser.LPAREN) - self.state = 13 - self.phi(0) - self.state = 14 - self.match(stlParser.RPAREN) - pass - - elif la_ == 2: - localctx = stlParser.OpNegExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 16 - self.match(stlParser.NEGATION) - self.state = 17 - self.phi(10) - pass - - elif la_ == 3: - localctx = stlParser.OpNextExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 18 - self.match(stlParser.NEXTOP) - self.state = 19 - self.phi(9) - pass - - elif la_ == 4: - localctx = stlParser.OpFutureExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 20 - self.match(stlParser.FUTUREOP) - self.state = 22 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,0,self._ctx) - if la_ == 1: - self.state = 21 - self.interval() - - - self.state = 24 - self.phi(8) - pass - - elif la_ == 5: - localctx = stlParser.OpGloballyExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 25 - self.match(stlParser.GLOBALLYOP) - self.state = 27 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,1,self._ctx) - if la_ == 1: - self.state = 26 - self.interval() - - - self.state = 29 - self.phi(7) - pass - - elif la_ == 6: - localctx = stlParser.PredicateExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 30 - self.signal(0) - self.state = 31 - _la = self._input.LA(1) - if not(_la==stlParser.RELOP or _la==stlParser.EQUALITYOP): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 32 - self.signal(0) - pass - - elif la_ == 7: - localctx = stlParser.SignalExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 34 - self.signal(0) - pass - - - self._ctx.stop = self._input.LT(-1) - self.state = 54 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,5,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - if self._parseListeners is not None: - self.triggerExitRuleEvent() - _prevctx = localctx - self.state = 52 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,4,self._ctx) - if la_ == 1: - localctx = stlParser.OpUntilExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) - self.state = 37 - if not self.precpred(self._ctx, 6): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") - self.state = 38 - self.match(stlParser.UNTILOP) - self.state = 40 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,3,self._ctx) - if la_ == 1: - self.state = 39 - self.interval() - - - self.state = 42 - self.phi(7) - pass - - elif la_ == 2: - localctx = stlParser.OpAndExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) - self.state = 43 - if not self.precpred(self._ctx, 5): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 5)") - self.state = 44 - self.match(stlParser.ANDOP) - self.state = 45 - self.phi(6) - pass - - elif la_ == 3: - localctx = stlParser.OpOrExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) - self.state = 46 - if not self.precpred(self._ctx, 4): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") - self.state = 47 - self.match(stlParser.OROP) - self.state = 48 - self.phi(5) - pass - - elif la_ == 4: - localctx = stlParser.OpPropExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) - self.state = 49 - if not self.precpred(self._ctx, 3): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") - self.state = 50 - _la = self._input.LA(1) - if not(_la==stlParser.IMPLIESOP or _la==stlParser.EQUIVOP): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 51 - self.phi(4) - pass - - - self.state = 56 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,5,self._ctx) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.unrollRecursionContexts(_parentctx) - return localctx + _parentctx = self._ctx + _parentState = self.state + localctx = stlParser.PhiContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 2 + self.enterRecursionRule(localctx, 2, self.RULE_phi, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 35 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,2,self._ctx) + if la_ == 1: + localctx = stlParser.ParenPhiExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 12 + self.match(stlParser.LPAREN) + self.state = 13 + self.phi(0) + self.state = 14 + self.match(stlParser.RPAREN) + elif la_ == 2: + localctx = stlParser.OpNegExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 16 + self.match(stlParser.NEGATION) + self.state = 17 + self.phi(10) + elif la_ == 3: + localctx = stlParser.OpNextExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 18 + self.match(stlParser.NEXTOP) + self.state = 19 + self.phi(9) + elif la_ == 4: + localctx = stlParser.OpFutureExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 20 + self.match(stlParser.FUTUREOP) + self.state = 22 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,0,self._ctx) + if la_ == 1: + self.state = 21 + self.interval() + + + self.state = 24 + self.phi(8) + elif la_ == 5: + localctx = stlParser.OpGloballyExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 25 + self.match(stlParser.GLOBALLYOP) + self.state = 27 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,1,self._ctx) + if la_ == 1: + self.state = 26 + self.interval() + + + self.state = 29 + self.phi(7) + elif la_ == 6: + localctx = stlParser.PredicateExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 30 + self.signal(0) + self.state = 31 + _la = self._input.LA(1) + if _la not in [stlParser.RELOP, stlParser.EQUALITYOP]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 32 + self.signal(0) + elif la_ == 7: + localctx = stlParser.SignalExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 34 + self.signal(0) + self._ctx.stop = self._input.LT(-1) + self.state = 54 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,5,self._ctx) + while _alt not in [2, ATN.INVALID_ALT_NUMBER]: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 52 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,4,self._ctx) + if la_ == 1: + localctx = stlParser.OpUntilExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) + self.state = 37 + if not self.precpred(self._ctx, 6): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") + self.state = 38 + self.match(stlParser.UNTILOP) + self.state = 40 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,3,self._ctx) + if la_ == 1: + self.state = 39 + self.interval() + + + self.state = 42 + self.phi(7) + elif la_ == 2: + localctx = stlParser.OpAndExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) + self.state = 43 + if not self.precpred(self._ctx, 5): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 5)") + self.state = 44 + self.match(stlParser.ANDOP) + self.state = 45 + self.phi(6) + elif la_ == 3: + localctx = stlParser.OpOrExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) + self.state = 46 + if not self.precpred(self._ctx, 4): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") + self.state = 47 + self.match(stlParser.OROP) + self.state = 48 + self.phi(5) + elif la_ == 4: + localctx = stlParser.OpPropExprContext(self, stlParser.PhiContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_phi) + self.state = 49 + if not self.precpred(self._ctx, 3): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 50 + _la = self._input.LA(1) + if _la not in [stlParser.IMPLIESOP, stlParser.EQUIVOP]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 51 + self.phi(4) + self.state = 56 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,5,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx class SignalContext(ParserRuleContext): @@ -751,118 +727,109 @@ def accept(self, visitor:ParseTreeVisitor): def signal(self, _p:int=0): - _parentctx = self._ctx - _parentState = self.state - localctx = stlParser.SignalContext(self, self._ctx, _parentState) - _prevctx = localctx - _startState = 4 - self.enterRecursionRule(localctx, 4, self.RULE_signal, _p) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 68 - self._errHandler.sync(self) - token = self._input.LA(1) - if token in [stlParser.NUMBER]: - localctx = stlParser.SignalNumberContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - - self.state = 58 - self.match(stlParser.NUMBER) - pass - elif token in [stlParser.NAME]: - localctx = stlParser.SignalNameContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 59 - self.match(stlParser.NAME) - pass - elif token in [stlParser.LPAREN]: - localctx = stlParser.SignalParenthesisExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 60 - self.match(stlParser.LPAREN) - self.state = 61 - self.signal(0) - self.state = 62 - self.match(stlParser.RPAREN) - pass - elif token in [stlParser.VBAR]: - localctx = stlParser.SignalAbsExprContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 64 - self.match(stlParser.VBAR) - self.state = 65 - self.signal(0) - self.state = 66 - self.match(stlParser.VBAR) - pass - else: - raise NoViableAltException(self) - - self._ctx.stop = self._input.LT(-1) - self.state = 78 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,8,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: - if self._parseListeners is not None: - self.triggerExitRuleEvent() - _prevctx = localctx - self.state = 76 - self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,7,self._ctx) - if la_ == 1: - localctx = stlParser.SignalMultExprContext(self, stlParser.SignalContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_signal) - self.state = 70 - if not self.precpred(self._ctx, 3): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") - self.state = 71 - _la = self._input.LA(1) - if not(_la==stlParser.MULT or _la==stlParser.DIV): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 72 - self.signal(4) - pass - - elif la_ == 2: - localctx = stlParser.SignalSumExprContext(self, stlParser.SignalContext(self, _parentctx, _parentState)) - self.pushNewRecursionContext(localctx, _startState, self.RULE_signal) - self.state = 73 - if not self.precpred(self._ctx, 2): - from antlr4.error.Errors import FailedPredicateException - raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") - self.state = 74 - _la = self._input.LA(1) - if not(_la==stlParser.PLUS or _la==stlParser.MINUS): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 75 - self.signal(3) - pass - - - self.state = 80 - self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,8,self._ctx) - - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.unrollRecursionContexts(_parentctx) - return localctx + _parentctx = self._ctx + _parentState = self.state + localctx = stlParser.SignalContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 4 + self.enterRecursionRule(localctx, 4, self.RULE_signal, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 68 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [stlParser.NUMBER]: + localctx = stlParser.SignalNumberContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 58 + self.match(stlParser.NUMBER) + elif token in [stlParser.NAME]: + localctx = stlParser.SignalNameContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 59 + self.match(stlParser.NAME) + elif token in [stlParser.LPAREN]: + localctx = stlParser.SignalParenthesisExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 60 + self.match(stlParser.LPAREN) + self.state = 61 + self.signal(0) + self.state = 62 + self.match(stlParser.RPAREN) + elif token in [stlParser.VBAR]: + localctx = stlParser.SignalAbsExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 64 + self.match(stlParser.VBAR) + self.state = 65 + self.signal(0) + self.state = 66 + self.match(stlParser.VBAR) + else: + raise NoViableAltException(self) + + self._ctx.stop = self._input.LT(-1) + self.state = 78 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + while _alt not in [2, ATN.INVALID_ALT_NUMBER]: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 76 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) + if la_ == 1: + localctx = stlParser.SignalMultExprContext(self, stlParser.SignalContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_signal) + self.state = 70 + if not self.precpred(self._ctx, 3): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 71 + _la = self._input.LA(1) + if _la not in [stlParser.MULT, stlParser.DIV]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 72 + self.signal(4) + elif la_ == 2: + localctx = stlParser.SignalSumExprContext(self, stlParser.SignalContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_signal) + self.state = 73 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 74 + _la = self._input.LA(1) + if _la not in [stlParser.PLUS, stlParser.MINUS]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 75 + self.signal(3) + self.state = 80 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx class IntervalContext(ParserRuleContext): @@ -913,61 +880,61 @@ def accept(self, visitor:ParseTreeVisitor): def interval(self): - localctx = stlParser.IntervalContext(self, self._ctx, self.state) - self.enterRule(localctx, 6, self.RULE_interval) - self._la = 0 # Token type - try: - self.enterOuterAlt(localctx, 1) - self.state = 81 - _la = self._input.LA(1) - if not(_la==stlParser.LPAREN or _la==stlParser.LBRACK): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 82 - _la = self._input.LA(1) - if not(_la==stlParser.INF or _la==stlParser.NUMBER): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 83 - self.match(stlParser.COMMA) - self.state = 84 - _la = self._input.LA(1) - if not(_la==stlParser.INF or _la==stlParser.NUMBER): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - self.state = 85 - _la = self._input.LA(1) - if not(_la==stlParser.RPAREN or _la==stlParser.RBRACK): - self._errHandler.recoverInline(self) - else: - self._errHandler.reportMatch(self) - self.consume() - except RecognitionException as re: - localctx.exception = re - self._errHandler.reportError(self, re) - self._errHandler.recover(self, re) - finally: - self.exitRule() - return localctx + localctx = stlParser.IntervalContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_interval) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 81 + _la = self._input.LA(1) + if _la not in [stlParser.LPAREN, stlParser.LBRACK]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 82 + _la = self._input.LA(1) + if _la not in [stlParser.INF, stlParser.NUMBER]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 83 + self.match(stlParser.COMMA) + self.state = 84 + _la = self._input.LA(1) + if _la not in [stlParser.INF, stlParser.NUMBER]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 85 + _la = self._input.LA(1) + if _la not in [stlParser.RPAREN, stlParser.RBRACK]: + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): - if self._predicates == None: - self._predicates = dict() - self._predicates[1] = self.phi_sempred - self._predicates[2] = self.signal_sempred - pred = self._predicates.get(ruleIndex, None) - if pred is None: - raise Exception("No predicate with index:" + str(ruleIndex)) - else: - return pred(localctx, predIndex) + if self._predicates is None: + self._predicates = {} + self._predicates[1] = self.phi_sempred + self._predicates[2] = self.signal_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception(f"No predicate with index:{ruleIndex}") + else: + return pred(localctx, predIndex) def phi_sempred(self, localctx:PhiContext, predIndex:int): if predIndex == 0: diff --git a/stl/visitor.py b/stl/visitor.py index 76256831..d93e3c63 100644 --- a/stl/visitor.py +++ b/stl/visitor.py @@ -108,11 +108,7 @@ def visitOpAndExpr(self, ctx:stlParser.OpAndExprContext): else: formulas.append(phi2) - if hasattr(self, "nu"): - nu = self.nu - else: - nu = None - + nu = self.nu if hasattr(self, "nu") else None return And(*formulas, nu=nu) @@ -146,11 +142,7 @@ def visitOpOrExpr(self, ctx:stlParser.OpOrExprContext): else: formulas.append(phi2) - if hasattr(self, "nu"): - nu = self.nu - else: - nu = None - + nu = self.nu if hasattr(self, "nu") else None return Or(*formulas, nu=nu) diff --git a/tests/test_mo3d_platypus_python.py b/tests/test_mo3d_platypus_python.py index 486fb248..b3933214 100644 --- a/tests/test_mo3d_platypus_python.py +++ b/tests/test_mo3d_platypus_python.py @@ -36,7 +36,7 @@ def test_plattypus1(self): ) r = generator.run() - file_name = generator.description + ".pickle" + file_name = f"{generator.description}.pickle" r.dump_to_file(file_name) os.remove(file_name) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index e78b46e0..09118ffb 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -37,7 +37,7 @@ def test_dump(self): ) r = generator.run() - file_name = generator.description + ".pickle" + file_name = f"{generator.description}.pickle" r.dump_to_file(file_name) result2 = STGEMResult.restore_from_file(file_name) diff --git a/tests/test_stl.py b/tests/test_stl.py index cec8ceb9..40ceb42a 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -32,15 +32,10 @@ def get_with_range(self, specification, timestamps, signals, ranges, time, nu=No if isinstance(node, (STL.Global, STL.Until)): time_bounded.append(node) if isinstance(node, STL.Finally): - time_bounded.append(node) - time_bounded.append(node.formula_robustness.formulas[0]) - + time_bounded.extend((node, node.formula_robustness.formulas[0])) args = [] for var in formula_variables: - args.append(var) - args.append(timestamps) - args.append(signals[var]) - + args.extend((var, timestamps, signals[var])) # Adjust time bounds to integers. sampling_period = 1/10 for x in time_bounded: