Skip to content

Commit

Permalink
Vectorized working for all block ciphers except rc5, aes_4_3 (bugs in…
Browse files Browse the repository at this point in the history
… GRs) and qarmav2_with_mixcolumn
  • Loading branch information
davidgerault committed May 15, 2024
1 parent 1bab236 commit 9ad1e55
Show file tree
Hide file tree
Showing 13 changed files with 228 additions and 147 deletions.
6 changes: 3 additions & 3 deletions claasp/cipher_modules/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,8 @@ def prepare_input_byte_based_vectorized_python_code_string(bit_sizes, component)
if component.type == 'constant':
return params

assert (input_bit_size % number_of_inputs) == 0, f"The number of inputs does not divide the number of input bits " \
f"for component {component.id}. "
#assert (input_bit_size % number_of_inputs) == 0, f"The number of inputs does not divide the number of input bits " \
# f"for component {component.id}. "
bits_per_input = input_bit_size // number_of_inputs
words_per_input = get_number_of_bytes_needed_for_bit_size(bits_per_input)
# Divide inputs
Expand Down Expand Up @@ -469,7 +469,7 @@ def get_number_of_inputs(component):
else:
number_of_inputs = component.description[1]
elif component.type == 'mix_column':
number_of_inputs = len(component.description[0])
number_of_inputs = len(component.description[0][0])
elif component.type == 'linear_layer':
number_of_inputs = len(component.description[0])
elif component.type == 'sbox':
Expand Down
65 changes: 46 additions & 19 deletions claasp/cipher_modules/generic_functions_vectorized_byte.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def integer_array_to_evaluate_vectorized_input(values, bit_size):
num_bytes = get_number_of_bytes_needed_for_bit_size(bit_size)
# math.ceil(bit_size / 8)
values_as_np = np.array(values, dtype=object) & (2 ** bit_size - 1)
#print(f"In conv function : {values=}, {values_as_np=}")
evaluate_vectorized_input = (np.uint8([(values_as_np >> ((num_bytes - j - 1) * 8)) & 0xff
for j in range(num_bytes)]).reshape((num_bytes, -1)))
return evaluate_vectorized_input
Expand Down Expand Up @@ -123,7 +122,7 @@ def byte_vector_is_consecutive(l):
return np.all(l[::-1] == np.arange(l[-1], l[0] + 1).tolist())


