diff --git a/claasp/cipher_modules/models/milp/utils/dictionary_that_contains_inequalities_for_large_sboxes.obj b/claasp/cipher_modules/models/milp/utils/dictionary_that_contains_inequalities_for_large_sboxes.obj index a6f8009a..0b91dc59 100644 Binary files a/claasp/cipher_modules/models/milp/utils/dictionary_that_contains_inequalities_for_large_sboxes.obj and b/claasp/cipher_modules/models/milp/utils/dictionary_that_contains_inequalities_for_large_sboxes.obj differ diff --git a/claasp/ciphers/block_ciphers/aradi_block_cipher.py b/claasp/ciphers/block_ciphers/aradi_block_cipher.py new file mode 100644 index 00000000..3f8df0b4 --- /dev/null +++ b/claasp/ciphers/block_ciphers/aradi_block_cipher.py @@ -0,0 +1,168 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + + +from claasp.cipher import Cipher +from claasp.utils.utils import get_ith_word +from claasp.DTOs.component_state import ComponentState +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY + +input_types = [INPUT_KEY, INPUT_PLAINTEXT] +PARAMETERS_CONFIGURATION_LIST = [{'block_bit_size': 128, 'key_bit_size': 256, 'number_of_rounds': 16}] + + +class AradiBlockCipher(Cipher): + """ + Construct an instance of the AradiBlockCipher class. + + This class is used to store compact representations of a cipher, + used to generate the corresponding cipher. + + INPUT: + + - ``block_bit_size`` -- **integer** (default: `64`); cipher input and output block bit size of the cipher + - ``key_bit_size`` -- **integer** (default: `128`); cipher key bit size of the cipher + - ``number_of_rounds`` -- **integer** (default: `0`); number of rounds of the cipher. The cipher uses the + corresponding amount given the other parameters (if available) when number_of_rounds is 0 + - ``sub_keys_zero`` -- **boolean** (default: `False`) + - ``transformations_flag`` -- **boolean** (default: `True`) + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.aradi_block_cipher import AradiBlockCipher + sage: aradi = AradiBlockCipher(number_of_rounds=16) + sage: aradi.evaluate([0, 0x1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100]) + 83791582030165712186104466959690447122 + """ + + def __init__(self, number_of_rounds=16): + self.block_bit_size = 128 + self.key_bit_size = 256 + self.WORD_SIZE = 32 + + super().__init__(family_name="aradi", + cipher_type="block_cipher", + cipher_inputs=[INPUT_PLAINTEXT, INPUT_KEY], + cipher_inputs_bit_size=[self.block_bit_size, self.key_bit_size], + cipher_output_bit_size=self.block_bit_size) + self.A = [11,10,9,8] + self.B = [8,9,4,9] + self.C = [14,11,14,7] + state = INPUT_PLAINTEXT + key = INPUT_KEY + + for round_i in range(number_of_rounds): + self.add_round() + round_key = self.get_round_key_id(key, round_i) + state = self.round_function(state, round_key, round_i) + key = self.update_key(key, round_i) + + round_key = self.get_round_key_id(key, 0) + w = self.add_XOR_component([round_key, state], [list(range(32)), list(range(32))], 32).id + x = self.add_XOR_component([round_key, state], [list(range(32, 64)), list(range(32, 64))], 32).id + y = self.add_XOR_component([round_key, state], [list(range(64, 96)), list(range(64, 96))], 32).id + z = self.add_XOR_component([round_key, state], [list(range(96, 128)), list(range(96, 128))], 32).id + self.add_cipher_output_component([w, x, y, z], [list(range(32)) for _ in range(4)], 128) + + + def get_key_word_bit_indexes(self, word_index): + return list(range(32*(8-word_index-1), 32*(8-word_index))) + + def get_round_key_id(self, key, round_i): + j = round_i % 2 + return self.add_round_key_output_component([key, key, key, key], [self.get_key_word_bit_indexes(4*j + i) for i in range(4)], 128).id + + def l_function(self, xy_id_links, x_bits, y_bits, round_index): + j = round_index % 4 + rot_x_a = self.add_rotate_component([xy_id_links], [x_bits], 16, -self.A[j]).id + rot_x_b = self.add_rotate_component([xy_id_links], [x_bits], 16, -self.B[j]).id + rot_y_a = self.add_rotate_component([xy_id_links], [y_bits], 16, -self.A[j]).id + rot_y_c = self.add_rotate_component([xy_id_links], [y_bits], 16, -self.C[j]).id + + left_part = self.add_XOR_component([xy_id_links] + [rot_x_a, rot_y_c], [x_bits] + [list(range(16)), list(range(16))], 16).id + right_part = self.add_XOR_component([xy_id_links] + [rot_y_a, rot_x_b], [y_bits] + [list(range(16)), list(range(16))], 16).id + return left_part, right_part + + def m_function(self, i, j, xy_id_links, xy_input_bits): + y_inds = xy_input_bits[32:64] + x_inds = xy_input_bits[:32] + rot_i_y = self.add_rotate_component(xy_id_links, [y_inds], 32, -i).id + rot_j_x = self.add_rotate_component(xy_id_links, [x_inds], 32, -j).id + left_part = self.add_XOR_component(xy_id_links + [rot_i_y, rot_j_x], [x_inds] + [list(range(32)) for _ in range(2)], 32).id + right_part = self.add_XOR_component(xy_id_links + [rot_i_y], [x_inds] + [list(range(32))], 32).id + return left_part, right_part + + def update_key(self, key, round_i): + round_constant = self.add_constant_component(32, round_i).id + k1, k0 = self.m_function(1, 3, [key], self.get_key_word_bit_indexes(1) + self.get_key_word_bit_indexes(0)) + k3, k2 = self.m_function(9, 28, [key], self.get_key_word_bit_indexes(3) + self.get_key_word_bit_indexes(2)) + k5, k4 = self.m_function(1, 3, [key], self.get_key_word_bit_indexes(5) + self.get_key_word_bit_indexes(4)) + k7, k6 = self.m_function(9, 28, [key], self.get_key_word_bit_indexes(7) + self.get_key_word_bit_indexes(6)) + k7 = self.add_XOR_component([k7, round_constant], [list(range(32)) for _ in range(2)], 32).id + if round_i % 2 == 0: + updated_key = self.add_intermediate_output_component([k7, k5, k6, k4, k3, k1, k2, k0], [list(range(32)) for _ in range(8)], 256, f"key_{round_i}") + else: + updated_key = self.add_intermediate_output_component([k7, k3, k5, k1, k6, k2, k4, k0], [list(range(32)) for _ in range(8)], 256, f"key_{round_i}") + return updated_key.id + + def round_function(self, state, round_key, round_i): + w_ind = list(range(32)) + x_ind = list(range(32, 64)) + y_ind = list(range(64, 96)) + z_ind = list(range(96, 128)) + w = self.add_XOR_component([round_key, state], [list(range(32)), list(range(32))], 32).id + x = self.add_XOR_component([round_key, state], [list(range(32, 64)), list(range(32, 64))], 32).id + y = self.add_XOR_component([round_key, state], [list(range(64, 96)), list(range(64, 96))], 32).id + z = self.add_XOR_component([round_key, state], [list(range(96, 128)), list(range(96, 128))], 32).id + # SBOX + wy = self.add_AND_component([w, y], [list(range(32)) for _ in range(2)], 32).id + x = self.add_XOR_component([x, wy], [list(range(32)) for _ in range(2)], 32).id + + xy = self.add_AND_component([x, y], [list(range(32)) for _ in range(2)], 32).id + z = self.add_XOR_component([z, xy], [list(range(32)) for _ in range(2)], 32).id + + wz = self.add_AND_component([w,z], [list(range(32)) for _ in range(2)], 32).id + y = self.add_XOR_component([y, wz], [list(range(32)) for _ in range(2)], 32).id + + xz = self.add_AND_component([x,z], [list(range(32)) for _ in range(2)], 32).id + w = self.add_XOR_component([w, xz], [list(range(32)) for _ in range(2)], 32).id + + # L Function + #w_id_links = [sb_id for sb_id in sb_outputs] + #w_input_bits = [[0] for _ in sb_outputs] + x_bits = list(range(16)) + y_bits = list(range(16, 32)) + wl, wr = self.l_function( + w, x_bits, y_bits, round_i + ) + + xl, xr = self.l_function( + x, x_bits, y_bits, round_i + ) + + yl, yr = self.l_function( + y, x_bits, y_bits, round_i + ) + + zl, zr = self.l_function( + z, x_bits, y_bits, round_i + ) + + # Round output + state = self.add_round_output_component([wl, wr, xl, xr, yl, yr, zl, zr], [list(range(16)) for _ in range(8)], + 128) + return state.id diff --git a/claasp/ciphers/block_ciphers/aradi_block_cipher_sbox.py b/claasp/ciphers/block_ciphers/aradi_block_cipher_sbox.py new file mode 100644 index 00000000..e3c91449 --- /dev/null +++ b/claasp/ciphers/block_ciphers/aradi_block_cipher_sbox.py @@ -0,0 +1,164 @@ +# **************************************************************************** +# Copyright 2023 Technology Innovation Institute +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# **************************************************************************** + + +from claasp.cipher import Cipher +from claasp.utils.utils import get_ith_word +from claasp.DTOs.component_state import ComponentState +from claasp.name_mappings import INPUT_PLAINTEXT, INPUT_KEY + +input_types = [INPUT_KEY, INPUT_PLAINTEXT] +PARAMETERS_CONFIGURATION_LIST = [{'block_bit_size': 128, 'key_bit_size': 256, 'number_of_rounds': 16}] + + +class AradiBlockCipherSBox(Cipher): + """ + Construct an instance of the AradiBlockCipher class using an SBox component. + + This class is used to store compact representations of a cipher, + used to generate the corresponding cipher. + + INPUT: + + - ``block_bit_size`` -- **integer** (default: `64`); cipher input and output block bit size of the cipher + - ``key_bit_size`` -- **integer** (default: `128`); cipher key bit size of the cipher + - ``number_of_rounds`` -- **integer** (default: `0`); number of rounds of the cipher. The cipher uses the + corresponding amount given the other parameters (if available) when number_of_rounds is 0 + - ``sub_keys_zero`` -- **boolean** (default: `False`) + - ``transformations_flag`` -- **boolean** (default: `True`) + + EXAMPLES:: + + sage: from claasp.ciphers.block_ciphers.aradi_block_cipher_sbox import AradiBlockCipherSBox + sage: aradi = AradiBlockCipherSBox(number_of_rounds=16) + sage: aradi.evaluate([0, 0x1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100]) + 83791582030165712186104466959690447122 + """ + + def __init__(self, number_of_rounds=16): + self.block_bit_size = 128 + self.key_bit_size = 256 + self.WORD_SIZE = 32 + self.SBOX = [0, 1, 2, 3, 4, 13, 15, 6, 8, 11, 5, 14, 12, 7, 10, 9] + + super().__init__(family_name="aradi", + cipher_type="block_cipher", + cipher_inputs=[INPUT_PLAINTEXT, INPUT_KEY], + cipher_inputs_bit_size=[self.block_bit_size, self.key_bit_size], + cipher_output_bit_size=self.block_bit_size) + self.A = [11,10,9,8] + self.B = [8,9,4,9] + self.C = [14,11,14,7] + state = INPUT_PLAINTEXT + key = INPUT_KEY + + for round_i in range(number_of_rounds): + self.add_round() + round_key = self.get_round_key_id(key, round_i) + state = self.round_function(state, round_key, round_i) + key = self.update_key(key, round_i) + + round_key = self.get_round_key_id(key, 0) + w = self.add_XOR_component([round_key, state], [list(range(32)), list(range(32))], 32).id + x = self.add_XOR_component([round_key, state], [list(range(32, 64)), list(range(32, 64))], 32).id + y = self.add_XOR_component([round_key, state], [list(range(64, 96)), list(range(64, 96))], 32).id + z = self.add_XOR_component([round_key, state], [list(range(96, 128)), list(range(96, 128))], 32).id + self.add_cipher_output_component([w, x, y, z], [list(range(32)) for _ in range(4)], 128) + + + def get_key_word_bit_indexes(self, word_index): + return list(range(32*(8-word_index-1), 32*(8-word_index))) + + def get_round_key_id(self, key, round_i): + j = round_i % 2 + return self.add_round_key_output_component([key, key, key, key], [self.get_key_word_bit_indexes(4*j + i) for i in range(4)], 128).id + + def l_function(self, xy_id_links, xy_input_bits, round_index): + j = round_index % 4 + + rot_x_a = self.add_rotate_component(xy_id_links[:16], xy_input_bits[:16], 16, -self.A[j]).id + rot_x_b = self.add_rotate_component(xy_id_links[:16], xy_input_bits[:16], 16, -self.B[j]).id + rot_y_a = self.add_rotate_component(xy_id_links[16:], xy_input_bits[16:], 16, -self.A[j]).id + rot_y_c = self.add_rotate_component(xy_id_links[16:], xy_input_bits[16:], 16, -self.C[j]).id + + left_part = self.add_XOR_component(xy_id_links[:16] + [rot_x_a, rot_y_c], xy_input_bits[:16] + [list(range(16)), list(range(16))], 16).id + right_part = self.add_XOR_component(xy_id_links[16:] + [rot_y_a, rot_x_b], xy_input_bits[16:] + [list(range(16)), list(range(16))], 16).id + return left_part, right_part + + def m_function(self, i, j, xy_id_links, xy_input_bits): + y_inds = xy_input_bits[32:64] + x_inds = xy_input_bits[:32] + rot_i_y = self.add_rotate_component(xy_id_links, [y_inds], 32, -i).id + rot_j_x = self.add_rotate_component(xy_id_links, [x_inds], 32, -j).id + left_part = self.add_XOR_component(xy_id_links + [rot_i_y, rot_j_x], [x_inds] + [list(range(32)) for _ in range(2)], 32).id + right_part = self.add_XOR_component(xy_id_links + [rot_i_y], [x_inds] + [list(range(32))], 32).id + return left_part, right_part + + def update_key(self, key, round_i): + round_constant = self.add_constant_component(32, round_i).id + k1, k0 = self.m_function(1, 3, [key], self.get_key_word_bit_indexes(1) + self.get_key_word_bit_indexes(0)) + k3, k2 = self.m_function(9, 28, [key], self.get_key_word_bit_indexes(3) + self.get_key_word_bit_indexes(2)) + k5, k4 = self.m_function(1, 3, [key], self.get_key_word_bit_indexes(5) + self.get_key_word_bit_indexes(4)) + k7, k6 = self.m_function(9, 28, [key], self.get_key_word_bit_indexes(7) + self.get_key_word_bit_indexes(6)) + k7 = self.add_XOR_component([k7, round_constant], [list(range(32)) for _ in range(2)], 32).id + if round_i % 2 == 0: + updated_key = self.add_intermediate_output_component([k7, k5, k6, k4, k3, k1, k2, k0], [list(range(32)) for _ in range(8)], 256, f"key_{round_i}") + else: + updated_key = self.add_intermediate_output_component([k7, k3, k5, k1, k6, k2, k4, k0], [list(range(32)) for _ in range(8)], 256, f"key_{round_i}") + return updated_key.id + + def round_function(self, state, round_key, round_i): + w = self.add_XOR_component([round_key, state], [list(range(32)), list(range(32))], 32).id + x = self.add_XOR_component([round_key, state], [list(range(32, 64)), list(range(32, 64))], 32).id + y = self.add_XOR_component([round_key, state], [list(range(64, 96)), list(range(64, 96))], 32).id + z = self.add_XOR_component([round_key, state], [list(range(96, 128)), list(range(96, 128))], 32).id + # SBOX + sb_outputs = [] + for ind in range(32): + sb_outputs.append(self.add_SBOX_component([w, x, y, z], [[ind], [ind], [ind], [ind]], 4, self.SBOX).id) + + # L Function + w_id_links = [sb_id for sb_id in sb_outputs] + w_input_bits = [[0] for _ in sb_outputs] + wl, wr = self.l_function( + w_id_links, w_input_bits, round_i + ) + + x_id_links = [sb_id for sb_id in sb_outputs] + x_input_bits = [[1] for _ in sb_outputs] + xl, xr = self.l_function( + x_id_links, x_input_bits, round_i + ) + + y_id_links = [sb_id for sb_id in sb_outputs] + y_input_bits = [[2] for _ in sb_outputs] + yl, yr = self.l_function( + y_id_links, y_input_bits, round_i + ) + + z_id_links = [sb_id for sb_id in sb_outputs] + z_input_bits = [[3] for _ in sb_outputs] + zl, zr = self.l_function( + z_id_links, z_input_bits, round_i + ) + + # Round output + state = self.add_round_output_component([wl, wr, xl, xr, yl, yr, zl, zr], [list(range(16)) for _ in range(8)], + 128) + return state.id + + diff --git a/claasp/components/intermediate_output_component.py b/claasp/components/intermediate_output_component.py index 6f12a517..0e1e491a 100644 --- a/claasp/components/intermediate_output_component.py +++ b/claasp/components/intermediate_output_component.py @@ -124,6 +124,7 @@ def get_byte_based_vectorized_python_code(self, params): f' else:', f' intermediateOutputs["{self.description[0]}"].append({self.id}.transpose())'] + def milp_xor_linear_mask_propagation_constraints(self, model): """ Return a list of variables and a list of constraints for OUTPUT component for MILP xor linear. diff --git a/tests/unit/ciphers/block_ciphers/aradi_block_cipher_sbox_test.py b/tests/unit/ciphers/block_ciphers/aradi_block_cipher_sbox_test.py new file mode 100644 index 00000000..d53b7632 --- /dev/null +++ b/tests/unit/ciphers/block_ciphers/aradi_block_cipher_sbox_test.py @@ -0,0 +1,9 @@ +from claasp.ciphers.block_ciphers.aradi_block_cipher_sbox import AradiBlockCipherSBox + +def test_aradi_block_cipher_sbox(): + aradi = AradiBlockCipherSBox() + plaintext = 0 + key = 0x1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100 + ciphertext = 0x3f09abf400e3bd7403260defb7c53912 + assert aradi.evaluate([plaintext, key]) == ciphertext + assert aradi.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext diff --git a/tests/unit/ciphers/block_ciphers/aradi_block_cipher_test.py b/tests/unit/ciphers/block_ciphers/aradi_block_cipher_test.py new file mode 100644 index 00000000..3031838d --- /dev/null +++ b/tests/unit/ciphers/block_ciphers/aradi_block_cipher_test.py @@ -0,0 +1,10 @@ +from claasp.ciphers.block_ciphers.aradi_block_cipher import AradiBlockCipher + + +def test_aradi_block_cipher(): + aradi = AradiBlockCipher() + plaintext = 0 + key = 0x1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100 + ciphertext = 0x3f09abf400e3bd7403260defb7c53912 + assert aradi.evaluate([plaintext, key]) == ciphertext + assert aradi.evaluate_vectorized([plaintext, key], evaluate_api=True) == ciphertext