Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: faster poseidon hades permutation #170

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions src/starkware/cairo/common/poseidon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
"""

import hashlib
from typing import Iterable, List, Optional, Type

import numpy as np
from typing import List, Optional, Type

from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME
from starkware.python.math_utils import pow_mod, safe_div
from starkware.python.math_utils import safe_div


def generate_round_constant(fn_name: str, field_prime: int, idx: int) -> int:
Expand Down Expand Up @@ -40,7 +38,7 @@ class PoseidonParams:
poseidon_small_params: Optional["PoseidonParams"] = None

def __init__(
self, field_prime: int, r: int, c: int, r_f: int, r_p: int, mds: Iterable[Iterable[int]]
self, field_prime: int, r: int, c: int, r_f: int, r_p: int
):
self.field_prime = field_prime
self.r = r
Expand All @@ -53,45 +51,41 @@ def __init__(
self.output_size = c
assert self.output_size <= r
# A list of r_f + r_p vectors for the Add-Round Key phase.
self.ark = np.array(
[
self.ark = [
[generate_round_constant("Hades", field_prime, m * i + j) for j in range(m)]
for i in range(n_rounds)
],
dtype=object,
)

# The MDS matrix for the MixLayer phase.
self.mds = np.array(mds, dtype=object)
]

@classmethod
def get_default_poseidon_params(cls: Type["PoseidonParams"]):
if cls.poseidon_small_params is None:
cls.poseidon_small_params = cls(
field_prime=DEFAULT_PRIME, r=2, c=1, r_f=8, r_p=83, mds=SmallMds
field_prime=DEFAULT_PRIME, r=2, c=1, r_f=8, r_p=83
)

return cls.poseidon_small_params


def hades_round(values, params: PoseidonParams, is_full_round: bool, round_idx: int):
# Add-Round Key.
values = (values + params.ark[round_idx]) % params.field_prime

values = [
(val + ark) % params.field_prime
for val, ark in zip(values, params.ark[round_idx])
]
# SubWords.
if is_full_round:
values = pow_mod(values, 3, params.field_prime)
values = [pow(val, 3, params.field_prime) for val in values]
else:
values[-1:] = pow_mod(values[-1:], 3, params.field_prime)
values[-1] = pow(values[-1], 3, params.field_prime)

# MixLayer.
values = params.mds.dot(values) % params.field_prime
values = mds_mul(values, params.field_prime)
return values


def hades_permutation(values: List[int], params: PoseidonParams) -> List[int]:
assert len(values) == params.m
values = np.array(values, dtype=object)

round_idx = 0
# Apply r_f/2 full rounds.
for _ in range(safe_div(params.r_f, 2)):
Expand All @@ -106,8 +100,21 @@ def hades_permutation(values: List[int], params: PoseidonParams) -> List[int]:
values = hades_round(values, params, True, round_idx)
round_idx += 1
assert round_idx == params.n_rounds
return list(values)
return values


def mds_mul(vector, field):
"""
Multiplies a vector by the SmallMds matrix.
[3, 1, 1] [r0] [3* r0 + r1 + r2 ]
[1, -1, 1] * [r1] = [r0 - r1 + r2 ]
[1, 1, -2] [r2] [r0 + r1 - 2 * r2]
"""
return [
(3 * vector[0] + vector[1] + vector[2]) % field,
(vector[0] - vector[1] + vector[2]) % field,
(vector[0] + vector[1] - 2 * vector[2]) % field,
]

# The actual config to be in use, with extremely small MDS coefficients.
SmallMds = [[3, 1, 1], [1, -1, 1], [1, 1, -2]]
# SmallMds = [[3, 1, 1], [1, -1, 1], [1, 1, -2]]