Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/differential-linear sat model #285

Merged
merged 13 commits into from
Oct 4, 2024

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, cipher, counter='sequential', compact=False):
super().__init__(cipher, counter, compact)
self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, '_'.join)

def branch_xor_linear_constraints(self):
@staticmethod
def branch_xor_linear_constraints(bindings):
"""
Return lists of variables and clauses for branch in XOR LINEAR model.

Expand All @@ -52,7 +53,7 @@ def branch_xor_linear_constraints(self):
sage: from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher
sage: speck = SpeckBlockCipher(number_of_rounds=3)
sage: sat = SatXorLinearModel(speck)
sage: sat.branch_xor_linear_constraints()
sage: SatXorLinearModel.branch_xor_linear_constraints(sat.bit_bindings)
['-plaintext_0_o rot_0_0_0_i',
'plaintext_0_o -rot_0_0_0_i',
'-plaintext_1_o rot_0_0_1_i',
Expand All @@ -62,7 +63,7 @@ def branch_xor_linear_constraints(self):
'xor_2_10_15_o -cipher_output_2_12_31_i']
"""
constraints = []
for output_bit, input_bits in self.bit_bindings.items():
for output_bit, input_bits in bindings.items():
constraints.extend(utils.cnf_xor(output_bit, input_bits))

return constraints
Expand Down Expand Up @@ -91,7 +92,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]):
self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(self._cipher, '_'.join)
if fixed_variables == []:
fixed_variables = get_single_key_scenario_format_for_fixed_values(self._cipher)
constraints = self.fix_variables_value_xor_linear_constraints(fixed_variables)
constraints = SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables)
self._model_constraints = constraints
component_types = (CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, LINEAR_LAYER, SBOX, MIX_COLUMN, WORD_OPERATION)
operation_types = ("AND", "MODADD", "NOT", "ROTATE", "SHIFT", "XOR", "OR", "MODSUB")
Expand All @@ -106,7 +107,7 @@ def build_xor_linear_trail_model(self, weight=-1, fixed_variables=[]):
self._variables_list.extend(variables)
self._model_constraints.extend(constraints)

constraints = self.branch_xor_linear_constraints()
constraints = SatXorLinearModel.branch_xor_linear_constraints(self.bit_bindings)
self._model_constraints.extend(constraints)

