Skip to content

Commit

Permalink
remove commented code, add verbosity for printing inside a methode
Browse files Browse the repository at this point in the history
  • Loading branch information
SiMohamedRachidi committed Dec 4, 2024
1 parent 984ada1 commit 97936a8
Showing 1 changed file with 13 additions and 32 deletions.
45 changes: 13 additions & 32 deletions claasp/cipher_modules/division_trail_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from gurobipy import Model, GRB
import os

verbosity = False

class MilpDivisionTrailModel():
"""
Expand Down Expand Up @@ -190,7 +192,6 @@ def add_sbox_constraints(self, component):
x = B.variable_names()
anfs = self.get_anfs_from_sbox(component)
anfs = [B(anfs[i]) for i in range(component.input_bit_size)]
# print(anfs)

copy_monomials_deg = self.create_gurobi_vars_sbox(component, input_vars_concat)

Expand Down Expand Up @@ -270,13 +271,10 @@ def add_linear_layer_constraints(self, component):

def add_xor_constraints(self, component):
output_vars = self.get_output_vars(component)
# print(output_vars)

input_vars_concat = []
constant_flag = []
for index, input_name in enumerate(component.input_id_links):
# print(input_name)
# print(self._variables[input_name])
for pos in component.input_bit_positions[index]:
current = self._variables[input_name][pos]["current"]
if input_name[:8] == "constant":
Expand All @@ -286,20 +284,15 @@ def add_xor_constraints(self, component):
else:
input_vars_concat.append(self._variables[input_name][pos][current])
self._variables[input_name][pos]["current"] += 1
# print(input_vars_concat)

block_size = component.output_bit_size
nb_blocks = component.description[1]
if constant_flag != []:
nb_blocks -= 1
# print(self._occurences[component.id])
# print(list(self._occurences[component.id].keys()))
# print(len(list(self._occurences[component.id].keys())))
for index, bit_pos in enumerate(list(self._occurences[component.id].keys())):
constr = 0
for j in range(nb_blocks):
constr += input_vars_concat[index + block_size * j]
# print(input_vars_concat[index + block_size * j])
self.set_as_used_variables([input_vars_concat[index + block_size * j]])
if (constant_flag != []) and (constant_flag[index]):
self._model.addConstr(output_vars[index] >= constr)
Expand Down Expand Up @@ -432,7 +425,6 @@ def add_constraints(self, predecessors, input_id_link_needed, block_needed):
for component_id in used_predecessors_sorted:
if component_id not in self._cipher.inputs:
component = self._cipher.get_component_from_id(component_id)
# print(f"---------> {component.id}")
if component.type == "sbox":
self.add_sbox_constraints(component)
elif component.type in ["linear_layer", "mix_column"]:
Expand Down Expand Up @@ -473,13 +465,9 @@ def get_where_component_is_used(self, predecessors, input_id_link_needed, block_
component = self._cipher.get_component_from_id(input_id_link_needed)
occurences[input_id_link_needed] = [[i for i in range(component.output_bit_size)]]

# print("occurences")
# print(occurences)
occurences_final = {}
for component_id in occurences.keys():
occurences_final[component_id] = self.find_copy_indexes(occurences[component_id])
# print("occurences_final")
# print(occurences_final)

self._occurences = occurences_final
return occurences_final
Expand Down Expand Up @@ -521,13 +509,10 @@ def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_nee
occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed)
all_vars = {}
used_predecessors_sorted = self.order_predecessors(list(occurences.keys()))
# print("used_predecessors_sorted")
# print(used_predecessors_sorted)
for component_id in used_predecessors_sorted:
all_vars[component_id] = {}
# We need the inputs vars to be the first ones defined by gurobi in order to find their values with X.values method.
# That's why we split the following loop: we first created the original vars, and then the copies vars when necessary.
# print(f"###### {component_id}")
if component_id[:3] == "rot":
component = self._cipher.get_component_from_id(component_id)
rotate_offset = component.description[1]
Expand Down Expand Up @@ -662,10 +647,7 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex
start = time.time()
output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot = self.get_output_bit_index_previous_component(
output_bit_index_ciphertext, chosen_cipher_output)
# print(output_id)
# print(block_needed)
# print(input_id_link_needed)
# print(output_bit_index_previous_comp)

self._output_id = output_id
self._output_bit_index_previous_comp = output_bit_index_previous_comp
self._block_needed = block_needed
Expand All @@ -682,16 +664,12 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex
var_from_block_needed = []
for i in block_needed:
var_from_block_needed.append(self._variables[input_id_link_needed][i][0])
# print("var_from_block_needed")
# print(var_from_block_needed)

output_vars = self._model.addVars(list(range(pivot, pivot + len(block_needed))), vtype=GRB.BINARY,
name=output_id)
self._variables[output_id] = output_vars
output_vars = list(output_vars.values())
self._model.update()
# print("output_vars")
# print(output_vars)

for i in range(len(block_needed)):
self._model.addConstr(output_vars[i] == var_from_block_needed[i])
Expand All @@ -712,10 +690,10 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex

self.set_unused_variables_to_zero()
self._model.update()
# self._model.write("division_trail_model.lp")
end = time.time()
building_time = end - start
print(f"########## building_time : {building_time}")
if verbosity:
print(f"########## building_time : {building_time}")
self._model.update()

def get_solutions(self):
Expand All @@ -734,7 +712,6 @@ def get_solutions(self):
first_input_bit_positions = list(self._occurences[self._cipher.inputs[0]].keys())

solCount = self._model.SolCount
print('Number of solutions (might cancel each other) found: ' + str(solCount))
monomials = []
for sol in range(solCount):
self._model.setParam(GRB.Param.SolutionNumber, sol)
Expand Down Expand Up @@ -770,22 +747,26 @@ def get_solutions(self):

end = time.time()
printing_time = end - start
print(f"########## printing_time : {printing_time}")
print(f'Number of monomials found: {len(monomials)}')
if verbosity:
print('Number of solutions (might cancel each other) found: ' + str(solCount))
print(f"########## printing_time : {printing_time}")
print(f'Number of monomials found: {len(monomials)}')
return monomials

def optimize_model(self):
print(self._model)
start = time.time()
self._model.optimize()
end = time.time()
solving_time = end - start
print(f"########## solving_time : {solving_time}")
if verbosity:
print(self._model)
print(f"########## solving_time : {solving_time}")

def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)
self._model.write("division_trail_model.lp")

self.optimize_model()
return self.get_solutions()
Expand Down

0 comments on commit 97936a8

Please sign in to comment.