diff --git a/lightphe/__init__.py b/lightphe/__init__.py index 39b3b9a..44d316d 100644 --- a/lightphe/__init__.py +++ b/lightphe/__init__.py @@ -12,6 +12,7 @@ from lightphe.cryptosystems.NaccacheStern import NaccacheStern from lightphe.cryptosystems.GoldwasserMicali import GoldwasserMicali from lightphe.cryptosystems.EllipticCurveElGamal import EllipticCurveElGamal +from lightphe.commons import calculations from lightphe.commons.logger import Logger # pylint: disable=eval-used @@ -102,15 +103,19 @@ def build_cryptosystem( raise ValueError(f"unimplemented algorithm - {algorithm_name}") return cs - def encrypt(self, plaintext: int) -> Ciphertext: + def encrypt(self, plaintext: Union[int, float]) -> Ciphertext: """ Encrypt a plaintext with a built cryptosystem Args: - plaintext (int): message + plaintext (int or float): message Returns ciphertext (from lightphe.models.Ciphertext import Ciphertext): encrypted message """ - ciphertext = self.cs.encrypt(plaintext=plaintext) + ciphertext = self.cs.encrypt( + plaintext=calculations.parse_int( + value=plaintext, modulo=self.cs.modulo or self.cs.plaintext_modulo + ) + ) return Ciphertext(algorithm_name=self.algorithm_name, keys=self.cs.keys, value=ciphertext) def decrypt(self, ciphertext: Ciphertext) -> int: diff --git a/lightphe/commons/calculations.py b/lightphe/commons/calculations.py new file mode 100644 index 0000000..ceb6392 --- /dev/null +++ b/lightphe/commons/calculations.py @@ -0,0 +1,23 @@ +from lightphe.commons.logger import Logger + +logger = Logger() + +# pylint: disable=no-else-return + + +def parse_int(value, modulo) -> int: + if isinstance(value, int) and value >= 0: + return value + elif isinstance(value, int) and value < 0: + return value % modulo + elif isinstance(value, float) and value >= 0: + decimal_places = len(str(value).split(".")[1]) + scaling_factor = 10**decimal_places + integer_value = int(value * scaling_factor) + logger.debug(f"{integer_value}*{scaling_factor}^-1 mod {modulo}") + return integer_value * pow(scaling_factor, -1, modulo) + elif isinstance(value, float) and value < 0: + # TODO: think and implement this later + raise ValueError("Case constant float and negative not implemented yet") + else: + raise ValueError(f"Unimplemented case for constant type {type(value)}") diff --git a/lightphe/commons/logger.py b/lightphe/commons/logger.py index 5c0336c..7411823 100644 --- a/lightphe/commons/logger.py +++ b/lightphe/commons/logger.py @@ -2,6 +2,8 @@ import logging from datetime import datetime +# pylint: disable=broad-except + class Logger: def __init__(self): @@ -15,25 +17,25 @@ def __init__(self): ) self.log_level = logging.INFO - def debug(self, message): - if self.log_level <= logging.DEBUG: - self.dump_log(message) - def info(self, message): if self.log_level <= logging.INFO: self.dump_log(message) + def debug(self, message): + if self.log_level <= logging.DEBUG: + self.dump_log(f"🕷️ {message}") + def warn(self, message): if self.log_level <= logging.WARNING: - self.dump_log(message) + self.dump_log(f"⚠️ {message}") def error(self, message): if self.log_level <= logging.ERROR: - self.dump_log(message) + self.dump_log(f"🔴 {message}") def critical(self, message): if self.log_level <= logging.CRITICAL: - self.dump_log(message) + self.dump_log(f"💥 {message}") def dump_log(self, message): print(f"{str(datetime.now())[2:-7]} - {message}") diff --git a/lightphe/models/Ciphertext.py b/lightphe/models/Ciphertext.py index 5df835e..d143b62 100644 --- a/lightphe/models/Ciphertext.py +++ b/lightphe/models/Ciphertext.py @@ -10,6 +10,7 @@ from lightphe.cryptosystems.NaccacheStern import NaccacheStern from lightphe.cryptosystems.GoldwasserMicali import GoldwasserMicali from lightphe.cryptosystems.EllipticCurveElGamal import EllipticCurveElGamal +from lightphe.commons import calculations from lightphe.commons.logger import Logger logger = Logger() @@ -79,7 +80,9 @@ def __mul__(self, other: Union["Ciphertext", int, float]) -> "Ciphertext": elif isinstance(other, int): result = self.cs.multiply_by_contant(ciphertext=self.value, constant=other) elif isinstance(other, float): - constant = self.__convert_to_int(constant=other) + constant = calculations.parse_int( + value=other, modulo=self.cs.modulo or self.cs.plaintext_modulo + ) result = self.cs.multiply_by_contant(ciphertext=self.value, constant=constant) else: raise ValueError( @@ -96,7 +99,9 @@ def __rmul__(self, constant: Union[int, float]) -> "Ciphertext": scalar multiplication of ciphertext """ if isinstance(constant, float): - constant = self.__convert_to_int(constant=constant) + constant = calculations.parse_int( + value=constant, modulo=self.cs.modulo or self.cs.plaintext_modulo + ) # Handle multiplication with a constant on the right result = self.cs.multiply_by_contant(ciphertext=self.value, constant=constant) @@ -112,30 +117,3 @@ def __xor__(self, other: "Ciphertext") -> "Ciphertext": """ result = self.cs.xor(ciphertext1=self.value, ciphertext2=other.value) return Ciphertext(algorithm_name=self.algorithm_name, keys=self.keys, value=result) - - def __convert_to_int(self, constant: Union[int, float]) -> int: - """ - Convert a constant to integer if it is float or negative - """ - if hasattr(self.cs, "modulo") and self.cs.modulo: - modulo = self.cs.modulo - elif hasattr(self.cs, "plaintext_modulo") and self.cs.plaintext_modulo: - modulo = self.cs.plaintext_modulo - else: - raise ValueError("Cryptosystem must have either modulo or plaintext_modulo") - - if isinstance(constant, int) and constant >= 0: - return constant - elif isinstance(constant, int) and constant < 0: - return constant % modulo - elif isinstance(constant, float) and constant >= 0: - decimal_places = len(str(constant).split(".")[1]) - scaling_factor = 10**decimal_places - integer_value = int(constant * scaling_factor) - logger.debug(f"{integer_value}*{scaling_factor}^-1 mod {modulo}") - return integer_value * pow(scaling_factor, -1, modulo) - elif isinstance(constant, float) and constant < 0: - # TODO: think and implement this later - raise ValueError("Case constant float and negative not implemented yet") - else: - raise ValueError(f"Unimplemented case for constant type {type(constant)}") diff --git a/tests/test_rsa.py b/tests/test_rsa.py index 6b18bc8..75560aa 100644 --- a/tests/test_rsa.py +++ b/tests/test_rsa.py @@ -2,6 +2,7 @@ from lightphe.cryptosystems.RSA import RSA from lightphe.commons.logger import Logger +from lightphe import LightPHE logger = Logger() @@ -36,8 +37,6 @@ def test_rsa(): def test_api(): - from lightphe import LightPHE - cs = LightPHE(algorithm_name="RSA") m1 = 17 @@ -63,3 +62,18 @@ def test_api(): _ = 5 * c1 logger.info("✅ RSA api test succeeded") + + +def test_float_multiplication(): + cs = LightPHE(algorithm_name="RSA") + + m1 = 10000 + m2 = 1.05 + + c1 = cs.encrypt(plaintext=m1) + c2 = cs.encrypt(plaintext=m2) + + # homomorphic addition + assert cs.decrypt(c1 * c2) == m1 * m2 + + logger.info("✅ RSA float multiplication test succeeded")