Skip to content

Commit

Permalink
feat(ff): enhance BabyBear arithmetic with int conversion decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
dynm committed Jan 14, 2025
1 parent 91a939c commit 7936712
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions src/finite-field/babybear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import random
from typing import List

def convert_int_to_babybear(func):
def wrapper(self, other):
if isinstance(other, int):
other = BabyBear(other)
return func(self, other)
return wrapper

class BabyBear:
P = 15 * (1 << 27) + 1
def __init__(self, value: int):
Expand All @@ -26,20 +33,17 @@ def __eq__(self, other):
return self.value == other.value
return False

@convert_int_to_babybear
def __add__(self, other):
if isinstance(other, BabyBear):
return BabyBear(self.add(self.value, other.value))
raise TypeError("Unsupported operand type for +")
return BabyBear(self.add(self.value, other.value))

@convert_int_to_babybear
def __sub__(self, other):
if isinstance(other, BabyBear):
return BabyBear(self.sub(self.value, other.value))
raise TypeError("Unsupported operand type for -")
return BabyBear(self.sub(self.value, other.value))

@convert_int_to_babybear
def __mul__(self, other):
if isinstance(other, BabyBear):
return BabyBear(self.mul(self.value, other.value))
raise TypeError("Unsupported operand type for *")
return BabyBear(self.mul(self.value, other.value))

def __neg__(self):
return BabyBear(self.P - self.value)
Expand Down Expand Up @@ -74,10 +78,21 @@ def mul(cls, lhs: int, rhs: int) -> int:
def __pow__(self, n: int):
return self.pow(n)

@convert_int_to_babybear
def __truediv__(self, other):
if isinstance(other, BabyBear):
return self * other.inv()
raise TypeError("Unsupported operand type for /")
return self * other.inv()

def __radd__(self, other):
return BabyBear(other) + self

def __rsub__(self, other):
return BabyBear(other) - self

def __rmul__(self, other):
return BabyBear(other) * self

def __rtruediv__(self, other):
return BabyBear(other) / self

class BabyBearExtElem:
BETA = BabyBear(11)
Expand Down Expand Up @@ -174,6 +189,14 @@ def __truediv__(self, other):
print(f"a = {a}")
print(f"b = {b}")
print(f"a + b = {a + b}")
print(f"a + 42 = {a + 42}")
print(f"42 + a = {42 + a}")
print(f"42 - a = {42 - a}")
print(f"-a + 42 = {-a + 42}")
print(f"42 * a = {42 * a}")
print(f"a * 42 = {a * 42}")
print(f"42 / a * a = {42 / a * a}")
print(f"a / 42 * 42 = {a / 42 * 42}")
print(f"a - b = {a - b}")
print(f"a * b = {a * b}")
print(f"a.inv() = {a.inv()}")
Expand Down

0 comments on commit 7936712

Please sign in to comment.