def byte_vector_select_all_words(unformatted_inputs, real_bits, real_inputs, number_of_inputs, words_per_input,
def byte_vector_select_all_words(unformated_inputs, real_bits, real_inputs, number_of_inputs, words_per_input,
actual_inputs_bits):
"""
Parses the inputs from the cipher into a list of numpy byte arrays, each corresponding to one input to the function.
Expand All @@ -139,23 +138,35 @@ def byte_vector_select_all_words(unformatted_inputs, real_bits, real_inputs, num
- ``words_per_input`` -- **integer**; the number of 8-bit words to be reserved for each of the inputs
- ``actual_inputs_bits`` -- **integer**; the bit size of the variables in unformatted_inputs
"""
number_of_columns = [unformatted_inputs[i].shape[1] for i in range(len(unformatted_inputs))]
#
# print("*"*20)
# print("SELECT")
# print(f"{unformated_inputs=}")
# print(f"{real_bits=}")
# print(f"{real_inputs=}")
# print(f"{number_of_inputs=}")
# print(f"{words_per_input=}")
# print(f"{actual_inputs_bits=}")

number_of_columns = [x.shape[1] for x in unformated_inputs]
max_number_of_columns = np.max(number_of_columns)
output = [0 for _ in range(number_of_inputs)]
for i in range(number_of_inputs):
pos = 0
number_of_output_bits = np.sum([len(x) for x in real_bits[i]])
if len(real_inputs[i]) == 1 and np.all(real_bits[i][0] == list(range(actual_inputs_bits[real_inputs[i][0]]))):
output[i] = unformatted_inputs[real_inputs[i][0]]
output[i] = unformated_inputs[real_inputs[i][0]]
if number_of_output_bits % 8 > 0:
left_byte_mask = 2 ** (number_of_output_bits % 8) - 1
else:
left_byte_mask = 0xffff
output[i][0, :] &= left_byte_mask
else:
output[i] = np.zeros(shape=(words_per_input, max_number_of_columns), dtype=np.uint8)
generate_formatted_inputs(actual_inputs_bits, i, output, pos, real_bits, real_inputs, unformatted_inputs,
generate_formatted_inputs(actual_inputs_bits, i, output, pos, real_bits, real_inputs, unformated_inputs,
words_per_input)
#print(f"{output=}")

return output


Expand Down Expand Up @@ -388,7 +399,7 @@ def byte_vector_ROTATE(input, rotation_amount, input_bit_size):
bits_to_cut = 8 - (input_bit_size % 8)
bin_input = np.unpackbits(input[0], axis=0)
rotated = np.vstack([np.zeros((bits_to_cut, bin_input.shape[1]), dtype=np.uint8),
np.roll(bin_input[bits_to_cut:, :], rotation_amount)])
np.roll(bin_input[bits_to_cut:, :], rotation_amount, axis=0)])
ret = np.packbits(rotated, axis=0)
else:
rot = rotation_amount
Expand Down Expand Up @@ -470,18 +481,30 @@ def byte_vector_mix_column(input, matrix, mul_table, word_size):
- ``matrix`` -- **list**; a list of lists of integers
- ``mul_tables`` -- **dictionary**; a dictionary giving the multiplication table by x at key x
"""
assert word_size == 4 or word_size == 8, "Vectorized evaluation of mix_columns does not support word sizes other than 8 and 4"
tmp = np.zeros(shape=(len(input), input[0].shape[1]), dtype=np.uint8)
#assert word_size == 4 or word_size == 8, "Vectorized evaluation of mix_columns does not support word sizes other than 8 and 4"
tmp = np.zeros(shape=(len(matrix) * input[0].shape[0], input[0].shape[1]), dtype=np.uint8)

#tmp = np.zeros(shape=(len(matrix), input[0].shape[1]), dtype=np.uint8)
print("="*30)
print("MC")
print(f"{word_size=}")
print(f"{tmp.shape=}")
print(input)

for i in [*mul_table]:
mul_table[i] = np.array(mul_table[i], dtype=np.uint8)
for i in range(len(matrix)):
for j in range(len(matrix[0])):
tmp[i] = reduce(lambda x, y: x ^ y, [tmp[i], mul_table[matrix[i][j]][input[j]]])
if word_size < 8:
output = np.uint8([(tmp[2 * i, :] << 4) ^ tmp[2 * i + 1, :] for i in range(len(input) // 2)])
return output
else:
if word_size >= 8:
return tmp
#else:
return byte_vector_select_all_words(unformated_inputs=[x.reshape(1,-1) for x in tmp],
real_bits = [[list(range(word_size)) for _ in tmp]],
real_inputs = [list(range(len(tmp)))],
number_of_inputs=1,
words_per_input=get_number_of_bytes_needed_for_bit_size(word_size*len(tmp)),
actual_inputs_bits=[word_size for _ in tmp])[0]


def byte_vector_mix_column_poly0(input, matrix, word_size):
Expand All @@ -493,15 +516,19 @@ def byte_vector_mix_column_poly0(input, matrix, word_size):
- ``input`` -- **np.array(dtype = np.uint8)** A numpy matrix with one row per byte, and one column per byte.
- ``matrix`` -- **list**; a list of lists of integers
"""
assert word_size == 4 or word_size == 8, "Vectorized evaluation of mix_columns does not support word sizes other than 8 and 4"
tmp = np.zeros(shape=(len(input), input[0].shape[1]), dtype=np.uint8)
#tmp = np.zeros(shape=(len(matrix), input[0].shape[1]), dtype=np.uint8)
tmp = np.zeros(shape=(len(matrix) * input[0].shape[0], input[0].shape[1]), dtype=np.uint8)

