Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
juaninf committed Oct 2, 2024
1 parent 3d9c732 commit af8a166
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,32 +82,49 @@ def _get_truncated_xor_differential_components_in_border(self):
return list(set(border_components))

def _get_connecting_constraints(self):
def is_any_string_in_list_substring_of_string(string, string_list):
# Check if any string in the list is a substring of the given string
return any(s in string for s in string_list)
"""
Adds constraints for connecting regular and truncated components.
"""
border_components = self._get_regular_xor_differential_components_in_border()

#print(border_components)
for component_id in border_components:
component = self.cipher.get_component_from_id(component_id)
for idx in range(component.output_bit_size):
constraints = sat_utils.get_cnf_bitwise_truncate_constraints(
f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1'
)
self._model_constraints.extend(constraints)
self._variables_list.extend([
f'{component_id}_{idx}', f'{component_id}_{idx}_0', f'{component_id}_{idx}_1'
])

border_components = self._get_truncated_xor_differential_components_in_border()

print(border_components)

linear_component_ids = [item['component_id'] for item in self.linear_components]

for component_id in border_components:

component = self.cipher.get_component_from_id(component_id)
for idx in range(component.output_bit_size):
constraints = sat_utils.get_cnf_truncated_linear_constraints(
f'{component_id}_{idx}_o', f'{component_id}_{idx}_0'
)
self._model_constraints.extend(constraints)
self._variables_list.extend([
f'{component_id}_{idx}', f'{component_id}_{idx}_0'
])
truncated_component = f'{component_id}_{idx}_o'
component_sucessors = self.bit_bindings[truncated_component]
for component_sucessor in component_sucessors:
length_component_sucessor = len(component_sucessor)
component_sucessor_id = component_sucessor[:length_component_sucessor-2]

if is_any_string_in_list_substring_of_string(component_sucessor_id, linear_component_ids):
#import ipdb;
#ipdb.set_trace()
constraints = sat_utils.get_cnf_truncated_linear_constraints(
component_sucessor, f'{component_id}_{idx}_0'
)
self._model_constraints.extend(constraints)
self._variables_list.extend([component_sucessor, f'{component_id}_{idx}_0'])

def _build_weight_constraints(self, weight):
"""
Expand Down Expand Up @@ -197,6 +214,14 @@ def build_xor_differential_linear_model(self, weight=-1, num_unknown_vars=None):
self._variables_list.extend(variables)
self._model_constraints.extend(constraints)

#ciphertext_output_vars = [f'cipher_output_7_25_{i}_o' for i in range(512)]
#
#variables, constraints = self._sequential_counter_algorithm(ciphertext_output_vars, 9, 'dummy_hw_ac')
#self._variables_list.extend(variables)
#self._model_constraints.extend(constraints)



self._get_connecting_constraints()

@staticmethod
Expand Down Expand Up @@ -254,7 +279,7 @@ def _parse_solver_output(self, variable2value):
components_solutions = self._get_cipher_inputs_components_solutions(out_suffix, variable2value)
total_weight_diff = 0
total_weight_lin = 0

total_weight_lin_input = 0
for component in self._cipher.get_all_components():
if component.id in [d['component_id'] for d in self.regular_components]:
hex_value = self._get_component_hex_value(component, '', variable2value)
Expand All @@ -265,13 +290,32 @@ def _parse_solver_output(self, variable2value):
elif component.id in [d['component_id'] for d in self.truncated_components]:
value = self._get_component_value_double_ids(component, variable2value)
components_solutions[component.id] = set_component_solution(value)

out_sufix = constants.OUTPUT_BIT_ID_SUFFIX
hex_value = self._get_component_hex_value(component, out_sufix, variable2value)
components_solutions[component.id+"_o"] = set_component_solution(hex_value, 0)

elif component.id in [d['component_id'] for d in self.linear_components]:
out_sufix = constants.OUTPUT_BIT_ID_SUFFIX
hex_value = self._get_component_hex_value(component, out_sufix, variable2value)
weight = self.calculate_component_weight(component, out_sufix, variable2value)
total_weight_lin += weight
input_sufix = constants.INPUT_BIT_ID_SUFFIX
hex_value_input = self._get_component_hex_value_input(component, input_sufix, variable2value)
#weight_input = self.calculate_component_weight(component, input_sufix, variable2value)
#total_weight_lin_input += weight_input

components_solutions[component.id] = set_component_solution(hex_value, weight)
components_solutions[component.id + "_input"] = set_component_solution(hex_value_input, 0)
components_solutions[component.id + "_input"]["links"] = str(component.input_id_links)
components_solutions[component.id + "_input_id_links"] = {}
for input_id_link in component.input_id_links:
components_solutions[component.id + "_input_id_links"][input_id_link] = self._get_component_hex_value(
self.cipher.get_component_from_id(input_id_link), out_sufix, variable2value)

#if component.id == "modadd_4_18":
# components_solutions[component.id + "_input"]["modadd_3_18"] = self._get_component_hex_value(self.cipher.get_component_from_id("modadd_3_18"), out_sufix, variable2value)
print("total_weight_diff, total_weight_lin", total_weight_diff, total_weight_lin)

return components_solutions, total_weight_diff + 2*total_weight_lin

Expand Down
Loading

0 comments on commit af8a166

Please sign in to comment.