if weight != -1:
Expand Down Expand Up @@ -399,7 +400,8 @@ def find_one_xor_linear_trail_with_fixed_weight(self, fixed_weight, fixed_values

return solution

def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]):
@staticmethod
def fix_variables_value_xor_linear_constraints(fixed_variables=[]):
"""
Return lists variables and clauses for fixing variables in XOR LINEAR model.

Expand Down Expand Up @@ -428,7 +430,7 @@ def fix_variables_value_xor_linear_constraints(self, fixed_variables=[]):
....: 'bit_positions': [0, 1, 2, 3],
....: 'bit_values': [1, 1, 1, 0]
....: }]
sage: sat.fix_variables_value_xor_linear_constraints(fixed_variables)
sage: SatXorLinearModel.fix_variables_value_xor_linear_constraints(fixed_variables)
['plaintext_0_o',
'-plaintext_1_o',
'plaintext_2_o',
Expand Down
32 changes: 32 additions & 0 deletions claasp/cipher_modules/models/sat/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,12 @@ def get_cnf_bitwise_truncate_constraints(a, a_0, a_1):
]


def get_cnf_truncated_linear_constraints(a, a_0):
return [
f'-{a} -{a_0}'
]


def modadd_truncated_lsb(result, variable_0, variable_1, next_carry):
return [f'{next_carry[0]} -{next_carry[1]}',
f'{next_carry[0]} -{variable_1[1]}',
Expand Down Expand Up @@ -819,3 +825,29 @@ def run_yices(solver_specs, options, dimacs_input, input_file_name):
os.remove(input_file_name)

return status, time, memory, values


def _generate_component_model_types(speck_cipher):
"""Generates the component model types for a given Speck cipher."""
component_model_types = []
for component in speck_cipher.get_all_components():
component_model_types.append({
"component_id": component.id,
"component_object": component,
"model_type": "sat_xor_differential_propagation_constraints"
})
return component_model_types


def _update_component_model_types_for_truncated_components(component_model_types, truncated_components):
"""Updates the component model types for truncated components."""
for component_model_type in component_model_types:
if component_model_type["component_id"] in truncated_components:
component_model_type["model_type"] = "sat_bitwise_deterministic_truncated_xor_differential_constraints"


def _update_component_model_types_for_linear_components(component_model_types, linear_components):
"""Updates the component model types for linear components."""
for component_model_type in component_model_types:
if component_model_type["component_id"] in linear_components:
component_model_type["model_type"] = "sat_xor_linear_mask_propagation_constraints"
89 changes: 89 additions & 0 deletions claasp/cipher_modules/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import math
from copy import deepcopy

import numpy as np

from claasp.name_mappings import CONSTANT, CIPHER_OUTPUT, INTERMEDIATE_OUTPUT, WORD_OPERATION, LINEAR_LAYER, SBOX, MIX_COLUMN, \
INPUT_KEY, INPUT_PLAINTEXT, INPUT_MESSAGE, INPUT_STATE

Expand Down Expand Up @@ -791,3 +793,90 @@ def get_related_key_scenario_format_for_fixed_values(_cipher):
fixed_variables.append(fixed_variable)

return fixed_variables


def _extract_bits(columns, positions):
"""Extracts bits from columns at specified positions using vectorization."""
bit_size = columns.shape[0] * 8
positions = np.array(positions)
byte_indices = (bit_size - positions - 1) // 8
bit_indices = positions % 8
if np.any(byte_indices < 0) or np.any(byte_indices >= columns.shape[0]):
raise IndexError("Byte index out of range.")
bytes_at_positions = columns[byte_indices][:, :]
bits = (bytes_at_positions >> bit_indices[:, np.newaxis]) & 1

return bits


def _number_to_n_bit_binary_string(number, n_bits):
"""Converts a number to an n-bit binary string with leading zero padding."""
return format(number, f'0{n_bits}b')


def _extract_bit_positions(hex_number, state_size):
binary_str = _number_to_n_bit_binary_string(hex_number, state_size)
binary_str = binary_str[::-1]
positions = [i for i, bit in enumerate(binary_str) if bit == '1']
return positions


def _repeat_input_difference(input_difference, num_samples, num_bytes):
"""Function to repeat the input difference for a large sample size."""
bytes_array = np.frombuffer(input_difference.to_bytes(num_bytes, 'big'), dtype=np.uint8)
repeated_array = np.broadcast_to(bytes_array[:, np.newaxis], (num_bytes, num_samples))
return repeated_array


def differential_linear_checker_for_permutation(
cipher, input_difference, output_mask, number_of_samples, state_size
):
"""
This method helps to verify experimentally differential-linear distinguishers for permutations using the vectorized evaluator
"""
if state_size % 8 != 0:
raise ValueError("State size must be a multiple of 8.")
num_bytes = int(state_size/8)

rng = np.random.default_rng()
input_difference_data = _repeat_input_difference(input_difference, number_of_samples, num_bytes)
plaintext1 = rng.integers(low=0, high=256, size=(num_bytes, number_of_samples), dtype=np.uint8)
plaintext2 = plaintext1 ^ input_difference_data
ciphertext1 = cipher.evaluate_vectorized([plaintext1])
ciphertext2 = cipher.evaluate_vectorized([plaintext2])
ciphertext3 = ciphertext1[0] ^ ciphertext2[0]
bit_positions_ciphertext = _extract_bit_positions(output_mask, state_size)
ccc = _extract_bits(ciphertext3.T, bit_positions_ciphertext)
parities = np.bitwise_xor.reduce(ccc, axis=0)
count = np.count_nonzero(parities == 0)
corr = 2*count/number_of_samples*1.0-1
return corr


def differential_linear_checker_for_block_cipher_single_key(
cipher, input_difference, output_mask, number_of_samples, block_size, key_size, fixed_key
):
"""
This method helps to verify experimentally differential-linear distinguishers for block ciphers using the vectorized evaluator
"""
if block_size % 8 != 0:
raise ValueError("State size must be a multiple of 8.")
if key_size % 8 != 0:
raise ValueError("Key size must be a multiple of 8.")
state_num_bytes = int(block_size / 8)
key_num_bytes = int(key_size / 8)

rng = np.random.default_rng()
fixed_key_data = _repeat_input_difference(fixed_key, number_of_samples, key_num_bytes)
input_difference_data = _repeat_input_difference(input_difference, number_of_samples, state_num_bytes)
plaintext1 = rng.integers(low=0, high=256, size=(state_num_bytes, number_of_samples), dtype=np.uint8)
plaintext2 = plaintext1 ^ input_difference_data
ciphertext1 = cipher.evaluate_vectorized([plaintext1, fixed_key_data])
ciphertext2 = cipher.evaluate_vectorized([plaintext2, fixed_key_data])
ciphertext3 = ciphertext1[0] ^ ciphertext2[0]
bit_positions_ciphertext = _extract_bit_positions(output_mask, block_size)
ccc = _extract_bits(ciphertext3.T, bit_positions_ciphertext)
parities = np.bitwise_xor.reduce(ccc, axis=0)
count = np.count_nonzero(parities == 0)
corr = 2*count/number_of_samples*1.0-1
return corr
Loading
Loading