Skip to content

feat: Encryption Package Implementation #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/lighthouseweb3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import io
from typing import List, Dict, Any
from .functions import (
upload as d,
deal_status,
Expand All @@ -16,7 +17,10 @@
remove_ipns_record as removeIpnsRecord,
create_wallet as createWallet
)

from .functions.kavach import (
generate,
recover_key as recoverKey
)

class Lighthouse:
def __init__(self, token: str = ""):
Expand Down Expand Up @@ -224,3 +228,19 @@ def getTagged(self, tag: str):
except Exception as e:
raise e

class Kavach:
@staticmethod
def generate(threshold: int, keyCount: int):
try:
return generate.generate(threshold, keyCount)
except Exception as e:
raise e


@staticmethod
def recoverKey(keyShards: List[Dict[str, Any]]):
try:
return recoverKey.recover_key(keyShards)
except Exception as e:
raise e

Empty file.
2 changes: 2 additions & 0 deletions src/lighthouseweb3/functions/kavach/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#A 257-bit prime to accommodate 256-bit secrets
PRIME = 2**256 + 297
82 changes: 82 additions & 0 deletions src/lighthouseweb3/functions/kavach/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import secrets
import logging
from typing import Dict, List, Any
from .config import PRIME
logger = logging.getLogger(__name__)


def evaluate_polynomial(coefficients: List[int], x: int, prime: int) -> int:
"""
Evaluate a polynomial with given coefficients at point x.
msk[0] is constant term (the secret), msk[1] is x coefficient, etc.

Args:
coefficients: List of coefficients where coefficients[0] is the constant term
x: Point at which to evaluate the polynomial
prime: Prime number for the finite field

Returns:
The result of the polynomial evaluation modulo prime
"""
result = 0
x_power = 1 # x^0 = 1

for coefficient in coefficients:
result = (result + coefficient * x_power) % prime
x_power = (x_power * x) % prime

return result

async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]:
"""
Generate threshold cryptography key shards using Shamir's Secret Sharing

Args:
threshold: Minimum number of shards needed to reconstruct the secret
key_count: Total number of key shards to generate

Returns:
{
"masterKey": "<master private key hex string>",
"keyShards": [
{
"key": "<shard value hex string>",
"index": "<shard index hex string>"
}
]
}
"""
logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}")

msk=[]
idVec=[]
secVec=[]

if threshold > key_count:
raise ValueError("key_count must be greater than or equal to threshold")
if threshold < 1 or key_count < 1:
raise ValueError("threshold and key_count must be positive integers")


msk = [secrets.randbits(256) for _ in range(threshold)]
master_key = msk[0]

used_ids = set()

for i in range(key_count):
while True:
id_vec = secrets.randbits(32)
if id_vec != 0 and id_vec not in used_ids:
idVec.append(id_vec)
used_ids.add(id_vec)
break

for i in range(key_count):
y = evaluate_polynomial(msk, idVec[i], PRIME)
secVec.append(y)

result = {
"masterKey": hex(master_key),
"keyShards": [{"key": hex(secVec[i]), "index": hex(idVec[i])} for i in range(key_count)]
}
return result
177 changes: 177 additions & 0 deletions src/lighthouseweb3/functions/kavach/recover_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import List, Dict, Any
import logging
from .config import PRIME

logger = logging.getLogger(__name__)

from typing import Tuple

def extended_gcd(a: int, b: int) -> Tuple[int, int, int]:
"""Extended Euclidean algorithm to find modular inverse.

Args:
a: First integer
b: Second integer

Returns:
A tuple (g, x, y) such that a*x + b*y = g = gcd(a, b)
"""
if a == 0:
return b, 0, 1
else:
g, y, x = extended_gcd(b % a, a)
return g, x - (b // a) * y, y

def modinv(a: int, m: int) -> int:
"""Find the modular inverse of a mod m."""
g, x, y = extended_gcd(a, m)
if g != 1:
raise ValueError('Modular inverse does not exist')
else:
return x % m

def lagrange_interpolation(shares: List[Dict[str, str]], prime: int) -> int:
"""
Reconstruct the secret using Lagrange interpolation.

Args:
shares: List of dictionaries with 'key' and 'index' fields
prime: The prime number used in the finite field

Returns:
The reconstructed secret as integer

Raises:
ValueError: If there are duplicate indices
"""

points = []
seen_indices = set()

for i, share in enumerate(shares):
try:
key_str, index_str = validate_share(share, i)
x = int(index_str, 16)

if x in seen_indices:
raise ValueError(f"Duplicate share index found: 0x{x:x}")
seen_indices.add(x)

y = int(key_str, 16)
points.append((x, y))
except ValueError as e:
raise ValueError(f"Invalid share at position {i}: {e}")


secret = 0

for i, (x_i, y_i) in enumerate(points):
# Calculate the Lagrange basis polynomial L_i(0)
# Evaluate at x=0 to get the constant term
numerator = 1
denominator = 1

for j, (x_j, _) in enumerate(points):
if i != j:
numerator = (numerator * (-x_j)) % prime
denominator = (denominator * (x_i - x_j)) % prime

try:
inv_denominator = modinv(denominator, prime)
except ValueError as e:
raise ValueError(f"Error in modular inverse calculation: {e}")

term = (y_i * numerator * inv_denominator) % prime
secret = (secret + term) % prime

return secret

def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]:
"""Validate and normalize a single share.

Args:
share: Dictionary containing 'key' and 'index' fields
index: Position of the share in the input list (for error messages)

Returns:
Tuple of (normalized_key, normalized_index) as strings without '0x' prefix

Raises:
ValueError: If the share is invalid
"""
if not isinstance(share, dict):
raise ValueError(f"Share at index {index} must be a dictionary")

