diff --git a/claasp/cipher_modules/report.py b/claasp/cipher_modules/report.py index b1b36af6..648e4da4 100644 --- a/claasp/cipher_modules/report.py +++ b/claasp/cipher_modules/report.py @@ -653,30 +653,27 @@ def create_heatmap_subplot(self, i, graph_data, cipher_rounds): z_data = [x[32 * i: min(len(x), 32 * (i + 1))] for x in list(graph_data.values())] yrange = list(graph_data.keys()) xrange = list(range(i * 32, 32 * (i + 1))) - + fontsize = max(1, ceil(12//len(yrange))) heatmap = go.Heatmap( z=z_data, coloraxis='coloraxis', texttemplate="%{text}", text=[['{:.2f}'.format(float(y)) for y in x] for x in z_data], - textfont={'size': 12 // len(yrange)}, + textfont={'size': 2*fontsize}, x=xrange, y=yrange, zmin=0, zmax=1, zauto=False ) - tick_step_x = 1 - tick_step_y = 1 - layout_update = { f'xaxis{i + 1}': { 'tickmode': 'array', - 'tickvals': list(range(i * 32, 32 * (i + 1), tick_step_x)), - 'ticktext': [str(j) for j in range(i * 32, 32 * (i + 1), tick_step_x)], - 'tickfont': {'size': 12 // len(yrange)} + 'tickvals': xrange, + 'ticktext': [str(j) for j in range(i * 32, 32 * (i + 1))], + 'tickfont': {'size': 2*fontsize} }, f'yaxis{i + 1}': { 'tickmode': 'array', - 'tickvals': list(range(1, cipher_rounds + 1, tick_step_y)), - 'ticktext': [str(j) for j in range(1, cipher_rounds + 1, tick_step_y)], - 'tickfont': {'size': 12 // len(yrange)}, + 'tickvals': yrange, + 'ticktext': [str(j) for j in range(1, cipher_rounds + 1)], + 'tickfont': {'size': 2*fontsize}, 'autorange': 'reversed' } } @@ -810,61 +807,36 @@ def _produce_graph(self, output_directory=os.getcwd(), show_graph=False, fixed_i graph_data[i + 1] = [case[res_key][i]] if type(case[res_key][i]) != list else \ case[res_key][i] + df = pd.DataFrame.from_dict(graph_data).T if len(graph_data[1]) > 1: if case[ 'input_difference_value'] != fixed_input_difference and fixed_input_difference != None: continue num_subplots = int(ceil(len(graph_data[1]) / 32)) - num_graphs = ceil(self.cipher.number_of_rounds / 5) - fig_list = [] - for n in range(num_graphs): - subgraph_data = {} - rounds = [i for i in - range(n * 5 + 1, min((n + 1) * 5 + 1, self.cipher.number_of_rounds))] - for r in rounds: - subgraph_data[r] = graph_data[r] - fig = make_subplots(num_subplots, 1) - - fig.update_layout({ - 'coloraxis': {'colorscale': 'rdylgn', - 'cmin': 0, - 'cmax': 1}}) - for i in range(num_subplots): - heatmap, layout_update = self.create_heatmap_subplot(i, - subgraph_data, - self.cipher.number_of_rounds) - fig.add_trace(heatmap, i + 1, 1) - fig.update_layout(layout_update) - fig_list.append(fig) - - if num_graphs > 1: - num_rows = ceil(math.sqrt(num_graphs)) - complete_fig = make_subplots(rows=num_rows, cols=num_rows) - - for i, sub_fig in enumerate(fig_list): - row = i // num_rows + 1 - col = i % num_rows + 1 - - for trace in sub_fig.data: - complete_fig.add_trace(trace, row=row, col=col) - else: - complete_fig = fig_list[0] + fig = make_subplots(num_subplots, 1) + + fig.update_layout({ + 'coloraxis': {'colorscale': 'rdylgn', + 'cmin': 0, + 'cmax': 1}}) + for i in range(num_subplots): + heatmap, layout_update = self.create_heatmap_subplot(i, + graph_data, + self.cipher.number_of_rounds) + fig.add_trace(heatmap, i + 1, 1) + fig.update_layout(layout_update) + if not show_graph: - if num_graphs > 1: - for i, sub_fig in enumerate(fig_list): - sub_fig.write_image( - f"{output_directory}/{it}/{out}/{res}/{res}_{case['input_difference_value']}" - f"_{i}.png", scale=20) - else: - complete_fig.write_image( - output_directory + '/' + it + '/' + out + '/' + res + '/' + str( - res) + '_' + str(case['input_difference_value']) + '.png', scale=20) + + fig.write_image( + f"{output_directory}/{it}/{out}/{res}/{res}_{case['input_difference_value']}.png", + scale=20) else: - complete_fig.show(renderer='png') + fig.show(renderer='png') return - complete_fig.data = [] - complete_fig.layout = {} + fig.data = [] + fig.layout = {} else: fig = px.line(df, range_x=[1, self.cipher.number_of_rounds],