#tmp = np.zeros(shape=(len(input) * input[0].shape[0], input[0].shape[1]), dtype=np.uint8)
for i in range(len(matrix)):
for j in range(len(matrix[0])):
tmp[i * input[0].shape[0]:(i + 1) * input[0].shape[0]] = \
tmp[i * input[0].shape[0]:(i + 1) * input[0].shape[0]] ^ matrix[i][j] * input[j]

if word_size < 8:
output = np.uint8([(tmp[2 * i, :] << 4) ^ tmp[2 * i + 1, :] for i in range(len(input) // 2)])
return output
else:
if word_size >=8:
return tmp
return byte_vector_select_all_words(unformated_inputs=[x.reshape(1,-1) for x in tmp],
real_bits = [[list(range(word_size)) for _ in tmp]],
real_inputs = [list(range(len(tmp)))],
number_of_inputs=1,
words_per_input=get_number_of_bytes_needed_for_bit_size(word_size*len(tmp)),
actual_inputs_bits=[word_size for _ in tmp])[0]
16 changes: 8 additions & 8 deletions claasp/ciphers/block_ciphers/aes_block_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,14 @@ def create_constant_component(self, round_number, state_size, word_size):
# # print("Error : Constant format")
# else:
if word_size != 8:
if self.ROUND_CONSTANT[word_size][round_number][:2] == '0b':
constant = self.add_constant_component(len(self.ROUND_CONSTANT[word_size][round_number]) - 2,
int(self.ROUND_CONSTANT[word_size][round_number], 2))
elif self.ROUND_CONSTANT[word_size][round_number][:2] == '0x':
constant = self.add_constant_component(self.ROW_SIZE,
int(self.ROUND_CONSTANT[word_size][round_number], 16))
else:
print("Error : Constant format")
#if self.ROUND_CONSTANT[word_size][round_number][:2] == '0b':
constant = self.add_constant_component(len(self.ROUND_CONSTANT[word_size][round_number]) - 2,
int(self.ROUND_CONSTANT[word_size][round_number], 0))
#elif self.ROUND_CONSTANT[word_size][round_number][:2] == '0x':
# constant = self.add_constant_component(self.ROW_SIZE,
# int(self.ROUND_CONSTANT[word_size][round_number], 16))
#else:
# print("Error : Constant format")
else:
constant = self.add_constant_component(self.ROW_SIZE,
int(self.ROUND_CONSTANT[word_size][state_size][round_number],
Expand Down
1 change: 0 additions & 1 deletion claasp/components/cipher_output_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes):
return code

def get_byte_based_vectorized_python_code(self, params):
print(params)
return [f' {self.id} = {params}[0]',
f' if "{self.description[0]}" not in intermediateOutputs.keys():',
f' intermediateOutputs["{self.description[0]}"] = []',
Expand Down
6 changes: 0 additions & 6 deletions claasp/components/constant_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,7 @@ def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes):
f'dtype=np.uint8).reshape({self.output_bit_size, 1})']

def get_byte_based_vectorized_python_code(self, params):
#print("*CONSTANT"*10)
#print(self.description, self.output_bit_size)
#print("*"*10, flush=True)
val = constant_to_repr(self.description[0], self.output_bit_size)
val2 = integer_array_to_evaluate_vectorized_input([int(self.description[0],0)], self.output_bit_size)


return [f' {self.id} = np.array({val}, dtype=np.uint8).reshape({len(val)}, 1)']

def get_word_based_c_code(self, verbosity, word_size, wordstring_variables):
Expand Down
1 change: 0 additions & 1 deletion claasp/components/intermediate_output_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def get_bit_based_vectorized_python_code(self, params, convert_output_to_bytes):
return code

def get_byte_based_vectorized_python_code(self, params):
print(params)
return [f' {self.id} = {params}[0]',
f' if "{self.description[0]}" not in intermediateOutputs.keys():',
f' intermediateOutputs["{self.description[0]}"] = []',
Expand Down
Loading

0 comments on commit 9ad1e55

Please sign in to comment.