From e939ea2a1aabcf9c548c320c25d58b1ef95cfabc Mon Sep 17 00:00:00 2001 From: Nathan Shreve Date: Mon, 12 Aug 2024 21:21:29 -0400 Subject: [PATCH] Fixed multithreading in parsing script --- .../profiling_utils/parse_lookahead_data.py | 83 +++++++------------ 1 file changed, 31 insertions(+), 52 deletions(-) diff --git a/vtr_flow/scripts/profiling_utils/parse_lookahead_data.py b/vtr_flow/scripts/profiling_utils/parse_lookahead_data.py index 4a39099b780..4041f104e17 100755 --- a/vtr_flow/scripts/profiling_utils/parse_lookahead_data.py +++ b/vtr_flow/scripts/profiling_utils/parse_lookahead_data.py @@ -12,10 +12,11 @@ import seaborn as sns import argparse from pathlib import Path -from multiprocessing import Process, Lock, Queue +from multiprocessing import Lock +from concurrent.futures import ThreadPoolExecutor # Output directory -output_dir = "../tasks/lookahead_verifier_output" +output_dir = "./vtr_flow/tasks/lookahead_verifier_output" # The graph types (pie, heatmap, bar, scatter) that will be created graph_types: list # The components that will be used for graphs (cost, delay, congestion) @@ -75,9 +76,9 @@ "test name" ] -# Lock and Queue for multithreading +# Lock and Pool for multithreading lock = Lock() -q = Queue() +pool = ThreadPoolExecutor(1) # Check if a component is valid, otherwise raise exception @@ -350,10 +351,7 @@ def make_standard_scatter_plots(self, test_name_plot: bool): if first_it and col == "iteration no.": continue - proc = Process( - target=self.make_scatter_plot, args=(comp, plot_type, col, first_it) - ) - q.put(proc) + pool.submit(self.make_scatter_plot, comp, plot_type, col, first_it) # Create a bar graph displaying average error # comp: The component (cost, delay, or congestion) @@ -427,6 +425,7 @@ def make_bar_graph(self, comp: str, column: str, first_it_only: bool, use_absolu avg_error_df.plot.bar(title=title, xlabel=column, ylabel=y_label, legend=False) self.write_exclusions_info() + print(os.path.join(curr_dir, file_name)) plt.savefig(os.path.join(curr_dir, file_name), dpi=300, bbox_inches="tight") plt.close() @@ -447,10 +446,7 @@ def make_standard_bar_graphs(self, test_name_plot: bool): for col in columns: for use_abs in [True, False]: for first_it in [True, False]: - proc = Process( - target=self.make_bar_graph, args=(comp, col, use_abs, first_it) - ) - q.put(proc) + pool.submit(self.make_bar_graph, comp, col, use_abs, first_it) # Create a heatmap comparing two quantitative columns # comp: The component (cost, delay, or congestion) @@ -559,23 +555,14 @@ def make_standard_heatmaps(self): for comp in components: for first_it in [True, False]: for use_abs in [True, False]: - proc = Process( - target=self.make_heatmap, - args=( - comp, - "sink cluster tile width", - "sink cluster tile height", - first_it, - use_abs, - ), - ) - q.put(proc) - - proc = Process( - target=self.make_heatmap, - args=(comp, "delta x", "delta y", first_it, use_abs), - ) - q.put(proc) + pool.submit(self.make_heatmap, + comp, + "sink cluster tile width", + "sink cluster tile height", + first_it, + use_abs, + ) + pool.submit(self.make_heatmap, comp, "delta x", "delta y", first_it, use_abs) # Create a pie chart showing the proportion of cases where error is under percent_error_threshold # comp: The component (cost, delay, or congestion) @@ -671,16 +658,13 @@ def make_standard_pie_charts(self, test_name_plot: bool): if test_name_plot: columns = self.__standard_bar_columns else: - columns = self.__standard_bar_columns[0:-1] + columns = self.__standard_bar_columns[:-1] for comp in components: for col in columns: for first_it in [True, False]: for weighted in [True, False]: - proc = Process( - target=self.make_pie_chart, args=(comp, col, first_it, weighted) - ) - q.put(proc) + pool.submit(self.make_pie_chart, comp, col, first_it, weighted) # Make "standard" graphs of all types. # test_name_plot: whether to create plots where data is split by test name. This option @@ -777,10 +761,7 @@ def make_csv(df_out: pd.DataFrame, file_name: str): # Write out the csv if csv_data and (not os.path.exists(os.path.join(directory, "data.csv")) or not no_replace): - proc = Process( - target=make_csv, args=(df, os.path.join(directory, "data.csv")) - ) - q.put(proc) + pool.submit(make_csv, df, os.path.join(directory, "data.csv")) if should_print: print("Created ", os.path.join(directory, "data.csv"), sep="") @@ -916,8 +897,8 @@ def main(): args = parser.parse_args() - global q - q = Queue(args.j) + global pool + pool = ThreadPoolExecutor(args.j) global graph_types global components @@ -1098,21 +1079,19 @@ def main(): # If --collect used, create output files for all csv files provided - if args.collect == "": - return - - results_folder = args.collect[0] - results_folder_path = os.path.join(output_dir, results_folder) - if len(args.dir_app) > 0: - results_folder_path += f"{args.dir_app[0]}" - make_dir(results_folder_path, False) + if args.collect != "": + results_folder = args.collect[0] + results_folder_path = os.path.join(output_dir, results_folder) + if len(args.dir_app) > 0: + results_folder_path += f"{args.dir_app[0]}" + make_dir(results_folder_path, False) - df_complete = df_complete.reset_index(drop=True) + df_complete = df_complete.reset_index(drop=True) - record_df_info(df_complete, results_folder_path) + record_df_info(df_complete, results_folder_path) - global_plots = Graphs(df_complete, os.path.join(results_folder, "plots"), "All Tests") - global_plots.make_standard_plots(True) + global_plots = Graphs(df_complete, os.path.join(results_folder, "plots"), "All Tests") + global_plots.make_standard_plots(True) if __name__ == "__main__":