Skip to content

Commit

Permalink
Merge pull request #266 from Crypto-TII/aradi-fixes-mattia
Browse files Browse the repository at this point in the history
Aradi report fixes
  • Loading branch information
peacker authored Aug 21, 2024
2 parents b5353b0 + 5be058d commit bc5ba1d
Showing 1 changed file with 55 additions and 29 deletions.
84 changes: 55 additions & 29 deletions claasp/cipher_modules/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import plotly.graph_objects as go
import pandas as pd
import itertools
import math
import json
import shutil
from claasp.cipher_modules.statistical_tests.dieharder_statistical_tests import DieharderTests
Expand Down Expand Up @@ -77,9 +78,9 @@ def _latex_heatmap(table, table_string, bit_count):
table_string += "{"
for i in range(len(round)):
if i == len(round) - 1:
table_string += str("{:.3f}".format(round[i]))
table_string += f"{float(round[i]):.3f}"
else:
table_string += str("{:.3f}".format(round[i])) + ","
table_string += f"{float(round[i]):.3f},"
table_string += "},\n\t\t"
table_string += ("} {\n\t%heatmap tiles\n\t\\foreach\\x [count=\\m] in \\y {\n\t\t\\pgfmathsetmacro{"
"\\colorgradient}{\\x * 100}\n\t\t\\node[fill=green!\\colorgradient!red, minimum size=6mm, "
Expand Down Expand Up @@ -186,8 +187,10 @@ def show(self, show_as_hex=False, test_name=None, fixed_input='plaintext', fixed
Component_Analysis.print_component_analysis_as_radar_charts(results=self.test_report['test_results'])
return
elif 'avalanche_tests' == self.test_name:
test_list = self.test_report['test_results']['plaintext']['round_output'].keys()
test_list = list(self.test_report['test_results']['plaintext']['round_output'].keys())
print(test_list)
if test_name not in test_list:
print(test_name)
print('Error! Invalid test name. The report.show function requires a test_name input')
print('test_name has to be one of the following : ', end='')
print(test_list)
Expand Down Expand Up @@ -384,14 +387,14 @@ def _export(self, file_format, output_dir, fixed_input=None, fixed_output=None,
table_string = _latex_heatmap(table_split, table_string, bit_count)

table_string += "\\caption{" + self.test_name.replace("_",
"\\_") + "\\_" + it + "\\_" + out.replace(
"_", "\\_") + "\\_" + test.replace("_", "\\_") + "\\_" + result[
"-'") + "-" + it + "-" + out.replace(
"_", "-") + "-" + test.replace("_", "-") + "-" + result[
"input_difference_value"] + ("}"
"\\label{fig:" + self.test_name.replace(
"_", "\\_") + "\\_" + it + "\\_" + out.replace("_",
"\\_") + "\\_" + test.replace(
"_", "-") + "-" + it + "-" + out.replace("_",
"-") + "-" + test.replace(
"_",
"\\_") + "\\_" +
"-") + "-" +
result[
"input_difference_value"] + "}\n")
table_string += "\\end{figure}"
Expand Down Expand Up @@ -460,7 +463,7 @@ def _update_out_list(self, out_list, rel_prob, abs_prob, show_as_hex, comp_id, w
else:
word_list = [
'*' if '*' in ''.join(bin_list[x:x + word_size]) else hex(int(''.join(bin_list[x:x + word_size]), 2))[
2:].zfill(int(word_size/4)) for x
2:].zfill(int(word_size / 4)) for x
in range(0, len(bin_list), word_size)]

if ('intermediate' in comp_id or 'cipher' in comp_id) and comp_id not in key_flow:
Expand Down Expand Up @@ -646,6 +649,36 @@ def _print_trail(self, show_as_hex, word_size, state_size, key_state_size, verbo
if show_key_flow:
self._print_key_flow(key_flow, show_components, out_list, verbose, file)

def create_heatmap_subplot(self, i, graph_data, cipher_rounds):
z_data = [x[32 * i: min(len(x), 32 * (i + 1))] for x in list(graph_data.values())]
yrange = list(graph_data.keys())
xrange = list(range(i * 32, 32 * (i + 1)))
fontsize = max(1, ceil(12//len(yrange)))
heatmap = go.Heatmap(
z=z_data, coloraxis='coloraxis', texttemplate="%{text}",
text=[['{:.2f}'.format(float(y)) for y in x] for x in z_data],
textfont={'size': 2*fontsize},
x=xrange,
y=yrange, zmin=0, zmax=1, zauto=False
)

layout_update = {
f'xaxis{i + 1}': {
'tickmode': 'array',
'tickvals': xrange,
'ticktext': [str(j) for j in range(i * 32, 32 * (i + 1))],
'tickfont': {'size': 2*fontsize}
},
f'yaxis{i + 1}': {
'tickmode': 'array',
'tickvals': yrange,
'ticktext': [str(j) for j in range(1, cipher_rounds + 1)],
'tickfont': {'size': 2*fontsize},
'autorange': 'reversed'
}
}
return heatmap, layout_update

def _produce_graph(self, output_directory=os.getcwd(), show_graph=False, fixed_input=None, fixed_output=None,
fixed_input_difference=None, test_name=None):

Expand Down Expand Up @@ -774,45 +807,38 @@ def _produce_graph(self, output_directory=os.getcwd(), show_graph=False, fixed_i
graph_data[i + 1] = [case[res_key][i]] if type(case[res_key][i]) != list else \
case[res_key][i]


df = pd.DataFrame.from_dict(graph_data).T
if len(graph_data[1]) > 1:
if case[
'input_difference_value'] != fixed_input_difference and fixed_input_difference != None:
continue
num_subplots = int(ceil(len(graph_data[1]) / 32))
fig = make_subplots(num_subplots, 1, vertical_spacing=0.08)
fig = make_subplots(num_subplots, 1)

fig.update_layout({
'coloraxis': {'colorscale': 'rdylgn',
'cmin': 0,
'cmax': 1}})
for i in range(num_subplots):
z_data = [x[32 * i: min(len(x), 32 * (i + 1))] for x in list(graph_data.values())]
fig.add_trace(go.Heatmap(z=z_data, coloraxis='coloraxis', texttemplate="%{text}",
text=[['{:.3f}'.format(float(y)) for y in x] for x in
z_data],
x=list(range(i * 32, 32 * (i + 1))),
y=list(range(1, self.cipher.number_of_rounds + 1)), zmin=0,
zmax=1, zauto=False),
i + 1, 1)
fig.update_layout({
'font': {'size': 8},
'xaxis' + str(i + 1): {'tick0': 0, 'dtick': 1, 'nticks': len(z_data),
'tickfont': {'size': 8}},
'yaxis' + str(i + 1): {'tick0': 0, 'dtick': 1,
'tickfont': {'size': 8}, 'autorange': 'reversed'}
})
heatmap, layout_update = self.create_heatmap_subplot(i,
graph_data,
self.cipher.number_of_rounds)
fig.add_trace(heatmap, i + 1, 1)
fig.update_layout(layout_update)

if show_graph == False:
fig.write_image(output_directory + '/' + it + '/' + out + '/' + res + '/' + str(
res) + '_' + str(case['input_difference_value']) + '.png', scale=4)
if not show_graph:

fig.write_image(
f"{output_directory}/{it}/{out}/{res}/{res}_{case['input_difference_value']}.png",
scale=20)
else:
fig.show(renderer='png')
return
fig.data = []
fig.layout = {}

else:

fig = px.line(df, range_x=[1, self.cipher.number_of_rounds],
range_y=[0, 1])
fig.update_layout(xaxis_title="round", yaxis_title=res_key,
Expand Down

0 comments on commit bc5ba1d

Please sign in to comment.