if 'key' not in share or 'index' not in share:
raise ValueError(f"Share at index {index} is missing required fields 'key' or 'index'")

key_str = str(share['key']).strip().lower()
index_str = str(share['index']).strip().lower()

if key_str.startswith('0x'):
key_str = key_str[2:]
if index_str.startswith('0x'):
index_str = index_str[2:]

if not key_str:
raise ValueError(f"Empty key in share at index {index}")
if not index_str:
raise ValueError(f"Empty index in share at index {index}")

if len(key_str) % 2 != 0:
key_str = '0' + key_str

if len(index_str) % 2 != 0:
index_str = '0' + index_str

try:
bytes.fromhex(key_str)
except ValueError:
raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string")

try:
bytes.fromhex(index_str)
except ValueError:
raise ValueError(f"Invalid index format in share at index {index}: must be a valid hex string")

index_int = int(index_str, 16)
if not (0 <= index_int <= 0xFFFFFFFF):
raise ValueError(f"Index out of range in share at index {index}: must be between 0 and 2^32-1")

return key_str, index_str

async def recover_key(keyShards: List[Dict[str, str]]) -> Dict[str, Any]:
"""
Recover the master key from a subset of key shares using Lagrange interpolation.

Args:
keyShards: List of dictionaries containing 'key' and 'index' fields

Returns:
{
"masterKey": "<recovered master key hex string>",
"error": "<error message if any>"
}
"""
logger.info(f"Attempting to recover master key from {len(keyShards)} shares")

try:
for i, share in enumerate(keyShards):
validate_share(share, i)
secret = lagrange_interpolation(keyShards, PRIME)
return {
"masterKey": hex(secret),
"error": None
}
except ValueError as e:
logger.error(f"Validation error during key recovery: {str(e)}")
return {
"masterKey": None,
"error": f"Validation error: {str(e)}"
}
except Exception as e:
logger.error(f"Error during key recovery: {str(e)}")
return {
"masterKey": None,
"error": f"Recovery error: {str(e)}"
}
Empty file added tests/tests_kavach/__init__.py
Empty file.
79 changes: 79 additions & 0 deletions tests/tests_kavach/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest
import asyncio
import logging
from src.lighthouseweb3 import Kavach

logger = logging.getLogger(__name__)

class TestGenerate(unittest.TestCase):
"""Test cases for the generate module."""

def test_generate_basic(self):
"""Test basic key generation with default parameters."""
async def run_test():
result = await Kavach.generate(threshold=2, keyCount=3)

self.assertIn('masterKey', result)
self.assertIn('keyShards', result)

# Check master key format (hex string with 0x prefix)
self.assertIsInstance(result['masterKey'], str)
self.assertTrue(result['masterKey'].startswith('0x'))
self.assertTrue(all(c in '0123456789abcdef' for c in result['masterKey'][2:]))

# Check key shards
self.assertEqual(len(result['keyShards']), 3)
for shard in result['keyShards']:
self.assertIn('key', shard)
self.assertIn('index', shard)

# Check key format (hex string with 0x prefix)
self.assertTrue(shard['key'].startswith('0x'))
self.assertTrue(all(c in '0123456789abcdef' for c in shard['key'][2:]))

# Check index format (hex string with 0x prefix)
self.assertTrue(shard['index'].startswith('0x'))
self.assertTrue(all(c in '0123456789abcdef' for c in shard['index'][2:]))

return result

return asyncio.run(run_test())

def test_generate_custom_parameters(self):
"""Test key generation with custom parameters."""
async def run_test():
threshold = 3
key_count = 5

result = await Kavach.generate(threshold=threshold, keyCount=key_count)

self.assertEqual(len(result['keyShards']), key_count)

# Check all indices are present and unique
indices = [shard['index'] for shard in result['keyShards']]
self.assertEqual(len(set(indices)), key_count) # All unique

# Verify all indices are valid hex strings with 0x prefix
for index in indices:
self.assertTrue(index.startswith('0x'))
self.assertTrue(all(c in '0123456789abcdef' for c in index[2:]))

return result

return asyncio.run(run_test())

def test_invalid_threshold(self):
"""Test that invalid threshold raises an error."""
async def run_test():
with self.assertRaises(ValueError) as context:
await Kavach.generate(threshold=0, keyCount=3)
self.assertIn("must be positive integers", str(context.exception))

with self.assertRaises(ValueError) as context:
await Kavach.generate(threshold=4, keyCount=3)
self.assertIn("must be greater than or equal to threshold", str(context.exception))

return asyncio.run(run_test())

if __name__ == '__main__':
unittest.main(verbosity=2)
Loading